Comments (12)
Would something like this work?
test_layers = []
test_params = []
for layer in model.layers:
if type(layer) not in [BatchNormalization, Dropout]:
test_layers.append(layer)
if len(test_layers) > 1:
layer.connect(test_layers[-2])
test_params += [p for p in layer.params]
model.layers = test_layers
model.params = test_params
model.compile(...)
If that's the case, we could only add a utility function similar to this to remove layers of arbitrary types (provided by the user in a list) and redo the connections.
from keras.
You can't replace Dropout with an identity layer, because Dropout does not behave like an identity layer at test time. It recalibrates the input by a factor retain_proba
, which is necessary since the weights were learned with only a fraction retain_proba
of the input activated at a time.
In general calling a layer with train=False
will always make it production-compliant. This is what happens when you use the methods test
, evaluate
, predict_proba
or predict_classes
of a model.
from keras.
Right, I overlooked this detail. Thanks for the clarification! I didn't want to have batch normalization during test, but it's just a matter of using batch_size=1
in predict_proba
and friends.
from keras.
Statistically, if the weights are learned with batch normalization, wouldn't it be necessary to keep batch normalization at test time, with the same batch size? Otherwise, wouldn't it affect the distribution of the inputs and therefore impact performance?
from keras.
According to the paper (page 4, section 3.1), during inference you should use the population statistics instead of the mini-batch statistics.
from keras.
Ok, interesting. In that case maybe we should update the train=False
behavior of the layer to reflect this. It would simply return its input unchanged, right?
from keras.
My prior understanding of this was wrong. Actually, the layer works the same way, but the normalization part is done using the mean and variance of the training data. The variance is computed using the unbiased variance estimate over size m mini-batches in the training data: Var[x] = m/(m-1) * E_B(\sigma^2_B). So, instead of being stochastic (depending on the batch content), the batch normalization layer is actually a linear transform over the activations during inference.
See algorithm 12 in the paper, where it is explained more clearly. Steps 10 and 11 explain how to use it for inference.
from keras.
Ok, I see. So we should indeed update the test behavior (train=False) of the layer to use the mean and variance of the training data (learned at train=True).
from keras.
Maybe obvious, but this also means we have to save these as model attributes (so we can serialize them with the layers as well).
from keras.
@jfsantos Has this been fixed?
from keras.
@isaacgerg this has been fixed a long time ago and I forgot to mark the issue as solved. Sorry!
from keras.
At first I want to say I'm sorry that I post this here, but I could not find a solution to my problem on the net and this seems to be the best place for it.
Let's say I want to train a GRU and because I need stateful=true the batch-size has to be known beforehand and the timesteps also have to be known in advance
Using the functional API I would have an Input as follows:
input_1 = Input(batch_shape=(batch_size, time_steps, features))
But when I evaluate the model I don't want to pass my test data in batches with fixed timesteps. My
solution at the moment is to load the saved model and rebuild it with:
input_1 = Input(shape=(None, num_input_dim))
To do that though I need a method that goes through every layer of the model and then
set the weights afterwards.
input_1 = Input(shape=(None, num_input_dim))
x1 = input_1
weights = []
for l in range(0, len(layers)):
if isinstance(layers[l], keras.layers.GRU):
x1 = GRU(layers[l].output_shape[-1], return_sequences=True)(x1)
weights.append(layers[l].get_weights())
elif isinstance(layers[l], keras.layers.Dense):
x1 = Dense(layers[l].output_shape[-1], activation='tanh')(x1)
weights.append(layers[l].get_weights())
else:
continue
(This is just an example and I find this solution very unelegant.)
There must be a better way to redefine the input shape. Can somebody help me out here
please.
From the discussion before I take it that I do not have to redefine the layers to
stateful = False
for testing purposes.
from keras.
Related Issues (20)
- Torch 2.3.0 (next ver) fails with AttributeError: 'Parameter' object has no attribute 'fget'
- Add support for jnp.linalg.slogdet HOT 2
- keras.layers.Layer.call method fails when building keras model with functional API HOT 1
- Getting Wrong output even though vgg16 model showing 95% val_accuracy HOT 3
- import keras error (V3.3.2) (kaggle Notebook) HOT 1
- Keras 3 with Pytorch backend ERROR - Layer 'lstm_cell' expected 3 variables, but received 0 variables during loading. Expected: ['kernel', 'recurrent_kernel', 'bias'] HOT 4
- The source code URL in the documentation leads to a non-existent page. HOT 1
- keras.ops.linalg.cholesky can't JIT HOT 1
- Model fails to train with Linux and Keras 3.3.2 HOT 9
- Compatible with .ogg format HOT 2
- keras autocast casts numpy int types to float HOT 2
- bug in TF _keepdims? HOT 3
- keras.ops.cross doesn't propagate input sizes HOT 1
- GSOC '24 Project? HOT 1
- Dice loss - incorrect HOT 1
- The loss becomes neagative from positive values dring taining loop HOT 3
- Unusual behavior of `predict` for JAX backend HOT 11
- Layernorm not supporting axis [-2, 3] HOT 1
- LSTM layer with dropout does not use fast CuDNN implementation in Keras 3 HOT 1
- RaggedTensor HOT 4
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from keras.