Code Monkey home page Code Monkey logo

Comments (24)

rlouf avatar rlouf commented on May 31, 2024 2

Looks good ! I think it would be nice to add a comparison in terms of sampling time with pm.sample and pm.sampling_jax? Using only one chain for fair comparison.

from blackjax.

rlouf avatar rlouf commented on May 31, 2024 2

No problem, I'll do it!

@ricardoV94 I was thinking I could add an aeppl example as well. Could be used to benchmark aehmc.

from blackjax.

twiecki avatar twiecki commented on May 31, 2024 1

@kc611 No, we just want to compile a PyMC3 model to JAX through Aesara's existing JAX backend and then pass that to blackjax. This should only require a couple of lines of code. You can take a look at https://github.com/pymc-devs/pymc3/blob/main/pymc3/sampling_jax.py#L157 how we do this for running a PyMC3 JAX model using numpyro's jax samplers. This here should look quite similar.

from blackjax.

rlouf avatar rlouf commented on May 31, 2024

This can be started now but will need to wait until #26 is merged to be completed.

from blackjax.

rlouf avatar rlouf commented on May 31, 2024

Could someone in the PyMC team take a stab at this? Generating a JAX-compatible logpdf does not seem so straightforward for someone who is not familiar with the internals looking at the sampling_jax.py file in the PyMC3 repo.

from blackjax.

twiecki avatar twiecki commented on May 31, 2024

@kc611 any interest in looking into this?

from blackjax.

kc611 avatar kc611 commented on May 31, 2024

Sure,

Just to be clear, by running BlackJAX on PyMC3 models we mean to convert an existing pymc3.Model completely into it's BlackJax (or Jax ?) variant, (along with it's inputs) similar to how we first approached the Jax conversion in PyMC, right? Or is there some sort of integration of BlackJax with PyMC already present.

from blackjax.

kc611 avatar kc611 commented on May 31, 2024

Ah okay, got it.

from blackjax.

kc611 avatar kc611 commented on May 31, 2024

So, This is something that I came up with. The original model is from @twiecki 's Radon hierarchical model notebook (from PyMC docs).

https://colab.research.google.com/drive/1lUXPRynGCuusHLpBTnC1tPw-2zw_6NwG?usp=sharing

But I don't know if it's working as expected. (Like, I haven't tried plotting the variables yet).

Note: This is built upon PyMC v3 and theano and will probably require changes when we release v4.

from blackjax.

903124 avatar 903124 commented on May 31, 2024

For PyMC v4 with Aesara backend here is the error:

from aesara.link.jax.dispatch import jax_funcify
from aesara.graph.fg import FunctionGraph

seed = jax.random.PRNGKey(1234)
chains = 1


fgraph = FunctionGraph(model.free_RVs, [model.logpt])
fns = jax_funcify(fgraph)
logp_fn_jax = fns[0]

rv_names = [rv.name for rv in model.free_RVs]
init_state = [model.test_point[rv_name] for rv_name in rv_names]
init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state)

fns = jax_funcify(fgraph)
logp_fn_jax = fns[0]


MissingInputError                         Traceback (most recent call last)

<ipython-input-9-6e87cadc4860> in <module>()
      6 
      7 
----> 8 fgraph = FunctionGraph(model.free_RVs, [model.logpt])
      9 fns = jax_funcify(fgraph)
     10 logp_fn_jax = fns[0]

2 frames

/usr/local/lib/python3.7/dist-packages/aesara/graph/fg.py in import_node(self, apply_node, check, reason, import_missing)
    400                                 "for more information on this error."
    401                             )
--> 402                             raise MissingInputError(error_msg, variable=var)
    403 
    404         for node in new_nodes:

MissingInputError: Input 0 (TensorConstant{-inf}) of the graph (indices start from 0), used to compute InplaceDimShuffle{x}(TensorConstant{-inf}), was not provided and not given a value. Use the Aesara flag exception_verbosity='high', for more information on this error.
 
Backtrace when that variable is created:

  File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2882, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-5-380e30b89547>", line 8, in <module>
    obs = pm.Normal("obs", mu=theta, sigma=sigma, shape=J, observed=y)
  File "/usr/local/lib/python3.7/dist-packages/pymc3/distributions/distribution.py", line 122, in __new__
    return model.Var(name, dist, data, total_size, dims=dims)
  File "/usr/local/lib/python3.7/dist-packages/pymc3/model.py", line 1182, in Var
    model=self,
  File "/usr/local/lib/python3.7/dist-packages/pymc3/model.py", line 1831, in __init__
    self.logp_sum_unscaledt = distribution.logp_sum(data)
  File "/usr/local/lib/python3.7/dist-packages/pymc3/distributions/distribution.py", line 267, in logp_sum
    return tt.sum(self.logp(*args, **kwargs))
  File "/usr/local/lib/python3.7/dist-packages/pymc3/distributions/continuous.py", line 535, in logp
    return bound((-tau * (value - mu) ** 2 + tt.log(tau / np.pi / 2.0)) / 2.0, sigma > 0)
  File "/usr/local/lib/python3.7/dist-packages/pymc3/distributions/dist_math.py", line 82, in bound
    return tt.switch(alltrue(conditions), logp, -np.inf)
`

from blackjax.

twiecki avatar twiecki commented on May 31, 2024

Thanks for the report @903124. @ricardoV94 any idea?

from blackjax.

ricardoV94 avatar ricardoV94 commented on May 31, 2024

model.logpt takes as inputs model.value_vars (plus shared variables), not model.free_RVs. But you don't need to specify the input variables when building a FunctionGraph. If you just pass the outputs with FunctionGraph(outputs=[model.logpt]), aesara will import the necessary input variables.

from blackjax.

903124 avatar 903124 commented on May 31, 2024

model.logpt takes as inputs model.value_vars (plus shared variables), not model.free_RVs. But you don't need to specify the input variables when building a FunctionGraph. If you just pass the outputs with FunctionGraph(outputs=[model.logpt]), aesara will import the necessary input variables.

Error change to this

from aesara.link.jax.dispatch import jax_funcify
from aesara.graph.fg import FunctionGraph

seed = jax.random.PRNGKey(1234)
chains = 1


fgraph = FunctionGraph(outputs= [model.logpt])
fns = jax_funcify(fgraph)
logp_fn_jax = fns[0]

rv_names = [rv.name for rv in model.free_RVs]
init_state = [model.test_point[rv_name] for rv_name in rv_names]
init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state)

fns = jax_funcify(fgraph)
logp_fn_jax = fns[0]

WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

---------------------------------------------------------------------------

AttributeError                            Traceback (most recent call last)

<ipython-input-7-c0d3406b2bee> in <module>()
      6 
      7 
----> 8 fgraph = FunctionGraph(outputs= [model.logpt])
      9 fns = jax_funcify(fgraph)
     10 logp_fn_jax = fns[0]

3 frames

/usr/local/lib/python3.7/dist-packages/aesara/graph/fg.py in setup_node(self, node)
    212                 " the values must be tuples or lists."
    213             )
--> 214         if node.op.destroy_map and not all(
    215             isinstance(destroy, (list, tuple))
    216             for destroy in node.op.destroy_map.values()

AttributeError: 'DimShuffle' object has no attribute 'destroy_map'

from blackjax.

ricardoV94 avatar ricardoV94 commented on May 31, 2024

Which version of Aesara are you using?

from blackjax.

903124 avatar 903124 commented on May 31, 2024

Right I think I make a mistake from importing pymc but there is a different error again


---------------------------------------------------------------------------

NotImplementedError                       Traceback (most recent call last)

<ipython-input-9-c0d3406b2bee> in <module>()
      7 
      8 fgraph = FunctionGraph(outputs= [model.logpt])
----> 9 fns = jax_funcify(fgraph)
     10 logp_fn_jax = fns[0]
     11 

4 frames

/usr/local/lib/python3.7/dist-packages/aesara/link/jax/dispatch.py in jax_funcify(op, node, storage_map, **kwargs)
    141 def jax_funcify(op, node=None, storage_map=None, **kwargs):
    142     """Create a JAX compatible function from an Aesara `Op`."""
--> 143     raise NotImplementedError(f"No JAX conversion for the given `Op`: {op}")
    144 
    145 

NotImplementedError: No JAX conversion for the given `Op`: TransformedVariable

Code to reproduce it: https://colab.research.google.com/drive/1IeZnP0BtUXuqucl9e3bHMBkIbszjVMld#scrollTo=mUtEBrO8sSVE

from blackjax.

ricardoV94 avatar ricardoV94 commented on May 31, 2024

Still odd, perhaps due to the cloning of inputs? Can you try to call this function directly from pymc to get the jax logp graph?

https://github.com/pymc-devs/pymc/blob/e63cb5b0081adc79382c69862b8893e5ecad46a6/pymc/sampling_jax.py#L73

from blackjax.

903124 avatar 903124 commented on May 31, 2024

Still odd, perhaps due to the cloning of inputs? Can you try to call this function directly from pymc to get the jax logp graph?

https://github.com/pymc-devs/pymc/blob/e63cb5b0081adc79382c69862b8893e5ecad46a6/pymc/sampling_jax.py#L73

get_jaxified_logp(model) indeed does not give out error

from blackjax.

ricardoV94 avatar ricardoV94 commented on May 31, 2024

Still odd, perhaps due to the cloning of inputs? Can you try to call this function directly from pymc to get the jax logp graph?
https://github.com/pymc-devs/pymc/blob/e63cb5b0081adc79382c69862b8893e5ecad46a6/pymc/sampling_jax.py#L73

get_jaxified_logp(model) indeed does not give out error

Is that good enough as a solution?

from blackjax.

rlouf avatar rlouf commented on May 31, 2024

Just to be clear: do we need to change anything in the example notebook?

from blackjax.

903124 avatar 903124 commented on May 31, 2024

To my understanding get_jaxified_logp sample via pymc not blackjax right?

from blackjax.

ricardoV94 avatar ricardoV94 commented on May 31, 2024

To my understanding get_jaxified_logp sample via pymc not blackjax right?

That function does not do any sampling. It simply returns a jax graph of the logp

from blackjax.

903124 avatar 903124 commented on May 31, 2024

Then I think the code example need to slightly modify to accept jaxified_logp instead of jax_funcify

from blackjax.

rlouf avatar rlouf commented on May 31, 2024

Do you want to make the change and open a PR?

from blackjax.

903124 avatar 903124 commented on May 31, 2024

I'm not good enough to fix it maybe someone else can look at it

from blackjax.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.