Code Monkey home page Code Monkey logo

minrev's Introduction

minREV

Inspired by minGPT

A PyTorch reimplementation of Reversible Vision Transformer architecture that is prefers simplicity over tricks, hackability over tedious organization, and interpretability over generality.

It is meant to serve as an educational guide for newcomers that are not familiar with the reversible backpropagation algorithm and reversible vision transformer.

The entire Reversible Vision Transformer is implemented from scratch in under <300 lines of pytorch code, including the memory-efficient reversible backpropagation algorithm (<100 lines). Even the driver code is < 150 lines. The repo supports both memory-efficient training and testing on CIFAR-10.

๐Ÿ’ฅ The CVPR 2021 oral talk for a 5-minute introduction to RevViT.

๐Ÿ’ฅ A gentle and in-depth 15 minute introduction to RevViT.

๐Ÿ’ฅ (4/22 Update) We have added in an implementation of fast, parallelized reversible backpropagation (paper coming soon)!

Setting Up

Simple! ๐ŸŒŸ

(if using conda for env, otherwise use pip)

conda create -n revvit python=3.8
conda activate revvit
conda install pytorch torchvision pytorch-cuda=11.7 -c pytorch -c nvidia

Code Organization

The code organization is also minimal ๐Ÿ’ซ:

  • rev.py defines the reversible vision model that supports:
    • The vanilla backpropagation
    • The memory-efficient reversible backpropagation
  • fast_rev.py contains a fast, parallelized reversible backpropagation (paper coming soon). Use --pareprop True to enable.
  • main.py that has the driver code for training on CIFAR-10.

Running CIFAR-10

python main.py --lr 1e-3 --bs 128 --embed_dim 128 --depth 6 --n_head 8 --epochs 100

This will achieve 80%+ validation accuracy on CIFAR-10 from scratch training!

Here are the Training/Validation Logs ๐Ÿ’ฏ

python main.py --lr 1e-3 --bs 128 --embed_dim 128 --depth 6 --n_head 8 --epochs 100 --vanilla_bp True

Will train the same network but without memory-efficient backpropagation to the same accuracy as above. Hence, there is no accuracy drop from the memory-efficient reversible backpropagation.

Here are the Training/Validation Logs ๐Ÿ’ฏ

๐Ÿ‘๏ธ Note: The relatively low accuracy is due to difficulty in training vision transformer (reversible or vanilla) from scratch on small datasets like CIFAR-10. Also likely is that6 a much higher accuracy can be achieved with the same code, using a better chosen model design and optimization parameters. The authors have done no tuning since this repository is meant for understanding code, not pushing performance.

Mixed precision training

Mixed precision training is also supported and can be enabled by adding --amp True flag to above commands. Training progresses smoothly and achieves 80%+ validation accuracy on CIFAR-10 similar to training without AMP.

๐Ÿ“ Note: Pytorch vanilla AMP, maintains full precision (fp32) on weights and only uses half-precision (fp16) on intermediate activations. Since reversible is already saving up on almost all intermediate activations (see video for examplanation), using AMP (ie half-precision on activations) brings little additional memory savings. For example, on a 16G V100 setup, AMP can improve rev maximum CIFAR-10 batch size from 12000 to 14500 ( ~20%). At usual training batch size (128) there is small gain in GPU training memory (about 4%).

Distributed Data Parallel Training

There are no additional overheads for DDP training with reversible that progresses the same as vanilla training. All results in paper (also see below) are obtained in DDP setups (>64 GPUs per run). However, implementing distributed training is not commensurate with the purpose of this repo, and instead can be found in the pyslowfast distributed training setup.

Running ImageNet, Kinetics-400 and more

For more usecases such as reproducing numbers from original paper, see the full code in PySlowFast that supports

  • ImageNet
  • Kinetics-400/600/700
  • RevViT, all sizes with configs
  • RevMViT, all sizes with configs

to state-of-the-art accuracies.

minrev's People

Contributors

karttikeya avatar tyleryzhu avatar eltociear avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.