Code Monkey home page Code Monkey logo

Comments (12)

jfsantos avatar jfsantos commented on May 2, 2024

Would something like this work?

test_layers = []
test_params = []
for layer in model.layers:
    if type(layer) not in [BatchNormalization, Dropout]:
        test_layers.append(layer)
        if len(test_layers) > 1:
            layer.connect(test_layers[-2])
        test_params += [p for p in layer.params]

model.layers = test_layers
model.params = test_params
model.compile(...)

If that's the case, we could only add a utility function similar to this to remove layers of arbitrary types (provided by the user in a list) and redo the connections.

from keras.

fchollet avatar fchollet commented on May 2, 2024

You can't replace Dropout with an identity layer, because Dropout does not behave like an identity layer at test time. It recalibrates the input by a factor retain_proba, which is necessary since the weights were learned with only a fraction retain_proba of the input activated at a time.

In general calling a layer with train=False will always make it production-compliant. This is what happens when you use the methods test, evaluate, predict_proba or predict_classes of a model.

from keras.

jfsantos avatar jfsantos commented on May 2, 2024

Right, I overlooked this detail. Thanks for the clarification! I didn't want to have batch normalization during test, but it's just a matter of using batch_size=1 in predict_proba and friends.

from keras.

fchollet avatar fchollet commented on May 2, 2024

Statistically, if the weights are learned with batch normalization, wouldn't it be necessary to keep batch normalization at test time, with the same batch size? Otherwise, wouldn't it affect the distribution of the inputs and therefore impact performance?

from keras.

jfsantos avatar jfsantos commented on May 2, 2024

According to the paper (page 4, section 3.1), during inference you should use the population statistics instead of the mini-batch statistics.

from keras.

fchollet avatar fchollet commented on May 2, 2024

Ok, interesting. In that case maybe we should update the train=False behavior of the layer to reflect this. It would simply return its input unchanged, right?

from keras.

jfsantos avatar jfsantos commented on May 2, 2024

My prior understanding of this was wrong. Actually, the layer works the same way, but the normalization part is done using the mean and variance of the training data. The variance is computed using the unbiased variance estimate over size m mini-batches in the training data: Var[x] = m/(m-1) * E_B(\sigma^2_B). So, instead of being stochastic (depending on the batch content), the batch normalization layer is actually a linear transform over the activations during inference.

See algorithm 12 in the paper, where it is explained more clearly. Steps 10 and 11 explain how to use it for inference.

from keras.

fchollet avatar fchollet commented on May 2, 2024

Ok, I see. So we should indeed update the test behavior (train=False) of the layer to use the mean and variance of the training data (learned at train=True).

from keras.

jfsantos avatar jfsantos commented on May 2, 2024

Maybe obvious, but this also means we have to save these as model attributes (so we can serialize them with the layers as well).

from keras.

isaacgerg avatar isaacgerg commented on May 2, 2024

@jfsantos Has this been fixed?

from keras.

jfsantos avatar jfsantos commented on May 2, 2024

@isaacgerg this has been fixed a long time ago and I forgot to mark the issue as solved. Sorry!

from keras.

Nimi42 avatar Nimi42 commented on May 2, 2024

At first I want to say I'm sorry that I post this here, but I could not find a solution to my problem on the net and this seems to be the best place for it.

Let's say I want to train a GRU and because I need stateful=true the batch-size has to be known beforehand and the timesteps also have to be known in advance

Using the functional API I would have an Input as follows:

input_1 = Input(batch_shape=(batch_size, time_steps, features))

But when I evaluate the model I don't want to pass my test data in batches with fixed timesteps. My
solution at the moment is to load the saved model and rebuild it with:

input_1 = Input(shape=(None, num_input_dim))

To do that though I need a method that goes through every layer of the model and then
set the weights afterwards.

        input_1 = Input(shape=(None, num_input_dim))
        x1 = input_1
        weights = []
        for l in range(0, len(layers)):
            if isinstance(layers[l], keras.layers.GRU):
                x1 = GRU(layers[l].output_shape[-1], return_sequences=True)(x1)
                weights.append(layers[l].get_weights())
            elif isinstance(layers[l], keras.layers.Dense):
                x1 = Dense(layers[l].output_shape[-1], activation='tanh')(x1)
                weights.append(layers[l].get_weights())
            else:
                continue

(This is just an example and I find this solution very unelegant.)

There must be a better way to redefine the input shape. Can somebody help me out here
please.


From the discussion before I take it that I do not have to redefine the layers to

stateful = False

for testing purposes.

from keras.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.