Comments (24)
Looks good ! I think it would be nice to add a comparison in terms of sampling time with pm.sample
and pm.sampling_jax
? Using only one chain for fair comparison.
from blackjax.
No problem, I'll do it!
@ricardoV94 I was thinking I could add an aeppl
example as well. Could be used to benchmark aehmc
.
from blackjax.
@kc611 No, we just want to compile a PyMC3 model to JAX through Aesara's existing JAX backend and then pass that to blackjax
. This should only require a couple of lines of code. You can take a look at https://github.com/pymc-devs/pymc3/blob/main/pymc3/sampling_jax.py#L157 how we do this for running a PyMC3 JAX model using numpyro's jax samplers. This here should look quite similar.
from blackjax.
This can be started now but will need to wait until #26 is merged to be completed.
from blackjax.
Could someone in the PyMC team take a stab at this? Generating a JAX-compatible logpdf does not seem so straightforward for someone who is not familiar with the internals looking at the sampling_jax.py
file in the PyMC3 repo.
from blackjax.
@kc611 any interest in looking into this?
from blackjax.
Sure,
Just to be clear, by running BlackJAX
on PyMC3
models we mean to convert an existing pymc3.Model
completely into it's BlackJax
(or Jax
?) variant, (along with it's inputs) similar to how we first approached the Jax
conversion in PyMC, right? Or is there some sort of integration of BlackJax
with PyMC
already present.
from blackjax.
Ah okay, got it.
from blackjax.
So, This is something that I came up with. The original model is from @twiecki 's Radon hierarchical model notebook (from PyMC docs).
https://colab.research.google.com/drive/1lUXPRynGCuusHLpBTnC1tPw-2zw_6NwG?usp=sharing
But I don't know if it's working as expected. (Like, I haven't tried plotting the variables yet).
Note: This is built upon PyMC v3 and theano and will probably require changes when we release v4.
from blackjax.
For PyMC v4 with Aesara backend here is the error:
from aesara.link.jax.dispatch import jax_funcify
from aesara.graph.fg import FunctionGraph
seed = jax.random.PRNGKey(1234)
chains = 1
fgraph = FunctionGraph(model.free_RVs, [model.logpt])
fns = jax_funcify(fgraph)
logp_fn_jax = fns[0]
rv_names = [rv.name for rv in model.free_RVs]
init_state = [model.test_point[rv_name] for rv_name in rv_names]
init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state)
fns = jax_funcify(fgraph)
logp_fn_jax = fns[0]
MissingInputError Traceback (most recent call last)
<ipython-input-9-6e87cadc4860> in <module>()
6
7
----> 8 fgraph = FunctionGraph(model.free_RVs, [model.logpt])
9 fns = jax_funcify(fgraph)
10 logp_fn_jax = fns[0]
2 frames
/usr/local/lib/python3.7/dist-packages/aesara/graph/fg.py in import_node(self, apply_node, check, reason, import_missing)
400 "for more information on this error."
401 )
--> 402 raise MissingInputError(error_msg, variable=var)
403
404 for node in new_nodes:
MissingInputError: Input 0 (TensorConstant{-inf}) of the graph (indices start from 0), used to compute InplaceDimShuffle{x}(TensorConstant{-inf}), was not provided and not given a value. Use the Aesara flag exception_verbosity='high', for more information on this error.
Backtrace when that variable is created:
File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2882, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-5-380e30b89547>", line 8, in <module>
obs = pm.Normal("obs", mu=theta, sigma=sigma, shape=J, observed=y)
File "/usr/local/lib/python3.7/dist-packages/pymc3/distributions/distribution.py", line 122, in __new__
return model.Var(name, dist, data, total_size, dims=dims)
File "/usr/local/lib/python3.7/dist-packages/pymc3/model.py", line 1182, in Var
model=self,
File "/usr/local/lib/python3.7/dist-packages/pymc3/model.py", line 1831, in __init__
self.logp_sum_unscaledt = distribution.logp_sum(data)
File "/usr/local/lib/python3.7/dist-packages/pymc3/distributions/distribution.py", line 267, in logp_sum
return tt.sum(self.logp(*args, **kwargs))
File "/usr/local/lib/python3.7/dist-packages/pymc3/distributions/continuous.py", line 535, in logp
return bound((-tau * (value - mu) ** 2 + tt.log(tau / np.pi / 2.0)) / 2.0, sigma > 0)
File "/usr/local/lib/python3.7/dist-packages/pymc3/distributions/dist_math.py", line 82, in bound
return tt.switch(alltrue(conditions), logp, -np.inf)
`
from blackjax.
Thanks for the report @903124. @ricardoV94 any idea?
from blackjax.
model.logpt
takes as inputs model.value_vars
(plus shared variables), not model.free_RVs
. But you don't need to specify the input variables when building a FunctionGraph
. If you just pass the outputs with FunctionGraph(outputs=[model.logpt])
, aesara will import the necessary input variables.
from blackjax.
model.logpt
takes as inputsmodel.value_vars
(plus shared variables), notmodel.free_RVs
. But you don't need to specify the input variables when building aFunctionGraph
. If you just pass the outputs withFunctionGraph(outputs=[model.logpt])
, aesara will import the necessary input variables.
Error change to this
from aesara.link.jax.dispatch import jax_funcify
from aesara.graph.fg import FunctionGraph
seed = jax.random.PRNGKey(1234)
chains = 1
fgraph = FunctionGraph(outputs= [model.logpt])
fns = jax_funcify(fgraph)
logp_fn_jax = fns[0]
rv_names = [rv.name for rv in model.free_RVs]
init_state = [model.test_point[rv_name] for rv_name in rv_names]
init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state)
fns = jax_funcify(fgraph)
logp_fn_jax = fns[0]
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-7-c0d3406b2bee> in <module>()
6
7
----> 8 fgraph = FunctionGraph(outputs= [model.logpt])
9 fns = jax_funcify(fgraph)
10 logp_fn_jax = fns[0]
3 frames
/usr/local/lib/python3.7/dist-packages/aesara/graph/fg.py in setup_node(self, node)
212 " the values must be tuples or lists."
213 )
--> 214 if node.op.destroy_map and not all(
215 isinstance(destroy, (list, tuple))
216 for destroy in node.op.destroy_map.values()
AttributeError: 'DimShuffle' object has no attribute 'destroy_map'
from blackjax.
Which version of Aesara are you using?
from blackjax.
Right I think I make a mistake from importing pymc but there is a different error again
---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
<ipython-input-9-c0d3406b2bee> in <module>()
7
8 fgraph = FunctionGraph(outputs= [model.logpt])
----> 9 fns = jax_funcify(fgraph)
10 logp_fn_jax = fns[0]
11
4 frames
/usr/local/lib/python3.7/dist-packages/aesara/link/jax/dispatch.py in jax_funcify(op, node, storage_map, **kwargs)
141 def jax_funcify(op, node=None, storage_map=None, **kwargs):
142 """Create a JAX compatible function from an Aesara `Op`."""
--> 143 raise NotImplementedError(f"No JAX conversion for the given `Op`: {op}")
144
145
NotImplementedError: No JAX conversion for the given `Op`: TransformedVariable
Code to reproduce it: https://colab.research.google.com/drive/1IeZnP0BtUXuqucl9e3bHMBkIbszjVMld#scrollTo=mUtEBrO8sSVE
from blackjax.
Still odd, perhaps due to the cloning of inputs? Can you try to call this function directly from pymc to get the jax logp graph?
from blackjax.
Still odd, perhaps due to the cloning of inputs? Can you try to call this function directly from pymc to get the jax logp graph?
get_jaxified_logp(model)
indeed does not give out error
from blackjax.
Still odd, perhaps due to the cloning of inputs? Can you try to call this function directly from pymc to get the jax logp graph?
https://github.com/pymc-devs/pymc/blob/e63cb5b0081adc79382c69862b8893e5ecad46a6/pymc/sampling_jax.py#L73
get_jaxified_logp(model)
indeed does not give out error
Is that good enough as a solution?
from blackjax.
Just to be clear: do we need to change anything in the example notebook?
from blackjax.
To my understanding get_jaxified_logp
sample via pymc not blackjax right?
from blackjax.
To my understanding
get_jaxified_logp
sample via pymc not blackjax right?
That function does not do any sampling. It simply returns a jax graph of the logp
from blackjax.
Then I think the code example need to slightly modify to accept jaxified_logp
instead of jax_funcify
from blackjax.
Do you want to make the change and open a PR?
from blackjax.
I'm not good enough to fix it maybe someone else can look at it
from blackjax.
Related Issues (20)
- Moving `max_num_doublings` argument of NUTS from `build_kernel`to `kernel` HOT 1
- 👋 Blackjax Meeting -
- 👋 Blackjax Meeting - HOT 4
- Implement Multiscale Generalized Hamiltonian Monte Carlo with Delayed Rejection HOT 3
- Simplify `ghmc`
- Generalizing integrators HOT 4
- 👋 Blackjax Meeting -
- Functions to run kernels HOT 15
- NUTS performance concerns on GPU HOT 2
- Refactor proposal.py HOT 1
- Implement the Schrödinger-Föllmer sampler HOT 2
- Specification of sampler HOT 4
- The ESS calculation for 1 chain
- Add progress bar to `run_inference_loop` HOT 2
- Remove `transform` from MCLMC, and place it in `run_inference_loop`
- Add inverse_mass_matrix to MCLMC HOT 3
- Merge dynamic_hmc and hmc HOT 3
- Improve SamplingAlgorithm design for init_fn and step_fn HOT 1
- Separate out Halton proposal length from CHEES HOT 1
- MCLMC Info should not scale kinetic_change again
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from blackjax.