Code Monkey home page Code Monkey logo

jax-ipu-experimental's Introduction

logo

๐Ÿ”ด Non-official experimental ๐Ÿ”ด JAX on Graphcore IPU

Run on Gradient Continuous integration

Install guide | Quickstart | IPU JAX on Paperspace | Documentation

๐Ÿ”ด โš ๏ธ Non-official experimental โš ๏ธ ๐Ÿ”ด

This is a very thin fork of http://github.com/google/jax for Graphcore IPU. This package is provided by Graphcore Research for experimentation purposes only, not production (inference or training).

Features and limitations of experimental JAX on IPUs

The following features are supported:

  • Vanilla JAX API: no additional IPU specific API, any code written for IPUs is backward compatible with other backends (CPU/GPU/TPU);
  • JAX asynchronous dispatch on IPU backend;
  • Multiple IPUs with collectives using pmap and (experimental) pjit;
  • Large coverage of JAX lax operators;
  • Support of JAX buffer donation to keep parameters on IPU SRAM;

Known limitations of the project:

  • No eager mode (every JAX call has to be compiled, loaded, and finally executed on IPU device);
  • IPU code generated can be larger than official Graphcore TensorFlow or PopTorch (limiting batch size or model size);
  • Multi-IPUs collective have topology restrictions (following Graphcore GCL API);
  • Missing linear algebra operators;
  • Incomplete support of JAX random number generation on IPU device;
  • Deactivated support of JAX infeeds and outfeeds;

This is a research project, not an official Graphcore product. Expect bugs and sharp edges. Please help by trying it out, reporting bugs, and letting us know what you think!

Installation

The experimental JAX wheels require Ubuntu 20.04, Graphcore Poplar SDK 3.1 or 3.2 and Python 3.8, and can be installed as following:

pip install jax==0.3.16+ipu jaxlib==0.3.15+ipu.sdk310 -f https://graphcore-research.github.io/jax-experimental/wheels.html

For SDK 3.2, please change jaxlib version to jaxlib==0.3.15+ipu.sdk320.

Minimal example

The following example can be run on Graphcore IPU Paperspace (or on a non-IPU machine using the IPU emulator):

from functools import partial
import jax
import numpy as np

@partial(jax.jit, backend="ipu")
def ipu_function(data):
    return data**2 + 1

data = np.array([1, -2, 3], np.float32)
output = ipu_function(data)
print(output, output.device())

JAX on IPU Paperspace notebooks

Additional JAX on IPU examples:

Useful JAX backend flags:

As standard in JAX, these flags can be set using from jax.config import config import.

Flag Description
config.FLAGS.jax_platform_name ='ipu'/'cpu' Configure default JAX backend. Useful for CPU initialization.
config.FLAGS.jax_ipu_use_model = True Use IPU model emulator.
config.FLAGS.jax_ipu_model_num_tiles = 8 Set the number of tiles in the IPU model.
config.FLAGS.jax_ipu_device_count = 2 Set the number of IPUs visible in JAX. Can be any local IPU available.
config.FLAGS.jax_ipu_visible_devices = '0,1' Set the specific collection of local IPUs to be visible in JAX.

Alternatively, like other JAX flags, these can be set using environment variables (e.g. JAX_IPU_USE_MODEL, JAX_IPU_MODEL_NUM_TILES,...).

Useful PopVision environment variables:

  • Generate PopVision Graph analyser profile: POPLAR_ENGINE_OPTIONS='{"autoReport.all":"true", "debug.allowOutOfMemory":"true"}'
  • Generate PopVision system analyser profile: PVTI_OPTIONS='{"enable":"true", "directory":"./reports"}'

Documentation

License

The project remains licensed under the Apache License 2.0, with the following files unchanged:

The additional dependencies introduced for Graphcore IPU support are:

jax-ipu-experimental's People

Contributors

alexbw avatar apaszke avatar balancap avatar bchetioui avatar dougalm avatar fehiepsi avatar froystig avatar gnecula avatar hawkinsp avatar j-towns avatar jacobjinkelly avatar jakevdp avatar jblespiau avatar jekbradbury avatar juliuskunze avatar lenamartens avatar levskaya avatar lgeiger avatar majnemer avatar marcvanzee avatar mattjj avatar minoring avatar pschuh avatar sharadmv avatar shoyer avatar skye avatar tlu7 avatar tomhennigan avatar yashk2810 avatar zhangqiaorjc avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.