Code Monkey home page Code Monkey logo

riemannian-batch-normalization-pytorch's Introduction

A PyTorch Implementation of "Riemannian approach to batch normalization"

A PyTorch implementation of Riemannian approach to batch normalization (NIPS 2017) by Minhyung Cho and Jaehyung Lee (https://arxiv.org/abs/1709.09603). The poster for the conference can be found here.

Refer to https://github.com/MinhyungCho/riemannian-batch-normalization/ for a tensorflow implementation.

Abstract

Batch Normalization (BN) has proven to be an effective algorithm for deep neural network training by normalizing the input to each neuron and reducing the internal covariate shift. The space of weight vectors in the BN layer can be naturally interpreted as a Riemannian manifold, which is invariant to linear scaling of weights. Following the intrinsic geometry of this manifold provides a new learning rule that is more efficient and easier to analyze. We also propose intuitive and effective gradient clipping and regularization methods for the proposed algorithm by utilizing the geometry of the manifold. The resulting algorithm consistently outperforms the original BN on various types of network architectures and datasets.

Results

See https://arxiv.org/abs/1709.09603 for details. Only the results for CIFAR-10 and 100 have been verified.

Requirements

The script depends on opencv python bindings, easily installable via conda:

conda install -c menpo opencv3

After that and after installing pytorch do:

pip install -r requirements.txt

Train

The commands below are examples.

CIFAR-10:

[SGD] python main.py --save ./logs/resnet_$RANDOM$RANDOM --depth 28 --width 10
[SGD-G] python main.py --save ./logs/resnet_$RANDOM$RANDOM --depth 28 --width 10 --optim_method SGDG --lr 0.01 --lrg 0.2
[Adam-G] python main.py --save ./logs/resnet_$RANDOM$RANDOM --depth 28 --width 10 --optim_method AdamG --lr 0.01 --lrg 0.05

CIFAR-100:

[SGD] python main.py --save ./logs/resnet_$RANDOM$RANDOM --depth 28 --width 10 --dataset CIFAR100
[SGD-G] python main.py --save ./logs/resnet_$RANDOM$RANDOM --depth 28 --width 10 --optim_method SGDG --lr 0.01 --lrg 0.2 --dataset CIFAR100
[Adam-G] python main.py --save ./logs/resnet_$RANDOM$RANDOM --depth 28 --width 10 --optim_method AdamG --lr 0.01 --lrg 0.05 --dataset CIFAR100

To apply this algorithm to your model

grassmann_optimizer.py is the main implementation which provides the proposed SGD-G and Adam-G optimizer. main.py includes all the steps to apply the provided optimizers to your model.

  1. Collect all the weight parameters which need to be optimized on Grassmann manifold (and initialize them to a unit scale):

    key_g = []
    if opt.optim_method == 'SGDG' or opt.optim_method == 'AdamG':
        param_g = []
        param_e0 = []
        param_e1 = []
    
        for key, value in params.items():
            if 'conv' in key and value.size()[0] < np.prod(value.size()[1:]):
                param_g.append(value)
                key_g.append(key)
                # initlize to scale 1
                unitp, _ = unit(value.data.view(value.size(0), -1)) 
                value.data.copy_(unitp.view(value.size()))
            elif 'bn' in key or 'bias' in key:
                param_e0.append(value)
            else:
                param_e1.append(value)
  2. Create the optimizer with proper parameters:

    import grassmann_optimizer
    dict_g = {'params':param_g,'lr':lrg,'momentum':0.9,'grassmann':True, 'omega':opt.omega, 'grad_clip':opt.grad_clip}
    dict_e0 = {'params':param_e0,'lr':lr,'momentum':0.9,'grassmann':False,'weight_decay':opt.bnDecay,'nesterov':True}
    dict_e1 = {'params':param_e1,'lr':lr,'momentum':0.9,'grassmann':False,'weight_decay':opt.weightDecay,'nesterov':True}
    return grassmann_optimizer.SGDG([dict_g, dict_e0, dict_e1])  # or use AdamG

riemannian-batch-normalization-pytorch's People

Contributors

minhyungcho avatar szagoruyko avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.