Code Monkey home page Code Monkey logo

Comments (12)

jfsantos avatar jfsantos commented on May 2, 2024

Since most layers seem to only check if train is a "true value" (i.e., not zero or False), maybe we could use a special flag to indicate that we are measuring activation statistics in the BatchNormalization layers. The rest of the network should behave as in inference time, but the batch normalization layers should store activations to be able to compute statistics later. I'm just not sure on how to deal with their outputs, as if you have more than one BatchNormalization layer, the statistics of the activations at a given layer will depend on the outputs of the layers below it.

from keras.

jfsantos avatar jfsantos commented on May 2, 2024

Apparently, you are supposed to measure the statistics for one batch normalization layer at a time, starting from the bottom of the network. I have an idea to implement this, and might give it a try today or over the weekend.

from keras.

jfsantos avatar jfsantos commented on May 2, 2024

There's an initial implementation to fix this in #82.

from keras.

fchollet avatar fchollet commented on May 2, 2024

I've implemented a fix that keeps the layer self-contained and its usage unchanged, using exponential averages. Here's a speed and performance benchmark on the Kaggle Otto challenge, on GPU (network is 93-512-512-512-512-9, with PReLU):

With BatchNormalization:

Train on 52596 samples, validate on 9282 samples
Epoch 0
52596/52596 [==============================] - 35s - loss: 0.7061 - val. loss: 0.6020
Epoch 1
52596/52596 [==============================] - 34s - loss: 0.5996 - val. loss: 0.5720
Epoch 2
52596/52596 [==============================] - 33s - loss: 0.5729 - val. loss: 0.5582
Epoch 3
52596/52596 [==============================] - 35s - loss: 0.5534 - val. loss: 0.5456
Epoch 4
52596/52596 [==============================] - 32s - loss: 0.5354 - val. loss: 0.5360
Epoch 5
52596/52596 [==============================] - 33s - loss: 0.5224 - val. loss: 0.5284
Epoch 6
52596/52596 [==============================] - 33s - loss: 0.5118 - val. loss: 0.5213
Epoch 7
52596/52596 [==============================] - 33s - loss: 0.5014 - val. loss: 0.5130
Epoch 8
52596/52596 [==============================] - 35s - loss: 0.4913 - val. loss: 0.5174
Epoch 9
52596/52596 [==============================] - 33s - loss: 0.4843 - val. loss: 0.5157
Epoch 10
52596/52596 [==============================] - 35s - loss: 0.4743 - val. loss: 0.5056

No BatchNormalization:

Train on 52596 samples, validate on 9282 samples
Epoch 0
52596/52596 [==============================] - 30s - loss: 0.7199 - val. loss: 0.6082
Epoch 1
52596/52596 [==============================] - 28s - loss: 0.6163 - val. loss: 0.5776
Epoch 2
52596/52596 [==============================] - 25s - loss: 0.5854 - val. loss: 0.5564
Epoch 3
52596/52596 [==============================] - 24s - loss: 0.5680 - val. loss: 0.5453
Epoch 4
52596/52596 [==============================] - 25s - loss: 0.5525 - val. loss: 0.5392
Epoch 5
52596/52596 [==============================] - 24s - loss: 0.5434 - val. loss: 0.5346
Epoch 6
52596/52596 [==============================] - 24s - loss: 0.5346 - val. loss: 0.5309
Epoch 7
52596/52596 [==============================] - 24s - loss: 0.5277 - val. loss: 0.5223
Epoch 8
52596/52596 [==============================] - 24s - loss: 0.5182 - val. loss: 0.5243
Epoch 9
52596/52596 [==============================] - 23s - loss: 0.5118 - val. loss: 0.5165
Epoch 10
52596/52596 [==============================] - 23s - loss: 0.5051 - val. loss: 0.5173

BatchNormalization slows things down quite a bit (probably more than it rightfully should) and improves performance by a small margin.

from keras.

fchollet avatar fchollet commented on May 2, 2024

For kicks, here's the old implementation of BatchNormalization:

Train on 52596 samples, validate on 9282 samples
Epoch 0
52596/52596 [==============================] - 36s - loss: 0.7061 - val. loss: 0.5956
Epoch 1
52596/52596 [==============================] - 33s - loss: 0.6011 - val. loss: 0.5666
Epoch 2
52596/52596 [==============================] - 35s - loss: 0.5725 - val. loss: 0.5530
Epoch 3
52596/52596 [==============================] - 35s - loss: 0.5518 - val. loss: 0.5430
Epoch 4
52596/52596 [==============================] - 33s - loss: 0.5367 - val. loss: 0.5360
Epoch 5
52596/52596 [==============================] - 33s - loss: 0.5236 - val. loss: 0.5227
Epoch 6
52596/52596 [==============================] - 33s - loss: 0.5120 - val. loss: 0.5128
Epoch 7
52596/52596 [==============================] - 35s - loss: 0.5022 - val. loss: 0.5094
Epoch 8
52596/52596 [==============================] - 33s - loss: 0.4920 - val. loss: 0.5104
Epoch 9
52596/52596 [==============================] - 33s - loss: 0.4827 - val. loss: 0.5146
Epoch 10
52596/52596 [==============================] - 33s - loss: 0.4777 - val. loss: 0.5046

Same speed as the new implementation (unsurprisingly), but remarkably it performs slightly better. I think this might be because switching from batch normalization to global normalization at test time changes ever so slightly the input distributions, which would hurt performance.

I will try to repro this with other datasets and other networks. It's quite possible that keeping batch normalization at both training and testing time is a better algorithm.

Needless to say, the tests are all fully deterministic (so the same initial weights are used in all 3 runs). However, a sample size of 1 is not statistically representative : )

from keras.

pranv avatar pranv commented on May 2, 2024

Impressive!
You might have already done this, but just to be sure, global averaging is only in the last epoch right?

from keras.

fchollet avatar fchollet commented on May 2, 2024

The validation losses are computed after each epoch, and in the global case the normalization uses the mean and std of an exponential average of the inputs over the previous epoch. I still need to check what happens if you use an exact average (computed after each epoch is over).

from keras.

fchollet avatar fchollet commented on May 2, 2024

I added normalization modes to BatchNormalization, to let the user choose between samplewise normalization or featurewise normalization (default).

Regarding the behavior at test time --when it comes to the mean, an exponential average is identical to the exact mean, so we can use that. As for the std, I am not sure what would be the best way to compute it in an incremental fashion, any thoughts?

from keras.

jfsantos avatar jfsantos commented on May 2, 2024

We could try something like the online/parallel variance calculation methods described here. I never used them and this is the first time I see anything about this, but it may be something interesting to play with.

from keras.

zzhang-cn avatar zzhang-cn commented on May 2, 2024

hi there!

i wonder if you are able to reproduce the mnist curves (batch normalization vs not) claimed in the paper? we seem to have trouble reproducing the gain, although we used our own implementation, written from scratch in numpy.

thanks!

-zz

from keras.

bhack avatar bhack commented on May 2, 2024

Another interesting alternative at http://arxiv.org/abs/1602.07868

from keras.

isaacgerg avatar isaacgerg commented on May 2, 2024

@fchollet This looks like a really old bug that has been fixed. Is this true?

from keras.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.