Code Monkey home page Code Monkey logo

pretrained-backbones-unet's Introduction

Pretrained Backbones with UNet

A PyTorch-based Python library with UNet architecture and multiple backbones for Image Semantic Segmentation.

Generic badge PyPI PyPI - Downloads
PyTorch - Version Python - Version

Overview

This is a simple package for semantic segmentation with UNet and pretrained backbones. This package utilizes the timm models for the pre-trained encoders.

When dealing with relatively limited datasets, initializing a model using pre-trained weights from a large dataset can be an excellent choice for ensuring successful network training. By utilizing state-of-the-art models, such as ConvNeXt, as an encoder, you can effortlessly solve the problem at hand while achieving optimal performance in this context.

The primary characteristics of this library are as follows:

  • 430 pre-trained backbone networks are available for the UNet semantic segmentation model.

  • Supports backbone networks such as ConvNext, ResNet, EfficientNet, DenseNet, RegNet, and VGG... which are popular and SOTA performers, for the UNet model.

  • It is possible to adjust which layers of the backbone of the model are trainable parametrically.

  • It includes a DataSet class for binary and multi-class semantic segmentation.

  • And it comes with a pre-built rapid custom training class.

Installation

Pypi version:

pip install pretrained-backbones-unet

Source code version:

pip install git+https://github.com/mberkay0/pretrained-backbones-unet

Usage

from backbones_unet.model.unet import Unet
from backbones_unet.utils.dataset import SemanticSegmentationDataset
from backbones_unet.model.losses import DiceLoss
from backbones_unet.utils.trainer import Trainer

# create a torch.utils.data.Dataset/DataLoader
train_img_path = 'example_data/train/images' 
train_mask_path = 'example_data/train/masks'

val_img_path = 'example_data/val/images' 
val_mask_path = 'example_data/val/masks'

train_dataset = SemanticSegmentationDataset(train_img_path, train_mask_path)
val_dataset = SemanticSegmentationDataset(val_img_path, val_mask_path)

train_loader = DataLoader(train_dataset, batch_size=2)
val_loader = DataLoader(val_dataset, batch_size=2)

model = Unet(
    backbone='convnext_base', # backbone network name
    in_channels=3,            # input channels (1 for gray-scale images, 3 for RGB, etc.)
    num_classes=1,            # output channels (number of classes in your dataset)
)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(params, 1e-4) 

trainer = Trainer(
    model,                    # UNet model with pretrained backbone
    criterion=DiceLoss(),     # loss function for model convergence
    optimizer=optimizer,      # optimizer for regularization
    epochs=10                 # number of epochs for model training
)

trainer.fit(train_loader, val_loader)

Available Pretrained Backbones

import backbones_unet

print(backbones_unet.__available_models__)

pretrained-backbones-unet's People

Contributors

mberkay0 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.