Code Monkey home page Code Monkey logo

tokenfusion's Introduction

Multimodal Token Fusion for Vision Transformers

By Yikai Wang, Xinghao Chen, Lele Cao, Wenbing Huang, Fuchun Sun, Yunhe Wang.

[Paper]

This repository is a PyTorch implementation of "Multimodal Token Fusion for Vision Transformers", in CVPR 2022.

Homogeneous predictions,

Heterogeneous predictions,

Datasets

For semantic segmentation task on NYUDv2 (official dataset), we provide a link to download the dataset here. The provided dataset is originally preprocessed in this repository, and we add depth data in it.

For image-to-image translation task, we use the sample dataset of Taskonomy, where a link to download the sample dataset is here.

Please modify the data paths in the codes, where we add comments 'Modify data path'.

Dependencies

python==3.6
pytorch==1.7.1
torchvision==0.8.2
numpy==1.19.2

Semantic Segmentation

First,

cd semantic_segmentation

Download the segformer pretrained model (pretrained on ImageNet) from weights, e.g., mit_b3.pth. Move this pretrained model to folder 'pretrained'.

Training script for segmentation with RGB and Depth input,

python main.py --backbone mit_b3 -c exp_name --lamda 1e-6 --gpu 0 1 2

Evaluation script,

python main.py --gpu 0 --resume path_to_pth --evaluate  # optionally use --save-img to visualize results

Checkpoint models, training logs, mask ratios and the single-scale performance on NYUDv2 are provided as follows:

Method Backbone Pixel Acc. (%) Mean Acc. (%) Mean IoU (%) Download
CEN ResNet101 76.2 62.8 51.1 Google Drive
CEN ResNet152 77.0 64.4 51.6 Google Drive
Ours SegFormer-B3 78.7 67.5 54.8 Google Drive

Mindspore implementation is available at: https://gitee.com/mindspore/models/tree/master/research/cv/TokenFusion

Image-to-Image Translation

First,

cd image2image_translation

Training script, from Shade and Texture to RGB,

python main.py --gpu 0 -c exp_name

This script will auto-evaluate on the validation dataset every 5 training epochs.

Predicted images will be automatically saved during training, in the following folder structure:

code_root/ckpt/exp_name/results
  ├── input0  # 1st modality input
  ├── input1  # 2nd modality input
  ├── fake0   # 1st branch output 
  ├── fake1   # 2nd branch output
  ├── fake2   # ensemble output
  ├── best    # current best output
  │    ├── fake0
  │    ├── fake1
  │    └── fake2
  └── real    # ground truth output

Checkpoint models:

Method Task FID KID Download
CEN Texture+Shade->RGB 62.6 1.65 -
Ours Texture+Shade->RGB 45.5 1.00 Google Drive

3D Object Detection (under construction)

Data preparation, environments, and training scripts follow Group-Free and ImVoteNet.

E.g.,

CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --master_port 2229 --nproc_per_node 4 train_dist.py --max_epoch 600 --val_freq 25 --save_freq 25 --lr_decay_epochs 420 480 540 --num_point 20000 --num_decoder_layers 6 --size_cls_agnostic --size_delta 0.0625 --heading_delta 0.04 --center_delta 0.1111111111111 --weight_decay 0.00000001 --query_points_generator_loss_coef 0.2 --obj_loss_coef 0.4 --dataset sunrgbd --data_root . --use_img --log_dir log/exp_name

Citation

If you find our work useful for your research, please consider citing the following paper.

@inproceedings{wang2022tokenfusion,
  title={Multimodal Token Fusion for Vision Transformers},
  author={Wang, Yikai and Chen, Xinghao and Cao, Lele and Huang, Wenbing and Sun, Fuchun and Wang, Yunhe},
  booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
  year={2022}
}

tokenfusion's People

Contributors

yikaiw avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

tokenfusion's Issues

cannot run evaluation for semantic segmentation with the checkpoint

I followed the README file to evaluate semantic segmentation, with the command below:

python3 main.py --gpu 0 --resume ./pretrained_models/model-best.pth.tar --evaluate

where the model-best.pth.tar is download from

Method Backbone Pixel Acc. (%) Mean Acc. (%) Mean IoU (%) Download
Ours SegFormer-B3 78.7 67.5 54.8 Google Drive

and i got error message:

Loaded Segmenter mit_b1, ImageNet-Pre-Trained=True, #PARAMS=14.96M
Traceback (most recent call last):
  File "main.py", line 459, in <module>
    main()
  File "main.py", line 382, in main
    best_val, epoch_start = load_ckpt(args.resume, {'segmenter': segmenter})
  File "main.py", line 211, in load_ckpt
    v.load_state_dict(ckpt[k])
  File "/home/kb/anaconda3/envs/tokenfusion/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1605, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for DataParallel:
	Unexpected key(s) in state_dict: "module.encoder.block1.2.norm1.ln_0.weight", "module.encoder.block1.2.norm1.ln_0.bias", "module.encoder.block1.2.norm1.ln_1.weight", "module.encoder.block1.2.norm1.ln_1.bias", "module.encoder.block1.2.attn.q.module.weight", "module.encoder.block1.2.attn.q.module.bias", "module.encoder.block1.2.attn.kv.module.weight", "module.encoder.block1.2.attn.kv.module.bias", "module.encoder.block1.2.attn.proj.module.weight", "module.encoder.block1.2.attn.proj.module.bias", "module.encoder.block1.2.attn.sr.module.weight", "module.encoder.block1.2.attn.sr.module.bias", "module.encoder.block1.2.attn.norm.ln_0.weight", "module.encoder.block1.2.attn.norm.ln_0.bias", "module.encoder.block1.2.attn.norm.ln_1.weight", "module.encoder.block1.2.attn.norm.ln_1.bias", "module.encoder.block1.2.norm2.ln_0.weight", "module.encoder.block1.2.norm2.ln_0.bias", "module.encoder.block1.2.norm2.ln_1.weight", "module.encoder.block1.2.norm2.ln_1.bias", "module.encoder.block1.2.mlp.fc1.module.weight", "module.encoder.block1.2.mlp.fc1.module.bias", "module.encoder.block1.2.mlp.dwconv.dwconv.weight", "module.encoder.block1.2.mlp.dwconv.dwconv.bias", "module.encoder.block1.2.mlp.fc2.module.weight", "module.encoder.block1.2.mlp.fc2.module.bias", "module.encoder.block2.2.norm1.ln_0.weight", "module.encoder.block2.2.norm1.ln_0.bias", "module.encoder.block2.2.norm1.ln_1.weight", "module.encoder.block2.2.norm1.ln_1.bias", "module.encoder.block2.2.attn.q.module.weight", "module.encoder.block2.2.attn.q.module.bias", "module.encoder.block2.2.attn.kv.module.weight", "module.encoder.block2.2.attn.kv.module.bias", "module.encoder.block2.2.attn.proj.module.weight", "module.encoder.block2.2.attn.proj.module.bias", "module.encoder.block2.2.attn.sr.module.weight", "module.encoder.block2.2.attn.sr.module.bias", "module.encoder.block2.2.attn.norm.ln_0.weight", "module.encoder.block2.2.attn.norm.ln_0.bias", "module.encoder.block2.2.attn.norm.ln_1.weight", "module.encoder.block2.2.attn.norm.ln_1.bias", "module.encoder.block2.2.norm2.ln_0.weight", "module.encoder.block2.2.norm2.ln_0.bias", "module.encoder.block2.2.norm2.ln_1.weight", "module.encoder.block2.2.norm2.ln_1.bias", "module.encoder.block2.2.mlp.fc1.module.weight", "module.encoder.block2.2.mlp.fc1.module.bias", "module.encoder.block2.2.mlp.dwconv.dwconv.weight", "module.encoder.block2.2.mlp.dwconv.dwconv.bias", "module.encoder.block2.2.mlp.fc2.module.weight", "module.encoder.block2.2.mlp.fc2.module.bias", "module.encoder.block2.3.norm1.ln_0.weight", "module.encoder.block2.3.norm1.ln_0.bias", "module.encoder.block2.3.norm1.ln_1.weight", "module.encoder.block2.3.norm1.ln_1.bias", "module.encoder.block2.3.attn.q.module.weight", "module.encoder.block2.3.attn.q.module.bias", "module.encoder.block2.3.attn.kv.module.weight", "module.encoder.block2.3.attn.kv.module.bias", "module.encoder.block2.3.attn.proj.module.weight", "module.encoder.block2.3.attn.proj.module.bias", "module.encoder.block2.3.attn.sr.module.weight", "module.encoder.block2.3.attn.sr.module.bias", "module.encoder.block2.3.attn.norm.ln_0.weight", "module.encoder.block2.3.attn.norm.ln_0.bias", "module.encoder.block2.3.attn.norm.ln_1.weight", "module.encoder.block2.3.attn.norm.ln_1.bias", "module.encoder.block2.3.norm2.ln_0.weight", "module.encoder.block2.3.norm2.ln_0.bias", "module.encoder.block2.3.norm2.ln_1.weight", "module.encoder.block2.3.norm2.ln_1.bias", "module.encoder.block2.3.mlp.fc1.module.weight", "module.encoder.block2.3.mlp.fc1.module.bias", "module.encoder.block2.3.mlp.dwconv.dwconv.weight", "module.encoder.block2.3.mlp.dwconv.dwconv.bias", "module.encoder.block2.3.mlp.fc2.module.weight", "module.encoder.block2.3.mlp.fc2.module.bias", "module.encoder.block3.2.norm1.ln_0.weight", "module.encoder.block3.2.norm1.ln_0.bias", "module.encoder.block3.2.norm1.ln_1.weight", "module.encoder.block3.2.norm1.ln_1.bias", "module.encoder.block3.2.attn.q.module.weight", "module.encoder.block3.2.attn.q.module.bias", "module.encoder.block3.2.attn.kv.module.weight", "module.encoder.block3.2.attn.kv.module.bias", "module.encoder.block3.2.attn.proj.module.weight", "module.encoder.block3.2.attn.proj.module.bias", "module.encoder.block3.2.attn.sr.module.weight", "module.encoder.block3.2.attn.sr.module.bias", "module.encoder.block3.2.attn.norm.ln_0.weight", "module.encoder.block3.2.attn.norm.ln_0.bias", "module.encoder.block3.2.attn.norm.ln_1.weight", "module.encoder.block3.2.attn.norm.ln_1.bias", "module.encoder.block3.2.norm2.ln_0.weight", "module.encoder.block3.2.norm2.ln_0.bias", "module.encoder.block3.2.norm2.ln_1.weight", "module.encoder.block3.2.norm2.ln_1.bias", "module.encoder.block3.2.mlp.fc1.module.weight", "module.encoder.block3.2.mlp.fc1.module.bias", "module.encoder.block3.2.mlp.dwconv.dwconv.weight", "module.encoder.block3.2.mlp.dwconv.dwconv.bias", "module.encoder.block3.2.mlp.fc2.module.weight", "module.encoder.block3.2.mlp.fc2.module.bias", "module.encoder.block3.3.norm1.ln_0.weight", "module.encoder.block3.3.norm1.ln_0.bias", "module.encoder.block3.3.norm1.ln_1.weight", "module.encoder.block3.3.norm1.ln_1.bias", "module.encoder.block3.3.attn.q.module.weight", "module.encoder.block3.3.attn.q.module.bias", "module.encoder.block3.3.attn.kv.module.weight", "module.encoder.block3.3.attn.kv.module.bias", "module.encoder.block3.3.attn.proj.module.weight", "module.encoder.block3.3.attn.proj.module.bias", "module.encoder.block3.3.attn.sr.module.weight", "module.encoder.block3.3.attn.sr.module.bias", "module.encoder.block3.3.attn.norm.ln_0.weight", "module.encoder.block3.3.attn.norm.ln_0.bias", "module.encoder.block3.3.attn.norm.ln_1.weight", "module.encoder.block3.3.attn.norm.ln_1.bias", "module.encoder.block3.3.norm2.ln_0.weight", "module.encoder.block3.3.norm2.ln_0.bias", "module.encoder.block3.3.norm2.ln_1.weight", "module.encoder.block3.3.norm2.ln_1.bias", "module.encoder.block3.3.mlp.fc1.module.weight", "module.encoder.block3.3.mlp.fc1.module.bias", "module.encoder.block3.3.mlp.dwconv.dwconv.weight", "module.encoder.block3.3.mlp.dwconv.dwconv.bias", "module.encoder.block3.3.mlp.fc2.module.weight", "module.encoder.block3.3.mlp.fc2.module.bias", "module.encoder.block3.4.norm1.ln_0.weight", "module.encoder.block3.4.norm1.ln_0.bias", "module.encoder.block3.4.norm1.ln_1.weight", "module.encoder.block3.4.norm1.ln_1.bias", "module.encoder.block3.4.attn.q.module.weight", "module.encoder.block3.4.attn.q.module.bias", "module.encoder.block3.4.attn.kv.module.weight", "module.encoder.block3.4.attn.kv.module.bias", "module.encoder.block3.4.attn.proj.module.weight", "module.encoder.block3.4.attn.proj.module.bias", "module.encoder.block3.4.attn.sr.module.weight", "module.encoder.block3.4.attn.sr.module.bias", "module.encoder.block3.4.attn.norm.ln_0.weight", "module.encoder.block3.4.attn.norm.ln_0.bias", "module.encoder.block3.4.attn.norm.ln_1.weight", "module.encoder.block3.4.attn.norm.ln_1.bias", "module.encoder.block3.4.norm2.ln_0.weight", "module.encoder.block3.4.norm2.ln_0.bias", "module.encoder.block3.4.norm2.ln_1.weight", "module.encoder.block3.4.norm2.ln_1.bias", "module.encoder.block3.4.mlp.fc1.module.weight", "module.encoder.block3.4.mlp.fc1.module.bias", "module.encoder.block3.4.mlp.dwconv.dwconv.weight", "module.encoder.block3.4.mlp.dwconv.dwconv.bias", "module.encoder.block3.4.mlp.fc2.module.weight", "module.encoder.block3.4.mlp.fc2.module.bias", "module.encoder.block3.5.norm1.ln_0.weight", "module.encoder.block3.5.norm1.ln_0.bias", "module.encoder.block3.5.norm1.ln_1.weight", "module.encoder.block3.5.norm1.ln_1.bias", "module.encoder.block3.5.attn.q.module.weight", "module.encoder.block3.5.attn.q.module.bias", "module.encoder.block3.5.attn.kv.module.weight", "module.encoder.block3.5.attn.kv.module.bias", "module.encoder.block3.5.attn.proj.module.weight", "module.encoder.block3.5.attn.proj.module.bias", "module.encoder.block3.5.attn.sr.module.weight", "module.encoder.block3.5.attn.sr.module.bias", "module.encoder.block3.5.attn.norm.ln_0.weight", "module.encoder.block3.5.attn.norm.ln_0.bias", "module.encoder.block3.5.attn.norm.ln_1.weight", "module.encoder.block3.5.attn.norm.ln_1.bias", "module.encoder.block3.5.norm2.ln_0.weight", "module.encoder.block3.5.norm2.ln_0.bias", "module.encoder.block3.5.norm2.ln_1.weight", "module.encoder.block3.5.norm2.ln_1.bias", "module.encoder.block3.5.mlp.fc1.module.weight", "module.encoder.block3.5.mlp.fc1.module.bias", "module.encoder.block3.5.mlp.dwconv.dwconv.weight", "module.encoder.block3.5.mlp.dwconv.dwconv.bias", "module.encoder.block3.5.mlp.fc2.module.weight", "module.encoder.block3.5.mlp.fc2.module.bias", "module.encoder.block3.6.norm1.ln_0.weight", "module.encoder.block3.6.norm1.ln_0.bias", "module.encoder.block3.6.norm1.ln_1.weight", "module.encoder.block3.6.norm1.ln_1.bias", "module.encoder.block3.6.attn.q.module.weight", "module.encoder.block3.6.attn.q.module.bias", "module.encoder.block3.6.attn.kv.module.weight", "module.encoder.block3.6.attn.kv.module.bias", "module.encoder.block3.6.attn.proj.module.weight", "module.encoder.block3.6.attn.proj.module.bias", "module.encoder.block3.6.attn.sr.module.weight", "module.encoder.block3.6.attn.sr.module.bias", "module.encoder.block3.6.attn.norm.ln_0.weight", "module.encoder.block3.6.attn.norm.ln_0.bias", "module.encoder.block3.6.attn.norm.ln_1.weight", "module.encoder.block3.6.attn.norm.ln_1.bias", "module.encoder.block3.6.norm2.ln_0.weight", "module.encoder.block3.6.norm2.ln_0.bias", "module.encoder.block3.6.norm2.ln_1.weight", "module.encoder.block3.6.norm2.ln_1.bias", "module.encoder.block3.6.mlp.fc1.module.weight", "module.encoder.block3.6.mlp.fc1.module.bias", "module.encoder.block3.6.mlp.dwconv.dwconv.weight", "module.encoder.block3.6.mlp.dwconv.dwconv.bias", "module.encoder.block3.6.mlp.fc2.module.weight", "module.encoder.block3.6.mlp.fc2.module.bias", "module.encoder.block3.7.norm1.ln_0.weight", "module.encoder.block3.7.norm1.ln_0.bias", "module.encoder.block3.7.norm1.ln_1.weight", "module.encoder.block3.7.norm1.ln_1.bias", "module.encoder.block3.7.attn.q.module.weight", "module.encoder.block3.7.attn.q.module.bias", "module.encoder.block3.7.attn.kv.module.weight", "module.encoder.block3.7.attn.kv.module.bias", "module.encoder.block3.7.attn.proj.module.weight", "module.encoder.block3.7.attn.proj.module.bias", "module.encoder.block3.7.attn.sr.module.weight", "module.encoder.block3.7.attn.sr.module.bias", "module.encoder.block3.7.attn.norm.ln_0.weight", "module.encoder.block3.7.attn.norm.ln_0.bias", "module.encoder.block3.7.attn.norm.ln_1.weight", "module.encoder.block3.7.attn.norm.ln_1.bias", "module.encoder.block3.7.norm2.ln_0.weight", "module.encoder.block3.7.norm2.ln_0.bias", "module.encoder.block3.7.norm2.ln_1.weight", "module.encoder.block3.7.norm2.ln_1.bias", "module.encoder.block3.7.mlp.fc1.module.weight", "module.encoder.block3.7.mlp.fc1.module.bias", "module.encoder.block3.7.mlp.dwconv.dwconv.weight", "module.encoder.block3.7.mlp.dwconv.dwconv.bias", "module.encoder.block3.7.mlp.fc2.module.weight", "module.encoder.block3.7.mlp.fc2.module.bias", "module.encoder.block3.8.norm1.ln_0.weight", "module.encoder.block3.8.norm1.ln_0.bias", "module.encoder.block3.8.norm1.ln_1.weight", "module.encoder.block3.8.norm1.ln_1.bias", "module.encoder.block3.8.attn.q.module.weight", "module.encoder.block3.8.attn.q.module.bias", "module.encoder.block3.8.attn.kv.module.weight", "module.encoder.block3.8.attn.kv.module.bias", "module.encoder.block3.8.attn.proj.module.weight", "module.encoder.block3.8.attn.proj.module.bias", "module.encoder.block3.8.attn.sr.module.weight", "module.encoder.block3.8.attn.sr.module.bias", "module.encoder.block3.8.attn.norm.ln_0.weight", "module.encoder.block3.8.attn.norm.ln_0.bias", "module.encoder.block3.8.attn.norm.ln_1.weight", "module.encoder.block3.8.attn.norm.ln_1.bias", "module.encoder.block3.8.norm2.ln_0.weight", "module.encoder.block3.8.norm2.ln_0.bias", "module.encoder.block3.8.norm2.ln_1.weight", "module.encoder.block3.8.norm2.ln_1.bias", "module.encoder.block3.8.mlp.fc1.module.weight", "module.encoder.block3.8.mlp.fc1.module.bias", "module.encoder.block3.8.mlp.dwconv.dwconv.weight", "module.encoder.block3.8.mlp.dwconv.dwconv.bias", "module.encoder.block3.8.mlp.fc2.module.weight", "module.encoder.block3.8.mlp.fc2.module.bias", "module.encoder.block3.9.norm1.ln_0.weight", "module.encoder.block3.9.norm1.ln_0.bias", "module.encoder.block3.9.norm1.ln_1.weight", "module.encoder.block3.9.norm1.ln_1.bias", "module.encoder.block3.9.attn.q.module.weight", "module.encoder.block3.9.attn.q.module.bias", "module.encoder.block3.9.attn.kv.module.weight", "module.encoder.block3.9.attn.kv.module.bias", "module.encoder.block3.9.attn.proj.module.weight", "module.encoder.block3.9.attn.proj.module.bias", "module.encoder.block3.9.attn.sr.module.weight", "module.encoder.block3.9.attn.sr.module.bias", "module.encoder.block3.9.attn.norm.ln_0.weight", "module.encoder.block3.9.attn.norm.ln_0.bias", "module.encoder.block3.9.attn.norm.ln_1.weight", "module.encoder.block3.9.attn.norm.ln_1.bias", "module.encoder.block3.9.norm2.ln_0.weight", "module.encoder.block3.9.norm2.ln_0.bias", "module.encoder.block3.9.norm2.ln_1.weight", "module.encoder.block3.9.norm2.ln_1.bias", "module.encoder.block3.9.mlp.fc1.module.weight", "module.encoder.block3.9.mlp.fc1.module.bias", "module.encoder.block3.9.mlp.dwconv.dwconv.weight", "module.encoder.block3.9.mlp.dwconv.dwconv.bias", "module.encoder.block3.9.mlp.fc2.module.weight", "module.encoder.block3.9.mlp.fc2.module.bias", "module.encoder.block3.10.norm1.ln_0.weight", "module.encoder.block3.10.norm1.ln_0.bias", "module.encoder.block3.10.norm1.ln_1.weight", "module.encoder.block3.10.norm1.ln_1.bias", "module.encoder.block3.10.attn.q.module.weight", "module.encoder.block3.10.attn.q.module.bias", "module.encoder.block3.10.attn.kv.module.weight", "module.encoder.block3.10.attn.kv.module.bias", "module.encoder.block3.10.attn.proj.module.weight", "module.encoder.block3.10.attn.proj.module.bias", "module.encoder.block3.10.attn.sr.module.weight", "module.encoder.block3.10.attn.sr.module.bias", "module.encoder.block3.10.attn.norm.ln_0.weight", "module.encoder.block3.10.attn.norm.ln_0.bias", "module.encoder.block3.10.attn.norm.ln_1.weight", "module.encoder.block3.10.attn.norm.ln_1.bias", "module.encoder.block3.10.norm2.ln_0.weight", "module.encoder.block3.10.norm2.ln_0.bias", "module.encoder.block3.10.norm2.ln_1.weight", "module.encoder.block3.10.norm2.ln_1.bias", "module.encoder.block3.10.mlp.fc1.module.weight", "module.encoder.block3.10.mlp.fc1.module.bias", "module.encoder.block3.10.mlp.dwconv.dwconv.weight", "module.encoder.block3.10.mlp.dwconv.dwconv.bias", "module.encoder.block3.10.mlp.fc2.module.weight", "module.encoder.block3.10.mlp.fc2.module.bias", "module.encoder.block3.11.norm1.ln_0.weight", "module.encoder.block3.11.norm1.ln_0.bias", "module.encoder.block3.11.norm1.ln_1.weight", "module.encoder.block3.11.norm1.ln_1.bias", "module.encoder.block3.11.attn.q.module.weight", "module.encoder.block3.11.attn.q.module.bias", "module.encoder.block3.11.attn.kv.module.weight", "module.encoder.block3.11.attn.kv.module.bias", "module.encoder.block3.11.attn.proj.module.weight", "module.encoder.block3.11.attn.proj.module.bias", "module.encoder.block3.11.attn.sr.module.weight", "module.encoder.block3.11.attn.sr.module.bias", "module.encoder.block3.11.attn.norm.ln_0.weight", "module.encoder.block3.11.attn.norm.ln_0.bias", "module.encoder.block3.11.attn.norm.ln_1.weight", "module.encoder.block3.11.attn.norm.ln_1.bias", "module.encoder.block3.11.norm2.ln_0.weight", "module.encoder.block3.11.norm2.ln_0.bias", "module.encoder.block3.11.norm2.ln_1.weight", "module.encoder.block3.11.norm2.ln_1.bias", "module.encoder.block3.11.mlp.fc1.module.weight", "module.encoder.block3.11.mlp.fc1.module.bias", "module.encoder.block3.11.mlp.dwconv.dwconv.weight", "module.encoder.block3.11.mlp.dwconv.dwconv.bias", "module.encoder.block3.11.mlp.fc2.module.weight", "module.encoder.block3.11.mlp.fc2.module.bias", "module.encoder.block3.12.norm1.ln_0.weight", "module.encoder.block3.12.norm1.ln_0.bias", "module.encoder.block3.12.norm1.ln_1.weight", "module.encoder.block3.12.norm1.ln_1.bias", "module.encoder.block3.12.attn.q.module.weight", "module.encoder.block3.12.attn.q.module.bias", "module.encoder.block3.12.attn.kv.module.weight", "module.encoder.block3.12.attn.kv.module.bias", "module.encoder.block3.12.attn.proj.module.weight", "module.encoder.block3.12.attn.proj.module.bias", "module.encoder.block3.12.attn.sr.module.weight", "module.encoder.block3.12.attn.sr.module.bias", "module.encoder.block3.12.attn.norm.ln_0.weight", "module.encoder.block3.12.attn.norm.ln_0.bias", "module.encoder.block3.12.attn.norm.ln_1.weight", "module.encoder.block3.12.attn.norm.ln_1.bias", "module.encoder.block3.12.norm2.ln_0.weight", "module.encoder.block3.12.norm2.ln_0.bias", "module.encoder.block3.12.norm2.ln_1.weight", "module.encoder.block3.12.norm2.ln_1.bias", "module.encoder.block3.12.mlp.fc1.module.weight", "module.encoder.block3.12.mlp.fc1.module.bias", "module.encoder.block3.12.mlp.dwconv.dwconv.weight", "module.encoder.block3.12.mlp.dwconv.dwconv.bias", "module.encoder.block3.12.mlp.fc2.module.weight", "module.encoder.block3.12.mlp.fc2.module.bias", "module.encoder.block3.13.norm1.ln_0.weight", "module.encoder.block3.13.norm1.ln_0.bias", "module.encoder.block3.13.norm1.ln_1.weight", "module.encoder.block3.13.norm1.ln_1.bias", "module.encoder.block3.13.attn.q.module.weight", "module.encoder.block3.13.attn.q.module.bias", "module.encoder.block3.13.attn.kv.module.weight", "module.encoder.block3.13.attn.kv.module.bias", "module.encoder.block3.13.attn.proj.module.weight", "module.encoder.block3.13.attn.proj.module.bias", "module.encoder.block3.13.attn.sr.module.weight", "module.encoder.block3.13.attn.sr.module.bias", "module.encoder.block3.13.attn.norm.ln_0.weight", "module.encoder.block3.13.attn.norm.ln_0.bias", "module.encoder.block3.13.attn.norm.ln_1.weight", "module.encoder.block3.13.attn.norm.ln_1.bias", "module.encoder.block3.13.norm2.ln_0.weight", "module.encoder.block3.13.norm2.ln_0.bias", "module.encoder.block3.13.norm2.ln_1.weight", "module.encoder.block3.13.norm2.ln_1.bias", "module.encoder.block3.13.mlp.fc1.module.weight", "module.encoder.block3.13.mlp.fc1.module.bias", "module.encoder.block3.13.mlp.dwconv.dwconv.weight", "module.encoder.block3.13.mlp.dwconv.dwconv.bias", "module.encoder.block3.13.mlp.fc2.module.weight", "module.encoder.block3.13.mlp.fc2.module.bias", "module.encoder.block3.14.norm1.ln_0.weight", "module.encoder.block3.14.norm1.ln_0.bias", "module.encoder.block3.14.norm1.ln_1.weight", "module.encoder.block3.14.norm1.ln_1.bias", "module.encoder.block3.14.attn.q.module.weight", "module.encoder.block3.14.attn.q.module.bias", "module.encoder.block3.14.attn.kv.module.weight", "module.encoder.block3.14.attn.kv.module.bias", "module.encoder.block3.14.attn.proj.module.weight", "module.encoder.block3.14.attn.proj.module.bias", "module.encoder.block3.14.attn.sr.module.weight", "module.encoder.block3.14.attn.sr.module.bias", "module.encoder.block3.14.attn.norm.ln_0.weight", "module.encoder.block3.14.attn.norm.ln_0.bias", "module.encoder.block3.14.attn.norm.ln_1.weight", "module.encoder.block3.14.attn.norm.ln_1.bias", "module.encoder.block3.14.norm2.ln_0.weight", "module.encoder.block3.14.norm2.ln_0.bias", "module.encoder.block3.14.norm2.ln_1.weight", "module.encoder.block3.14.norm2.ln_1.bias", "module.encoder.block3.14.mlp.fc1.module.weight", "module.encoder.block3.14.mlp.fc1.module.bias", "module.encoder.block3.14.mlp.dwconv.dwconv.weight", "module.encoder.block3.14.mlp.dwconv.dwconv.bias", "module.encoder.block3.14.mlp.fc2.module.weight", "module.encoder.block3.14.mlp.fc2.module.bias", "module.encoder.block3.15.norm1.ln_0.weight", "module.encoder.block3.15.norm1.ln_0.bias", "module.encoder.block3.15.norm1.ln_1.weight", "module.encoder.block3.15.norm1.ln_1.bias", "module.encoder.block3.15.attn.q.module.weight", "module.encoder.block3.15.attn.q.module.bias", "module.encoder.block3.15.attn.kv.module.weight", "module.encoder.block3.15.attn.kv.module.bias", "module.encoder.block3.15.attn.proj.module.weight", "module.encoder.block3.15.attn.proj.module.bias", "module.encoder.block3.15.attn.sr.module.weight", "module.encoder.block3.15.attn.sr.module.bias", "module.encoder.block3.15.attn.norm.ln_0.weight", "module.encoder.block3.15.attn.norm.ln_0.bias", "module.encoder.block3.15.attn.norm.ln_1.weight", "module.encoder.block3.15.attn.norm.ln_1.bias", "module.encoder.block3.15.norm2.ln_0.weight", "module.encoder.block3.15.norm2.ln_0.bias", "module.encoder.block3.15.norm2.ln_1.weight", "module.encoder.block3.15.norm2.ln_1.bias", "module.encoder.block3.15.mlp.fc1.module.weight", "module.encoder.block3.15.mlp.fc1.module.bias", "module.encoder.block3.15.mlp.dwconv.dwconv.weight", "module.encoder.block3.15.mlp.dwconv.dwconv.bias", "module.encoder.block3.15.mlp.fc2.module.weight", "module.encoder.block3.15.mlp.fc2.module.bias", "module.encoder.block3.16.norm1.ln_0.weight", "module.encoder.block3.16.norm1.ln_0.bias", "module.encoder.block3.16.norm1.ln_1.weight", "module.encoder.block3.16.norm1.ln_1.bias", "module.encoder.block3.16.attn.q.module.weight", "module.encoder.block3.16.attn.q.module.bias", "module.encoder.block3.16.attn.kv.module.weight", "module.encoder.block3.16.attn.kv.module.bias", "module.encoder.block3.16.attn.proj.module.weight", "module.encoder.block3.16.attn.proj.module.bias", "module.encoder.block3.16.attn.sr.module.weight", "module.encoder.block3.16.attn.sr.module.bias", "module.encoder.block3.16.attn.norm.ln_0.weight", "module.encoder.block3.16.attn.norm.ln_0.bias", "module.encoder.block3.16.attn.norm.ln_1.weight", "module.encoder.block3.16.attn.norm.ln_1.bias", "module.encoder.block3.16.norm2.ln_0.weight", "module.encoder.block3.16.norm2.ln_0.bias", "module.encoder.block3.16.norm2.ln_1.weight", "module.encoder.block3.16.norm2.ln_1.bias", "module.encoder.block3.16.mlp.fc1.module.weight", "module.encoder.block3.16.mlp.fc1.module.bias", "module.encoder.block3.16.mlp.dwconv.dwconv.weight", "module.encoder.block3.16.mlp.dwconv.dwconv.bias", "module.encoder.block3.16.mlp.fc2.module.weight", "module.encoder.block3.16.mlp.fc2.module.bias", "module.encoder.block3.17.norm1.ln_0.weight", "module.encoder.block3.17.norm1.ln_0.bias", "module.encoder.block3.17.norm1.ln_1.weight", "module.encoder.block3.17.norm1.ln_1.bias", "module.encoder.block3.17.attn.q.module.weight", "module.encoder.block3.17.attn.q.module.bias", "module.encoder.block3.17.attn.kv.module.weight", "module.encoder.block3.17.attn.kv.module.bias", "module.encoder.block3.17.attn.proj.module.weight", "module.encoder.block3.17.attn.proj.module.bias", "module.encoder.block3.17.attn.sr.module.weight", "module.encoder.block3.17.attn.sr.module.bias", "module.encoder.block3.17.attn.norm.ln_0.weight", "module.encoder.block3.17.attn.norm.ln_0.bias", "module.encoder.block3.17.attn.norm.ln_1.weight", "module.encoder.block3.17.attn.norm.ln_1.bias", "module.encoder.block3.17.norm2.ln_0.weight", "module.encoder.block3.17.norm2.ln_0.bias", "module.encoder.block3.17.norm2.ln_1.weight", "module.encoder.block3.17.norm2.ln_1.bias", "module.encoder.block3.17.mlp.fc1.module.weight", "module.encoder.block3.17.mlp.fc1.module.bias", "module.encoder.block3.17.mlp.dwconv.dwconv.weight", "module.encoder.block3.17.mlp.dwconv.dwconv.bias", "module.encoder.block3.17.mlp.fc2.module.weight", "module.encoder.block3.17.mlp.fc2.module.bias", "module.encoder.block4.2.norm1.ln_0.weight", "module.encoder.block4.2.norm1.ln_0.bias", "module.encoder.block4.2.norm1.ln_1.weight", "module.encoder.block4.2.norm1.ln_1.bias", "module.encoder.block4.2.attn.q.module.weight", "module.encoder.block4.2.attn.q.module.bias", "module.encoder.block4.2.attn.kv.module.weight", "module.encoder.block4.2.attn.kv.module.bias", "module.encoder.block4.2.attn.proj.module.weight", "module.encoder.block4.2.attn.proj.module.bias", "module.encoder.block4.2.norm2.ln_0.weight", "module.encoder.block4.2.norm2.ln_0.bias", "module.encoder.block4.2.norm2.ln_1.weight", "module.encoder.block4.2.norm2.ln_1.bias", "module.encoder.block4.2.mlp.fc1.module.weight", "module.encoder.block4.2.mlp.fc1.module.bias", "module.encoder.block4.2.mlp.dwconv.dwconv.weight", "module.encoder.block4.2.mlp.dwconv.dwconv.bias", "module.encoder.block4.2.mlp.fc2.module.weight", "module.encoder.block4.2.mlp.fc2.module.bias". 

l1_lamda value for SUN RGBD

Hi,
I am trying to reproduce the results for SUN RGBD. The paper mentions 1e-3 whereas the ReadMe mentions 1e-6.
Thanks!

LN使用

在图像翻译任务中,为什么每个block后要使用LN?block的第一层就是LN,现在的话,是连续使用了两个LN层。

Proj from equation 6

Hi,
Can you help me with the code for equation 6 from the paper for segmentation? Where is the projection implemented in the segmentation model?

Sunrgbd data preparation for semantic segmentation

Hi,Thanks for your creative work!
I am trying to reproduce the results for SUN RGBD on semantic segmentation task. The readme only give instructions about how to prepare for NYUDv2. I will appreciate it if some guidance about how to deal with Sun RGBD dataset! Thank you so much~

loss2d_class_error 100.0000

Hi! Is loss2d_class_error 100.0000 normal?

Thanks

[12/21 19:28:05] group-free INFO: Train: [2][30/95] 0head_neg_ratio 0.6538 0head_pos_ratio 0.3462 1head_neg_ratio 0.6538 1head_pos_ratio 0.3462 2head_neg_ratio 0.6538 2head_pos_ratio 0.3462 3head_neg_ratio 0.6538 3head_pos_ratio 0.3462 4head_neg_ratio 0.6538 4head_pos_ratio 0.3462 grad_norm 3.1559 last_neg_ratio 0.6538 last_pos_ratio 0.3462 points_hard_topk4_neg_ratio 0.9848 points_hard_topk4_pos_ratio 0.0152 points_hard_topk4_upper_recall_ratio 1.0000 proposal_neg_ratio 0.6538 proposal_pos_ratio 0.3462
[12/21 19:28:05] group-free INFO: grad_norm: 31.558673858642578
[12/21 19:28:05] group-free INFO: loss 14.4250 loss2d_class_error 100.0000 loss2d_loss_bbox 0.2392 loss2d_loss_ce 0.9980 loss2d_loss_giou 0.6201 query_points_generation_loss 0.0068 sum_heads_box_loss 6.3427 sum_heads_objectness_loss 0.4403 sum_heads_sem_cls_loss 11.6525
[12/21 19:28:05] group-free INFO: last_box_loss 0.9117 last_center_loss 0.5684 last_heading_cls_loss 2.4032 last_heading_reg_loss 0.0199 last_neg_ratio 0.6538 last_objectness_loss 0.0636 last_pos_ratio 0.3462 last_sem_cls_loss 1.6602 last_size_reg_loss 0.0831
[12/21 19:28:05] group-free INFO: proposal_box_loss 0.8211 proposal_center_loss 0.4900 proposal_heading_cls_loss 2.3392 proposal_heading_reg_loss 0.0211 proposal_neg_ratio 0.6538 proposal_objectness_loss 0.0593 proposal_pos_ratio 0.3462 proposal_sem_cls_loss 1.4890 proposal_size_reg_loss 0.0760
[12/21 19:28:05] group-free INFO: 4head_box_loss 0.9417 4head_center_loss 0.5890 4head_heading_cls_loss 2.4262 4head_heading_reg_loss 0.0201 4head_neg_ratio 0.6538 4head_objectness_loss 0.0636 4head_pos_ratio 0.3462 4head_sem_cls_loss 1.7456 4head_size_reg_loss 0.0900
[12/21 19:28:05] group-free INFO: 3head_box_loss 0.9164 3head_center_loss 0.5662 3head_heading_cls_loss 2.4262 3head_heading_reg_loss 0.0206 3head_neg_ratio 0.6538 3head_objectness_loss 0.0620 3head_pos_ratio 0.3462 3head_sem_cls_loss 1.7198 3head_size_reg_loss 0.0870
[12/21 19:28:05] group-free INFO: 2head_box_loss 0.9181 2head_center_loss 0.5710 2head_heading_cls_loss 2.4070 2head_heading_reg_loss 0.0204 2head_neg_ratio 0.6538 2head_objectness_loss 0.0644 2head_pos_ratio 0.3462 2head_sem_cls_loss 1.6655 2head_size_reg_loss 0.0860
[12/21 19:28:05] group-free INFO: 1head_box_loss 0.9180 1head_center_loss 0.5718 1head_heading_cls_loss 2.4213 1head_heading_reg_loss 0.0200 1head_neg_ratio 0.6538 1head_objectness_loss 0.0640 1head_pos_ratio 0.3462 1head_sem_cls_loss 1.6628 1head_size_reg_loss 0.0841
[12/21 19:28:05] group-free INFO: 0head_box_loss 0.9157 0head_center_loss 0.5666 0head_heading_cls_loss 2.4174 0head_heading_reg_loss 0.0201 0head_neg_ratio 0.6538 0head_objectness_loss 0.0633 0head_pos_ratio 0.3462 0head_sem_cls_loss 1.7096 0head_size_reg_loss 0.0873
[12/21 19:28:05] group-free INFO: Mean IoU2d 0.47
[12/21 19:28:05] group-free INFO: loss2d_class_error 100.0000 loss2d_loss_bbox 0.2392 loss2d_loss_ce 0.9980 loss2d_loss_giou 0.6201

Training error on single GPU

Hi, first of all this is excellent work! I am trying to train some data on the segmentation portion.
When I training using the command“ python main.py --backbone mit_b3 -c exp_name --lamda 1e-6 --gpu 0 ” on single gpu on the
nyudv2 dataset. But have an error always.The error in the screenshot below.
微信图片_20230530214548

checkpoint model architrcture

The model architecture of the checkpoint you provided for sematic segmentation is different from the model architecture you provided in the code ?

No pytorch models (.pth) saved during training?

Hi, first of all this is excellent work! I am trying to test some data on the segmentation portion.
When I train it with pretained segformer weight, it saves data.pkl as checkpoints. It is not saving any .pth file. How am I supposed to evaluate it without any .pth file(pytorch model) generated during training? So question is, which 'path_to_pth ' path do you mean here "python main.py --gpu 0 --resume path_to_pth --evaluate" as during training no ".pth" file is generated. Am I supposed to convert the data.pkl to .pth here?
Moreover, what was the total training time in a single gpu with B3 segformer weight?
Thank you!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.