Code Monkey home page Code Monkey logo

jochastic's Introduction

Jochastic: stochastically rounded operations between JAX tensors.

This repository contains a JAX software-based implementation of some stochastically rounded operations.

When encoding the weights of a neural network in low precision (such as bfloat16), one runs into stagnation problems: updates end up being too small relative to the numbers the precision of the encoding. This leads to weights becoming stuck and the model's accuracy being significantly reduced.

Stochastic arithmetic lets you perform the operations in such a way that the weights have a non-zero probability of being modified anyway. This avoids the stagnation problem (see figure 4 of "Revisiting BFloat16 Training") without increasing the memory usage (as might happen if one were using a compensated summation to solve the problem).

The downside is that software-based stochastic arithmetic is significantly slower than normal floating-point arithmetic. It is thus viable for things like the weight update (when using the output of an Optax optimizer for example) but would not be appropriate in a hot loop.

Do not hesitate to submit an issue or a pull request if you need added functionalities for your needs!

Usage

This repository introduces the add and tree_add operations. They take a PRNGkey and two tensors (or pytree respectively) to be added but round the result up or down randomly:

import jax
import jax.numpy as jnp
import jochastic

# problem definition
size = 10
dtype = jnp.bfloat16
key = jax.random.PRNGKey(1993)

# deterministic addition
key, keyx, keyy = jax.random.split(key, num=3)
x = jax.random.normal(keyx, shape=(size,), dtype=dtype)
y = jax.random.normal(keyy, shape=(size,), dtype=dtype)
result = x + y
print(f"deterministic addition: {result}")

# stochastic addition
result_sto = jochastic.add(key, x, y)
print(f"stochastic addition: {result_sto} ({result_sto.dtype})")
difference = result - result_sto
print(f"difference: {difference}")

Both functions take an optional is_biased boolean parameter. If is_biased is True (the default value), the random number generator is biased according to the relative error of the operation else, it will round up half of the time on average.

Jitting the functions is left to the user's discretion (you will need to indicate that is_biased is static).

NOTE: Very low precision (16 bits floating-point arithmetic or less) is extremely brittle. We recommend using higher precision locally (such as using 32 bits floating point arithmetic to compute the optimizer's update) then casting down to 16 bits at summing / storage time (something that Pytorch does transparently when using their addcdiv in low precision). Both functions will accept mixed-precision inputs (adding a high precision number to a low precision), use that information for the rounding then return an output in the lowest precision of their inputs (contrary to most casting conventions).

Implementation details

We use TwoSum to measure the numerical error done by the addition, our tests show that it behaves as needed on bfloat16 (some edge cases might be invalid, leading to an inexact computation of the numerical error but, it is reliable enough for our purpose).

This and the nextafter function let us emulate various rounding modes in software (this is inspired by Verrou's backend).

Crediting this work

You can use this BibTeX reference if you use Jochastic within a published work:

@misc{Jochastic,
  author = {Nestor, Demeure},
  title = {Jochastic: stochastically rounded operations between JAX tensors.},
  year = {2022},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/nestordemeure/jochastic}}
}

You will find a Pytorch implementation called StochasTorch here.

jochastic's People

Contributors

nestordemeure avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Forkers

lk-wq

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.