Code Monkey home page Code Monkey logo

trans2seg's Introduction

Transparent Transformer Segmentation

Introduction

This repository contains the data and code for IJCAI 2021 paper Segmenting transparent object in the wild with transformer.

Environments

  • python 3
  • torch = 1.4.0
  • torchvision
  • pyyaml
  • Pillow
  • numpy

INSTALL

python setup.py develop --user

Data Preparation

  1. create dirs './datasets/transparent/Trans10K_v2'
  2. put the train/validation/test data under './datasets/transparent/Trans10K_v2'. Data Structure is shown below.
Trans10K_v2
├── test
│   ├── images
│   └── masks_12
├── train
│   ├── images
│   └── masks_12
└── validation
    ├── images
    └── masks_12

Download Dataset: Google Drive. Baidu Drive. code: oqms

Network Define

The code of Network pipeline is in segmentron/models/trans2seg.py.

The code of Transformer Encoder-Decoder is in segmentron/modules/transformer.py.

Train

Our experiments are based on one machine with 8 V100 GPUs with 32g memory, about 1 hour training time.

bash tools/dist_train.sh $CONFIG-FILE $GPUS

For example:

bash tools/dist_train.sh configs/trans10kv2/trans2seg/trans2seg_medium.yaml 8

Test

bash tools/dist_train.sh $CONFIG-FILE $GPUS --test TEST.TEST_MODEL_PATH $MODEL_PATH

Citations

Please consider citing our paper in your publications if the project helps your research. BibTeX reference is as follows.

@article{xie2021segmenting,
  title={Segmenting transparent object in the wild with transformer},
  author={Xie, Enze and Wang, Wenjia and Wang, Wenhai and Sun, Peize and Xu, Hang and Liang, Ding and Luo, Ping},
  journal={arXiv preprint arXiv:2101.08461},
  year={2021}
}

trans2seg's People

Contributors

cclauss avatar lielinjiang avatar likely-journey avatar lxtgh avatar xieenze avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

trans2seg's Issues

ValueError: Command line options config format error! Please check it:

by input this command:
'''
bash tools/dist_train.sh configs/trans10kv2/trans2seg/trans2seg_medium.yaml 2
'''
but I get error for :
'''
Traceback (most recent call last):
File "tools/train.py", line 296, in
cfg.update_from_list(args.opts)
File "/home/lthpc/Algo/xjw/Trans2Seg/segmentron/config/config.py", line 107, in update_from_list
raise ValueError(
ValueError: Command line options config format error! Please check it: ['2']''', and I try to check the src code:

'''
def update_from_list(self, config_list):
if len(config_list) % 2 != 0:
raise ValueError(
"Command line options config format error! Please check it: {}".
format(config_list))
for key, value in zip(config_list[0::2], config_list[1::2]):
try:
self.setattr(key, value, create_if_not_exist=False)
except KeyError:
raise KeyError('Non-existent config key: {}'.format(key))'''
I don't konw how to fix it?

about sbu result

Hi, thanks for your paper and code. I just found that SBU dataset was implemented without related performance listed in your paper. I wonder the performance on SBU. If possible, would you provide some details performance scores, e.g, ber or accuracy.
Thank you.

Pretrained Model

Hi,
I'm interested in your project and working on transparent object segmentation. Can you provide the pre-trained model? Thanks for your help!

How and where should I modify the code when I am reducing the class

When I try to reduce the categories to two while testing, I get the following error:

THCudaCheck FAIL file=/opt/conda/conda-bld/pytorch_1579027003190/work/aten/src/THC/generic/THCTensorMath.cu line=26 error=59 : device-side assert triggered
Traceback (most recent call last):
File "tools/train.py", line 330, in
trainer.train()
File "tools/train.py", line 170, in train
losses.backward()
File "/opt/conda/envs/trans_seg/lib/python3.6/site-packages/torch/tensor.py", line 195, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/opt/conda/envs/trans_seg/lib/python3.6/site-packages/torch/autograd/init.py", line 99, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: cuda runtime error (59) : device-side assert triggered at /opt/conda/conda-bld/pytorch_1579027003190/work/aten/src/THC/generic/THCTensorMath.cu:26
Traceback (most recent call last):
File "/opt/conda/envs/trans_seg/lib/python3.6/runpy.py", line 193, in _run_module_as_main
"main", mod_spec)
File "/opt/conda/envs/trans_seg/lib/python3.6/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/opt/conda/envs/trans_seg/lib/python3.6/site-packages/torch/distributed/launch.py", line 263, in
main()
File "/opt/conda/envs/trans_seg/lib/python3.6/site-packages/torch/distributed/launch.py", line 259, in main
cmd=cmd)

How to modify the network parameters to run?
thank you!

显存会爆

用最小的模型trans2seg_small.yaml, 4块2080ti显卡,batch_size大小设置成1,显存还是会爆,跑这个模型最低的硬件的条件是什么?

question

Can i run it in windows10 environment? Thanks

Cannot follow the train demo

When I followed the train demo, there were some path errors like

(trans2seg) jinrae@costar-Oryx-Pro:~/workspace/Trans2Seg$ bash tools/dist_train.sh configs/trans10kv2/trans2seg/trans2seg_medium.yaml 1
+ CONFIG=configs/trans10kv2/trans2seg/trans2seg_medium.yaml
+ GPUS=8
++ dirname tools/dist_train.sh
+ python -m torch.distributed.launch --nproc_per_node=8 tools/train.py --config-file configs/trans10kv2/trans2seg/trans2seg_medium.yaml 1
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
*****************************************
Traceback (most recent call last):
  File "tools/train.py", line 20, in <module>
    from segmentron.solver.loss import get_segmentation_loss
  File "/home/jinrae/workspace/Trans2Seg/segmentron/solver/loss.py", line 9, in <module>
Traceback (most recent call last):
  File "tools/train.py", line 20, in <module>
    from ..models.pointrend import point_sample
ModuleNotFoundError: No module named 'segmentron.models.pointrend'
    from segmentron.solver.loss import get_segmentation_loss
  File "/home/jinrae/workspace/Trans2Seg/segmentron/solver/loss.py", line 9, in <module>
    from ..models.pointrend import point_sample
ModuleNotFoundError: No module named 'segmentron.models.pointrend'
Traceback (most recent call last):
  File "tools/train.py", line 20, in <module>
    from segmentron.solver.loss import get_segmentation_loss
  File "/home/jinrae/workspace/Trans2Seg/segmentron/solver/loss.py", line 9, in <module>
    from ..models.pointrend import point_sample
ModuleNotFoundError: No module named 'segmentron.models.pointrend'
Traceback (most recent call last):
  File "tools/train.py", line 20, in <module>
    from segmentron.solver.loss import get_segmentation_loss
  File "/home/jinrae/workspace/Trans2Seg/segmentron/solver/loss.py", line 9, in <module>
    from ..models.pointrend import point_sample
ModuleNotFoundError: No module named 'segmentron.models.pointrend'
Traceback (most recent call last):
  File "/home/jinrae/miniconda3/envs/trans2seg/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/home/jinrae/miniconda3/envs/trans2seg/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/jinrae/miniconda3/envs/trans2seg/lib/python3.6/site-packages/torch/distributed/launch.py", line 263, in <module>
    main()
  File "/home/jinrae/miniconda3/envs/trans2seg/lib/python3.6/site-packages/torch/distributed/launch.py", line 259, in main
    cmd=cmd)
subprocess.CalledProcessError: Command '['/home/jinrae/miniconda3/envs/trans2seg/bin/python', '-u', 'tools/train.py', '--local_rank=7', '--config-file', 'configs/trans10kv2/trans2seg/trans2seg_medium.yaml', '1']' returned non-zero exit status 1.
Traceback (most recent call last):
  File "tools/train.py", line 20, in <module>
    from segmentron.solver.loss import get_segmentation_loss
  File "/home/jinrae/workspace/Trans2Seg/segmentron/solver/loss.py", line 9, in <module>
    from ..models.pointrend import point_sample
ModuleNotFoundError: No module named 'segmentron.models.pointrend'
Traceback (most recent call last):
  File "tools/train.py", line 20, in <module>
    from segmentron.solver.loss import get_segmentation_loss
  File "/home/jinrae/workspace/Trans2Seg/segmentron/solver/loss.py", line 9, in <module>
    from ..models.pointrend import point_sample
ModuleNotFoundError: No module named 'segmentron.models.pointrend'
Traceback (most recent call last):
  File "tools/train.py", line 20, in <module>
    from segmentron.solver.loss import get_segmentation_loss
  File "/home/jinrae/workspace/Trans2Seg/segmentron/solver/loss.py", line 9, in <module>
    from ..models.pointrend import point_sample
ModuleNotFoundError: No module named 'segmentron.models.pointrend'
(trans2seg) jinrae@costar-Oryx-Pro:~/workspace/Trans2Seg$ Traceback (most recent call last):
  File "tools/train.py", line 20, in <module>
    from segmentron.solver.loss import get_segmentation_loss
  File "/home/jinrae/workspace/Trans2Seg/segmentron/solver/loss.py", line 9, in <module>
    from ..models.pointrend import point_sample
ModuleNotFoundError: No module named 'segmentron.models.pointrend'

ModuleNotFoundError: No module named 'segmentron.models.pointrend'

Hi,

Have been trying to run training (though on single GPU).
-> bash tools/dist_train.sh configs/trans10kv2/trans2seg/trans2seg_medium.yaml 1

And below is error stack. Please can you help?
Many thanks.

  • CONFIG=configs/trans10kv2/trans2seg/trans2seg_medium.yaml
  • GPUS=8
    ++ dirname tools/dist_train.sh
  • python -m torch.distributed.launch --nproc_per_node=8 tools/train.py --config-file configs/trans10kv2/trans2seg/trans2seg_medium.yaml 1

Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.


Traceback (most recent call last):
File "tools/train.py", line 20, in
from segmentron.solver.loss import get_segmentation_loss
File "/home/paperspace/Trans2Seg/segmentron/solver/loss.py", line 9, in
from ..models.pointrend import point_sample
ModuleNotFoundError: No module named 'segmentron.models.pointrend'
Traceback (most recent call last):
File "tools/train.py", line 20, in
from segmentron.solver.loss import get_segmentation_loss
File "/home/paperspace/Trans2Seg/segmentron/solver/loss.py", line 9, in
from ..models.pointrend import point_sample
ModuleNotFoundError: No module named 'segmentron.models.pointrend'
Traceback (most recent call last):
File "tools/train.py", line 20, in
from segmentron.solver.loss import get_segmentation_loss
File "/home/paperspace/Trans2Seg/segmentron/solver/loss.py", line 9, in
from ..models.pointrend import point_sample
ModuleNotFoundError: No module named 'segmentron.models.pointrend'
Traceback (most recent call last):
File "tools/train.py", line 20, in
from segmentron.solver.loss import get_segmentation_loss
File "/home/paperspace/Trans2Seg/segmentron/solver/loss.py", line 9, in
from ..models.pointrend import point_sample
ModuleNotFoundError: No module named 'segmentron.models.pointrend'
Traceback (most recent call last):
File "tools/train.py", line 20, in
from segmentron.solver.loss import get_segmentation_loss
File "/home/paperspace/Trans2Seg/segmentron/solver/loss.py", line 9, in
from ..models.pointrend import point_sample
ModuleNotFoundError: No module named 'segmentron.models.pointrend'
Traceback (most recent call last):
File "tools/train.py", line 20, in
from segmentron.solver.loss import get_segmentation_loss
File "/home/paperspace/Trans2Seg/segmentron/solver/loss.py", line 9, in
from ..models.pointrend import point_sample
ModuleNotFoundError: No module named 'segmentron.models.pointrend'
Traceback (most recent call last):
File "tools/train.py", line 20, in
from segmentron.solver.loss import get_segmentation_loss
File "/home/paperspace/Trans2Seg/segmentron/solver/loss.py", line 9, in
from ..models.pointrend import point_sample
ModuleNotFoundError: No module named 'segmentron.models.pointrend'
Traceback (most recent call last):
File "tools/train.py", line 20, in
from segmentron.solver.loss import get_segmentation_loss
File "/home/paperspace/Trans2Seg/segmentron/solver/loss.py", line 9, in
from ..models.pointrend import point_sample
ModuleNotFoundError: No module named 'segmentron.models.pointrend'
Traceback (most recent call last):
File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/paperspace/Torch/lib/python3.8/site-packages/torch/distributed/launch.py", line 263, in
main()
File "/home/paperspace/Torch/lib/python3.8/site-packages/torch/distributed/launch.py", line 258, in main
raise subprocess.CalledProcessError(returncode=process.returncode,
subprocess.CalledProcessError: Command '['/home/paperspace/Torch/bin/python', '-u', 'tools/train.py', '--local_rank=7', '--config-file', 'configs/trans10kv2/trans2seg/trans2seg_medium.yaml', '1']' returned non-zero exit status 1.

about demo

Does anyone knows how to use the demo?It always has a AssertionError when i use it.
File "E:\python\Trans2Seg-master\segmentron\models\trans2seg.py", line 74, in forward
cls_token, x = self.vit.forward_encoder(x)
File "E:\python\Trans2Seg-master\segmentron\modules\transformer.py", line 224, in forward_encoder
pos_embed = self.resize_pos_embed(x, pos_embed)
File "E:\python\Trans2Seg-master\segmentron\modules\transformer.py", line 242, in resize_pos_embed
assert x_h * x_w == hw-1
AssertionError

Pretrained models?

First of all, Thank you for this nice work!
Secondly, are you planning on uploading a pretrained model?

fail running

hello, thanks for sharing. I assume this repo is still undergoing significant clean-up? It looks like a lot of the code was moved from segmentron. But, deep clean-up seems necessary. For example, is loss.py line 9 : from ..models.pointrend import point_sample. However, pointrend only exists in segmentron and has been removed from this repo. So, loss.py will have import error.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.