Code Monkey home page Code Monkey logo

oryx's People

Contributors

alimuldal avatar colcarroll avatar fehiepsi avatar femtomc avatar froystig avatar hawkinsp avatar mattjj avatar sharadmv avatar yashk2810 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

oryx's Issues

Bump `jaxlib` in `oryx`

I'd like to play around with oryx to express a DSL in a library I'm working on - but I can't depend on oryx until dependencies are bumped.

Import error under JAX 4.1

Version: jax-0.4.1 jaxlib-0.4.1

---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
Cell In[1], line 1
----> 1 from oryx.core.ppl import random_variable

File /opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/oryx/__init__.py:16
      1 # Copyright 2020 The TensorFlow Probability Authors.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     13 # limitations under the License.
     14 # ============================================================================
     15 """Oryx is a neural network mini-library built on top of Jax."""
---> 16 from oryx import bijectors
     17 from oryx import core
     18 from oryx import distributions

File /opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/oryx/bijectors/__init__.py:19
     16 import inspect
     18 from tensorflow_probability.python.experimental.substrates import jax as tfp
---> 19 from oryx.bijectors import bijector_extensions
     21 tfb = tfp.bijectors
     23 _bijectors = {}

File /opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/oryx/bijectors/bijector_extensions.py:28
     26 from six.moves import zip
     27 from tensorflow_probability.python.experimental.substrates import jax as tfp
---> 28 from oryx import core
     29 from oryx.core.interpreters import inverse
     31 safe_map = jax_util.safe_map

File /opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/oryx/core/__init__.py:16
      1 # Copyright 2020 The TensorFlow Probability Authors.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     13 # limitations under the License.
     14 # ============================================================================
     15 """Contains core logic for Oryx classes."""
---> 16 from oryx.core import ppl
     17 from oryx.core.interpreters.inverse import ildj
     18 from oryx.core.interpreters.inverse import ildj_registry

File /opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/oryx/core/ppl/__init__.py:17
      1 # Copyright 2020 The TensorFlow Probability Authors.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     14 # ============================================================================
     15 # Lint as: python3
     16 """Module for probabilistic programming features."""
---> 17 from oryx.core.ppl.transformations import conditional
     18 from oryx.core.ppl.transformations import graph_replace
     19 from oryx.core.ppl.transformations import intervene

File /opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/oryx/core/ppl/transformations.py:204
    201 from jax import util as jax_util
    203 from oryx.core import primitive
--> 204 from oryx.core.interpreters import harvest
    205 from oryx.core.interpreters import log_prob as lp
    207 Program = Callable[..., Any]

File /opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/oryx/core/interpreters/harvest.py:136
    134 from jax.interpreters import ad
    135 from jax.interpreters import batching
--> 136 from jax.interpreters import masking
    137 from jax.interpreters import xla
    138 from jax.lib.xla_bridge import xla_client as xc

ImportError: cannot import name 'masking' from 'jax.interpreters' (/opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/jax/interpreters/__init__.py)

In [2]: import oryx
---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
Cell In[2], line 1
----> 1 import oryx

File /opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/oryx/__init__.py:16
      1 # Copyright 2020 The TensorFlow Probability Authors.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     13 # limitations under the License.
     14 # ============================================================================
     15 """Oryx is a neural network mini-library built on top of Jax."""
---> 16 from oryx import bijectors
     17 from oryx import core
     18 from oryx import distributions

File /opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/oryx/bijectors/__init__.py:19
     16 import inspect
     18 from tensorflow_probability.python.experimental.substrates import jax as tfp
---> 19 from oryx.bijectors import bijector_extensions
     21 tfb = tfp.bijectors
     23 _bijectors = {}

File /opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/oryx/bijectors/bijector_extensions.py:28
     26 from six.moves import zip
     27 from tensorflow_probability.python.experimental.substrates import jax as tfp
---> 28 from oryx import core
     29 from oryx.core.interpreters import inverse
     31 safe_map = jax_util.safe_map

File /opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/oryx/core/__init__.py:16
      1 # Copyright 2020 The TensorFlow Probability Authors.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     13 # limitations under the License.
     14 # ============================================================================
     15 """Contains core logic for Oryx classes."""
---> 16 from oryx.core import ppl
     17 from oryx.core.interpreters.inverse import ildj
     18 from oryx.core.interpreters.inverse import ildj_registry

File /opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/oryx/core/ppl/__init__.py:17
      1 # Copyright 2020 The TensorFlow Probability Authors.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     14 # ============================================================================
     15 # Lint as: python3
     16 """Module for probabilistic programming features."""
---> 17 from oryx.core.ppl.transformations import conditional
     18 from oryx.core.ppl.transformations import graph_replace
     19 from oryx.core.ppl.transformations import intervene

File /opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/oryx/core/ppl/transformations.py:204
    201 from jax import util as jax_util
    203 from oryx.core import primitive
--> 204 from oryx.core.interpreters import harvest
    205 from oryx.core.interpreters import log_prob as lp
    207 Program = Callable[..., Any]

File /opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/oryx/core/interpreters/harvest.py:136
    134 from jax.interpreters import ad
    135 from jax.interpreters import batching
--> 136 from jax.interpreters import masking
    137 from jax.interpreters import xla
    138 from jax.lib.xla_bridge import xla_client as xc

ImportError: cannot import name 'masking' from 'jax.interpreters' (/opt/miniconda3/envs/blackjax/lib/python3.10/site-packages/jax/interpreters/__init__.py)

Running tests

Hey! New to poetry so maybe I'm missing something, but I couldn't get bazel working, nor could I find any BUILD files. One thing that did work surprisingly well was pytest-cov (docs), and using

poetry run pytest-cov

Lets me know there's 95% test coverage, and that all but 4 tests pass on my machine. Those remaining 4 fail with

ValueError: compiling computation that requires 2 logical devices, but only 1 XLA devices are available...

Seems like it wouldn't be too much work to set this up on CI if there's interest (probably just using GH actions and marking xfail on the failing tests). LMK if there's interest and I can put a PR together (not sure if tests are already being run elsewhere!)

Specify Module's PyTree-Representation for jit/grad seperately. I.e. How to freeze state.variables

Disclaimer: I have not used oryx yet. Further, not an issue but rather just a question/discussion.

Suppose i want to define some recurrent network but its initial hidden state is not a parameter, i.e. it should be exposed to jax.jit but not to jax.grad. How can this be done?

E.g.

# syntax might be slightly wrong, think of it as pseudo-code
def network_def(x):
  s = state.variable(..., name="hidden-state")
  p = state.variable(..., name="parameters")
  s, y = f(s, p, x)
  state.assign(s, name="hidden-state")
  return y 

network = state.init(network_def)(x)

@jax.jit # <- this should "see" hidden-state
@jax.grad # <- this should not "see" hidden-state
def loss_fn(network, x, y):
  ...

Is there an elegant way of doing that?
Thank you!

Also, are all jax-transformations supported? Readme mentions jit, grad, vmap. What about pmap,scan (and all the others) ?

Incompatible with latest jax version

I get the following error when running on the latest jax release:
AttributeError: module 'jax.interpreters.partial_eval' has no attribute 'remat_call_p'

jax version: jax-0.3.25 jaxlib-0.3.25+cuda11.cudnn82

Pull harvest out into a separate pkg

Hi!

I’ve begun to use the harvest transform for a bunch of little things — for the purposes of this conversation, let’s say one of these things is a type of runtime debugger.

I started with a version of harvest which was a little more restrictive than the one in Oryx. But now I want to apply my debugger to more complicated programs, with control flow primitives.

now, this just about reaches the limit of what I would be comfortable maintaining separately (judging by the amount of touching internals)

Is there any talk of pulling harvest into JAX directly? (E.g. similar to experimental.checkify)?

Colab Tutorial

Following the tutorial here, the following line raises an error on colab.

unzip = oryx.core.unzip

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
[<ipython-input-3-a56f63161d11>](https://localhost:8080/#) in <module>
----> 1 unzip = oryx.core.unzip

AttributeError: module 'oryx.core' has no attribute 'unzip'

`log_prob` doesn't work for functions whose jaxpr has multiple equations.

Hi there!

I'm a student working on implementing a Variational Autoencoder using this library, and I came across a bug where a function transformed with ppl.log_prob produces wrong results when the function's jaxpr has more than one equation, which happens, for instance, when jit or nest are used or when the name argument is passed to ppl.random_variable.

For instance, consider the following random variables:

def f(rng):
    eps = ppl.random_variable(tfd.Normal(0,1))(rng)
    y = eps * 2
    return y

def g(rng):
    eps = ppl.random_variable(tfd.Normal(0,1), name = 'name')(rng)
    y = eps * 2
    return y

Here, the function f produces a jaxpr with only one equation, while g produces one with two equations (one that takes in rng and outputs a sample from Normal(0,1), and another that multiplies that sample by 2).

When transformed by ppl.log_prob, however, these functions produce different outputs:

ppl.log_prob(f)(0.0) 
# -1.6120857

ppl.log_prob(g)(0.0) 
# -2.305233

Here, the function f produces the correct result of log(1/2) + log(1/sqrt(2*pi)), but the function g adds the transformation term twice, producing log(1/2) + log(1/2) + log(1/sqrt(2*pi)).

Since this problem only happens for log_prob, but not for inverse_and_ildj, I believe the issue is with the reducer function shown below:

  def reducer(env, eqn, curr_log_prob, new_log_prob):
    if (isinstance(curr_log_prob, FailedLogProb)
        or isinstance(new_log_prob, FailedLogProb)):
      # If `curr_log_prob` is `None` that means we were unable to compute
      # a log_prob elsewhere, so the propagate failed.
      return failed_log_prob
    if eqn.primitive in log_prob_registry and new_log_prob is None:
      # We are unable to compute a log_prob for this primitive.
      return failed_log_prob
    if new_log_prob is not None:
      cells = [env.read(var) for var in eqn.outvars]
      ildjs = sum([cell.ildj.sum() for cell in cells if cell.top()])
      return curr_log_prob + new_log_prob + ildjs
    return curr_log_prob

What I believe is happening here is that the equations are processed in the reverse order, and the outvar ildjs are counted twice. For example, considering the function g above, we have two equations:

  • Equation 1: eps = nest(random_variable(rng))
  • Equation 2: y = 2 * eps

First, equation 2 is processed and eps.ildj assumes the correct value of log(1/2).

Then, equation 1 is processed, which has a nest primitive that triggers another call to propagate for the random_variable primitive, that then calls reducer. However, since eps.ildj already has the value log(1/2) it gets added to the state of the equation, which becomes log(1/2) + log(1/sqrt(2*pi)). I believe that, in this step, the correct value would be only log(1/sqrt(2*pi)).

After equation 1 is processed, the results are aggregated using the reducer function. However, while the state assigned to equation 1 is already log(1/2) + log(1/sqrt(2*pi)), the value eps.ildj is added again in this step, which leads to the wrong results.

If I replace the loop over eqn.outvars by one over eqn.invars, the issue in this case is solved, since in the nested calls to reducer the cell.ildj will still be undefined. I'm not totally sure this won't lead to potential problems in other cases, though. In any case, I've opened a pull request which changes this line and also adds a small test case corresponding to this issue, in case this is indeed what's causing the bug.

expit of logistic variable gives log_prob 0.0

I'm trying to get the logprob of a expit-transformed logistic-distributed variable but it always returns zero:

import jax.random
from oryx.core.ppl import random_variable, log_prob
from jax.scipy.special import expit
import oryx.distributions as tfd


def simple_sample(key):
    a = random_variable(tfd.Logistic(0., 1.))(key)
    return expit(a)


x = simple_sample(jax.random.PRNGKey(0))
print(x)  # 0.41845703
print(log_prob(simple_sample)(0.5))  # 0.0
print(log_prob(simple_sample)(x))  # 0.0

Versions:
jax-0.4.25
oryx-0.2.6

(Both a exp transformed logistic variable and a expit transformed normal variable seems to work, so there is something special about this combination)

Harvest does not seem to work within scan

The following code using Haiku

def f(x):
  y = hk.Linear(10)(x)
  y_hat = harvest.sow(jnp.zeros_like(y), tag='lqcd', name='y')
  return jax.nn.sigmoid(y + y_hat)


def _full_f(x, y):
  scan_f = hk.experimental.layer_stack(5)(f)
  scan_out = scan_f(x)
  y_hat = hk.Linear(y.shape[-1])(scan_out)
  return jnp.sum((y_hat - y) ** 2) / 2

# ff = _full_f
ff = harvest.call_and_reap(_full_f, tag='lqcd')
full_f = hk.transform(ff)
rng = hk.PRNGSequence(41231)
x = jax.random.normal(next(rng), [5, 10])
y = jax.random.normal(next(rng), [5, 3])
p = full_f.init(next(rng), x, y)
l = full_f.apply(p, next(rng), x, y)
print(l)

Produce an error:

UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[5,10] wrapped in a HarvestTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.Detail: Different traces at same level: Traced<ShapedArray(float32[5,10])>with<HarvestTrace(level=1/0)>, HarvestTrace(level=1/0)

It is unclear why exactly is this happening, as reading the code it seems that harvest should work with higher order primitives as well.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.