Comments (12)
Since most layers seem to only check if train
is a "true value" (i.e., not zero or False), maybe we could use a special flag to indicate that we are measuring activation statistics in the BatchNormalization layers. The rest of the network should behave as in inference time, but the batch normalization layers should store activations to be able to compute statistics later. I'm just not sure on how to deal with their outputs, as if you have more than one BatchNormalization layer, the statistics of the activations at a given layer will depend on the outputs of the layers below it.
from keras.
Apparently, you are supposed to measure the statistics for one batch normalization layer at a time, starting from the bottom of the network. I have an idea to implement this, and might give it a try today or over the weekend.
from keras.
There's an initial implementation to fix this in #82.
from keras.
I've implemented a fix that keeps the layer self-contained and its usage unchanged, using exponential averages. Here's a speed and performance benchmark on the Kaggle Otto challenge, on GPU (network is 93-512-512-512-512-9, with PReLU):
With BatchNormalization:
Train on 52596 samples, validate on 9282 samples
Epoch 0
52596/52596 [==============================] - 35s - loss: 0.7061 - val. loss: 0.6020
Epoch 1
52596/52596 [==============================] - 34s - loss: 0.5996 - val. loss: 0.5720
Epoch 2
52596/52596 [==============================] - 33s - loss: 0.5729 - val. loss: 0.5582
Epoch 3
52596/52596 [==============================] - 35s - loss: 0.5534 - val. loss: 0.5456
Epoch 4
52596/52596 [==============================] - 32s - loss: 0.5354 - val. loss: 0.5360
Epoch 5
52596/52596 [==============================] - 33s - loss: 0.5224 - val. loss: 0.5284
Epoch 6
52596/52596 [==============================] - 33s - loss: 0.5118 - val. loss: 0.5213
Epoch 7
52596/52596 [==============================] - 33s - loss: 0.5014 - val. loss: 0.5130
Epoch 8
52596/52596 [==============================] - 35s - loss: 0.4913 - val. loss: 0.5174
Epoch 9
52596/52596 [==============================] - 33s - loss: 0.4843 - val. loss: 0.5157
Epoch 10
52596/52596 [==============================] - 35s - loss: 0.4743 - val. loss: 0.5056
No BatchNormalization:
Train on 52596 samples, validate on 9282 samples
Epoch 0
52596/52596 [==============================] - 30s - loss: 0.7199 - val. loss: 0.6082
Epoch 1
52596/52596 [==============================] - 28s - loss: 0.6163 - val. loss: 0.5776
Epoch 2
52596/52596 [==============================] - 25s - loss: 0.5854 - val. loss: 0.5564
Epoch 3
52596/52596 [==============================] - 24s - loss: 0.5680 - val. loss: 0.5453
Epoch 4
52596/52596 [==============================] - 25s - loss: 0.5525 - val. loss: 0.5392
Epoch 5
52596/52596 [==============================] - 24s - loss: 0.5434 - val. loss: 0.5346
Epoch 6
52596/52596 [==============================] - 24s - loss: 0.5346 - val. loss: 0.5309
Epoch 7
52596/52596 [==============================] - 24s - loss: 0.5277 - val. loss: 0.5223
Epoch 8
52596/52596 [==============================] - 24s - loss: 0.5182 - val. loss: 0.5243
Epoch 9
52596/52596 [==============================] - 23s - loss: 0.5118 - val. loss: 0.5165
Epoch 10
52596/52596 [==============================] - 23s - loss: 0.5051 - val. loss: 0.5173
BatchNormalization slows things down quite a bit (probably more than it rightfully should) and improves performance by a small margin.
from keras.
For kicks, here's the old implementation of BatchNormalization:
Train on 52596 samples, validate on 9282 samples
Epoch 0
52596/52596 [==============================] - 36s - loss: 0.7061 - val. loss: 0.5956
Epoch 1
52596/52596 [==============================] - 33s - loss: 0.6011 - val. loss: 0.5666
Epoch 2
52596/52596 [==============================] - 35s - loss: 0.5725 - val. loss: 0.5530
Epoch 3
52596/52596 [==============================] - 35s - loss: 0.5518 - val. loss: 0.5430
Epoch 4
52596/52596 [==============================] - 33s - loss: 0.5367 - val. loss: 0.5360
Epoch 5
52596/52596 [==============================] - 33s - loss: 0.5236 - val. loss: 0.5227
Epoch 6
52596/52596 [==============================] - 33s - loss: 0.5120 - val. loss: 0.5128
Epoch 7
52596/52596 [==============================] - 35s - loss: 0.5022 - val. loss: 0.5094
Epoch 8
52596/52596 [==============================] - 33s - loss: 0.4920 - val. loss: 0.5104
Epoch 9
52596/52596 [==============================] - 33s - loss: 0.4827 - val. loss: 0.5146
Epoch 10
52596/52596 [==============================] - 33s - loss: 0.4777 - val. loss: 0.5046
Same speed as the new implementation (unsurprisingly), but remarkably it performs slightly better. I think this might be because switching from batch normalization to global normalization at test time changes ever so slightly the input distributions, which would hurt performance.
I will try to repro this with other datasets and other networks. It's quite possible that keeping batch normalization at both training and testing time is a better algorithm.
Needless to say, the tests are all fully deterministic (so the same initial weights are used in all 3 runs). However, a sample size of 1 is not statistically representative : )
from keras.
Impressive!
You might have already done this, but just to be sure, global averaging is only in the last epoch right?
from keras.
The validation losses are computed after each epoch, and in the global case the normalization uses the mean and std of an exponential average of the inputs over the previous epoch. I still need to check what happens if you use an exact average (computed after each epoch is over).
from keras.
I added normalization modes to BatchNormalization, to let the user choose between samplewise normalization or featurewise normalization (default).
Regarding the behavior at test time --when it comes to the mean, an exponential average is identical to the exact mean, so we can use that. As for the std, I am not sure what would be the best way to compute it in an incremental fashion, any thoughts?
from keras.
We could try something like the online/parallel variance calculation methods described here. I never used them and this is the first time I see anything about this, but it may be something interesting to play with.
from keras.
hi there!
i wonder if you are able to reproduce the mnist curves (batch normalization vs not) claimed in the paper? we seem to have trouble reproducing the gain, although we used our own implementation, written from scratch in numpy.
thanks!
-zz
from keras.
Another interesting alternative at http://arxiv.org/abs/1602.07868
from keras.
@fchollet This looks like a really old bug that has been fixed. Is this true?
from keras.
Related Issues (20)
- MLX Backend HOT 4
- `keras.ops.nan_to_num` doesn't have a `nan` argument HOT 2
- batch_normalization issue when trying to load model HOT 2
- Need help to understand the logic here HOT 1
- [BUG] Conflicting `loss_weights` implementation in Keras3 for single output case. HOT 6
- CategoryEncoding layer one hot indices cast to float in graph execution HOT 3
- Any method to get formulation of functions without seeking into source? HOT 2
- There seem some differences between the source code from PyPI and that tagged with v3.2.1 in GitHub. HOT 3
- Rescaling layer on input problems / ValueError: Layer node index out of bounds. inbound_layer = <InputLayer name=keras_tensorCLONE, built=True> HOT 1
- Conv3D crash when the data_format is 'channels_first' and using Tensorflow backend HOT 3
- Misspelled link. HOT 1
- imdb.load_data function returns a python list instead of ndarray object HOT 2
- Returning backend.set_learning_phase HOT 4
- Custom loss defined as a class instance vs function HOT 3
- Torch 2.3.0 (next ver) fails with AttributeError: 'Parameter' object has no attribute 'fget'
- Add support for jnp.linalg.slogdet HOT 2
- keras.layers.Layer.call method fails when building keras model with functional API HOT 1
- Getting Wrong output even though vgg16 model showing 95% val_accuracy HOT 3
- import keras error (V3.3.2) (kaggle Notebook) HOT 1
- Keras 3 with Pytorch backend ERROR - Layer 'lstm_cell' expected 3 variables, but received 0 variables during loading. Expected: ['kernel', 'recurrent_kernel', 'bias'] HOT 4
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from keras.