Code Monkey home page Code Monkey logo

brainstorm's Introduction

brainstorm

This repository contains the authors' implementation from "Data Augmentation using Learned Transformations for One-shot Medical Image Segmentation", which will be presented as an oral at CVPR 2019. We provide code for training spatial and appearance transform models, and for using the transform models to synthesize training examples for segmentation.

If you use our code, please cite:

Data augmentation using learned transforms for one-shot medical image segmentation
Amy Zhao, Guha Balakrishnan, Fredo Durand, John Guttag, Adrian V. Dalca
CVPR 2019. eprint arXiv:1902.09383

Getting started

Prerequisites

To run this code, you will need:

  • Python 3.6+ (Python 2.7 may work but has not been tested)
  • CUDA 10.0+
  • Tensorflow 1.13+ and Keras 2.2.4+
  • one GPU with 12 GB of memory (we used a single NVIDIA Titan X)

Download module dependencies

Run the script setup.sh. This will automatically pull the following dependencies and place them in the correct subdirectories:

Setting up your dataset

We have included a few sample MRI scans (including volumes and segmentations) in the data/ folder. If you wish to use the datasets mentioned in the paper, you should download them directly from the respective dataset sites.

If you wish to use your own dataset, place your volume and segmentation files in the data/ folder. The data loading code in src/mri_loader.py expects each example to be stored as a volume file {example_id}_vol.npz and, if applicable, a corresponding {example_id}_seg.npz file, with the data stored in each file using the keys vol_data and seg_data respectively. The functions load_dataset_files and load_vol_and_seg in src/mri_loader.py can be easily modified to suit your data format.

Training transform models

This repo does not include any pre-trained models. You may train your own spatial and appearance transform models by specifying the GPU ID, dataset name, and the model type.

python main.py trans --gpu 0 --data mri-100unlabeled --model flow-fwd
python main.py trans --gpu 0 --data mri-100unlabeled --model flow-bck
python main.py trans --gpu 0 --data mri-100unlabeled --model color-unet

The results will be placed in experiments/. Note that in order to train an appearance/color transform model, you will want to edit main.py to point at your trained forward/backward spatial transform models. We have provided pretrained forward/backward spatial transform models for testing.

As described in the paper, each model is implemented using a simple architecture based on U-Net. You can change hyperparameters by modifying transform_model_arch_params in main.py. We encourage you to experiment with your favorite model architecture, and to adjust the model parameters to suit your dataset.

Training a segmentation network

You may train a segmentation model by specifying the GPU ID and dataset name.

python main.py seg --gpu 0 --data mri-100unlabeled

Again, results will be placed under experiments/.

You can use additional flags:

  • --aug_rand will apply random augmentation to each training example consisting of a random smooth deformation and a random global multiplicative intensity factor.
  • --aug_sas will pseudo-label any unlabeled examples in the training set using the specified spatial registration model.
  • --aug_tm will synthesize training examples using our method.

If you wish to use --aug_sas or --aug_tm, it is important to specify the spatial and appearance transform models to use in seg_model_arch_params in main.py.

Evaluation

To evaluate trained segmenters, look at the code in evaluate_segmenters.py. You will have to modify the code to point at your trained models.

Repo name inspired by Magic: The Gathering.

Brainstorm

brainstorm's People

Contributors

xamyzhao avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.