Code Monkey home page Code Monkey logo

patrick-kidger / fasterneuraldiffeq Goto Github PK

View Code? Open in Web Editor NEW
85.0 7.0 9.0 650 KB

Code for "'Hey, that's not an ODE:' Faster ODE Adjoints via Seminorms" (ICML 2021)

License: Apache License 2.0

Python 100.00%
deep-neural-networks deep-learning pytorch dynamical-systems differential-equations ordinary-differential-equations controlled-differential-equations neural-differential-equations numerical-methods numerical-analysis machine-learning

fasterneuraldiffeq's Introduction

"Hey, that's not an ODE": Faster ODE Adjoints via Seminorms
(ICML 2021)
[arXiv]

One simple-to-implement trick dramatically improves the speed at which Neural ODEs and Neural CDEs can be trained. (As much as doubling the speed.)

Backpropagation through a Neural ODE/CDE can be performed via the "adjoint method", which involves solving another differential equation backwards in time. However it turns out that default numerical solvers are unnecessarily stringent when solving the adjoint equation, and take too many steps, that are too small.

Tweaking things slightly reduces the number of function evaluations on the backward pass by as much as 62%. (Exact number will be problem-dependent, of course.)

torchdiffeq now supports this feature natively!


Summary:

If you're using torchdiffeq (at least version 0.1.0) then replace

import torchdiffeq

func = ...
y0 = ...
t = ...
torchdiffeq.odeint_adjoint(func=func, y0=y0, t=t)

with

import torchdiffeq

def rms_norm(tensor):
    return tensor.pow(2).mean().sqrt()

def make_norm(state):
    state_size = state.numel()
    def norm(aug_state):
        y = aug_state[1:1 + state_size]
        adj_y = aug_state[1 + state_size:1 + 2 * state_size]
        return max(rms_norm(y), rms_norm(adj_y))
    return norm

func = ...    
y0 = ...
t = ...
torchdiffeq.odeint_adjoint(func=func, y0=y0, t=t, 
                           adjoint_options=dict(norm=make_norm(y0)))

That's it.

Reproducing experiments

The code for the Neural CDE and Symplectic ODE-Net experiments is available.

Requirements

PyTorch >= 1.6
torchdiffeq >= 0.1.0
torchcde >= 0.1.0
torchaudio >= 0.6.0
sklearn >= 0.23.1
gym >= 0.17.2
tqdm >= 4.47.0

In summary:

conda install pytorch torchaudio -c pytorch
pip install torchdiffeq scikit-learn gym tqdm
pip install git+https://github.com/patrick-kidger/torchcde.git

Neural CDEs

python
>>> import speech_commands
>>> device = 'cuda'
>>> norm = False  # don't use our trick
>>> norm = True   # use our trick
>>> rtol = 1e-3
>>> atol = 1e-5
>>> results = speech_commands.main(device, norm, rtol, atol)
>>> print(results.keys())  # inspect results object
>>> print(results.test_metrics.accuracy)  # query results object

Symplectic ODE-Net

python
>>> import acrobot
>>> device = 'cuda'
>>> norm = False  # don't use our trick
>>> norm = True   # use our trick
>>> results = acrobot.main(device, norm)
>>> print(results.keys())  # inspect results object
>>> print(results.test_metrics.loss)  # query results object

Citation

@article{kidger2021hey,
    author={Kidger, Patrick and Chen, Ricky T. Q. and Lyons, Terry},
    title={{``Hey, that's not an ODE'': Faster ODE Adjoints via Seminorms}},
    year={2021},
    journal={International Conference on Machine Learning}
}

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.