Code Monkey home page Code Monkey logo

Comments (9)

patrick-kidger avatar patrick-kidger commented on May 27, 2024 1

Thank you for the issue! This was a fairly tricky one.

Ultimately I think this is a sort-of bug (or at least a questionable design decision) in jax.checkpoint. This seems to be something we can work around, however, in Equinox. As such I've opened patrick-kidger/equinox#694 to address this.

Can you give that branch a go on your actual (non-MWE) problem, and let me know if that fixes things? If so then I'll merge it.

from optimistix.

patrick-kidger avatar patrick-kidger commented on May 27, 2024

Hmm, I think I agree that sounds like a plausible root cause.

I'm still looking at this, but FWIW I've managed to reduce it to this MWE. Curiously, the type of x0 seems to affect whether a crash is generated. Right now I'm not sure why that should be!

import jax.random as jr
import jax.numpy as jnp
import jax
import optimistix as optx

jax.config.update("jax_enable_x64", True)

CRASH = True

def rf(x, g):
    return x[0], x[1] - g

def opt_2st_vec(g):
    if CRASH:
        x0 = (0.5, 0.5)
    else:
        x0 = jnp.array([0.5, 0.5])
    solver = optx.Newton(atol=1e-8, rtol=1e-8)
    solution = optx.root_find(rf, solver, x0, args=g)
    return solution.value[0]

def loss_fn(x):
    return jnp.sum(jax.vmap(opt_2st_vec)(x))

x = jr.uniform(jr.key(0), (128,))
jax.grad(loss_fn)(x)

I'll keep poking at this, but let me know if you find anything sooner than that.

from optimistix.

patrick-kidger avatar patrick-kidger commented on May 27, 2024

Okay, got it! Looks like grad-of-vmap-of-<a linear_solve that we only some of the outputs from> threaded the needle to hit a case we didn't handle correctly.

I've opened patrick-kidger/equinox#671 and patrick-kidger/lineax#84 to fix this. (Although the Lineax CI will fail as it can't see the updated Equinox PR.)
I'm hoping to do new Equinox and Lineax releases, including these fixes, in the next few days.

from optimistix.

FFroehlich avatar FFroehlich commented on May 27, 2024

Fantastic thanks for the quick fix & workaround.

from optimistix.

johannahaffner avatar johannahaffner commented on May 27, 2024

Hi @patrick-kidger and @FFroehlich,

I might have a related issue. It persists even with the fixes in equinox@dev and lineax@vprim_transpose_symbolic_zeros.

I'm vmapping a nonlinear solve (parameter estimation for ODEs across many individuals, each with their own parameter set).

I get ValueError: Unexpected batch tracer. This operation cannot be vmap'd., raised by _cannot_batch in equinox/internal/_nontraceable.py, which calls jax.interpreters.batching. (The whole thing is very long.)

The error goes away if I use a for-loop, and it also goes away with a nonlinear solver that does not use gradients (Nelder-Mead).

I'm working on an MWE, starting by adapting yours from above, @patrick-kidger.

For added context: I have a nested hierarchical model composed of equinox modules, and I now want to optimize the final layer (population level) to leverage jax' SPMD capabilities.

from optimistix.

johannahaffner avatar johannahaffner commented on May 27, 2024

Here comes the MWE.

import jax.random as jr
import jax.numpy as jnp
import jax
import optimistix as optx
import equinox as eqx
import diffrax as dfx

jax.config.update("jax_enable_x64", True)

GRAD = True
VMAP = True

def dydt(t, y, args):
    k = args
    return -k * y

class Individual(eqx.Module):
    term: dfx.ODETerm
    solver: dfx.Tsit5
    y0: float
    t0: int
    t1: int
    dt0: int
    saveat: dfx.SaveAt
    
    def __init__(self, ode_system, y0):
        self.term = dfx.ODETerm(ode_system)
        self.solver = dfx.Tsit5()

        self.y0 = y0
        self.t0 = 0
        self.t1 = 10
        self.dt0 = 0.01
        self.saveat = dfx.SaveAt(ts=jnp.arange(self.t0, self.t1, self.dt0))

    def simulate(self, args):
        sol = dfx.diffeqsolve(
            self.term, 
            self.solver, 
            self.t0, 
            self.t1, 
            self.dt0, 
            self.y0, 
            args=args, 
            saveat=self.saveat,
            adjoint=dfx.DirectAdjoint(),
        )
        return sol.ys

    def estimate_param(self, initial_param, ydata, solver):
        args = (self.simulate, ydata)

        def residuals(param, args):
            model, ydata = args
            yfit = model(param)
            res = ydata - yfit
            return res

        sol = optx.least_squares(
            residuals,
            solver, 
            initial_param,
            args=args,
        )
        return sol.value

m = Individual(dydt, 10.)

def generate_data(individual_model):  # Noise-free
    k0s = (0.3, 0.5, 0.7)  # Vary parameters
    ydata = []
    for k0 in k0s:
        y = individual_model.simulate(k0)
        ydata.append(y)
    return jnp.array(ydata)
    
data = generate_data(m)
initial_k0 = 0.5  # Starting point for all 

def run(initial_param, individual_model, individual_data):
    if GRAD:
        solver = optx.LevenbergMarquardt(rtol=1e-07, atol=1e-07)
    else: 
        solver = optx.NelderMead(rtol=1e-07, atol=1e-07)
    if VMAP:
        get_params = jax.vmap(individual_model.estimate_param, in_axes=(None, 0, None))
        params = get_params(initial_param, individual_data, solver)
    else:
        params = [individual_model.estimate_param(initial_param, y, solver) for y in individual_data]
    return params
            
params = run(initial_k0, m, data)
params

And this is how it behaves (with equinox@dev and lineax@vprim_transpose_symbolic_zeros).

If (GRAD and VMAP): ValueError: Unexpected batch tracer. This operation cannot be vmap'd.
Works for the other three combinations.

from optimistix.

johannahaffner avatar johannahaffner commented on May 27, 2024

A few post scriptums:

  1. eqx.filter_vmap does not make a difference.

  2. I noticed that your example also uses a combination of vmap/diffeqsolve/least squares. Since you batch inside the residuals function, this means you have a composition grad(vmap)), which works. I have vmap(grad), which does not. (Tried running it with and without jax.config.update('jax_disable_jit', True).)

  3. Replacing Individual.simulate(...) with Individual.__call__(...) and defining a function estimate_param outside of the Individual class also does not change things. I had been wondering if it is a problem that things are happening inside of bound methods.

  4. The same also happens with BFGS. GradientDescent and NonlinearCD do not converge, so I can't judge them using this MWE. However, it does not happen with GaussNewton.

  5. I don't think it is in lineax. Gauss Newton works with QR, which is the default for Levenberg-Marquardt.

from optimistix.

johannahaffner avatar johannahaffner commented on May 27, 2024

It works!

Thank you so much for taking a look at this, even during the Easter holidays. It is very much appreciated!

I want to add that I am new to the ecosystem and enjoy it very much, it is so well thought-through and documented. I hope I can start contributing something other than questions as I get to know it better :)

from optimistix.

patrick-kidger avatar patrick-kidger commented on May 27, 2024

Awesome stuff, I'm glad to hear it! I hope you enjoy using the ecosystem. :)

On this basis I've just merged the fix, so it will appear in the next release of Equinox.

from optimistix.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.