Comments (10)
Thank you for your work, @rlouf!
from blackjax.
thank you very much for the fix ❤️
from blackjax.
Thank you for the bug report and sorry for the inconvenience! I managed to reproduce your example, will investigate. @junpenglao this happens for nuts as well.
from blackjax.
i just tried nuts on the current main branch. i'll try to find time tonight and check if nuts was working on c6f75e9.
from blackjax.
Thanks for the feedback - yeah we need to add a test with monte carlo central limit theorem for this.
from blackjax.
i am pretty but not 100% sure that the example code is correct. i also didn't spend much time trying to optimize the parameter.
nuts seems to work on v1 (8209172) in the sense that the sample mean is close to zero and the sample variance is kind of close to unity (1.17 but this might be due to bad parameters for NUTS). in commit cfeffb1, i can notice different results compared to v1: the sample variance is now 0.33691323
import numpy as np
import jax
import jax.numpy as jnp
import blackjax.nuts as nuts
import matplotlib.pyplot as plt
potential = lambda x: -jax.scipy.stats.norm.logpdf(x, loc=0.0, scale=1.0).squeeze()
initial_position = np.array([1.0,])
initial_state = nuts.new_state(initial_position, potential)
initial_state
step_size=0.1
params = nuts.NUTSParameters(
step_size=step_size,
inv_mass_matrix = 1 * jnp.ones_like(initial_position)
)
nuts_kernel = nuts.kernel(potential, params)
nuts_kernel = jax.jit(nuts_kernel)
def inference_loop(rng_key, kernel, initial_state, num_samples):
def one_step(state, rng_key):
state, _ = kernel(rng_key, state)
return state, state
keys = jax.random.split(rng_key, num_samples)
_, states = jax.lax.scan(one_step, initial_state, keys)
return states
rng_key = jax.random.PRNGKey(0)
states = inference_loop(rng_key, nuts_kernel, initial_state, 50_000)
samples = states.position.block_until_ready()
print(np.mean(samples, axis=0))
print(np.var(samples, axis=0))
plt.plot(samples)
plt.show()
commit | 8209172 | fe6ff69 | cfeffb1 |
---|---|---|---|
output | [0.0027857] [1.1750876] | [0.0027857] [1.1750876] | [0.0067384] [0.33691323] |
from blackjax.
I'll track this one down, it might be the reason why adaptations fails for the variance of a gaussian target in #44. It must be the change in generate_proposal
since the bug affects both HMC and NUTS.
I will add the extra tests you were talking about @junpenglao with the bug fix.
from blackjax.
I have found a bug which may explain what you observe (I will try to run the corrected code later tonight).
In the init
function for the proposal, position
is passed to the kinetic energy instead of momentum
. The fact that we can pass both is a remnant of the code on which BlackJAX is based, where I tried to implement the SoftAbs metric. I will also remove that possibility for now.
from blackjax.
The first commit on #47 fixes it. I will do a little refactor, add an extra test like @junpenglao suggested and will release a patch.
from blackjax.
I corrected the bug and pushed a patch in 0.2.1 now available on PyPi. Thank you again for reporting!
from blackjax.
Related Issues (20)
- Implement Multiscale Generalized Hamiltonian Monte Carlo with Delayed Rejection HOT 3
- Simplify `ghmc`
- Generalizing integrators HOT 4
- 👋 Blackjax Meeting -
- Functions to run kernels HOT 15
- NUTS performance concerns on GPU HOT 2
- Refactor proposal.py HOT 1
- Implement the Schrödinger-Föllmer sampler HOT 2
- Specification of sampler HOT 4
- The ESS calculation for 1 chain
- Add progress bar to `run_inference_loop` HOT 2
- Remove `transform` from MCLMC, and place it in `run_inference_loop`
- Add inverse_mass_matrix to MCLMC HOT 3
- Merge dynamic_hmc and hmc HOT 3
- Improve SamplingAlgorithm design for init_fn and step_fn HOT 1
- Separate out Halton proposal length from CHEES HOT 1
- MCLMC Info should not scale kinetic_change again
- 👋 Blackjax Meeting -
- Integration with Inference Gym HOT 6
- 👋 Blackjax Meeting -
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from blackjax.