Code Monkey home page Code Monkey logo

Comments (4)

ASEM000 avatar ASEM000 commented on August 18, 2024

Hello Adam, thank you for your question.

For background, there are a couple of libraries with similar ideas (Pytorch-like API) that predate Treex and Equinox and as you have seen, libraries with seemingly similar ideas that postdate them, each of these libraries has their reason to exist. Similarly, The landscape of neural network libraries is even more diverse, with Google and DeepMind alone having several such libraries, including objax, flax, haiku, and oryx.nn. Each of these libraries represents a slightly different conceptual model. Therefore, as you delve deeper into this landscape, you will likely discover a variant that aligns with your specific needs and preferences.

Now, let me explain why PyTreeClass exists when Equinox/Treex/simple_pytree/ exists; since you mentioned Equinox, I will try to focus more on it.

1- One of the core ideas of equinox is filtered transformations, where you filter your pytrees on the function level, while in PyTreeClass, you filter on the pytree level by masking. This is a deliberate decision, and doing this prevents me from creating automatic decorators like equinox.filter_{...} that parallels jax API.

I believe that mirroring an API can be a risky strategy (although it can be cleverly implemented like in jax.numpy) because it can lead to confusion and errors due to inconsistent behaviour. Additionally, it requires meticulous maintenance to keep up with updates to the original API (you can see examples of filter_ related issues in the Equinox issue tab). Moreover, debugging can be more challenging because you need to understand which nodes have been frozen and which are under training beforehand. If not handled carefully, this approach can introduce bugs and unexpected behaviours when interacting with pure Jax or other libraries. For these reasons, I prefer a more explicit method through masking, where it's possible to see which nodes are frozen before passing them to a function. This helps prevent any unforeseen outcomes.

2- IMO, PyTreeClass has better functional tree manipulation (~lenses-like), you can do couple of things with easily:

  • Create masks with optax to selectively apply certain optimizers to the leaves you like (based on name or condition ) .
import optax
import pytreeclass as pytc
import jax

class Tree(pytc.TreeClass):
    a: float = 1.0
    b: float = 2.0
    c: float = 3.0


tree = Tree()

a_mask = tree.at[...].set(False).at["a"].set(True)
b_mask = tree.at[...].set(False).at["b"].set(True)
c_mask = tree.at[...].set(False).at["c"].set(True)

optim = optax.chain(
    # update `a` with sgd of learning rate 1
    optax.masked(optax.sgd(learning_rate=1), a_mask),
    # update `b` with sgd of learning rate -1
    optax.masked(optax.sgd(learning_rate=-1), b_mask),
    # update `c` with sgd of learning rate 0
    optax.masked(optax.sgd(learning_rate=0), c_mask),

)
  • You can use the functional call to do myriad of things like adding new leaves after class instantiation ( not possible in equinox)
import pytreeclass as pytc

class Tree(pytc.TreeClass):
    a: float = 1.0
    b: float = 2.0
    c: float = 3.0
    
    def add_leaf(self, name: str, value):
        setattr(self, name, value)

tree = Tree()
# Tree(a=1.0, b=2.0, c=3.0)

_ , tree_with_d = tree.at["add_leaf"]("d", 4.0)

tree_with_d
# Tree(a=1.0, b=2.0, c=3.0, d=4.0)

3- Debugging, all my viz tools are geared towards debugging; for example, you always have helpful information whenever you interact with trees. For example, for deep and nested networks, I usually resort to tree_diagram function with depth argument to navigate the network.

  1. more advanced features, like tree_map_with_trace, let you filter based on type path; this is useful if you want to freeze leaves with certain parent types ( Dropout layer leaves, for example). This is a unique feature of PyTreeClass

  2. Data model, pytreeclass blend the idea of pytree of arrays with array (optionally throw leafwise=True)

import pytreeclass as pytc
import jax.numpy as jnp

class Tree(pytc.TreeClass, leafwise=True):
    a:int = 1
    b:tuple[float] = (2.,3.)
    c:jax.Array = jnp.array([4.,5.,6.])

tree = Tree()

print(pytc.bcmap(jnp.where)(tree>2, tree+100, 0))
# Tree(a=0, b=(0.0, 103.0), c=[104. 105. 106.])

print(tree.at[tree>1].apply(lambda x:x+100))
# Tree(a=1, b=(102.0, 103.0), c=[104. 105. 106.])
  1. Module design, this is where all other PyTree libraries have their flavour; I will focus on Equinox to explain my point; I will use the example I found here
class CNN(eqx.Module):
    layers: list

    def __init__(self, key):
        key1, key2, key3, key4 = jax.random.split(key, 4)
        # Standard CNN setup: convolutional layer, followed by flattening,
        # with a small MLP on top.
        self.layers = [
            eqx.nn.Conv2d(1, 3, kernel_size=4, key=key1),
            eqx.nn.MaxPool2d(kernel_size=2),
            jax.nn.relu,
            jnp.ravel,
            eqx.nn.Linear(1728, 512, key=key2),
            jax.nn.sigmoid,
            eqx.nn.Linear(512, 64, key=key3),
            jax.nn.relu,
            eqx.nn.Linear(64, 10, key=key4),
            jax.nn.log_softmax,
        ]

    def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]:
        for layer in self.layers:
            x = layer(x)
        return x

In equinox, you need to declare your trainable params as type hinted fields on top of your class, so if you want the previous example to have nn.conv1 to point to the first convolution layer, for example, then you have to do something like this:

class CNN(eqx.Module):
   conv1:eqx.nn.Conv2d
   pool1:eqx.nn.MaxPool2d
   linear1:eqx.nn.Linear
   linear2:eqx.nn.Linear

   def __init__(self, key):
       key1, key2, key3, key4 = jax.random.split(key, 4)
       # Standard CNN setup: convolutional layer, followed by flattening,
       # with a small MLP on top.
       self.conv1 = eqx.nn.Conv2d(1, 3, kernel_size=4, key=key1)
       self.pool1 = eqx.nn.MaxPool2d(kernel_size=2)
       self.linear1 = eqx.nn.Linear(1728, 512, key=key2)
       self.linear2 = eqx.nn.Linear(512, 10, key=key3)

   def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]:
       x = self.conv1(x)
       x = self.pool1(x)
       x = jax.nn.relu(x)
       x = jnp.ravel(x)
       x = self.linear1(x)
       x = jax.nn.sigmoid(x)
       x = self.linear2(x)
       x = jax.nn.log_softmax(x)

IMO, This is a repetitive design. The example above escapes this repetition by using a mutable container(list) to wrap all the layers, but you must use something like nn. layers[0] instead of nn.conv1 to fetch your first layer which hurts ergonomics. moreover, by doing so, you lose the immutability (try nn.layers.pop() ) essential to correct behaviour under Jax. Another reason you want to avoid using tuple/list as a layer container is that you are missing out the name of the layer/leaf which can be accessed using jax.tree_util.tree_map_with_path from jax

In pytreeclass, all class variables are leaves by default. If you want to filter non-trainable parameters, use a mask, as seen in the readme.

  1. finally, I am a user Equinox. I use Equinox's internal tools equinox.internal and i think my library must play nicely with others in the jax ecosystem, This is why pytreeclass does not have any special treatment for non-pytreeclass instances. You can use all these tools with any library you like (e.g. flax/equinox/haiku).

so for the CNN example, you can inherit all pytreeclass pros by doing something like this:

import pytreeclass as pytc 
import equinox as eqx 

class CNN(pytc.TreeClass):
   def __init__(self, key):
       key1, key2, key3, key4 = jax.random.split(key, 4)
       # Standard CNN setup: convolutional layer, followed by flattening,
       # with a small MLP on top.
       self.conv1 = eqx.nn.Conv2d(1, 3, kernel_size=4, key=key1)
       self.pool1 = eqx.nn.MaxPool2d(kernel_size=2)
       self.linear1 = eqx.nn.Linear(1728, 512, key=key2)
       self.linear2 = eqx.nn.Linear(512, 10, key=key3)

   def __call__(self, x):
       x = self.conv1(x)
       x = self.pool1(x)
       x = jax.nn.relu(x)
       x = jnp.ravel(x)
       x = self.linear1(x)
       x = jax.nn.sigmoid(x)
       x = self.linear2(x)
       x = jax.nn.log_softmax(x)

For serket, you inherit the tools and mental model of pytreeclass, while being 100% compatible with other libraries including equinox. If you are a user of eqx.nn, you can use serket layers that does not exist in equinox like fft convolution within equinox if you like.

Let me know if this answers your question.

from pytreeclass.

adam-hartshorne avatar adam-hartshorne commented on August 18, 2024

Thank you for the extremely in-depth response. It will take me some time to consider all that has been stated, but my interest has definitely been peaked.

One other quick question. Do you have any benchmarking for your implementation vs say Equinox for a range of uses? Obviously, I saw your charts for flatten / unflatten, which look very good. I wonder how it performs in terms of memory / speed, when it comes to various common NN architectures (as I have found over the years, JAX can be very sensitive in which small changes in code when it comes to using things like vmap's - this is obviously down to how JAX / XLA optimisation is being conducted).

from pytreeclass.

ASEM000 avatar ASEM000 commented on August 18, 2024

Except flax.struct, I think most Pytree libraries should behave similarly regarding memory/speed.
PytreeClass is slightly faster because no logic (for static fields) is done when flattening/unflattening.
Check readme for benchmark links

from pytreeclass.

ASEM000 avatar ASEM000 commented on August 18, 2024

For reference:
[1] Pytree-based implementation : one that predates equinox/treex flax PyTreeNode, another one that postdate it pax
[2] equinox tree_at sample issue
[3] filter inconsistent behavior-sample issues 1, 2

from pytreeclass.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.