Code Monkey home page Code Monkey logo

ssl-medseg's Introduction

Self-Supervised Pretraining for 2D Medical Image Segmentation

This repository is the official implementation of Self-Supervised Pretraining for 2D Medical Image Segmentation (accepted for the AIMIA workshop at ECCV 2022).

pretraining_strategies

If you use our code or results, please cite our paper:

@InProceedings{Kalapos2022,
  author    = {Kalapos, Andr{\'a}s and Gyires-T{\'o}th, B{\'a}lint},
  booktitle = {Computer Vision -- ECCV 2022 Workshops},
  title     = {{Self-supervised Pretraining for 2D Medical Image Segmentation}},
  year      = {2023},
  address   = {Cham},
  pages     = {472--484},
  publisher = {Springer Nature Switzerland},
  doi       = {10.1007/978-3-031-25082-8_31},
  isbn      = {978-3-031-25082-8},
}

Requirements

Required python packages

To install pypi requirements:

pip install -r requirements.txt

For self-supervised pre-training solo-learn==1.0.6 is also required. For it's installation, follow instructions in solo-learn's documentation (dali, umap support is not needed) or use the following commands:

git clone https://github.com/vturrisi/solo-learn.git
cd solo-learn
pip3 install -e .

Dataset setup

Download the ACDC Segmentation dataset from: https://acdc.creatis.insa-lyon.fr (registration required)

Specify the path for the dataset in:

Slurm

For hyperparamter sweeps and running many experiments in a batch, we use Slurm jobs, therefore an installed and configured Slurm environment is required for these runs. However based on supervised_segmentation/sweeps/data_eff_learning.sh and supervised_segmentation/sweeps/grid_search_helper.py other methods of running experiments in a batch can be implemented.

Training

Supervised segmentation (downstream) training

To train the model(s) in the paper, run this command:

PYTHONPATH=. python supervised_segmentation/train.py

On a Slurm cluster:

PYTHONPATH=. srun -p gpu --gres=gpu --cpus-per-task=10 python supervised_segmentation/train.py

To initialize the downstream training with different pretrained models, we provide pretrained weights that we used in our paper. These can be selected by setting the encoder_weights config in supervised_segmentation/config_acdc.yaml

Pre-training approach Corresponding arrow on the figure above encoder_weights
Supervised ImageNet arrow-generalist-supervised supervised_imagenet
BYOL ImageNet arrow-generalist-selfsupervised resnet50_byol_imagenet2012.pth.tar
Supervised ImageNet + BYOL ACDC arrow-hierarchical-supervised (2nd step) supervised-imagenet-byol-acdc-ep=25.pth
BYOL ImageNet + BYOL ACDC arrow-hierarchical-selfsupervised (2nd step) byol-imagenet-acdc-ep=34.ckpt
or
byol-imagenet-acdc-ep=25.pth
BYOL ACDC arrow-specialist byol_acdc_backbone_last.pth

Pretraining

For ImageNet pretraining, we acquire weights from:

To pretrain a model on the ACDC dataset, run this command:

python self_supervised/main_pretrain.py --config-path configs/ --config-name byol_acdc.yaml

Different pretraining strategies can be configured by specifying different --pretrained_weights and --max_epochs or by modifying the corresponding configs in self_supervised/config_acdc.yaml.

Pre-training approach Corresponding arrow on the figure above --pretrained_weights --max_epochs Published pretrained model
[models.zip]
Supervised ImageNet + BYOL ACDC arrow-hierarchical-supervised(1st step) supervised_imagenet 25 models/supervised-imagenet-byol-acdc-ep=25.pth
BYOL ImageNet + BYOL ACDC arrow-hierarchical-selfsupervised (1st step) resnet50_byol_imagenet2012.pth.tar 25 models/byol-imagenet-acdc-ep=34.ckpt
and
models/byol-imagenet-acdc-ep=25.pth
BYOL ACDC arrow-specialist None 400 models/byol_acdc_backbone_last.pth

We publish pretrained models for these pretrainings as specified in the last column of the table

Evaluation

To evaluate the segmentation model on the ACDC dataset, run:

PYTHONPATH=. python supervised_segmentation/inference_acdc.py

Custom dataset

Supervised segmentation (downstream) training

Custom segmentation datasets can be loaded via mmsegmentations's custom dataset API. This requires the data to be formatted in the following directory structure:

├── data
│   ├── custom_dataset
│   │   ├── img_dir
│   │   │   ├── train
│   │   │   │   ├── xxx{img_suffix}
│   │   │   │   ├── yyy{img_suffix}
│   │   │   │   ├── zzz{img_suffix}
│   │   │   ├── val
│   │   ├── ann_dir
│   │   │   ├── train
│   │   │   │   ├── xxx{seg_map_suffix}
│   │   │   │   ├── yyy{seg_map_suffix}
│   │   │   │   ├── zzz{seg_map_suffix}
│   │   │   ├── val

The following configurations must be specified in data/config/custom.py:

data_root = PATH_TO_DATASET # Path to custom dataset
num_classes = NUM_CLASSES # Number of segmentation classes 
in_channels = IN_CHANNELS  # Number of input channels (e.g. 3 for RGB data)
class_labels = CLASS_LABELS # Class labels used for logging
ignore_label = IGNORE_LABEL  # Ignored label during iou metric computation
img_suffix='.tiff'
seg_map_suffix = '_gt.tiff'

To train the model(s) run this command:

PYTHONPATH=. python supervised_segmentation/train.py --config_path supervised_segmentation/config_custom.yaml

Pretraining

If your custom dataset is in a simple image folder format, solo-learn's built in data loading should handle your dataset (including dali support). In this case you onyl need to specify the following paths in self_supervised/configs/byol_custom.yaml

train_path: "PATH_TO_TRAIN_DIR"
val_path: "PATH_TO_VAL_DIR" # optional

To run the SSL pretraining on a custom dataset:

python self_supervised/main_pretrain.py --config-path configs/ --config-name byol_custom.yaml

For more complex cases you can build a custom dataset class (similar to data.acdc_dataset.ACDCDatasetUnlabeleld) and instantiate it in self_supervised/main_pretrain.py#L183.

Copyright

Segmentation code is based on: ternaus/cloths_segmentation

Self-supervised training script is based on: vturrisi/solo-learn

BYOL ImageNet pretrained model from: yaox12/BYOL-PyTorch [Google drive link for their model file]

This Readme is based on: paperswithcode/releasing-research-code

ssl-medseg's People

Contributors

kaland313 avatar

Stargazers

liuxin avatar  avatar  avatar  avatar Andrew Boyley avatar  avatar  avatar MinghaoZhou avatar Duyen Nguyen avatar Zhou Han avatar  avatar Myrrolinz avatar  avatar  avatar Jeannine avatar  avatar Ellery Queen avatar Dongyue (Oliver) Li avatar Miguel Ferreira avatar Cansu Yalçın avatar  avatar Hyejin Hur avatar  avatar YeboSun avatar Rohan Banerjee avatar sophiahsu avatar Robert Turnbull avatar Yongquan Yang avatar Anna Woodard avatar  avatar Yuchong Yao avatar  avatar  avatar  avatar  avatar Zhangyc avatar  avatar JINWAC avatar TOMCAT avatar  avatar Tabris avatar

Watchers

James Cloos avatar Kostas Georgiou avatar Rohan Banerjee avatar  avatar

ssl-medseg's Issues

train_acdc.py

the command from supervised_segmentation.train_acdc import CardiacSegmentation in inference_acdc.py has an error, and the file train_acdc.py is not found. Can you provide the missed file? Thanks!

How was the validation data split?

Hi all,

Great work and thanks for the open-sourcing.

I read your paper and it's not very clear how you split the ACDC dataset into training and validation sets. Please, what percentage or heuristic do you use to create the validation set?

models

excuse me~the names of the models in the file resnet50_byol_imagenet2012.pth.tar are the same as model.zip, are they same or just having the same names?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.