Code Monkey home page Code Monkey logo

medical-cross-modality-domain-adaptation's Introduction

Medical Cross-Modality Domain Adaptation (Med-CMDA)

Here are implementations for paper:

PnP-AdaNet: Plug-and-Play Adversarial Domain Adaptation Network with a Benchmark at Cross-modality Cardiac Segmentation. (https://arxiv.org/abs/1812.07907) (long version)

Unsupervised Cross-Modality Domain Adaptation of ConvNets for Biomedical Image Segmentations with Adversarial Loss, IJCAI, pp. 691-697, 2018. (https://arxiv.org/abs/1804.10916) (short version)

Introduction

Deep convolutional networks have demonstrated the state-of-the-art performance on various medical image computing tasks. However, the generalization capability of deep models on test data with different distributions remain as a major challenge. In this project, we tackle an interesting problem setting of unsupervised domain adaptation between CT and MRI, by proposing a plug-and-play adversarial domain adaptation network to align feature spaces of both domains presenting significant domain shift.

Usage

0. Packages

nibabel==2.1.0
nilearn==0.3.1
numpy==1.13.3
tensorflow-gpu==1.4.0 
python 2.7

(Note: other tf versions not tested, and please notify us if it also works :0)

-Updated on Mar 2020: Our user friendly confirmed that TensorFlow 1.6 with CUDA 10.0 also works.

1. Data preprocessing

The original data of cardiac 20 CT and 20 MR images come from MMWHS Challenge, with the original data release license also applies to this project.

The pre-processed and augmented training data repository can be downloaded here, in the form of tfrecord for direct load. The testing CT data can be downloaded here, in the form of .nii with heart region cropped.
The same data is also used for our SIFA paper.

Briefly, the images were 1. cropped centering at the heart region, with four cardiac substructures selected for segmentation considering mutual visibility in 2D view; 2. for each 3D cropped image top 2/% of its intensity histogram was cut off for alleviating artifacts; 3. each 3D image was then normalized to zero-mean, unit standard diviation; 4. 2D coronal slices were sampled with data augmentation.

To adapt a segmenter from MR to CT, use:
ct_train_tfs: training slices from 14 cases, 600 slices each, 8400 slices in total.
ct_val_tfs: validation slices from 2 cases, 600 slices each. 1200 slices in total.
mr_train_tfs: training slices from 16 cases, 600 slices each, 9600 slices in total.
mr_val:tfs: validation slices from 4 cases, 600 slices each, 2400 slices in total.
Since we are doing MR to CT adaptation, we don't need a real MR testing set

For the ease of training, after data augmentation, training samples are expected to be written into tfrecord with the following format:

feature = {
            # image size, dimensions of 3 consecutive slices
            'dsize_dim0': tf.FixedLenFeature([], tf.int64), # 256
            'dsize_dim1': tf.FixedLenFeature([], tf.int64), # 256
            'dsize_dim2': tf.FixedLenFeature([], tf.int64), # 3
            # label size, dimension of the middle slice
            'lsize_dim0': tf.FixedLenFeature([], tf.int64), # 256
            'lsize_dim1': tf.FixedLenFeature([], tf.int64), # 256
            'lsize_dim2': tf.FixedLenFeature([], tf.int64), # 1
            # image slices of size [256, 256, 3]
            'data_vol': tf.FixedLenFeature([], tf.string),
            # label slice of size [256, 256, 1]
            'label_vol': tf.FixedLenFeature([], tf.string)}

2. Training base segmentation network

Run train_segmenter.py, where training configurations are specified.

This calls source_segmenter.py, where network structure and training function are defined.

3. Training adversarial domain adaptation

3.1 Warming-up the discriminator

To obtain a good initial estimation of Wasserstein distances between feature maps of two domains, we first pre-train the feature domain discriminator. In order to do this, run

python train_gan.py --phase pre-train

3.2 Training adversarial domain adaptation

After warming the discriminator up, we can then jointly train the feature domain discriminator and the domain adaptation module (generator). To do this, run

python train_gan.py --phase train-gan

The experiment configurations can be found in train_gan.py. It calls adversarial.py, where network structures and training functions are defined.

4. Evaluation

The evaluation code has been released with our SIFA repo, please refer to here

5. Citations

If you make use of the code, please cite the paper in resulting publications.

@inproceedings{dou2018unsupervised,
  title={Unsupervised cross-modality domain adaptation of convnets for biomedical image segmentations with adversarial loss},
  author={Dou, Qi and Ouyang, Cheng and Chen, Cheng and Chen, Hao and Heng, Pheng-Ann},
  booktitle={Proceedings of the 27th International Joint Conference on Artificial Intelligence (IJCAI)},
  pages={691--697},
  year={2018}
}

or

@article{dou2018pnp,
  title={PnP-AdaNet: Plug-and-play adversarial domain adaptation network with a benchmark at cross-modality cardiac segmentation},
  author={Dou, Qi and Ouyang, Cheng and Chen, Cheng and Chen, Hao and Glocker, Ben and Zhuang, Xiahai and Heng, Pheng-Ann},
  journal={arXiv preprint arXiv:1812.07907},
  year={2018}
}

6. Acknowledgements

Special thanks to Ryan Neph for the PyMedImage package, which was used for debugging in the original project.

Contact

General questions, please email [email protected] (Qi Dou) and [email protected] (Cheng Ouyang).
Questions on data license, please contact [email protected] (Qi Dou) and [email protected] (Xiahai Zhuang).

medical-cross-modality-domain-adaptation's People

Contributors

carrend avatar cheng-01037 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

medical-cross-modality-domain-adaptation's Issues

Test code

Hi! I am very interested in your work. The model has been trained, but the test code is lacking. Can you upload it to the project? Before, in an answer, you suggested to use evaluate.py of the sifa project, but the training model is different, and the test will prompt the lack of parameters. Looking forward to your reply!

ValueError: At least two variables have the same name: BatchNorm/beta

when running this code,I met some problems:
File "Medical-Cross-Modality-Domain-Adaptation-master/source_segmenter.py", line 282, in r
estore
saver = tf.train.Saver(tf.contrib.framework.get_variables() + tf.get_collection_ref("internal_batchnorm_variables") )
File "/opt/conda/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 825, in init
self.build()
File "/opt/conda/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 837, in build
self._build(self._filename, build_save=True, build_restore=True)
File "/opt/conda/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 875, in _build
build_restore=build_restore)
File "/opt/conda/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 482, in _build_internal
names_to_saveables)
File "/opt/conda/lib/python3.6/site-packages/tensorflow/python/training/saving/saveable_object_util.py", line 335, in valida
te_and_slice_inputs
names_to_saveables = op_list_to_dict(names_to_saveables)
File "/opt/conda/lib/python3.6/site-packages/tensorflow/python/training/saving/saveable_object_util.py", line 292, in op_list_to_dict
name)
ValueError: At least two variables have the same name: BatchNorm/beta

I try to change the tensorFlow version to 1.14, but It didn't work. So what should I do to solve this problems? Thanks very much!

The Preprocess Data Issue

Dear author, I'm trying to train your code on my MRI dataset but I encountered some problems when preprocessing the training data. Would you please kindly share the preprocessing script fit for training the network?Thanks a lot!!!

OutOfRangeError: RandomShuffleQueue '_1_shuffle_batch/random_shuffle_queue' is closed and has insufficient elements (requested 1, current size 0)

OutOfRangeError: RandomShuffleQueue '_1_shuffle_batch/random_shuffle_queue' is closed and has insufficient elements (requested 1, current size 0)
first, thanks for your code, and i have this problem, i don't know why the data didn't read.
i download the dataset, and i have change the address (in ct_train_list) into absolute address(like /home/lijiachen/下载/Medical-Cross-Modality-Domain-Adaptation-master/lists/mr_train_tfs/mr_train_slice0.tfrecords), but it didn't work.
i also try change batchsize, capacity, num_threads, but it didn't work.
can you help me to solve this problem? it confused me many days, please. thx

ValueError: At least two variables have the same name: BatchNorm/beta

when run source_segmenter.py,
"if last_ckpt and last_ckpt.model_checkpoint_path:
self.net.restore(sess, last_ckpt.model_checkpoint_path)"
showed the error"saver = tf.train.Saver(tf.contrib.framework.get_variables() + tf.get_collection_ref("internal_batchnorm_variables") )",
and i find the following code in restore lead this problem
"saver = tf.train.Saver(tf.contrib.framework.get_variables() + tf.get_collection_ref("internal_batchnorm_variables") )"

binary classification error

  1. how to do if I only have 2 classes to classify? In adversarial.py, when defing the classifier, it need to concatenate some features into 32 channels, but it can only get 26 channels for binary classification task. (333-338 rows in original code)
  2. the code will corrupt with the error "At least two variables have the same name: adapt_1/adapt_1_1/beta"

thanks!

Exception has occurred: OutOfRangeError RandomShuffleQueue '_2_shuffle_batch/random_shuffle_queue' is closed and has insufficient elements (requested 10, current size 0)

There is no problem with my file path, but this error occurred.

Exception has occurred: OutOfRangeError
RandomShuffleQueue '_2_shuffle_batch/random_shuffle_queue' is closed and has insufficient elements (requested 10, current size 0)
[[Node: shuffle_batch = QueueDequeueManyV2[component_types=[DT_FLOAT, DT_FLOAT, DT_STRING], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/device:CPU:0"](shuffle_batch/random_shuffle_queue, shuffle_batch/n)]]
[[Node: shuffle_batch/_67 = _Recvclient_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device_incarnation=1, tensor_name="edge_5_shuffle_batch", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"]]

Caused by op u'shuffle_batch', defined at:
File "/usr/lib/python2.7/runpy.py", line 174, in _run_module_as_main
"main", fname, loader, pkg_name)
File "/usr/lib/python2.7/runpy.py", line 72, in _run_code
exec code in run_globals
File "/home/hfcui/.vscode-server/extensions/ms-python.python-2020.9.114305/pythonFiles/lib/python/debugpy/main.py", line 45, in
cli.main()
File "/home/hfcui/.vscode-server/extensions/ms-python.python-2020.9.114305/pythonFiles/lib/python/debugpy/../debugpy/server/cli.py", line 430, in main
run()
File "/home/hfcui/.vscode-server/extensions/ms-python.python-2020.9.114305/pythonFiles/lib/python/debugpy/../debugpy/server/cli.py", line 267, in run_file
runpy.run_path(options.target, run_name=compat.force_str("main"))
File "/usr/lib/python2.7/runpy.py", line 252, in run_path
return _run_module_code(code, init_globals, run_name, path_name)
File "/usr/lib/python2.7/runpy.py", line 82, in _run_module_code
mod_name, mod_fname, mod_loader, pkg_name)
File "/usr/lib/python2.7/runpy.py", line 72, in _run_code
exec code in run_globals
File "/media/nomachine/project/Medical-Cross-Modality-Domain-Adaptation/train_segmenter.py", line 88, in
main()
File "/media/nomachine/project/Medical-Cross-Modality-Domain-Adaptation/train_segmenter.py", line 82, in main
trainer.train(output_path = output_path, training_iters = training_iters, epochs = epochs, restore = True, restored_path = restored_path)
File "/media/nomachine/project/Medical-Cross-Modality-Domain-Adaptation/source_segmenter.py", line 466, in train
feed_all, feed_fid = self.next_batch(self.train_queue)
File "/media/nomachine/project/Medical-Cross-Modality-Domain-Adaptation/source_segmenter.py", line 351, in next_batch
num_threads = num_threads, min_after_dequeue = min_after_dequeue)
File "/home/hfcui/.local/lib/python2.7/site-packages/tensorflow/python/training/input.py", line 1225, in shuffle_batch
name=name)
File "/home/hfcui/.local/lib/python2.7/site-packages/tensorflow/python/training/input.py", line 796, in _shuffle_batch
dequeued = queue.dequeue_many(batch_size, name=name)
File "/home/hfcui/.local/lib/python2.7/site-packages/tensorflow/python/ops/data_flow_ops.py", line 464, in dequeue_many
self._queue_ref, n=n, component_types=self._dtypes, name=name)
File "/home/hfcui/.local/lib/python2.7/site-packages/tensorflow/python/ops/gen_data_flow_ops.py", line 2418, in _queue_dequeue_many_v2
component_types=component_types, timeout_ms=timeout_ms, name=name)
File "/home/hfcui/.local/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
op_def=op_def)
File "/home/hfcui/.local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2956, in create_op
op_def=op_def)
File "/home/hfcui/.local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1470, in init
self._traceback = self._graph._extract_stack() # pylint: disable=protected-access

OutOfRangeError (see above for traceback): RandomShuffleQueue '_2_shuffle_batch/random_shuffle_queue' is closed and has insufficient elements (requested 10, current size 0)
[[Node: shuffle_batch = QueueDequeueManyV2[component_types=[DT_FLOAT, DT_FLOAT, DT_STRING], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/device:CPU:0"](shuffle_batch/random_shuffle_queue, shuffle_batch/n)]]
[[Node: shuffle_batch/_67 = _Recvclient_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device_incarnation=1, tensor_name="edge_5_shuffle_batch", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"]]
File "/media/nomachine/project/Medical-Cross-Modality-Domain-Adaptation/source_segmenter.py", line 478, in train
batch, fid = sess.run([feed_all, feed_fid])
File "/media/nomachine/project/Medical-Cross-Modality-Domain-Adaptation/train_segmenter.py", line 82, in main
trainer.train(output_path = output_path, training_iters = training_iters, epochs = epochs, restore = True, restored_path = restored_path)
File "/media/nomachine/project/Medical-Cross-Modality-Domain-Adaptation/train_segmenter.py", line 88, in
main()

Code for making the tfrecords ?

Hello authors and thank you for your work !
I am using my own augmentation on the whs data, but I'm not able to write to tfrecords as you specify below :

"For the ease of training, after data augmentation, training samples are expected to be written into tfrecord with the following format:
feature = {
# image size, dimensions of 3 consecutive slices
'dsize_dim0': tf.FixedLenFeature([], tf.int64), # 256
'dsize_dim1': tf.FixedLenFeature([], tf.int64), # 256
'dsize_dim2': tf.FixedLenFeature([], tf.int64), # 3
# label size, dimension of the middle slice
'lsize_dim0': tf.FixedLenFeature([], tf.int64), # 256
'lsize_dim1': tf.FixedLenFeature([], tf.int64), # 256
'lsize_dim2': tf.FixedLenFeature([], tf.int64), # 1
# image slices of size [256, 256, 3]
'data_vol': tf.FixedLenFeature([], tf.string),
# label slice of size [256, 256, 1]
'label_vol': tf.FixedLenFeature([], tf.string)}"

I followed this tutorial here:
https://gist.github.com/CihanSoylu/02117f2d77136baf41ddb46789d8a331#file-tfrecords-and-tf-train-example-ipynb

basically my script does everything above and serializes the slices like this:

def serialize_example(image, label):
'dsize_dim0': _int64_feature(256),
'dsize_dim1': _int64_feature(256),
'dsize_dim2': _int64_feature(3),
'lsize_dim0': _int64_feature(256),
'lsize_dim1': _int64_feature(256),
'lsize_dim2': _int64_feature(1),
'data_vol': _bytes_feature(image),
'label_vol': _bytes_feature(label)
}
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
return example_proto.SerializeToString()

I save to a new folder. It appears to have worked and that all my augmented slices have been saved in the above tfrecords. I also did the new list mr_train_list_augmented with the path just like yours. But then your code train_segmenter cannot read the tfrecords.

The error is
tensorflow.python.framework.errors_impl.OutOfRangeError: RandomShuffleQueue '_1_shuffle_batch/random_shuffle_queue' is closed and has insufficient elements (requested 10, current size 0)

Do you have a code available to make the tfrecords from nii or from tensors ?

Thanks !
Mathilde

Problem with test code

Dear author!
Thanks for your work!
Your project does not include the json file in the evaluate.py file:
(line 231) main(config_filename='./config_param.json'),
and SIFA is a different network model. How can evaluate.py files be used as test files for this project?
Looking forward to your answer.

Missing BN variable

Hi ! In the training process, if restore = True, the code will report an error that there are no two variables of BN. Is it necessary to save these two variables during the process of saving the model, or only the last model with the missing BN variable? ( TF1.9.0)

evaluation code

Hi, thanks for the impressive work.

The readme said the evaluation code will be released in the future. Then, may I ask whether you can release the code recently?

BTW, TensorFlow 1.6 and CUDA 10.0 can run the code successfully.

Thanks.

PnpAda数据集预处理

您好~请问数据集中连续的slice是属于同一个case的么?
例如:ct_train_slice0.tfrecords 至 ct_train_slice599.tfrecords 属于同一个case么?

Available dataset

Hi @carrenD

We just saw your paper and we think it is very interesting. We were wondering when are you planning to make also the data publicly available so that others can have access and compare to your method.

Best regards,

Jose

run train_gan.py self.compact_pred = tf.argmax(self.predicter, 3) # predictions AttributeError: 'Full_DRN' object has no attribute 'predicter'

Warming-up the discriminator,when run train_gan.py, this error occurred.

File "D:\E\python\domian-adaptation\Medical-Cross-Modality-Domain-Adaptation-master\Medical-Cross-Modality-Domain-Adaptation-master\adversarial.py", line 105, in init
self.compact_pred = tf.argmax(self.predicter, 3) # predictions
AttributeError: 'Full_DRN' object has no attribute 'predicter'

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.