Code Monkey home page Code Monkey logo

keras-swa's Introduction

Keras SWA - Stochastic Weight Averaging

PyPI version License

This is an implemention of SWA for Keras and TF-Keras. It currently only implements the constant learning rate scheduler, the cyclic learning rate described in the paper will come soon.

Introduction

Stochastic weight averaging (SWA) is build upon the same principle as snapshot ensembling and fast geometric ensembling. The idea is that averaging select stages of training can lead to better models. Where as the two former methods average by sampling and ensembling models, SWA instead average weights. This has been shown to give comparable improvements confined into a single model.

Illustration

Paper

Installation

pip install keras-swa

SWA API

Keras callback object for SWA.

Arguments

start_epoch - Starting epoch for SWA.

lr_schedule - Learning rate schedule. 'manual' , 'constant' or 'cyclic'.

swa_lr - Learning rate used when averaging weights.

swa_lr2 - Upper bound of learning rate for the cyclic schedule.

swa_freq - Frequency of weight averagining. Used with cyclic schedules.

batch_size - Batch size. Only needed in the Keras API when using both batch normalization and a data generator.

verbose - Verbosity mode, 0 or 1.

Batch Normalization

Last epoch will be a forward pass, i.e. have learning rate set to zero, for models with batch normalization. This is due to the fact that batch normalization uses the running mean and variance of it's preceding layer to make a normalization. SWA will offset this normalization by suddenly changing the weights in the end of training. Therefore, it is necessary for the last epoch to be used to reset and recalculate batch normalization for the updated weights.

Learning Rate Schedules

The default schedule is 'manual', allowing the learning rate to be controlled by an external learning rate scheduler or the optimizer. Then SWA will only affect the final weights and the learning rate of the last epoch if batch normalization is used. The schedules for the two predefined, 'constant' or 'cyclic' can be observed below.

lr_schedules

Example

For Keras

from sklearn.datasets.samples_generator import make_blobs
from keras.utils import to_categorical
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD

from swa.keras import SWA
 
# make dataset
X, y = make_blobs(n_samples=1000, 
                  centers=3, 
                  n_features=2, 
                  cluster_std=2, 
                  random_state=2)

y = to_categorical(y)

# build model
model = Sequential()
model.add(Dense(50, input_dim=2, activation='relu'))
model.add(Dense(3, activation='softmax'))

model.compile(loss='categorical_crossentropy', 
              optimizer=SGD(learning_rate=0.1))

epochs = 100
start_epoch = 75

# define swa callback
swa = SWA(start_epoch=start_epoch, 
          lr_schedule='constant', 
          swa_lr=0.01, 
          verbose=1)

# train
model.fit(X, y, epochs=epochs, verbose=1, callbacks=[swa])

Or for Keras in Tensorflow

from sklearn.datasets.samples_generator import make_blobs
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import SGD

from swa.tfkeras import SWA

# make dataset
X, y = make_blobs(n_samples=1000, 
                  centers=3, 
                  n_features=2, 
                  cluster_std=2, 
                  random_state=2)

y = to_categorical(y)

# build model
model = Sequential()
model.add(Dense(50, input_dim=2, activation='relu'))
model.add(Dense(3, activation='softmax'))

model.compile(loss='categorical_crossentropy', 
              optimizer=SGD(learning_rate=0.1))

epochs = 100
start_epoch = 75

# define swa callback
swa = SWA(start_epoch=start_epoch, 
          lr_schedule='constant', 
          swa_lr=0.01, 
          verbose=1)

# train
model.fit(X, y, epochs=epochs, verbose=1, callbacks=[swa])

Output

Epoch 1/100
1000/1000 [==============================] - 1s 703us/step - loss: 0.7518
Epoch 2/100
1000/1000 [==============================] - 0s 47us/step - loss: 0.5997
...
Epoch 74/100
1000/1000 [==============================] - 0s 31us/step - loss: 0.3913
Epoch 75/100
Epoch 00075: starting stochastic weight averaging
1000/1000 [==============================] - 0s 202us/step - loss: 0.3907
Epoch 76/100
1000/1000 [==============================] - 0s 47us/step - loss: 0.3911
...
Epoch 99/100
1000/1000 [==============================] - 0s 31us/step - loss: 0.3910
Epoch 100/100
1000/1000 [==============================] - 0s 47us/step - loss: 0.3905

Epoch 00100: final model weights set to stochastic weight average

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.