Code Monkey home page Code Monkey logo

tensorflow-phased-lstm's Introduction

Phased LSTM: Accelerating Recurrent Network Training for Long or Event-based Sequences (NIPS 2016)

license dep1

Tensorflow has released an official version of the Phased LSTM. I wrote a script to show how to use it:

https://github.com/philipperemy/tensorflow-phased-lstm/blob/master/official_tensorflow_phased_lstm.py



Training on the classification of MNIST digits with Phased LSTM and Basic LSTM (official TF implementation)


NOTE: You can still use this alternative implementation (tested on Tensorflow v1.10). The code is very similar and it's as fast as the official one (0.8 seconds for a forward-backward pass on a Titan X Maxwell GPU).

How to use it?

git clone [email protected]:philipperemy/tensorflow-phased-lstm.git plstm
cd plstm
sudo pip install -r requirements.txt
# make sure at least Tensorflow 1.2.0 is installed.
# To reproduce the results of Phased LSTM on MNIST dataset.
python mnist_phased_lstm.py -m BasicLSTMCell
python mnist_phased_lstm.py -m PhasedLSTMCell

Phased LSTM

The Phased LSTM model extends the LSTM model by adding a new time gate, kt (Fig. 1(b)). The opening and closing of this gate is controlled by an independent rhythmic oscillation specified by three parameters; updates to the cell state ct and ht are permitted only when the gate is open. The first parameter, τ , controls the real-time period of the oscillation. The second, ron, controls the ratio of the duration of the “open” phase to the full period. The third, s, controls the phase shift of the oscillation to each Phased LSTM cell.







Resuts on MNIST dataset

Here are the results on the MNIST dataset on the training set. We consider MNIST as long sequences. Clearly with 32 cells, the Basic LSTM implementation cannot learn whereas Phased LSTM does pretty well.

Training Accuracy



Training Loss



The Phased LSTM has many surprising advantages. With its rhythmic periodicity, it acts like a learnable, gated Fourier transform on its input, permitting very fine timing discrimination. Alternatively, the rhythmic periodicity can be viewed as a kind of persistent dropout that preserves state [27], enhancing model diversity. The rhythmic inactivation can even be viewed as a shortcut to the past for gradient backpropagation, accelerating training. The presented results support these interpretations, demonstrating the ability to discriminate rhythmic signals and to learn long memory traces. Importantly, in all experiments, Phased LSTM converges more quickly and theoretically requires only 5% of the computes at runtime, while often improving in accuracy compared to standard LSTM. The presented methods can also easily be extended to GRUs [6], and it is likely that even simpler models, such as ones that use a square-wave-like oscillation, will perform well, thereby making even more efficient and encouraging alternative Phased LSTM formulations. An inspiration for using oscillations in recurrent networks comes from computational neuroscience [3], where rhythms have been shown to play important roles for synchronization and plasticity [22]. Phased LSTMs were not designed as biologically plausible models, but may help explain some of the advantages and robustness of learning in large spiking recurrent networks.

From: https://arxiv.org/pdf/1610.09513v1.pdf

tensorflow-phased-lstm's People

Contributors

philipperemy avatar tomrunia avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

tensorflow-phased-lstm's Issues

API 1.0

Dear Philippe,

your work is amazing.
Do you have any plan to update to Tensorflow 1.0 which was just released?

Thanks.
Cheers.

No experiment code for the 4 tasks in the paper?

Hello, the code here just purely integrated one for the PhasedLSTMCell right? Is there possibility u also provide experiment code for the tasks in the paper as the original Theano code? Thanks

wrong gradient for mod?

In phased_lstm.py, gradient of modulo is:

@ops.RegisterGradient('FloorMod')
def _mod_grad(op, grad):
    x, y = op.inputs
    gz = grad
    x_grad = gz
    y_grad = tf.reduce_mean(-(x // y) * gz, axis=[0], keep_dims=True)[0]
    return x_grad, y_grad

Why reduce_mean and not reduce_sum? I understand that broadcasting y to x size makes multiple copies, so that gradient has to be summed over multiple usage, like in a convolution?

not sure how to use it my own dataset

Hey,
I'm not sure about the shppe of the used data from mnist in your code.

I have 9 different sensors which simultaneously record for 3072 time steps, at the end of that record i have 3 possible classes.

it means:

  • x values are 9x3072
  • y values 1x3

I wish to use your code in order to give a prediction for what is the right class.

I'm sure it is a minor change but i'm struggling /=

What do you think?

Thanks

Any idea for regression

Hi, nice work.
I am wondering if there is any variant of phased lstm for regression problem.
Could you offer me any advice?

initial state should not be plain tuple

env:
ubuntu: 16.04
tf_version: 1.3.0
python: 2.7

When running mnist_phased_lstm.py -m BasicLSTMCell i get error:

TypeError: The two structures don't have the same sequence type. First structure has type <type 'tuple'>, while second structure has type <class 'tensorflow.python.ops.rnn_cell_impl.LSTMStateTuple'>.

Easy fix, replace tuple initial_states at line 22 by:

initial_states = rnn.LSTMStateTuple(tf.random_normal([batch_size, hidden_size], stddev=0.1),
tf.random_normal([batch_size, hidden_size], stddev=0.1))

why leak_rate = 1.0 during test ?

if not training_phase:
        leak_rate = 1.0

in time_gate_fast inside phased_lstm.py. But in the paper they say:
'Unless otherwise specified, the leak rate was set to
α = 0.001 during training and α = 0 during test'

module 'tensorflow.python.ops.rnn_cell_impl' has no attribute '_linear' - SOLVED

I received that error:

Traceback (most recent call last):
File "mnist_phased_lstm.py", line 12, in
from phased_lstm import PhasedLSTMCell
File "C:\Users\HILA\il_google_drive\code\PLSTM\tensorflow-phased-lstm\phased_lstm.py", line 13, in
_linear = rnn_cell_impl._linear
AttributeError: module 'tensorflow.python.ops.rnn_cell_impl' has no attribute '_linear'

I'm using TF 1.1 python 3.5.
Solved by switching in the file phased_lstm.py the _linear definition.
I added:
from tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl import _linear
I removed:
_linear = rnn_cell_impl._linear

Optimize the way Phi is computed.

Hi Philippe,

Thank you very much for your implementation, it is really helpful. I was going through the code and found that your computation of Phi differs from the equation of the paper, since you compute tf.mod twice.

Current code:
def phi(times, s, tau): return tf.div(tf.mod(tf.mod(times - s, tau) + tau, tau), tau)

Implementation following Equation 6 in the paper:
def phi(times, s, tau): return tf.div(tf.mod(times - s, tau), tau)

Is there any reason why you did this modification? Or is it an error?

Thanks,
Victor

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.