Code Monkey home page Code Monkey logo

autobound's Introduction

AutoBound: Automatically Bounding Functions

Continuous integration PyPI version

AutoBound is a generalization of automatic differentiation. In addition to computing a Taylor polynomial approximation of a function, it computes upper and lower bounds that are guaranteed to hold over a user-specified trust region.

As an example, here are the quadratic upper and lower bounds AutoBound computes for the function f(x) = 1.5*exp(3*x) - 25*(x**2), centered at 0.5, and valid over the trust region [0, 1].

Example quadratic upper and lower bounds

The code to compute the bounds shown in this plot looks like this (see quickstart):

import autobound.jax as ab
import jax.numpy as jnp

f = lambda x: 1.5*jnp.exp(3*x) - 25*x**2
x0 = .5
trust_region = (0, 1)
# Compute quadratic upper and lower bounds on f.
bounds = ab.taylor_bounds(f, max_degree=2)(x0, trust_region)
# bounds.upper(1) == 5.1283045 == f(1)
# bounds.lower(0) == 1.5 == f(0)
# bounds.coefficients == (0.47253323, -4.8324013, (-5.5549355, 28.287888))

These bounds can be used for:

and more!

Under the hood, AutoBound computes these bounds using an interval arithmetic variant of Taylor-mode automatic differentiation. Accordingly, the memory requirements are linear in the input dimension, and the method is only practical for functions with low-dimensional inputs. A reverse-mode algorithm that efficiently handles high-dimensional inputs is under development.

A detailed description of the AutoBound algorithm can be found in this paper.

Installation

Assuming you have installed pip, you can install this package directly from GitHub with

pip install git+https://github.com/google/autobound.git

or from PyPI with

pip install autobound

You may need to upgrade pip before running these commands.

Testing

To run unit tests, first install the packages the unit tests depend on with

pip install autobound[dev]

As above, you may need to install or upgrade pip before running this command.

Then, download the source code and run the tests using

git clone https://github.com/google/autobound.git
python3 -m pytest autobound

or

pip install -e git+https://github.com/google/autobound.git#egg=autobound
python3 -m pytest src/autobound

Limitations

The current code has a few limitations:

  • Only JAX-traceable functions can be automatically bounded.
  • Many JAX library functions are not yet supported. What is supported is bounding the squared error loss of a multi-layer perceptron or convolutional neural network that uses the jax.nn.sigmoid, jax.nn.softplus, or jax.nn.swish activation functions.
  • To compute accurate bounds for deeper neural networks, you may need to use float64 rather than float32.

Citing AutoBound

To cite this repository:

@article{autobound2022,
  title={Automatically Bounding the Taylor Remainder Series: Tighter Bounds and New Applications},
  author={Streeter, Matthew and Dillon, Joshua V},
  journal={arXiv preprint arXiv:2212.11429},
  url = {http://github.com/google/autobound},
  year={2022}
}

This is not an officially supported Google product.

autobound's People

Contributors

jack-mcivor avatar mstreeter avatar wdhdev avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

autobound's Issues

Quickstart Notebook Error

Hi!

This seems to be an awesome tool for Taylor Models, but I am having trouble using it from the start. I have tried following the quickstart model both in google colab notebook as well as on my local Windows and Linux machines, but to no avail. Any ideas on what might be causing this?

`
import autobound.jax as ab
import jax.numpy as jnp

f = lambda x: 1.5 * jnp.exp(3 * x) - 25*x**2
x0 = .5
trust_region = (0, 1)

bounds = ab.taylor_bounds(f, max_degree=2)(x0, trust_region)
bounds.coefficients'


TypeError Traceback (most recent call last)
in <cell line: 8>()
6 trust_region = (0, 1)
7 # Compute quadratic upper and lower bounds on f.
----> 8 bounds = ab.taylor_bounds(f, max_degree=2)(x0, trust_region)
9 bounds.coefficients # == (f(x0), f'(x0), )

5 frames
/usr/local/lib/python3.10/dist-packages/autobound/jax/jaxpr_editor.py in vertex_to_var_or_literal(vertex)
170 if vertex[0]:
171 _, count, suffix, aval = vertex
--> 172 if count not in count_to_var:
173 count_to_var[count] = jax.core.Var(count, suffix, aval)
174 return count_to_var[count]

TypeError: Var.init() takes 3 positional arguments but 4 were given

Implement trigonometric functions

Sine and cosine are not implemented, as you state in the paper:

For example, we do not currently have a way to compute sharp Taylor enclosures for periodic functions, such as sine or cosine.

I only skimmed the paper, but I assume periodic functions are a non-trivial limitation?

Bug in Colab Notebook

Hi guys,

First of all, great work, really cool stuff.

I just had a look at this repo and the Colab notebook for safe learning rates. I am note sure if this is the right place to post this, since it is not really a problem with the package itself. But I think there is a bug in the saferate_optimizer of the afore-mentioned notebook.

def saferate_optimizer(loss, initial_max_eta: float = 1.):
  def init_fun(x0):
    return (x0, 0., initial_max_eta)

  def update_state(state):
    x, _, max_eta = state
    jax.tree_util.tree_map(lambda v: -v, jax.grad(loss)(x))
    safe_eta = safe_learning_rate(x, update_dir, max_eta)
    next_x = jax.tree_util.tree_map(
        lambda p, v: p + safe_eta*v, x, update_dir
    )
    next_max_eta = jnp.where(
        # If safe_eta is NaN, we cut the learning rate in half.
        jnp.logical_or(safe_eta < max_eta / 2, safe_eta > safe_eta),
        max_eta / 2,
        max_eta * 2
    )
    return (next_x, safe_eta, next_max_eta)

  def get_params(state):
    x, _, _ = state
    return x

  return init_fun, update_state, get_params

In the update_state function the line jax.tree_util.tree_map(lambda v: -v, jax.grad(loss)(x)) is assigned to nothing, I assume it should be the update_dir. Currently, the update_dir is taken from an outer scope.

Additionally, just for readability, I would recommend to add loss as explicit argument to the bound_loss function in the beginning of the notebook, as this is also coming from an outer scope.

'Quickstart' notebook outputs an error

Thanks for the interesting work! When I tried to run your quickstart.ipynb notebook, I got an error message at the second cell of code, where the first call to your library is made

bounds = ab.taylor_bounds(f, max_degree=2)(x0, trust_region)

The error shows up both on the "colab" link I followed from your github repo (with jax 0.4.20), and on my machine (with jax 0.4.19). The error reads

AttributeError: module 'jax' has no attribute 'abstract_arrays'

Find upper and lower bounds for a simple MLP function

How can I use AutoBound to compute upper and lower bounds on a MLP function?
I wanted to run this script:


import jax.numpy as jnp
import jax.nn
from jax import random
import autobound.jax as ab

def initialize_mlp_params(rng_key, input_dim, hidden_dim, output_dim):
k1, k2, k3, k4 = random.split(rng_key, 4)
weights_hidden1 = random.normal(k1, (input_dim, hidden_dim))
biases_hidden1 = jnp.zeros(hidden_dim)
weights_hidden2 = random.normal(k2, (hidden_dim, hidden_dim))
biases_hidden2 = jnp.zeros(hidden_dim)
weights_hidden3 = random.normal(k3, (hidden_dim, hidden_dim))
biases_hidden3 = jnp.zeros(hidden_dim)
weights_output = random.normal(k4, (hidden_dim, output_dim))
biases_output = jnp.zeros(output_dim)
return (weights_hidden1, biases_hidden1,
weights_hidden2, biases_hidden2,
weights_hidden3, biases_hidden3,
weights_output, biases_output)

def mlp(params, x):
(weights_hidden1, biases_hidden1,
weights_hidden2, biases_hidden2,
weights_hidden3, biases_hidden3,
weights_output, biases_output) = params
hidden_layer1 = jax.nn.softplus(jnp.dot(x, weights_hidden1) + biases_hidden1)
hidden_layer2 = jax.nn.softplus(jnp.dot(hidden_layer1, weights_hidden2) + biases_hidden2)
hidden_layer3 = jax.nn.softplus(jnp.dot(hidden_layer2, weights_hidden3) + biases_hidden3)
return jnp.dot(hidden_layer3, weights_output) + biases_output

input_dim = 2
hidden_dim = 10
output_dim = 1
rng_key = random.PRNGKey(0)
params = initialize_mlp_params(rng_key, input_dim, hidden_dim, output_dim)
x0 = jnp.array([0.5, 0.5])
trust_region = (jnp.array([0, 0]), jnp.array([1, 1]))
mlp_lambda = lambda x: mlp(params, x)
bounds = ab.taylor_bounds(mlp_lambda, max_degree=2)(x0, trust_region)
bounds.coefficients


but I got this error:

TypeError Traceback (most recent call last)
in <cell line: 47>()
45 mlp_lambda = lambda x: mlp(params, x)
46 # Use the mlp function directly in taylor_bounds
---> 47 bounds = ab.taylor_bounds(mlp_lambda, max_degree=2)(x0, trust_region)
48 bounds.coefficients

14 frames
/usr/local/lib/python3.10/dist-packages/autobound/jax/jax_bound.py in bound_fun(x0, x_trust_region)
140 if fun is None:
141 raise NotImplementedError(eqn.primitive)
--> 142 outvar_enclosures = fun(*invar_intermediates, **eqn.params)
143 if len(eqn.outvars) == 1:
144 outvar_enclosures = (outvar_enclosures,)

/usr/local/lib/python3.10/dist-packages/autobound/jax/jax_bound.py in g(intermediate)
415 f = arithmetic.get_elementwise_fun(get_enclosure)
416 def g(intermediate):
--> 417 return f(intermediate.enclosure, intermediate.trust_region)
418 return g
419

/usr/local/lib/python3.10/dist-packages/autobound/enclosure_arithmetic.py in fun(arg_enclosure, arg_trust_region)
292 self.max_degree,
293 self.np_like)
--> 294 return self.compose_enclosures(elementwise_enclosure, arg_enclosure)
295 return fun
296

/usr/local/lib/python3.10/dist-packages/autobound/enclosure_arithmetic.py in compose_enclosures(self, elementwise_enclosure, arg_enclosure)
241 term = (coefficient,)
242 else:
--> 243 poly = self.power(arg_diff_enclosure, p)
244 term = tuple(
245 # The special-casing when i < p ensures that the TaylorEnclosure

/usr/local/lib/python3.10/dist-packages/autobound/enclosure_arithmetic.py in power(self, a, p)
330 np_like=self.np_like)
331 multiplicative_identity = self.np_like.ones_like(self.trust_region[0])
--> 332 result = polynomials.integer_power( # pytype: disable=wrong-arg-types
333 a,
334 p,

/usr/local/lib/python3.10/dist-packages/autobound/polynomials.py in integer_power(a, exponent, add, additive_identity, multiplicative_identity, term_product_coefficient, term_power_coefficient, scalar_product)
222 return c
223 output_degree = (len(a) - 1) * exponent
--> 224 return tuple(get_coeff(i) for i in range(1 + output_degree))
225
226

/usr/local/lib/python3.10/dist-packages/autobound/polynomials.py in (.0)
222 return c
223 output_degree = (len(a) - 1) * exponent
--> 224 return tuple(get_coeff(i) for i in range(1 + output_degree))
225
226

/usr/local/lib/python3.10/dist-packages/autobound/polynomials.py in get_coeff(i)
210 running_product_power = 0
211 for j, p_j in enumerate(p):
--> 212 running_product = term_product_coefficient(
213 running_product,
214 term_power_coefficient(a[j], j, p_j),

/usr/local/lib/python3.10/dist-packages/autobound/enclosure_arithmetic.py in _elementwise_term_product_coefficient(c0, c1, i, j, x_ndim, np_like)
443 return _pairwise_batched_multiply(u, v, ix_ndim, jx_ndim, np_like)
444 set_arithmetic = interval_arithmetic.IntervalArithmetic(np_like)
--> 445 return set_arithmetic.arbitrary_bilinear(c0, c1, product, assume_product=True)
446
447

/usr/local/lib/python3.10/dist-packages/autobound/interval_arithmetic.py in arbitrary_bilinear(self, a, b, bilinear, assume_product)
74 b_is_interval = isinstance(b, tuple)
75 if not a_is_interval and not b_is_interval:
---> 76 return bilinear(a, b)
77
78 if assume_product:

/usr/local/lib/python3.10/dist-packages/autobound/enclosure_arithmetic.py in product(u, v)
441 """Returns d such that <c0, zi> * <c1, zj> == <d, z**(i+j)>."""
442 def product(u, v):
--> 443 return _pairwise_batched_multiply(u, v, ix_ndim, jx_ndim, np_like)
444 set_arithmetic = interval_arithmetic.IntervalArithmetic(np_like)
445 return set_arithmetic.arbitrary_bilinear(c0, c1, product, assume_product=True)

/usr/local/lib/python3.10/dist-packages/autobound/enclosure_arithmetic.py in _pairwise_batched_multiply(u, v, p, q, np_like)
472 u = np_like.asarray(u)
473 v = np_like.asarray(v)
--> 474 return expand_multiple_dims(u, q) * expand_multiple_dims(v, p, v.ndim-q)
475
476

/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/array_methods.py in deferring_binary_op(self, other)
256 args = (other, self) if swap else (self, other)
257 if isinstance(other, _accepted_binop_types):
--> 258 return binary_op(*args)
259 if isinstance(other, rejected_binop_types):
260 raise TypeError(f"unsupported operand type(s) for {opchar}: "
[... skipping hidden 12 frame]
/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/ufuncs.py in fn(x1, x2)
95 def fn(x1, x2, /):
96 x1, x2 = promote_args(numpy_fn.name, x1, x2)
---> 97 return lax_fn(x1, x2) if x1.dtype != np.bool
else bool_lax_fn(x1, x2)
98 fn.qualname = f"jax.numpy.{numpy_fn.name}"
99 fn = jit(fn, inline=True)
[... skipping hidden 7 frame]
/usr/local/lib/python3.10/dist-packages/jax/_src/lax/lax.py in broadcasting_shape_rule(name, *avals)
1577 result_shape.append(non_1s[0])
1578 else:
-> 1579 raise TypeError(f'{name} got incompatible shapes for broadcasting: '
1580 f'{", ".join(map(str, map(tuple, shapes)))}.')
1581

TypeError: mul got incompatible shapes for broadcasting**: (2, 1), (10, 2).**

Support all JAX primitives

  • abs
  • acos
  • acosh
  • add
  • after_all
  • all_gather
  • all_to_all
  • and
  • approx_top_k
  • argmax
  • argmin
  • asin
  • asinh
  • atan
  • atan2
  • atanh
  • axis_index
  • bessel_i0e
  • bessel_i1e
  • bitcast_convert_type
  • broadcast_in_dim
  • cbrt
  • ceil
  • clamp
  • clz
  • complex
  • concatenate
  • cond
  • conj
  • conv_general_dilated
  • convert_element_type
  • copy
  • cos
  • cosh
  • create_token
  • cumlogsumexp
  • cummax
  • cummin
  • cumprod
  • cumsum
  • custom_linear_solve
  • device_put
  • digamma
  • div
  • dot_general
  • dynamic_slice
  • dynamic_update_slice
  • eq
  • erf
  • erf_inv
  • erfc
  • exp
  • expm1
  • fft
  • floor
  • gather
  • ge
  • gt
  • igamma
  • igamma_grad_a
  • igammac
  • imag
  • infeed
  • integer_pow
  • iota
  • is_finite
  • le
  • lgamma
  • log
  • log1p
  • logistic
  • lt
  • max
  • min
  • mul
  • ne
  • neg
  • nextafter
  • not
  • or
  • outfeed
  • pad
  • pmax
  • pmin
  • population_count
  • pow
  • ppermute
  • psum
  • random_gamma_grad
  • real
  • reduce
  • reduce_and
  • reduce_max
  • reduce_min
  • reduce_or
  • reduce_precision
  • reduce_prod
  • reduce_sum
  • reduce_window
  • reduce_window_max
  • reduce_window_min
  • reduce_window_sum
  • reduce_xor
  • regularized_incomplete_beta
  • rem
  • reshape
  • rev
  • rng_bit_generator
  • rng_uniform
  • round
  • rsqrt
  • scan
  • scatter
  • scatter-add
  • scatter-max
  • scatter-min
  • scatter-mul
  • select_and_gather_add
  • select_and_scatter
  • select_and_scatter_add
  • select_n
  • shift_left
  • shift_right_arithmetic
  • shift_right_logical
  • sign
  • sin
  • sinh
  • slice
  • sort
  • sqrt
  • squeeze
  • stop_gradient
  • sub
  • tan
  • tanh
  • top_k
  • transpose
  • while
  • xor

Multidimensional Example

It would be nice to expand the README to show how to use autobound with vector-valued inputs and interpret the output.
From looking into the code, I thought that

 import autobound.jax as ab
 import jax.numpy as jnp
 f = lambda x: 1.5*jnp.exp(3.0*x[1]*x[0]) - 25.0*x[0]**2
 trust_region = (jnp.array([0,0]),jnp.array([1,1]))
 x0 = jnp.array([.5,.5])
 bounds = ab.taylor_bounds(f,max_degree=2)(x0,trust_region)

should work, but that runs into issues with dynamic_slice not being implemented.
The following worked:

A= jnp.array([0,1])
B= jnp.array([1,0])
f = lambda x: 1.5*jnp.exp(3.0*jnp.dot(A,x)*jnp.dot(B,x)) - 25*jnp.dot(B,x)**2

Perhaps there is an intended, easier way I am missing?

Torch and TensorFlow support

Thanks to the authors and sponsors, this looks like it could be very useful for AI/ML.

Is there support for Torch and TensorFlow on the way?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.