Code Monkey home page Code Monkey logo

mytree's Introduction


This project has now been incorporated into GPJax.

My๐ŸŒณ

PyPI version codecov

"Module pytrees" that cleanly handle parameter trainability and transformations for JAX models.

Installation

pip install mytree

Usage

Defining a model

  • Mark leaf attributes with param_field to set a default bijector and trainable status.
  • Unmarked leaf attributes default to an Identity bijector and trainablility set to True.
from mytree import Mytree, param_field, Softplus, Identity

class SimpleModel(Mytree):
    weight: float = param_field(bijector=Softplus, trainable=False)

    def __init__(self, weight, bias):
        self.weight = weight
        self.bias = bias # Unmarked ๐Ÿ€ attribute `bias`, has `Identity` bijector and trainability set to `True`.
    
    def __call__(self, test_point):
        return test_point * self.weight + self.bias

Dataclasses

Works seamlessly with the dataclasses.dataclass decorators!

from dataclasses import dataclass

@dataclass
class SimpleModel(Mytree):
    weight: float = param_field(bijector=Softplus, trainable=False)
    bias: float
    
    def __call__(self, test_point):
        return test_point * self.weight + self.bias

Replacing values

Update values via replace.

model = SimpleModel(1.0, 2.0)
model.replace(weight=123.0)
SimpleModel(weight=123.0, bias=2.0)

Transformations ๐Ÿค–

Applying transformations

Use constrain / unconstrain to return a Mytree with each parameter's bijector forward / inverse operation applied!

model.constrain()
model.unconstrain()
SimpleModel(weight=1.3132616, bias=2.0)
SimpleModel(weight=0.5413248, bias=2.0)

Replacing transformations

Default transformations can be replaced on an instance via the replace_bijector method.

new = model.replace_bijector(bias=Identity)
new.constrain()
new.unconstrain()
SimpleModel(weight=1.0, bias=2.0)
SimpleModel(weight=1.0, bias=2.0)

And we see that weight's parameter is no longer transformed under the Identity.

Trainability ๐Ÿš‚

Applying trainability

Applying stop_gradient within the loss function, prevents the flow of gradients during forward or reverse-mode automatic differentiation.

import jax

# Create simulated data.
n = 100
key = jax.random.PRNGKey(123)
x = jax.random.uniform(key, (n, ))
y = 3.0 * x + 2.0 + 1e-3 * jax.random.normal(key, (n, ))


# Define a mean-squared-error loss.
def loss(model: SimpleModel) -> float:
   model = model.stop_gradient() # ๐Ÿ›‘ Stop gradients!
   return jax.numpy.sum((y - model(x))**2)
   
jax.grad(loss)(model)
SimpleModel(weight=0.0, bias=-188.37418)

As weight trainability was set to False, it's gradient is zero as expected!

Replacing trainability

Default trainability status can be replaced via the replace_trainable method.

new = model.replace_trainable(weight=True)
jax.grad(loss)(model)
SimpleModel(weight=-121.42676, bias=-188.37418)

And we see that weight's gradient is no longer zero.

Metadata

Viewing field metadata

View field metadata pytree via meta.

from mytree import meta
meta(model)
SimpleModel(weight=({'bijector': Bijector(forward=<function <lambda> at 0x17a024e50>, inverse=<function <lambda> at 0x17a024430>), 1.0), 'trainable': False, 'pytree_node': True}, bias=({}, 2.0))

Or the metadata pytree leaves via meta_leaves.

from mytree import meta_leaves
meta_leaves(model)
[({}, 2.0),
 ({'bijector': Bijector(forward=<function <lambda> at 0x17a024e50>, inverse=<function <lambda> at 0x17a024430>),
  'trainable': False,
  'pytree_node': True}, 1.0)]

Note this shows any metadata defined via a dataclasses.field for the pytree leaves. So feel free to define your own.

Applying field metadata

Leaf metadata can be applied via the meta_map function.

from mytree import meta_map

# Function passed to `meta_map` has its argument as a `(meta, leaf)` tuple!
def if_trainable_then_10(meta_leaf):
    meta, leaf = meta_leaf
    if meta.get("trainable", True):
        return 10.0
    else:
        return leaf

meta_map(if_trainable_then_10, model)
SimpleModel(weight=1.0, bias=10.0)

It is possible to define your own custom metadata and therefore your own metadata transformations in this vein.

Static fields

Since Mytree inherits from simple-pytree's Pytree, fields can be marked as static via simple_pytree's static_field.

import jax.tree_util as jtu
from simple_pytree import static_field

class StaticExample(Mytree):
    b: float = static_field
    
    def __init__(self, a=1.0, b=2.0):
        self.a=a
        self.b=b
    
jtu.tree_leaves(StaticExample())
[1.0]

Performance ๐ŸŽ

Preliminary benchmarks can be found in: https://github.com/Daniel-Dodd/mytree/blob/master/benchmarks/benchmarks.ipynb

mytree's People

Contributors

daniel-dodd avatar thomaspinder avatar

Stargazers

 avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

kylewynne

mytree's Issues

Support prior specification

Incorporate the functionality to attach prior distributions to each parameter and compute the log-prior density for each parameter.

I would be happy to open a PR for this.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.