Code Monkey home page Code Monkey logo

local-loss's Introduction

Training neural networks with local error signals

This repo contains PyTorch code for training neural networks without global backprop. Experiments are performed by Arild Nøkland and Lars Hiller Eidnes.

A more detailed description of the experiments is available on arXiv here: https://arxiv.org/abs/1901.06656

Supervised training of neural networks for classification is typically performed with a global loss function. The loss function provides a gradient for the output layer, and this gradient is back-propagated to hidden layers to dictate an update direction for the weights. An alternative approach is to train the network with layer-wise loss functions. In this paper we demonstrate, for the first time, that layer-wise training can approach the state-of-the-art on a variety of image datasets. We use single-layer sub-networks and two different supervised loss functions to generate local error signals for the hidden layers, and we show that the combination of these losses help with optimization in the context of local learning. Using local errors could be a step towards more biologically plausible deep learning because the global error does not have to be transported back to hidden layers.

In the tables below, 'pred' indicates a layer-wise cross-entropy loss, 'sim' indicates a layer-wise similarity matching loss, and 'predsim' indicates a combination of these losses. For the local losses, the computational graph is detached after each hidden layer.

Experiments

Results on MNIST with 2 pixel jittering:

Network #Params Global loss Local loss 'pred' Local loss 'sim' Local loss 'predsim'
mlp 2.9M 0.75 0.68 0.80 0.62
vgg8b 7.3M 0.26 0.40 0.65 0.31
vgg8b + cutout 7.3M - - - 0.26

Results on Fashion-MNIST with 2 pixel jittering and horizontal flipping:

Network #Params Global loss Local loss 'pred' Local loss 'sim' Local loss 'predsim'
mlp 2.9M 8.37 8.60 9.70 8.54
vgg8b 7.3M 4.53 5.66 5.12 4.65
vgg8b (2x) 28.2M 4.55 5.11 4.92 4.33
vgg8b (2x) + cutout 28.2M - - - 4.14

Results on Kuzusjiji-MNIST with no data augmentation:

Network #Params Global loss Local loss 'pred' Local loss 'sim' Local loss 'predsim'
mlp 2.9M 5.99 7.26 9.80 7.33
vgg8b 7.3M 1.53 2.22 2.19 1.36
vgg8b + cutout 7.3M - - - 0.99

Results on Cifar-10 with data augmentation:

Network #Params Global loss Local loss 'pred' Local loss 'sim' Local loss 'predsim'
mlp 27.3M 33.56 32.33 33.48 30.93
vgg8b 8.9M 5.99 8.40 7.16 5.58
vgg11b 11.6M 5.56 8.39 6.70 5.30
vgg11b (2x) 42.0M 4.91 7.30 6.66 4.42
vgg11b (3x) 91.3M 5.02 7.37 9.34 3.97
vgg11b (3x) + cutout 91.3M - - - 3.60

Results on Cifar-100 with data augmentation:

Network #Params Global loss Local loss 'pred' Local loss 'sim' Local loss 'predsim'
mlp 27.3M 62.57 58.87 62.46 56.88
vgg8b 9.0M 26.24 29.32 32.64 24.07
vgg11b 11.7M 25.18 29.58 30.82 24.05
vgg11b (2x) 42.1M 23.44 26.91 28.03 21.20
vgg11b (3x) 91.4M 23.69 25.90 28.01 20.13

Results on SVHN with extra training data, but no augmentation:

Network #Params Global loss Local loss 'pred' Local loss 'sim' Local loss 'predsim'
vgg8b 8.9M 2.29 2.12 1.89 1.74
vgg8b + cutout 8.9M - - - 1.65

Results on STL-10 with no data augmentation:

Network #Params Global loss Local loss 'pred' Local loss 'sim' Local loss 'predsim'
vgg8b 11.5M 33.08 26.83 23.15 20.51
vgg8b + cutout 11.5M - - - 19.25

Training recipes

To replicate training of MLP on MNIST with local loss 'predsim':

python train.py --model mlp --dataset MNIST --dropout 0.1 --lr 5e-4 --num-layers 3 --epochs 100 --lr-decay-milestones 50 75 89 94 --nonlin leakyrelu

To replicate training of VGG8b on MNIST with local loss 'predsim':

python train.py --model vgg8b --dataset MNIST --dropout 0.2 --lr 5e-4 --epochs 100 --lr-decay-milestones 50 75 89 94 --nonlin leakyrelu --dim-in-decoder 1024

To replicate training of MLP on CIFAR10 with local loss 'predsim':

python train.py --model mlp --dataset CIFAR10 --dropout 0.1 --lr 5e-4 --num-layers 3 --num-hidden 3000 --nonlin leakyrelu

To replicate training of VGG8b on CIFAR10 with local loss 'predsim':

python train.py --model vgg8b --dataset CIFAR10 --dropout 0.2 --lr 5e-4 --nonlin leakyrelu --dim-in-decoder 2048

To replicate training of VGG11b (3x) on CIFAR10 with local loss 'predsim':

python train.py --model vgg11b --dataset CIFAR10 --dropout 0.3 --lr 3e-4 --feat-mult 3 --nonlin leakyrelu

For all the above recipes, to train with local cross-entropy loss, add argument

--loss-sup pred

For all the above recipes, to train with local similarity matching loss, add argument

--loss-sup sim

For all the above recipes, to train with global loss, add argument

--backprop

For all the above recipes, to train with a more biologically plausible version of local loss, add argument

--bio

To add cutout regularization with cutout hole size 14, add arguments

--cutout --length 14

To replicate all the above experiments, run

./run_experiments.sh

local-loss's People

Contributors

anokland avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.