Code Monkey home page Code Monkey logo

segmentation_models_pytorch_3d's Introduction

Segmentation Models Pytroch 3D

Python library with Neural Networks for Volume (3D) Segmentation based on PyTorch.

This library is based on famous Segmentation Models Pytorch library for images. Most of the documentation can be used directly from there.

Installation

  • Type 1: pip install segmentation_models_pytorch_3d
  • Type 2: Copy segmentation_models_pytorch_3d folder from this repository in your project folder.

Quick start

Segmentation model is just a PyTorch nn.Module, which can be created as easy as:

import segmentation_models_pytorch_3d as smp
import torch

model = smp.Unet(
    encoder_name="efficientnet-b0", # choose encoder, e.g. resnet34
    in_channels=1,                  # model input channels (1 for gray-scale volumes, 3 for RGB, etc.)
    classes=3,                      # model output channels (number of classes in your dataset)
)

# Shape of input (B, C, H, W, D). B - batch size, C - channels, H - height, W - width, D - depth
res = model(torch.randn(4, 1, 64, 64, 64)) 

Models

Architectures

Encoders

The following is a list of supported encoders in the SMP. Select the appropriate family of encoders and click to expand the table and select a specific encoder and its pre-trained weights (encoder_name and encoder_weights parameters).

ResNet
Encoder Weights Params, M
resnet18 imagenet / ssl / swsl 11M
resnet34 imagenet 21M
resnet50 imagenet / ssl / swsl 23M
resnet101 imagenet 42M
resnet152 imagenet 58M
ResNeXt
Encoder Weights Params, M
resnext50_32x4d imagenet / ssl / swsl 22M
resnext101_32x4d ssl / swsl 42M
resnext101_32x8d imagenet / instagram / ssl / swsl 86M
resnext101_32x16d instagram / ssl / swsl 191M
resnext101_32x32d instagram 466M
resnext101_32x48d instagram 826M
SE-Net
Encoder Weights Params, M
senet154 imagenet 113M
se_resnet50 imagenet 26M
se_resnet101 imagenet 47M
se_resnet152 imagenet 64M
se_resnext50_32x4d imagenet 25M
se_resnext101_32x4d imagenet 46M
DenseNet
Encoder Weights Params, M
densenet121 imagenet 6M
densenet169 imagenet 12M
densenet201 imagenet 18M
densenet161 imagenet 26M
EfficientNet
Encoder Weights Params, M
efficientnet-b0 imagenet 4M
efficientnet-b1 imagenet 6M
efficientnet-b2 imagenet 7M
efficientnet-b3 imagenet 10M
efficientnet-b4 imagenet 17M
efficientnet-b5 imagenet 28M
efficientnet-b6 imagenet 40M
efficientnet-b7 imagenet 63M
DPN
Encoder Weights Params, M
dpn68 imagenet 11M
dpn68b imagenet+5k 11M
dpn92 imagenet+5k 34M
dpn98 imagenet 58M
dpn107 imagenet+5k 84M
dpn131 imagenet 76M
VGG
Encoder Weights Params, M
vgg11 imagenet 9M
vgg11_bn imagenet 9M
vgg13 imagenet 9M
vgg13_bn imagenet 9M
vgg16 imagenet 14M
vgg16_bn imagenet 14M
vgg19 imagenet 20M
vgg19_bn imagenet 20M
Mix Vision Transformer

Backbone from SegFormer pretrained on Imagenet! Can be used with other decoders from package, you can combine Mix Vision Transformer with Unet, FPN and others!

Limitations:

  • encoder is not supported by Linknet, Unet++
  • encoder is supported by FPN only for encoder depth = 5
Encoder Weights Params, M
mit_b0 imagenet 3M
mit_b1 imagenet 13M
mit_b2 imagenet 24M
mit_b3 imagenet 44M
mit_b4 imagenet 60M
mit_b5 imagenet 81M
MobileOne

Apple's "sub-one-ms" Backbone pretrained on Imagenet! Can be used with all decoders.

Note: In the official github repo the s0 variant has additional num_conv_branches, leading to more params than s1.

Encoder Weights Params, M
mobileone_s0 imagenet 4.6M
mobileone_s1 imagenet 4.0M
mobileone_s2 imagenet 6.5M
mobileone_s3 imagenet 8.8M
mobileone_s4 imagenet 13.6M

Notes for 3D version

Input size

Recommended input size for backbones can be calculated as: K = pow(N, 2/3). Where N - is size for input image for the same model in 2D variant.

For example for N = 224, K = 32. For N = 512, K = 64.

Strides

Typical strides for 2D case is 2 for H and W. It applied depth times (in almost all cases 5 times). So input image reduced from (224, 224) to (7, 7) on final layers. For 3D case because of very massive input, it's sometimes useful to control strides for every dimension independently. For this you can use input variable strides, which default values is: strides=((2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2)). Example:

Let's say you have input data of size: (224, 128, 12). You can use strides like that: ((2, 2, 2), (4, 2, 1), (2, 2, 2), (2, 2, 1), (1, 2, 3)). Output shape for these strides will be: (7, 4, 1)

import segmentation_models_pytorch_3d as smp
import torch

model = smp.Unet(
    encoder_name="resnet50",        
    in_channels=1,                  
    strides=((2, 2, 2), (4, 2, 1), (2, 2, 2), (2, 2, 1), (1, 2, 3)),
    classes=3, 
)

res = model(torch.randn(4, 1, 224, 128, 12)) 

Note: Strides currently supported by resnet-family and densenet models with Unet decoder only.

Related repositories

Citation

If you find this code useful, please cite it as:

@article{solovyev20223d,
  title={3D convolutional neural networks for stalled brain capillary detection},
  author={Solovyev, Roman and Kalinin, Alexandr A and Gabruseva, Tatiana},
  journal={Computers in Biology and Medicine},
  volume={141},
  pages={105089},
  year={2022},
  publisher={Elsevier},
  doi={10.1016/j.compbiomed.2021.105089}
}

To Do List

  • Support for strides for all encoders
  • Add timm_ models

segmentation_models_pytorch_3d's People

Contributors

zfturbo avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Forkers

azadehazin

segmentation_models_pytorch_3d's Issues

find module path for python by modifying to relative module path

Hi, I just installed your package via Type 2: Copy segmentation_models_pytorch_3d folder from this repository in your project folder. .

I discovered that certain files, particularly those in the model and decoder, rely on the segmentation_models_pytorch_3d module as part of the import root directory. This made it difficult for Python to locate the module when I copied the segmentation_models_pytorch_3d folder to my project directory. To resolve this, I changed the import format to a relative import (as attached). Is it a good choice to do so, or is there any solution to make Python find the files without modifying the source codes?

Screenshot 2024-03-14 at 17 22 38

Here is the list of files where I found the same issue. (maybe more)

modified:   segmentation_models_pytorch_3d/decoders/deeplabv3/model.py
modified:   segmentation_models_pytorch_3d/decoders/fpn/model.py
modified:   segmentation_models_pytorch_3d/decoders/linknet/decoder.py
modified:   segmentation_models_pytorch_3d/decoders/linknet/model.py
modified:   segmentation_models_pytorch_3d/decoders/manet/decoder.py
modified:   segmentation_models_pytorch_3d/decoders/manet/model.py
modified:   segmentation_models_pytorch_3d/decoders/pan/model.py
modified:   segmentation_models_pytorch_3d/decoders/pspnet/decoder.py
modified:   segmentation_models_pytorch_3d/decoders/pspnet/model.py
modified:   segmentation_models_pytorch_3d/decoders/unet/decoder.py
modified:   segmentation_models_pytorch_3d/decoders/unet/model.py
modified:   segmentation_models_pytorch_3d/decoders/unetplusplus/decoder.py
modified:   segmentation_models_pytorch_3d/decoders/unetplusplus/model.py
modified:   segmentation_models_pytorch_3d/encoders/__init__.py
modified:   segmentation_models_pytorch_3d/encoders/efficientnet.py

About Licenses

Thanks for releasing a great repository!
We used it for a community competition and it was very easy to use!

It would be great if you could clarify the license to allow commercial use, etc.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.