Code Monkey home page Code Monkey logo

pytorch-lars's Introduction

pytorch-lars

Layer-wise Adaptive Rate Scaling in PyTorch

This repo contains a PyTorch implementation of layer-wise adaptive rate scaling (LARS) from the paper "Large Batch Training of Convolutional Networks" by You, Gitman, and Ginsburg. Another version of this was recently included in PyTorch Lightning.

To run, do

python train.py --optimizer LARS --cuda lars_results

It uses skeletor-ml for experiment logging. But the main optimizer file does not depend on that framework.

Preliminary results

I just tested this using a ResNet18 on CIFAR-10. I used a standard gradient accumulation trick to train on very large batch sizes.

Alt text

Batch Size Test Accuracy
64 89.39
256 85.45
1024 81.2
4096 73.41
16384 64.13

As a comparison, using SGD with momentum, I am able to achieve about 93.5% test accuracy in 200 epochs using a geometric decay schedule (using this implementation). I have not done extensive hyperparameter tuning, though -- I used the default parameters suggested by the paper. I had a base learning rate of 0.1, 200 epochs, eta .001, momentum 0.9, weight decay of 5e-4, and the polynomial learning rate decay schedule.

There are two likely explanations for the difference in performance. One is hyperparameter tuning. ResNet18 may have different optimal hyperparameters compared to ResNet50, or CIFAR-10 may have different ones compared to ImageNet. Or both. Plugging in a geometric schedule in place of (or in addition to) the polynomial decay schedule may be the main culprit. The other possibility is that the gradient accumulation trick mentioned above interacts in unexpected ways with batch normalization. Both options could cause a performance regression.

pytorch-lars's People

Contributors

noahgolmant avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.