Code Monkey home page Code Monkey logo

atlas-gan's Introduction

Generative Adversarial Template Construction

dHCP Templates

Tensorflow 2 code repository for Generative Adversarial Registration for Improved Conditional Deformable Templates, arXiv 2021.

train_script.py is the main template construction script that implements all methods considered in the paper for the 3D datasets.

The current code repository will be heavily refactored (e.g., improving data loading, better abstraction) in the coming days and weeks. FFHQ-Aging scripts only require a change from 3D to 2D and will be added as well.

Dependencies

We recommend setting up an anaconda environment and installing all dependencies as,

conda env create -f environment.yml
conda activate tf2

Usage

Example training call for conditional templates:

python conditional_script.py --name phd-ours-cond --dataset pHD --oversample --nonorm_reg --clip --losswt_gp 5e-4 --gen_config ours

CLI args are:

usage: train_script.py [-h] [--epochs EPOCHS] [--batch_size BATCH_SIZE] 
                       [--dataset DATASET] [--name NAME] [--d_train_steps D_TRAIN_STEPS]
                       [--g_train_steps G_TRAIN_STEPS] [--lr_g LR_G] [--lr_d LR_D]
                       [--beta1_g BETA1_G] [--beta2_g BETA2_G] [--beta1_d BETA1_D]
                       [--beta2_d BETA2_D] [--unconditional] [--nonorm_reg] [--oversample]
                       [--d_snout] [--clip] [--reg_loss REG_LOSS] [--losswt_reg LOSSWT_REG]
                       [--losswt_gan LOSSWT_GAN] [--losswt_tv LOSSWT_TV] [--losswt_gp LOSSWT_GP]
                       [--gen_config GEN_CONFIG] [--steps_per_epoch STEPS_PER_EPOCH]
                       [--rng_seed RNG_SEED] [--start_step START_STEP] [--resume_ckpt RESUME_CKPT]
                       [--g_ch G_CH] [--d_ch D_CH] [--init INIT] [--lazy_reg LAZY_REG]

With verbose descriptions:

CLI args:
    epochs: int
        Number of epochs to train for.
    batch_size: int
        Batch size for training. GPU memory typically only allows small batches
    dataset: str
        Dataset of interest. Currently one of {'dHCP', 'pHD'}
    name: str
        Name of experiment. Will be prepended to saved folders.
    d_train_steps: int
        Number of discriminator updates in each GAN cycle.
    g_train_steps: int
        Number of generator updates in each GAN cycle.
    lr_g: float
        Learning rate for generator.
    lr_d: float
        Learning rate for discriminator.
    beta1_g: float
        Adam beta1 parameter for the generator.
    beta2_g: float
        Adam beta2 parameter for the generator.
    beta1_d: float
        Adam beta1 parameter for the generator.
    beta2_d: float
        Adam beta2 parameter for the discriminator.
    unconditional: bool
        Whether to train conditional/unconditional templates.
    nonorm_reg: bool
        Whether to use instance normalization in registration branch.
        Not used in the paper.
    oversample: bool
        Whether to oversample rare ages during training.
    d_snout: bool
        Whether to apply Spectral Norm to the last layer of the Discriminator.
    clip: bool
        Whether to clip the template background during training.        
    reg_loss: str
        Type of registration loss. One of {'NCC', 'NonSquareNCC'}.
    losswt_reg: float
        Multiplier for deformation regularizers.
    losswt_gan: float
        GAN loss weight in generator loss.
    losswt_tv: float
        Weight of TV penalty on generated templates.
        Not used in paper.
    losswt_gp: float
        Gradient penalty for discriminator loss.
    gen_config: str
        Template generator architecture. One of {'ours', 'voxelmorph'}.
    steps_per_epoch: int
        Number of steps per epoch.
    rng_seed: int
        Seed for random number generators.
    start_step: int
        Step to activate GAN training (as opposed to just registration).
        Not used in paper. GAN training is active from the first iteration.
    resume_ckpt: int
        If >0 then resume training from given ckpt index
    g_ch: int
        Channel width multiplier for generator.
    d_ch: int
        Channel width multiplier for discriminator.
    init: str
        Weight initialization. One of {'default', 'orthogonal'}.
    lazy_reg: int
        Calculate/apply gradient penalty only once every lazy_reg iterations.
        Not used in the paper.

Data loaders:

The training script expects data points to be in the form of npz files. To construct a usable npz from a nifti file, the following code snippet was used:

import numpy as np
import SimpleITK as sitk

simg = sitk.ReadImage('/path/to/nifti.nii.gz')
npy_img = sitk.GetArrayFromImage(simg)

# Assuming that you have 'age' and 'attribute' loaded:
np.savez_compressed(
    './data/dataset_name/train_npz/fname.npz',
    vol=npy_img,
    age=age,
    attribute=attribute,
)

We recommend inspecting L196-L238 of train_script.py and ./src/data_generators.py for more details of how to modify the data loaders for your use case.

Acknowledgements:

This repo makes extensive usage of the VoxelMorph library.

atlas-gan's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

atlas-gan's Issues

dataset

Hi, thanks for providing such an excellent work. Could you provide your preprocessing dataset?

The structure of the generation network part

Hi @neel-dey ,
Sorry to bother you again. I tried to reproduce some code using PyTorch before, but I encountered issues with the generation network part. It consistently behaves differently from what I expected. Could you please provide a detailed explanation of the structure of the generation network part?Alternatively, could you help me take a look at my code to identify any issues?
微信图片_20240514114806

possibility of use in prostate dataset

Hello - Thanks for sharing your work !
I would like to use your algorithm on prostate dataset and ask couple questions.
As far as I analyzed the code I do not see it taking into account spacing, orientation and direction - so I suppose all should be the same before starting training?

I have t2 weighted images in transverse, saggital and coronal plane and 3 resolutions for each 256x256; 512x512 and 1024x1024, should I train using all of the views of t2 or separately each? what do you think best do with differences in resolution - pad smaller images? resize and interpolate? train separate conditional template for each resolution?

I understand that some is unknown and possible to establish only by experimentation - but you know your tool best so any comment would be highly valuable .

Thank you !

Patches of noise on the edge of the generated atlas images

Hi @neel-dey,

First of all, thanks for making the code available. I tried out the model on my data set and everything turned out to be great. Except that during training I noticed that starting from around 10,000 iteration, strange-looking patches of noise started to appear.

1

I saved the volume as .nii file and observed that there seemed to be more than one patch.

2

I was trying to figure out the problem on my own. Here's what I have done. I confirmed that my rough template to start with did not have strange patches. All my data were normalized to 0 to 1.
I can't figure out what is wrong. Intuitively, the presence of the regularizer should be able to prevent voxels from walking that far to the corners. Plus, similarity loss should be able to penalize the generation of random patches of noise at places where they shouldn't be. I have noticed that in the script that clip_bckgnd = False. I wonder if setting this to True could help resolve this problem.

I would really appreciate if you could provide me with some insight on how this happened. Hope to hear back from you soon!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.