Code Monkey home page Code Monkey logo

pytorch-classification's Introduction

Config File is All You Need: An Image Classification Codebase Written in PyTorch

This project aims at providing the necessary building blocks for easily creating an image classification model using PyTorch.

Note: I finished this project in my spare time within a week. So there is still a lot of work to be done.

Highlights

  • Convenient: You can use a config file to create an image classification model and train on your own datasets without writing any code.
  • Extensible: You can write your own modules (Dataset, Transform, Network, Loss and so on) and register them to the default config easily.
  • Parameter-is-Module: You can create a module by the parameter which is consisted of a module name and an argument list.
  • Multi-GPU training and inference: You can train your model on one GPU or use multi-GPU to train the model in parallel.

Accuracy

The top-1 accuracy (%) of different models in CIFAR-10 are shown as below. NETWORK_STRIDE is set to (2,2,2,2,2) and (1,1,2,2,2) respectively. Refer to DETAILS.md for more details about the parameter NETWORK_STRIDE.

Model (2,2,2,2,2) (1,1,2,2,2)
ResNet-18 86.10 92.64
ResNet-34 86.14 92.73
ResNet-50 86.65 92.20
ResNet-101 87.41 93.27
ResNet-152 87.01 93.25
ResNeXt-50, 32x4d 87.56 93.65
ResNeXt-101, 32x8d 88.24 93.75

Installation

pip3 install -r requirements.txt

Inference in a few lines

We provide a helper class to simplify writing inference pipelines using pre-trained models. Here is how we would do it. Run following code from the demo folder. (The pre-trained model and sample images can be downloaded from here [Baidu(PWD: f25u)][OneDrive]. You can choose any config in ./configs)

import sys

sys.path.append("../")

import cv2

from predictor import ClsDemo
from pytorch_classification.config import cfg


config = "/path_to_config"
img_path = "/path_to_image"
checkpoint_path = "/path_to_pre-trained_model"

cfg.merge_from_file(config)
cfg.merge_from_list(["CHECKPOINT", checkpoint_path])

cls_demo = ClsDemo(cfg)

image = cv2.imread(img_path)
pred = cls_demo.run_on_openv_image(image)

Perform training on CIFAR-10 dataset

You need to download the CIFAR-10 dataset and convert it the required GeneralDataset format in this codebase. (You can also download from here [Baidu(PWD: f25u)][OneDrive], which has been reformatted.) We recommend to symlink the path to the cifar-10 dataset to ./datasets as follows

# symlink the cifar-10 dataset
cd pytorch-classification
mkdir -p datasets/cifar-10
ln -s /path_to_cifar-10_dataset datasets/cifar-10

You can also configure your own paths to the datasets. For that, all you need to do is to modify ./pytorch_classification/config/data_catalog.py to point to the location where your dataset is stored.(See DETAILS.md for more details.)

Single GPU training

You can run the following without modifications to train your model on a single GPU.

python3 tools/train.py --config-file "configs/config_cifar10_R50_1gpu.yaml"

Multi-GPU training

We use internally torch.distributed.launch in order to launch multi-GPU training. This utility function from PyTorch spawns as many Python processes as the number of GPUs we want to use, and each Python process will only use a single GPU.

export NGPUS=8
python3 -m torch.distributed.launch --nproc_per_node=$NGPUS tools/train.py --config-file "configs/config_cifar10_R50_8gpu.yaml"

If you want to train your model on more GPUs, you should change the batch size SOLVER.BATCH_SIZE and learning rate SOLVER.BASE_LR adaptively.

Evaluation

You can test your model directly on single or multiple GPUs. Here is an example for multi-GPU testing:

export NGPUS=8
python3 -m torch.distributed.launch --nproc_per_node=$NGPUS tools/test.py --config-file "configs/config_cifar10_R50_8gpu.yaml"

Details

You can refer to DETAILS.md for more details.

License

This project is released under the MIT license. See LICENSE for additional details.

Acknowledgement

This codebase is heavily influenced by the project maskrcnn-benchmark.

pytorch-classification's People

Contributors

hysapphire avatar

Stargazers

 avatar

Watchers

 avatar

Forkers

kongbia

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.