Code Monkey home page Code Monkey logo

Comments (5)

ASEM000 avatar ASEM000 commented on August 18, 2024

As of version 0.8.0

TLDR;

as of version 0.8.0 use

import pytreeclass as pytc
import jax


@pytc.autoinit
class Tree(pytc.TreeClass):
    frozen_a: int = pytc.field(on_getattr=[pytc.unfreeze], on_setattr=[pytc.freeze])

    def __call__(self, x):
        return self.frozen_a + x


tree = Tree(frozen_a=1)  # 1 is non-jaxtype
# can be used in jax transformations


@jax.jit
def f(tree, x):
    return tree(x)


print(f(tree, 1.0))  # 2.0
print(jax.grad(f)(tree, 1.0))  # Tree(frozen_a=#1)
print(jax.tree_util.tree_leaves(tree))  # []

More details into about the freezing/unfreezing mechanism:

If you prefer manual masking, you could apply pytc.freeze on the value directly. But you have to use is_leaf=pytc.is_frozen if you want to interact with this value using tree_map

Using this style, the end user will only have to unmask before calling. At the same time, having access to the masked values using is_leaf=pytc.is_frozen.

You can do something like this:

Style 1: with no init body, callbacks here is a list of functions applied on your in_features before setting it to the instance.

import pytreeclass as pytc
class Tree(pytc.TreeClass):
     in_features: int = pytc.field(callbacks=[pytc.freeze])

Style 2: with init body

class Tree(pytc.TreeClass):
    def __init__(self, in_features: int):
        # Some logic using in_features
        # ...

        # Lastly you freeze it
        self.in_features = pytc.freeze(in_features)

    def __call__(self, x:float):
        return x * self.in_features

t1 = Tree(2)

@jax.value_and_grad
def jax_func(tree:Tree):
    tree = jax.tree_map(pytc.unfreeze, tree, is_leaf=pytc.is_frozen)
    return tree(1.0)

jax_func(t1)
# (2.0, Tree(in_features=#2)) # ->`#` is frozen marker

For background, an earlier version of pytreeclass had static field-like behaviour, but this has three problems:

1 . Even if these fields are constants, Using static_field, you will lose the ability to filter your models based on that always-non-trainable field using jax.tree_map.
2. .at uses jax.tree_map under the hood, if I let the user designate a permanently static field, then this will have an asymmetric design. For example, if you select a as a static field for model nn, then nn.a will work while nn.at['a'].get() will not work at all.
3. static_field will lead to repetitive code because you have to declare it twice as a field and inside the init body. something like this: (from equinox conv code)

class Conv(Module):
    """General N-dimensional convolution."""

    num_spatial_dims: int = static_field()
    weight: Array
    bias: Optional[Array]
    in_channels: int = static_field()
    out_channels: int = static_field()
    kernel_size: Tuple[int, ...] = static_field()
    stride: Tuple[int, ...] = static_field()
    padding: Tuple[Tuple[int, int], ...] = static_field()
    dilation: Tuple[int, ...] = static_field()
    groups: int = static_field()
    use_bias: bool = static_field()

    def __init__(
        self,
        num_spatial_dims: int,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, Sequence[int]],
        stride: Union[int, Sequence[int]] = 1,
        padding: Union[int, Sequence[int], Sequence[Tuple[int, int]]] = 0,
        dilation: Union[int, Sequence[int]] = 1,
        groups: int = 1,
        use_bias: bool = True,
        *,
        key: PRNGKey,
        **kwargs,
    ):

This gets worse as you write more and more code.

Lastly, pytc.freeze is just a pytree with no leaves yielded during the flattening rule. So you can use pytc.freeze on any pytree ( no special treatment inside a TreeClass ).

This design eliminates static field logic during the flattening/unflattening of a tree, leading to faster flattening/unflattening for non-masked trees and simplifying the code.
Let me know if this answers your question.

from pytreeclass.

adam-hartshorne avatar adam-hartshorne commented on August 18, 2024

So if I am understanding this correctly,

tree = jax.tree_map(pytc.unfreeze, tree, is_leaf=pytc.is_frozen)

needs to be called prior to any call to a pytc.TreeClass containing static / frozen variables. So if I have a class within a class within a class all containing frozen variables (or a class that contains numerous other classes which utilise frozen variables), for each call to methods of that class they must have this unfreezing.

That doesn't seem ideal, when you start to get much more complicated models or wish to build a library of functions (as every class would need to be "wrapped" to hide this from the user).

from pytreeclass.

ASEM000 avatar ASEM000 commented on August 18, 2024

For a deeply nested instance with frozen attributes all over the place, you need to write it once (usually inside your loss function) , something like this.

from typing import Any
import pytreeclass as pytc
import jax


class A(pytc.TreeClass):
    a: int = pytc.freeze(1)
    b: float = 2.0

    def __call__(self, x):
        return self.a * x + self.b


class B(pytc.TreeClass):
    c: int = pytc.freeze(1)
    d: A = A()

    def __call__(self, x):
        return self.c * x + self.d(x)

b = B()
# B(c=#1, d=A(a=#1, b=2))

@jax.jit
@jax.value_and_grad
def loss_func(b: B):
    b = jax.tree_map(pytc.unfreeze, b, is_leaf=pytc.is_frozen)
    return b(1.0)


loss_func(b)
# (Array(4., dtype=float32, weak_type=True),
#  B(c=#1, d=A(a=#1, b=f32[](μ=1.00, σ=0.00, ∈[1.00,1.00]))))

For comparison, under the hood, equinox filter decorated functions do something similar on two steps:
first equinox splits the tree to trainable/non-trainable parts before the Jax boundary, then combines it inside the jax function for each call. pytreeclass scheme should be faster because you only do one step.

import equinox as eqx 
import jax 
import pytreeclass as pytc
import jax.numpy as jnp

class TreeEqx(eqx.Module):
    a:int  = eqx.static_field(default=1)
    b:jax.Array = jnp.array(1.)

class TreePyTC(pytc.TreeClass):
    a:int  = pytc.freeze(1)
    b:jax.Array = jnp.array(1.)

tree = TreePyTC()

@jax.jit
def some_func(t):
    t = jax.tree_map(pytc.unfreeze, t, is_leaf=pytc.is_frozen)
    return t.a + t.b

%timeit some_func(tree)
# 12.1 µs ± 836 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

tree = TreeEqx()

@eqx.filter_jit
def some_func(t):
    return t.a + t.b

%timeit some_func(tree)
# 26.7 µs ± 5.86 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

let me know if you have any questions.

from pytreeclass.

adam-hartshorne avatar adam-hartshorne commented on August 18, 2024

Arh ok, I understand. That is obviously more manageable.

I wonder if adding a wrapper function/decorator to hide this from users might be useful? The equinox decorated functions are very useful in this respect of hiding the complexity away.

I can see occurrences where somebody might want to use your model and try a different loss function, or incorporate your model / NN into a pipeline of others and they don't realise this behaviour. The ability to wrap your model such that another user doesn't even need to think about this jax.tree_map(pytc.unfreeze, b, is_leaf=pytc.is_frozen) might prove helpful in stopping obvious mistakes.

from pytreeclass.

ASEM000 avatar ASEM000 commented on August 18, 2024

You are right; fortunately, it's easy to do just that.

def unfreeze_func(func):
    @ft.wraps(func)
    def wrapper(tree, *a, **k):
        tree = jax.tree_map(pytc.unfreeze, tree, is_leaf=pytc.is_frozen)
        return func(tree, *a, **k)

    return wrapper


@jax.jit
@jax.value_and_grad
@unfreeze_func
def loss_func(b: B):
    return b(1.0)

from pytreeclass.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.