Code Monkey home page Code Monkey logo

resnets's Introduction

Companion code for "Stability of deep ResNets via discrete rough paths"

This repository contains the companion code for the numerical experiments presented in the paper "Stability of deep ResNets via discrete rough paths".

The numerical expermints consist in two parts:

  1. Training of a Residual Network consisting of 512 residual blocks on the MNIST data set for a total of 100 epochs. The weights are then saved to disk.
  2. Computing the $p$-variation (for $p\in[1,3]$) of those weights using the Euclidean norm for the vector norm and the Frobenius norm for the matrix norm. The results of this step are then plotted and saved to disk.

Step 1. is performed by the code contained in resnet.py. This file uses the PyTorch Machine Learning framework 1. Step 2. is performed by the code contained in iss.py. This file implements the iterated-sums signature in PyTorch, and uses a port of T. Lyons', A. Korepanov's and P. Zorin-Kranich's p-var C++ library 2 to the Rust programming language, and exported as a Python module. This port was written by the authors of the current package.

All the code runs inside a pipenv environment pinning all Python modules to the exact version used during development. In order to build the Rust extensions, a working Rust installation is needed (see here for installation instructions).

How to run the code

First, setup the pipenv environment by running pipenv sync to install all the dependencies. Then build the p_var extension by compiling it with pipenv run maturin develop --release. This will compile the Rust extension and install it inside the pipenv virtual environment.

Train the ResNet by running pipenv run python3 resnet.py. This will produce a file on disk with the .pth extension containing the trained weights. Note that some configuration of the training setup could be needed before the code fully trains in your particular machine (e.g. type and number of accelerators, etc), although they should be automatically discovered in most cases.

Finally, compute the p-variations by running pipenv run python3 iss.py. This will produce a numpy file with extension .pth containing the numerical data and a PDF figures showing the evolution and $p$-variation norm of the weights.

Footnotes

  1. https://pytorch.org/

  2. https://github.com/khumarahn/p-var.git

resnets's People

Contributors

ntapiam avatar bayerc2 avatar

Watchers

 avatar

Forkers

florin9669

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.