Code Monkey home page Code Monkey logo

torch-em's Introduction

DOC Build Status DOI Anaconda-Server Badge

torch-em

Deep-learning based semantic and instance segmentation for 3D Electron Microscopy and other bioimage analysis problems based on PyTorch. Any feedback is highly appreciated, just open an issue!

Highlights:

  • Functional API with sensible defaults to train a state-of-the-art segmentation model with a few lines of code.
  • Differentiable augmentations on GPU and CPU thanks to kornia.
  • Off-the-shelf logging with tensorboard or wandb.
  • Export trained models to bioimage.io model format with one function call to deploy them in ilastik or deepimageJ.

Design:

  • All parameters are specified in code, no configuration files.
  • No callback logic; to extend the core functionality inherit from trainer.DefaultTrainer instead.
  • All data-loading is lazy to support training on large data-sets.
# train a 2d U-Net for foreground and boundary segmentation of nuclei
# using data from https://github.com/mpicbg-csbd/stardist/releases/download/0.1.0/dsb2018.zip

import torch
import torch_em
from torch_em.model import UNet2d
from torch_em.data.datasets import get_dsb_loader

model = UNet2d(in_channels=1, out_channels=2)

# transform to go from instance segmentation labels
# to foreground/background and boundary channel
label_transform = torch_em.transform.BoundaryTransform(
    add_binary_target=True, ndim=2
)

# training and validation data loader
data_path = "./dsb"  # the training data will be downloaded and saved here
train_loader = get_dsb_loader(
    data_path, 
    patch_shape=(1, 256, 256),
    batch_size=8,
    split="train",
    download=True,
    label_transform=label_transform
)
val_loader = get_dsb_loader(
    data_path, 
    patch_shape=(1, 256, 256),
    batch_size=8,
    split="test",
    label_transform=label_transform
)

# the trainer object that handles the training details
# the model checkpoints will be saved in "checkpoints/dsb-boundary-model"
# the tensorboard logs will be saved in "logs/dsb-boundary-model"
trainer = torch_em.default_segmentation_trainer(
    name="dsb-boundary-model",
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    learning_rate=1e-4,
    device=torch.device("cuda")
)
trainer.fit(iterations=5000)

# export bioimage.io model format
from glob import glob
import imageio
from torch_em.util import export_bioimageio_model

# load one of the images to use as reference image image
# and crop it to a shape that is guaranteed to fit the network
test_im = imageio.imread(glob(f"{data_path}/test/images/*.tif")[0])[:256, :256]

export_bioimageio_model("./checkpoints/dsb-boundary-model", "./bioimageio-model", test_im)

For a more in-depth example, check out one of the example notebooks:

Installation

From mamba

mamba is a drop-in replacement for conda, but much faster. While the steps below may also work with conda, it's highly recommended using mamba. You can follow the instructions here to install mamba.

You can install torch_em from conda-forge:

mamba install -c conda-forge torch_em

Please check out pytorch.org for more information on how to install a PyTorch version compatible with your system.

From source

It's recommmended to set up a conda environment for using torch_em. Two conda environment files are provided: environment_cpu.yaml for a pure CPU set-up and environment_gpu.yaml for a GPU set-up. If you want to use the GPU version, make sure to set the correct CUDA version for your system in the environment file, by modifiying this-line.

You can set up a conda environment using one of these files like this:

mamba env create -f <ENV>.yaml -n <ENV_NAME>
mamba activate <ENV_NAME>
pip install -e .

where <ENV>.yaml is either environment_cpu.yaml or environment_gpu.yaml.

Features

  • Training of 2d U-Nets and 3d U-Nets for various segmentation tasks.
  • Random forest based domain adaptation from Shallow2Deep
  • Training models for embedding prediction with sparse instance labels from SPOCO
  • Training of UNETR for various 2d segmentation tasks, with a flexible choice of vision transformer backbone from Segment Anything or Masked Autoencoder.
  • Training of ViM-UNet for various 2d segmentation tasks.

Command Line Scripts

A command line interface for training, prediction and conversion to the bioimage.io modelzoo format wll be installed with torch_em:

  • torch_em.train_unet_2d: train a 2D U-Net.
  • torch_em.train_unet_3d: train a 3D U-Net.
  • torch_em.predict: run prediction with a trained model.
  • torch_em.predict_with_tiling: run prediction with tiling.
  • torch_em.export_bioimageio_model: export a model to the modelzoo format.

For more details run <COMMAND> -h for any of these commands. The folder scripts/cli contains some examples for how to use the CLI.

Note: this functionality was recently added and is not fully tested.

Research Projects using torch-em

torch-em's People

Contributors

anwai98 avatar buglakova avatar constantinpape avatar czaki avatar fynnbe avatar jonashell avatar oeway avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

torch-em's Issues

Follow ups: SPOCO

I merged #87 without all finishing all TODOs (because other important stuff got mixed up there...). Finish them:

  • add spoco experiment in experiments/spoco and validate that this reproduces the original spoco experiments (try with dsb)
  • clarify the different spoco losses and when to use them
    • ExtendedContrastiveLoss
    • SpocoLoss
    • SPOCOConsistencyLoss

Add support for multichannel compresed tif files

Currently torch-em support only reading files using tifffile only if could use memmap to read files

def load_image(image_path):
if supports_memmap(image_path):
return tifffile.memmap(image_path, mode='r')
else:
# TODO handle multi-channel images
return imageio.imread(image_path)

but memmap could not be used if the file is saved with compression (from tifffile docs):

Memory-mapping requires the image data stored in native byte order,
    without tiling, compression, predictors, etc

the fallback is to use imageio that read-only 2d single channel data. Why not use tifffile.imread if the extension is tif, tiff, TIF, TIFF?

Are you open to contributions?

MitoEM rat volume download doesn't work

https://github.com/constantinpape/torch-em/blob/main/torch_em/data/datasets/mitoem.py#L17 (installation for the rat volume doesn't work)

from the data loader:

requests.exceptions.HTTPError: 404 Client Error: Not Found for url: https://www.dropbox.com/s/dl/kobmxbrabdfkx7y/EM30-R-im.zip    

from wget:

(base) glogin9:/scratch/usr/nimanwai $ wget https://www.dropbox.com/s/dl/kobmxbrabdfkx7y/EM30-R-im.zip                                           
--2024-01-19 17:36:35--  https://www.dropbox.com/s/dl/kobmxbrabdfkx7y/EM30-R-im.zip                                                              
Resolving www.dropbox.com (www.dropbox.com)... 162.125.66.18
Connecting to www.dropbox.com (www.dropbox.com)|162.125.66.18|:443... connected.                                                                 
HTTP request sent, awaiting response... 404 Not Found
2024-01-19 17:36:35 ERROR 404: Not Found.

Make installation easier

  • Bring affogato on conda forge, so that it can be installed on all platforms and current python versions
  • Add elf and affogato to the env file (currently takes forever when trying to solve this env, I blame vigra we should eliminate it from elf...)

Update google drive downloads

environment_cpu.yaml does not exist (perhaps a typo or a missing channel

The install instructions as per readme never worked for me:

mamba create -f <ENV>.yaml -n <ENV_NAME>

results in

Looking for: ['environment_cpu.yaml']
(...)
The following package could not be installed
└─ environment_cpu.yaml   does not exist (perhaps a typo or a missing channel

The solution to this is to use the update command instead
mamba env update -n <your-env> --file environment_cpu.yaml

Source: mamba-org/mamba#633

In case somebody else is facing the same issue.

Support for optional internal padding for SegmentationDataset

I think it would be nice to support optional internal padding (I roughly remember that ImageCollectionDataset has this feature), where if the user desires use SegmentationDataset and wishes to have a patch shape of (1, 512, 512) let's say, and if the sample has an available shape of (1, 436, 456) for example, then this does zero padding around the remaining region to return the desired patch shape.

Meta issue: notes from modelzoo sprint

  • Issues with torchscript export
    • check if adding the torch.jit.scripted_method makes the model scriptable
    • if not, make an issue here and in pytorch forums
  • upload / update via zenodo API? make issue in bioimage.io
  • tiktorch preprocessing issues: data-type and per image preprocessing (make issue)
  • bring up napari infra grant again
  • report issue with running ilastik remote classification (have the screenshots!)

UNETR with SAM initialization is not working yet

I created https://github.com/constantinpape/torch-em/blob/main/experiments/vision-transformer/unetr/initialize_with_sam.py to check it. But it fails when trying to feed a tensor into it.

cc @anwai98

$ python initialize_with_sam.py 
Traceback (most recent call last):
  File "/home/pape/Work/my_projects/torch-em/experiments/vision-transformer/unetr/initialize_with_sam.py", line 8, in <module>
    y = model(x)
  File "/home/pape/software/conda/miniconda3/envs/sam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/pape/Work/my_projects/torch-em/torch_em/model/unetr.py", line 209, in forward
    z12, from_encoder = self.encoder(x)
  File "/home/pape/software/conda/miniconda3/envs/sam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/pape/Work/my_projects/torch-em/torch_em/model/unetr.py", line 54, in forward
    x = blk(x)
  File "/home/pape/software/conda/miniconda3/envs/sam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/pape/Work/my_projects/SegmentAnythingModel/segment_anything/modeling/image_encoder.py", line 174, in forward
    x = self.attn(x)
  File "/home/pape/software/conda/miniconda3/envs/sam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/pape/Work/my_projects/SegmentAnythingModel/segment_anything/modeling/image_encoder.py", line 227, in forward
    qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
  File "/home/pape/software/conda/miniconda3/envs/sam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/pape/software/conda/miniconda3/envs/sam/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (4096x768 and 64x27)

Pin pytorch versions in modelzoo export

Pytorch weights are not always backward comp. (even for >1.0):

Attempted to read a PyTorch file with version 3, but the maximum supported version for reading is 2. Your PyTorch installation may be too old.

Maybe it's possible to find out this version programatically, otherwise just pin ">=CURRENT_VERSION;<=2.0"

questions about image shape

Hi,
I am now training micro-sam with a custom dataset. The shape of my raw imags is HxWx3 and the label's shape is HxW. I found an error "File "/home/jovyan/.conda/micro-sam/lib/python3.10/site-packages/torch_em/data/segmentation_dataset.py", line 61, in init
assert shape_raw == shape_label, f"{shape_raw}, {shape_label}"
AssertionError: (2048, 3), (2048, 2048)".
I am confused with this error. Please give me some tips why it happens and how to solve it.

Readme example does not work

because the tif images have different shapes.
Need to catch this in the dataset and implement a special case for it.

2D-UNet training notebook doesn't work for Colab

For code in https://colab.research.google.com/github/constantinpape/torch-em/blob/main/experiments/2D-UNet-Training.ipynb

!nvidia-smi

NVIDIA-SMI has failed because it couldn't communicate with the NVIDIA driver. Make sure that the latest NVIDIA driver is installed and running.

!conda install -c conda-forge torch_em "pyyaml<5.0"

Collecting package metadata (current_repodata.json): done
Solving environment: failed with initial frozen solve. Retrying with flexible solve.
Solving environment: failed with repodata from current_repodata.json, will retry with next repodata source.
Collecting package metadata (repodata.json): done
Solving environment: /
The environment is inconsistent, please check the package plan carefully
The following packages are causing the inconsistency:

  • conda-forge/linux-64::mamba==0.8.0=py37h7f483ca_failed with initial frozen solve. Retrying with flexible solve.
    Solving environment: /
    Found conflicts! Looking for incompatible packages.
    This can take several minutes. Press CTRL-C to abortfailed

UnsatisfiableError: The following specifications were found
to be incompatible with the existing python installation in your environment:

Specifications:

  • pyyaml[version='<5.0'] -> python[version='>=3.10,<3.11.0a0|>=3.9,<3.10.0a0']

Your python: python=3.7

If python is on the left-most side of the chain, that's the version you've asked for.
When python appears to the right, that indicates that the thing on the left is somehow
not available for the python version you are constrained to. Note that conda will not
change your python version to a different minor version unless you explicitly specify
that.

The following specifications were found to be incompatible with each other:

Output in format: Requested package -> Available versions

Package setuptools conflicts for:
python=3.7 -> pip -> setuptools
torch_em -> pytorch -> setuptools[version='<59.6|>=41.0.0']

Package pypy3.6 conflicts for:
torch_em -> python[version='>=3.6'] -> pypy3.6[version='7.3.0.|7.3.1.|7.3.2.|7.3.3.']
pyyaml[version='<5.0'] -> python[version='>=3.6,<3.7.0a0'] -> pypy3.6[version='7.3.0.|7.3.1.|7.3.2.|7.3.3.|>=7.3.3|>=7.3.2|>=7.3.1']

Package _libgcc_mutex conflicts for:
python=3.7 -> libgcc-ng[version='>=9.4.0'] -> _libgcc_mutex[version='|0.1',build='conda_forge|main']
pyyaml[version='<5.0'] -> libgcc-ng[version='>=7.3.0'] -> _libgcc_mutex[version='
|0.1',build='conda_forge|main']The following specifications were found to be incompatible with your system:

  • feature:/linux-64::__glibc==2.27=0
  • torch_em -> pytorch -> __cuda
  • torch_em -> pytorch -> __glibc[version='>=2.17|>=2.17,<3.0.a0']

Your installed version is: 2.27

Platynereis cilia val split

I checked out just now that platy-cilia has a val split provided in the dataset itself (3 train volumes and 2 val volumes)

TODO:

  • Make use of the val split explicity. if it's meant to be used (@constantinpape ?)
    • Pass an argument which gets the val split and train split (on top of RoIs, if needed)

Enable UNETR for dynamic input shape

Currently our UNETR implementation has a fixed input shape, see https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py#L64.

This is due to the fixed input shape of the underlying VIT implementation (either TIMM/MAE or SAM). However, this is only due to the fixed positional encoding size. Otherwise the transformer could process sequences of arbitrary length (and consequently images of dynamic shape as long as their divisible by the patch shape).

It would be nice to update this so that arbitrary input shapes are supported. But this is currently not a priority. cc @anwai98.

Sample code from Readme ends with error

After trying to run the sample code on my machine, it ends with an error.

  File "/home/czaki/Projekty/entropy-train/sample_run.py", line 53, in <module>
    export_bioimageio_model("./checkpoints/dsb-boundary-model", "./bioimageio-model", test_im)
  File "/home/czaki/mambaforge/envs/entropy_train/lib/python3.10/site-packages/torch_em/util/modelzoo.py", line 500, in export_bioimageio_model
    test_in_paths, test_out_paths = _write_data(input_data, model, trainer, export_folder)
  File "/home/czaki/mambaforge/envs/entropy_train/lib/python3.10/site-packages/torch_em/util/modelzoo.py", line 161, in _write_data
    test_outputs = model(*test_tensors)
  File "/home/czaki/mambaforge/envs/entropy_train/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/czaki/mambaforge/envs/entropy_train/lib/python3.10/site-packages/torch_em/model/unet.py", line 214, in forward
    self._check_shape(x)
  File "/home/czaki/mambaforge/envs/entropy_train/lib/python3.10/site-packages/torch_em/model/unet.py", line 208, in _check_shape
    raise ValueError(msg)
ValueError: Invalid shape for U-Net: (520, 696) is not divisible by [16, 16]

After checking the code, it looks like glob order on my machine is different than author's one.

I added such code to get it working:

size_dim = min(test_im.shape[0]//16*16, test_im.shape[1]//16*16)
test_im = test_im[:size_dim, :size_dim]

But maybe there is a better solution?

Error in _get_kwarg function when handling lists as input

Hi,
firstly, I want to thanks for this repo! I've been using the 2D-UNet notebook, and I encountered an issue when attempting to add a description,tags,authors,... of the model as my input . I think that to enable both single and double quotes, the line of code should be modified to something like this ?
val = val.replace(" ' ", ' " ')

here is the line of code.

val = val.replace(""", """) # enable single quotes

Multiple class output

If I correctly go through the code currently only object segmentation is supported and there is no support for assigning objects to one of the predefined class?

If I'm wrong could you point me how to do this. Or maybe it is possible to somehow modify code to achieve this?

Issue when running deform:compress augmentation

  File "torch_em/transform/defect.py", line 184, in deform_slice
    raw = self.compress_slice(raw)
  File "ransform/defect.py", line 136, in compress_slice
    assert len(np.unique(components)) == 3, "%i" % len(np.unique(components))
AssertionError: 2

Need to add a test for all the augmentation modes! (And then fix this and reactivate the compress augmentation in cremi)

Request for DistributedDataParallel

To allow for the training of larger models split across GPUS using the DefaultTrainer class, implement model parallel capabilities using the torch.DistributedDataParallel functionality.

Save running metric for each checkpoint

It would be helpful to save the current metric (store it as current_metric) here for:
a. The latest checkpoint metric
b. The save_every_kth_epoch checkpoint metric

(It is potentially helpful for various experiments to check the metric at certain checkpoints (which could / could not be the best one))

request for more examples

Thank you for the effort building this repo.
I'm wondering if it is possible to add more experiment examples such as some neural segmentaion scripts written in out-dated Python2.7
For example, the neural network training for mutex-watershed and multicut.
It would not only provide more context for users of torch-em but also help maintain the previous algorithms' utility
Any kind of feedback from you would be appreciated. :)

Error messages for ROIs

Description

When using ROIs in the torch_em Dataloader (I used: torch_em.default_segmentation_loader) there is a generic error message if the ROIs do not fit (e.g. empty or incorrect size).
Error message:

File "/home/freckmann15/miniforge3/envs/sam/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/home/freckmann15/miniforge3/envs/sam/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1319, in _next_data
    raise StopIteration
StopIteration

This message is quite difficult to interpret and therefore it is difficult to fix.
The error can be fixed by having a min_shape for the ROIs so ROIs are at least the shape of the patch_shape used.
I would suggest a more specific error message for this case.

Can't Install torch'em

Hi,

I installed torch'em about three weeks ago easily using this command:

conda install -c conda-forge torch_em

but now I can't install it and It is stuck in the installation process:

image

can you help me with it !?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.