Code Monkey home page Code Monkey logo

keras-rnn-timeseries's Introduction

Timeserie Keras RNN Timeseries

This Python project uses LSTM (Long Short-Term Memory) and GRU (Gated Recurrent Unit) based Recurrent Neural Networks to forecast (predict) timeseries using Keras + Theano. We compare the results produced by each of these deep neural networks with those from a linear regression model.

Dataset: Number of daily births in Quebec, Jan. 01 '77 - Dec. 31 '90 (Hipel & McLeod, 1994)

##Usage I suggest you install Virtualenv before trying this out.

git clone https://github.com/dhrushilbadani/deeplearning-timeseries.git
cd deeplearning-timeseries
virtualenv ENV
source ENV/bin/activate
pip install --upgrade pip
pip install keras h5py pandas sklearn
python evaluate.py

##Architecture & Model Properties We use Keras' Sequential model to construct recurrent neural networks. There are 3 layers:

  • Layer 1 : Either a LSTM (with output dimension 10, and statefulness enabled) layer or a GRU (with output dimension 4) layer.
  • Layer 2 : A Dropout layer with dropout probability = 0.2, to prevent overfitting.
  • Layer 3 : A fully-connected Dense Layer with output dimension 1.
  • Default optimizer: rmsprop; Default # of epochs: 150.
  • Accuracy Metric: Mean Squared Error.
This architecture can certainly further be optimized - I just haven't had the chance to experiment too much thanks to my laptop's constraints!

##Results & Observations

  1. The LSTM-RNN model performed the best with a MSE of 1464.78 (look back = 37).
  2. Naively making the RNN "deeper" did not yield immediate results; I didn't fine-tune the parameters (output_dim, for example) though.
  3. Making the LSTM network stateful (setting stateful=true when initializing the LSTM layer) did yield a significant performance improvement though. In stateless LSTM layers, the cell states are reset at each sequence. When stateful=true however, the states are propagated onto the next batch i.e. the state of the sample located at index trainX[i] will be used in the computation of the sample trainX[i+k] in the next batch, where k is the batch size. You can read more about this at the Keras docs.
  4. Using Glorot initializations yielded a performance improvement. However, using He uniform initialization (Gaussian initialization scaled by fan_in) yielded even better results than with Glorot.

##Files

  • ```data/number-of-daily-births-in-quebec.csv``` : Dataset.
  • ```lstm_model.py```: Contains the class ```LSTM_RNN``` for LSTM-based Recurrent Neural Networks.
  • ```gru_model.py```: Contains the class ```GRU_RNN``` for GRU-based Recurrent Neural Networks.
  • ```evaluate.py```: Loads and preprocesses the dataset, creates LSTM-RNN, GRU-RNN and Linear Regression models, and outputs results.
  • ##To-do

  • K-fold cross validation.
  • Add plots to aid in visualization.
  • ##References

    1. On the use of ‘Long-Short Term Memory’ neural networks for time series prediction, Gomez-Git et. al, 2014.
    2. Dropout: A Simple Way to Prevent Neural Networks from Overfitting, Srivastava et. al 2014.
    3. Learning to forget, Gers, Schmidhuber & Cummins, 2000.
    4. Empirical Evaluation of Gated Recurrent Neural Networks on Sequence Modeling, Chung et. al, 2014.

    keras-rnn-timeseries's People

    Contributors

    scp-173-cool avatar

    Watchers

     avatar

    Recommend Projects

    • React photo React

      A declarative, efficient, and flexible JavaScript library for building user interfaces.

    • Vue.js photo Vue.js

      🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

    • Typescript photo Typescript

      TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

    • TensorFlow photo TensorFlow

      An Open Source Machine Learning Framework for Everyone

    • Django photo Django

      The Web framework for perfectionists with deadlines.

    • D3 photo D3

      Bring data to life with SVG, Canvas and HTML. 📊📈🎉

    Recommend Topics

    • javascript

      JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

    • web

      Some thing interesting about web. New door for the world.

    • server

      A server is a program made to process requests and deliver data to clients.

    • Machine learning

      Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

    • Game

      Some thing interesting about game, make everyone happy.

    Recommend Org

    • Facebook photo Facebook

      We are working to build community through open source technology. NB: members must have two-factor auth.

    • Microsoft photo Microsoft

      Open source projects and samples from Microsoft.

    • Google photo Google

      Google ❤️ Open Source for everyone.

    • D3 photo D3

      Data-Driven Documents codes.