Code Monkey home page Code Monkey logo

blackjax's Issues

Return average acceptance probability for NUTS

We need the average acceptance probability for the step size adaptation in #16. Currently the warmup assumes that the information is contained in field acceptance_probability of the NUTSInfo tuple. Of course this description is not completely accurate for NUTS. I see two options:

  1. Forget about semantics and name it acceptance_probability. It's not too hurtful in this case.
  2. Use dataclasses for the Info tuples using chex, which we may include in the requirements anyway. This would also allow us to use inheritance.

Add parallel tempering

Parallel tempering is a meta-algorithm that runs several chains in parallel with different "temperatures" and swaps them regularly. This samples multi-model posteriors better than with a single chain.

The implementation should be generic and take any kernel with signature kernel(rng_key, state) -> new_state

Rename `rwmh` to `rmh`

The implementation is general and not limited to random walks, so @AdrienCorenflos was right in naming the algorithm mh (for Metropolis-Hasting).

As the story goes, the "Metropolis algorithm" was also developed by Arianna Rosenbluth whose name is little known. I thus suggest we rename the algorithm rmh for Rosenbluth-Metropolis-Hastings.

I understand this goes against previous conventions, and it may need some getting used to for users. But if no one ever does it...

Add other momentum distributions

We currently only implement a normally-distributed moment distribution, but it could also be distributed according to a Laplace, Cauchy, etc distribution.

For that we would need to transform gaussian_euclidean to euclidean and pass the momentum distribution as an argument (default unit normal).

Use `jupytext`-compatible markdown files instead of notebooks?

Notebooks don't work well with versioning, but the markdown files produced by jupytext do. I suggest we only version the markdown versions in the repository, and render the notebooks for the documentation.

Edit: We use MyST-nb to build documentation from notebooks, and it does support notebooks in jupytext format. It is just a matter of converting the notebooks and configure the build in CI.

reuse of a prng key in expand_once()

in the function expand_once a PRNG key is used multiple times.

def expand_once(loop_state):

in particular, the key rng_key generated in line 523 is used twice: in the call to trajectory_integrator (line 540) and as an element of the loop state returned in line 588. thus, it is reused in the next iteration for splitting. as far as i understand jax' rng implementation this kind of reuse must be avoided.

HMC fails to 'draw' from standard gaussian target

blackjax currently fails to target the standard normal distribution with HMC. in my example, the empirical mean is close to 0 but the empirical variance is close to 0.5. similar values arise when employing nuts with and without stan warmup.

running

import numpy as np
import jax
import jax.numpy as jnp
import blackjax.hmc as hmc
import matplotlib.pyplot as plt

potential = lambda x: -jax.scipy.stats.norm.logpdf(x, loc=0.0, scale=1.0).squeeze()
initial_position = np.array([1.0,])
initial_state = hmc.new_state(initial_position, potential)
initial_state

inv_mass_matrix = 0.1 * jnp.ones_like(initial_position)
num_integration_steps=100
step_size=1e-2

hmc_kernel = hmc.kernel(
    potential,
    step_size=step_size,
    inverse_mass_matrix=inv_mass_matrix,
    num_integration_steps=num_integration_steps
)

hmc_kernel = jax.jit(hmc_kernel)

def inference_loop(rng_key, kernel, initial_state, num_samples):
    def one_step(state, rng_key):
        state, _ = kernel(rng_key, state)
        return state, state

    keys = jax.random.split(rng_key, num_samples)
    _, states = jax.lax.scan(one_step, initial_state, keys)

    return states

rng_key = jax.random.PRNGKey(0)
states = inference_loop(rng_key, hmc_kernel, initial_state, 50_000)

samples = states.position.block_until_ready()
print(np.mean(samples, axis=0))
print(np.var(samples, axis=0))
plt.plot(samples)
plt.show()

gives

[-0.00325776]
[0.47831735]

it seems the behaviour changed with 7be2822. in particular, with the change in blackjax/inference/proposal.py.
using p_accept = jnp.clip(jnp.exp(proposal.weight), a_max=1) gives with the above code

[-0.00980166]
[0.9992622]

similar, with the parent commit (i.e., c6f75e9)

import numpy as np
import jax
import jax.numpy as jnp
import blackjax.hmc as hmc
import matplotlib.pyplot as plt

potential = lambda x: -jax.scipy.stats.norm.logpdf(x, loc=0.0, scale=1.0).squeeze()
initial_position = np.array([1.0,])
initial_state = hmc.new_state(initial_position, potential)
initial_state

inv_mass_matrix = 0.1 * jnp.ones_like(initial_position)
num_integration_steps=100
step_size=1e-2
params = hmc.HMCParameters(
    step_size=step_size,
    inv_mass_matrix=inv_mass_matrix,
    num_integration_steps=num_integration_steps
)
hmc_kernel = hmc.kernel(potential, params)

def inference_loop(rng_key, kernel, initial_state, num_samples):
    def one_step(state, rng_key):
        state, _ = kernel(rng_key, state)
        return state, state

    keys = jax.random.split(rng_key, num_samples)
    _, states = jax.lax.scan(one_step, initial_state, keys)

    return states

rng_key = jax.random.PRNGKey(0)
states = inference_loop(rng_key, hmc_kernel, initial_state, 50_000)

samples = states.position.block_until_ready()
print(np.mean(samples, axis=0))
print(np.var(samples, axis=0))
plt.plot(samples)
plt.show()

results in

[-0.00980166]
[0.9992622]

Add coupled multinomial HMC

"Multinomial HMC" should be easy to implement once the NUTS PR is merged since we formulated NUTS as trajectory sampling.

https://arxiv.org/abs/2104.05134

We only need to implement a linear_expansion function that samples a direction and runs the integrator for L steps. Implement the particular case L=1 directly without scan.

See https://arxiv.org/abs/1701.02434 page 40. Choosing a sample from the entire trajectory supposedly is more efficient in terms of ESS/gradient evaluation.

Relax jax and jaxlib version requirements?

It's currently a bit annoying to use BlackJAX with other JAX-dependent libraries, because the version requirements of BlackJAX are fixed to jax==0.2.7 and jaxlib==0.1.57. Maybe they can be relaxed a bit?

Cheers!

Add community resources: contributing guide, code of conduct, governance?

I noticed there are already some issues labeled as "good first issue", but without a contributing guide (plus a code of conduct to establish a baseline on how to interact) they can't really be good first issues, at most they can be good first issue for someone who has contributed to other projects but not yet to blackjax.

While I was here, I have no idea if there are some non-written rules about collaboration and decision making between the different projects and people involved but it may also be worth it to add a governance doc at some point, even if a couple sentences.

I will probably be able to review anything related to this on a day/week timescale but I don't think I can send a PR in the near future.

Write documentation

With almost 200 stars I think it is now time to set up some documentation. It should contain:

  • An explanation of the philosophy behind blackjax;
  • A high-level example;
  • A /comprehensive/ API documentation, including the elementary modules;
  • An example that shows how to use building blocks to eg. plot trajectories;
  • Notebook examples.

Add empirical HMC

https://arxiv.org/pdf/1810.04449.pdf

This requires:

  1. A jittered version of HMC (as in #12);
  2. A new tuning algorithm that returns an array of trajectory length;
  3. Dual Averaging;
  4. Mass matrix adaptation;
  5. The NUTS kernel to use with (1).

(1) is trivial to
implement with BlackJAX and (3,4,5) are part of the NUTS sampler.

Add SgLd

Starting from the SgMCMCJAX implementation.

Implement waste-free SMC

A natural next step after the tempered SMC method is to implement the waste-free SMC algorithm in https://arxiv.org/abs/2011.02328. The gist is fairly simple: in order to work tempered SMC may require to run K >> 1 iterations of a chain for all N particles, and only keeps the last sample of each chain. This is quite statistically wasteful, and waste-free SMC proposes to reuse the samples generated along the chain to propose better estimates (in a sense this can be seen as the SMC analog of NUTS).

Parameters returned by the warmup are off

See also #44 for a failure on a simple gaussian target. Issue was spotted while addressing #112.

Bug Description

In the Introduction.ipynb example, the kernel built using parameters obtained with the window adaptation results in only divergences due to a value for the step size that is too large.

Versions

Python implementation: CPython
Python version       : 3.9.6
IPython version      : 7.22.0

jax     : 0.2.20
jaxlib  : 0.1.71
blackjax: 0.2.x

Compiler    : GCC 11.1.0
OS          : Linux
Release     : 5.13.13-arch1-1
Machine     : x86_64
Processor   : 
CPU cores   : 24
Architecture: 64bit

Add HMC Swindles

There are several algorithms in https://arxiv.org/abs/2001.05033 that should be simple to implement.:

  • HMC-COUPLED is a kernel that runs two identical kernels with the same rng_key;
def hmc_coupled(rng_key, states):
    states, infos = jax.vmap(kernel, in_axis=(None, 0))(rng_key, states)
    return states, infos
  • HMC-ANTITHETIC is a kernel that runs two kernels with opposite momenta and identical rng_key.

expose `target_accept` to `stan_warmup`?

For dealing with divergent transitions it can be useful to increase the target acceptance probability, e.g. in Stan's using adapt_delta, or in NumPyro with target_accept_prob. Am I right in thinking that this isn't possible with your current API? The relevant argument seems to be set to a default value in the function,

def find_reasonable_step_size(
    ...
    target_accept: float = 0.65,
) -> float:

but this is not exposed to stan_warmup.

Is there a reason for this choice and/or another approach you recommend for dealing with divergences? Alternatively, if you agree to exposing target_accept to stan_warmup, I'd be happy to put in a small PR.

Thanks for the great library!

Make kernel factories return NamedTuples

Optax has this very nice design where their optimizer factories returns the following NamedTuple:

class GradientTransformation(NamedTuple):
  """Optax transformations consists of a function pair: (initialise, update)."""
  init: TransformInitFn
  update: TransformUpdateFn

See the code. This would fit the design of our kernel factories quite well and simplify the API. I was thinking:

class SamplingAlgorithm(NamedTuple):
    init: Callable
    sample: Callable

hmc = blackjax.hmc(logprob_fn, **params)

state = hmc.init(position)
new_state, info = hmc.sample(rng_key, state)

Opening the issue to get your opinion on the design and the naming. Here we don't have an update since we keep every sample.

Add ABC

Can be useful for people who have a simulator in JAX and want to do approximate bayesian inference.

Add a pytest config to deactivate GPU by default

We shouldn't have to use GPUs to run unittests. and I believe that a lot (all) users have JAX GPU installed by default, which makes running the tests slower or adds manual dev overhead. What do you think?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.