Code Monkey home page Code Monkey logo

3d-unet--tensorflow's Introduction

Non-local U-Nets for Biomedical Image Segmentation

This repository provides the experimental code for our paper "Non-local U-Nets for Biomedical Image Segmentation" accepted by AAAI-20.

This repository includes an (re-)implementation, using updated Tensorflow APIs, of 3D Unet for isointense infant brain image segmentation. Besides, we implement our proposed global aggregation blocks, which modify self-attention layers for 3D Unet. The user can optionally insert the blocks to the standard 3D Unet.

For users who wants to use the standard 3D Unet, you need to modify network.py by removing line 62-67 and 72-79. Do not use "_att_decoding_block_layer" in "_build_network". Should you have any question, open an issue and I will respond.

Created by Zhengyang Wang and Shuiwang Ji at Texas A&M University.

Update

11/10/2019:

Our paper "Non-local U-Nets for Biomedical Image Segmentation" has been accepted by AAAI-20!

10/01/2018:

  1. The code now works when we have subjects of different spatial sizes.

  2. During training, validation and prediction, you only need to change the configures in configure.py. In the old version, you have to change configures correspondingly in several files like main.py, utils/input_fn.py, etc.

Publication

The paper is available at https://www.aaai.org/Papers/AAAI/2020GB/AAAI-WangZ.5933.pdf.

If using this code , please cite our paper.

@inproceedings{wang2020non,
  title={Non-local U-Nets for Biomedical Image Segmentation},
  author={Wang, Zhengyang and Zou, Na and Shen, Dinggang and Ji, Shuiwang},
  booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
  year={2020}
}

Dataset

The dataset is from UNC and used as the training dataset in iSeg-2017. Basically, it is composed of multi-modality isointense infant brain MR images (3D) of 10 subjects. Each subject has two 3D images (modalities), T1WI and T2WI, with a manually created 3D segmentation label.

It is an important step in brain development study to perform automatic segmentation of infant brain magnetic resonance (MR) images into white matter (WM), grey matter (GM) and cerebrospinal fluid (CSF) regions. This task is especially challenging in the isointense stage (approximately 6-8 months of age) when WM and GM exhibit similar levels of intensities in MR images.

Results

Here provides a glance at the effect of our proposed model. The baseline is 3-D Fully Convolutional Networks for Multimodal Isointense Infant Brain Image Segmentation.

Visualization of the segmentation results on the 10th subject by our proposed model and the baseline model: model

Comparison of training processes between our proposed model and the baseline model: model

System requirement

Programming language

Python 3.5+

Python Packages

tensorflow-gpu 1.7 - 1.10, numpy, scipy

Configure the network

All network hyperparameters are configured in main.py.

Training

raw_data_dir:the directory where the raw data is stored

data_dir: the directory where the input data is stored

num_training_subs: the number of subjects used for training

train_epochs: the number of epochs to use for training

epochs_per_eval: the number of training epochs to run between evaluations

batch_size: the number of examples processed in each training batch

learning_rate: learning rate

weight_decay: weight decay rate

num_parallel_calls: The number of records that are processed in parallel during input processing. This can be optimized per data set but for generally homogeneous data sets, should be approximately the number of available CPU cores.

model_dir: the directory where the model will be stored

Validation

patch_size: spatial size of patches

overlap_step: overlap step size when performing testing

validation_id: 1-10, which subject is used for validation

checkpoint_num: which checkpoint is used for validation

save_dir: the directory where the prediction is stored

raw_data_dir: the directory where the raw data is stored

Network architecture

network_depth: the network depth

num_classes: the number of classes

num_filters: number of filters for initial_conv

Training and Evaluation

Preprocess data

Before training, we preprocess data into tfrecords format, which is optimized for Tensorflow. A good example of how to preprocess data and use tfrecords files as inputs can be found in generate_tfrecord.py and input_fn.py.

Start training

After configure configure.py, we can start to train by running

python main.py

Training process visualization

We employ tensorboard to visualize the training process.

tensorboard --logdir=model_dir/

Testing and prediction

If you want to do testing, first make predictions by running

python main.py --option='predict'

Then, if you have access to labels, setup evaluation.py and run

python evaluation.py

You may also visualize the results. setup visualize.py and run

python visualize.py

3d-unet--tensorflow's People

Contributors

zhengyang-wang avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

3d-unet--tensorflow's Issues

opening of file

Hi Zhengyang,
I am trying to get results, but it gives results for only one slice, how we can open it slice wise like Dicom or nifti as we have our input data.
I am waiting for your kind reply.
Thanks

a AttributeError

When I run main.py Error reported: AttributeError: 'int' object has no attribute 'value', errors in attention.py line 152。Is it because I use tensorflow2.0?

Invalid argument error

Hi there,

I am trying to run your code on my data where I have a very similar multimodal problem. After I generate tfrecords successfully, I run main.py and keep getting:
"Invalid argument: Input to reshape is a tensor with 2823480 values, but the requested shape has 1411740
[[{{node Reshape_2}}]]
[[input/IteratorGetNext]]"
error. Interestingly the ratio of the numbers is always 2:1. Moreover, my data size is 13015669 and they are all dividers of 1411740. I did not change anything on the code other than adapting generate_tfrecords and input_fn functions according to my data. Do you have any recommendations to solve this error?

dataset

Where should we put the 2017 dataset after downloading it from the website?

The problem with dataset

Hello,
Thanks a lot for your open sources,but i am still have problems about your dataset.
because i wanna to train my own dataset ,which is contained lots of slices of 2D images stored with PNG format.
so could u please tell something about how to process or the detail with your dataset?

如果只使用T1数据集,我需要在input_fn.py中怎么修改

我把所有需要修改的部分都修改了,但是在运行中出现了这样的错误,我看了一下,发生在这里,麻烦您看一下
image

在第135行,错误为
ValueError: Dimension must be 2 but is 3 for 'unstack' (op: 'Unpack') with input shapes: [32,32,32,3].
如果你有时间的话帮我看一眼

label data format

Hello, I'm trying to use your network. For the multi-labelled segmentation, what is the proper structure in case of the labelled data?

Your code seems to load *.tfrecord (generated from *.hdr by generate_tfrecord.py). I'm confused about how to make HDR files as per your intention because each slice can have multiple label binary images.

Data set problem

Hello, iSeg2017 and 2019 data sets cannot be downloaded now. Can you provide other ways to download them?Such as BaiduCloud and so on.
Any help would be appreciated.

global_attention_3d error

it should be

new_shape = tf.concat([tf.shape(input=q)[0:-1],[v.shape[-1]]],0)

instead of

new_shape=tf.concat([tf.shape(input=q)[0:-1],[v.shape[-1].value]],0)

Cpu exhauxtion when excuting prediction

Thank you for your contribution first. These days, I use your code to do a segmentation of a set of images whose size is about 200200200. It's OK when train the net. But when I do the prediction, CPU is used up in a very fast speed, and the procedure is blocked.
How can I fix this problem?

InvalidArgumentError

Hi,

Thank you for this implementation of UNET.

I am trying to use your code and replicate your results. However I am getting the following error.

I was wondering if you can help.

`During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/home/potis/Desktop/Pycharm/unettf/main.py", line 30, in
tf.app.run()
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/platform/app.py", line 125, in run
_sys.exit(main(argv))
File "/home/potis/Desktop/Pycharm/unettf/main.py", line 23, in main
getattr(model, args.option)()
File "/home/potis/Desktop/Pycharm/unettf/model.py", line 143, in train
classifier.train(input_fn=input_fn_train, hooks=[logging_hook])
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py", line 366, in train
loss = self._train_model(input_fn, hooks, saving_listeners)
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py", line 1119, in _train_model
return self._train_model_default(input_fn, hooks, saving_listeners)
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py", line 1135, in _train_model_default
saving_listeners)
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py", line 1333, in _train_with_estimator_spec
log_step_count_steps=self._config.log_step_count_steps) as mon_sess:
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py", line 415, in MonitoredTrainingSession
stop_grace_period_secs=stop_grace_period_secs)
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py", line 826, in init
stop_grace_period_secs=stop_grace_period_secs)
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py", line 549, in init
self._sess = _RecoverableSession(self._coordinated_creator)
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py", line 1012, in init
_WrappedSession.init(self, self._create_session())
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py", line 1017, in _create_session
return self._sess_creator.create_session()
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py", line 706, in create_session
self.tf_sess = self._session_creator.create_session()
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py", line 477, in create_session
init_fn=self._scaffold.init_fn)
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/training/session_manager.py", line 281, in prepare_session
config=config)
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/training/session_manager.py", line 211, in _restore_checkpoint
saver.restore(sess, ckpt.model_checkpoint_path)
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1752, in restore
{self.saver_def.filename_tensor_name: save_path})
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 900, in run
run_metadata_ptr)
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1135, in _run
feed_dict_tensor, options, run_metadata)
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1316, in _do_run
run_metadata)
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1335, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Assign requires shapes of both tensors to match. lhs shape= [3,3,3,2,32] rhs shape= [3,3,3,1,32]
[[Node: save/Assign_106 = Assign[T=DT_FLOAT, _class=["loc:@conv3d/kernel"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/device:CPU:0"](conv3d/kernel, save/RestoreV2:106)]]

Caused by op 'save/Assign_106', defined at:
File "/home/potis/Desktop/Pycharm/unettf/main.py", line 30, in
tf.app.run()
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/platform/app.py", line 125, in run
_sys.exit(main(argv))
File "/home/potis/Desktop/Pycharm/unettf/main.py", line 23, in main
getattr(model, args.option)()
File "/home/potis/Desktop/Pycharm/unettf/model.py", line 143, in train
classifier.train(input_fn=input_fn_train, hooks=[logging_hook])
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py", line 366, in train
loss = self._train_model(input_fn, hooks, saving_listeners)
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py", line 1119, in _train_model
return self._train_model_default(input_fn, hooks, saving_listeners)
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py", line 1135, in _train_model_default
saving_listeners)
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py", line 1333, in _train_with_estimator_spec
log_step_count_steps=self._config.log_step_count_steps) as mon_sess:
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py", line 415, in MonitoredTrainingSession
stop_grace_period_secs=stop_grace_period_secs)
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py", line 826, in init
stop_grace_period_secs=stop_grace_period_secs)
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py", line 549, in init
self._sess = _RecoverableSession(self._coordinated_creator)
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py", line 1012, in init
_WrappedSession.init(self, self._create_session())
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py", line 1017, in _create_session
return self._sess_creator.create_session()
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py", line 706, in create_session
self.tf_sess = self._session_creator.create_session()
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py", line 468, in create_session
self._scaffold.finalize()
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py", line 214, in finalize
self._saver.build()
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1296, in build
self._build(self._filename, build_save=True, build_restore=True)
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1333, in _build
build_save=build_save, build_restore=build_restore)
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 775, in _build_internal
restore_sequentially, reshape)
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 453, in _AddShardedRestoreOps
name="restore_shard"))
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 422, in _AddRestoreOps
assign_ops.append(saveable.restore(saveable_tensors, shapes))
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 113, in restore
self.op.get_shape().is_fully_defined())
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/ops/state_ops.py", line 219, in assign
validate_shape=validate_shape)
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/ops/gen_state_ops.py", line 60, in assign
use_locking=use_locking, name=name)
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
op_def=op_def)
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 3414, in create_op
op_def=op_def)
File "/home/potis/Desktop/Python_projects/VirtualEnv/skynet3/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1740, in init
self._traceback = self._graph._extract_stack() # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): Assign requires shapes of both tensors to match. lhs shape= [3,3,3,2,32] rhs shape= [3,3,3,1,32]
[[Node: save/Assign_106 = Assign[T=DT_FLOAT, _class=["loc:@conv3d/kernel"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/device:CPU:0"](conv3d/kernel, save/RestoreV2:106)]]

Process finished with exit code 1
`

A little bug in generate_tfrecord.py

On line 279, the 'subject_name' in 'valid_filename' should probably be 'valid_subject_name', otherwise the patch file for verification will not be generated.

Performance issues in input_fn.py (by P3)

Hello! I've found a performance issue in input_fn.py: batch() should be called before map(), which could make your program more efficient. Here is the tensorflow document to support it.

Detailed description is listed below:

Besides, you need to check the function called in map()(e.g., normalize_image called indataset.map(normalize_image, num_parallel_calls=num_parallel_calls)) whether to be affected or not to make the changed code work properly. For example, if normalize_image needs data with shape (x, y, z) as its input before fix, it would require data with shape (batch_size, x, y, z).

Looking forward to your reply. Btw, I am very glad to create a PR to fix it if you are too busy.

Dataset

Can the dataset be provided?

I went to the official website, the registration page could not be opened and could not be downloaded.

Thank you

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.