Code Monkey home page Code Monkey logo

folx's Introduction

folx - Forward Laplacian for JAX

This submodule implements the forward laplacian from https://arxiv.org/abs/2307.08214. It is implemented as a custom interpreter for Jaxprs.

Install

Either clone repo and install locally via

poetry install

or

pip install .

or install via pip package manager via

pip install folx

Example

For simple usage, one can decorate any function with forward_laplacian.

import numpy as np
from folx import forward_laplacian

def f(x):
    return (x**2).sum()

fwd_f = forward_laplacian(f)
result = fwd_f(np.arange(3, dtype=float))
result.x # f(x) 3
result.jacobian.dense_array # J_f(x) [0, 2, 4]
result.laplacian # tr(H_f(x)) 6

Introduction

To avoid custom wrappers for all of JAX's commands, the forward laplacian is implemented as custom interpreter for Jaxpr. This means if you have a function

class Fn(Protocol):
    def __call__(self, *args: PyTree[Array]) -> PyTree[Array]:
        ...

the resulting function will have the signature:

class LaplacianFn(Protocol):
    def __call__(self, *args: PyTree[Array]) -> PyTree[FwdLaplArray]:
        ...

where FwdLaplArray is a triplet of

FwdLaplArray.x # jax.Array f(x) f(x).shape
FwdLaplArray.jacobian # FwdJacobian J_f(x)
FwdLaplArray.laplacian # jax.Array tr(H_f(x)) f(x).shape

The jacobian is implemented by a custom class as the forward laplacian supports automatic sparsity. To get the full jacobian:

FwdLaplArray.jacobian.dense_array # jax.Array (*f(x).shape, x.size)

Implementation idea

The idea is to rely on the original function and autodifferentiation to propagate FwdLaplArray forward instead of the regular jax.Array. The rules for updating FwdLaplArray are described by the pseudocode:

x # FwdLaplArray
y = FwdLaplArray(
    x=f(x.x),
    jacobian=jvp(f, (x.x,), (x.jacobian)),
    laplacian=tr_vhv(f, x.jacobian) + jvp(f, (x.x,), (x.laplacian,))
)
# tr_vhv is tr(J_f H_f J_f^T)

Implementation

When you call the function returned by forward_laplacian(fn), we first use jax.make_jaxpr to obtain the jaxpr for fn. But instead of using the standard evaluation pipeline, we use a custom interpreter that replaces all operations to propate FwdLaplArray forward instead of regular jax.Array.

Package structure

The general structure of the package is

  • interpreter.py contains the evaluation of jaxpr and exported function decorator.
  • wrapper.py contains subfunction decorator that maps a function that takes jax.Arrays to a function that accepts FwdLaplArrays instead.
  • wrapped_functions.py contains a registry of predefined functions as well as utility functions to add new functions to the registry.
  • jvp.py contains logic for jacobian vector products.
  • hessian.py contains logic for tr(JHJ^T).
  • custom_hessian.py contains special treatment logic for tr(JHJ^T).
  • api.py contains general interfaces shared in the package.
  • operators.py contains a forward laplacian operator as well as alternatives.
  • utils.py contains several small utility functions.
  • tree_utils.py contains several utility functions for PyTrees.
  • vmap.py contains a batched vmap implementation to reduce memory usage by going through a batch sequentially in chunks.

Function Annotations

There is a default interpreter that will simply apply the rules outlined above but if additional information about a function is available, e.g., that it applies elementwise like jnp.tanh, we can do better. These additional annotations are available in wrapped_functions.py's _LAPLACE_FN_REGISTRY. Specifically, to augment a function fn to accept FwdLaplArray instead of regular jax.Array, we wrap it with wrap_forward_laplacian from fwd_laplacian.py:

wrap_forward_laplacian(jnp.tanh, in_axes=())

In this case, we annotate the function to be applied elementwise, i.e., () indicates that none of the axes are relevant for the function.

If we know nothing about which axes might be essential, one must pass None (the default value) to mark all axes as imporatnt, e.g.,

wrap_forward_laplacian(jnp.sum, in_axes=None, flags=FunctionFlags.LINEAR)

However, in this case we know that a summation is a linear operation. This information is useful for fast hessian computations.

If you want rules to a function and add it to the registry you can do the following

import jax
from folx import register_function, wrap_forward_laplacian

register_function(jax.lax.cos_p, wrap_forward_laplacian(f, in_axes=()))
# Now the tracer is aware that the cosine function is applied elementwise.

We can do even more by defining custom rules:

import jax
from folx import register_function, wrap_forward_laplacian

# the jit is important
@jax.jit
def f(x):
    return x

# define a custom jacobian hessian jacobian product rule
def custom_jac_hessian_jac(args, extra_args, merge, materialize_idx):
    return jtu.tree_map(lambda x: jnp.full_like(x, 10), args.x)

# make sure to use the same name here as above
register_function("f", wrap_forward_laplacian(f, custom_jac_hessian_jac=custom_jac_hessian_jac))

@forward_laplacian
def g(x):
    return f(x)

g(jnp.ones(())).laplacian # 10

Sparsity

Sparsity is detected at compile time, this has the advantage of avoiding expensive index computations at runtime and enables efficient reductions. However, it completely prohibits dynamic indexing, i.e., if indices are data-dependent we will simply default to full jacobians.

As we know a lot about the sparsity structure apriori, e.g., that we are only sparse in one dimension, we use a custom sparsity operations that are more efficient than relying on JAX's default BCOO (further, at the time of writing, the support for jax.experimental.sparse is quite bad). So, the sparsity data format is implemented in FwdJacobian in api.py. Instead of storing a dense array (m, n) for a function f:R^n -> R^m, we store only the non-zero data in a (m,d) array where d<n is the maximum number of non-zero inputs any output depends on. To be able to recreate the larger (m,n) array from the (m,d) array, we additional keep track of the indices in the last dimension in a mask (m,d) dimensional array of integers 0<mask_ij<n.

Masks are treated as compile time static and will be traced automatically. If the tracing is not possible, e.g., due to data dependent indexing, we will fall back to a dense implementation. These propagation rules are implemented in jvp.py.

Omnistaging

If arrays do not depend on the initial input, they are typically still traced to better optimize the final program. This is called omnistaging. While this generally is beneficial, it does not allow us to perform indexing as tracer hide the actual data. So, if we use sparsity we want to compute all arrays that do not explicitly depend on the input such that we could use them for index operations. While this is not documented, it can be accomplished by overwriting the global trace via:

from jax import core

with core.new_main(core.EvalTrace, dynamic=True):
    ...

Citation

If you find work helpful, please consider citing it as

@software{folx2023gao,
  author = {Nichoals Gao and Jonas Köhler and Adam Foster},
  title = {folx - Forward Laplacian for JAX},
  url = {http://github.com/microsoft/folx},
  version = {0.2.0},
  year = {2023},
}

as well as the original forward laplacian:

@article{li2023forward,
  title={Forward Laplacian: A New Computational Framework for Neural Network-based Variational Monte Carlo},
  author={Li, Ruichen and Ye, Haotian and Jiang, Du and Wen, Xuelan and Wang, Chuwei and Li, Zhe and Li, Xiang and He, Di and Chen, Ji and Ren, Weiluo and Wang, Liwei},
  journal={arXiv preprint arXiv:2307.08214},
  year={2023}
}

Contributing

This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.

When you submit a pull request, a CLA bot will automatically determine whether you need to provide a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repos using our CLA.

This project has adopted the Microsoft Open Source Code of Conduct. For more information see the Code of Conduct FAQ or contact [email protected] with any additional questions or comments.

Trademarks

This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow Microsoft's Trademark & Brand Guidelines. Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos are subject to those third-party's policies.

folx's People

Contributors

jonkhler avatar microsoft-github-operations[bot] avatar microsoftopensource avatar n-gao avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.