Code Monkey home page Code Monkey logo

Comments (1)

ASEM000 avatar ASEM000 commented on July 19, 2024

Comparison between PyTreeClass and data class field metadata-based approach for freezing tree leaves [Draft]

PyTreeClass Flax Equinox
import jax
import pytreeclass as pytc

@pytc.treeclass
class Test:
    a: tuple[int, ...] = (pytc.freeze(1.), 2., 3.)

    def __call__(self, x):
        return sum(self.a)

@jax.grad
def loss_func(NN:Test,x : float):
    return (NN(x)-x)**2

@jax.jit
def step_func(NN:Test, x:float):
    dNN = loss_func(NN, x)
    NN -= 1e-3* dNN
    return NN

def train(NN:Test, epochs:int= 10_000):
    NN = Test()
    for _ in range(epochs):
        NN = step_func(NN, 10.)
    return NN


print(train(NN))  
# Test(a=(#1.0, 3.9999967, 4.9999437))
import jax
import jax.tree_util as jtu
from flax import struct

@struct.dataclass
class Test:
    a_0: float = struct.field(pytree_node=False, default=1.)
    a_1: float = 2.
    a_2: float = 3.

    def __call__(self, x):
        return self.a_0 + self.a_1 + self.a_2

@jax.grad
def loss_func(NN:Test,x : float):
    return (NN(x)-x)**2

@jax.jit
def step_func(NN:Test, x:float):
    dNN = loss_func(NN, x)
    NN = jtu.tree_map(lambda x,y: x-y*1e-3, NN,dNN)
    return NN

def train(NN:Test, epochs:int= 10_000):
    NN = Test()
    for _ in range(epochs):
        NN = step_func(NN, 10.)
    return NN


print(train(NN))  
# Test(a_0=1.0, 
a_1=Array(3.9999967, dtype=float32, weak_type=True),
 a_2=Array(4.9999437, dtype=float32, weak_type=True))
import jax
import jax.tree_util as jtu
import equinox as eqx

class Test(eqx.Module):
    a_0: float = eqx.static_field(default=1.)
    a_1: float = 2.
    a_2: float = 3.

    def __call__(self, x):
        return self.a_0 + self.a_1 + self.a_2

@jax.grad
def loss_func(NN:Test,x : float):
    return (NN(x)-x)**2

@jax.jit
def step_func(NN:Test, x:float):
    dNN = loss_func(NN, x)
    NN = jtu.tree_map(lambda x,y: x-y*1e-3, NN,dNN)
    return NN

def train(NN:Test, epochs:int= 10_000):
    NN = Test()
    for _ in range(epochs):
        NN = step_func(NN, 10.)
    return NN

print(train(NN)) 
# Test(a_0=1.0, a_1=f32[], a_2=f32[])

print(train(NN).a_1)  
# 3.9999967

from pytreeclass.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.