Code Monkey home page Code Monkey logo

Comments (10)

hriebl avatar hriebl commented on June 9, 2024 1

Thank you for your work, @rlouf!

from blackjax.

wiep avatar wiep commented on June 9, 2024 1

thank you very much for the fix ❤️

from blackjax.

rlouf avatar rlouf commented on June 9, 2024

Thank you for the bug report and sorry for the inconvenience! I managed to reproduce your example, will investigate. @junpenglao this happens for nuts as well.

from blackjax.

wiep avatar wiep commented on June 9, 2024

i just tried nuts on the current main branch. i'll try to find time tonight and check if nuts was working on c6f75e9.

from blackjax.

junpenglao avatar junpenglao commented on June 9, 2024

Thanks for the feedback - yeah we need to add a test with monte carlo central limit theorem for this.

from blackjax.

wiep avatar wiep commented on June 9, 2024

i am pretty but not 100% sure that the example code is correct. i also didn't spend much time trying to optimize the parameter.

nuts seems to work on v1 (8209172) in the sense that the sample mean is close to zero and the sample variance is kind of close to unity (1.17 but this might be due to bad parameters for NUTS). in commit cfeffb1, i can notice different results compared to v1: the sample variance is now 0.33691323

import numpy as np
import jax
import jax.numpy as jnp
import blackjax.nuts as nuts
import matplotlib.pyplot as plt

potential = lambda x: -jax.scipy.stats.norm.logpdf(x, loc=0.0, scale=1.0).squeeze()
initial_position = np.array([1.0,])
initial_state = nuts.new_state(initial_position, potential)
initial_state

step_size=0.1
params = nuts.NUTSParameters(
    step_size=step_size,
    inv_mass_matrix = 1 * jnp.ones_like(initial_position)
)
nuts_kernel = nuts.kernel(potential, params)
nuts_kernel = jax.jit(nuts_kernel)

def inference_loop(rng_key, kernel, initial_state, num_samples):
    def one_step(state, rng_key):
        state, _ = kernel(rng_key, state)
        return state, state

    keys = jax.random.split(rng_key, num_samples)
    _, states = jax.lax.scan(one_step, initial_state, keys)

    return states

rng_key = jax.random.PRNGKey(0)
states = inference_loop(rng_key, nuts_kernel, initial_state, 50_000)

samples = states.position.block_until_ready()
print(np.mean(samples, axis=0))
print(np.var(samples, axis=0))
plt.plot(samples)
plt.show()
commit 8209172 fe6ff69 cfeffb1
output [0.0027857] [1.1750876] [0.0027857] [1.1750876] [0.0067384] [0.33691323]

from blackjax.

rlouf avatar rlouf commented on June 9, 2024

I'll track this one down, it might be the reason why adaptations fails for the variance of a gaussian target in #44. It must be the change in generate_proposal since the bug affects both HMC and NUTS.

I will add the extra tests you were talking about @junpenglao with the bug fix.

from blackjax.

rlouf avatar rlouf commented on June 9, 2024

I have found a bug which may explain what you observe (I will try to run the corrected code later tonight).

In the init function for the proposal, position is passed to the kinetic energy instead of momentum. The fact that we can pass both is a remnant of the code on which BlackJAX is based, where I tried to implement the SoftAbs metric. I will also remove that possibility for now.

from blackjax.

rlouf avatar rlouf commented on June 9, 2024

The first commit on #47 fixes it. I will do a little refactor, add an extra test like @junpenglao suggested and will release a patch.

from blackjax.

rlouf avatar rlouf commented on June 9, 2024

I corrected the bug and pushed a patch in 0.2.1 now available on PyPi. Thank you again for reporting!

from blackjax.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.