Code Monkey home page Code Monkey logo

jackknife-variational-inference's Introduction

Jackknife Variational Inference, Python implementation

This repository contains code related to the following ICLR 2018 paper:

  • Sebastian Nowozin, "Debiasing Evidence Approximations: On Importance-weighted Autoencoders and Jackknife Variational Inference", Forum, PDF.

Citation

If you use this code or build upon it, please cite the following paper (BibTeX format):

@InProceedings{
	title = "Debiasing Evidence Approximations: On Importance-weighted Autoencoders and Jackknife Variational Inference",
	author = "Sebastian Nowozin",
	booktitle = "International Conference on Learning Representations (ICLR 2018)",
	year = "2018"
}

Installation

Install the required Python2 prequisites via running:

pip install -r requirements.txt

Currently this installs:

  • Chainer, the deep learning framework, version 3.1.0
  • CuPy, a CUDA linear algebra framework compatible with NumPy, version 2.1.0
  • NumPy, numerical linear algebra for Python, version 1.11.0
  • SciPy, scientific computing framework for Python, version 1.0.0
  • H5py, an HDF5 interface for Python, version 2.6.0
  • docopt, Pythonic command line arguments parser, version 0.6.2
  • PyYAML, Python library for YAML data language, version 3.12

Running the MNIST experiment

To train the MNIST model from the paper, use the following parameters:

python ./train.py -g 0 -d mnist -e 1000 -b 2048 --opt adam \
    --vae-type jvi --vae-samples 8 --jvi-order 1 --nhidden 300 --nlatent 40 \
    -o modeloutput

Here the parameters are:

  • -g 0: train on GPU device 0
  • -d mnist: use the dynamically binarized MNIST data set
  • -e 1000: train for 1000 epochs
  • -b 2048: use a batch size of 2048 samples
  • --opt adam: use the Adam optimizer
  • --vae-type jvi: use jackknife variational inference
  • --vae-samples 8: use eight Monte Carlo samples
  • --jvi-order 1: use first-order JVI bias correction
  • --nhidden 300: in each hidden layer use 300 hidden neurons
  • --nlatent 40: use 40 dimensions for the VAE latent variable

The training process creates a file modeloutput.meta.yaml containing the training parameters as well as a directoy modeloutput/ which contains a log file and the serialized model which performed best on the validation set.

To evaluate the trained model on the test set, use

python ./evaluate.py -g 0 -d mnist -E iwae -s 256 modeloutput

This evaluates the model trained previously using the following test-time evaluation setup:

  • -g 0: use GPU device 0 for evaluation
  • -d mnist: evaluate on the mnist data set
  • -E iwae: use the IWAE objective for evaluation
  • -s 256: use 256 Monte Carlo samples in the IWAE objective

Because test-time evaluation does not require backpropagation, we can evaluate the IWAE and JVI objectives accurately using a large number of samples, e.g. -s 65536.

The evaluate.py script also supports a --reps 10 parameter which would evaluate the same model ten times to investigate variance in the Monte Carlo approximation to the evaluation objective.

Choosing different objectives

As illustrated in the paper, the JVI objective generalizes both the ELBO and the IWAE objectives.

For example, you can train on the importance-weighted autoencoder (IWAE) objective using the parameter --jvi-order 0 instead of --jvi-order 1.

You can train using the regular evidence lower bound (ELBO) by using the special case of JVI, --jvi-order 0 --vae-samples 1, or directly via --vae-type vae.

Counting JVI sets

We include a small utility to count the number of subsets used by the different JVI approximations. There are two parameters, n and order, where n is the number of samples of latent space variables per instance, and order is the order of the JVI approximation (order zero corresponds to the IWAE).

To run the utility, use:

python ./jvicount.py 16 2

This utility is useful because the set size can grow very rapidly for larger JVI orders. Therefore we can use the utility to assess the total number of terms quickly and make informed choices about batch sizes and order of the approximation.

Contact

Sebastian Nowozin, [email protected]

Contributing

This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.microsoft.com.

When you submit a pull request, a CLA-bot will automatically determine whether you need to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repos using our CLA.

This project has adopted the Microsoft Open Source Code of Conduct. For more information see the Code of Conduct FAQ or contact [email protected] with any additional questions or comments.

jackknife-variational-inference's People

Contributors

microsoftopensource avatar msftgits avatar nowozinmsr avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.