patrick-kidger / optimistix Goto Github PK
View Code? Open in Web Editor NEWNonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox. https://docs.kidger.site/optimistix/
License: Apache License 2.0
Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox. https://docs.kidger.site/optimistix/
License: Apache License 2.0
Using the OptaxMinimiser as solver results in " '_Closure' object has no attribute 'init' " whereas the BFGS solver runs without errors.
The objective function uses a custom equinox pytree.
Hi,
I'm using a couple Equinox pytrees in my program and in one case it is used in conjunction with Newton root finding from Optimistix. My larger code is a gradient descent variation and occasionally a data point will be expected to not have a solution to the root finding algorithm. In trying to set up try: except: , what is the correct name of the exception class that Equinox uses for runtime errors, or should it be something from optimistix?
instance_of_acceleration = AccelerationPytree(l_pr, regime, kinetic_conservative, rot_dissapative, ld_dissapative, epd_dissapative_1, qe_conservative_1, epd_dissapative_2, epd_dissapative_3, epd_dissapative_4, qe_conservative_2, qe_conservative_3)
solver_root = optx.Newton(rtol=1e-8, atol=1e-8)
y0 = (jnp.array(0.1))
try:
sol = optx.root_find(fn=time_root_from_distance, solver=solver_root (well_posed=False), y0=y0, args=instance_of_acceleration, options=dict(lower=0.), max_steps=20000, throw=False)
Thv = sol.value
except eqx.exception_module.EqxRuntimeError. (WHAT GOES HERE?):
#Set Thv to a default value or handle it accordingly
Thv = 999. # Replace with an appropriate default value or action
print(Thv)
return Thv
Please and thanks,
Tom
Hi! Thanks for making this package, I'm finding really helpful for some optimisation problems I'm solving. I was wondering if there is a Newton method that uses jax.hessian
or similar to minimise a function? I've seen bfgs, but for problems in <100 dims finding the hessian is fairly fast and guaranteed quadratic convergence is something I'd really like. My own few line implementation (below) supports this. I'm also wondering if there are line searches to adapt the Newton step (e.g. backtracking, exact line search). Happy to help merge this in if you are interested
while n_iter < max_iter and np.linalg.norm(g) > tol:
n_iter += 1
H = jax.hessian(lambda x: newton_loss_func(x, args))(x)
g = jax.grad(lambda x: newton_loss_func(x, args))(x)
# line search - to be upgraded to use a proper line search
t_vals = np.linspace(-2, 0, 101)
loss_vals = jax.vmap(lambda t: newton_loss_func(x + t*np.linalg.inv(H)@g, args))(t_vals)
t = t_vals[np.argmin(loss_vals)]
x = x + t*np.linalg.inv(H)@g
new_losses.append(newton_loss_func(x, args))
I am running into ValueError: pytree does not match out_structure
errors when computing gradients for functions where optimistix is called via vmap. The errors disappear when replacing jax.vmap
with an equivalent for loop. I have included a MWE bug_report.py
which can switch between jax.vmap
and for loops via the VMAP
variable. My first impression is that the implicit solve during backprop gets passed the wrong (unbatched?) input vector.
Traceback (most recent call last):
File ".../python/bug_report.py", line 74, in <module>
loss, grads = loss_fn_w_grad(
File ".../python/bug_report.py", line 56, in loss_fn
output = batched_model(
File ".../python/bug_report.py", line 44, in __call__
return self.output_layer(opt_2st_vec(t))
File ".../python/bug_report.py", line 22, in opt_2st_vec
solution = optx.root_find(obj, solver, x0)
File ".../venv/lib/python3.11/site-packages/optimistix/_root_find.py", line 227, in root_find
return iterative_solve(
File ".../venv/lib/python3.11/site-packages/optimistix/_iterate.py", line 346, in iterative_solve
) = adjoint.apply(_iterate, rewrite_fn, inputs, tags)
File ".../venv/lib/python3.11/site-packages/optimistix/_adjoint.py", line 148, in apply
return implicit_jvp(primal_fn, rewrite_fn, inputs, tags, self.linear_solver)
File ".../venv/lib/python3.11/site-packages/optimistix/_ad.py", line 72, in implicit_jvp
root, residual = _implicit_impl(fn_primal, fn_rewrite, inputs, tags, linear_solver)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: pytree does not match out_structure
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File ".../python/bug_report.py", line 74, in <module>
loss, grads = loss_fn_w_grad(
^^^^^^^^^^^^^^^
File ".../venv/lib/python3.11/site-packages/equinox/_ad.py", line 79, in __call__
return fun_value_and_grad(diff_x, nondiff_x, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../venv/lib/python3.11/site-packages/equinox/internal/_primitive.py", line 413, in _vprim_transpose
return transpose(cts, *inputs)
^^^^^^^^^^^^^^^^^^^^^^^
File ".../venv/lib/python3.11/site-packages/equinox/internal/_primitive.py", line 211, in _wrapper
cts = rule(inputs, cts_out)
^^^^^^^^^^^^^^^^^^^^^
File ".../venv/lib/python3.11/site-packages/lineax/_solve.py", line 272, in _linear_solve_transpose
cts_vector, _, _ = eqxi.filter_primitive_bind(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../venv/lib/python3.11/site-packages/equinox/internal/_primitive.py", line 264, in filter_primitive_bind
flat_out = prim.bind(*dynamic, treedef=treedef, static=static, flatten=flatten)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../venv/lib/python3.11/site-packages/equinox/internal/_primitive.py", line 299, in batch_rule
out = _vprim_p.bind(
^^^^^^^^^^^^^^
File ".../venv/lib/python3.11/site-packages/equinox/internal/_primitive.py", line 337, in _vprim_abstract_eval
outs = abstract_eval(*inputs, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../venv/lib/python3.11/site-packages/equinox/internal/_primitive.py", line 147, in _wrapper
out = rule(*args)
^^^^^^^^^^^
File ".../venv/lib/python3.11/site-packages/lineax/_solve.py", line 115, in _linear_solve_abstract_eval
out = eqx.filter_eval_shape(
^^^^^^^^^^^^^^^^^^^^^^
File ".../venv/lib/python3.11/site-packages/lineax/_solve.py", line 86, in _linear_solve_impl
out = solver.compute(state, vector, options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../venv/lib/python3.11/site-packages/lineax/_solve.py", line 632, in compute
solution, result, _ = solver.compute(state, vector, options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../venv/lib/python3.11/site-packages/lineax/_solver/lu.py", line 62, in compute
vector = ravel_vector(vector, packed_structures)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../venv/lib/python3.11/site-packages/lineax/_solver/misc.py", line 84, in ravel_vector
raise ValueError("pytree does not match out_structure")
ValueError: pytree does not match out_structure
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
MWE:
import equinox as eqx
import jax.random as jr
import jax.numpy as jnp
import jax
import optimistix as optx
jax.config.update("jax_enable_x64", True)
VMAP = True
def rf(x, args, g):
c = 1 - x[0] - x[1]
f = -x[0] * jnp.exp(-g) - x[1]
return f, c
def opt_2st_vec(g):
x0 = (1 / 2, 1 / 2)
obj = eqx.Partial(rf, g=g.squeeze())
solver = optx.Newton(atol=1e-8, rtol=1e-8)
solution = optx.root_find(obj, solver, x0)
return jnp.expand_dims(solution.value[1], 0)
class Model(eqx.Module):
input_layer: eqx.nn.Linear
output_layer: eqx.nn.Linear
def __init__(
self,
n_inputs,
key,
):
self.input_layer = eqx.nn.Linear(
in_features=n_inputs, out_features=1, use_bias=False, key=key
)
self.output_layer = eqx.nn.Linear(
in_features=1, out_features=1, use_bias=True, key=key
)
def __call__(self, inputs):
t = self.input_layer(inputs)
return self.output_layer(opt_2st_vec(t))
def loss_fn(
params,
static,
inputs_folding,
target,
):
model = eqx.combine(params, static)
if VMAP:
batched_model = jax.vmap(model)
output = batched_model(
inputs_folding,
)
else:
output = jnp.array([
model(inputs_folding[i])
for i in range(inputs_folding.shape[0])
])
loss = jnp.mean(jnp.abs(target - output[:, 0]))
return loss
inputs = jr.uniform(jr.PRNGKey(0), (128, 10))
target = jr.uniform(jr.PRNGKey(0), (128,))
model = Model(inputs.shape[1], jr.PRNGKey(0))
params, static = eqx.partition(model, eqx.is_array)
loss_fn_w_grad = eqx.filter_value_and_grad(loss_fn)
loss, grads = loss_fn_w_grad(
params,
static,
inputs,
target,
)
package versions:
equinox==0.11.3
jax==0.4.25
jaxlib==0.4.25
lineax==0.0.4
optimistix==0.0.6
Hi,
I previously had the optx newton root finding algorithm in operation which used a jnp.where to set a default value when the root_finder couldn't find a solution. It worked to insert the default value but the program would fail to find the gradient when default value was implemented.
I ended up moving to the optx minimizer wrapper for a optax solver to minimize a func in place of a root finding operation and this works very nicely as it handles the more extreme slopes that occur in my functions.
But then params outside were changed such that two long lumbers equalled the negative of each other with x64 precision. The point is not the need to buy a lottery ticket but that I need a way to make grad work when the solver cannot find a solution.
Specifics: I use vmap to fill out the elements of a 1D array, by calling a function by vmap for each elemen of array. That function includes the following code:
new code returns a value = y0 and grad = crash when solution does not exist
optimizer_acc = optx.OptaxMinimiser(optax.adabelief(learning_rate=1e-2), rtol=1e-8, atol=1e-8)
y0 = (jnp.array(0.01))
sol = optx.minimise(fn=time_root_from_distance, solver = optimizer_acc, y0 = y0, args = instance_of_acceleration, options=dict(lower=0.),max_steps=10000, throw=False)
old code returns default value but crashes when grad requested:
solver_root = optx.Newton(rtol=1e-5, atol=1e-4)
y0 = (jnp.array(0.01))
sol = optx.root_find(fn=time_root_from_distance, solver = solver_root, y0 = y0, args = instance_of_acceleration, options=dict(lower=0.),max_steps=10000, throw=False)
Thv = jnp.where(sol.result == optx.RESULTS.successful, sol.value, 9999.)
Both work when they can find a solution. But when a solution does not exist, cannot be found, jax.grad(Objective) fails.
Not sure if this question is misplaced but any suggestions on an approach to return a grad not just a value when solution is absent would be appreciated.
Thanks,
Tom
Hi everyone, thanks for the great library and apologies in advance for this basic question.
I'm trying to find the true minimum of a small neural network, and I thought of using a solver from optimistix
together with an equinox
model. However, I haven't been able to make the two work together.
Here is a minimal snippet which fails:
import jax
import jax.numpy as jnp
import equinox as eqx
import optimistix as optx
jax.config.update("jax_enable_x64", True)
X = jax.random.normal(jax.random.PRNGKey(0), (2000, 8))
@jax.vmap
def function(x):
return x[0] + x[1]**2 + jnp.cos(x[2]) + jnp.sin(x[3]) + x[4]*x[5] + (x[6]*x[7])**3
y = function(X).reshape(-1, 1)
model = eqx.nn.MLP(in_size=8, out_size=1, width_size=4, depth=2, activation=jax.nn.silu, key=jax.random.PRNGKey(0))
static, params = eqx.partition(model, eqx.is_inexact_array)
def loss_fn(params, static, X, y):
model = eqx.combine(params, static)
return jnp.sum((jax.vmap(model)(X) - y)**2)
solver = optx.Newton(rtol=1e-5, atol=1e-5)
sol = optx.minimise(loss_fn, solver, params)
I'm getting TypeError: Cannot determine dtype of <PjitFunction of <function silu at 0x742fde959300>>
.
What am I doing wrong?
Thank you in advance.
Right now these have been implemented as standalone solvers.
This would allow us to use all our line searches and descents with the nonlinear CG approximate Hessian. See Conjugate Gradient Methods with Inexact Line Search by Shanno.
Hi, I was trying to run the example given in "optimistix/docs/examples /optimise_diffeq.ipynb". For some reason I am receiving error "---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
__________________________________________________________________ Cell 4 line 9
5 pred_values = batch_solve(parameters, y0s, saveat)
6 return values - pred_values
----> 9 (y0s_0, values_0) = get_data()
10 y0s = jnp.array(y0s_0)
11 values = jnp.array(values_0)
__________________________________________________________________ Cell 4 line 1
9 saveat = dfx.SaveAt(ts=jnp.linspace(0, 30, 20))
10 batch_solve = eqx.filter_jit(eqx.filter_vmap(solve, in_axes=(None, 0, None)))
---> 11 values = batch_solve(true_parameters, y0s, saveat)
12 return y0s, values
[... skipping hidden 21 frame]
__________________________________________________________________ Cell 4 line 2
19 t1 = saveat.subs.ts[-1]
20 dt0 = 0.1
---> 21 sol = dfx.diffeqsolve(
22 term,
23 solver,
24 t0,
25 t1,
...
--> 305 raise ValueError("No arrays to thread error on to.")
306 dynamic_x = _error(dynamic_x, pred, index, msgs=msgs, on_error=on_error)
307 return combine(dynamic_x, static_x)
ValueError: No arrays to thread error on to.". I would appreciate if I could get help to fix it.
Hi devs, looks like a really nice library. I've been looking for a Jax-native root finding method that supports vmap
for some time. Currently I am using an external call to scipy.optimize.root
together with the multiprocessing library, which is quite slow.
The runtime for root finding using the Newton method in this library is slower than the above method though - I suspect this is because the Jacobian needs to be calculated at each iteration. Is there a way for the user to supply an analytic Jacobian? Or could you point me in the right direction to implement this feature?
For reference, this is my MWE in case I am not doing things efficiently:
from jax import jit, jacfwd, vmap, random
import optimistix as optx
def fn(y, b):
return (y-b)**2
M = 1024
key = random.PRNGKey(42)
key, key_ = random.split(key, 2)
y = random.normal(key, (M,))
b = random.normal(key_, (M,))
sol = optx.root_find(vmap(fn), solver, y, b)
Powell's (unconstrained) derivative free optimisers:
On an affine solvers: such systems can be handled with a single linear solve. JAX can detect affine functions via
import jax
import jax.interpreters.partial_eval as pe
def is_affine(f, *args, **kwargs):
jaxpr = jax.make_jaxpr(jax.jacfwd(f))(*args, **kwargs)
_, used_inputs = pe.dce_jaxpr(jaxpr.jaxpr, [True] * len(jaxpr.out_avals))
return all(not x for x in used_inputs)
I'm running into some trouble applying optimistix.least_squares(fn, LevenbergMarquardt(...), x0)
to certain problems. From the error message below, my understanding of the root cause is that forward-mode autodiff cannot be used on jax.custom_vjp
. In my case I am using diffrax
to solve an ODE within fn(...)
, which I think might be causing the problem.
Is my basic understanding correct? Are there specific constraints / assumptions that fn(...)
must follow for optimistix.least_squares
to work (e.g. cannot use jax.custom_vjp
)? Is there any way around this?
The error I get is:
TypeError: can't apply forward-mode autodiff (jvp) to a custom_vjp function.
The full code to reproduce the error is below. By the way I get the same problem when trying to use jaxopt.LevenbergMarquardt
on this problem.
# === imports === #
from jax.config import config; config.update("jax_enable_x64", True)
import jax
import jax.numpy as jnp
import diffrax
from diffrax import ODETerm, Dopri5, SaveAt
from tqdm import trange
import optimistix
from optimistix import LevenbergMarquardt
# === functions defining flow field and residuals === #
def geodesic_vector_field(P):
jacP = jax.jacobian(P)
def vector_field(t, state, args):
x, v = state
Pdx = jacP(x)
q1 = 0.5 * jnp.einsum("jki,j,k->i",Pdx, v, v)
q2 = jnp.einsum("ilp,l,p->i", Pdx, v, v)
dxdt = v
dvdt = jnp.linalg.solve(P(x), q1 + q2)
return (dxdt, dvdt)
return vector_field
def exponential_map(x0, v0, term, solver):
return diffrax.diffeqsolve(
term, solver, t0=0, t1=1, dt0=0.1, y0=(x0, v0),
saveat=SaveAt(t0=False, t1=True)
).ys[0].ravel()
def shooting_method_resids(x0, x1, term, solver):
return jax.jit(
lambda v0, args: (x1 - exponential_map(x0, v0, term, solver)).ravel()
)
# === try solving the boundary value problem === #
term = ODETerm(geodesic_vector_field(lambda x: jnp.eye(2)))
solver = Dopri5()
optimistix.least_squares(
shooting_method_resids(jnp.zeros(2), jnp.ones(2), term, solver),
LevenbergMarquardt(1e-3, 1e-3),
-1 * jnp.ones(2)
)
Hey, first up another awesome package in the Jax eco-system! I've been meaning to incorporate these kind of solvers in my work for a long time, so thanks for for making it easy 😛. This is partly a discussion post as I am relatively unfamiliar with these algorithms, I have done my best to parse the docs in detail, but feel free to point me to any external resources as I would love to learn more.
Mapping solvers to different leaves
Is there a way that we can map different solvers to each leaf of a pytree? Lets say we know one parameter will be initialised in the smooth bowl of the loss space and can be solved with BFGS
, but the other parameter has a 'noisy' loss topology and is best tackled with a regular optax
optimiser. This is actually quite typical for the sort of optical models I work with, although not super common in general AFAIK.
It is simple to apply each of these algorithms one at a time to each pytree leaf with eqx.partition
and eqx.combine
. This approach works but can't 'jointly' optimise these leaves and would result in redundant calculation of the model gradients, since the grads from each evaluation could be passed to both algorithms.
Now I recognise that a 'joint' approach would pose a problem for algorithms like BFGS
since it would be trying to find the minimum of a dynamic topology that changes as the other leaves are updated throughout the optimisation. I would be curious as to what you think might be the right approach to this kind of problem, maybe there are solvers designed for this sort of problem? If not what approach might you take, I'm very excited about the flexibility and extensibility of this software to be able to build out much better custom solvers for my rather niche set of optimisation problems.
Parameter normalisation during the solve loop
So during a gradient descent loop we commonly need to apply some normalisation/regularisation to our updated parameters to ensure they are 'physical'. An example would be normalising relative spectral weights to have a sum of 1 after the updates have been applied. I am wondering if there is a way to enforce these constraints during the solve. The simplest example case here would be preventing some values from being above some threshold.
I would guess this would likely be possible through a user-defined solver class, that applies the custom regularisation. If something like this is possible, how would it be implemented? From a crude look at the code it looks like this could be done within the step
function of the AbstractGradientDescent
class?
Parameters with large scale variation
So this one is more of an open discussion, rather than a specific question. It's very common for the models I work with to have vastly different scales (everything from 1e10 to 1e-10). This is a problem for these algorithms in general, so I was hoping to get your thoughts on what would be the right way to approach a solution.
There is the 'naive' solution where you apply a scaling to each parameter of the pytree before passing it into the minimisation function, and then inverting the scaling once inside the function. Now this works but is far from what I would consider ideal as it still requires a degree of prior knowledge of the model and sort of just kicks the tunable hyper-parameter from a learning rate into a scaling. Granted this is still going to be generally more robust, but I feel like there is something more elegant... I'm wondering if you have any thoughts or ideas about this!
Anyway thanks again for the excellent software and the help!
IndirectIterativeDual
specific Newton safeguards (Conn, Gould, and Toint "Trust Region Methods" section 7.3)diff
for different values of λ more efficiently (Conn, Gould, and Toint section 7.3 or Nocedal Wright section 4.3.) Depends on patrick-kidger/lineax#6Hi @patrick-kidger and @packquickly,
I was trying to implement the following meta-learning example from jax-opt in optimistix: Few-shot Adaptation with Model Agnostic Meta-Learning . However, I ran into an issue with implicit differentiation through the inner loop. The below example runs well when using optx.RecursiveCheckpointAdjoint
but when I try to recreate the iMAML setup by putting optx.ImplicitAdjoint
with a CG solver with 20 steps, all the meta-gradients are zero, and the meta-optimiser doesn't change at all in the training. Could you please help me identify the issue with the code? It seems to be an implementation detail for implicit adjoints that differs between jax-opt and optimistic.
Here is an MWE:
import optimistix as optx
import equinox as eqx
import lineax as lx
import jax
import jax.random as jr
import jax.numpy as jnp
import optax
key = jr.PRNGKey(0)
model = eqx.nn.MLP(1, 1, 40, 2, key=key)
sine_target = lambda x: 1.0 * jnp.sin(x - 0.5) # Target function
x = jr.normal(key, (10, 1)) # Randomly drawn inputs for validation
y_true = sine_target(x)
opt = optx.OptaxMinimiser(optax.adam(1e-3, eps_root=1e-8), 1e-7, 1e-7)
params, static = eqx.partition(model, eqx.is_inexact_array)
def apply_model(params, x):
model = eqx.combine(params, static)
return jax.vmap(model)(x)
def loss_fn(params, args):
y_pred = apply_model(params, x)
loss = jnp.mean(jnp.square(y_pred - y_true))
return loss, loss
def adapt_fn(params):
sol = optx.minimise(loss_fn,
opt,
params,
None,
has_aux=True,
max_steps=2,
throw=False,
adjoint=optx.ImplicitAdjoint(lx.CG(1e-7, 1e-7, max_steps=10)),
tags=lx.positive_semidefinite_tag)
return sol.aux # Return the final loss only
loss, grad = jax.value_and_grad(adapt_fn)(params)
print(f"Final loss: {loss:.5f}")
print(f"Gradient: {grad.layers[0].weight}")
Hi Patrick,
solver = optx.Bisection(rtol=1e-19, atol=1e-15, lower = jnp.array(0.), upper = jnp.array(64000) )
y0 = (jnp.array(0.01))
sol = optx.root_find(fn=speed_curve_function, solver = solver, y0 = y0, args = root_find_pytree)
Thv = sol.value
I've tried inserting upper and lower several different ways.
Thanks,
Tom
Thanks very much for this library! Though I understand it's not the primary use case, I'd like to use optimistix
with first-order gradient optimizers and standard neural nets to make use of the ability to vectorize optimizers. (Specifically, I'd like to train an ensemble like in equinox
, but where each member of the ensemble is paired with a distinct optimizer.)
I run into an error when using optx.GradientDescent
with an eqx.Module
. Adapting some example code from this repo for a MWE:
import equinox as eqx
import jax
import jax.numpy as jnp
import optimistix as optx
N = K = 8
x = jnp.linspace(0, 1, N)[None, ...]
y = x**2
model = eqx.nn.MLP(
in_size=N,
out_size=N,
width_size=K,
depth=1,
activation=jax.nn.relu,
key=jax.random.PRNGKey(42),
)
@eqx.filter_jit
def loss(model, args):
x, y = args
pred_y = eqx.filter_vmap(model)(x)
loss = jnp.mean((pred_y - y) ** 2)
aux = None
return loss, aux
optimizer = optx.GradientDescent(learning_rate=1e-1, rtol=1e-4, atol=1e-4)
options = None
f_struct = jax.ShapeDtypeStruct((), jnp.float32)
aux_struct = None
tags = frozenset()
init = eqx.filter_jit(eqx.Partial(optimizer.init, fn=loss, options=options, f_struct=f_struct,
aux_struct=aux_struct, tags=tags))
step = eqx.filter_jit(eqx.Partial(optimizer.step, fn=loss, options=options, tags=tags))
terminate = eqx.filter_jit(eqx.Partial(optimizer.terminate, fn=loss, options=options, tags=tags))
postprocess = eqx.filter_jit(eqx.Partial(optimizer.postprocess, fn=loss, options=options, tags=tags))
state = init(y=model, args=(x, y))
done, result = terminate(y=model, args=(x, y), state=state)
while not done:
model, state, _ = step(y=model, args=(x, y), state=state)
done, result = terminate(y=model, args=(x, y), state=state)
print(f"Evaluating iteration with loss value {loss(model, (x, y))[0]}.")
if result != optx.RESULTS.successful:
print("Failed!")
model, _, _ = postprocess(
y=model,
aux=None,
args=(x, y),
state=state,
result=result,
)
print(f"Found solution with loss value {loss(model, (x, y))[0]}.")
gives me:
TypeError: Value <jax._src.custom_derivatives.custom_jvp object at 0x1022171d0> with type <class 'jax._src.custom_derivatives.custom_jvp'> is not a valid JAX type
at this line:
which, if I understand correctly, is the result of jax.eval_shape
hitting non-arrays. How can I filter for arrays in model
, or is there a different recommended usage pattern here?
So I would love to be able to pass in a pytree for both the rtol
and atol
values, in a similar vein to how you can set individual learning rates for each leaf in optax
. This would make a lot of sense for most of my work which has pytree leaves with vastly different parameter scales.
Looking at the termination condition code, it looks like this hasn't been made an option because the values are applied to both the pytree leaf values ('y space') and the loss value ('f space').
From what I can tell there would two ways to get this behavior:
Allow custom termination condition.
I don't think this is the best solution as the cauchy termination is already wrapped up with input norm
function/pytree.
Allow pytree inputs for rtol
and atol
.
I think this could be done relatively easily by allowing the f-space and y-space conditions to be individually specified via a tuple like this: (f-space rtol (float), y-space rtol (float or pytree)). This would preserve the present syntax, while also allowing users full freedom over the termination condition.
Anyway maybe there is a better way that I'm missing, but having this functionality is actually somewhat essential for using optimistix in my work in the long run, so let me know your thoughts!
Hi,
I have issues vectorizing the optx.least_squares
function (version 0.0.6
) when directly vectorized using JAX's vmap
function. This behavior occurs unless the sol.state
and sol.result
fields are removed from the Solution
dataclass instance. Perhaps related to this commit? Somehow, vmap
does not know that the jaxpr stuff should not be batched (i.e. pytree_node=False
).
In the provided Minimum Working Example (MWE), I attempt to vectorize the least squares optimization using JAX's vmap
function. The process involves a quadratic residual function and the Levenberg-Marquardt solver.
The vectorization attempt fails when trying to return the full Solution object (sol
) from the least_squares
function. However, if only sol.value
is returned, the vectorization succeeds. This suggests a compatibility issue between the full Solution instance structure and JAX's vmap operation.
import jax
import jax.numpy as jnp
import optimistix as optx
# Define a simple quadratic residual function
def residual_fn(params, *args):
return params[0] * jnp.arange(10) ** 2 + params[1] * jnp.arange(10) + params[2]
# Initialize the Levenberg-Marquardt solver
solver = optx.LevenbergMarquardt(rtol=1e-5, atol=1e-7, norm=optx.rms_norm)
# Define the initial parameters
params_init = jnp.array([1.0, 2.0, 3.0])
# Define a function to perform least squares optimization
def least_squares(params_init):
sol = optx.least_squares(residual_fn, solver, params_init, max_steps=100, throw=False)
return sol # throws an error --> returning sol.value does not...
# Attempt to vectorize the least squares function
vmap_least_squares = jax.vmap(least_squares, in_axes=(0,), out_axes=0)
# Define a batch of initial parameters
batch_params_init = jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
# Attempt to perform batched least squares optimization
batch_params_final = vmap_least_squares(batch_params_init)
This produces the error:
Traceback (most recent call last):
File "/home/r2ci/.cache/pypoetry/virtualenvs/rex-lib-fAzIlxw_-py3.9/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3550, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-7-521a622dc7fd>", line 27, in <module>
batch_params_final = vmap_least_squares(batch_params_init)
File "/home/r2ci/.cache/pypoetry/virtualenvs/rex-lib-fAzIlxw_-py3.9/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/r2ci/.cache/pypoetry/virtualenvs/rex-lib-fAzIlxw_-py3.9/lib/python3.9/site-packages/jax/_src/api.py", line 1258, in vmap_f
out_flat = batching.batch(
File "/home/r2ci/.cache/pypoetry/virtualenvs/rex-lib-fAzIlxw_-py3.9/lib/python3.9/site-packages/jax/_src/linear_util.py", line 206, in call_wrapped
ans = gen.send(ans)
File "/home/r2ci/.cache/pypoetry/virtualenvs/rex-lib-fAzIlxw_-py3.9/lib/python3.9/site-packages/jax/_src/interpreters/batching.py", line 638, in _batch_inner
out_vals = map(partial(from_elt, trace, axis_size), outs, out_dim_dests)
File "/home/r2ci/.cache/pypoetry/virtualenvs/rex-lib-fAzIlxw_-py3.9/lib/python3.9/site-packages/jax/_src/interpreters/batching.py", line 270, in from_elt
return matchaxis(trace.axis_name, axis_size, x_.batch_dim, spec, x_.val)
File "/home/r2ci/.cache/pypoetry/virtualenvs/rex-lib-fAzIlxw_-py3.9/lib/python3.9/site-packages/jax/_src/interpreters/batching.py", line 1107, in matchaxis
raise TypeError(f"Output from batched function {x!r} with type "
TypeError: Output from batched function { lambda a:f32[10] b:f32[10]; c:f32[3]. let
d:f32[1] = slice[limit_indices=(1,) start_indices=(0,) strides=None] c
e:f32[] = squeeze[dimensions=(0,)] d
f:f32[10] = mul e a
g:f32[1] = slice[limit_indices=(2,) start_indices=(1,) strides=None] c
h:f32[] = squeeze[dimensions=(0,)] g
i:f32[10] = mul h b
j:f32[10] = add f i
k:f32[1] = slice[limit_indices=(3,) start_indices=(2,) strides=None] c
l:f32[] = squeeze[dimensions=(0,)] k
m:f32[10] = add j l
in (m,) } with type <class 'jax._src.core.Jaxpr'> is not a valid JAX type
Running into NaN
values during rootfinding in optimistixs errors rather than decreasing step size. This should probably also be fixed in lineax, which does not appear to check inputs (vector, and potentially operators) for non-finiteness.
Can be reproduced using the following code:
import jax
import jax.numpy as jnp
import jax.random as jrandom
import optimistix as optx
import equinox as eqx
import lineax as lx
class Model(eqx.Module):
atol: float = eqx.static_field()
rtol: float = eqx.static_field()
maxsteps: int = eqx.static_field()
def __init__(self):
self.atol = 1e-6
self.rtol = 1e-4
self.maxsteps = int(1e5)
def xdot(self, _, x, __, p, k):
dxdt = jnp.array(
[
jnp.exp(p[2] - x[0] - x[1] + p[0]) - jnp.exp(p[0]),
jnp.exp(k[0] - x[1] + p[1]) - jnp.exp(p[1]),
]
)
jax.debug.print(
"t: {t}, x: {x}, dxdt: {dxdt}", t=_, x=x, dxdt=dxdt, ordered=True
)
return dxdt
@jax.value_and_grad
def loss(self, p):
mapped_simulate = jax.vmap(
self.simulate,
in_axes=(None, 1, 1, 2),
)
n_repeat = 1
k_range_ss = 5
k_sss = (
jrandom.uniform(jrandom.PRNGKey(1), shape=(1, n_repeat))
* k_range_ss
* 2
- k_range_ss
)
y0s = jrandom.uniform(jrandom.PRNGKey(0), shape=(2, n_repeat))
yms = jrandom.uniform(
jrandom.PRNGKey(0),
shape=(1, 2, n_repeat),
)
r = mapped_simulate(p, k_sss, y0s, yms)
return jnp.sqrt(jnp.mean(jnp.square(r)))
def simulate(
self,
p: jnp.ndarray,
ts: jnp.ndarray,
k_ss: jnp.ndarray,
k_sim: jnp.ndarray,
y0: jnp.ndarray,
ym: jnp.ndarray,
):
xdot_ss = eqx.Partial(self.xdot, 0.0, k=k_ss, p=p)
solver_ss = optx.Newton(
rtol=self.rtol,
atol=self.atol,
linear_solver=lx.AutoLinearSolver(well_posed=None),
)
sol_ss = optx.root_find(
fn=xdot_ss,
y0=y0,
solver=solver_ss,
max_steps=self.maxsteps,
)
return sol_ss.value - ym
if __name__ == "__main__":
model = Model()
p = jnp.array([-50, 50, 1.0])
print(model.loss(p))
produces
Traceback (most recent call last):
File ".../test.py", line 57, in loss
r = mapped_simulate(p, k_sss, y0s, yms)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../test.py", line 75, in simulate
sol_ss = optx.root_find(
^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: The linear solver returned non-finite (NaN or inf) output. This usually means that the
operator was not well-posed, and that the solver does not support this.
If you are trying solve a linear least-squares problem then you should pass
`solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve`
assumes that the operator is square and nonsingular.
If you *were* expecting this solver to work with this operator, then it may be because:
(a) the operator is singular, and your code has a bug; or
(b) the operator was nearly singular (i.e. it had a high condition number:
`jnp.linalg.cond(operator.as_matrix())` is large), and the solver suffered from
numerical instability issues; or
(c) the operator is declared to exhibit a certain property (e.g. positive definiteness)
that is does not actually satisfy.
-------
This error occurred during the runtime of your JAX program. Setting the environment
variable `EQX_ON_ERROR=breakpoint` is usually the most useful way to debug such errors.
(This can be navigated using most of the usual commands for the Python debugger:
`u` and `d` to move through stack frames, the name of a variable to print its value,
etc.) See also `https://docs.kidger.site/equinox/api/errors/#equinox.error_if` for more
information.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
As in the paper The Geometry of Nonlinear Least Squares, with applications to Sloppy Models and
Optimization by Transtrum, Machta, and Sethna.
I see you have docs, but didn't see a link.
Hi Patrick,
I suppose, that the BestSoFar...
solvers are working as intended.
import jax.numpy as jnp
import optimistix as optx
def f(x, _):
return 1 / 2 * x
solv = optx.FixedPointIteration(10e-3, 10e-3)
sol = optx.fixed_point(f, solv, jnp.array(0.3), max_steps=10)
print(sol.value)
solv2 = optx.BestSoFarFixedPoint(solv)
sol2 = optx.fixed_point(f, solv2, jnp.array(0.3), max_steps=3) # shoots out error
print(sol2.value)
solv2
returns an XlaRuntimeError
but could it be possible instead to return the best so far value computed (here it would be at max_steps=2
) even if the latest value of sol2
doesn't satisfy there termination condition ?
If not, would there be a possibility to get the best so for value computed without throwing a runtime error ?
Hello! In my applications it is very common to optimize a function with an exhaustive grid search method. This is because our loss functions are sharped peaked and poorly behaved in a manageable subset of parameter space, so it is often best to do exhaustive search (perhaps in a clever way). I am planning on implementing this in JAX, and I am wondering if this is within the scope of optimistix
. It seems to me that there are tools in the library that would be useful for this task.
A rough outline of the implementation I am imagining is the following:
I'm not sure if this is within the scope of optimistix
, and I would totally understand if it is not. If it were to be added to the library, I suppose it could be used as a method of ultra-last resort.
Hi, looks like a very promising library, this bit in the docs got me interested:
Unlike the SciPy implementation of Newton's method, the Optimistix version also works for vector-valued (or PyTree-valued) y.
Does this mean that the function passed to the root finder has to be vectorized in the traditional sense? Do the root finders here support functions which rely on vmap
? I couldn't find anything in the docs about this. Thanks!
Hi all, thanks for the phenomenal library. We're already using it in several statistical genetics methods in my group!
I've been porting over some older code of mine to use optimistix, rather than hand-rolled inference procedures and could use some advice. Currently, I am performing some variational inference using a mix of closed-form updates for variational parameters, as well as gradient-based updates for some hyperparameters. It -roughly- works like,
while True:
eval_f = jax.value_and_grad(_infer, has_aux=True)
((value, var_params), gradient) = eval_f(hyper_param, var_params, data)
hyper_param = hyper_param + learning_rate * gradient
if converged:
break
I'd -like- to retool the above to not only report the current value, aux values (i.e. updated variational parameters), and gradient wrt hyper param, but return a -hvp- function that could be used in a Newton CG like step in Optimistix. I know of the new minimize
function, but what isn't clear is how to set up the scenario to not only report gradients, but also return a hvp
function internally without having to take two additional passes over the graph (i.e. once for value and grad, another two for hvp => forward + backward).
Is this doable? Apologies if this is somewhat nebulous--I'm happy to clarify.
Dear optimistix team,
First of all, thank you for your effort in developing optimistix. I have recently transitioned from JAXOpt, and I love it!
I was wondering if it is possible to extract the loss/ function value history from the optimistic solve? In the code example below, it is easy to evaluate the intermediate losses when using the multi_step_solve
method, but it is much less efficient than the 'single_step_solve' approach. Using a jax.lax.scan
would definitely improve the performance over using a for
but I was wondering if there is a simpler method to extract this information in optimistix.
def rastrigin(x, args):
A = 10.0
y = A * x.shape[0] + jnp.sum(x**2 - A * jnp.cos(2 * jnp.pi * x), axis=0)
return y
# How can we extract the losses for a single_step_solve?
def single_step_solve(solver, y0):
sol = optx.minimise(rastrigin, solver, max_steps=2_000, y0=y0, throw=False)
return sol.value
def multi_step_solve(solver, y0):
# This is much less efficient, but it's easy to extract losses
current_sol = y0
for i in range(2_000):
current_sol = optx.minimise(rastrigin, solver, max_steps=1, y0=current_sol, throw=False).value
return current_sol
Hi,
I'm not sure if this is the desired behavior, but BFGS directly stops if f(y0) == 0
. As an example :
from optimistix import minimise, BFGS, rms_norm
import equinox as eqx
import jax.numpy as jnp
@eqx.filter_jit
def f(y, args):
return jnp.sum(jnp.square(y))
N = 2
def f2(y, args):
return f(y, args) - N
y0 = jnp.ones((N,))
res = minimise(
f,
BFGS(rtol=1e-13, atol=1e-13, norm=rms_norm),
y0=y0,
max_steps=1024,
)
print(res.stats["num_steps"])
assert jnp.allclose(res.value, jnp.zeros((N,)))
res = minimise(
f2,
BFGS(rtol=1e-13, atol=1e-13, norm=rms_norm),
y0=y0,
max_steps=1024,
)
print(res.stats["num_steps"])
assert jnp.allclose(res.value, jnp.zeros((N,)))
This is because then the termination condition is already validated at first step.
It's not really a big problem since you just need to offset the cost by a constant but it took me a while to figure out where the problem was coming from.
Maybe the initial value for f_info
should be set to infinity, or maybe this should be mentioned in the doc ?
Not sure if this is a bug or not, but BestSoFarMinimiser
appears to not check the last step of the wrapped solver:
solver = optimistix.BestSoFarMinimiser(optimistix.BFGS(rtol=1e-5, atol=1e-5))
ret = optimistix.minimise(lambda x, _: (x - 3.)**2, solver, 0.)
print(ret.value, ret.state.state.y_eval)
0.0 3.0
Hi, thank you for the amazing library!
I was wondering if minimizing with user-specified bounds, or algorithms like projected gradient descent are supported?
If not, what would be the best practice you suggest if we are trying to solve things like argmin(norm(Ax+b)) s.t. x >=0
?
Thanks in advance!
Excited to explore the library as always!
class MiniData(NamedTuple):
X: ArrayImpl
Y: ArrayImpl
def loss_fn_per_obs(y, p):
return jnp.where(y==1.0, -jnp.log(p ), -jnp.log(1-p ))
def fn(params, args):
P = jax.nn.sigmoid(args.X @ params)
losses = jax.vmap(loss_fn_per_obs, in_axes=(0,0))(args.Y, P)
return jnp.mean(losses)
init_params = jax.random.normal(jax.random.PRNGKey(0), shape=(19,1))
data = MiniData(X=jax.random.normal(jax.random.PRNGKey(1), shape=(100, 19)),
Y= jax.random.normal(jax.random.PRNGKey(2), shape=(100, 1)))
solver = optimistix.NonlinearCG(rtol=0.01, atol=0.01)
optimistix.minimise(fn=fn, solver=solver, y0 = init_params, args=data, has_aux=False)
I am running into the following type error:
TypeError: linearize() got an unexpected keyword argument 'has_aux'
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.