Code Monkey home page Code Monkey logo

blackjax's Introduction

BlackJAX

Continuous integration codecov PyPI version

BlackJAX animation: sampling BlackJAX with BlackJAX

What is BlackJAX?

BlackJAX is a library of samplers for JAX that works on CPU as well as GPU.

It is not a probabilistic programming library. However it integrates really well with PPLs as long as they can provide a (potentially unnormalized) log-probability density function compatible with JAX.

Who should use BlackJAX?

BlackJAX should appeal to those who:

  • Have a logpdf and just need a sampler;
  • Need more than a general-purpose sampler;
  • Want to sample on GPU;
  • Want to build upon robust elementary blocks for their research;
  • Are building a probabilistic programming language;
  • Want to learn how sampling algorithms work.

Quickstart

Installation

You can install BlackJAX using pip:

pip install blackjax

or via conda-forge:

conda install -c conda-forge blackjax

Nightly builds (bleeding edge) of Blackjax can also be installed using pip:

pip install blackjax-nightly

BlackJAX is written in pure Python but depends on XLA via JAX. By default, the version of JAX that will be installed along with BlackJAX will make your code run on CPU only. If you want to use BlackJAX on GPU/TPU we recommend you follow these instructions to install JAX with the relevant hardware acceleration support.

Example

Let us look at a simple self-contained example sampling with NUTS:

import jax
import jax.numpy as jnp
import jax.scipy.stats as stats
import numpy as np

import blackjax

observed = np.random.normal(10, 20, size=1_000)
def logdensity_fn(x):
    logpdf = stats.norm.logpdf(observed, x["loc"], x["scale"])
    return jnp.sum(logpdf)

# Build the kernel
step_size = 1e-3
inverse_mass_matrix = jnp.array([1., 1.])
nuts = blackjax.nuts(logdensity_fn, step_size, inverse_mass_matrix)

# Initialize the state
initial_position = {"loc": 1., "scale": 2.}
state = nuts.init(initial_position)

# Iterate
rng_key = jax.random.key(0)
for step in range(100):
    nuts_key = jax.random.fold_in(rng_key, step)
    state, _ = nuts.step(nuts_key, state)

See the documentation for more examples of how to use the library: how to write inference loops for one or several chains, how to use the Stan warmup, etc.

Philosophy

What is BlackJAX?

BlackJAX bridges the gap between "one liner" frameworks and modular, customizable libraries.

Users can import the library and interact with robust, well-tested and performant samplers with a few lines of code. These samplers are aimed at PPL developers, or people who have a logpdf and just need a sampler that works.

But the true strength of BlackJAX lies in its internals and how they can be used to experiment quickly on existing or new sampling schemes. This lower level exposes the building blocks of inference algorithms: integrators, proposal, momentum generators, etc and makes it easy to combine them to build new algorithms. It provides an opportunity to accelerate research on sampling algorithms by providing robust, performant and reusable code.

Why BlackJAX?

Sampling algorithms are too often integrated into PPLs and not decoupled from the rest of the framework, making them hard to use for people who do not need the modeling language to build their logpdf. Their implementation is most of the time monolithic and it is impossible to reuse parts of the algorithm to build custom kernels. BlackJAX solves both problems.

How does it work?

BlackJAX allows to build arbitrarily complex algorithms because it is built around a very general pattern. Everything that takes a state and returns a state is a transition kernel, and is implemented as:

new_state, info =  kernel(rng_key, state)

kernels are stateless functions and all follow the same API; state and information related to the transition are returned separately. They can thus be easily composed and exchanged. We specialize these kernels by closure instead of passing parameters.

Contributions

Please follow our short guide.

Citing Blackjax

To cite this repository:

@misc{cabezas2024blackjax,
      title={BlackJAX: Composable {B}ayesian inference in {JAX}},
      author={Alberto Cabezas and Adrien Corenflos and Junpeng Lao and Rémi Louf},
      year={2024},
      eprint={2402.10797},
      archivePrefix={arXiv},
      primaryClass={cs.MS}
}

In the above bibtex entry, names are in alphabetical order, the version number should be the last tag on the main branch.

Acknowledgements

Some details of the NUTS implementation were largely inspired by Numpyro's.

blackjax's People

Contributors

adriencorenflos avatar albcab avatar antotocar34 avatar bkktimber avatar canyon289 avatar ciguaran avatar colcarroll avatar dfm avatar elanmart avatar horaceg avatar howsiyu avatar juanitorduz avatar junpenglao avatar ksnxr avatar marcogorelli avatar miclegr avatar murphyk avatar oarriaga avatar paulscemama avatar prashjet avatar reubenharry avatar rlouf avatar samduffield avatar swapneelm avatar twhentschel avatar waynedw avatar weiyaw avatar xidulu avatar yayami3 avatar zaxtax avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

blackjax's Issues

Add other momentum distributions

We currently only implement a normally-distributed moment distribution, but it could also be distributed according to a Laplace, Cauchy, etc distribution.

For that we would need to transform gaussian_euclidean to euclidean and pass the momentum distribution as an argument (default unit normal).

expose `target_accept` to `stan_warmup`?

For dealing with divergent transitions it can be useful to increase the target acceptance probability, e.g. in Stan's using adapt_delta, or in NumPyro with target_accept_prob. Am I right in thinking that this isn't possible with your current API? The relevant argument seems to be set to a default value in the function,

def find_reasonable_step_size(
    ...
    target_accept: float = 0.65,
) -> float:

but this is not exposed to stan_warmup.

Is there a reason for this choice and/or another approach you recommend for dealing with divergences? Alternatively, if you agree to exposing target_accept to stan_warmup, I'd be happy to put in a small PR.

Thanks for the great library!

Return average acceptance probability for NUTS

We need the average acceptance probability for the step size adaptation in #16. Currently the warmup assumes that the information is contained in field acceptance_probability of the NUTSInfo tuple. Of course this description is not completely accurate for NUTS. I see two options:

  1. Forget about semantics and name it acceptance_probability. It's not too hurtful in this case.
  2. Use dataclasses for the Info tuples using chex, which we may include in the requirements anyway. This would also allow us to use inheritance.

Add ABC

Can be useful for people who have a simulator in JAX and want to do approximate bayesian inference.

Add a pytest config to deactivate GPU by default

We shouldn't have to use GPUs to run unittests. and I believe that a lot (all) users have JAX GPU installed by default, which makes running the tests slower or adds manual dev overhead. What do you think?

Add HMC Swindles

There are several algorithms in https://arxiv.org/abs/2001.05033 that should be simple to implement.:

  • HMC-COUPLED is a kernel that runs two identical kernels with the same rng_key;
def hmc_coupled(rng_key, states):
    states, infos = jax.vmap(kernel, in_axis=(None, 0))(rng_key, states)
    return states, infos
  • HMC-ANTITHETIC is a kernel that runs two kernels with opposite momenta and identical rng_key.

Use `jupytext`-compatible markdown files instead of notebooks?

Notebooks don't work well with versioning, but the markdown files produced by jupytext do. I suggest we only version the markdown versions in the repository, and render the notebooks for the documentation.

Edit: We use MyST-nb to build documentation from notebooks, and it does support notebooks in jupytext format. It is just a matter of converting the notebooks and configure the build in CI.

Add SgLd

Starting from the SgMCMCJAX implementation.

Implement waste-free SMC

A natural next step after the tempered SMC method is to implement the waste-free SMC algorithm in https://arxiv.org/abs/2011.02328. The gist is fairly simple: in order to work tempered SMC may require to run K >> 1 iterations of a chain for all N particles, and only keeps the last sample of each chain. This is quite statistically wasteful, and waste-free SMC proposes to reuse the samples generated along the chain to propose better estimates (in a sense this can be seen as the SMC analog of NUTS).

Add parallel tempering

Parallel tempering is a meta-algorithm that runs several chains in parallel with different "temperatures" and swaps them regularly. This samples multi-model posteriors better than with a single chain.

The implementation should be generic and take any kernel with signature kernel(rng_key, state) -> new_state

Add community resources: contributing guide, code of conduct, governance?

I noticed there are already some issues labeled as "good first issue", but without a contributing guide (plus a code of conduct to establish a baseline on how to interact) they can't really be good first issues, at most they can be good first issue for someone who has contributed to other projects but not yet to blackjax.

While I was here, I have no idea if there are some non-written rules about collaboration and decision making between the different projects and people involved but it may also be worth it to add a governance doc at some point, even if a couple sentences.

I will probably be able to review anything related to this on a day/week timescale but I don't think I can send a PR in the near future.

reuse of a prng key in expand_once()

in the function expand_once a PRNG key is used multiple times.

def expand_once(loop_state):

in particular, the key rng_key generated in line 523 is used twice: in the call to trajectory_integrator (line 540) and as an element of the loop state returned in line 588. thus, it is reused in the next iteration for splitting. as far as i understand jax' rng implementation this kind of reuse must be avoided.

Rename `rwmh` to `rmh`

The implementation is general and not limited to random walks, so @AdrienCorenflos was right in naming the algorithm mh (for Metropolis-Hasting).

As the story goes, the "Metropolis algorithm" was also developed by Arianna Rosenbluth whose name is little known. I thus suggest we rename the algorithm rmh for Rosenbluth-Metropolis-Hastings.

I understand this goes against previous conventions, and it may need some getting used to for users. But if no one ever does it...

Write documentation

With almost 200 stars I think it is now time to set up some documentation. It should contain:

  • An explanation of the philosophy behind blackjax;
  • A high-level example;
  • A /comprehensive/ API documentation, including the elementary modules;
  • An example that shows how to use building blocks to eg. plot trajectories;
  • Notebook examples.

Parameters returned by the warmup are off

See also #44 for a failure on a simple gaussian target. Issue was spotted while addressing #112.

Bug Description

In the Introduction.ipynb example, the kernel built using parameters obtained with the window adaptation results in only divergences due to a value for the step size that is too large.

Versions

Python implementation: CPython
Python version       : 3.9.6
IPython version      : 7.22.0

jax     : 0.2.20
jaxlib  : 0.1.71
blackjax: 0.2.x

Compiler    : GCC 11.1.0
OS          : Linux
Release     : 5.13.13-arch1-1
Machine     : x86_64
Processor   : 
CPU cores   : 24
Architecture: 64bit

HMC fails to 'draw' from standard gaussian target

blackjax currently fails to target the standard normal distribution with HMC. in my example, the empirical mean is close to 0 but the empirical variance is close to 0.5. similar values arise when employing nuts with and without stan warmup.

running

import numpy as np
import jax
import jax.numpy as jnp
import blackjax.hmc as hmc
import matplotlib.pyplot as plt

potential = lambda x: -jax.scipy.stats.norm.logpdf(x, loc=0.0, scale=1.0).squeeze()
initial_position = np.array([1.0,])
initial_state = hmc.new_state(initial_position, potential)
initial_state

inv_mass_matrix = 0.1 * jnp.ones_like(initial_position)
num_integration_steps=100
step_size=1e-2

hmc_kernel = hmc.kernel(
    potential,
    step_size=step_size,
    inverse_mass_matrix=inv_mass_matrix,
    num_integration_steps=num_integration_steps
)

hmc_kernel = jax.jit(hmc_kernel)

def inference_loop(rng_key, kernel, initial_state, num_samples):
    def one_step(state, rng_key):
        state, _ = kernel(rng_key, state)
        return state, state

    keys = jax.random.split(rng_key, num_samples)
    _, states = jax.lax.scan(one_step, initial_state, keys)

    return states

rng_key = jax.random.PRNGKey(0)
states = inference_loop(rng_key, hmc_kernel, initial_state, 50_000)

samples = states.position.block_until_ready()
print(np.mean(samples, axis=0))
print(np.var(samples, axis=0))
plt.plot(samples)
plt.show()

gives

[-0.00325776]
[0.47831735]

it seems the behaviour changed with 7be2822. in particular, with the change in blackjax/inference/proposal.py.
using p_accept = jnp.clip(jnp.exp(proposal.weight), a_max=1) gives with the above code

[-0.00980166]
[0.9992622]

similar, with the parent commit (i.e., c6f75e9)

import numpy as np
import jax
import jax.numpy as jnp
import blackjax.hmc as hmc
import matplotlib.pyplot as plt

potential = lambda x: -jax.scipy.stats.norm.logpdf(x, loc=0.0, scale=1.0).squeeze()
initial_position = np.array([1.0,])
initial_state = hmc.new_state(initial_position, potential)
initial_state

inv_mass_matrix = 0.1 * jnp.ones_like(initial_position)
num_integration_steps=100
step_size=1e-2
params = hmc.HMCParameters(
    step_size=step_size,
    inv_mass_matrix=inv_mass_matrix,
    num_integration_steps=num_integration_steps
)
hmc_kernel = hmc.kernel(potential, params)

def inference_loop(rng_key, kernel, initial_state, num_samples):
    def one_step(state, rng_key):
        state, _ = kernel(rng_key, state)
        return state, state

    keys = jax.random.split(rng_key, num_samples)
    _, states = jax.lax.scan(one_step, initial_state, keys)

    return states

rng_key = jax.random.PRNGKey(0)
states = inference_loop(rng_key, hmc_kernel, initial_state, 50_000)

samples = states.position.block_until_ready()
print(np.mean(samples, axis=0))
print(np.var(samples, axis=0))
plt.plot(samples)
plt.show()

results in

[-0.00980166]
[0.9992622]

Make kernel factories return NamedTuples

Optax has this very nice design where their optimizer factories returns the following NamedTuple:

class GradientTransformation(NamedTuple):
  """Optax transformations consists of a function pair: (initialise, update)."""
  init: TransformInitFn
  update: TransformUpdateFn

See the code. This would fit the design of our kernel factories quite well and simplify the API. I was thinking:

class SamplingAlgorithm(NamedTuple):
    init: Callable
    sample: Callable

hmc = blackjax.hmc(logprob_fn, **params)

state = hmc.init(position)
new_state, info = hmc.sample(rng_key, state)

Opening the issue to get your opinion on the design and the naming. Here we don't have an update since we keep every sample.

Relax jax and jaxlib version requirements?

It's currently a bit annoying to use BlackJAX with other JAX-dependent libraries, because the version requirements of BlackJAX are fixed to jax==0.2.7 and jaxlib==0.1.57. Maybe they can be relaxed a bit?

Cheers!

Add coupled multinomial HMC

"Multinomial HMC" should be easy to implement once the NUTS PR is merged since we formulated NUTS as trajectory sampling.

https://arxiv.org/abs/2104.05134

We only need to implement a linear_expansion function that samples a direction and runs the integrator for L steps. Implement the particular case L=1 directly without scan.

See https://arxiv.org/abs/1701.02434 page 40. Choosing a sample from the entire trajectory supposedly is more efficient in terms of ESS/gradient evaluation.

Add empirical HMC

https://arxiv.org/pdf/1810.04449.pdf

This requires:

  1. A jittered version of HMC (as in #12);
  2. A new tuning algorithm that returns an array of trajectory length;
  3. Dual Averaging;
  4. Mass matrix adaptation;
  5. The NUTS kernel to use with (1).

(1) is trivial to
implement with BlackJAX and (3,4,5) are part of the NUTS sampler.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.