Code Monkey home page Code Monkey logo

torch-em's Introduction

Build Status DOI Anaconda-Server Badge

Torch'em

Deep-learning based semantic and instance segmentation for 3D Electron Microscopy and other bioimage analysis problems based on pytorch. Any feedback is highly appreciated, just open an issue!

Highlights:

  • Functional API with sensible defaults to train a state-of-the-art segmentation model with a few lines of code.
  • Differentiable augmentations on GPU and CPU thanks to kornia.
  • Off-the-shelf logging with tensorboard or wandb.
  • Export trained models to bioimage.io model format with one function call to deploy them in ilastik or deepimageJ.

Design:

  • All parameters are specified in code, no configuration files.
  • No callback logic; to extend the core functionality inherit from trainer.DefaultTrainer instead.
  • All data-loading is lazy to support training on large data-sets.
# train a 2d U-Net for foreground and boundary segmentation of nuclei
# using data from https://github.com/mpicbg-csbd/stardist/releases/download/0.1.0/dsb2018.zip

import torch
import torch_em
from torch_em.model import UNet2d
from torch_em.data.datasets import get_dsb_loader

model = UNet2d(in_channels=1, out_channels=2)

# transform to go from instance segmentation labels
# to foreground/background and boundary channel
label_transform = torch_em.transform.BoundaryTransform(
    add_binary_target=True, ndim=2
)

# training and validation data loader
data_path = "./dsb"  # the training data will be downloaded and saved here
train_loader = get_dsb_loader(
    data_path, 
    patch_shape=(1, 256, 256),
    batch_size=8
    split="train",
    download=True,
    label_transform=label_transform
)
val_loader = get_dsb_loader(
    data_path, 
    patch_shape=(1, 256, 256),
    batch_size=8,
    split="test",
    label_transform=label_transform
)

# the trainer object that handles the training details
# the model checkpoints will be saved in "checkpoints/dsb-boundary-model"
# the tensorboard logs will be saved in "logs/dsb-boundary-model"
trainer = torch_em.default_segmentation_trainer(
    name="dsb-boundary-model",
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    learning_rate=1e-4,
    device=torch.device("cuda")
)
trainer.fit(iterations=5000)

# export bioimage.io model format
from glob import glob
import imageio
from torch_em.util import export_bioimageio_model

# load one of the images to use as reference image image
# and crop it to a shape that is guaranteed to fit the network
test_im = imageio.imread(glob(f"{data_path}/test/images/*.tif")[0])[:256, :256]

export_bioimageio_model("./checkpoints/dsb-boundary-model", "./bioimageio-model", test_im)

For a more in-depth example, check out one of the example notebooks:

Installation

From conda

You can install torch_em from conda-forge:

conda install -c conda-forge torch_em

Please check out pytorch.org for more information on how to install a pytorch version compatible with your system.

From source

It's recommmended to set up a conda environment for using torch_em. Two conda environment files are provided: environment_cpu.yaml for a pure cpu set-up and environment_gpu.yaml for a gpu set-up. If you want to use the gpu version, make sure to set the correct cuda version for your system in the environment file, by modifiying this-line.

You can set up a conda environment using one of these files like this:

conda env create -f <ENV>.yaml -n <ENV_NAME>
conda activate <ENV_NAME>
pip install -e .

where .yaml is either environment_cpu.yaml or environment_gpu.yaml.

Features

  • Training of 2d U-Nets and 3d U-Nets for various segmentation tasks.
  • Random forest based domain adaptation from Shallow2Deep
  • Training models for embedding prediction with sparse instance labels from SPOCO

Command Line Scripts

TODO

torch-em's People

Contributors

constantinpape avatar fynnbe avatar jonashell avatar czaki avatar paulhfu avatar buglakova avatar oeway avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.