jax-ml / oryx Goto Github PK
View Code? Open in Web Editor NEWOryx is a library for probabilistic programming and deep learning built on top of Jax.
Home Page: https://tensorflow.org/probability/oryx
License: Apache License 2.0
Oryx is a library for probabilistic programming and deep learning built on top of Jax.
Home Page: https://tensorflow.org/probability/oryx
License: Apache License 2.0
I'd like to play around with oryx
to express a DSL in a library I'm working on - but I can't depend on oryx
until dependencies are bumped.
Version: jax-0.4.1 jaxlib-0.4.1
---------------------------------------------------------------------------
ImportError Traceback (most recent call last)
Cell In[1], line 1
----> 1 from oryx.core.ppl import random_variable
File /opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/oryx/__init__.py:16
1 # Copyright 2020 The TensorFlow Probability Authors.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
(...)
13 # limitations under the License.
14 # ============================================================================
15 """Oryx is a neural network mini-library built on top of Jax."""
---> 16 from oryx import bijectors
17 from oryx import core
18 from oryx import distributions
File /opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/oryx/bijectors/__init__.py:19
16 import inspect
18 from tensorflow_probability.python.experimental.substrates import jax as tfp
---> 19 from oryx.bijectors import bijector_extensions
21 tfb = tfp.bijectors
23 _bijectors = {}
File /opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/oryx/bijectors/bijector_extensions.py:28
26 from six.moves import zip
27 from tensorflow_probability.python.experimental.substrates import jax as tfp
---> 28 from oryx import core
29 from oryx.core.interpreters import inverse
31 safe_map = jax_util.safe_map
File /opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/oryx/core/__init__.py:16
1 # Copyright 2020 The TensorFlow Probability Authors.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
(...)
13 # limitations under the License.
14 # ============================================================================
15 """Contains core logic for Oryx classes."""
---> 16 from oryx.core import ppl
17 from oryx.core.interpreters.inverse import ildj
18 from oryx.core.interpreters.inverse import ildj_registry
File /opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/oryx/core/ppl/__init__.py:17
1 # Copyright 2020 The TensorFlow Probability Authors.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
(...)
14 # ============================================================================
15 # Lint as: python3
16 """Module for probabilistic programming features."""
---> 17 from oryx.core.ppl.transformations import conditional
18 from oryx.core.ppl.transformations import graph_replace
19 from oryx.core.ppl.transformations import intervene
File /opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/oryx/core/ppl/transformations.py:204
201 from jax import util as jax_util
203 from oryx.core import primitive
--> 204 from oryx.core.interpreters import harvest
205 from oryx.core.interpreters import log_prob as lp
207 Program = Callable[..., Any]
File /opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/oryx/core/interpreters/harvest.py:136
134 from jax.interpreters import ad
135 from jax.interpreters import batching
--> 136 from jax.interpreters import masking
137 from jax.interpreters import xla
138 from jax.lib.xla_bridge import xla_client as xc
ImportError: cannot import name 'masking' from 'jax.interpreters' (/opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/jax/interpreters/__init__.py)
In [2]: import oryx
---------------------------------------------------------------------------
ImportError Traceback (most recent call last)
Cell In[2], line 1
----> 1 import oryx
File /opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/oryx/__init__.py:16
1 # Copyright 2020 The TensorFlow Probability Authors.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
(...)
13 # limitations under the License.
14 # ============================================================================
15 """Oryx is a neural network mini-library built on top of Jax."""
---> 16 from oryx import bijectors
17 from oryx import core
18 from oryx import distributions
File /opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/oryx/bijectors/__init__.py:19
16 import inspect
18 from tensorflow_probability.python.experimental.substrates import jax as tfp
---> 19 from oryx.bijectors import bijector_extensions
21 tfb = tfp.bijectors
23 _bijectors = {}
File /opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/oryx/bijectors/bijector_extensions.py:28
26 from six.moves import zip
27 from tensorflow_probability.python.experimental.substrates import jax as tfp
---> 28 from oryx import core
29 from oryx.core.interpreters import inverse
31 safe_map = jax_util.safe_map
File /opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/oryx/core/__init__.py:16
1 # Copyright 2020 The TensorFlow Probability Authors.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
(...)
13 # limitations under the License.
14 # ============================================================================
15 """Contains core logic for Oryx classes."""
---> 16 from oryx.core import ppl
17 from oryx.core.interpreters.inverse import ildj
18 from oryx.core.interpreters.inverse import ildj_registry
File /opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/oryx/core/ppl/__init__.py:17
1 # Copyright 2020 The TensorFlow Probability Authors.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
(...)
14 # ============================================================================
15 # Lint as: python3
16 """Module for probabilistic programming features."""
---> 17 from oryx.core.ppl.transformations import conditional
18 from oryx.core.ppl.transformations import graph_replace
19 from oryx.core.ppl.transformations import intervene
File /opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/oryx/core/ppl/transformations.py:204
201 from jax import util as jax_util
203 from oryx.core import primitive
--> 204 from oryx.core.interpreters import harvest
205 from oryx.core.interpreters import log_prob as lp
207 Program = Callable[..., Any]
File /opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/oryx/core/interpreters/harvest.py:136
134 from jax.interpreters import ad
135 from jax.interpreters import batching
--> 136 from jax.interpreters import masking
137 from jax.interpreters import xla
138 from jax.lib.xla_bridge import xla_client as xc
ImportError: cannot import name 'masking' from 'jax.interpreters' (/opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/jax/interpreters/__init__.py)
Hey! New to poetry
so maybe I'm missing something, but I couldn't get bazel
working, nor could I find any BUILD
files. One thing that did work surprisingly well was pytest-cov
(docs), and using
poetry run pytest-cov
Lets me know there's 95% test coverage, and that all but 4 tests pass on my machine. Those remaining 4 fail with
ValueError: compiling computation that requires 2 logical devices, but only 1 XLA devices are available...
Seems like it wouldn't be too much work to set this up on CI if there's interest (probably just using GH actions and marking xfail on the failing tests). LMK if there's interest and I can put a PR together (not sure if tests are already being run elsewhere!)
Disclaimer: I have not used oryx
yet. Further, not an issue but rather just a question/discussion.
Suppose i want to define some recurrent network but its initial hidden state is not a parameter, i.e. it should be exposed to jax.jit
but not to jax.grad
. How can this be done?
E.g.
# syntax might be slightly wrong, think of it as pseudo-code
def network_def(x):
s = state.variable(..., name="hidden-state")
p = state.variable(..., name="parameters")
s, y = f(s, p, x)
state.assign(s, name="hidden-state")
return y
network = state.init(network_def)(x)
@jax.jit # <- this should "see" hidden-state
@jax.grad # <- this should not "see" hidden-state
def loss_fn(network, x, y):
...
Is there an elegant way of doing that?
Thank you!
Also, are all jax-transformations supported? Readme mentions jit
, grad
, vmap
. What about pmap
,scan
(and all the others) ?
I think Oryx and Blackjax are very good match, would you consider adding an example of using Oryx as a modelling language to generate a logprob that Blackjax can use?
I get the following error when running on the latest jax release:
AttributeError: module 'jax.interpreters.partial_eval' has no attribute 'remat_call_p'
jax version: jax-0.3.25 jaxlib-0.3.25+cuda11.cudnn82
Hi!
I’ve begun to use the harvest transform for a bunch of little things — for the purposes of this conversation, let’s say one of these things is a type of runtime debugger.
I started with a version of harvest which was a little more restrictive than the one in Oryx. But now I want to apply my debugger to more complicated programs, with control flow primitives.
now, this just about reaches the limit of what I would be comfortable maintaining separately (judging by the amount of touching internals)
Is there any talk of pulling harvest into JAX directly? (E.g. similar to experimental.checkify)?
Following the tutorial here, the following line raises an error on colab.
unzip = oryx.core.unzip
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
[<ipython-input-3-a56f63161d11>](https://localhost:8080/#) in <module>
----> 1 unzip = oryx.core.unzip
AttributeError: module 'oryx.core' has no attribute 'unzip'
Hi all!
Will Oryx
continue to be actively maintained? Are there maintainers who are hoping to continue working on the package?
Hi there!
I'm a student working on implementing a Variational Autoencoder using this library, and I came across a bug where a function transformed with ppl.log_prob
produces wrong results when the function's jaxpr has more than one equation, which happens, for instance, when jit
or nest
are used or when the name
argument is passed to ppl.random_variable
.
For instance, consider the following random variables:
def f(rng):
eps = ppl.random_variable(tfd.Normal(0,1))(rng)
y = eps * 2
return y
def g(rng):
eps = ppl.random_variable(tfd.Normal(0,1), name = 'name')(rng)
y = eps * 2
return y
Here, the function f
produces a jaxpr with only one equation, while g
produces one with two equations (one that takes in rng
and outputs a sample from Normal(0,1)
, and another that multiplies that sample by 2).
When transformed by ppl.log_prob
, however, these functions produce different outputs:
ppl.log_prob(f)(0.0)
# -1.6120857
ppl.log_prob(g)(0.0)
# -2.305233
Here, the function f
produces the correct result of log(1/2) + log(1/sqrt(2*pi))
, but the function g
adds the transformation term twice, producing log(1/2) + log(1/2) + log(1/sqrt(2*pi))
.
Since this problem only happens for log_prob
, but not for inverse_and_ildj
, I believe the issue is with the reducer
function shown below:
def reducer(env, eqn, curr_log_prob, new_log_prob):
if (isinstance(curr_log_prob, FailedLogProb)
or isinstance(new_log_prob, FailedLogProb)):
# If `curr_log_prob` is `None` that means we were unable to compute
# a log_prob elsewhere, so the propagate failed.
return failed_log_prob
if eqn.primitive in log_prob_registry and new_log_prob is None:
# We are unable to compute a log_prob for this primitive.
return failed_log_prob
if new_log_prob is not None:
cells = [env.read(var) for var in eqn.outvars]
ildjs = sum([cell.ildj.sum() for cell in cells if cell.top()])
return curr_log_prob + new_log_prob + ildjs
return curr_log_prob
What I believe is happening here is that the equations are processed in the reverse order, and the outvar ildjs are counted twice. For example, considering the function g
above, we have two equations:
eps = nest(random_variable(rng))
y = 2 * eps
First, equation 2 is processed and eps.ildj
assumes the correct value of log(1/2)
.
Then, equation 1 is processed, which has a nest
primitive that triggers another call to propagate
for the random_variable
primitive, that then calls reducer
. However, since eps.ildj
already has the value log(1/2)
it gets added to the state of the equation, which becomes log(1/2) + log(1/sqrt(2*pi))
. I believe that, in this step, the correct value would be only log(1/sqrt(2*pi))
.
After equation 1 is processed, the results are aggregated using the reducer
function. However, while the state assigned to equation 1 is already log(1/2) + log(1/sqrt(2*pi))
, the value eps.ildj
is added again in this step, which leads to the wrong results.
If I replace the loop over eqn.outvars
by one over eqn.invars
, the issue in this case is solved, since in the nested calls to reducer
the cell.ildj
will still be undefined. I'm not totally sure this won't lead to potential problems in other cases, though. In any case, I've opened a pull request which changes this line and also adds a small test case corresponding to this issue, in case this is indeed what's causing the bug.
I'm trying to get the logprob of a expit-transformed logistic-distributed variable but it always returns zero:
import jax.random
from oryx.core.ppl import random_variable, log_prob
from jax.scipy.special import expit
import oryx.distributions as tfd
def simple_sample(key):
a = random_variable(tfd.Logistic(0., 1.))(key)
return expit(a)
x = simple_sample(jax.random.PRNGKey(0))
print(x) # 0.41845703
print(log_prob(simple_sample)(0.5)) # 0.0
print(log_prob(simple_sample)(x)) # 0.0
Versions:
jax-0.4.25
oryx-0.2.6
(Both a exp transformed logistic variable and a expit transformed normal variable seems to work, so there is something special about this combination)
The following code using Haiku
def f(x):
y = hk.Linear(10)(x)
y_hat = harvest.sow(jnp.zeros_like(y), tag='lqcd', name='y')
return jax.nn.sigmoid(y + y_hat)
def _full_f(x, y):
scan_f = hk.experimental.layer_stack(5)(f)
scan_out = scan_f(x)
y_hat = hk.Linear(y.shape[-1])(scan_out)
return jnp.sum((y_hat - y) ** 2) / 2
# ff = _full_f
ff = harvest.call_and_reap(_full_f, tag='lqcd')
full_f = hk.transform(ff)
rng = hk.PRNGSequence(41231)
x = jax.random.normal(next(rng), [5, 10])
y = jax.random.normal(next(rng), [5, 3])
p = full_f.init(next(rng), x, y)
l = full_f.apply(p, next(rng), x, y)
print(l)
Produce an error:
UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[5,10] wrapped in a HarvestTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.Detail: Different traces at same level: Traced<ShapedArray(float32[5,10])>with<HarvestTrace(level=1/0)>, HarvestTrace(level=1/0)
It is unclear why exactly is this happening, as reading the code it seems that harvest
should work with higher order primitives as well.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.