Code Monkey home page Code Monkey logo

memcnn's Introduction

MemCNN

CircleCI - Status master branch Docker - Status Documentation - Status master branch Codacy - Branch grade Codecov - Status master branch PyPI - Latest release Conda - Latest release PyPI - Implementation PyPI - Python version GitHub - Repository license JOSS - DOI

A PyTorch framework for developing memory-efficient invertible neural networks.

Features

  • Enable memory savings during training by wrapping arbitrary invertible PyTorch functions with the InvertibleModuleWrapper class.
  • Simple toggling of memory saving by setting the keep_input property of the InvertibleModuleWrapper.
  • Turn arbitrary non-linear PyTorch functions into invertible versions using the AdditiveCoupling or the AffineCoupling classes.
  • Training and evaluation code for reproducing RevNet experiments using MemCNN.
  • CI tests for Python v3.7 and torch v1.0, v1.1, v1.4 and v1.7 with good code coverage.

Examples

Creating an AdditiveCoupling with memory savings

import torch
import torch.nn as nn
import memcnn


# define a new torch Module with a sequence of operations: Relu o BatchNorm2d o Conv2d
class ExampleOperation(nn.Module):
    def __init__(self, channels):
        super(ExampleOperation, self).__init__()
        self.seq = nn.Sequential(
                                    nn.Conv2d(in_channels=channels, out_channels=channels,
                                              kernel_size=(3, 3), padding=1),
                                    nn.BatchNorm2d(num_features=channels),
                                    nn.ReLU(inplace=True)
                                )

    def forward(self, x):
        return self.seq(x)


# generate some random input data (batch_size, num_channels, y_elements, x_elements)
X = torch.rand(2, 10, 8, 8)

# application of the operation(s) the normal way
model_normal = ExampleOperation(channels=10)
model_normal.eval()

Y = model_normal(X)

# turn the ExampleOperation invertible using an additive coupling
invertible_module = memcnn.AdditiveCoupling(
    Fm=ExampleOperation(channels=10 // 2),
    Gm=ExampleOperation(channels=10 // 2)
)

# test that it is actually a valid invertible module (has a valid inverse method)
assert memcnn.is_invertible_module(invertible_module, test_input_shape=X.shape)

# wrap our invertible_module using the InvertibleModuleWrapper and benefit from memory savings during training
invertible_module_wrapper = memcnn.InvertibleModuleWrapper(fn=invertible_module, keep_input=True, keep_input_inverse=True)

# by default the module is set to training, the following sets this to evaluation
# note that this is required to pass input tensors to the model with requires_grad=False (inference only)
invertible_module_wrapper.eval()

# test that the wrapped module is also a valid invertible module
assert memcnn.is_invertible_module(invertible_module_wrapper, test_input_shape=X.shape)

# compute the forward pass using the wrapper
Y2 = invertible_module_wrapper.forward(X)

# the input (X) can be approximated (X2) by applying the inverse method of the wrapper on Y2
X2 = invertible_module_wrapper.inverse(Y2)

# test that the input and approximation are similar
assert torch.allclose(X, X2, atol=1e-06)

Run PyTorch Experiments

After installing MemCNN run:

python -m memcnn.train [MODEL] [DATASET] [--fresh] [--no-cuda]
  • Available values for DATASET are cifar10 and cifar100.
  • Available values for MODEL are resnet32, resnet110, resnet164, revnet38, revnet110, revnet164
  • Use the --fresh flag to remove earlier experiment results.
  • Use the --no-cuda flag to train on the CPU rather than the GPU through CUDA.

Datasets are automatically downloaded if they are not available.

When using Python 3.* replace the python directive with the appropriate Python 3 directive. For example when using the MemCNN docker image use python3.6.

When MemCNN was installed using pip or from sources you might need to setup a configuration file before running this command. Read the corresponding section about how to do this here: https://memcnn.readthedocs.io/en/latest/installation.html

Results

TensorFlow results were obtained from the reversible residual network running the code from their GitHub.

The PyTorch results listed were recomputed on June 11th 2018, and differ from the results in the ICLR paper. The Tensorflow results are still the same.

Prediction accuracy

  Cifar-10 Cifar-100
Model Tensorflow PyTorch Tensorflow PyTorch
resnet-32 92.74 92.86 69.10 69.81
resnet-110 93.99 93.55 73.30 72.40
resnet-164 94.57 94.80 76.79 76.47
revnet-38 93.14 92.80 71.17 69.90
revnet-110 94.02 94.10 74.00 73.30
revnet-164 94.56 94.90 76.39 76.90

Training time (hours : minutes)

  Cifar-10 Cifar-100
Model Tensorflow PyTorch Tensorflow PyTorch
resnet-32 2:04 1:51 1:58 1:51
resnet-110 4:11 2:51 6:44 2:39
resnet-164 11:05 4:59 10:59 3:45
revnet-38 2:17 2:09 2:20 2:16
revnet-110 6:59 3:42 7:03 3:50
revnet-164 13:09 7:21 13:12 7:17

Memory consumption of model training in PyTorch

Layers Parameters Parameters (MB) Activations (MB)
ResNet RevNet ResNet RevNet ResNet RevNet ResNet RevNet
32 38 466906 573994 1.9 2.3 238.6 85.6
110 110 1730714 1854890 6.8 7.3 810.7 85.7
164 164 1704154 1983786 6.8 7.9 2452.8 432.7

The ResNet model is the conventional Residual Network implementation in PyTorch, while the RevNet model uses the memcnn.InvertibleModuleWrapper to achieve memory savings.

Works using MemCNN

Citation

Sil C. van de Leemput, Jonas Teuwen, Bram van Ginneken, and Rashindra Manniesing. MemCNN: A Python/PyTorch package for creating memory-efficient invertible neural networks. Journal of Open Source Software, 4, 1576, http://dx.doi.org/10.21105/joss.01576, 2019.

If you use our code, please cite:

@article{vandeLeemput2019MemCNN,
  journal = {Journal of Open Source Software},
  doi = {10.21105/joss.01576},
  issn = {2475-9066},
  number = {39},
  publisher = {The Open Journal},
  title = {MemCNN: A Python/PyTorch package for creating memory-efficient invertible neural networks},
  url = {http://dx.doi.org/10.21105/joss.01576},
  volume = {4},
  author = {Sil C. {van de} Leemput and Jonas Teuwen and Bram {van} Ginneken and Rashindra Manniesing},
  pages = {1576},
  date = {2019-07-30},
  year = {2019},
  month = {7},
  day = {30},
}

memcnn's People

Contributors

silvandeleemput avatar tychovdo avatar kyleniemeyer avatar robintibor avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.