Code Monkey home page Code Monkey logo

hyperiax's Introduction

Hyperiax: Tree traversals using JAX

Introduction

Hyperiax is a framework for tree traversal and computations on large-scale tree. Its primary purpose is to facilitate efficient message passing and operation execution on large trees. Hyperiax uses JAX for fast execution and automatic differentiation. Hyperiax is developed and maintained by CCEM, UCPH.

Initially, Hyperiax was designed specifically for phylogenetic analysis of biological shape data, particularly enabling statistical inference with continuous time stochastic processes along the edges of the trees. For this purpose, is integrated with JAXGeometry, a computational differential geometry toolbox implemented in JAX. However, Hyperiax's messaging system and operations are general, which means that they can be easily adapted for use in other contexts. With minor modifications, Hyperiax can be used for any application where fast tree-level computations are necessary. Included examples cover such cases with inference in Gaussian graphical models, phylogenetic mean computation, and recursive shape matching in binary trees.

Installation

# Install Hyperiax directly using pip
pip install hyperiax

# Install Hyperiax from the repository, for the newest version
pip install git+https://github.com/ComputationalEvolutionaryMorphometry/hyperiax.git

# Install Hyperiax for development
git clone [email protected]:ComputationalEvolutionaryMorphometry/hyperiax.git
# or (if you haven't set up ssh)
git clone https://github.com/ComputationalEvolutionaryMorphometry/hyperiax.git
# and then install by
pip install -e hyperiax[dev]
# and optionally
pip install -e hyperiax[examples]
# to install the dependencies for all the example notebooks

Code Examples

  • Set up a tree
# Initialize a tree with a height of 4 and a degree of 3
tree = hyperiax.tree.builders.symmetric_tree(h=4, degree=3)
# Visualize
tree.plot_tree()

# Initialize the data value in nodes and branch lengths with example initialized data
key = jax.random.PNGKey(0)
# Randomly initialized values and lengths
exmp_values = jax.random.normal(key, shape=(16, ))
exmp_lengths = jax.random.uniform(key, shape=(16, ))
# Assign the values and lengths by broadcasting
tree["value"] = init_values
tree["edge_length"] = init_lengths
# Initialize the noise within the tree
noisy_tree = hyperiax.tree.initializers.initialize_noise(tree, key, shape=(1, ))
  • Define operations and executor
# Define the function executed along the edge
@jax.jit
def down(noise, edge_length, parent_value, **kwargs):    # example down function, insert your own one
    return {"value": jnp.sqrt(edge_length) * noise + parent_value}

up = jaxtrees.models.functional.pass_up('value', 'edge_length')

@jax.jit
def fuse(child_value, child_edge_length, **kwargs):    # example fuse function, insert your own one
    child_edge_length_inv = 1. / child_edge_length
    res = jnp.einsum('b1,bd->d', child_edge_length_inv, child_value) / child_edge_length.sum()
    return {"value": res}
  • Run the simulation
# Wrap all the functions in one model
updown_model = hyperiax.model.lambdamodels.UpDownLambda(up, fuse, down)
# Define the executor and run it
exe = hyperiax.execution.DependencyTreeExecutor(updown_model, batch_size=5)
# Do the inference from bottom to top
inf_tree = exe.up(noisy_tree)
# Do the sampling from top to bottom
sample_tree = exe.down(noisy_tree)

See Examples for more specific examples.

Documentation

  • Getting Started: See Getting-Started
  • Guidance: See Wiki
  • Full API Documentation: See Hyperiax API

Todo

Contribution

Contributions, issues and feature requests are all welcome! Please refer to the contributing guidelines before you want to contribute to the project.

Contact

If you experience problems or have technical questions, please open an issue. For questions related to the Hyperiax project or CCEM, please contact Stefan Sommer.

hyperiax's People

Contributors

frozaken avatar bookdiver avatar nickeopti avatar michaelsev avatar stefansommer avatar mortenapedersen avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.