Code Monkey home page Code Monkey logo

structuredinference's Introduction

structuredInf

Code to fully reproduce benchmark results (and to extend for your own purposes) from the paper:

Krishnan, Shalit, Sontag. Structured Inference Networks for Nonlinear State Space Models, AAAI 2017.

See here for a simplified and easier to use version of the code.

Goal

The goal of this package is to provide a black box inference algorithm for learning models of time-series data. Inference during learning and at test time is based on compiled recognition or inference network.

Model

The figure below describes a simple model of time-series data.

This method is a good fit if:

  • You have an arbitrarily specified state space model whose parameters you're interested in fitting.
  • You would like to have a method for fast posterior inference at train and test time
  • Your temporal generative model has Gaussian latent variables (mean/variance can be a nonlinear function of previous timestep's variables).

Deep Kalman Filter

The code uses variational inference during learning to maximize the likelihood of the observed data:

Evidence Lower Bound

Generative Model

  • The latent variables z1...zT and the observations x1...xT describe the generative process for the data.
  • The figure depicts a state space model for time-varying data.
  • The emission and transition functions may be pre-specified to have a fixed functional form, a parametric functional form, a function parameterized by a deep neural networks or some combination thereof.

Inference Model

The box q(z1..zT | x1...xT) represents the inference network. There are several supported inference networks within this package.

  • Inference implemented with a bi-directional LSTM
  • Inference implemented with an LSTM conditioned on observations in the future
  • Inference implemented with an LSTM conditioned on observations from the past

Installation

Requirements

This package has the following requirements:

python2.7

Theano Used for automatic differentiations

[theanomodels] (https://github.com/clinicalml/theanomodels) Wrapper around theano that takes care of bookkeeping, saving/loading models etc. Clone the github repository and add its location to the PYTHONPATH environment variable so that it is accessible by python.

[pykalman] (https://pykalman.github.io/) [Optional: For running baseline UKFs/KFs]

An NVIDIA GPU w/ atleast 6G of memory is recommended.

Once the requirements have been met, clone this repository and it's ready to run.

Folder Structure

The following folders contain code to reproduct the results reported in our paper:

  • expt-synthetic, expt-polyphonic: Contains code and instructions for reproducing results from the paper.
  • baselines/: Contains to run some of the baseline algorithms on the synthetic data
  • ipynb/: Ipython notebooks for visualizing saved checkpoints and building plots

The main files of interest are:

  • parse_args_dkf.py: Arguments that the model expects to be present. Looking through it is useful to understand the different knobs available to tune the model.
  • stinfmodel/dkf.py: Code to construct the inference and generative model. The code is commented to enable easy modification for different scenarios.
  • stinfmodel/evaluate.py: Code to evaluate the Deep Kalman Filter's performance during learning.
  • stinfmodel/learning.py: Code for performing stochastic gradient ascent in the Evidence Lower Bound.

Dataset

We use numpy tensors to store the datasets with binary numpy masks to allow batch sizes comprising sequences of variable length. We train the models using mini-batch gradient descent on negative ELBO.

Format

The code to run on polyphonic and synthetic datasets has already been created in the theanomodels repository. See theanomodels/datasets/load.py for how the dataset is created and loaded.

The datasets are stored in three dimensional numpy tensors. To deal with datapoints of different lengths, we use numpy matrices comprised of binary masks. There may be different choices to manipulate data that you may adopt depending on your needs and this is merely a guideline.

assert type(dataset) is dict,'Expecting dictionary'
dataset['train'] # N_train x T_train_max x dim_observation : training data
dataset['test']  # N_test  x T_test_max  x dim_observation : validation data
dataset['valid'] # N_valid x T_valid_max x dim_observation : test data
dataset['mask_train'] # N_train x T_train_max : training masks
dataset['mask_test']  # N_test  x T_test_max  : validation masks
dataset['mask_valid'] # N_valid x T_valid_max : test masks
dataset['data_type'] # real/binary
dataset['has_masks'] # true/false

During learning, we select a minibatch of these tensors to update the weights of the model.

Running on different datasets

See the folder expt-template for an example of how to setup your data and run the code on your data

References:

@inproceedings{krishnan2016structured,
  title={Structured Inference Networks for Nonlinear State Space Models},
  author={Krishnan, Rahul G and Shalit, Uri and Sontag, David},
  booktitle={AAAI},
  year={2017}
}

This paper subsumes the work in : [Deep Kalman Filters] (https://arxiv.org/abs/1511.05121)

structuredinference's People

Contributors

dsontag avatar rahulk90 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

structuredinference's Issues

Paper and code could benefit from rewrite

Both the arxiv paper and the code in this repo for the deep kalman filter algorithm are somewhat unclear. In the paper it is rather difficult to understand the exact algorithmic steps used to train the network. I believe clarification of this issue will lead to wider use of the algorithm by the community.

ValueError: masked arrays are not supported

What kind of format pre-process for synthesis experiment should i do?

root@yht-t:/home/yht/related17/feb17/structuredinference/expt-synthetic# python run_baselines.py 
Loading linear matrices
Loading linear matrices
Reading from:  /home/yht/related17/config/theanomodels/datasets/synthetic/synthetic18.h5  Saving to:      ./baselines/synthetic18-baseline.h5
Running filter:  KF  on train
Traceback (most recent call last):
File "run_baselines.py", line 52, in <module>
runBaselines('./baselines')
File "run_baselines.py", line 24, in runBaselines
mus,cov,ll = runFilter(X, params_synthetic, dataset, filterType)
File "../baselines/filters.py", line 73, in runFilter
ll      += model.loglikelihood(X[n,:])
File "/usr/local/lib/python2.7/dist-packages/pykalman-0.9.5-py2.7.egg/pykalman/standard.py", line 1474, in loglikelihood
predicted_state_means, predicted_state_covariances, Z
File "/usr/local/lib/python2.7/dist-packages/pykalman-0.9.5-py2.7.egg/pykalman/standard.py", line 170, in _loglikelihoods
predicted_observation_covariance[np.newaxis, :, :]
File "/usr/local/lib/python2.7/dist-packages/pykalman-0.9.5-py2.7.egg/pykalman/utils.py", line 73, in log_multivariate_normal_density
cv_sol = solve_triangular(cv_chol, (X - mu).T, lower=True).T
File "/usr/local/lib/python2.7/dist-packages/scipy/linalg/basic.py", line 158, in solve_triangular
b1 = _asarray_validated(b, check_finite=check_finite)
File "/usr/local/lib/python2.7/dist-packages/scipy/_lib/_util.py", line 226, in _asarray_validated
raise ValueError('masked arrays are not supported')
ValueError: masked arrays are not supported

error: ImportError: No module named datasets.synthp

Hi, I was trying to run VisualizeSynthetic in Ipython notebook, but always meet the following error:
My running env is ubuntu server 14.04 with theanomodels, pykalman installed and tested, also with cuda
8.0 correctly installed. Any one save me?

ImportError Traceback (most recent call last)
in ()
17 #http://stackoverflow.com/questions/22408237/named-colors-in-matplotlib
18 import cPickle as pickle
---> 19 from datasets.synthp import params_synthetic
20
21 #visualize synthetic results

ImportError: No module named datasets.synthp

Supporting params['data_type']=real

In the model implementation for dkf (dkf.py), it seems like only binary data types are supported, shown below. Must we discretize continuous observations to use the model? The paper mentions binning A1c levels and glucose into clinically meaningful bins and quantiles, respectively.

if self.params['data_type']=='binary':
            npWeights['p_emis_W_ber'] = self._getWeight((self.params['dim_hidden'], self.params['dim_observations']))
            npWeights['p_emis_b_ber'] = self._getWeight((self.params['dim_observations'],))
elif self.params['data_type']=='binary_nade':
            n_visible, n_hidden   = self.params['dim_observations'], self.params['dim_hidden']
            npWeights['p_nade_W'] = self._getWeight((n_visible, n_hidden))
            npWeights['p_nade_U'] = self._getWeight((n_visible,n_hidden))
            npWeights['p_nade_b'] = self._getWeight((n_visible,))
else:
            assert False,'Invalid datatype: '+params['data_type']

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.