Code Monkey home page Code Monkey logo

mpscl's Introduction

Margin Preserving Self-paced Contrastive Learning Towards Domain Adaptation for Medical Image Segmentation

Pytorch implementation of our MPSCL for adapting semantic segmentation from the MR/CT modality (source domain) to CT/MR modality (target domain).

Paper

Margin Preserving Self-paced Contrastive Learning Towards Domain Adaptation for Medical Image Segmentation
IEEE Journal of Biomedical and Health Informatics (JBHI) Early Access

Please cite our paper if you find it useful for your research.

@ARTICLE{9672690,  
author={Liu, Zhizhe and Zhu, Zhenfeng and Zheng, Shuai and Liu, Yang and Zhou, Jiayu and Zhao, Yao},  
journal={IEEE Journal of Biomedical and Health Informatics},   title={Margin Preserving Self-Paced Contrastive Learning Towards Domain Adaptation for Medical Image Segmentation},   
year={2022},  
volume={26},  
number={2},  
pages={638-647},  
doi={10.1109/JBHI.2022.3140853}}

Dependencies

This code requires the following

  • Python 3.6
  • Pytorch 1.3.0

Configure Dataset

  • Thanks to SIFA for sharing the pre-processed data. We have changed the tfrecords data to Numpy. Plz download the data from data and put it in ./data folder
  • Plz run ./dataset/create_datalist.py to create the file containing training data path.
  • Plz run ./dataset/create_test_datalist.py to create the file containing testing data path.

Configure Pretrained Model

  • Plz download the pretrained model from pretrained_model and put it in ./pretrained_model folder The pretrained model file contains two folder:

training contains the initialized models of our MPSCL for generating representative category prototypes and informative pseudo-labels, as described in the implementation details of our paper. testing contains the models corresponding to the results in our paper

Training

To train MPSCL

  • cd <root_dir>/MPSCL/scripts/

For MR2CT

  • CUDA_VISIBLE_DEVICES=#device_id# python train.py --cfg ./configs/MPSCL_MR2CT.yml

For CT2MR

  • CUDA_VISIBLE_DEVICES=#device_id# python train.py --cfg ./configs/MPSCL_CT2MR.yml

Testing

To test MPSCL

If you want to test our released pretrained model

  • cd <root_dir>/MPSCL/scripts

For MR2CT

  • CUDA_VISIBLE_DEVICES=#device_id# python test.py --target_modality 'CT' --pretrained_model_pth '../pretrained_model/testing/MPSCL_MR2CT_best.pth'

For CT2MR

  • CUDA_VISIBLE_DEVICES=#device_id# python test.py --target_modality 'MR' --pretrained_model_pth '../pretrained_model/testing/MPSCL_CT2MR_best.pth'

If you want to test your model

For MR2CT

  • CUDA_VISIBLE_DEVICES=#device_id# python test.py --target_modality 'CT' --pretrained_model_pth 'your model path'

For CT2MR

  • CUDA_VISIBLE_DEVICES=#device_id# python test.py --target_modality 'MR' --pretrained_model_pth 'your model path'

Acknowledgements

This codebase is heavily borrowed from AdvEnt and SupContrast

mpscl's People

Contributors

tfboys-lzz avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

mingxuangu

mpscl's Issues

pretrain on source domain

Dear author, I have a question. If i want to run train_UDA.py to train MPSCL, I need a pretrained model on source domain which you have provided, but you did not provide the code to pretrain model on source data?

Reproduction of the results of the paper.

I found that using the provided yaml file directly only reproduces half of the results in the paper. How can I reproduce the results reported in the paper? For example, the number of optimization iterations, batch size, etc.

The warmup CT2MR pre-trained model.

Hello, Nice work. But I have a question about how to obtain the warmup model of CT2MR.

The pre-trained model could achieve ~54.1% Dice score on target test_mr. And when I loaded the warmup model and continued the warmup stage, I could obtain a ~63% Dice score on target test_mr, which was very close to your reported result of AdvEnt/AdaSeg in the MPSCL paper.

BUT, when I conducted the CT2MR warmup stage with the default config.yml from the scratch, I could only achieve ~30% Dice on target test_mr, which makes me very confused.

Could you please provide some advice? Very thanks.

Getting nan results when training

Hi, thanks for the work. I have a question about the training process. When I run the training, the first 25 iterations seem to work fine, but on the 26th iteration and afterwards, I am only getting nan for all of the numerical values (for loss_seg_src_aux, loss_dice_src_aux, etc.). This seems to be caused by the model itself predicting nan.
Also, the testing process works fine, so it is only the training process that I am having the issue with.
Do you have any idea what might be causing this issue or how I can resolve it? Thank you for the help.

About the cosine similarity between the pixel feature and the anchor

In the paper, logits are defined as the cosine similarity between the pixel feature and the anchor. However, it seems the code uses "cosine = torch.matmul(anchor_feature, class_center_feas)" directly instead of cosine similarity. Maybe I do not fully understand the paper or code. Could you explain this issue? Thanks a lot.

CLASS_CENTER_FEA_INIT

您好,感谢您的工作和代码,请问一下初始prototype是如何获得的,是随机的吗还是从数据集中求得的

train on another dataset

Hi author! I have a doubt about using your code to train on another dataset. At first, I need to pre-train a segmentation model on the source domain to obtain the initial category prototypes. Is it necessary to pre-train main and aux discriminators?

about supervised training

Hi! I am very interested in your work, I used your code for supervised training on Cardiac CT images, my average dice result was only 86.6% at best. But the result of supervised training on Cardiac CT in your paper is 90.40%. I would like to ask if you changed the hyperparameters while doing supervised training ?

How to initialize the category center

Hi, it's me again.

I want to try your code on another dataset, the abdominal dataset, which was reported in the SIFA[TMI] paper. But I cannot find the initialization code for calculating category centers. Could you plz release the corresponding code?

Thanks!

shape error

Hi, I'm pretraining the deeplabv2 to generate prototypes on another dataset. I find that the prediction is down-sampled to 33 $\times$ 33 but not up-sampled. Why is that? Did you also down-sample labels? But it seems that you did not down-sample labels before calculating BCE and dice loss in train_UDA.py.

Note that I notice that there is a label_downsample function in train_UDA.py. But it is only used in the update_class_center_iter function, and that does not affect the shape of labels used for BCE and dice loss calculation.

Thanks.

data processing

Hello, Thanks for you code, I like you work.

In domain_adaptation/eval_UDA.py line 68 :
img_mean = np.array((104.00698793, 116.66876762, 122.67891434), dtype=np.float32)
if I want to test my own dataset, I need to replace img_mean up to my dataset right?
Is the value equal to the mean value of my test dataset?

how to train from scratch

dear author, i want to train from scratch for another dataset, how can i train it, first train the warmup.yml for 4000 iterations,second using the trained parameters for training mpscl.yml?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.