Comments (1)
def forward(self, x):
# Set initial states
h0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size))
c0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size))
# Forward propagate RNN
out, (h_last, c_last) = self.lstm(x, (h0, c0)) # same as self.lstm(x). Default is zero initialization.
# Decode hidden state of last time step
out = self.fc(out[:, -1, :]) # same as out = h_last
return out
@sinhasam You should note that images
is a 3D tensor of shape (batch_size, seq_length, input_size). In out, (h_last, c_last) = lstm(x, (h0, c0))
, out
is a tensor of shape (batch_size, seq_length, hidden_size) and h_last
and c_last
are tensors of shape (batch_size, hidden_size) which indicate the last hidden and cell state of the lstm (Full time step forward propagation, not one-step). You can alse see here to understand the mechanism of nn.LSTM.
In addition, in many-to-one case, we need only last hidden state of the lstm. For this, we can use out[:, -1, :]
or h_last
. In our case (MNIST), each sequence in x
has a fixed length of 28 so output[:, -1, :]
and h_last
are exactly same. For variable length many-to-one, you can use pack_padded_sequence and h_last
. Please, see here for the details.
from pytorch-tutorial.
Related Issues (20)
- Issues in running tensorboard tutorial HOT 1
- Initialize DecoderCNN in Image captioning
- Some problems occurred when I used model evaluation
- RuntimeError in Logistic Regression python file
- Using LSTM method in Python
- size mismatch for pretrained models HOT 2
- pytorch
- No Jupyter Notebooks. HOT 1
- About the learning method of neural_style_transfer
- Does anyone know the source code of channel calculation in pytorch?
- make ur repo cloneable and not editable by anyone.
- TypeError: conv2d(): argument 'input' (position 1) must be Tensor, not tuple HOT 1
- AttributeError: module 'torch.nn' has no attribute 'linear' HOT 2
- ValueError: num_samples should be a positive integer value, but got num_samples=0 HOT 1
- main.py failed HOT 2
- some question about the position of 'optimizer.zero_grad()' HOT 4
- Pytorch tutorial HOT 1
- How can I get a PDF version of the tutorial HOT 2
- Cuda is true why don`t use itοΌ
- GNN model
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
π Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. πππ
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google β€οΈ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch-tutorial.