Code Monkey home page Code Monkey logo

mgu-net's Introduction

MGU-Net

Multi-Scale GCN-Assisted Two-Stage Network for Joint Segmentation of Retinal Layers and Disc in Peripapillary OCT Images

The codes are implemented in PyTorch and trained on NVIDIA Tesla V100 GPUs.

Introduction

An accurate and automated tissue segmentation algorithm for retinal optical coherence tomography (OCT) images is crucial for the diagnosis of glaucoma. However, due to the presence of the optic disc, the anatomical structure of the peripapillary region of the retina is complicated and is challenging for segmentation. To address this issue, we develop a novel graph convolutional network (GCN)-assisted two-stage framework to simultaneously label the nine retinal layers and the optic disc. Specifically, a multi-scale global reasoning module is inserted between the encoder and decoder of a U-shape neural network to exploit anatomical prior knowledge and perform spatial reasoning. We conduct experiments on human peripapillary retinal OCT images. We also provide public access to the collected dataset, which might contribute to the research in the field of biomedical image processing. The Dice score of the proposed segmentation network is 0.820 ± 0.001 and the pixel accuracy is 0.830 ± 0.002, both of which outperform those from other state-of-the-art techniques.

Experiments

Dataset

  1. Collected dataset: Download our collected dataset.
    The labeled images are grayscale images. Labels and corresponding pixel values are as follows:
    RNFL=26, GCL=51, IPL=77, INL=102, OPL=128, ONL=153, IS/OS=179, RPE=204, Choroid=230, Optic Disc=255

  2. Public dataset: Download Duke SD-OCT dataset

Train and test

Run the following script to train and test our model.

python main_ts.py --name tsmgunet -d ./data/dataset --batch-size 1 --epoch 50 --lr 0.001

Results

Results on the collected dataset

Results on the public dataset

For more details, please refer to our paper.

Citation

If you use the codes or collected dataset for your research, please cite the following paper:

@article{li2021mgunet,
author = {Jiaxuan Li and Peiyao Jin and Jianfeng Zhu and Haidong Zou and Xun Xu and Min Tang and Minwen Zhou and Yu Gan and Jiangnan He and Yuye Ling and Yikai Su},
journal = {Biomed. Opt. Express},
number = {4},
pages = {2204--2220},
title = {Multi-scale GCN-assisted two-stage network for joint segmentation of retinal layers and discs in peripapillary OCT images},
volume = {12},
year = {2021},
url = {http://www.osapublishing.org/boe/abstract.cfm?URI=boe-12-4-2204},
doi = {10.1364/BOE.417212},
}

Acknowledgements

The codes are built on AI-Challenger-Retinal-Edema-Segmentation and GloRe. We sincerely appreciate the authors for sharing their codes.

Contact

If you have any questions, please do not hesitate to contact [email protected]

mgu-net's People

Contributors

jiaxuan-li avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

mgu-net's Issues

Dice loss

In the Dice loss, there's a line:
input = torch.exp(input)
but shouldn't it be the sigmoid function? The input is the last conv layer output from the network and a sigmoid will convert it to probability within [0,1]
input = torch.sigmoid(input)

关于test_dataset

test_dataset的参数phase是不是应该设置为predict更加的合理,因为设置test或eval实际上调用的是一个函数

Script for using the public dataset

First of all, thanks Jiaxuan-Li for your work and how tidy the code base is. It worked for me almost straight off the shelf!

Regarding the public dataset - it is not in the format the code expects, but rather in Matlab objects. I guess you've been using a script to transform it. Is that the case? If so, would it be possible to share it?

Thanks!
Dan

Training on my own dataset

(MGU-Net-main) E:_BACKUP\lbc\MGU-Net-main>python main_ts.py --name tsmgunet -d ./data/dataset --batch-size 1 --epoch 50 --lr 0.001
torch version: 1.10.0
Total amount of train images is : 96
Total amount of eval images is : 96
Total amount of test images is : 96
data_dir : ./data/dataset
name : tsmgunet
workers : 2
step : 20
batch_size : 1
epochs : 50
lr : 0.001
lr_mode : step
momentum : 0.9
weight_decay : 0.0001
t : t1
model_path :
############### tsmgunet ###############
[2021-12-03 04:27:43,359 main_ts.py:278 train_seg] Epoch: [0]
Traceback (most recent call last):
File "main_ts.py", line 394, in
main()
File "main_ts.py", line 386, in main
train_seg(args,train_result_path,train_loader,eval_loader)
File "main_ts.py", line 280, in train_seg
loss,dice_train,dice_1,dice_2,dice_3,dice_4,dice_5,dice_6,dice_7,dice_8,dice_9,dice_10 = train(args,train_loader, model,criterion1, criterion2, optimizer,epoch)
File "main_ts.py", line 53, in train
output_seg1,_,output_seg = model(input_var1)
File "E:\anaconda\envs\MGU-Net-main\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "E:\anaconda\envs\MGU-Net-main\lib\site-packages\torch\nn\parallel\data_parallel.py", line 166, in forward
return self.module(*inputs[0], **kwargs[0])
File "E:\anaconda\envs\MGU-Net-main\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "E:_BACKUP\lbc\MGU-Net-main\models\nets\TSNet.py", line 16, in forward
out1 = self.stage1(inputs)
File "E:\anaconda\envs\MGU-Net-main\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "E:_BACKUP\lbc\MGU-Net-main\models\nets\MGUNet.py", line 106, in forward
up3 = self.up_concat3(center, conv3)
File "E:\anaconda\envs\MGU-Net-main\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "E:_BACKUP\lbc\MGU-Net-main\models\utils\utils.py", line 91, in forward
outputs0 = torch.cat([outputs0,input[i]], 1)
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 52 but got size 53 for tensor number 1 in the list.
I would like to ask what went wrong

Data set request

First of all, thank you for your contribution, Jiaxuan Li. Since I have not been able to get a response after filling out the data collection form, I would like to ask if you can give me a way to get the dataset. My Google email is [email protected].

Thanks!
Tom

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.