Comments (1)
The two code snippets are now equivalent.
With register_node
- A node is registered by its name once and retrieved for the subsequent calls.
- This workaround to emulate the
flax.linen.compact
behavior. - This method enables the definition of some nodes at runtime.
@treeclass
class StackedLinear:
def __init__(self, key):
self.keys = jax.random.split(key, 3)
def __call__(self, x):
x = self.register_node(Linear(key=self.keys[0], in_dim=x.shape[-1], out_dim=128),name="l1")(x)
x = jax.nn.tanh(x)
x = self.register_node(Linear(key=self.keys[1], in_dim=128, out_dim=128),name="l2")(x)
x = jax.nn.tanh(x)
x = self.register_node(Linear(key=self.keys[2], in_dim=128, out_dim=x.shape[-1]),name="l3")(x)
return x
@treeclass
class StackedLinear:
l1: Linear
l2: Linear
l3: Linear
def __init__(self, key, in_dim, out_dim):
keys = jax.random.split(key, 3)
self.l1 = Linear(key=keys[0], in_dim=in_dim, out_dim=128)
self.l2 = Linear(key=keys[1], in_dim=128, out_dim=128)
self.l3 = Linear(key=keys[2], in_dim=128, out_dim=out_dim)
def __call__(self, x):
x = self.l1(x)
x = jax.nn.tanh(x)
x = self.l2(x)
x = jax.nn.tanh(x)
x = self.l3(x)
return x
from pytreeclass.
Related Issues (20)
- fix `tree_diagram` to work with containers (list/tuple)
- Move to immutable approach HOT 2
- More documentation
- dealing with non differentiable tree values under `jax.{grad,value_and_grad,...}`.
- Add Container data structure to register list/tuple/set/dict as dataclass fields
- Move `.at[].freeze()` / `.at[].unfreeze()`
- Remove `field_only=True`
- Move model related viz to `serket`
- Document New freezing API HOT 1
- Python 3.11 dataclass will mark JAX arrays as mutable. HOT 1
- Document `tree_viz` changes HOT 3
- Document Broadcasting mapping decorator `bcmap`
- Document `field(callabacks=[...])`
- Comparison with `equinox` HOT 4
- Frozen / Static leaves HOT 5
- Add common recipes to docs
- Add freeze/unfreeze scheme in docs
- Improve readme
- logo
- Using pytreeclass with jax and pytorch without specifying backend as environment variable
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytreeclass.