Code Monkey home page Code Monkey logo

chex's Introduction

Chex

CI status docs pypi

Chex is a library of utilities for helping to write reliable JAX code.

This includes utils to help:

  • Instrument your code (e.g. assertions, warnings)
  • Debug (e.g. transforming pmaps in vmaps within a context manager).
  • Test JAX code across many variants (e.g. jitted vs non-jitted).

Installation

You can install the latest released version of Chex from PyPI via:

pip install chex

or you can install the latest development version from GitHub:

pip install git+https://github.com/deepmind/chex.git

Modules Overview

Dataclass (dataclass.py)

Dataclasses are a popular construct introduced by Python 3.7 to allow to easily specify typed data structures with minimal boilerplate code. They are not, however, compatible with JAX and dm-tree out of the box.

In Chex we provide a JAX-friendly dataclass implementation reusing python dataclasses.

Chex implementation of dataclass registers dataclasses as internal PyTree nodes to ensure compatibility with JAX data structures.

In addition, we provide a class wrapper that exposes dataclasses as collections.Mapping descendants which allows to process them (e.g. (un-)flatten) in dm-tree methods as usual Python dictionaries. See @mappable_dataclass docstring for more details.

Example:

@chex.dataclass
class Parameters:
  x: chex.ArrayDevice
  y: chex.ArrayDevice

parameters = Parameters(
    x=jnp.ones((2, 2)),
    y=jnp.ones((1, 2)),
)

# Dataclasses can be treated as JAX pytrees
jax.tree_util.tree_map(lambda x: 2.0 * x, parameters)

# and as mappings by dm-tree
tree.flatten(parameters)

NOTE: Unlike standard Python 3.7 dataclasses, Chex dataclasses cannot be constructed using positional arguments. They support construction arguments provided in the same format as the Python dict constructor. Dataclasses can be converted to tuples with the from_tuple and to_tuple methods if necessary.

parameters = Parameters(
    jnp.ones((2, 2)),
    jnp.ones((1, 2)),
)
# ValueError: Mappable dataclass constructor doesn't support positional args.

Assertions (asserts.py)

One limitation of PyType annotations for JAX is that they do not support the specification of DeviceArray ranks, shapes or dtypes. Chex includes a number of functions that allow flexible and concise specification of these properties.

E.g. suppose you want to ensure that all tensors t1, t2, t3 have the same shape, and that tensors t4, t5 have rank 2 and (3 or 4), respectively.

chex.assert_equal_shape([t1, t2, t3])
chex.assert_rank([t4, t5], [2, {3, 4}])

More examples:

from chex import assert_shape, assert_rank, ...

assert_shape(x, (2, 3))                # x has shape (2, 3)
assert_shape([x, y], [(), (2,3)])      # x is scalar and y has shape (2, 3)

assert_rank(x, 0)                      # x is scalar
assert_rank([x, y], [0, 2])            # x is scalar and y is a rank-2 array
assert_rank([x, y], {0, 2})            # x and y are scalar OR rank-2 arrays

assert_type(x, int)                    # x has type `int` (x can be an array)
assert_type([x, y], [int, float])      # x has type `int` and y has type `float`

assert_equal_shape([x, y, z])          # x, y, and z have equal shapes

assert_trees_all_close(tree_x, tree_y) # values and structure of trees match
assert_tree_all_finite(tree_x)         # all tree_x leaves are finite

assert_devices_available(2, 'gpu')     # 2 GPUs available
assert_tpu_available()                 # at least 1 TPU available

assert_numerical_grads(f, (x, y), j)   # f^{(j)}(x, y) matches numerical grads

See asserts.py documentation to find all supported assertions.

If you cannot find a specific assertion, please consider making a pull request or openning an issue on the bug tracker.

Optional Arguments

All chex assertions support the following optional kwargs for manipulating the emitted exception messages:

  • custom_message: A string to include into the emitted exception messages.
  • include_default_message: Whether to include the default Chex message into the emitted exception messages.
  • exception_type: An exception type to use. AssertionError by default.

For example, the following code:

dataset = load_dataset()
params = init_params()
for i in range(num_steps):
  params = update_params(params, dataset.sample())
  chex.assert_tree_all_finite(params,
                              custom_message=f'Failed at iteration {i}.',
                              exception_type=ValueError)

will raise a ValueError that includes a step number when params get polluted with NaNs or Nones.

Static and Value (aka Runtime) Assertions

Chex divides all assertions into 2 classes: static and value assertions.

  1. static assertions use anything except concrete values of tensors. Examples: assert_shape, assert_trees_all_equal_dtypes, assert_max_traces.

  2. value assertions require access to tensor values, which are not available during JAX tracing (see HowJAX primitives work), thus such assertion need special treatment in a jitted code.

To enable value assertions in a jitted function, it can be decorated with chex.chexify() wrapper. Example:

  @chex.chexify
  @jax.jit
  def logp1_abs_safe(x: chex.Array) -> chex.Array:
    chex.assert_tree_all_finite(x)
    return jnp.log(jnp.abs(x) + 1)

  logp1_abs_safe(jnp.ones(2))  # OK
  logp1_abs_safe(jnp.array([jnp.nan, 3]))  # FAILS (in async mode)

  # The error will be raised either at the next line OR at the next
  # `logp1_abs_safe` call. See the docs for more detain on async mode.
  logp1_abs_safe.wait_checks()  # Wait for the (async) computation to complete.

See this docstring for more detail on chex.chexify().

JAX Tracing Assertions

JAX re-traces JIT'ted function every time the structure of passed arguments changes. Often this behavior is inadvertent and leads to a significant performance drop which is hard to debug. @chex.assert_max_traces decorator asserts that the function is not re-traced more than n times during program execution.

Global trace counter can be cleared by calling chex.clear_trace_counter(). This function be used to isolate unittests relying on @chex.assert_max_traces.

Examples:

  @jax.jit
  @chex.assert_max_traces(n=1)
  def fn_sum_jitted(x, y):
    return x + y

  fn_sum_jitted(jnp.zeros(3), jnp.zeros(3))  # tracing for the 1st time - OK
  fn_sum_jitted(jnp.zeros([6, 7]), jnp.zeros([6, 7]))  # AssertionError!

Can be used with jax.pmap() as well:

  def fn_sub(x, y):
    return x - y

  fn_sub_pmapped = jax.pmap(chex.assert_max_traces(fn_sub, n=10))

See HowJAX primitives work section for more information about tracing.

Warnings (warnigns.py)

In addition to hard assertions Chex also offers utilities to add common warnings, such as specific types of deprecation warnings.

Test variants (variants.py)

JAX relies extensively on code transformation and compilation, meaning that it can be hard to ensure that code is properly tested. For instance, just testing a python function using JAX code will not cover the actual code path that is executed when jitted, and that path will also differ whether the code is jitted for CPU, GPU, or TPU. This has been a source of obscure and hard to catch bugs where XLA changes would lead to undesirable behaviours that however only manifest in one specific code transformation.

Variants make it easy to ensure that unit tests cover different ‘variations’ of a function, by providing a simple decorator that can be used to repeat any test under all (or a subset) of the relevant code transformations.

E.g. suppose you want to test the output of a function fn with or without jit. You can use chex.variants to run the test with both the jitted and non-jitted version of the function by simply decorating a test method with @chex.variants, and then using self.variant(fn) in place of fn in the body of the test.

def fn(x, y):
  return x + y
...

class ExampleTest(chex.TestCase):

  @chex.variants(with_jit=True, without_jit=True)
  def test(self):
    var_fn = self.variant(fn)
    self.assertEqual(fn(1, 2), 3)
    self.assertEqual(var_fn(1, 2), fn(1, 2))

If you define the function in the test method, you may also use self.variant as a decorator in the function definition. For example:

class ExampleTest(chex.TestCase):

  @chex.variants(with_jit=True, without_jit=True)
  def test(self):
    @self.variant
    def var_fn(x, y):
       return x + y

    self.assertEqual(var_fn(1, 2), 3)

Example of parameterized test:

from absl.testing import parameterized

# Could also be:
#  `class ExampleParameterizedTest(chex.TestCase, parameterized.TestCase):`
#  `class ExampleParameterizedTest(chex.TestCase):`
class ExampleParameterizedTest(parameterized.TestCase):

  @chex.variants(with_jit=True, without_jit=True)
  @parameterized.named_parameters(
      ('case_positive', 1, 2, 3),
      ('case_negative', -1, -2, -3),
  )
  def test(self, arg_1, arg_2, expected):
    @self.variant
    def var_fn(x, y):
       return x + y

    self.assertEqual(var_fn(arg_1, arg_2), expected)

Chex currently supports the following variants:

  • with_jit -- applies jax.jit() transformation to the function.
  • without_jit -- uses the function as is, i.e. identity transformation.
  • with_device -- places all arguments (except specified in ignore_argnums argument) into device memory before applying the function.
  • without_device -- places all arguments in RAM before applying the function.
  • with_pmap -- applies jax.pmap() transformation to the function (see notes below).

See documentation in variants.py for more details on the supported variants. More examples can be found in variants_test.py.

Variants notes

  • Test classes that use @chex.variants must inherit from chex.TestCase (or any other base class that unrolls tests generators within TestCase, e.g. absl.testing.parameterized.TestCase).

  • [jax.vmap] All variants can be applied to a vmapped function; please see an example in variants_test.py (test_vmapped_fn_named_params and test_pmap_vmapped_fn).

  • [@chex.all_variants] You can get all supported variants by using the decorator @chex.all_variants.

  • [with_pmap variant] jax.pmap(fn) (doc) performs parallel map of fn onto multiple devices. Since most tests run in a single-device environment (i.e. having access to a single CPU or GPU), in which case jax.pmap is a functional equivalent to jax.jit, with_pmap variant is skipped by default (although it works fine with a single device). Below we describe a way to properly test fn if it is supposed to be used in multi-device environments (TPUs or multiple CPUs/GPUs). To disable skipping with_pmap variants in case of a single device, add --chex_skip_pmap_variant_if_single_device=false to your test command.

Fakes (fake.py)

Debugging in JAX is made more difficult by code transformations such as jit and pmap, which introduce optimizations that make code hard to inspect and trace. It can also be difficult to disable those transformations during debugging as they can be called at several places in the underlying code. Chex provides tools to globally replace jax.jit with a no-op transformation and jax.pmap with a (non-parallel) jax.vmap, in order to more easily debug code in a single-device context.

For example, you can use Chex to fake pmap and have it replaced with a vmap. This can be achieved by wrapping your code with a context manager:

with chex.fake_pmap():
  @jax.pmap
  def fn(inputs):
    ...

  # Function will be vmapped over inputs
  fn(inputs)

The same functionality can also be invoked with start and stop:

fake_pmap = chex.fake_pmap()
fake_pmap.start()
... your jax code ...
fake_pmap.stop()

In addition, you can fake a real multi-device test environment with a multi-threaded CPU. See section Faking multi-device test environments for more details.

See documentation in fake.py and examples in fake_test.py for more details.

Faking multi-device test environments

In situations where you do not have easy access to multiple devices, you can still test parallel computation using single-device multi-threading.

In particular, one can force XLA to use a single CPU's threads as separate devices, i.e. to fake a real multi-device environment with a multi-threaded one. These two options are theoretically equivalent from XLA perspective because they expose the same interface and use identical abstractions.

Chex has a flag chex_n_cpu_devices that specifies a number of CPU threads to use as XLA devices.

To set up a multi-threaded XLA environment for absl tests, define setUpModule function in your test module:

def setUpModule():
  chex.set_n_cpu_devices()

Now you can launch your test with python test.py --chex_n_cpu_devices=N to run it in multi-device regime. Note that all tests within a module will have an access to N devices.

More examples can be found in variants_test.py, fake_test.py and fake_set_n_cpu_devices_test.py.

Using named dimension sizes.

Chex comes with a small utility that allows you to package a collection of dimension sizes into a single object. The basic idea is:

dims = chex.Dimensions(B=batch_size, T=sequence_len, E=embedding_dim)
...
chex.assert_shape(arr, dims['BTE'])

String lookups are translated integer tuples. For instance, let's say batch_size == 3, sequence_len = 5 and embedding_dim = 7, then

dims['BTE'] == (3, 5, 7)
dims['B'] == (3,)
dims['TTBEE'] == (5, 5, 3, 7, 7)
...

You can also assign dimension sizes dynamically as follows:

dims['XY'] = some_matrix.shape
dims.Z = 13

For more examples, see chex.Dimensions documentation.

Citing Chex

This repository is part of the DeepMind JAX Ecosystem, to cite Chex please use the DeepMind JAX Ecosystem citation.

chex's People

Contributors

copybara-github avatar dependabot[bot] avatar dhgarrette avatar graingert avatar hamzamerzic avatar hawkinsp avatar hbq1 avatar jblespiau avatar jeffdonahue avatar kristianholsheimer avatar lkhphuc avatar malcolmreynolds avatar mbrukman avatar mjwillson avatar mtthss avatar pfackeldey avatar ravichandraa-google avatar rchen152 avatar rishabhkabra avatar sauravmaheshkar avatar stompchicken avatar superbobry avatar suryabhupa avatar tomhennigan avatar tomwardio avatar tttc3 avatar yashk2810 avatar yilei avatar yueshengys avatar zafarali avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

chex's Issues

Analog to `flax.struct.dataclass`

Hi chex team,

Is there any potential for something like flax.struct.dataclass in chex?

Basically a kind of dataclass that can mark static arguments.

Two other variations include jax_dataclasses and simple-pytree though a chex official one would be cool for all its benefits of being part of chex.

Thank you!

Consider supporting static attributes in chex.dataclass

from jax import jit
from jax.lax import scan
from tjax import IntegralNumeric, RealNumeric
from tjax.dataclasses import dataclass, field
import chex

def f(carry, _):
  return carry + 1.0, None

@jit
def do_scan(c):
  final, _ = scan(f, c.x, None, c.y)
  return final

@dataclass
class C:
  x: RealNumeric
  y: IntegralNumeric = field(static=True)

print(do_scan(C(1.0, 10)))  # works

@chex.dataclass
class D:
  x: RealNumeric
  y: IntegralNumeric

print(do_scan(D(x=1.0, y=10)))  # fails

`assert_tree_shape_prefix` requires tuple instead of `Sequence`

Hi,

Thanks for sharing this nice package! I especially like all the assertions for pytrees.
However, I came across the following inconsistency in the documentation:
Current situation
According to the docs, the shape_prefix argument of assert_tree_shape_prefix is of type Sequence[int].
However, when I pass sequence such as list (instead of a tuple)

import chex
import jax.numpy as jnp

mytree = {'a': jnp.array([[[1], [2]]])}
chex.assert_tree_shape_prefix(mytree, shape_prefix=[1, 2])  # AssertionError!

The assertion raises an exception:

AssertionError: [Chex] Assertion assert_tree_shape_prefix failed: Tree leaf 'a' has a shape prefix different from expected: (1, 2) != [1, 2].

The error can simply be fixed by using a tuple instead:

chex.assert_tree_shape_prefix(mytree, shape_prefix=(1, 2))  # OK!

I think this is not the only inconsistent function, but I did not check for others.

Desired situation
Ideally, I would like the function to behave like in the docs, so that I can also pass a list. Why? I think being able to choose square brackets (i.e., list) after closing parenthesis helps readability.

I would be interested to hear your opinion and I am happy to contribute a pull request.

Keep up the good work!

Hylke

Specify non-pytree node dataclass fields

Hi,

Thanks for making this awesome library!

Is it possible to specify fields in the chex.dataclass definitions to not include certain fields? This is a feature supported in flax https://flax.readthedocs.io/en/latest/_modules/flax/struct.html#dataclass
which I found to be quite useful when defining data classes with fields (such as JAX functions) that shouldn't be mapped over with dm-tree or jax.tree_map. I am not sure if this is supported out of the box by chex at the moment but is something that I hope would be part of chex.

Dataclass breaks is_leaf function for jax.tree_map

is_leaf is handled correctly with NamedTuple

class NT(NamedTuple):
  a: Any
  b: Any

class Histogram(NamedTuple):
  hist: jnp.ndarray
  bins: jnp.ndarray

def is_leaf(n):
  print(n)
  return isinstance(n, Histogram)

jax.tree_map(lambda x: x, NT(a=Histogram(1, 2), b=Histogram(4, 3)), is_leaf=is_leaf)

Output (as expected)

NT(a=Histogram(hist=1, bins=2), b=Histogram(hist=4, bins=3))
Histogram(hist=1, bins=2)
Histogram(hist=4, bins=3)

Does not work with chex.dataclass

@chex.dataclass
class DC:
  a: Any
  b: Any

jax.tree_map(identity, DC(a=Histogram(1, 2), b=Histogram(4, 3)), is_leaf=is_leaf)

Actual output

DC(a=Histogram(hist=1, bins=2), b=Histogram(hist=4, bins=3))
1
2
4
3

Using variants with pytest

Hi,

First of all thank you for this very useful library !

I have a project in Jax in which I already implemented my tests using pytest. However the possibility that chex.variants offers are too nice to ignore. Simultaneously I would like not to rewrite all my test.

Is there a way to reconcile pytest and chex ?

Thank you again for all the work!
Best,

chex.variants(with_pmap=True) ignores `static_argnames`

The _with_pmap function accepts static_argnums as a parameter, but not static_argnames. This is inconsistent with other variants, such as with_jit and with_device. Crucially, this prevents to test methods that require to pass arguments by name (e.g., Distrax's Distribution.sample())

More generally, it would be best if all variants accepted the same parameters where possible (i.e., where not specific to a single variant) and I would suggest to check all keys in **unused_kwargs against a list of allowed parameters (i.e., the union of the parameters of all variant functions) to prevent silent errors due to e.g., misspells.

ImportError: cannot import name 'DeviceLocalLayout' from 'jax._src.layout' (/opt/conda/lib/python3.10/site-packages/jax/_src/layout.py)

/tmp/ipykernel_34/2874194604.py:15: DeprecationWarning: Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display
from IPython.core.display import display, HTML

ImportError Traceback (most recent call last)
Cell In[27], line 20
18 # Import model definition from big_vision
19 from big_vision.models.proj.paligemma import paligemma
---> 20 from big_vision.trainers.proj.paligemma import predict_fns
22 # Import big vision utilities
23 import big_vision.datasets.jsonl

File /kaggle/working/big_vision_repo/big_vision/trainers/proj/paligemma/predict_fns.py:20
17 import functools
19 from big_vision.pp import registry
---> 20 import big_vision.utils as u
21 import einops
22 import jax

File /kaggle/working/big_vision_repo/big_vision/utils.py:38
36 import flax.jax_utils as flax_utils
37 import jax
---> 38 from jax.experimental.array_serialization import serialization as array_serial
39 import jax.numpy as jnp
40 import ml_collections as mlc

File /opt/conda/lib/python3.10/site-packages/jax/experimental/array_serialization/serialization.py:36
34 from jax._src import sharding
35 from jax._src import sharding_impls
---> 36 from jax._src.layout import Layout, DeviceLocalLayout as DLL
37 from jax._src import typing
38 from jax._src import util

ImportError: cannot import name 'DeviceLocalLayout' from 'jax._src.layout' (/opt/conda/lib/python3.10/site-packages/jax/_src/layout.py)

Numpy Conflict

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
chex 0.1.82 requires numpy>=1.25.0, but you have numpy 1.24.3 which is incompatible.
Successfully installed numpy-1.24.3

Trying to get Diffusers setup - https://huggingface.co/docs/transformers/installation

ModuleNotFoundError: No module named 'jax.numpy'

On installing the latest version of Chex directly from GitHub, the following error pops up when importing chex:

File ".../anaconda3/lib/python3.8/site-packages/chex/__init__.py", line 17, in <module>
    from chex._src.asserts import assert_axis_dimension
  File "...anaconda3/lib/python3.8/site-packages/chex/_src/asserts.py", line 26, in <module>
    from chex._src import asserts_internal as _ai
  File ".../anaconda3/lib/python3.8/site-packages/chex/_src/asserts_internal.py", line 34, in <module>
    from chex._src import pytypes
  File ".../lib/python3.8/site-packages/chex/_src/pytypes.py", line 19, in <module>
    import jax.numpy as jnp
ModuleNotFoundError: No module named 'jax.numpy'

Jax: v0.3.34
jaxlib: v0.1.69

CpuDevice no longer in jax

Hello,

Seems like the newest version of jax (0.3.7) removed some classes that are used here in chex. Should chex upper bound the jax version? I see this conflicting code is not currently on the main branch -- alternatively, maybe a new release can be made?

jax-ml/jax#10326

Strange gap in version

@hbq1 The previous release version was 0.1.8 and the current one is 0.1.81, not 0.1.9 as one would expect. Why is there a gap between versions? Might it be a typo/mistake by any chance?

Test files are in binary package distribution

Test files are not filtered properly. The issue is that setuptools.find_package finds packages not modules while tests are organized as a separate modules. In order to mitigate the issue, one should filter test files manually as follows. This patch are created and tested on chex v0.1.86.

--- a/setup.py	2024-03-19 12:58:11.000000000 +0300
+++ b/setup.py	2024-04-17 15:19:33.593293889 +0300
@@ -15,8 +15,11 @@
 """Install script for setuptools."""
 
 import os
-from setuptools import find_packages
-from setuptools import setup
+from pathlib import Path
+
+from setuptools import find_packages, setup
+from setuptools.command.build_py import build_py as _build_py
+
 
 _CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
 
@@ -40,6 +43,15 @@
     ]
 
 
+class build_py(_build_py):
+
+    def find_package_modules(self, package, package_dir):
+        modules = super().find_package_modules(package, package_dir)
+        return [(pkg, mod, file)
+                for pkg, mod, file in modules
+                if not Path(file).match('**/*_test.py')]
+
+
 setup(
     name='chex',
     version=_get_version(),
@@ -51,7 +63,7 @@
     long_description_content_type='text/markdown',
     author_email='[email protected]',
     keywords='jax testing debugging python machine learning',
-    packages=find_packages(exclude=['*_test.py']),
+    packages=find_packages(),
     install_requires=_parse_requirements(
         os.path.join(_CURRENT_DIR, 'requirements', 'requirements.txt')),
     tests_require=_parse_requirements(
@@ -73,4 +85,5 @@
         'Topic :: Software Development :: Testing :: Unit',
         'Topic :: Software Development :: Libraries :: Python Modules',
     ],
+    cmdclass={'build_py': build_py},
 )

How to version?

Hi there!
Please tell -- how to get version of chex installed in my ubuntu(18.0)

Fake contexts by calling .start() not working

Hi, I tried using both the latest github version and the latest pypi version but in neither using fake contexts by calling .start() works (it did work as a context manager!).
Here some pictures with my problem for fake_pmap and fake_jit:

Captura de pantalla 2020-09-28 a la(s) 18 27 11
Captura de pantalla 2020-09-28 a la(s) 18 27 20

What could be the cause of my problem? Thanks

Chex dataclass defaulting mappable_dataclass=True

To start with, thanks for open sourcing your work on Chex, it's a great tooling library for building robust Jax applications!

As I was upgrading to the latest release 0.0.3, I noticed quite a few of my tests breaking. It happens that the default option mappable_dataclass=True in chex.dataclass is breaking the usual interface of dataclasses (which is clearly expected reading the code documentation!)

I guess probably from the perspective of Deepmind usage, it makes sense to default this option. But from an external user point of view, it is rather surprising to have a dataclass decorator not behaving like a dataclass. I think it would be great to make it clear in the library readme that this option needs to be turned off to get the full dataclass behaviour (or turned it off by default).

TypeError: non-default argument 'value' follows default argument

I'm trying to inherit from a dataclass with an optional id in the parent, however, I am seeing TypeError: non-default argument 'value' follows default argument. This is the minimal code to reproduce.

from typing import Optional
import chex

@chex.dataclass
class Base:
    idx: Optional[int] = None

@chex.dataclass
class Derived(Base):
    value: int

This is an issue with the dataclasses library rather than chex.dataclasses, however, according to https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses#answer-69822584, it is possible to fix by setting kw_only=True.

`chex.dataclass` wrapper causes type error: Expected no arguments to dataclass constructor

Applying the chex.dataclass wrapper to a class yields the following error:

import chex

@chex.dataclass(frozen=True)
class Class:
    x: int

Class(4)
$ pyright --version
pyright 1.1.337
$ python3 --version
Python 3.11.6
$ python3 -c "import chex; print(chex.__version__)"
0.1.85
$ pyright test.py
/Users/user/Desktop/test.py
  /Users/user/Desktop/test.py:7:1 - error: Expected no arguments to "Class" constructor (reportGeneralTypeIssues)
1 error, 0 warnings, 0 informations 

microsoft/pyright#6536 (comment)

This is a bug in the chex library. The chex.dataclass decorator has no type annotations despite the fact that the package contains a "py.typed" marker file.

microsoft/pyright#6536 (comment)

I recommend looking at the stdlib dataclass class in the typeshed dataclass.pyi stub.

https://github.com/python/typeshed/blob/main/stdlib/dataclasses.pyi

`AssertsChexifyTest.test_uninspected_checks` test failure

I'm seeing the following test failure when running the test suite:

============================= test session starts ==============================
platform linux -- Python 3.10.7, pytest-7.1.3, pluggy-1.0.0
rootdir: /build/source
collected 548 items                                                            

chex/chex_test.py .                                                      [  0%]
chex/_src/asserts_chexify_test.py ......F.....                           [  2%]
chex/_src/asserts_internal_test.py .s.s.........                         [  4%]
chex/_src/asserts_test.py ..s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s. [ 13%]
s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s..................... [ 26%]
........................................................................ [ 39%]
........................................................................ [ 52%]
.................................                                        [ 58%]
chex/_src/dataclass_test.py ...........................................  [ 66%]
chex/_src/dimensions_test.py .................                           [ 69%]
chex/_src/fake_set_n_cpu_devices_test.py s                               [ 69%]
chex/_src/fake_test.py ................................                  [ 75%]
chex/_src/restrict_backends_test.py ssssssssss                           [ 77%]
chex/_src/variants_test.py .....................s....s............s....s [ 85%]
..........................................................ssssssssssssss [ 98%]
sssssss                                                                  [100%]

=================================== FAILURES ===================================
__________________ AssertsChexifyTest.test_uninspected_checks __________________

self = <chex._src.asserts_chexify_test.AssertsChexifyTest testMethod=test_uninspected_checks>

    def test_uninspected_checks(self):
    
      @jax.jit
      def _pos_sum(x):
        chex_value_assert_positive(x, custom_message='err_label')
        return x.sum()
    
      invalid_x = -jnp.ones(3)
      chexify_async(_pos_sum)(invalid_x)  # async error
    
>     with self.assertRaisesRegex(AssertionError, 'err_label'):
E     AssertionError: AssertionError not raised

chex/_src/asserts_chexify_test.py:179: AssertionError
------------------------------ Captured log call -------------------------------
WARNING  absl:asserts_chexify.py:57 [Chex] Some of chexify assetion statuses were not inspected due to async exec (https://jax.readthedocs.io/en/latest/async_dispatch.html). Consider calling `chex.block_until_chexify_assertions_complete()` at the end of computations that rely on jitted chex assetions.
=============================== warnings summary ===============================
chex/_src/asserts_chexify_test.py: 12 warnings
  /build/source/chex/_src/asserts_chexify_test.py:58: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
    return jnp.all(jnp.array([(x > 0).all() for x in jax.tree_leaves(tree)]))

chex/_src/asserts_chexify_test.py::AssertsChexifyTest::test_static_assertion__with_jit
chex/_src/asserts_chexify_test.py::AssertsChexifyTest::test_static_assertion__without_jit
  /build/source/chex/_src/asserts_chexify_test.py:86: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
    return sum(x.sum() for x in jax.tree_leaves(tree))

chex/_src/asserts_chexify_test.py::AssertsChexifyTest::test_uninspected_checks
  /nix/store/4y9j6xdkgqwkdx5ki508l175smcjgs9l-python3.10-pytest-7.1.3/lib/python3.10/site-packages/_pytest/unraisableexception.py:78: PytestUnraisableExceptionWarning: Exception ignored in atexit callback: <function _check_if_hanging_assertions at 0x7ffddfe66d40>
  
  Traceback (most recent call last):
    File "/build/source/chex/_src/asserts_chexify.py", line 32, in _check_error
      checkify.check_error(err)
    File "/nix/store/2fcsbc07baqm4mfmibzr1qlh8bfvb6mc-python3.10-jax-0.3.23/lib/python3.10/site-packages/jax/_src/checkify.py", line 476, in check_error
      return assert_p.bind(err, code, payload, msgs=error.msgs)
    File "/nix/store/2fcsbc07baqm4mfmibzr1qlh8bfvb6mc-python3.10-jax-0.3.23/lib/python3.10/site-packages/jax/core.py", line 328, in bind
      return self.bind_with_trace(find_top_trace(args), args, params)
    File "/nix/store/2fcsbc07baqm4mfmibzr1qlh8bfvb6mc-python3.10-jax-0.3.23/lib/python3.10/site-packages/jax/core.py", line 331, in bind_with_trace
      out = trace.process_primitive(self, map(trace.full_raise, args), params)
    File "/nix/store/2fcsbc07baqm4mfmibzr1qlh8bfvb6mc-python3.10-jax-0.3.23/lib/python3.10/site-packages/jax/core.py", line 698, in process_primitive
      return primitive.impl(*tracers, **params)
    File "/nix/store/2fcsbc07baqm4mfmibzr1qlh8bfvb6mc-python3.10-jax-0.3.23/lib/python3.10/site-packages/jax/_src/checkify.py", line 483, in assert_impl
      raise_error(Error(err, code, msgs, payload))
    File "/nix/store/2fcsbc07baqm4mfmibzr1qlh8bfvb6mc-python3.10-jax-0.3.23/lib/python3.10/site-packages/jax/_src/checkify.py", line 123, in raise_error
      raise ValueError(err)
  ValueError: [Chex] chexify assertion failed [err_label] [failed at /build/source/chex/_src/asserts_chexify_test.py:173] (check failed at /build/source/chex/_src/asserts_internal.py:229 (_chex_assert_fn))
  
  During handling of the above exception, another exception occurred:
  
  Traceback (most recent call last):
    File "/build/source/chex/_src/asserts_chexify.py", line 62, in _check_if_hanging_assertions
      block_until_chexify_assertions_complete()
    File "/build/source/chex/_src/asserts_chexify.py", line 51, in block_until_chexify_assertions_complete
      wait_fn()
    File "/build/source/chex/_src/asserts_chexify.py", line 180, in _wait_checks
      _check_error(async_check_futures.popleft().result(async_timeout))
    File "/build/source/chex/_src/asserts_chexify.py", line 40, in _check_error
      raise AssertionError(msg)  # pylint:disable=raise-missing-from
  AssertionError: [Chex] chexify assertion failed [err_label] [failed at /build/source/chex/_src/asserts_chexify_test.py:173] 
  
    warnings.warn(pytest.PytestUnraisableExceptionWarning(msg))

chex/_src/asserts_chexify_test.py::AssertsChexifyTestSuite::test_log_abs_fn_jitted
chex/_src/asserts_chexify_test.py::AssertsChexifyTestSuite::test_log_abs_fn_jitted
chex/_src/asserts_chexify_test.py::AssertsChexifyTestSuite::test_log_abs_fn_jitted
chex/_src/asserts_chexify_test.py::AssertsChexifyTestSuite::test_log_abs_fn_jitted_vmapped
chex/_src/asserts_chexify_test.py::AssertsChexifyTestSuite::test_log_abs_fn_jitted_vmapped
  /build/source/chex/_src/asserts_chexify_test.py:52: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
    if not all((x > 0).all() for x in jax.tree_leaves(tree)):

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=========================== short test summary info ============================
FAILED chex/_src/asserts_chexify_test.py::AssertsChexifyTest::test_uninspected_checks
====== 1 failed, 461 passed, 86 skipped, 20 warnings in 84.47s (0:01:24) =======
error: builder for '/nix/store/f9icjsb9pbz4p8qpsyhp9gq1fvjvwwhz-python3.10-chex-0.1.5.drv' failed with exit code 1;
       last 10 log lines:
       > chex/_src/asserts_chexify_test.py::AssertsChexifyTestSuite::test_log_abs_fn_jitted
       > chex/_src/asserts_chexify_test.py::AssertsChexifyTestSuite::test_log_abs_fn_jitted_vmapped
       > chex/_src/asserts_chexify_test.py::AssertsChexifyTestSuite::test_log_abs_fn_jitted_vmapped
       >   /build/source/chex/_src/asserts_chexify_test.py:52: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
       >     if not all((x > 0).all() for x in jax.tree_leaves(tree)):
       >
       > -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
       > =========================== short test summary info ============================
       > FAILED chex/_src/asserts_chexify_test.py::AssertsChexifyTest::test_uninspected_checks
       > ====== 1 failed, 461 passed, 86 skipped, 20 warnings in 84.47s (0:01:24) =======

I'm using

`isinstance(None, (int, float, chex.Array))` raises error since `chex==0.1.7`

Hello,

Thanks for the useful package. I am hitting an error when using the latest chex. See reproduction instructions below.

pip install chex==0.1.5 && python -c "import chex; print(isinstance(None, (int, float, chex.Array)))"
False
pip install chex==0.1.7 && python -c "import chex; print(isinstance(None, (int, float, chex.Array)))"
raceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/usr/lib/python3.8/typing.py", line 769, in __instancecheck__
    return self.__subclasscheck__(type(obj))
  File "/usr/lib/python3.8/typing.py", line 777, in __subclasscheck__
    raise TypeError("Subscripted generics cannot be used with"
TypeError: Subscripted generics cannot be used with class and instance checks

This causes an error for optax when trying to use inject_hyperparams, which essentially uses import chex; print(isinstance(None, (int, float, chex.Array)))

https://github.com/deepmind/optax/blob/04768d252911d6af4d4d36361930ccd0a54f9160/optax/_src/schedule.py#L589

Chex dataclass throws an exception in Python 3.9

$ python --version
Python 3.9.1
In [1]: import chex

In [2]: @chex.dataclass
   ...: class Parameters:
   ...:   x: chex.ArrayDevice
   ...:   y: chex.ArrayDevice
   ...:
   ...: parameters = Parameters(
   ...:     x=jnp.ones((2, 2)),
   ...:     y=jnp.ones((1, 2)),
   ...: )
   ...:
   ...: # Dataclasses can be treated as JAX pytrees
   ...: jax.tree_map(lambda x: 2.0 * x, parameters)
   ...:
   ...: # and as mappings by dm-tree
   ...: tree.flatten(parameters)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-3-3461a2700932> in <module>
      1 @chex.dataclass
----> 2 class Parameters:
      3   x: chex.ArrayDevice
      4   y: chex.ArrayDevice
      5

~/.virtualenvs/science/lib/python3.9/site-packages/chex/_src/dataclass.py in dataclass(cls, init, repr, eq, order, unsafe_hash, frozen, mappable_dataclass, restricted_inheritance)
    104   if cls is None:
    105     return dcls
--> 106   return dcls(cls)
    107
    108

~/.virtualenvs/science/lib/python3.9/site-packages/chex/_src/dataclass.py in __call__(self, cls)
    147
    148     if self.mappable_dataclass:
--> 149       dcls = mappable_dataclass(dcls, self.restricted_inheritance)
    150
    151     def _from_tuple(args):

~/.virtualenvs/science/lib/python3.9/site-packages/chex/_src/dataclass.py in mappable_dataclass(cls, restricted_inheritance)
     81   if cls.__bases__ == (object,):
     82     # `collections.Mapping` is incompatible with `object`
---> 83     cls.__bases__ = (collections.Mapping,)
     84   else:
     85     cls.__bases__ += (collections.Mapping,)

TypeError: __bases__ assignment: 'Mapping' deallocator differs from 'object'

Allow for nested chex.chexify

Hello, I have a dilemma with chexify - consider the following code:

# If this is not commented out, the second test will fail
# If this is commented out, the first test will fail
@chex.chexify
@jax.jit
def log_safe(x: jnp.array) -> jnp.array:
    chex.assert_trees_all_equal(x > 0, jnp.ones_like(x, dtype=bool))
    return jnp.log(x)

@chex.chexify
@jax.jit
def combo_safe(x: jnp.array) -> jnp.array:
    chex.assert_trees_all_equal(x != 1, jnp.ones_like(x, dtype=bool))
    return log_safe(x) / (x - 1)


def test_log_safe() -> None:
    x = jnp.array([1.0, 2.0, 3.0, -1.0])
    with pytest.raises(Exception):
        log_safe(x)
        log_safe.wait_checks()

    x = jnp.array([1.0, 2.0, 3.0, 4.0])
    assert jnp.array_equal(log_safe(x), jnp.log(x))
    log_safe.wait_checks()

def test_combo_safe() -> None:
    x = jnp.array([1.0, 2.0, 3.0, 4.0])
    with pytest.raises(Exception):
        combo_safe(x)
        combo_safe.wait_checks()

    x = jnp.array([2.0, 3.0, 4.0, 5.0])
    assert jnp.array_equal(combo_safe(x), jnp.log(x) / (x - 1))
    combo_safe.wait_checks()

If I comment out the first chexify the test_log_safe test will fail with RuntimeError: Value assertions can only be called from functions wrapped with @chex.chexify. See the docs. which makes sense to me. However, once I add the decorator back in, the second test fails with RuntimeError: Nested @chexify wrapping is disallowed. Make sure that you only wrap the function at the outermost level.

A hack in this simple scenario would be to make two versions of the function, a log_safe without the chexify decorator and a log_safe_test = chex.chexify(log_safe) and only call the log_safe_test version during my test. However, that solution is pretty clumsy, especially if I have a lot of these scenarios. In a codebase that is fully end-to-end jax, that would mean all but the outermost function require this hack. Would it be possible to allow for nested chex.chexify where subsequent applications of the macro simply do nothing, or just raise a warning?

chex.disable_asserts() is ignored by chex.assert_max_traces

The assert max traces decorator still raises an assertion error when chex is configured to disable assertions with chex.disable_assertions().

Code to reproduce:

import jax
import jax.numpy as jnp
import chex
chex.disable_asserts()

@jax.jit
@chex.assert_max_traces(n=1)
def f(x):
    return x

chex.assert_equal_shape(jnp.zeros((1)), jnp.zeros((2,))) # correctly ignored

f(jnp.zeros((1,)))
f(jnp.zeros((2,))) # AssertionError: [Chex] Function 'f' is traced > 1 times!

Support wrapping functools.partial() objects

Example:

chex.chexify(functools.partial(fn, foo='bar'))

Error:

AttributeError: 'functools.partial' object has no attribute '__name__'
WARNING:absl:[Chex] Some of chexify assetion statuses were not inspected due to async exec (https://jax.readthedocs.io/en/latest/async_dispatch.html). Consider calling `chex.block_until_chexify_assertions_complete()` at the end of computations that rely on jitted chex assetions.

Typing issue with chex.dataclass

Static type checkers like pyright, mypy, etc. will think chex.dataclass-decorated dataclass has a constructor with no parameters.

Example:

@chex.dataclass(frozen=True)
class Foo:
    a: int
    b: int

image

However, Foo() is not a valid call: A legitimate call would be something like Foo(a=1, b=2). This does not agree with static type checker's analysis.

Compare the behavior with built-in dataclass:

image

Better error report for max traces exceeded

In my experience, when chex reports max traces exceeded, it's usually because of me passing parameters to the function with different shapes or data types. Is it possible for chex to report such inconsistency?

e.g.,

AssertionError: [Chex] Function '_wrapper' is traced > 1 times!
Difference in input shapes. Last time variable `x` traced with shape "(10, 1)", this time traced with shape "(9, 1)".

chex.Dimensions API enhancement

I would like to propose an API enhancement that allow the use of chex.Dimensions inside function annotations. If there is interest I'd like to contribute. Example below:

dims = chex.Dimensions(B=batch_size, T=sequence_len, E=embedding_dim)
...
def foo(arr: chex.Array):
     chex.assert_shape(arr, dims['BTE'])
     # fn logic

### turns into ###

def foo(arr: chex.Array(dims['BTE'])): # behind the scenes assert on function call
     # fn logic

This is particularly useful for dataclasses e.g.

dims = chex.Dimensions(B=batch_size, T=rollout_len)

# asserts are run on instantiation
class TimeStep:
     q_values: chex.Array(dims['BT']) 
     discounts: chex.Array(dims['BT']) 
     rewards: chex.Array(dims['BT']) 

Pros:

  • reduces clutter that asserts can add
  • allows user to view the shape expected by function or class in editor (not sure what you call the VScode popup)
    • example: using RLax, in order to know what shape is expected for each arg in a loss fn you need to either look at source code or wait for fn call to raise an assert

Cons:

  • increased API complexity
  • ...?

post_init error in inherited dataclass

When inheriting one dataclass from another, Chex's dataclass does not allow a super() call to be made. This is something you can do in Python's base dataclass module.

A minimum working example is

from chex import dataclass as dataclass

@dataclass
class ChexBase:
    a : int 

    def __post_init__(self):
        self.b = self.a + 1

@dataclass
class ChexSub(ChexBase):
    a: int 

    def __post_init__(self):
        super().__post_init__()
        self.c = self.a + 2

temp = ChexSub(a = 1)
temp.b

Importing dataclass from dataclasses runs without error and returns 2, as expected.

Environment

  • Chex version 0.1.5
  • Ubuntu 20.04
  • Python 3.9

conda package dependencies

Apologies if this is the wrong place to post this.

The problem is that the dependencies for the chex package on conda-forge is incorrect: there is a typo in the jax version required.

conda search chex==0.1.7 --channel conda-forge --info

returns

chex 0.1.7 pyhd8ed1ab_0
-----------------------
file name   : chex-0.1.7-pyhd8ed1ab_0.conda
name        : chex
version     : 0.1.7
build       : pyhd8ed1ab_0
build number: 0
size        : 70 KB
license     : Apache-2.0
subdir      : noarch
url         : https://conda.anaconda.org/conda-forge/noarch/chex-0.1.7-pyhd8ed1ab_0.conda
md5         : 7d643a09cac375aab18872f92db3b78c
timestamp   : 2023-03-27 14:01:47 UTC
dependencies: 
  - absl-py >=0.9.0
  - dm-tree >=0.1.5
  - jax >=0.1.55
  - jaxlib >=0.1.37
  - numpy >=1.18.0
  - python >=3.6
  - toolz >=0.9.0
  - typing_extensions >=4.2.0

But the jax version should be >=0.4.6

Next release

chex v0.1.83 is failing with the latest jax (0.4.19) as jax.core.Shape.
This issue has been fixed on the master branch of chex in this commit.

Could you please make a release so as to ship this fix ?

Thank you very much !

Tuple not recognized as a valid chex.ArrayTree

def print_pytree(pytree: chex.ArrayTree):
  print(pytree)


def main(argv):
  # does not work
  # tree_x = (-1.0, {"k1": 1.0, "k2": (1.0, 1.0)}, 1.0)

  # works
  tree_x = [-1.0, {"k1": 1.0, "k2": (1.0, 1.0)}, 1.0]
  print_pytree(tree_x)

`jax.random.key` tree comparison results in a `ZeroDivision` warning

See below:

import jax
import chex

chex.assert_trees_all_equal(jax.random.key(0), jax.random.key(0))
> RuntimeWarning: divide by zero encountered in equal val = comparison(x, y)

# Runs fine
chex.assert_trees_all_equal(jax.random.PRNGKey(0), jax.random.PRNGKey(0))

The problem is not jax per se, since key comparison works:

jax.random.key(0) == jax.random.key(1)
> Array(False, dtype=bool)

jax.random.key(0) == jax.random.key(0)
> Array(True, dtype=bool)

Mypy index type error with `chex.dataclass`

According to the docs, by default a class wrapped with chex.dataclass can be indexed, because the dataclass becomes compatible with collections.abc.Mapping (because mappable_dataclass=True).
However, mypy doesn't seem to understand this. For example:

import chex

@chex.dataclass
class Container:
    foo: float

c = Container(foo=1.)
d = c.foo  # OK.
e = c['foo']  # error: Value of type "Container" is not indexable  [index]

Looking at the code, it seems that this is related to methods such as __getitem__ that are added dynamically with setattr which mypy doesn't recognise.
Any ideas how to go around this apart for (i) explicitly silencing the error (ii) using a different method of accessing the variables?

Keep up the good work!
Hylke

AttributeError: module 'jax' has no attribute '_src'

trying to import optax and getting an error AttributeError: module 'jax' has no attribute '_src' for jax versions > 0.3.17

optax version == 0.1.3
chex version == 0.1.3

In [1]: import optax
/home/penn/anaconda3/envs/jax/lib/python3.10/site-packages/chex/_src/pytypes.py:37: FutureWarning: jax.tree_structure is deprecated, and will be removed in a future release. Use jax.tree_util.tree_structure instead.
  PyTreeDef = type(jax.tree_structure(None))
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Input In [1], in <cell line: 1>()
----> 1 import optax

File ~/anaconda3/envs/jax/lib/python3.10/site-packages/optax/__init__.py:17, in <module>
      1 # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Optax: composable gradient processing and optimization, in JAX."""
---> 17 from optax import experimental
     18 from optax._src.alias import adabelief
     19 from optax._src.alias import adafactor

File ~/anaconda3/envs/jax/lib/python3.10/site-packages/optax/experimental/__init__.py:20, in <module>
      1 # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Experimental features in Optax.
     16 
     17 Features may be removed or modified at any time.
     18 """
---> 20 from optax._src.experimental.complex_valued import split_real_and_imaginary
     21 from optax._src.experimental.complex_valued import SplitRealAndImaginaryState

File ~/anaconda3/envs/jax/lib/python3.10/site-packages/optax/_src/experimental/complex_valued.py:32, in <module>
     15 """Complex-valued optimization.
     16 
     17 When using `split_real_and_imaginary` to wrap an optimizer, we split the complex
   (...)
     27 See details at https://github.com/deepmind/optax/issues/196
     28 """
     30 from typing import NamedTuple, Union
---> 32 import chex
     33 import jax
     34 import jax.numpy as jnp

File ~/anaconda3/envs/jax/lib/python3.10/site-packages/chex/__init__.py:17, in <module>
      1 # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Chex: Testing made fun, in JAX!"""
---> 17 from chex._src.asserts import assert_axis_dimension
     18 from chex._src.asserts import assert_axis_dimension_comparator
     19 from chex._src.asserts import assert_axis_dimension_gt

File ~/anaconda3/envs/jax/lib/python3.10/site-packages/chex/_src/asserts.py:26, in <module>
     23 import unittest
     24 from unittest import mock
---> 26 from chex._src import asserts_internal as _ai
     27 from chex._src import pytypes
     28 import jax

File ~/anaconda3/envs/jax/lib/python3.10/site-packages/chex/_src/asserts_internal.py:32, in <module>
     29 from typing import Any, Sequence, Union, Callable, Optional, Set, Tuple, Type
     31 from absl import logging
---> 32 from chex._src import pytypes
     33 import jax
     34 import jax.numpy as jnp

File ~/anaconda3/envs/jax/lib/python3.10/site-packages/chex/_src/pytypes.py:44, in <module>
     40 Device = jax.lib.xla_extension.Device
     42 ArrayTree = Union[Array, Iterable['ArrayTree'], Mapping[Any, 'ArrayTree']]
---> 44 ArrayDType = jax._src.numpy.lax_numpy._ScalarMeta

AttributeError: module 'jax' has no attribute '_src'

No attribute 'KeyArray' when importing chex

Hi,
When I try to import chex, I got the following error:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/chuaraym/cvpr/lib/python3.8/site-packages/chex/__init__.py", line 17, in <module>
    from chex._src.asserts import assert_axis_dimension
  File "/home/chuaraym/cvpr/lib/python3.8/site-packages/chex/_src/asserts.py", line 26, in <module>
    from chex._src import asserts_internal as _ai
  File "/home/chuaraym/cvpr/lib/python3.8/site-packages/chex/_src/asserts_internal.py", line 32, in <module>
    from chex._src import pytypes
  File "/home/chuaraym/cvpr/lib/python3.8/site-packages/chex/_src/pytypes.py", line 36, in <module>
    PRNGKey = jax.random.KeyArray
AttributeError: module 'jax.random' has no attribute 'KeyArray'

I am using jax version 0.2.14 and jaxlib 0.1.68. Did they cause the error?

without_jit=True for already jitted functions

In most JAX-based implementations, jit is almost always included. Basically, if there is no reason not to use it, people will try to take advantage of its speedup.

I noticed that @chex.variants(with_jit=True, without_jit=True) is a great way to assert the same behavior for both execution paths, as long as the variant is derived from a non-jitted function.

In the following example, I would expect to see "Tracing fn" four times total: Three times for the non-jitted variants and once for the initial jit compiliation. In reality, test_variant_pre_jitted() is executed twice with the jitted fn, resulting in two tracer outputs.

@chex.variants(with_jit=True, without_jit=Truue)
def test_variant_pre_jitted(self):
  @jit
  def fn(x, y):
    print("Tracing fn")
    return x + y

  var_fn = self.variant(fn)
  self.assertEqual(var_fn(1, 2), 3)
  self.assertEqual(var_fn(3, 4), 7)
  self.assertEqual(var_fn(5, 6), 11)

Of course, omitting @jit will lead to the expected behavior. However, when more complex implementations already make use of jit, variants do not make sense anymore, sadly.

My case is the latter and I only see the option of implementing a model-wide use_jit flag so that I can derive variants from non-jitted code. However, this makes the whole idea of variants rather obsolete altogether.

I'm aware this could well be a limitation of JAX and jit itself rather than chex. In that case, I think an error when jitted code is passed to variant() would make this more transparent.

fake_pmap_and_jit has a confusing interface

I spend quite some time figuring out why code in a large codebase was so slow, only to find out that jit was disabled throughout the entire project. This was because the main function was called as follows:

with chex.fake_pmap_and_jit(FLAGS.debug):
  main()

While on first sight it appears as if this indeed disables both pmap and jit if flag debug is set, this in fact only disables pmap and always disables jit!

The reason is that fake_pmap_and_jit take two positional arguments that disable respectively pmap and jit, and they are both True by default. The names of these arguments are somewhat cryptic to me as well: enable_pmap_patching and enable_jit_patching, which actually disable these JAX transformations.

Given these observations, I think the situation would improve if the signature would be:

def fake_pmap_and_jit(*, disable_pmap: bool = True, disable_jit: bool = True)

Then my code above would then look like this:

with chex.fake_pmap_and_jit(disable_pmap=FLAGS.debug):
  main()

Which shows clearly we are not setting disable_jit, so we would rewrite this to:

with chex.fake_pmap_and_jit(disable_pmap=FLAGS.debug, disable_jit=FLAGS.debug):
  main()

DeprecationWarning for importing toolz

  /home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/toolz/compatibility.py:2: DeprecationWarning: The toolz.compatibility module is no longer needed in Python 3 and has been deprecated. Please import these utilities directly from the standard library. This module will be removed in a future release.
    warnings.warn("The toolz.compatibility module is no longer "

[REQ] Conda recipe

Hi,
I'm the lead developer of NetKet, an established machine learning / quantum physics package.

We have recently finished rewriting our core to be based on Jax (and flax), and recently released a beta version.
Since many physicists seem to use anaconda, we would also like to update our conda recipe.
However, since we depend on optax (and therefore on Chex), we would need Chex to have a Conda recipe.

Is that something you'd consider? I am willing to volunteer some work to help you.

I tried creating a recipe starting from your pypi source distribution, but that is problematic because you don't bundle your requirements.txt file, which is required to run setup.py.
I could create a recipe from the tag tarballs on GitHub, but that sometimes prevent the conda packages from auto-updating the recipe for later releases.

Dill pickling chex.dataclass blows the stack

Chex dataclasses can be pickled with pickle, but not with dill:

import dill
import pickle

import chex


@chex.dataclass
class Point:
  x: float
  y: float


# Works fine.
pickle.dumps(Point(x=1.0, y=2.0))

# Generates `RecursionError: maximum recursion depth exceeded`
dill.dumps(Point(x=1.0, y=2.0))

Missing package dependency typing-extension

Hi,

The latest version (chex==0.1.7) is missing the typing-extension dependency.
Problem
For example, the following python code

from chex import dataclass

raises the error

  File "<string>", line 1, in <module>
  File "/usr/local/lib/python3.11/dist-packages/chex/__init__.py", line 72, in <module>
    from chex._src.dataclass import dataclass
  File "/usr/local/lib/python3.11/dist-packages/chex/_src/dataclass.py", line 23, in <module>
    from typing_extensions import dataclass_transform  # pytype: disable=not-supported-yet
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ModuleNotFoundError: No module named 'typing_extensions'

on ubuntu 23.04 (with Python 3.11)

Solution
I found that installing typing-extension resolves the error:

pip3 install typing-extension

Reproduce
The following code reproduces the error:

docker run -it ubuntu:23.04 \
  bash -c \
  "apt update && apt install --assume-yes python3-pip && pip3 install --break-system-packages chex && python3 -c 'from chex import dataclass'"

Breaking for jax 0.4.24

Can you update chex to work with the newest version of jax.

From jax.random: PRNGKeyArray, KeyArray,
default_prng_impl, threefry_2x32, threefry2x32_key, threefry2x32_p, rbg_key, and unsafe_rbg_key.

This breaks the completely breaks the import of chex and makes many packages unusable.

Error with Pydantic

Hello!
I'm interested in using pydantic's recursive constructor / asdict functionality, but jax.jit-ed functions give the following error:

Argument '_Pydantic_OptimConfig_93971134241088(.. SOMETHING HERE...)' of type <class 'pydantic.dataclasses._Pydantic_OptimConfig_93971134241088'> is not a valid JAX type.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.