Code Monkey home page Code Monkey logo

Comments (4)

jakevdp avatar jakevdp commented on June 9, 2024

Hi - thanks for the question! It looks like you're hitting a variant of this issue: https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where

Expressed in terms of where rather than masking, your code is roughly equivalent to this:

def f(x, pre_filter=False):
    delta = jnp.array((0, 1)) * x
    if pre_filter:
        return jnp.sum(jnp.where(delta > 0, delta, 0) ** 0.5)
    else:
        return jnp.sum(jnp.where(delta > 0, delta**0.5, 0))

In the first case, you only apply the square root to positive values. In the second case, you apply the square root to negative values, which generate NaNs. Outside autodiff, the filter works the same each way. But inside autodiff, the autodiff rule must consider the contributions of both filtered and non-filtered values to the gradient. You can read a more complete description of this at the link above.

Does that help answer your question?

from jax.

agudym avatar agudym commented on June 9, 2024

Thanks so much for the swift response!

That's definitely a feature, not a bug, already described in details (for Tensorflow too), my bad! The most relevant comprehensive info (just refreshing the topic :) ):
#1052 (comment)
https://github.com/tensorflow/probability/blob/main/discussion/where-nan.pdf

Summarizing all the stuff, my understanding is the following, the problem is:

def f(x): # Our "intuition" on how it SHOULD work
  if x > 0:
    return x**0.5 #sqrt(x)
  else:
    return Const # some stub
jax.grad(f)(0.0)
# 0 - and it really works CORRECTLY, i.e. we get "d Const / d x = 0" at "0"

However with jnp.where(or similar), we get "a surprize":

def f(x): # x.size = 1
   return jnp.where(x > 0, x**0.5, Const)
jax.grad(f)(0.0)
# nan - booom!

So the workaround is to avoid Nans in ANY "execution branch" (even it's supposed to be filtered later by jnp.where(or similar):

def f(x): # x.size = 1
   return jnp.where(x > 0, jnp.where(x > 0, x, 0)**0.5, Const)
jax.grad(f)(0.0)
# 0 - OK, i.e. we get "d Const / d x = 0" at "0"

Look like that...

p.s. Frankly speaking from the docs read it's still not obvious WHY it is so. Mentioned thing that "0 * nan = nan" isn't really a "theoretical limitation" to the problem imho. So I'm considering the problem to be an implementation "feature" (mb a known issue?). Or maybe a special NOTE with config.update("jax_debug_nans", True) can appear if that's happened after jnp.where ?

from jax.

mattjj avatar mattjj commented on June 9, 2024

Or maybe a special NOTE with config.update("jax_debug_nans", True) can appear if that's happened after jnp.where ?

Interesting idea! @jakevdp wdyt?

Mentioned thing that "0 * nan = nan" isn't really a "theoretical limitation" to the problem imho.

That is the fundamental root of the issue, but maybe we can connect the dots more concretely:

  1. the vjp of any function lambda x:T: ... (think of T as including the shape) must be a function which produces a value of type T
  2. the entry of a vjp value corresponding to an input or intermediate for which the value doesn't affect the output must be zero
  3. the vjp of f = lambda y:f32[2]: y[1] is lambda zbar:f32[]: jnp.zeros(2, 'f32').at[1].set(zbar) (has to be a dense array because of Claim 1, and has to be a dense array of zeros because of Claim 2)
  4. the vjp of g = lambda x:f32[]: c:f32[2] * x for any constant c is lambda ybar:f32[2]: (ybar * c:f32[2]).sum(), where the sum arises from the broadcast
  5. the vjp of the composed function lambda x: f(g(x)) is correct no matter the value of the constant c so long as 0 * x = 0 (zero scaling) for all possible array entries x and 0 + x = x (zero vector) for all possible array entries x, but if we had a nan value in c[0] then the vjp of the composition is incorrect

Indeed the vjp of just g by itself always produces a nan value if c[0] is nan, but it's not clear that that's a problem because there's a nan in the output of g. It's only when we compose it with f, which drops the nan from the output, that it's clear things are really going wrong: just having nans in intermediates, not outputs, can break VJPs.

The last claim is really what we mean when we say the root of the issue is allowing some value x for which x * 0 != 0. If we didn't have such values, then this problem couldn't arise!

We might be able to fix this by changing Claim 1, basically by letting cotangents be sparse array types even when the primals are dense. But we've never gone down that path.

What do you think?

from jax.

agudym avatar agudym commented on June 9, 2024

Interesting, thanks so much for the clarifications!

Lets check that I've got your point :)

$f(\boldsymbol{g}(x)) \in \mathbb{R}$ - some functions composition, with scalar $x \in \mathbb{R}$ input,

$\boldsymbol{g}(x) = (g_1(x), g_2(x))^T \in \mathbb{R}^{2 \times 1}$ - intermediate vector value,

with following derivatives:

$\frac{\partial\ f}{\partial\ \boldsymbol{g}}=\left (\frac{\partial\ f}{\partial\ g_1}, \frac{\partial\ f}{\partial\ g_2}\right ) \in \mathbb{R}^{1 \times 2}$ - gradient of scalar $f$, with corresponding vector-jacobian-product of the form: $\upsilon_g \left (\frac{\partial\ f}{\partial\ g_1}, \frac{\partial\ f}{\partial\ g_2}\right ) \in \mathbb{R}^{1 \times 2}$ with $\upsilon_g \in \mathbb{R}$

$\frac{\partial\ \boldsymbol{g}}{\partial\ x}=\left (\frac{\partial\ g_1}{\partial\ x}, \frac{\partial\ g_2}{\partial\ x}\right )^T \in \mathbb{R}^{2 \times 1}$ - gradient of vector $\boldsymbol{g}$, with corresponding vector-jacobian-product of the form: $\boldsymbol{\upsilon_f}^T \left (\frac{\partial\ g_1}{\partial\ x}, \frac{\partial\ g_2}{\partial\ x}\right )^T \in \mathbb{R}$ with $\boldsymbol{\upsilon_f} = (\upsilon_f^1 , \upsilon_f^2)^T \in \mathbb{R}^{2 \times 1}$

and finally

$\frac{\partial\ f}{\partial\ x} = \frac{\partial\ f}{\partial\ \boldsymbol{g}} \frac{\partial\ \boldsymbol{g}}{\partial\ x} = \frac{\partial\ f}{\partial\ g_1} \frac{\partial\ g_1}{\partial\ x} + \frac{\partial\ f}{\partial\ g_2} \frac{\partial\ g_2}{\partial\ x} \in \mathbb{R}$

I hope the notation is OK, and I've understood your clarifications correctly, so going straight to the point - the problem the ORIGINAL EXAMPLE is that because

$\frac{\partial\ g_1}{\partial\ x} = \frac{\partial\ \sqrt{x \cdot 0}}{x} = nan$

is being multiplied by derivative of our "filter-function" $f(x)=g_2(x)$ or even $f(x)=0 \cdot g_1(x) + g_2(x)$:

$\frac{\partial\ f}{\partial\ g_1} = 0$ (because $f$ doesn't depend on $g_1$), resulting in $\frac{\partial\ f}{\partial\ x} = 0 \cdot nan + ... = nan$.
Is that correct ?

In other words, simplifying all the above stuff, one may assume that jnp.where is kinda a "ReLU" function, for example:

$where(y > 0,\ y,\ 0) =f(y) = (y\ if\ y > 0\ else\ 0) = (y\ if\ y > 0\ else\ 0 \cdot y)$

and having to differentiate $f(y) = f(\sqrt{x}) = (\sqrt{x}\ \ if\ \sqrt{x} > 0\ \ else\ 0) = (\sqrt{x}\ \ if\ \sqrt{x} > 0\ \ else\ \ 0 \cdot \sqrt{x})$ we anyway end up with the "classic" chainrule:

$\frac{\partial\ f}{\partial\ x} = \frac{\partial\ f}{\partial\ y} \frac{\partial\ \sqrt{x}}{\partial\ x} = (1\ if\ y > 0\ else\ 0) \cdot \frac{\partial\ \sqrt{x}}{\partial\ x}$

and finally with the evaluation the above expression we get $\ 0 \cdot nan$ if $x = 0$.
Does this make sense ?

So my NAIVE thought about "theoretical limitation" was based on assumption that current implementation of jnp.where under the hood is something like the binded code

for(size_t i = 0; i < y.size(); ++i)
   if(y[i].value > 0) // y[i] is sort of dual number
     //get derivative of y
   else
    // ignore y[i], consider another constant value

Where we drop "nan-execution branch" with "if rather then 0-multiplication"

p.s. Merry Xmas Everybody!

from jax.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.