Code Monkey home page Code Monkey logo

penzai's Introduction

Penzai

็›† ("pen", tray) ๆ ฝ ("zai", planting) - an ancient Chinese art of forming trees and landscapes in miniature, also called penjing and an ancestor of the Japanese art of bonsai.

Penzai is a JAX library for writing models as legible, functional pytree data structures, along with tools for visualizing, modifying, and analyzing them. Penzai focuses on making it easy to do stuff with models after they have been trained, making it a great choice for research involving reverse-engineering or ablating model components, inspecting and probing internal activations, performing model surgery, debugging architectures, and more. (But if you just want to build and train a model, you can do that too!)

With Penzai, your neural networks could look like this:

Screenshot of the Gemma model in Penzai

Penzai is structured as a collection of modular tools, designed together but each useable independently:

  • penzai.nn (pz.nn): A declarative combinator-based neural network library and an alternative to other neural network libraries like Flax, Haiku, Keras, or Equinox, which exposes the full structure of your model's forward pass in the model pytree. This means you can see everything your model does by pretty printing it, and inject new runtime logic with jax.tree_util. Like Equinox, there's no magic: models are just callable pytrees under the hood.

  • penzai.treescope (pz.ts): A superpowered interactive Python pretty-printer, which works as a drop-in replacement for the ordinary IPython/Colab renderer. It's designed to help understand Penzai models and other deeply-nested JAX pytrees, with built-in support for visualizing arbitrary-dimensional NDArrays.

  • penzai.core.selectors (pz.select): A pytree swiss-army-knife, generalizing JAX's .at[...].set(...) syntax to arbitrary type-driven pytree traversals, and making it easy to do complex rewrites or on-the-fly patching of Penzai models and other data structures.

  • penzai.core.named_axes (pz.nx): A lightweight named axis system which lifts ordinary JAX functions to vectorize over named axes, and allows you to seamlessly switch between named and positional programming styles without having to learn a new array API.

  • penzai.data_effects (pz.de): An opt-in system for side arguments, random numbers, and state variables that is built on pytree traversal and puts you in control, without getting in the way of writing or using your model.

Documentation on Penzai can be found at https://penzai.readthedocs.io.

Getting Started

If you haven't already installed JAX, you should do that first, since the installation process depends on your platform. You can find instructions in the JAX documentation. Afterward, you can install Penzai using

pip install penzai

and import it using

import penzai
from penzai import pz

(penzai.pz is an alias namespace, which makes it easier to reference common Penzai objects.)

When working in an Colab or IPython notebook, we recommend also configuring Penzai as the default pretty printer, and enabling some utilities for interactive use:

pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()

# Optional: enables automatic array visualization
pz.ts.active_autovisualizer.set_interactive(pz.ts.ArrayAutovisualizer())

Here's how you could initialize and visualize a simple neural network:

from penzai.example_models import simple_mlp
mlp = pz.nn.initialize_parameters(
    simple_mlp.MLP.from_config([8, 32, 32, 8]),
    jax.random.key(42),
)

# Models and arrays are visualized automatically when you output them from a
# Colab/IPython notebook cell:
mlp

Here's how you could capture and extract the activations after the elementwise nonlinearities:

mlp_with_captured_activations = pz.de.CollectingSideOutputs.handling(
    pz.select(mlp)
    .at_instances_of(pz.nn.Elementwise)
    .insert_after(pz.de.TellIntermediate())
)

output, intermediates = mlp_with_captured_activations(
  pz.nx.ones({"features": 8})
)

To learn more about how to build and manipulate neural networks with Penzai, we recommend starting with the "How to Think in Penzai" tutorial, or one of the other tutorials in the Penzai documentation.


This is not an officially supported Google product.

penzai's People

Contributors

danieldjohnson avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

penzai's Issues

Customized `selector` functions for user-defined `Pytree` classes

Hi! Awesome package -- had a great sense of "I absolutely must use this" when I was looking at the documentation and the repository.

I have a library I'm developing for probabilistic programming, where I define custom nested Pytree classes. The "getter" and "setter" interfaces for these classes are not just like normal Pytree accessing (via fields) -- they are defined to ensure consistency with the rest of the semantics of the system.

I'd like to use penzai and penzai.treescope with these objects -- so I'm wondering if there's some way to overload the selectors to use my own interfaces, and have this reflected in e.g. the lambda functions to access pieces of data from the visuals?

Thanks -- and again, awesome package -- I'm totally willing to read your code, and try to figure out a solution to the above, if you're enthusiastic about it.

Keep axes in nmap

Is there anyway to specify axes that you want to push into the nmap? Currently it feels like I need to go fully positional inside of a function.

I guess it would be something like

x=x.tag("foo", "bar")
vmap(f) (x.untag("foo"))

Where f can see bar?

importing model from other jax-based libraries for weight visualization

I saw the release announcement for this library and it looks really cool. But the base layers I'm using in flax right now aren't implemented in the neural network library here. Is there any way to import my jax pytree that I created in my flax training run to visualize my weights, even if those layers don't have implementations in penzai.nn?

Synchronize sliders across array data in Pytree

https://penzai.readthedocs.io/en/stable/notebooks/treescope_arrayviz.html#slicing-and-scrubbing-with-sliders presents a great way to setup small visuals inside the printing of nested Pytree instances.

One thing I'd like to figure out: is it possible to synchronize the slider state across all arrays?

In general, my computations produce Pytrees whose interpretation will be invalid if a user looks at a visual where some of the sliders are mismatched on slider index.

I'm not sure if this is out of scope, or how this would work, but curious if there are any thoughts about this.

Penzai as a training framework

Hello, thank you for a great framework!

I am currently considering penzai for building model trainer, as an alternative to frameworks like Flax. Would be there any possible limitations for penzai for this purpose? For example, using NamedArray would incur zero overheads? I think they are, but I want to be sure before proceeding with development.

Thank you!

Auto unwrap for full reductions.

When you call full reductions like sum or any, you need to untag and then unwrap which is a bit painful. Seems safe to just auto unwrap there.

Default `pytree_dataclass` for inheritors of a common baseclass

I'd like to define a common base class which always explicitly uses @pz.pytree_dataclass for its inheritors.

Is there a convenient way to override this? E.g. I'm trying to avoid

Every non-abstract `Struct` must be explicitly registered as a dataclass
pytree using the decorator `penzai.pytree_dataclass`, so that readers of the
code can tell the class's semantics differ from that of an ordinary Python class.

for all the classes I'm defining in a library.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.