Code Monkey home page Code Monkey logo

optimistix's People

Contributors

colcarroll avatar packquickly avatar patrick-kidger avatar randl avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

optimistix's Issues

Using OptaxMinimiser results in AttributeError

Using the OptaxMinimiser as solver results in " '_Closure' object has no attribute 'init' " whereas the BFGS solver runs without errors.

The objective function uses a custom equinox pytree.

correct name of the exception class that Equinox uses for runtime errors

Hi,

I'm using a couple Equinox pytrees in my program and in one case it is used in conjunction with Newton root finding from Optimistix. My larger code is a gradient descent variation and occasionally a data point will be expected to not have a solution to the root finding algorithm. In trying to set up try: except: , what is the correct name of the exception class that Equinox uses for runtime errors, or should it be something from optimistix?

instance_of_acceleration = AccelerationPytree(l_pr, regime, kinetic_conservative, rot_dissapative, ld_dissapative, epd_dissapative_1, qe_conservative_1, epd_dissapative_2, epd_dissapative_3, epd_dissapative_4, qe_conservative_2, qe_conservative_3)
 
 solver_root = optx.Newton(rtol=1e-8, atol=1e-8)
 y0 = (jnp.array(0.1))
 try:
     sol = optx.root_find(fn=time_root_from_distance, solver=solver_root (well_posed=False), y0=y0, args=instance_of_acceleration, options=dict(lower=0.), max_steps=20000, throw=False)
     Thv = sol.value
 except eqx.exception_module.EqxRuntimeError. (WHAT GOES HERE?):

     #Set Thv to a default value or handle it accordingly
     Thv = 999.  # Replace with an appropriate default value or action
 print(Thv)
 return Thv

Please and thanks,
Tom

Classical newton methods

Hi! Thanks for making this package, I'm finding really helpful for some optimisation problems I'm solving. I was wondering if there is a Newton method that uses jax.hessian or similar to minimise a function? I've seen bfgs, but for problems in <100 dims finding the hessian is fairly fast and guaranteed quadratic convergence is something I'd really like. My own few line implementation (below) supports this. I'm also wondering if there are line searches to adapt the Newton step (e.g. backtracking, exact line search). Happy to help merge this in if you are interested

while n_iter < max_iter and np.linalg.norm(g) > tol:
  n_iter += 1
  H = jax.hessian(lambda x: newton_loss_func(x, args))(x)
  g = jax.grad(lambda x: newton_loss_func(x, args))(x)
  
  # line search - to be upgraded to use a proper line search 
  t_vals = np.linspace(-2, 0, 101)
  loss_vals = jax.vmap(lambda t: newton_loss_func(x + t*np.linalg.inv(H)@g, args))(t_vals)
  t = t_vals[np.argmin(loss_vals)]
  
  x = x + t*np.linalg.inv(H)@g
  
  new_losses.append(newton_loss_func(x, args))

pytree output structure mismatch error in backprop during vmap

I am running into ValueError: pytree does not match out_structure errors when computing gradients for functions where optimistix is called via vmap. The errors disappear when replacing jax.vmap with an equivalent for loop. I have included a MWE bug_report.py which can switch between jax.vmap and for loops via the VMAP variable. My first impression is that the implicit solve during backprop gets passed the wrong (unbatched?) input vector.

Traceback (most recent call last):
  File ".../python/bug_report.py", line 74, in <module>
    loss, grads = loss_fn_w_grad(
  File ".../python/bug_report.py", line 56, in loss_fn
    output = batched_model(
  File ".../python/bug_report.py", line 44, in __call__
    return self.output_layer(opt_2st_vec(t))
  File ".../python/bug_report.py", line 22, in opt_2st_vec
    solution = optx.root_find(obj, solver, x0)
  File ".../venv/lib/python3.11/site-packages/optimistix/_root_find.py", line 227, in root_find
    return iterative_solve(
  File ".../venv/lib/python3.11/site-packages/optimistix/_iterate.py", line 346, in iterative_solve
    ) = adjoint.apply(_iterate, rewrite_fn, inputs, tags)
  File ".../venv/lib/python3.11/site-packages/optimistix/_adjoint.py", line 148, in apply
    return implicit_jvp(primal_fn, rewrite_fn, inputs, tags, self.linear_solver)
  File ".../venv/lib/python3.11/site-packages/optimistix/_ad.py", line 72, in implicit_jvp
    root, residual = _implicit_impl(fn_primal, fn_rewrite, inputs, tags, linear_solver)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: pytree does not match out_structure

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File ".../python/bug_report.py", line 74, in <module>
    loss, grads = loss_fn_w_grad(
                  ^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.11/site-packages/equinox/_ad.py", line 79, in __call__
    return fun_value_and_grad(diff_x, nondiff_x, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.11/site-packages/equinox/internal/_primitive.py", line 413, in _vprim_transpose
    return transpose(cts, *inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.11/site-packages/equinox/internal/_primitive.py", line 211, in _wrapper
    cts = rule(inputs, cts_out)
          ^^^^^^^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.11/site-packages/lineax/_solve.py", line 272, in _linear_solve_transpose
    cts_vector, _, _ = eqxi.filter_primitive_bind(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.11/site-packages/equinox/internal/_primitive.py", line 264, in filter_primitive_bind
    flat_out = prim.bind(*dynamic, treedef=treedef, static=static, flatten=flatten)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.11/site-packages/equinox/internal/_primitive.py", line 299, in batch_rule
    out = _vprim_p.bind(
          ^^^^^^^^^^^^^^
  File ".../venv/lib/python3.11/site-packages/equinox/internal/_primitive.py", line 337, in _vprim_abstract_eval
    outs = abstract_eval(*inputs, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.11/site-packages/equinox/internal/_primitive.py", line 147, in _wrapper
    out = rule(*args)
          ^^^^^^^^^^^
  File ".../venv/lib/python3.11/site-packages/lineax/_solve.py", line 115, in _linear_solve_abstract_eval
    out = eqx.filter_eval_shape(
          ^^^^^^^^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.11/site-packages/lineax/_solve.py", line 86, in _linear_solve_impl
    out = solver.compute(state, vector, options)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.11/site-packages/lineax/_solve.py", line 632, in compute
    solution, result, _ = solver.compute(state, vector, options)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.11/site-packages/lineax/_solver/lu.py", line 62, in compute
    vector = ravel_vector(vector, packed_structures)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.11/site-packages/lineax/_solver/misc.py", line 84, in ravel_vector
    raise ValueError("pytree does not match out_structure")
ValueError: pytree does not match out_structure
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

MWE:

import equinox as eqx
import jax.random as jr
import jax.numpy as jnp
import jax
import optimistix as optx

jax.config.update("jax_enable_x64", True)

VMAP = True


def rf(x, args, g):
    c = 1 - x[0] - x[1]
    f = -x[0] * jnp.exp(-g) - x[1]
    return f, c


def opt_2st_vec(g):
    x0 = (1 / 2, 1 / 2)
    obj = eqx.Partial(rf, g=g.squeeze())
    solver = optx.Newton(atol=1e-8, rtol=1e-8)
    solution = optx.root_find(obj, solver, x0)
    return jnp.expand_dims(solution.value[1], 0)


class Model(eqx.Module):
    input_layer: eqx.nn.Linear
    output_layer: eqx.nn.Linear

    def __init__(
        self,
        n_inputs,
        key,
    ):
        self.input_layer = eqx.nn.Linear(
            in_features=n_inputs, out_features=1, use_bias=False, key=key
        )
        self.output_layer = eqx.nn.Linear(
            in_features=1, out_features=1,  use_bias=True, key=key
        )

    def __call__(self, inputs):
        t = self.input_layer(inputs)
        return self.output_layer(opt_2st_vec(t))


def loss_fn(
    params,
    static,
    inputs_folding,
    target,
):
    model = eqx.combine(params, static)
    if VMAP:
        batched_model = jax.vmap(model)
        output = batched_model(
            inputs_folding,
        )
    else:
        output = jnp.array([
            model(inputs_folding[i])
            for i in range(inputs_folding.shape[0])
        ])
    loss = jnp.mean(jnp.abs(target - output[:, 0]))
    return loss

inputs = jr.uniform(jr.PRNGKey(0), (128, 10))
target = jr.uniform(jr.PRNGKey(0), (128,))

model = Model(inputs.shape[1], jr.PRNGKey(0))

params, static = eqx.partition(model, eqx.is_array)
loss_fn_w_grad = eqx.filter_value_and_grad(loss_fn)
loss, grads = loss_fn_w_grad(
    params,
    static,
    inputs,
    target,
)

package versions:

equinox==0.11.3
jax==0.4.25
jaxlib==0.4.25
lineax==0.0.4
optimistix==0.0.6

grad of vmap of function which wraps an optax solver occasionally fails

Hi,
I previously had the optx newton root finding algorithm in operation which used a jnp.where to set a default value when the root_finder couldn't find a solution. It worked to insert the default value but the program would fail to find the gradient when default value was implemented.
I ended up moving to the optx minimizer wrapper for a optax solver to minimize a func in place of a root finding operation and this works very nicely as it handles the more extreme slopes that occur in my functions.
But then params outside were changed such that two long lumbers equalled the negative of each other with x64 precision. The point is not the need to buy a lottery ticket but that I need a way to make grad work when the solver cannot find a solution.

Specifics: I use vmap to fill out the elements of a 1D array, by calling a function by vmap for each elemen of array. That function includes the following code:

new code returns a value = y0 and grad = crash when solution does not exist

optimizer_acc = optx.OptaxMinimiser(optax.adabelief(learning_rate=1e-2), rtol=1e-8, atol=1e-8)
y0 = (jnp.array(0.01))
sol = optx.minimise(fn=time_root_from_distance, solver = optimizer_acc, y0 = y0, args = instance_of_acceleration, options=dict(lower=0.),max_steps=10000, throw=False)

old code returns default value but crashes when grad requested:

solver_root = optx.Newton(rtol=1e-5, atol=1e-4)
    y0 = (jnp.array(0.01))
    sol = optx.root_find(fn=time_root_from_distance, solver = solver_root, y0 = y0, args = instance_of_acceleration, options=dict(lower=0.),max_steps=10000, throw=False)    
    Thv = jnp.where(sol.result == optx.RESULTS.successful, sol.value, 9999.)

Both work when they can find a solution. But when a solution does not exist, cannot be found, jax.grad(Objective) fails.

Not sure if this question is misplaced but any suggestions on an approach to return a grad not just a value when solution is absent would be appreciated.

Thanks,
Tom

Using `optimistix` with an `equinox` model

Hi everyone, thanks for the great library and apologies in advance for this basic question.
I'm trying to find the true minimum of a small neural network, and I thought of using a solver from optimistix together with an equinox model. However, I haven't been able to make the two work together.

Here is a minimal snippet which fails:

import jax 
import jax.numpy as jnp
import equinox as eqx
import optimistix as optx

jax.config.update("jax_enable_x64", True)


X = jax.random.normal(jax.random.PRNGKey(0), (2000, 8))

@jax.vmap
def function(x):
    return x[0] + x[1]**2 + jnp.cos(x[2]) + jnp.sin(x[3]) + x[4]*x[5] + (x[6]*x[7])**3

y = function(X).reshape(-1, 1)

model = eqx.nn.MLP(in_size=8, out_size=1, width_size=4, depth=2, activation=jax.nn.silu, key=jax.random.PRNGKey(0))

static, params = eqx.partition(model, eqx.is_inexact_array)

def loss_fn(params, static, X, y):
    model = eqx.combine(params, static)
    return jnp.sum((jax.vmap(model)(X) - y)**2)

solver = optx.Newton(rtol=1e-5, atol=1e-5)
sol = optx.minimise(loss_fn, solver, params)

I'm getting TypeError: Cannot determine dtype of <PjitFunction of <function silu at 0x742fde959300>>.

What am I doing wrong?
Thank you in advance.

Error in "optimistix/docs/examples /optimise_diffeq.ipynb"

Hi, I was trying to run the example given in "optimistix/docs/examples /optimise_diffeq.ipynb". For some reason I am receiving error "---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
__________________________________________________________________ Cell 4 line 9
5 pred_values = batch_solve(parameters, y0s, saveat)
6 return values - pred_values
----> 9 (y0s_0, values_0) = get_data()
10 y0s = jnp.array(y0s_0)
11 values = jnp.array(values_0)

__________________________________________________________________ Cell 4 line 1
9 saveat = dfx.SaveAt(ts=jnp.linspace(0, 30, 20))
10 batch_solve = eqx.filter_jit(eqx.filter_vmap(solve, in_axes=(None, 0, None)))
---> 11 values = batch_solve(true_parameters, y0s, saveat)
12 return y0s, values

[... skipping hidden 21 frame]

__________________________________________________________________ Cell 4 line 2
19 t1 = saveat.subs.ts[-1]
20 dt0 = 0.1
---> 21 sol = dfx.diffeqsolve(
22 term,
23 solver,
24 t0,
25 t1,
...
--> 305 raise ValueError("No arrays to thread error on to.")
306 dynamic_x = _error(dynamic_x, pred, index, msgs=msgs, on_error=on_error)
307 return combine(dynamic_x, static_x)

ValueError: No arrays to thread error on to.". I would appreciate if I could get help to fix it.

Including user-defined Jacobian

Hi devs, looks like a really nice library. I've been looking for a Jax-native root finding method that supports vmap for some time. Currently I am using an external call to scipy.optimize.root together with the multiprocessing library, which is quite slow.

The runtime for root finding using the Newton method in this library is slower than the above method though - I suspect this is because the Jacobian needs to be calculated at each iteration. Is there a way for the user to supply an analytic Jacobian? Or could you point me in the right direction to implement this feature?

For reference, this is my MWE in case I am not doing things efficiently:

from jax import jit, jacfwd, vmap, random
import optimistix as optx

def fn(y, b):
    return (y-b)**2

M = 1024
key = random.PRNGKey(42)
key, key_ = random.split(key, 2)

y = random.normal(key, (M,))
b = random.normal(key_, (M,))
sol = optx.root_find(vmap(fn), solver, y, b)

New solvers

  • Anderson acceleration
  • LBFGS
  • Affine

Powell's (unconstrained) derivative free optimisers:

  • UOBYQA
  • NEWUOA

On an affine solvers: such systems can be handled with a single linear solve. JAX can detect affine functions via

import jax
import jax.interpreters.partial_eval as pe

def is_affine(f, *args, **kwargs):
    jaxpr = jax.make_jaxpr(jax.jacfwd(f))(*args, **kwargs)
    _, used_inputs = pe.dce_jaxpr(jaxpr.jaxpr, [True] * len(jaxpr.out_avals))
    return all(not x for x in used_inputs)

Incompatibility of least_squares and custom_vjp

I'm running into some trouble applying optimistix.least_squares(fn, LevenbergMarquardt(...), x0) to certain problems. From the error message below, my understanding of the root cause is that forward-mode autodiff cannot be used on jax.custom_vjp. In my case I am using diffrax to solve an ODE within fn(...), which I think might be causing the problem.

Is my basic understanding correct? Are there specific constraints / assumptions that fn(...) must follow for optimistix.least_squares to work (e.g. cannot use jax.custom_vjp)? Is there any way around this?

The error I get is:

TypeError: can't apply forward-mode autodiff (jvp) to a custom_vjp function.

The full code to reproduce the error is below. By the way I get the same problem when trying to use jaxopt.LevenbergMarquardt on this problem.

# === imports === #
from jax.config import config; config.update("jax_enable_x64", True)
import jax
import jax.numpy as jnp
import diffrax
from diffrax import ODETerm, Dopri5, SaveAt
from tqdm import trange
import optimistix
from optimistix import LevenbergMarquardt

# === functions defining flow field and residuals === #
def geodesic_vector_field(P):
    jacP = jax.jacobian(P)
    def vector_field(t, state, args):
        x, v = state
        Pdx = jacP(x)
        q1 = 0.5 * jnp.einsum("jki,j,k->i",Pdx, v, v)
        q2 = jnp.einsum("ilp,l,p->i", Pdx, v, v)
        dxdt = v
        dvdt = jnp.linalg.solve(P(x), q1 + q2)
        return (dxdt, dvdt)
    return vector_field

def exponential_map(x0, v0, term, solver):
    return diffrax.diffeqsolve(
        term, solver, t0=0, t1=1, dt0=0.1, y0=(x0, v0),
        saveat=SaveAt(t0=False, t1=True)
    ).ys[0].ravel()

def shooting_method_resids(x0, x1, term, solver):
    return jax.jit(
        lambda v0, args: (x1 - exponential_map(x0, v0, term, solver)).ravel()
    )

# === try solving the boundary value problem === #
term = ODETerm(geodesic_vector_field(lambda x: jnp.eye(2)))
solver = Dopri5()

optimistix.least_squares(
    shooting_method_resids(jnp.zeros(2), jnp.ones(2), term, solver),
    LevenbergMarquardt(1e-3, 1e-3),
    -1 * jnp.ones(2)
)

Qs: Mapping different solvers to leaves, parameter normalisation, different parameter scales

Hey, first up another awesome package in the Jax eco-system! I've been meaning to incorporate these kind of solvers in my work for a long time, so thanks for for making it easy 😛. This is partly a discussion post as I am relatively unfamiliar with these algorithms, I have done my best to parse the docs in detail, but feel free to point me to any external resources as I would love to learn more.


Mapping solvers to different leaves

Is there a way that we can map different solvers to each leaf of a pytree? Lets say we know one parameter will be initialised in the smooth bowl of the loss space and can be solved with BFGS, but the other parameter has a 'noisy' loss topology and is best tackled with a regular optax optimiser. This is actually quite typical for the sort of optical models I work with, although not super common in general AFAIK.

It is simple to apply each of these algorithms one at a time to each pytree leaf with eqx.partition and eqx.combine. This approach works but can't 'jointly' optimise these leaves and would result in redundant calculation of the model gradients, since the grads from each evaluation could be passed to both algorithms.

Now I recognise that a 'joint' approach would pose a problem for algorithms like BFGS since it would be trying to find the minimum of a dynamic topology that changes as the other leaves are updated throughout the optimisation. I would be curious as to what you think might be the right approach to this kind of problem, maybe there are solvers designed for this sort of problem? If not what approach might you take, I'm very excited about the flexibility and extensibility of this software to be able to build out much better custom solvers for my rather niche set of optimisation problems.


Parameter normalisation during the solve loop

So during a gradient descent loop we commonly need to apply some normalisation/regularisation to our updated parameters to ensure they are 'physical'. An example would be normalising relative spectral weights to have a sum of 1 after the updates have been applied. I am wondering if there is a way to enforce these constraints during the solve. The simplest example case here would be preventing some values from being above some threshold.

I would guess this would likely be possible through a user-defined solver class, that applies the custom regularisation. If something like this is possible, how would it be implemented? From a crude look at the code it looks like this could be done within the step function of the AbstractGradientDescent class?


Parameters with large scale variation

So this one is more of an open discussion, rather than a specific question. It's very common for the models I work with to have vastly different scales (everything from 1e10 to 1e-10). This is a problem for these algorithms in general, so I was hoping to get your thoughts on what would be the right way to approach a solution.

There is the 'naive' solution where you apply a scaling to each parameter of the pytree before passing it into the minimisation function, and then inverting the scaling once inside the function. Now this works but is far from what I would consider ideal as it still requires a degree of prior knowledge of the model and sort of just kicks the tunable hyper-parameter from a learning rate into a scaling. Granted this is still going to be generally more robust, but I feel like there is something more elegant... I'm wondering if you have any thoughts or ideas about this!


Anyway thanks again for the excellent software and the help!

Improve `IndirectIterativeDual`

  • Add IndirectIterativeDual specific Newton safeguards (Conn, Gould, and Toint "Trust Region Methods" section 7.3)
  • Use Given's rotations to compute diff for different values of λ more efficiently (Conn, Gould, and Toint section 7.3 or Nocedal Wright section 4.3.) Depends on patrick-kidger/lineax#6

Zero implicit gradients when using `ImplicitAdjoint` with CG solver

Hi @patrick-kidger and @packquickly,

I was trying to implement the following meta-learning example from jax-opt in optimistix: Few-shot Adaptation with Model Agnostic Meta-Learning . However, I ran into an issue with implicit differentiation through the inner loop. The below example runs well when using optx.RecursiveCheckpointAdjoint but when I try to recreate the iMAML setup by putting optx.ImplicitAdjoint with a CG solver with 20 steps, all the meta-gradients are zero, and the meta-optimiser doesn't change at all in the training. Could you please help me identify the issue with the code? It seems to be an implementation detail for implicit adjoints that differs between jax-opt and optimistic.

Here is an MWE:

import optimistix as optx
import equinox as eqx
import lineax as lx
import jax
import jax.random as jr
import jax.numpy as jnp
import optax

key = jr.PRNGKey(0)
model = eqx.nn.MLP(1, 1, 40, 2, key=key)

sine_target = lambda x: 1.0 * jnp.sin(x - 0.5) # Target function
x = jr.normal(key, (10, 1)) # Randomly drawn inputs for validation
y_true = sine_target(x)

opt = optx.OptaxMinimiser(optax.adam(1e-3, eps_root=1e-8), 1e-7, 1e-7)
params, static = eqx.partition(model, eqx.is_inexact_array)

def apply_model(params, x):
    model = eqx.combine(params, static)
    return jax.vmap(model)(x)

def loss_fn(params, args):
    y_pred = apply_model(params, x)
    loss = jnp.mean(jnp.square(y_pred - y_true))
    return loss, loss

def adapt_fn(params):
    sol = optx.minimise(loss_fn,
                        opt,
                        params,
                        None,
                        has_aux=True,
                        max_steps=2,
                        throw=False,
                        adjoint=optx.ImplicitAdjoint(lx.CG(1e-7, 1e-7, max_steps=10)),
                        tags=lx.positive_semidefinite_tag)
    return sol.aux # Return the final loss only

loss, grad = jax.value_and_grad(adapt_fn)(params)

print(f"Final loss: {loss:.5f}")
print(f"Gradient: {grad.layers[0].weight}")

Can't use Optimistix solvers with `eqx.Module`s and filtered transformations

Thanks very much for this library! Though I understand it's not the primary use case, I'd like to use optimistix with first-order gradient optimizers and standard neural nets to make use of the ability to vectorize optimizers. (Specifically, I'd like to train an ensemble like in equinox, but where each member of the ensemble is paired with a distinct optimizer.)

I run into an error when using optx.GradientDescent with an eqx.Module. Adapting some example code from this repo for a MWE:

import equinox as eqx
import jax
import jax.numpy as jnp
import optimistix as optx

N = K = 8
x = jnp.linspace(0, 1, N)[None, ...]
y = x**2

model = eqx.nn.MLP(
  in_size=N,
  out_size=N,
  width_size=K,
  depth=1,
  activation=jax.nn.relu,
  key=jax.random.PRNGKey(42),
)


@eqx.filter_jit
def loss(model, args):
  x, y = args
  pred_y = eqx.filter_vmap(model)(x)
  loss = jnp.mean((pred_y - y) ** 2)
  aux = None
  return loss, aux


optimizer = optx.GradientDescent(learning_rate=1e-1, rtol=1e-4, atol=1e-4)
options = None
f_struct = jax.ShapeDtypeStruct((), jnp.float32)
aux_struct = None
tags = frozenset()

init = eqx.filter_jit(eqx.Partial(optimizer.init, fn=loss, options=options, f_struct=f_struct,
                      aux_struct=aux_struct, tags=tags))
step = eqx.filter_jit(eqx.Partial(optimizer.step, fn=loss, options=options, tags=tags))
terminate = eqx.filter_jit(eqx.Partial(optimizer.terminate, fn=loss, options=options, tags=tags))
postprocess = eqx.filter_jit(eqx.Partial(optimizer.postprocess, fn=loss, options=options, tags=tags))

state = init(y=model, args=(x, y))
done, result = terminate(y=model, args=(x, y), state=state)

while not done:
  model, state, _ = step(y=model, args=(x, y), state=state)
  done, result = terminate(y=model, args=(x, y), state=state)
  print(f"Evaluating iteration with loss value {loss(model, (x, y))[0]}.")

if result != optx.RESULTS.successful:
  print("Failed!")

model, _, _ = postprocess(
  y=model,
  aux=None,
  args=(x, y),
  state=state,
  result=result,
)
print(f"Found solution with loss value {loss(model, (x, y))[0]}.")

gives me:

TypeError: Value <jax._src.custom_derivatives.custom_jvp object at 0x1022171d0> with type <class 'jax._src.custom_derivatives.custom_jvp'> is not a valid JAX type

at this line:

f_info_struct = jax.eval_shape(lambda: f_info)

which, if I understand correctly, is the result of jax.eval_shape hitting non-arrays. How can I filter for arrays in model, or is there a different recommended usage pattern here?

Pytree inputs for `rtol` and `atol` or custom termination condition?

So I would love to be able to pass in a pytree for both the rtol and atol values, in a similar vein to how you can set individual learning rates for each leaf in optax. This would make a lot of sense for most of my work which has pytree leaves with vastly different parameter scales.

Looking at the termination condition code, it looks like this hasn't been made an option because the values are applied to both the pytree leaf values ('y space') and the loss value ('f space').

From what I can tell there would two ways to get this behavior:

  1. Allow custom termination condition.

    I don't think this is the best solution as the cauchy termination is already wrapped up with input norm function/pytree.

  2. Allow pytree inputs for rtol and atol.

    I think this could be done relatively easily by allowing the f-space and y-space conditions to be individually specified via a tuple like this: (f-space rtol (float), y-space rtol (float or pytree)). This would preserve the present syntax, while also allowing users full freedom over the termination condition.

Anyway maybe there is a better way that I'm missing, but having this functionality is actually somewhat essential for using optimistix in my work in the long run, so let me know your thoughts!

Issue with vmap `optx.least_squares`.

Hi,

I have issues vectorizing the optx.least_squares function (version 0.0.6) when directly vectorized using JAX's vmap function. This behavior occurs unless the sol.state and sol.result fields are removed from the Solution dataclass instance. Perhaps related to this commit? Somehow, vmap does not know that the jaxpr stuff should not be batched (i.e. pytree_node=False).

MWE

In the provided Minimum Working Example (MWE), I attempt to vectorize the least squares optimization using JAX's vmap function. The process involves a quadratic residual function and the Levenberg-Marquardt solver.

The vectorization attempt fails when trying to return the full Solution object (sol) from the least_squares function. However, if only sol.value is returned, the vectorization succeeds. This suggests a compatibility issue between the full Solution instance structure and JAX's vmap operation.

import jax
import jax.numpy as jnp
import optimistix as optx

# Define a simple quadratic residual function
def residual_fn(params, *args):
    return params[0] * jnp.arange(10) ** 2 + params[1] * jnp.arange(10) + params[2]

# Initialize the Levenberg-Marquardt solver
solver = optx.LevenbergMarquardt(rtol=1e-5, atol=1e-7, norm=optx.rms_norm)

# Define the initial parameters
params_init = jnp.array([1.0, 2.0, 3.0])

# Define a function to perform least squares optimization
def least_squares(params_init):
    sol = optx.least_squares(residual_fn, solver, params_init, max_steps=100, throw=False)
    return sol  # throws an error --> returning sol.value does not... 

# Attempt to vectorize the least squares function
vmap_least_squares = jax.vmap(least_squares, in_axes=(0,), out_axes=0)

# Define a batch of initial parameters
batch_params_init = jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])

# Attempt to perform batched least squares optimization
batch_params_final = vmap_least_squares(batch_params_init)

This produces the error:

Traceback (most recent call last):
  File "/home/r2ci/.cache/pypoetry/virtualenvs/rex-lib-fAzIlxw_-py3.9/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3550, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-7-521a622dc7fd>", line 27, in <module>
    batch_params_final = vmap_least_squares(batch_params_init)
  File "/home/r2ci/.cache/pypoetry/virtualenvs/rex-lib-fAzIlxw_-py3.9/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/r2ci/.cache/pypoetry/virtualenvs/rex-lib-fAzIlxw_-py3.9/lib/python3.9/site-packages/jax/_src/api.py", line 1258, in vmap_f
    out_flat = batching.batch(
  File "/home/r2ci/.cache/pypoetry/virtualenvs/rex-lib-fAzIlxw_-py3.9/lib/python3.9/site-packages/jax/_src/linear_util.py", line 206, in call_wrapped
    ans = gen.send(ans)
  File "/home/r2ci/.cache/pypoetry/virtualenvs/rex-lib-fAzIlxw_-py3.9/lib/python3.9/site-packages/jax/_src/interpreters/batching.py", line 638, in _batch_inner
    out_vals = map(partial(from_elt, trace, axis_size), outs, out_dim_dests)
  File "/home/r2ci/.cache/pypoetry/virtualenvs/rex-lib-fAzIlxw_-py3.9/lib/python3.9/site-packages/jax/_src/interpreters/batching.py", line 270, in from_elt
    return matchaxis(trace.axis_name, axis_size, x_.batch_dim, spec, x_.val)
  File "/home/r2ci/.cache/pypoetry/virtualenvs/rex-lib-fAzIlxw_-py3.9/lib/python3.9/site-packages/jax/_src/interpreters/batching.py", line 1107, in matchaxis
    raise TypeError(f"Output from batched function {x!r} with type "
TypeError: Output from batched function { lambda a:f32[10] b:f32[10]; c:f32[3]. let
    d:f32[1] = slice[limit_indices=(1,) start_indices=(0,) strides=None] c
    e:f32[] = squeeze[dimensions=(0,)] d
    f:f32[10] = mul e a
    g:f32[1] = slice[limit_indices=(2,) start_indices=(1,) strides=None] c
    h:f32[] = squeeze[dimensions=(0,)] g
    i:f32[10] = mul h b
    j:f32[10] = add f i
    k:f32[1] = slice[limit_indices=(3,) start_indices=(2,) strides=None] c
    l:f32[] = squeeze[dimensions=(0,)] k
    m:f32[10] = add j l
  in (m,) } with type <class 'jax._src.core.Jaxpr'> is not a valid JAX type

Non-finite values in the root function are not handled well

Running into NaN values during rootfinding in optimistixs errors rather than decreasing step size. This should probably also be fixed in lineax, which does not appear to check inputs (vector, and potentially operators) for non-finiteness.

Can be reproduced using the following code:

import jax
import jax.numpy as jnp
import jax.random as jrandom
import optimistix as optx
import equinox as eqx
import lineax as lx


class Model(eqx.Module):
    atol: float = eqx.static_field()
    rtol: float = eqx.static_field()
    maxsteps: int = eqx.static_field()

    def __init__(self):
        self.atol = 1e-6
        self.rtol = 1e-4
        self.maxsteps = int(1e5)

    def xdot(self, _, x, __, p, k):
        dxdt = jnp.array(
            [
                jnp.exp(p[2] - x[0] - x[1] + p[0]) - jnp.exp(p[0]),
                jnp.exp(k[0] - x[1] + p[1]) - jnp.exp(p[1]),
            ]
        )
        jax.debug.print(
            "t: {t}, x: {x}, dxdt: {dxdt}", t=_, x=x, dxdt=dxdt, ordered=True
        )
        return dxdt

    @jax.value_and_grad
    def loss(self, p):
        mapped_simulate = jax.vmap(
            self.simulate,
            in_axes=(None, 1, 1, 2),
        )
        n_repeat = 1

        k_range_ss = 5
        k_sss = (
            jrandom.uniform(jrandom.PRNGKey(1), shape=(1, n_repeat))
            * k_range_ss
            * 2
            - k_range_ss
        )
        y0s = jrandom.uniform(jrandom.PRNGKey(0), shape=(2, n_repeat))
        yms = jrandom.uniform(
            jrandom.PRNGKey(0),
            shape=(1, 2, n_repeat),
        )
        r = mapped_simulate(p, k_sss, y0s, yms)
        return jnp.sqrt(jnp.mean(jnp.square(r)))

    def simulate(
        self,
        p: jnp.ndarray,
        ts: jnp.ndarray,
        k_ss: jnp.ndarray,
        k_sim: jnp.ndarray,
        y0: jnp.ndarray,
        ym: jnp.ndarray,
    ):
        xdot_ss = eqx.Partial(self.xdot, 0.0, k=k_ss, p=p)

        solver_ss = optx.Newton(
            rtol=self.rtol,
            atol=self.atol,
            linear_solver=lx.AutoLinearSolver(well_posed=None),
        )

        sol_ss = optx.root_find(
            fn=xdot_ss,
            y0=y0,
            solver=solver_ss,
            max_steps=self.maxsteps,
        )

        return sol_ss.value - ym


if __name__ == "__main__":
    model = Model()
    p = jnp.array([-50, 50, 1.0])
    print(model.loss(p))

produces

Traceback (most recent call last):
  File ".../test.py", line 57, in loss
    r = mapped_simulate(p, k_sss, y0s, yms)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../test.py", line 75, in simulate
    sol_ss = optx.root_find(
             ^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: The linear solver returned non-finite (NaN or inf) output. This usually means that the
operator was not well-posed, and that the solver does not support this.
If you are trying solve a linear least-squares problem then you should pass
`solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve`
assumes that the operator is square and nonsingular.
If you *were* expecting this solver to work with this operator, then it may be because:
(a) the operator is singular, and your code has a bug; or
(b) the operator was nearly singular (i.e. it had a high condition number:
    `jnp.linalg.cond(operator.as_matrix())` is large), and the solver suffered from
    numerical instability issues; or
(c) the operator is declared to exhibit a certain property (e.g. positive definiteness)
    that is does not actually satisfy.
-------
This error occurred during the runtime of your JAX program. Setting the environment
variable `EQX_ON_ERROR=breakpoint` is usually the most useful way to debug such errors.
(This can be navigated using most of the usual commands for the Python debugger:
`u` and `d` to move through stack frames, the name of a variable to print its value,
etc.) See also `https://docs.kidger.site/equinox/api/errors/#equinox.error_if` for more
information.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

`BestSoFar...` wanted behavior ?

Hi Patrick,

I suppose, that the BestSoFar... solvers are working as intended.

import jax.numpy as jnp
import optimistix as optx

def f(x, _):
    return 1 / 2 * x

solv = optx.FixedPointIteration(10e-3, 10e-3)
sol = optx.fixed_point(f, solv, jnp.array(0.3), max_steps=10)
print(sol.value)


solv2 = optx.BestSoFarFixedPoint(solv)
sol2 = optx.fixed_point(f, solv2, jnp.array(0.3), max_steps=3) # shoots out error
print(sol2.value) 

solv2 returns an XlaRuntimeError but could it be possible instead to return the best so far value computed (here it would be at max_steps=2) even if the latest value of sol2 doesn't satisfy there termination condition ?

If not, would there be a possibility to get the best so for value computed without throwing a runtime error ?

Would an exhaustive grid search have a place in `optimistix`?

Hello! In my applications it is very common to optimize a function with an exhaustive grid search method. This is because our loss functions are sharped peaked and poorly behaved in a manageable subset of parameter space, so it is often best to do exhaustive search (perhaps in a clever way). I am planning on implementing this in JAX, and I am wondering if this is within the scope of optimistix. It seems to me that there are tools in the library that would be useful for this task.

A rough outline of the implementation I am imagining is the following:

  1. The grid: The grid would be represented as an arbitrary pytree, whose leaves have leading "batch" dimensions. Namely, leaf $i$ is an array (or a pytree of arrays) with leading dimension(s) $N_i$. For $m$ leaves, the grid is then represented as a $N_1 \times N_2 \times ... \times N_m$ cartesian grid.
  2. The cost function: The cost function $f$ would be a function that takes in a pytree of the same structure of the grid (and additional arguments), evaluated at grid point $(i_1, i_2, ..., i_m)$. This function can return a single value of the cost function, or it can return a grid of cost function evaluations. This would allow support for more clever grid searches than a simple exhaustive search. For example, the simplest example in my field is to return a grid of cost-function evaluations through fourier convolution--e.g. a search over the space of translations. In general, this sub-grid returned by $f$ would explore a region of parameter space unrelated to the $N_1 \times N_2 \times ... \times N_m$ grid.
  3. The solution: There would need to be a flexible API for how the results of the grid search are stored. In a simple case, one could store the results of the best cost function evaluation at every grid point returned by the function $f$. In a more complicated case, one might want to use the grid search to marginalize away a portion of parameter space.

I'm not sure if this is within the scope of optimistix, and I would totally understand if it is not. If it were to be added to the library, I suppose it could be used as a method of ultra-last resort.

Usage with 'vmap'

Hi, looks like a very promising library, this bit in the docs got me interested:

Unlike the SciPy implementation of Newton's method, the Optimistix version also works for vector-valued (or PyTree-valued) y.

Does this mean that the function passed to the root finder has to be vectorized in the traditional sense? Do the root finders here support functions which rely on vmap? I couldn't find anything in the docs about this. Thanks!

Efficient NewtonCG Implementation

Hi all, thanks for the phenomenal library. We're already using it in several statistical genetics methods in my group!

I've been porting over some older code of mine to use optimistix, rather than hand-rolled inference procedures and could use some advice. Currently, I am performing some variational inference using a mix of closed-form updates for variational parameters, as well as gradient-based updates for some hyperparameters. It -roughly- works like,

while True:
  eval_f = jax.value_and_grad(_infer, has_aux=True)
  ((value, var_params), gradient) = eval_f(hyper_param, var_params, data)
  hyper_param = hyper_param + learning_rate * gradient
  if converged:
    break

I'd -like- to retool the above to not only report the current value, aux values (i.e. updated variational parameters), and gradient wrt hyper param, but return a -hvp- function that could be used in a Newton CG like step in Optimistix. I know of the new minimize function, but what isn't clear is how to set up the scenario to not only report gradients, but also return a hvp function internally without having to take two additional passes over the graph (i.e. once for value and grad, another two for hvp => forward + backward).

Is this doable? Apologies if this is somewhat nebulous--I'm happy to clarify.

Extracting intermediate function values/ losses from the solve

Dear optimistix team,

First of all, thank you for your effort in developing optimistix. I have recently transitioned from JAXOpt, and I love it!

I was wondering if it is possible to extract the loss/ function value history from the optimistic solve? In the code example below, it is easy to evaluate the intermediate losses when using the multi_step_solve method, but it is much less efficient than the 'single_step_solve' approach. Using a jax.lax.scan would definitely improve the performance over using a for but I was wondering if there is a simpler method to extract this information in optimistix.

def rastrigin(x, args):
    A = 10.0
    y = A * x.shape[0] + jnp.sum(x**2 - A * jnp.cos(2 * jnp.pi * x), axis=0)
    return y

# How can we extract the losses for a single_step_solve?
def single_step_solve(solver, y0):
    sol = optx.minimise(rastrigin, solver, max_steps=2_000, y0=y0, throw=False)
    return sol.value

def multi_step_solve(solver, y0):
    # This is much less efficient, but it's easy to extract losses
    current_sol = y0
    for i in range(2_000):
        current_sol = optx.minimise(rastrigin, solver, max_steps=1, y0=current_sol, throw=False).value
    return current_sol

Behavior of BFGS

Hi,

I'm not sure if this is the desired behavior, but BFGS directly stops if f(y0) == 0. As an example :

from optimistix import minimise, BFGS, rms_norm
import equinox as eqx
import jax.numpy as jnp


@eqx.filter_jit
def f(y, args):
    return jnp.sum(jnp.square(y))


N = 2


def f2(y, args):
    return f(y, args) - N


y0 = jnp.ones((N,))

res = minimise(
    f,
    BFGS(rtol=1e-13, atol=1e-13, norm=rms_norm),
    y0=y0,
    max_steps=1024,
)

print(res.stats["num_steps"])
assert jnp.allclose(res.value, jnp.zeros((N,)))

res = minimise(
    f2,
    BFGS(rtol=1e-13, atol=1e-13, norm=rms_norm),
    y0=y0,
    max_steps=1024,
)

print(res.stats["num_steps"])
assert jnp.allclose(res.value, jnp.zeros((N,)))

This is because then the termination condition is already validated at first step.
It's not really a big problem since you just need to offset the cost by a constant but it took me a while to figure out where the problem was coming from.
Maybe the initial value for f_info should be set to infinity, or maybe this should be mentioned in the doc ?

BestSoFarMinimiser behavior

Not sure if this is a bug or not, but BestSoFarMinimiser appears to not check the last step of the wrapped solver:

solver = optimistix.BestSoFarMinimiser(optimistix.BFGS(rtol=1e-5, atol=1e-5))
ret = optimistix.minimise(lambda x, _: (x - 3.)**2, solver, 0.)
print(ret.value, ret.state.state.y_eval)

0.0 3.0

Will constrained optimization be supported?

Hi, thank you for the amazing library!

I was wondering if minimizing with user-specified bounds, or algorithms like projected gradient descent are supported?
If not, what would be the best practice you suggest if we are trying to solve things like argmin(norm(Ax+b)) s.t. x >=0?

Thanks in advance!

Incompatibility with jax 0.4.27

Running the basic example produces:

image

when using the dependencies
image

Switching to 0.4.26 on both jax and jaxlib removes the error

TypeError

Excited to explore the library as always!

class MiniData(NamedTuple):
    X: ArrayImpl
    Y: ArrayImpl

def loss_fn_per_obs(y, p):
    return jnp.where(y==1.0, -jnp.log(p ), -jnp.log(1-p ))

def fn(params, args):
    P =  jax.nn.sigmoid(args.X @ params)
    losses = jax.vmap(loss_fn_per_obs, in_axes=(0,0))(args.Y, P)
    return jnp.mean(losses)

init_params = jax.random.normal(jax.random.PRNGKey(0), shape=(19,1))
data = MiniData(X=jax.random.normal(jax.random.PRNGKey(1), shape=(100, 19)),
                Y= jax.random.normal(jax.random.PRNGKey(2), shape=(100, 1)))
solver = optimistix.NonlinearCG(rtol=0.01, atol=0.01)
optimistix.minimise(fn=fn, solver=solver, y0 = init_params, args=data, has_aux=False)

I am running into the following type error:

TypeError: linearize() got an unexpected keyword argument 'has_aux'

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.