Code Monkey home page Code Monkey logo

knee-segmentation's Introduction

Automated Tissue Segmentation from High-Resolution 3D Steady-State MRI with Deep Learning

Albert Ugwudike, Joe Arrowsmith, Joonsu Gha, Kamal Shah, Lapo Rastrelli, Olivia Gallupova, Pietro Vitiello


2D Models Implemented

  • SegNet
  • Vanilla UNet
  • Attention UNet
  • Multi-res UNet
  • R2_UNet
  • R2_Attention UNet
  • UNet++
  • 100-layer Tiramisu
  • DeepLabv3+

3D Models Implemented

  • 3D UNet
  • Relative 3D UNet
  • Slice 3D UNet
  • VNet
  • Relative VNet
  • Slice VNet

Results

Baseline Comparision of 3D Methods

Model Input Shape Loss Val Loss Duration / Min
Small Highwayless 3D UNet (160,160,160) 0.777 0.847 86.6
Small 3D UNet (160,160,160) 0.728 0.416 89.1
Small Relative 3D UNet (160,160,160),(3) 0.828 0.889 90.1
Small VNet (160,160,160) 0.371 0.342 89.5

Small 3D Unet Highwayless (160,160,160)

Training Loss Training Progress
small-highway-less-loss small-highway-less-progress

Small 3D Unet (160,160,160)

Training Loss Training Progress
small-3d-unet-loss small-3d-unet-progress

Small Relative 3D Unet (160,160,160),(3)

Training Loss Training Progress
small-relative-3d-unet-loss small-relative-3d-unet-progress

Small VNet (160,160,160)

Training Loss Training Progress
small-vnet-loss small-vnet-progress

Comparison of VNet Methods

Model Input Shape Loss Val Loss Roll Loss Roll Val Loss Duration / Min
Tiny (64,64,64) 0.627 ± 0.066 0.684 ± 0.078 0.652 ± 0.071 0.686 ± 0.077 61.5 ± 5.32
Tiny (160,160,160) 0.773 ± 0.01 0.779 ± 0.019 0.778 ± 0.007 0.787 ± 0.016 101.8 ± 2.52
Small (160,160,160) 0.648 ± 0.156 0.676 ± 0.106 0.656 ± 0.152 0.698 ± 0.076 110.1 ± 4.64
Small Relative (160,160,160),(3) 0.653 ± 0.168 0.639 ± 0.176 0.659 ± 0.167 0.644 ± 0.172 104.6 ± 9.43
Slice (160,160,5) 0.546 ± 0.019 0.845 ± 0.054 0.559 ± 0.020 0.860 ± 0.072 68.6 ± 9.68
Small (240,240,160) 0.577 ± 0.153 0.657 ± 0.151 0.583 ± 0.151 0.666 ± 0.149 109.7 ± 0.37
Large (240,240,160) 0.505 ± 0.262 0.554 ± 0.254 0.508 ± 0.262 0.574 ± 0.243 129.2 ± 0.50
Large Relative (240,240,160),(3) 0.709 ± 0.103 0.880 ± 0.078 0.725 ± 0.094 0.913 ± 0.081 148.6 ± 0.20
Baseline results from training VNet models for 50 epochs, exploring how quick models converge. Models optimized for dice loss using a scheduled Adam optimizier. Start learning rate: $5e^{-5}$, Schedule drop: $0.9$, Schedule drop epoch frequency: $3$. Z-Score normalisation and replacement of outliers with mean pixel was applied to inputs. Subsamples were selected normally distributed from the centre. Github commit: cb39158

Optimal training session is choosen for each visulation.

Tiny VNet (64,64,64)

Training Loss Training Progress
tiny_vnet_646464_loss tiny_vnet_646464_progress

Tiny VNet (160,160,160)

Training Loss Training Progress
tiny_vnet_160160160_loss tiny_vnet_160160160_progress

Small VNet (160,160,160)

Training Loss Training Progress
small_vnet_160160160_loss small_vnet_160160160_progress

Small Relative VNet (160,160,160),(3)

Training Loss Training Progress
small_rel_vnet_160160160_loss small_rel_vnet_160160160_progress

Small Slice VNet (160,160,5)

Training Loss Training Progress
small_slice_vnet_1601605_loss small_slice_vnet_1601605_progress

Small VNet (240,240,160)

Training Loss Training Progress
small_vnet_240240160_loss small_vnet_240240160_progress

Large VNet (240,240,160)

Training Loss Training Progress
large_vnet_240240160_loss large_vnet_240240160_progress

Large Relative VNet (240,240,160),(3)

Training Loss Training Progress
large_rel_vnet_240240160_loss large_rel_vnet_240240160_progress

Useful Code Snippets

Run 3D Train

python Segmentation/model/vnet_train.py
Unit-Testing and Unit-Test Converage

python -m pytest --cov-report term-missing:skip-covered --cov=Segmentation && coverage html && open ./htmlcov.index.html
Start tensorboard on Pompeii

On pompeii: tensorboard --logdir logs --samples_per_plugin images=100

On your local machine: ssh -L 16006:127.0.0.1:6006 username@ip

Go to localhost: http://localhost:16006/

Valid 3D Configs

Batch / GPU Crop Size Depth Crop Size Num Channels Num Conv Layers Kernel Size
1 32 32 20 2 (5,5,5)
1 64 64 32 2 (3,3,3)
1 64 64 32 2 (5,5,5)
3 64 32 16 2 (3,3,3)

knee-segmentation's People

Contributors

arahosu avatar joearrowsmith avatar laporastrelli avatar olive004 avatar olivemicrographia avatar pietrovitiello avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

knee-segmentation's Issues

bfloat16 cannot be used with tf.keras.backend

Our current loss functions are implemented using tf.keras.backend. This throws an error when we use more memory-efficient data types like bfloat16 or other mixed-precision strategies. We could potentially reimplement the loss functions using low-level Tensorflow function (without relying on Keras backend functions)

Downsampler

  • Bounding box of bounding boxes
  • Valid performance of model using the data obtained from kamal and albert.

TODO list before abstract submission

  • Move class 'train' and its helper functions to train folder (Olivia & Joe)
  • Integrate 3D models into main.py
  • Integrate training functions from vnet_train.py into main.py (Olivia)
  • Rewrite evaluation functions to compute the dice coefficient for each cartilage class (multi-class) (Joonsu & Joe)
  • Rewrite all tf.slice operations to use more Pythonic expressions
  • Implement weighted categorical-cross-entropy (compute weights using the distribution of cartilage classes in the training dataset) (Olivia & Pietro)
  • Refactor visualisation code (everything in vnet_train.py and main.py should be in the same folder/file) (Olivia & Pietro)
  • Rewrite unit tests (Lapo)

Confusion Matrix for Multi-class Segmentation results

In the Segmentation.utils.evaluation_metrics.py, we should include a function that computes the confusion matrix of the predictions from the model. That should give us a clearer idea of which cartilages are the most difficult to segment.

UNet

  • Attention U-Net as class
  • Multi-Res U-Net as class
  • Reformat attention block to use Conv2D_Block
  • Implement the number of convolution layers in all classes
  • Implement option for 'channel_first' data format for batch normalization across all layers

GAN

Where we are at building the baseline GAN

  • Generator
  • Discriminator
  • Loss functions
  • Optimizer
  • Training script

Issue with gradient flows in Vnet

W0227 10:43:46.311452 140678072145728 optimizer_v2.py:1043] Gradients do not exist for variables ['vnet_small_relative/c3/conv3d_4/kernel:0', 'vnet_small_relative/c3/conv3d_4/bias:0', 'vnet_small_relative/c3/conv3d_5/kernel:0', 'vnet_small_relative/c3/conv3d_5/bias:0', 'vnet_small_relative/c3/batch_normalization_2/gamma:0', 'vnet_small_relative/c3/batch_normalization_2/beta:0', 'vnet_small_relative/cu3/conv3d_6/kernel:0', 'vnet_small_relative/cu3/conv3d_6/bias:0', 'vnet_small_relative/cu3/batch_normalization_3/gamma:0', 'vnet_small_relative/cu3/batch_normalization_3/beta:0', 'vnet_small_relative/upc2/conv3d_8/kernel:0', 'vnet_small_relative/upc2/conv3d_8/bias:0', 'vnet_small_relative/upc2/conv3d_9/kernel:0', 'vnet_small_relative/upc2/conv3d_9/bias:0', 'vnet_small_relative/upc2/batch_normalization_5/gamma:0', 'vnet_small_relative/upc2/batch_normalization_5/beta:0', 'vnet_small_relative/compessor_0/kernel:0', 'vnet_small_relative/compessor_0/bias:0', 'vnet_small_relative/compessor_1/kernel:0', 'vnet_small_relative/compessor_1/bias:0', 'vnet_small_relative/compessor_2/kernel:0', 'vnet_small_relative/compessor_2/bias:0', 'vnet_small_relative/dense/kernel:0', 'vnet_small_relative/dense/bias:0', 'vnet_small_relative/dense_1/kernel:0', 'vnet_small_relative/dense_1/bias:0', 'vnet_small_relative/dense_2/kernel:0', 'vnet_small_relative/dense_2/bias:0'] when minimizing the loss.

TODO List before poster submission @ IWOAI 2020

Changes to be made in dev_v2 branch

Essential

  • Replace all tf.slice operations with more Pythonic indexing and slicing operations in Segmentation/train/reshape.py
  • Rewrite loss functions and metrics to generalise to 2D and 3D
  • Check that visualise_sample in train.py works with both 2D and 3D data
  • Rename enable_function to run_eagerly
  • Show metrics for all (verbose=True) or selected (verbose=False) segmentation classes during training/evaluation
  • Unit tests for all models and training/evaluation functions in the pipeline
  • Implement a separate Evaluator class in evaluate folder along with its helper functions
  • Load dataset should return num_classes, or otherwise make less hardcoded

Optional

  • Include an option to select training mode ('fully-supervised', 'semi-supervised', 'self-supervised', 'transfer') in Trainer class
  • Implement a function that allows the users to choose colours for multi-class segmentation evaluation (Similar to 'color' argument in matplotlib.pyplot functions)
  • Implement an option to perform k-fold cross validation/bootstrapping to validate model performance in Evaluator classs
  • Write functions for downloading public medical imaging datasets

Data augmentation fails on TPU

System Information

  • OS Platform and Distribution: Linux Ubuntu 16.04
  • Tensorflow Version: 2.1
  • Tensorflow Addons Version: 0.9.1
  • Python version: 3.7
  • TPU hardware used: v3-8

Describe the bug

Data augmentation transformation methods implemented using Tensorflow Addons library fail on TPU (but not on GPU). The custom_ops in Tensorflow Addons are implemented in Tensorflow and not XLA HLO.

The issue has been recently raised on the Tensorflow Addons repository (https://github.com/tensorflow/addons/issues/1553)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.