Code Monkey home page Code Monkey logo

miccai22_adn's Introduction

ADN

The Pytorch implementation of our MICCAI22 paper Asymmetry Disentanglement Network for Interpretable Acute Ischemic Stroke Infarct Segmentation in Non-Contrast CT Scans.

Example Results

ADN can separate different kinds of asymmetries in NCCT images ( $A$: total asymmetry map, $P$: pathologigcal asymmetry map, $Q$: intrinsic anatomical asymmetry map) and generate pathology-salient ( $X+Q$ ) or pathology-compensated ( $X+P$ ) images for better clinical examination.

Dependencies

Python 3.7.10, Pytorch 1.10.2, etc.

Quick Start

ADN includes three parts: transformation network $T$, assymmetry extraction network $D$, and segmentation network $F$. In our experiments, we first train $T$, and then fix $T$ and jointly train $D$ and $F$. The following codes show a simple example that how to train the network $T$.

# A toy example to show how to train transform network
import os
import torch
from model.transform_net import PlaneFinder
import torch.optim as optim

# set GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# transformation network T
align_model = PlaneFinder(is_train=True)
align_model.train()
align_model.cuda()

optimizer = optim.AdamW(align_model.parameters(),
                        lr=1e-5, betas=(0.9, 0.999), weight_decay=5e-4)

# load CT data
# size: (batch_size, num_channel, num_slices, height, width)
x = torch.rand((4, 1, 40, 256, 256)).cuda()

optimizer.zero_grad()

# x_t: transformer symmetric x
# view: x, y, z rotation and translation
# M: transformation matrix
# please refer to the comments of forward function for the definition of each return variable
x_t, _, _, view, M, _ = align_model(x)

align_model.loss_total.backward()

optimizer.step()

After training $T$, we jointly train $D$ and $F$. The following codes show a simple example about how to implement this. We first set warm_start=1 in the begining stages and then set warm_start=0.

# A toy example to show how to train the whole network without using tissue segmentation maps
# unet3d is borrowed from https://github.com/wolny/pytorch-3dunet/tree/master/pytorch3dunet/unet3d
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
from model.transform_net import PlaneFinder
import torch.optim as optim
import torch.nn as nn
from model.unet3d.unet_model import ResidualUNet3D
from model.unet3d.losses import GeneralizedDiceLoss

# affine transform
def stn(x, theta):
    # theta must be (Bs, 3, 4) = [R|t]
    grid = nn.functional.affine_grid(theta, x.size(), align_corners=False)
    out = nn.functional.grid_sample(x, grid, padding_mode='zeros', align_corners=False)
    return out

def loss_calc(pred, label):
    """
    This function returns cross entropy loss for semantic segmentation
    """
    label = label.cuda()
    BCELoss = nn.BCELoss()
    DiceLoss = GeneralizedDiceLoss(normalization="none")
    return BCELoss(pred, label), DiceLoss(pred.unsqueeze(dim=1), label.unsqueeze(dim=1))

# transformation network T
align_model = PlaneFinder(is_train=False)
align_model.cuda()
# load pretrained transformation network T
# Note that we train T first and then fixed it and train D and F
align_model.eval()

# asymmetry extraction network D
asym_model = ResidualUNet3D(in_channels=1, out_channels=1, f_maps=32, use_transconv=False, use_dp=True, p=0.2)
asym_model.cuda()
asym_model.train()

# segmentation network F
seg_model = ResidualUNet3D(in_channels=1, out_channels=1, f_maps=32, use_transconv=False, use_dp=True, p=0.2)
seg_model.cuda()
seg_model.train()

optimizer = optim.AdamW([{'params': asym_model.parameters(), 'lr': 1e-4},
                         {'params': seg_model.parameters(), 'lr': 1e-4}],
                        weight_decay=5e-4,
                        betas=(0.9, 0.999))

# load CT data
# size: (batch_size, num_channel, num_slices, height, width)
images = torch.rand((1, 1, 40, 256, 256)).cuda()
labels = torch.randint(size=(1, 40, 256, 256), low=0, high=2).float().cuda()

# Perform transformation
with torch.no_grad():
    images_t, images_r, images_t_f, _, M, M_inv = align_model(images)
    diff_t = images_t - images_t_f
    sym_comp_t = torch.zeros_like(images_t)
    sym_comp_t[diff_t > 0] = images_t[diff_t > 0]
    sym_comp_t[diff_t == 0] = images_t[diff_t == 0]
    sym_comp_t[diff_t < 0] = images_t_f[diff_t < 0]
    asym_map_t = nn.ReLU()(images_t_f - images_t)  # total asym map A

labels_t = stn(labels.unsqueeze(dim=1), M[:, :3, :]).squeeze(dim=1)

optimizer.zero_grad()
# separate asym to be anatomical asym Q and pathological asym P
subject_asym_conf_t = asym_model(images_t)  # pathological asym P
anatomy_asym_conf_t = asym_map_t - subject_asym_conf_t
anatomy_asym_conf_t = nn.ReLU()(anatomy_asym_conf_t)  # anatomical asym P
subject_asym_images_t = images_t + anatomy_asym_conf_t  # X_hat = X + Q

anatomy_asym_images_t = images_t + subject_asym_conf_t
anatomy_asym_images_t = torch.clamp(anatomy_asym_images_t, max=sym_comp_t)  # X_bar = X + P

# perform segmentation on X_hat
pred_t = seg_model(subject_asym_images_t)
pred_t = pred_t.squeeze(dim=1)
bce_loss, dice_loss = loss_calc(pred_t, labels_t)
seg_loss = bce_loss + dice_loss

lambda_seg = 1
lambda_reg = 10
warm_start = 1  # 1 use warm start else use regularization loss
# for warm start stage
if warm_start:
    reg_bce_loss, reg_dice_loss = loss_calc(subject_asym_conf_t.squeeze(dim=1), labels_t)
    reg_loss = reg_bce_loss + reg_dice_loss
else:
    subject_asym_msk_t = labels_t.unsqueeze(dim=1) == 1
    subject_asym_gt_t = asym_map_t * subject_asym_msk_t
    # 1. the size of subject asym (i.e. pathological asym) should be similar to the size of stroke
    reg_loss1 = nn.L1Loss()(subject_asym_conf_t.mean(), subject_asym_gt_t.mean())
    # 2. subject asym should from subject + anatomy
    sym_map_mask_t = asym_map_t == 0
    reg_loss2 = nn.L1Loss()(subject_asym_conf_t*sym_map_mask_t, torch.zeros_like(subject_asym_conf_t))
    # 3. anatomical asym should as large as possible
    reg_loss3 = -anatomy_asym_conf_t.mean()
    reg_loss = reg_loss1 + reg_loss2 + reg_loss3

loss = lambda_seg * seg_loss + lambda_reg * reg_loss
loss.backward()
optimizer.step()

DONE

The network structure of ADN

The code of complete process of training and testing

TODO

The illustration of data preparation, training and testing

Citing ADN

If you find our approaches useful in your research, please consider citing:

@inproceedings{ni2022asymmetry,
  title={Asymmetry Disentanglement Network for Interpretable Acute Ischemic Stroke Infarct Segmentation in Non-Contrast CT Scans},
  author={Ni, Haomiao and Xue, Yuan and Wong, Kelvin and Volpi, John and Wong, Stephen TC and Wang, James Z and Huang, Xiaolei},
  booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
  pages={416--426},
  year={2022},
  organization={Springer}
}

For questions with the code, please feel free to open an issue or contact me: [email protected]

Acknowledgement

Part of our code was borrowed from unsup3d and unet3d. We thank the authors of these repositories for their valuable implementations.

miccai22_adn's People

Contributors

nihaomiao avatar yuanxue1993 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

miccai22_adn's Issues

Some results and questions

Refer to some of your suggestions, have some results and questions, would like to discuss with you.
1、May I ask you which normalization method is used in data preprocessing?
I tried two normalization methods, one was global maximum and minimum normalization, and the other was clip to [0,80] , then followed by maximum and minimum normalization. When pretrain with your checkpoint, you can see a big difference in initial results. The result of clip first will be better. Did you clip during normalization? What is the clip range?
2、About the brain tissue segmentation network, why is the output channel 5 instead of 4?
I think background, gray matter, white matter, cerebrospinal fluid, there are four categories.
1
3、Do you still have the checkpoint of the brain tissue segmentation network(‘’gwm_seg_model‘’)?
4、The following results are my best results at present. There will be some anomalies in the test of some data. I am still investigating the reasons.
2
3
4

About dice metrics

May I ask whether the dice evaluation metrics of the paper is dice per case or global dice?

A sincere request

Can I add your WeChat account to inquire about the related issues of ADN article? If possible, I am willing to pay you money in return.

problem

we resample all NCCTs to be 1.2 × 1.2 × 5 mm3 and reshape them to be 256 × 256 × 40 using the Python package nilearn. 这里resample然后risize,这样的话resample就失去意义了,我的意思是应该滑动patch,但是因为对称性必须关注全图,这里就自相矛盾了

Confusion

“Yes, as mentioned in our paper, "after skull stripping, we resample all NCCTs to be 1.2 × 1.2 × 5 mm^3 and reshape
them to be 256×256×40 using the Python package nilearn."

“I think it is unreasonable to reshape the image to 40256256 after resampling. You should use the original volume to calculate the relevant metrics, because the asymmetry after resampling cannot represent the asymmetry of the original image.”

puzzeled

I'm sorry to bother you again. I downloaded SPM12 and would like to inquire about the specific details of skull dissection and GWM segmentation. Is your GWM segmented using the NII.gz file from CT?

some questions

First of all, thank you for sharing. I have been reproducing your paper and code recently, and I have some questions to ask you.
(1)The first Network I trained, Transformation Network T, used NVIDIA 3060 for one day, but the result was not very good. May I ask how long have you trained this network?

(2)For the mask in the AISD dataset, is it to use all 1,2,3,5 in the label as labels of the lesion area?
(3)Could you share the model parameters(checkpoint) of the three networks you trained T, D and F, so as to directly test the segmentation effect on the ncct images ?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.