Code Monkey home page Code Monkey logo

multi-scale-attention's Introduction

Medical Image Segmentation with Guided Attention

This repository contains the code of our paper:
"'Multi-scale self-guided attention for medical image segmentation'", which has been recently accepted at the Journal of Biomedical And Health Informatics (JBHI).

Abstract

Even though convolutional neural networks (CNNs) are driving progress in medical image segmentation, standard models still have some drawbacks. First, the use of multi-scale approaches, i.e., encoder-decoder architectures, leads to a redundant use of information, where similar low-level features are extracted multiple times at multiple scales. Second, long-range feature dependencies are not efficiently modeled, resulting in nonoptimal discriminative feature representations associated with each semantic class. In this paper we attempt to overcome these limitations with the proposed architecture, by capturing richer contextual dependencies based on the use of guided self-attention mechanisms. This approach is able to integrate local features with their corresponding global dependencies, as well as highlight interdependent channel maps in an adaptive manner. Further, the additional loss between different modules guides the attention mechanisms to neglect irrelevant information and focus on more discriminant regions of the image by emphasizing relevant feature associations. We evaluate the proposed model in the context of abdominal organ segmentation on magnetic resonance imaging (MRI). A series of ablation experiments support the importance of these attention modules in the proposed architecture. In addition, compared to other state-of-the-art segmentation networks our model yields better segmentation performance, increasing the accuracy of the predictions while reducing the standard deviation. This demonstrates the efficiency of our approach to generate precise and reliable automatic segmentations of medical images.

Design of the Proposed Model

model

Results

Result

Requirements

  • The code has been written in Python (3.6) and requires pyTorch (version 1.1.0)
  • Install the dependencies using pip install -r requirements.txt

Preparing your data

You have to split your data into three folders: train/val/test. Each folder will contain two sub-folders: Img and GT, which contain the png files for the images and their corresponding ground truths. The naming of these images is important, as the code to save the results temporarily to compute the 3D DSC, for example, is sensitive to their names.

Specifically, the convention we follow for the names is as follows:

  • Subj_Xslice_Y.png where X indicates the subject number (or ID) and Y is the slice number within the whole volume. (Do not use 0 padding for numbers, i.e., the first slice should be 1 and not 01)
  • The corresponding mask must be named in the same way as the image.

An example of a sample image is added in dataset

Running the code

Note: Set the data path appropriately in src/main.py before running the code.

To run the code you simply need to use the following script:

bash train.sh

If you use this code for your research, please consider citing our paper:

@article{sinha2020multi,
  title={Multi-scale self-guided attention for medical image segmentation.},
  author={Sinha, A and Dolz, J},
  journal={IEEE Journal of Biomedical and Health Informatics},
  year={2020}
}

multi-scale-attention's People

Contributors

josedolz avatar sinashish avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

multi-scale-attention's Issues

Results not matching with paper?

Hi, I am trying to reproduce the results but could not get the results closer to those reported in paper. Did you utilize validation set for reporting the results or test set (in table 4)? Also, could you provide with the training, test and validation splits?

AttributeError: cannot assign module before Module.__init__() call

Hi, @sinAshish
I 'm trying to run your code on my dataset.
But I got this error :
(I think the problem may be about init() in runTraining(args) .)
(Do you have any idea to solve it ? Thanks~)

docker@warriors:[/Desktop/CodeFolder/attention/multi_scale_guided_attention]$ bash train.sh

----------------------------------------
 Dataset: ./DataSet/
~~~~~~~~~~~ Creating the DAF Stacked model ~~~~~~~~~~
Traceback (most recent call last):
  File "src/main.py", line 349, in <module>
    runTraining(args)
  File "src/main.py", line 89, in runTraining
    net = DAF_stack()
  File "/Desktop/CodeFolder/attention/multi_scale_guided_attention/src/models/my_stacked_danet.py", line 61, in __init__
    self.pam_attention_1_1= PAM_CAM_Layer(64, True)
  File "/Desktop/CodeFolder/attention/multi_scale_guided_attention/src/models/attention.py", line 175, in __init__
    nn.PReLU()
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 565, in __setattr__
    "cannot assign module before Module.__init__() call")
AttributeError: cannot assign module before Module.__init__() call

Issue in common/utils.py

On line 33, the following file location is requested
path_GT = './DataSet_Challenge/GT_Nifti/Val_1'
What does this exactly correspond to? You had instructed to put the slices in Img/ and masks in GT/ for each of test, train and val. So kindly clarify what does this correspond to.
It could be a greater help if you can host the entire data that you are using, in the exact format on a drive link.
@sinAshish @josedolz

datasets

can you share the datasets link to me ?thanks

Out of memory

Hello,how much memory is needed for training? It prompt out of memory.Thank you!

Code

Hello,

I can't find the code.

Questions about output segmentation

Thanks for your creative work! But I want to know how to make the L0,L1,L2,L3 to be the segmention?and L0,L1,L2,L3 are same?
Best wishes!
image

About the pre-trained weights

We are sorry that we have re-trained your code on BRATS and CHAOS datasets in our experiments, but we did not get the result reported in the paper. Hence, can you release the re-trained weights, which will help us more accurately report your results?

How to generate png data?

Hi!
I am interested in the repo. But the released code have not mention how to generate png data.
Could you share the preprocessing dataset code?
Thanks!

Model complexity

Hello and congratulation for your work,

I study your paper and decided to reproduce the models
in tensorflow, however I get an elevated number of parameters
for a 128*128 original input : about 106 millions with ResNet101
in my implementation. Moreover, when I check the number of
parameters of the ResNext101 you use in Pytorch, I get about 88
millions for the ResNext alone.
Since this doesn't correspond to the Table VIII
of your appendix at all, I wonder what the numbers in this table exactly
corresponds to.
Anyway thanks for your contribution.

Joris Fournel

Questions about picture size and output channels

Thank you for sharing your code!

I have been reading your paper and code these days, I found that 7 subjects' picture size of CHAOS dataset (T1 DUAL, In Phase) is 288288, which are different with other 13 subjects.So I'm confused whether to crop or resize them to 256256.

Besides, I found that the output channels in your my_stacked_danet.py python file was set as 5:
` self.predict4 = nn.Conv2d(64, 5, kernel_size=1)
self.predict3 = nn.Conv2d(64, 5, kernel_size=1)
self.predict2 = nn.Conv2d(64, 5, kernel_size=1)
self.predict1 = nn.Conv2d(64, 5, kernel_size=1)

    self.predict4_2 = nn.Conv2d(64, 5, kernel_size=1)
    self.predict3_2 = nn.Conv2d(64, 5, kernel_size=1)
    self.predict2_2 = nn.Conv2d(64, 5, kernel_size=1)
    self.predict1_2 = nn.Conv2d(64, 5, kernel_size=1)`

But there are 4 classes ,if I'm right,in the dataset,so here comes another question: Why is it?

Waiting for your guidance in your spare time,thanks! TAT

out of memory

Why 22G+ of CUDA memory is occupied during training, when the batchsize is 2

Segmentation target and output size mismatch during loss calculation

Hey, thanks for sharing your research. I was trying to train the model using your code but could not due to following issue:

ValueError: Expected input batch_size (1) to match target batch_size (256).

I am using batch size 1 and the sizes for inputs that is fed into (loss0 = CE_loss(outputs0, Segmentation_class)) are:

Segmentation_class torch.Size([256, 256])
outputs0 torch.Size([1, 5, 256, 256])

The target segmentation should be properly converted

About the function of one hot code

def getOneHotSegmentation(batch):
backgroundVal = 0

# Chaos MRI (These values are to set label values as 0,1,2,3 and 4)
label1 = 0.24705882
label2 = 0.49411765
label3 = 0.7411765
label4 = 0.9882353

oneHotLabels = torch.cat((batch == backgroundVal, batch == label1, batch == label2, batch == label3, batch == label4),
                         dim=1)

return oneHotLabels.float()

How can get the float value of the ground truth?
Why do not set these label to 1,0,0; 0 ,1,0.....

For the dataset split

Hi, Thank you for your contribution, which inspires me a lot.
Can you provide the Chaos dataset split list(i.e. train/val/test/ in three-fold) for reproducing? Thank you!

Parameters for training on other models (UNet, DANet, PAN, DAF)

Hi, I can see in your paper that you have compared the performance of your model with several other models (UNet, DANet, PAN, DAF), however, I am not sure if you have used the same training parameters (batch-size, learning rate, total number of epochs etc) for training the datasets (CHAOS) on other models? Are the training parameters same for all models?

About the code before softmax in CAM_Module

Hi, thanks for sharing this awesome project :)

Here's a question while reading your source code. In CAM_Module, there is one line code before softmax function. It doesn't exist in PAM_Module. According to my understanding, it means you use the maximum value (which calculated by query dot key) each channel vector to minus every value respectively. But...it will be the larger number equals the more irrelevant channel, right?
Sorry... I cannot understand this, could you kindly explain it for me? Thanks a lot!

CAM_Module

Evaluation metrics (dice coefficient, volume similarity and mean surface distance) on 2D or 3D segmentations?

Hi, I am trying to reproduce the results given in paper. Initially I assumed the dice scores and volume similarity has been calculated for 2D segmentation results but I saw you mentioned 'Since inter- slice distances and x-y spacing for each individual scan are not provided, we report these results on voxels.' in paper, does that mean the dice scores, volume similarity and MSD have been calculated on 3D segmentations (i.e., first reconstructing 3D segmentations and then evaluating)? Also, I could not find formulation for volume similarity and MSD, it would be great if you can append these evaluation metrics in your source code.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.