Code Monkey home page Code Monkey logo

Comments (5)

adarob avatar adarob commented on August 22, 2024

@hwchung27 @levskaya

from t5x.

levskaya avatar levskaya commented on August 22, 2024

Hi - the custom_vjp approach is intentional to exactly control the precision, stability, and performance of the final compiled function. If you need to do forward-diff, I would just recommend substituting the straightforward loss+z-loss definition and let JAX take care of the fwd AD.

from t5x.

salayatana66 avatar salayatana66 commented on August 22, 2024

I am not completely convinced. So JVP vs VJP should be about push-forward vs pull-back, i.e. if we apply the Jacobian J on the left or right (tangent vs cotangent vectors). Numerical stability should be about getting the correct map params, inputs -> J. So if numerical stability is a problem with VJP it would a problem with JVP too, right? To make things more concrete, let's look at the jacobian of the softmax (which is a square matrix so push-forward and pull-back can be isomorphisms). Then by numerical instability say some rows are incorrectly computed to 0 (say J was supposed to be nonsingular but now turns out to be singular). Then with a custom VJP we would fix the issue when pulling back cotangent vectors, but then we get the wrong answer when pushing forward with a JVP.
On a practical note, what would be a good place to modify the implementation? Implementing a custom_jvp instead of a custom_vjp, defining a new method in the model class, subclassing the model class?

from t5x.

levskaya avatar levskaya commented on August 22, 2024

Sorry if I wasn't clear:

The situation is that we have a manually crafted VJP for the "cross-entropy + z-loss" loss function that we know from long experience is performant and stable for training with the BF16 half-precision type that we frequently use. That's why we hard-code the VJP this way. We don't particularly want to mess around with trying to decompose this into separate JVP + Transpose rules (and I'm not yet even certain that JAX fully supports custom transpose rule-definitions yet, though it will soon if not already).

If you want to get the VJP out of your way for forward-diffs, etc. Just use a copy of the original function instead:

def no_custom_vjp_xent_loss_fn(logits, z_loss):
  logits_sum = jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True)
  log_softmax = logits - logits_sum
  loss = -jnp.sum(targets * log_softmax, axis=-1)
  log_z = jnp.squeeze(logits_sum, axis=-1)
  total_z_loss = z_loss * jax.lax.square(log_z)
  loss += total_z_loss
  return loss

Our custom-VJP version is only really needed for training, not any extra analytical work you might want to do. The above is mathematically identical to the custom-VJP version, and JAX should be able to apply any mix of fwd/rev AD methods to it.

from t5x.

salayatana66 avatar salayatana66 commented on August 22, 2024

Thanks, I'll factor this out how you suggested.

from t5x.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.