Comments (4)
Hi - thanks for the question! It looks like you're hitting a variant of this issue: https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where
Expressed in terms of where
rather than masking, your code is roughly equivalent to this:
def f(x, pre_filter=False):
delta = jnp.array((0, 1)) * x
if pre_filter:
return jnp.sum(jnp.where(delta > 0, delta, 0) ** 0.5)
else:
return jnp.sum(jnp.where(delta > 0, delta**0.5, 0))
In the first case, you only apply the square root to positive values. In the second case, you apply the square root to negative values, which generate NaNs. Outside autodiff, the filter works the same each way. But inside autodiff, the autodiff rule must consider the contributions of both filtered and non-filtered values to the gradient. You can read a more complete description of this at the link above.
Does that help answer your question?
from jax.
Thanks so much for the swift response!
That's definitely a feature, not a bug, already described in details (for Tensorflow too), my bad! The most relevant comprehensive info (just refreshing the topic :) ):
#1052 (comment)
https://github.com/tensorflow/probability/blob/main/discussion/where-nan.pdf
Summarizing all the stuff, my understanding is the following, the problem is:
def f(x): # Our "intuition" on how it SHOULD work
if x > 0:
return x**0.5 #sqrt(x)
else:
return Const # some stub
jax.grad(f)(0.0)
# 0 - and it really works CORRECTLY, i.e. we get "d Const / d x = 0" at "0"
However with jnp.where(or similar), we get "a surprize":
def f(x): # x.size = 1
return jnp.where(x > 0, x**0.5, Const)
jax.grad(f)(0.0)
# nan - booom!
So the workaround is to avoid Nans in ANY "execution branch" (even it's supposed to be filtered later by jnp.where(or similar):
def f(x): # x.size = 1
return jnp.where(x > 0, jnp.where(x > 0, x, 0)**0.5, Const)
jax.grad(f)(0.0)
# 0 - OK, i.e. we get "d Const / d x = 0" at "0"
Look like that...
p.s. Frankly speaking from the docs read it's still not obvious WHY it is so. Mentioned thing that "0 * nan = nan" isn't really a "theoretical limitation" to the problem imho. So I'm considering the problem to be an implementation "feature" (mb a known issue?). Or maybe a special NOTE with config.update("jax_debug_nans", True)
can appear if that's happened after jnp.where
?
from jax.
Or maybe a special NOTE with config.update("jax_debug_nans", True) can appear if that's happened after jnp.where ?
Interesting idea! @jakevdp wdyt?
Mentioned thing that "0 * nan = nan" isn't really a "theoretical limitation" to the problem imho.
That is the fundamental root of the issue, but maybe we can connect the dots more concretely:
- the vjp of any function
lambda x:T: ...
(think ofT
as including the shape) must be a function which produces a value of typeT
- the entry of a vjp value corresponding to an input or intermediate for which the value doesn't affect the output must be zero
- the vjp of
f = lambda y:f32[2]: y[1]
islambda zbar:f32[]: jnp.zeros(2, 'f32').at[1].set(zbar)
(has to be a dense array because of Claim 1, and has to be a dense array of zeros because of Claim 2) - the vjp of
g = lambda x:f32[]: c:f32[2] * x
for any constantc
islambda ybar:f32[2]: (ybar * c:f32[2]).sum()
, where the sum arises from the broadcast - the vjp of the composed function
lambda x: f(g(x))
is correct no matter the value of the constantc
so long as0 * x = 0
(zero scaling) for all possible array entries x and0 + x = x
(zero vector) for all possible array entries x, but if we had a nan value inc[0]
then the vjp of the composition is incorrect
Indeed the vjp of just g
by itself always produces a nan value if c[0]
is nan, but it's not clear that that's a problem because there's a nan
in the output of g
. It's only when we compose it with f
, which drops the nan from the output, that it's clear things are really going wrong: just having nans in intermediates, not outputs, can break VJPs.
The last claim is really what we mean when we say the root of the issue is allowing some value x
for which x * 0 != 0
. If we didn't have such values, then this problem couldn't arise!
We might be able to fix this by changing Claim 1, basically by letting cotangents be sparse array types even when the primals are dense. But we've never gone down that path.
What do you think?
from jax.
Interesting, thanks so much for the clarifications!
Lets check that I've got your point :)
with following derivatives:
and finally
I hope the notation is OK, and I've understood your clarifications correctly, so going straight to the point - the problem the ORIGINAL EXAMPLE is that because
is being multiplied by derivative of our "filter-function"
Is that correct ?
In other words, simplifying all the above stuff, one may assume that jnp.where
is kinda a "ReLU" function, for example:
and having to differentiate
and finally with the evaluation the above expression we get
Does this make sense ?
So my NAIVE thought about "theoretical limitation" was based on assumption that current implementation of jnp.where
under the hood is something like the binded code
for(size_t i = 0; i < y.size(); ++i)
if(y[i].value > 0) // y[i] is sort of dual number
//get derivative of y
else
// ignore y[i], consider another constant value
Where we drop "nan-execution branch" with "if rather then 0-multiplication"
p.s. Merry Xmas Everybody!
from jax.
Related Issues (20)
- Conditional array update on GPU using jnp.where vs fori_loop HOT 2
- Large numerical error when using vmap and bfloat16/tensorfloat32 matmul precision, only on A100 GPU
- TPU XlaRuntimeError involving nn.Conv, transpose and avg_pool
- `jax.vmap(jax.pure_callback(...), in_axes=1)` is broken
- Errors when building on AMD GPU
- Error when lowering pallas kernel: 'jaxlib.triton.dialect' has no attribute 'permute' HOT 5
- Executing genrule @tsl//tsl/cuda:cudnn_stub_gen failed
- XLA "cannot remove instruction" when compiling big MoE model HOT 3
- `ensure_compile_time_eval` does not error out for traced arrays HOT 3
- Vectorised operation on string arrays? HOT 1
- Pallas Tutorial outputs RESOURCE_EXHAUSTED HOT 1
- Unimplemented primitive in Pallas: slice HOT 1
- Marking non-trainable / frozen parameters HOT 3
- jax.clear_backends() does not release device memory
- XLA Check Failed: options.is_autotuning_compilation HOT 1
- Pallas Kernel using Smem/SReg failed to lower HOT 2
- Batch dependence of `jax.numpy.linalg.solve` HOT 10
- Matrix-vector multiply: ValueError: all dimensions of x and y must be >= 16 HOT 1
- Stochastic but high probabiltiy crash after saving checkpoints (via array_ser) on TPU HOT 1
- Pallas slicing makes Jupyter Kernel Crash HOT 8
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from jax.