blackjax-devs / blackjax Goto Github PK
View Code? Open in Web Editor NEWBlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity.
Home Page: https://blackjax-devs.github.io/blackjax/
License: Apache License 2.0
BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity.
Home Page: https://blackjax-devs.github.io/blackjax/
License: Apache License 2.0
mcx/inference/integrators.py
We need the average acceptance probability for the step size adaptation in #16. Currently the warmup assumes that the information is contained in field acceptance_probability
of the NUTSInfo
tuple. Of course this description is not completely accurate for NUTS. I see two options:
acceptance_probability
. It's not too hurtful in this case.Info
tuples using chex
, which we may include in the requirements anyway. This would also allow us to use inheritance.Parallel tempering is a meta-algorithm that runs several chains in parallel with different "temperatures" and swaps them regularly. This samples multi-model posteriors better than with a single chain.
The implementation should be generic and take any kernel with signature kernel(rng_key, state) -> new_state
The implementation is general and not limited to random walks, so @AdrienCorenflos was right in naming the algorithm mh
(for Metropolis-Hasting).
As the story goes, the "Metropolis algorithm" was also developed by Arianna Rosenbluth whose name is little known. I thus suggest we rename the algorithm rmh
for Rosenbluth-Metropolis-Hastings
.
I understand this goes against previous conventions, and it may need some getting used to for users. But if no one ever does it...
We currently only implement a normally-distributed moment distribution, but it could also be distributed according to a Laplace, Cauchy, etc distribution.
For that we would need to transform gaussian_euclidean
to euclidean
and pass the momentum distribution as an argument (default unit normal).
https://arxiv.org/pdf/1212.4693.pdf
See https://arxiv.org/abs/1910.06243 for an explicit integrator.
Notebooks don't work well with versioning, but the markdown files produced by jupytext
do. I suggest we only version the markdown versions in the repository, and render the notebooks for the documentation.
Edit: We use MyST-nb to build documentation from notebooks, and it does support notebooks in jupytext format. It is just a matter of converting the notebooks and configure the build in CI.
It would be very useful to support SG-MCMC methods, such as SGLD, SG-HMC, etc, so we can tackle large N problems.
Perhaps similar to this codebase:
https://github.com/jeremiecoullon/SGMCMCJax
in the function expand_once a PRNG key is used multiple times.
blackjax/blackjax/inference/trajectory.py
Line 507 in 6938da9
in particular, the key rng_key
generated in line 523 is used twice: in the call to trajectory_integrator (line 540) and as an element of the loop state returned in line 588. thus, it is reused in the next iteration for splitting. as far as i understand jax' rng implementation this kind of reuse must be avoided.
Should make testing pmap etc easier, among many other improvements: https://github.com/deepmind/chex
blackjax currently fails to target the standard normal distribution with HMC. in my example, the empirical mean is close to 0 but the empirical variance is close to 0.5. similar values arise when employing nuts with and without stan warmup.
running
import numpy as np
import jax
import jax.numpy as jnp
import blackjax.hmc as hmc
import matplotlib.pyplot as plt
potential = lambda x: -jax.scipy.stats.norm.logpdf(x, loc=0.0, scale=1.0).squeeze()
initial_position = np.array([1.0,])
initial_state = hmc.new_state(initial_position, potential)
initial_state
inv_mass_matrix = 0.1 * jnp.ones_like(initial_position)
num_integration_steps=100
step_size=1e-2
hmc_kernel = hmc.kernel(
potential,
step_size=step_size,
inverse_mass_matrix=inv_mass_matrix,
num_integration_steps=num_integration_steps
)
hmc_kernel = jax.jit(hmc_kernel)
def inference_loop(rng_key, kernel, initial_state, num_samples):
def one_step(state, rng_key):
state, _ = kernel(rng_key, state)
return state, state
keys = jax.random.split(rng_key, num_samples)
_, states = jax.lax.scan(one_step, initial_state, keys)
return states
rng_key = jax.random.PRNGKey(0)
states = inference_loop(rng_key, hmc_kernel, initial_state, 50_000)
samples = states.position.block_until_ready()
print(np.mean(samples, axis=0))
print(np.var(samples, axis=0))
plt.plot(samples)
plt.show()
gives
[-0.00325776]
[0.47831735]
it seems the behaviour changed with 7be2822. in particular, with the change in blackjax/inference/proposal.py
.
using p_accept = jnp.clip(jnp.exp(proposal.weight), a_max=1)
gives with the above code
[-0.00980166]
[0.9992622]
similar, with the parent commit (i.e., c6f75e9)
import numpy as np
import jax
import jax.numpy as jnp
import blackjax.hmc as hmc
import matplotlib.pyplot as plt
potential = lambda x: -jax.scipy.stats.norm.logpdf(x, loc=0.0, scale=1.0).squeeze()
initial_position = np.array([1.0,])
initial_state = hmc.new_state(initial_position, potential)
initial_state
inv_mass_matrix = 0.1 * jnp.ones_like(initial_position)
num_integration_steps=100
step_size=1e-2
params = hmc.HMCParameters(
step_size=step_size,
inv_mass_matrix=inv_mass_matrix,
num_integration_steps=num_integration_steps
)
hmc_kernel = hmc.kernel(potential, params)
def inference_loop(rng_key, kernel, initial_state, num_samples):
def one_step(state, rng_key):
state, _ = kernel(rng_key, state)
return state, state
keys = jax.random.split(rng_key, num_samples)
_, states = jax.lax.scan(one_step, initial_state, keys)
return states
rng_key = jax.random.PRNGKey(0)
states = inference_loop(rng_key, hmc_kernel, initial_state, 50_000)
samples = states.position.block_until_ready()
print(np.mean(samples, axis=0))
print(np.var(samples, axis=0))
plt.plot(samples)
plt.show()
results in
[-0.00980166]
[0.9992622]
In cell 19 of Introduction.ipynb you create kernel
using STAN warmup, but in cell 20, you use the old nuts_kernel
which created earlier without using warmup.
https://github.com/pymc-devs/pymc3/blob/master/pymc3/sampling_jax.py should serve as a good blue-print.
"Multinomial HMC" should be easy to implement once the NUTS PR is merged since we formulated NUTS as trajectory sampling.
https://arxiv.org/abs/2104.05134
We only need to implement a linear_expansion
function that samples a direction and runs the integrator for L
steps. Implement the particular case L=1
directly without scan
.
See https://arxiv.org/abs/1701.02434 page 40. Choosing a sample from the entire trajectory supposedly is more efficient in terms of ESS/gradient evaluation.
It's currently a bit annoying to use BlackJAX with other JAX-dependent libraries, because the version requirements of BlackJAX are fixed to jax==0.2.7
and jaxlib==0.1.57
. Maybe they can be relaxed a bit?
Cheers!
Ideally implement these two different papers
They both build on this https://rss.onlinelibrary.wiley.com/doi/full/10.1111/rssb.12336 and can be used to compute unbiased estimates of expectations. I don't really think these methods are implemented in any Python library (I don't think it's in Stan either?) and would be a cool show case.
See #87.
Following suggestion by @rlouf I'm splitting the coupled HMC commit a bit.
This is to tag the RWMH method.
I noticed there are already some issues labeled as "good first issue", but without a contributing guide (plus a code of conduct to establish a baseline on how to interact) they can't really be good first issues, at most they can be good first issue for someone who has contributed to other projects but not yet to blackjax.
While I was here, I have no idea if there are some non-written rules about collaboration and decision making between the different projects and people involved but it may also be worth it to add a governance doc at some point, even if a couple sentences.
I will probably be able to review anything related to this on a day/week timescale but I don't think I can send a PR in the near future.
With almost 200 stars I think it is now time to set up some documentation. It should contain:
https://arxiv.org/pdf/1810.04449.pdf
This requires:
(1) is trivial to
implement with BlackJAX and (3,4,5) are part of the NUTS sampler.
Starting from the SgMCMCJAX implementation.
A natural next step after the tempered SMC method is to implement the waste-free SMC algorithm in https://arxiv.org/abs/2011.02328. The gist is fairly simple: in order to work tempered SMC may require to run K >> 1 iterations of a chain for all N particles, and only keeps the last sample of each chain. This is quite statistically wasteful, and waste-free SMC proposes to reuse the samples generated along the chain to propose better estimates (in a sense this can be seen as the SMC analog of NUTS).
Implement tempered SMC for sampling in BlackJax
The interface between this library and the rest of the world should be log_prior
, log_likelihood
and log_posterior
. potential
is too specific to the HMC family.
Is the version for deps in requirements-dev.txt
necessary? Can we remove them?
See also #44 for a failure on a simple gaussian target. Issue was spotted while addressing #112.
In the Introduction.ipynb
example, the kernel built using parameters obtained with the window adaptation results in only divergences due to a value for the step size that is too large.
Python implementation: CPython
Python version : 3.9.6
IPython version : 7.22.0
jax : 0.2.20
jaxlib : 0.1.71
blackjax: 0.2.x
Compiler : GCC 11.1.0
OS : Linux
Release : 5.13.13-arch1-1
Machine : x86_64
Processor :
CPU cores : 24
Architecture: 64bit
See https://arxiv.org/abs/1701.02434 page 40. Choosing a sample from the entire trajectory supposedly is more efficient in terms of ESS/gradient evaluation.
The variable names logprob_fn
and potential_fn
are used interchangeably in different places. However I think we are choosing to work in the potential space am I right? The occurences should be corrected.
@rlouf do you know of a way to catch (at minima) the docstring errors in the commit checks?
blackjax/blackjax/inference/rwmh/base.py
Line 60 in fe10807
There are several algorithms in https://arxiv.org/abs/2001.05033 that should be simple to implement.:
rng_key
;def hmc_coupled(rng_key, states):
states, infos = jax.vmap(kernel, in_axis=(None, 0))(rng_key, states)
return states, infos
rng_key
.For dealing with divergent transitions it can be useful to increase the target acceptance probability, e.g. in Stan's using adapt_delta
, or in NumPyro with target_accept_prob
. Am I right in thinking that this isn't possible with your current API? The relevant argument seems to be set to a default value in the function,
def find_reasonable_step_size(
...
target_accept: float = 0.65,
) -> float:
but this is not exposed to stan_warmup
.
Is there a reason for this choice and/or another approach you recommend for dealing with divergences? Alternatively, if you agree to exposing target_accept
to stan_warmup
, I'd be happy to put in a small PR.
Thanks for the great library!
Optax has this very nice design where their optimizer factories returns the following NamedTuple:
class GradientTransformation(NamedTuple):
"""Optax transformations consists of a function pair: (initialise, update)."""
init: TransformInitFn
update: TransformUpdateFn
See the code. This would fit the design of our kernel factories quite well and simplify the API. I was thinking:
class SamplingAlgorithm(NamedTuple):
init: Callable
sample: Callable
hmc = blackjax.hmc(logprob_fn, **params)
state = hmc.init(position)
new_state, info = hmc.sample(rng_key, state)
Opening the issue to get your opinion on the design and the naming. Here we don't have an update since we keep every sample.
Can be useful for people who have a simulator in JAX and want to do approximate bayesian inference.
As this forces users to explicitly use these names when defining the kernel factory.
We shouldn't have to use GPUs to run unittests. and I believe that a lot (all) users have JAX GPU installed by default, which makes running the tests slower or adds manual dev overhead. What do you think?
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.