Comments (4)
Hey there! Thanks for the issue.
Would you be able to condense your code down to a single MWE? (Preferably around 20 lines of code.) For example we probably don't need the details of your training loop, the fact that it's batched, etc. Moreover I'm afraid this code won't run -- if nothing else, it currently doesn't have any import statements.
from optimistix.
Hi @patrick-kidger,
I updated the original post with a condensed MWE.
In the above code, when I use optx.RecursiveCheckpointAdjoint()
, I am able to recover the correct gradients. However, when I use optx.ImplicitAdjoint
with a solver specified as CG, the gradients are all exactly zero. To be fair, I was not expecting this to work out of the box because 1. adapt_fn
does not find the exact solution to the inner optimization problem, 2. even for a small network, this seems to be a rather difficult calculation. However, the jax-opt example I shared above indicates that the gradients can be calculated correctly in a similar scenario using CG:
Because of this, I started wondering if there is a fundamental difference in how implicit adjoints are calculated in the two packages. My instinct is that the mismatch might have to do with handling of higher-order terms but I am curious to hear your opinion and whether it is something that can be quickly patched.
from optimistix.
So in a minimisation problem, the solution stays constant as you peturb the initial parameters. Regardless of where you start, you should expect to converge to the same solution! So in fact a zero gradient is what is expected. (Imagine finding argmin_x x^2
. It doesn't matter whether you start at x=1
or x=1.1
; either way your output will be x=0
.)
The fact that you get a nonzero gradient via RecursiveCheckpointAdjoint
will be because of the fact that you are taking so few steps that you are not actually converging to the minima at all. (In the above example, you might only converge as far as x=0.5
or x=0.6
.) So I think for a meta-learning use-case, then probably RecursiveCheckpointAdjoint
is actually the correct thing to be doing!
The fact that JAXopt appears to do otherwise is possibly a bug in JAXopt. (?)
That aside, some comments on your implementation:
- I would recommend against a
CG
solver if you can help it. This is a fairly numerically inaccurate / unstable solver. - Other than the discussion above, you also appear to be computing a gradient through
sol.aux
. In its current meaning this is theaux
from the final step of the solve. This means gradients through this are actually not defined mathematically, since this is an internal detail of the solver and unrelated to the implicit function theorem! So (a) don't try to use this in gradient calculations when usingImplicitAdjoint
, but also (b) I can still totally see that this is a footgun without guardrails, and you've prompted us to think whether there's a way to adjust this into something better.
from optimistix.
Hi Patrick,
Thank you for your thorough response. It is true that the solution to a minimisation problem is independent of the initial parameters and should lead to zero gradients. As you noted, the gradients from RecursiveCheckpointAdjoint
are non-zero in the above MWE because we take very few optimisation steps. I set that number low to emulate a typical bi-level meta-learning setup where, in the inner loops, we do not fully optimise the model parameters for each task but rather take just a few steps of optimisation. This is because the goal is not to find the true optimum for any single task but rather to optimise the initialisation such that the model can quickly adapt to related tasks. In this case, the gradients through the inner loop, must be non-zero for the initialisation to evolve across outer loop iterations. Also, iMAML has a special regularising term to ensure meta-gradients are not non-zero for larger number of inner steps.
So, there is no bug in JAXopt, and it only appeared so because of my incomplete explanation. The iMAML paper that I am following uses CG because it avoids forming a Hessian matrix. It also seems that CG-like iterative solvers are used quite extensively within JAXopt - as far as I understand this again has to do with their matrix-free nature. Optimistix, on the other hand seems to be using direct solvers as a default.
from optimistix.
Related Issues (20)
- Can't use Optimistix solvers with `eqx.Module`s and filtered transformations HOT 2
- BestSoFarMinimiser behavior HOT 1
- correct name of the exception class that Equinox uses for runtime errors HOT 1
- Error in "optimistix/docs/examples /optimise_diffeq.ipynb" HOT 1
- Issue with vmap `optx.least_squares`. HOT 2
- grad of vmap of function which wraps an optax solver occasionally fails HOT 2
- `BestSoFar...` wanted behavior ? HOT 1
- Classical newton methods HOT 6
- Non-finite values in the root function are not handled well HOT 2
- Will constrained optimization be supported? HOT 4
- Behavior of BFGS HOT 2
- pytree output structure mismatch error in backprop during vmap HOT 9
- Incompatibility of least_squares and custom_vjp HOT 2
- Extracting intermediate function values/ losses from the solve HOT 4
- Would an exhaustive grid search have a place in `optimistix`? HOT 2
- Using `optimistix` with an `equinox` model HOT 2
- Incompatibility with jax 0.4.27 HOT 1
- Possibly of interest HOT 1
- Unexpected behaviour with JAX version HOT 3
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from optimistix.