Code Monkey home page Code Monkey logo

torch-classification's Introduction

Torch Classification

License PyTorch CIFAR-100

Torch Classification is a PyTorch-based image classification project showcasing the implementation of the EfficientNet V2 family to classify images. This project covers training the model from scratch and employing transfer learning with pre-trained weights specifically on the CIFAR-100 dataset. Additionally, it delves into the impact of leveraging GANs (BSRGAN & SwinIR) for image super-resolution on the same CIFAR-100 dataset. This initiative was undertaken as part of a Machine Learning course at NUST, emphasizing practical applications of deep learning.

Installation

To get started with this project, follow the steps below:

  • Clone the repository to your local machine using the following command:

    git clone https://github.com/muhd-umer/torch-classification.git
  • It is recommended to create a new virtual environment so that updates/downgrades of packages do not break other projects. To create a new virtual environment, run the following command:

    conda env create -f environment.yml
  • Alternatively, you can use mamba (faster than conda) package manager to create a new virtual environment:

    wget -O miniforge.sh \
         "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh"
    bash miniforge.sh -b -p "${HOME}/conda"
    
    source "${HOME}/conda/etc/profile.d/conda.sh"
    
    # For mamba support also run the following command
    source "${HOME}/conda/etc/profile.d/mamba.sh"
    
    conda activate
    mamba env create -f environment.yml
  • Activate the newly created environment:

    conda activate torch-classification
  • Install the PyTorch Ecosystem:

    # pip will take care of necessary CUDA packages
    pip3 install torch torchvision torchaudio
    
    # additional packages (already included in environment.yml)
    pip3 install einops python-box timm torchinfo \
                 lightning rich wandb rawpy

Dataset

The CIFAR-100 dataset is used for training and testing the model. The dataset can be downloaded from here.

Or, you can use the following commands to download the dataset:

# download as python pickle
cd data
curl -O https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz
tar -xvzf cifar-100-python.tar.gz

# download as ImageNet format
pip3 install cifar2png
cifar2png cifar100 data/cifar100

We also offer super-resolution variants of the CIFAR-100 dataset, which have upscaled the images to 128x128 resolution using BSRGAN 4x and SwinIR. You can download these dataset from the Weights & Data section. Or, you can use the following commands to download the dataset:

wget -O data/bsrgan_4x_cifar100.zip \
    "https://github.com/muhd-umer/torch-classification/releases/download/v0.0.1/bsrgan_4x_cifar100.zip"

# unzip the dataset
unzip -q data/bsrgan_4x_cifar100.zip -d data/

# or
wget -O data/swinir_4x_cifar100.zip \
    "https://github.com/muhd-umer/torch-classification/releases/download/v0.0.1/swinir_4x_cifar100.zip"

# unzip the dataset
unzip -q data/swinir_4x_cifar100.zip -d data/

Usage

To train the model from scratch, run the following command:

# train the model from scratch using default config
python3 train.py

# train the model from scratch using overrides
python3 train.py --mode MODE \  # (train, finetune)
                 --data-dir DATA_DIR \  # directory containing data
                 --model-dir MODEL_DIR \  # directory to save model
                 --batch-size BATCH_SIZE \  # batch size
                 --dataset-type DATASET_TYPE \  # (default, imagefolder)
                 --num-workers NUM_WORKERS \  # number of workers
                 --num-epochs NUM_EPOCHS \  # number of epochs
                 --lr LR \  # learning rate
                 --rich-progress \  # use rich progress bar
                 --accelerator ACCELERATOR \  # type of accelerator
                 --devices DEVICES \  # number of devices
                 --weights WEIGHTS \  # path to weights file
                 --resume \  # resume training from checkpoint
                 --test-only \  # test the model on test set
                 --logger-backend LOGGER_BACKEND  # (wandb, tensorboard)

To evaluate the models, download the appropriate weights from the Weights & Data section and place them in weights/ directory. Then, run the following command:

bash run.sh

# or
python3 train.py --weights WEIGHTS --test-only

Project Structure

The project is structured as follows:

torch-classification
├── data/             # data directory
├── models/           # model directory
├── resources/        # resources directory
├── utils/            # utility directory
├── LICENSE           # license file
├── README.md         # readme file
├── environment.yml   # conda environment file
├── upscale.py        # upscaling script
└── train.py          # training script

Contributing ❤️

Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change.

torch-classification's People

Contributors

muhd-umer avatar

Stargazers

 avatar

Watchers

Kostas Georgiou avatar  avatar

Forkers

ahmd-mohsin

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.