Code Monkey home page Code Monkey logo

universal_neural_functional's Introduction

Universal Neural Functionals

This is the code for constructing UNFs, from the paper Universal Neural Functionals. UNFs are architectures that can process the weights of other neural networks, while maintaining equivariance or invariance to the weight space permutation symmetries. In contrast to NFNs, UNFs can ingest weights from any architecture.

Equivalently, we can think of UNFs as equivariant architectures for processing any collection of tensors, where the action involves a shared set of permutations permuting the axes of the tensors in a given way.

The codebase requires JAX for core functionality and Flax for the example (though other Jax NN libraries are likely compatible as well). See usage in example.py.

High level usage

The perm_spec is what tells our library the permutation symmetries it should be equivariant to. For example, suppose you have a collection of weight tensors corresponding to a simple MLP:

params = {
    "params": {
        "Dense_0": {
            "kernel": Array[784, 512],
            "bias": Array[512]
        },
        "Dense_1": {
            "kernel": Array[512, 10],
            "bias": Array[10]
        }
    }
}

We can describe the permutation symmetry of this network as follows (assume the input and output neurons are also permutable).

  • The weight tensors can be permuted by $\sigma=(\sigma_0, \sigma_1, \sigma_2) \in S_{784} \times S_{512} \times S_{10}$.
  • $\sigma_0$ permutes the first dimension of params["params"]["Dense_0"]["kernel"].
  • $\sigma_1$ permutes the second dimension of params["params"]["Dense_0"]["kernel"], the vector params["params"]["Dense_0"]["bias"], and the first dimension of params["params"]["Dense_1"]["kernel"].
  • $\sigma_2$ permutes the second dimension of params["params"]["Dense_1"]["kernel"] and the vector params["params"]["Dense_1"]["bias"].

Then we number each permutation by integers: $(\sigma_0, \sigma_1, \sigma_2) \mapsto (0, 1, 2)$ and define the permutation specification:

perm_spec = {
    "params": {
        "Dense_0": {
            "kernel": (0, 1),
            "bias": (1,)
        },
        "Dense_1": {
            "kernel": (1, 2),
            "bias": (2,)
        }
    }
}

Notice that nothing requires the input to be a collection of weight tensors. This library processes any collection of tensors if you give it a description of the permutation symmetries.

universal_neural_functional's People

Contributors

allanyangzhou avatar

Stargazers

Liam Gray avatar Jason Ken Adhinarta avatar Kuk Jin Kim avatar Julien avatar Nikhil Mehta avatar Jack Cole avatar Ege Erdogan avatar Mikolaj Konarski avatar Mingjia Li avatar  avatar Jose Cohenca avatar Apurv Verma avatar Nathan Breitsch avatar LostThinker avatar Rujikorn Charakorn avatar  avatar Xiaosen Zheng avatar Ilya Zakharkin avatar  avatar Xiang Pan (潘翔) avatar Stefan Schoepf avatar Ritchie avatar Jens Egholm Pedersen avatar Ilias Abdouni avatar Kuan-Ying Lai avatar John F. Wu avatar  avatar Ryan avatar Aheli avatar Anshul Kundaje avatar Jürgen R. Plasser / ThetaPhiPsi avatar Darren Garvey avatar Jacques Thibodeau avatar Jaisurya avatar firstuserhere avatar  avatar polychromatist avatar Erich Ocean avatar  avatar Luke Chen avatar Yohaï-Eliel Berreby avatar Kyle Bollinger avatar Andrea Terlizzi avatar Lei Zhou avatar Kinke Kabingila avatar Hans Brouwer avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.