Comments (3)
Hi @fchollet, thanks for the quick fix but I can't find it at master
branch HEAD.
So model_a
and model_b
are actually equivalent in the first snippet? And they will behave differently (even if they have the same number of parameters) If I do something like:
from keras.layers import RNN, LSTM, LSTMCell
inputs = keras.Input(shape=(5, 10))
first_lstm_layer_out, *cell_states = LSTM(10, return_sequences=True, return_state=True)(inputs)
second_lstm_layer_out = LSTM(10)(first_lstm_layer_out, initial_state=cell_states)
model_a = keras.Model(inputs, second_lstm_layer_out)
model_a.summary()
inputs = keras.Input(shape=(5, 10))
stacked_lstm_outputs = RNN([LSTMCell(10), LSTMCell(10)])(inputs)
model_b = keras.Model(inputs, stacked_lstm_outputs)
model_b.summary()
from keras.
From what I understand the main difference should be that model_b returns the states of both LSTM layers while model_a returns only the final ones (as expected).
Yes, that's right. The second model would return 4 state tensors (2 per cell).
But in the stacked implementation of model_b are the states of the first layer used to initialize the states of the second one?
No, states are initialized at zero by each cell. To get a non-zero state you would have to pass the initial state when calling the layer.
Is this a Keras issue?
Yes, that's actually a bug. I've fixed it at HEAD. Check that it works for you.
Note that since your layer returns (outputs, [cell_1_state_0, cell_0_state_1], [cell_1_state_1, cell_1_state_2])
you cannot use it with a Sequential
model. Instead you could do something like:
inputs = keras.Input(shape=(5, 10))
outputs, cell_1_states, cell_2_states = keras.layers.RNN(
[keras.layers.LSTMCell(10), keras.layers.LSTMCell(10)],
return_state=True,
)(inputs)
model = keras.Model(inputs, [outputs] + cell_1_states + cell_2_states)
model.summary()
from keras.
Yes, there's only one gotcha: Functional model inputs/outputs must be flat structured, and here stacked_lstm_outputs
is nested. You have to flatten it (like in my example above). If you want to keep it structured, write a subclassed model.
from keras.
Related Issues (20)
- JAX array conversion failure in Keras model prediction HOT 3
- On JAX, Keras replaces any exception inside `call` method of `keras.Model` subclass with misleading error HOT 2
- To Keras community: What interpretations do you have for these curves? HOT 3
- No module named 'keras.src.engine' HOT 7
- Feature request: keras.ops.linalg.lstsq HOT 4
- Example Doubt HOT 3
- More Customisation in utils.ProgBar HOT 6
- Progress bar crash when empty dataset HOT 1
- Multihead Attention Seed Specification HOT 1
- Unable to make two instances of the MobileNetV3 within the same model HOT 2
- NumPy 2.0 support HOT 3
- Add backend-agnostic worker-process data loading HOT 8
- Keras does not save weights properly HOT 2
- Potential bug in legacy h5 weights loading. HOT 2
- Enable Discussions Tab in Github HOT 1
- FeatureSpace multiple output from one input HOT 3
- `keras.Sequential` sometimes states misleading reason for failing to construct model HOT 2
- Implement tool for saved Keras model file inspection, diff, and patching. HOT 5
- Request for a map function like map_fn in TF and vmap in Jax HOT 5
- AttributeError raised: 'list' object has no attribute 'dtype' when running the official example of SparseCategoricalAccuracy, TopKCategoricalAccuracy, SparseTopKCategoricalAccuracy HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from keras.