Code Monkey home page Code Monkey logo

hmctuning's Introduction

HMCTuning

This repository contains a Python package for running HMC with Pytorch, including automatic optimization of its hyperparameters. You will be able to i) sample from any distribution, given as a unnormalized target, and ii) automatically tune the HMC hyperparameters to improve the efficiency in exploring the density.

For further details about the algorithm, see Section 3.5 of our paper, where we adapted the HMC tuning for improving the inference in a Hierarchical VAE for mixed-type partial data. Original idea for optimizing HMC via Variational Inference can be found here. If you refer to this algorithm, please consider citing both works. If you use this code, please cite:

@article{peis2022missing,
  title={Missing Data Imputation and Acquisition with Deep Hierarchical Models and Hamiltonian Monte Carlo},
  author={Peis, Ignacio and Ma, Chao and Hern{\'a}ndez-Lobato, Jos{\'e} Miguel},
  journal={arXiv preprint arXiv:2202.04599},
  year={2022}
}

Instalation

The installation is straightforward using the following instruction, that creates a conda virtual environment named HMCTuning using the provided file environment.yml:

conda env create -f environment.yml

Usage

For an extended usage guide, check notebooks/usage.ipynb. For a basic usage, continue reading here. An HMC object can be created as in the following example:

from examples.distributions import *
from examples.utils import *

# Load the log probability function of MoG, and the initial proposal
logp = get_logp('gaussian_mixture')
mu0, var0 = initial_proposal('gaussian_mixture')   # [0, 0],  [0.01, 0.01]

# Create the HMC object
hmc = HMC(dim=2, logp=logp, T=5,  L=5, chains=1000, chains_sksd=30, mu0=mu0, var0=var0, vector_scale=True)

where:

  • dim is an int with the dimension of the target space.
  • logp is a Callable (function) that returns the log probability $\log p(\mathbf{x})$ for an input $\mathbf{x}$.
  • T is an int with the length of the chains.
  • L is an int with the number of Leapfrog steps.
  • chains is an int with the number of parallel chains used for each optimization step.
  • chains_sksd is an int with the number of parallel chains used independently for computing the SKSD discrepancy within each optimization step.
  • mu0 is a (bath_size, D) tensor with the means of the Gaussian initial proposal.
  • var0 is a (bath_size, D) tensor with the variances of the Gaussian initial proposal.

Sampling

For sampling from the created HMC object, just call:

samples, chains = hmc.sample(N)

Your final N samples will be stored in samples, and, if needed, you can inspect the full chains.

Training

To train the HMC hyperparameters, call:

hmc.fit(steps=100)

This will run the gradient-based optimization algorithm that tunes the hyperparameters using Variational Inference.

Example 1: 2D densities

In the following gifs you can observe two simple examples on how effective is the training algorithm for wave-shaped (left) and dual-mooon (right) densities. Horizontal scaling is automatically increased during training to inflate the proposal for covering the density.

   

Example 2: improving inference in Deep Generative Models

We used our method for improving the inference of advanced Variational Autoencoders like the one presented in our paper. In the following Figure, an illustrative example for a simple VAE is included: when training the VAE parameters jointly with the HMC hyperparameters, the multimodal true posterior (green) is successfully explored with HMC samples (orange) using the Gaussian proposal provided by the encoder (blue).

Help

Use the --help option for documentation on the usage of any of the mentioned scripts.

Contributors

Ignacio Peis

Contact

For further information: [email protected]

hmctuning's People

Contributors

ipeis avatar

Stargazers

 avatar  avatar  avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.