Code Monkey home page Code Monkey logo

jaxutils's Introduction


This project has now been incorporated into GPJax.

CircleCI

JaxUtils provides utility functions for the JaxGaussianProcesses ecosystem.

Contents

PyTree

Overview

jaxutils.PyTree is a mixin class for registering a python class as a JAX PyTree. You would define your Python class as follows.

class MyClass(jaxutils.PyTree):
    ...

Example

import jaxutils

from jaxtyping import Float, Array

class Line(jaxutils.PyTree):
    def __init__(self, gradient: Float[Array, "1"], intercept: Float[Array, "1"]) -> None
        self.gradient = gradient
        self.intercept = intercept

    def y(self, x: Float[Array, "N"]) -> Float[Array, "N"]
        return x * self.gradient + self.intercept

Dataset

Overview

jaxutils.Dataset is a datset abstraction. In future, we wish to extend this to a heterotopic and isotopic data abstraction.

Example

import jaxutils
import jax.numpy as jnp

# Inputs
X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])

# Outputs
y = jnp.array([[7.0], [8.0], [9.0]])

# Datset
D = jaxutils.Dataset(X=X, y=y)

print(f'The number of datapoints is {D.n}')
print(f'The input dimension is {D.in_dim}')
print(f'The output dimension is {D.out_dim}')
print(f'The input data is {D.X}')
print(f'The output data is {D.y}')
print(f'The data is supervised {D.is_supervised()}')
print(f'The data is unsupervised {D.is_unsupervised()}')
The number of datapoints is 3
The input dimension is 2
The output dimension is 1
The input data is [[1. 2.]
 [3. 4.]
 [5. 6.]]
The output data is [[7.]
 [8.]
 [9.]]
The data is supervised True
The data is unsupervised False

You can also add dataset together to concatenate them.

# New inputs
X_new = jnp.array([[1.5, 2.5], [3.5, 4.5], [5.5, 6.5]])

# New outputs
y_new = jnp.array([[7.0], [8.0], [9.0]])

# New dataset
D_new = jaxutils.Dataset(X=X_new, y=y_new)

# Concatenate the two datasets
D = D + D_new

print(f'The number of datapoints is {D.n}')
print(f'The input dimension is {D.in_dim}')
print(f'The output dimension is {D.out_dim}')
print(f'The input data is {D.X}')
print(f'The output data is {D.y}')
print(f'The data is supervised {D.is_supervised()}')
print(f'The data is unsupervised {D.is_unsupervised()}')
The number of datapoints is 6
The input dimension is 2
The output dimension is 1
The input data is [[1.  2. ]
 [3.  4. ]
 [5.  6. ]
 [1.5 2.5]
 [3.5 4.5]
 [5.5 6.5]]
The output data is [[7.]
 [8.]
 [9.]
 [7.]
 [8.]
 [9.]]
The data is supervised True
The data is unsupervised False

jaxutils's People

Contributors

daniel-dodd avatar fonnesbeck avatar thomaspinder avatar

Stargazers

 avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

jaxutils's Issues

dev: Train test split for a dataset

Would be nice to have a train test split akin to scikit-learn, for the Dataset.

import jax.random as jr
from jax.random import KeyArray
from jaxutils import Dataset

# Need to define this function
def train_test_split(data: Dataset, Key: KeyArray, test_size: float, ...) -> Tuple[Dataset, Dataset]
    ...

# Example usage:
data = Dataset(...)
key = jr.PRNGKey(42)
size = 0.3
train, test = train_test_split(data, key, size)

feat: Use PyTreeNode inplace of PyTree

Feature Request

Right now, a custom PyTree object is used. This should be replaced with Flax's PyTreeNode. This would then be consumed by the Parameters object to enable more robust parameter handling.

feat: Scaler for the `Dataset`.

Would be nice to have a Scaler object that scales inputs or and outputs of a jaxutils.Dataset, and that saves the mean and variance, to scale test inputs for later.

from jaxutils import PyTree

class Scaler(PyTree):
  ...

# call method scales data and "fits the scale transform"

train = jaxutils.Dataset(X=..., y=...)
test = jaxutils.Dataset(X=..., y=...)

scaler = Scaler(...)
scaled_train = Scaler(train) # learn the transform
scaled_test = Scaler(test) # scales the test data, under the learnt transform of the train data

feat: Use Benedict for better dictionary management

Feature Request

When handling complex dictionary structures, Benedict makes tasks such as indexing easier. We should transition all dictionaries to use Benedict.

A clear place where this will be helpful is in selecting and updating individual parameters. In Benedict, the syntax could be as simple as

def update_param_value(self, key: str, value: jax.Array) -> None:
    self.params[key] = value

In Benedict, the key could be kernel.lengthscale whereas with regular nested dictionaries, one would have to write something more complex to index with ['kernel']['lengthscale'].

dev: Consider `equinox.Module` (or similar) inplace of `jaxutils.PyTree`

I just stumbled across jaxutils, and spotted jaxutils.PyTree. I can see that this is based off of Distrax's Jittable base class.

I wanted to give a heads-up that Distrax's approach has some performance issues, and some compatibility issues. So I'd really recommend against using it.

Equinox has an equinox.Module which accomplishes the same thing (registering a class as a pytree), and also automatically handles a lot of edge cases. (E.g. bound methods are pytrees too; multiple inheritance works smoothly; good performance; pretty-printing; etc.) I realise I am advertising my own libary here... but hopefully it's of interest!

feat: Dynamic data-set sizes

Feature Request

For some applications with GPs, like Bayesian Optimization, the dataset grows dynamically with time. Unfortunately, dynamic array sizes with Jax jit compiled functions causes the computation to be re-compiled for every different buffer size. This means that the computation will take much longer than should be neccesary...

In my own code I was able to work around the recompilation with dynamic shapes by using a fixed buffer and modifying the Gaussian Process logic through a dynamic masks that treats all data at index i>t as independent of j<=t in the Kernel computation. One downside is of course that all iterations from t=1, ... n, will induce a time and memory complexity proportional to n. For most applications, however, the speed-up provided by jit makes this completely negligible.

I am not sure whether a solution already exists within gpjax as I'm still relatively new to this cool library :).

Describe Preferred Solution

I believe something like this can be implemented as follows, though I haven't yet tried.

  1. Inherit from gpx.Dataset and create a sub-class gpx.OnlineDataset(gpx.Dataset) with a new integer time_step variable and requiring the exact shapes of the data-buffer for initialization.
  2. Add a method to add data to the buffer through jax.ops.
  3. Make a DynamicKernel class that wraps around the standard kernel K computation along the lines of K(a, b, a_idx, b_idx, t) that returns K(a, b) if a_idx <= b_idx <= t and otherwise int(a_idx == b_idx).

Describe Alternatives

NA

Related Code

Example of the jit recompilation based on the Documentation Regression notebook:

import gpjax as gpx
from jax import jit, random
from jax import numpy as jnp


n = 5

x = jnp.linspace(-1, 1, n)[..., None]
y = jnp.sin(x)[..., None]

xtest = jnp.linspace(-2, 2, 100)[..., None]


@jit
def gp_predict(xs, x_train, y_train):
    posterior = gpx.Prior(kernel=gpx.RBF()) * gpx.Gaussian(num_datapoints=len(x_train))
    
    params, *_ = gpx.initialise(
        posterior, random.PRNGKey(0), kernel={"lengthscale": jnp.array([0.5])}
    ).unpack()
    
    post_predictive = posterior(params, gpx.Dataset(X=x_train, y=y_train))
    out_dist = post_predictive(xs)
    
    return out_dist.mean(), out_dist.stddev()


# First call - compile
print('compile')
for i in range(len(x)):
    %time gp_predict(xtest, x[:i+1], y[:i+1])
print()


# Second call - use cached
print('jitted')    
for i in range(len(x)):
    %time gp_predict(xtest, x[:i+1], y[:i+1])

# Output
compile
CPU times: user 519 ms, sys: 1.64 ms, total: 521 ms
Wall time: 293 ms
CPU times: user 1.06 s, sys: 0 ns, total: 1.06 s
Wall time: 316 ms
CPU times: user 956 ms, sys: 17.9 ms, total: 974 ms
Wall time: 219 ms

jitted
CPU times: user 3.66 ms, sys: 443 µs, total: 4.11 ms
Wall time: 2.46 ms
CPU times: user 2.89 ms, sys: 348 µs, total: 3.23 ms
Wall time: 1.84 ms
CPU times: user 894 µs, sys: 0 ns, total: 894 µs
Wall time: 568 µs

Additional Context

Example issue on the Jax: google/jax#2521

If the feature request is approved, would you be willing to submit a PR?

When I have time available I can try and port my solution to the gpjax API, though, I am still quite new to the library.

bug: Check dataset shapes when adding.

The __add__ method that concatenates two jaxutils.Dataset's does not provide any checks for both dataset batch shapes. Though this will error via jnp.concatenate, it would be nice to write a function that checks shapes and gives a clear error message to users.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.