Code Monkey home page Code Monkey logo

efficient_densenet_pytorch

A PyTorch >=1.0 implementation of DenseNets, optimized to save GPU memory.

Recent updates

  1. Now works on PyTorch 1.0! It uses the checkpointing feature, which makes this code WAY more efficient!!!

Motivation

While DenseNets are fairly easy to implement in deep learning frameworks, most implmementations (such as the original) tend to be memory-hungry. In particular, the number of intermediate feature maps generated by batch normalization and concatenation operations grows quadratically with network depth. It is worth emphasizing that this is not a property inherent to DenseNets, but rather to the implementation.

This implementation uses a new strategy to reduce the memory consumption of DenseNets. We use checkpointing to compute the Batch Norm and concatenation feature maps. These intermediate feature maps are discarded during the forward pass and recomputed for the backward pass. This adds 15-20% of time overhead for training, but reduces feature map consumption from quadratic to linear.

This implementation is inspired by this technical report, which outlines a strategy for efficient DenseNets via memory sharing.

Requirements

  • PyTorch >=1.0.0
  • CUDA

Usage

In your existing project: There is one file in the models folder.

If you care about speed, and memory is not an option, pass the efficient=False argument into the DenseNet constructor. Otherwise, pass in efficient=True.

Options:

  • All options are described in the docstrings of the model files
  • The depth is controlled by block_config option
  • efficient=True uses the memory-efficient version
  • If you want to use the model for ImageNet, set small_inputs=False. For CIFAR or SVHN, set small_inputs=True.

Running the demo:

The only extra package you need to install is python-fire:

pip install fire
  • Single GPU:
CUDA_VISIBLE_DEVICES=0 python demo.py --efficient True --data <path_to_folder_with_cifar10> --save <path_to_save_dir>
  • Multiple GPU:
CUDA_VISIBLE_DEVICES=0,1,2 python demo.py --efficient True --data <path_to_folder_with_cifar10> --save <path_to_save_dir>

Options:

  • --depth (int) - depth of the network (number of convolution layers) (default 40)
  • --growth_rate (int) - number of features added per DenseNet layer (default 12)
  • --n_epochs (int) - number of epochs for training (default 300)
  • --batch_size (int) - size of minibatch (default 256)
  • --seed (int) - manually set the random seed (default None)

Performance

A comparison of the two implementations (each is a DenseNet-BC with 100 layers, batch size 64, tested on a NVIDIA Pascal Titan-X):

Implementation Memory cosumption (GB/GPU) Speed (sec/mini batch)
Naive 2.863 0.165
Efficient 1.605 0.207
Efficient (multi-GPU) 0.985 -

Other efficient implementations

Reference

@article{pleiss2017memory,
  title={Memory-Efficient Implementation of DenseNets},
  author={Pleiss, Geoff and Chen, Danlu and Huang, Gao and Li, Tongcheng and van der Maaten, Laurens and Weinberger, Kilian Q},
  journal={arXiv preprint arXiv:1707.06990},
  year={2017}
}

liguiming77's Projects

action-recognition icon action-recognition

Exploration of different solutions to action recognition in video, using neural networks implemented in PyTorch.

apex icon apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch

awesome-chatbot icon awesome-chatbot

Awesome Chatbot Projects,Corpus,Papers,Tutorials.Chinese Chatbot =>:

basicsr icon basicsr

Open Source Image and Video Restoration Toolbox for Super-resolution, Denoise, Deblurring, etc. Currently, it includes EDSR, RCAN, SRResNet, SRGAN, ESRGAN, EDVR, etc. Also support StyleGAN2, DFDNet.

bert-as-service icon bert-as-service

Mapping a variable-length sentence to a fixed-length vector using BERT model

bertsimilarity icon bertsimilarity

Computing similarity of two sentences with google's BERT algorithm。利用Bert计算句子相似度。语义相似度计算。文本相似度计算。

break-a-scene icon break-a-scene

Official implementation for "Break-A-Scene: Extracting Multiple Concepts from a Single Image" [SIGGRAPH Asia 2023]

chatterbot icon chatterbot

ChatterBot is a machine learning, conversational dialog engine for creating chat bots

cocoapi icon cocoapi

COCO API - Dataset @ http://cocodataset.org/

deepnude_nowatermark_withmodel icon deepnude_nowatermark_withmodel

DeepNude source code,without watermark,with demo and model download link,one command to run offline,GAN/Pytorch/pix2pix/pic2pic

deeppavlov icon deeppavlov

An open source library for deep learning end-to-end dialog systems and chatbots.

deepqa icon deepqa

My tensorflow implementation of "A neural conversational model", a Deep learning based chatbot

detectron icon detectron

FAIR's research platform for object detection research, implementing popular algorithms like Mask R-CNN and RetinaNet.

detectron2 icon detectron2

Detectron2 is FAIR's next-generation platform for object detection, segmentation and other visual recognition tasks.

dlrm icon dlrm

An implementation of a deep learning recommendation model (DLRM)

drqa icon drqa

Reading Wikipedia to Answer Open-Domain Questions

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.