Code Monkey home page Code Monkey logo

Comments (28)

mattjj avatar mattjj commented on May 29, 2024 1

This is roughly repeating what @dlwh just said, but I just figured it out and came back to explain: this use of custom_vjp is buggy in that the flash_attention_forward output needs to be a pair where the first element has the same type as the output of flash_attention. Yet we can see that where flash_attention includes three arrays, the first element of the return value of flash_attention_forward only has one array.

There's a JAX bug in that this was a terrible error message to raise, but the fundamental bug is in that use of custom_vjp.

from flash-attention-jax.

mattjj avatar mattjj commented on May 29, 2024 1

google/jax#12611 should improve the error message we got here! With the same repro (i.e. before the fix #7 was merged here), the error will be:

TypeError: Custom VJP fwd rule flash_attention_forward for function
flash_attention must produce a pair (list or tuple of length two) where the
first element represents the primal output (equal to the output of the
custom_vjp-decorated function flash_attention) and the second element
represents residuals (i.e. values stored from the forward pass for use on the
backward pass), but instead the fwd rule output's first element had
container/pytree structure:
    float32[3,16,5,19]
while the custom_vjp-decorated function flash_attention had output
container/pytree structure:
    (float32[3,16,5,19], (float32[3,16,5], float32[3,16,5])).

from flash-attention-jax.

GallagherCommaJack avatar GallagherCommaJack commented on May 29, 2024

can confirm that this error also appears under jax.lax.scan

example here:

q = jax.random.normal(keys[0], (l, b, lq, h, d))
k = jax.random.normal(keys[1], (l, b, lkv, h, d))
v = jax.random.normal(keys[2], (l, b, lkv, h, d))
mask = jax.random.bernoulli(keys[3], 0.5, (l, b, lkv))


def scan_fn(carry, qkv):
    out = flash_attention(*qkv)[0]
    carry += out
    return carry, out


@jax.jit
def bench_flash_bwd(q, k, v, mask):
    return jax.grad(
        lambda q, k, v, mask: jnp.sum(
            jax.lax.scan(
                scan_fn,
                jnp.zeros_like(q[0]),
                (q, k, v, mask),
            )[0],
        )
    )(q, k, v, mask)


bench_flash_bwd(q, k, v, mask)

from flash-attention-jax.

mattjj avatar mattjj commented on May 29, 2024

Thanks for raising this! It looks like a JAX core bug most likely.

Could you provide a self-contained runnable repro, in particular including the import or definition for flash_attention? (Sorry, I'm not the developer of this repo, so I'm not familiar with that function.)

from flash-attention-jax.

GallagherCommaJack avatar GallagherCommaJack commented on May 29, 2024
from flash_attention_jax import flash_attention

from flash-attention-jax.

dlwh avatar dlwh commented on May 29, 2024

ran into this and failed to upstream. The trick to fix it is to basically do this:

stanford-crfm/levanter@a2828ce#diff-658abe908dd5cd256efe9370e7ec2ae9fa2dcdca586a5f886940331e7b56dd09R129-R132

from flash-attention-jax.

GallagherCommaJack avatar GallagherCommaJack commented on May 29, 2024

@dlwh looks like you also ran an autoformatter so there's a ton of other changes here - can you say a bit more about how you fixed it?

from flash-attention-jax.

dlwh avatar dlwh commented on May 29, 2024

Yeah sorry, the line linked is the key one. Basically just rename the method called "causal_flash_attention" to "_causal_flash_attention" and make causal_flash_attention return just the first result. Then make flash_attention_forward call _causal_flash_attention instead, and you're done.

@custom_vjp
def causal_flash_attention(q, k, v):
+    return _causal_flash_attention(q, k, v)[0]
+
+
+def _causal_flash_attention(q, k, v):

from flash-attention-jax.

GallagherCommaJack avatar GallagherCommaJack commented on May 29, 2024

won't that make flash_attention always do causal masking? I'm using this in a context where that's not appropriate

from flash-attention-jax.

dlwh avatar dlwh commented on May 29, 2024

you'll need to make the analogous change to flash_attention then. as @mattjj said it's really just a buggy use of custom_vjp. (Though despite it not running the code was otherwise correct according to my gradient testing!)

from flash-attention-jax.

mattjj avatar mattjj commented on May 29, 2024

Shall I send a PR fix to this repo (maybe you both could review it), and then separately fix the JAX error message? Or @dlwh do you want to send the fix to this repo?

from flash-attention-jax.

dlwh avatar dlwh commented on May 29, 2024

I can probably get to it tonight or tomorrow, but I'm about to go dark for several hours. Totally up to you!

from flash-attention-jax.

mattjj avatar mattjj commented on May 29, 2024

I'll take the first stab, and cc you!

from flash-attention-jax.

GallagherCommaJack avatar GallagherCommaJack commented on May 29, 2024

so the relevant fix would be to replace

return out, (q, k, v, key_mask, out, row_sum, row_max)
with

    return (out, (row_sum, row_max)), (q, k, v, key_mask, out, row_sum, row_max)

?

from flash-attention-jax.

GallagherCommaJack avatar GallagherCommaJack commented on May 29, 2024

interesting that this works with grad outside of scan and remat - probably it should fail under grad alone without either of those?

from flash-attention-jax.

mattjj avatar mattjj commented on May 29, 2024

@GallagherCommaJack Yes, that'd work! It's probably the simplest fix, though we could also look at the call sites of flash_attention to see if some other organization would be more natural.

What's a repro for the behavior you're describing? I tried removing jax.checkpoint from the repro in the OP and I still got an error. That is, this still errors for me:

import jax
import jax.numpy as jnp

from flash_attention_jax import flash_attention


b = 3
lq = 16
lkv = 17
h = 5
d = 19
keys = jax.random.split(jax.random.PRNGKey(0), 4)
q = jax.random.normal(keys[0], (b, lq, h, d))
k = jax.random.normal(keys[1], (b, lkv, h, d))
v = jax.random.normal(keys[2], (b, lkv, h, d))
mask = jax.random.bernoulli(keys[3], 0.5, (b, lkv))

@jax.jit
def bench_flash_bwd(q, k, v, mask):
    return jax.grad(lambda x: jnp.sum(flash_attention(x, k, v, mask)[0]))(q)


bench_flash_bwd(q, k, v, mask)

from flash-attention-jax.

mattjj avatar mattjj commented on May 29, 2024

Ah, I think it was just a shape bug; if I sent lq = lvk = 16 then I see what you mean.

I think by adding the better JAX error message I described, we'll catch this much earlier and get an error in both cases. I'll be sure to test both with and without checkpoint/scan.

from flash-attention-jax.

mattjj avatar mattjj commented on May 29, 2024

Yes, that'd work!

Actually, I think it would not work just because the callers expect only a single output there.

I think the issue here was that the custom_vjp-decorated function (ie the "primal function") didn't agree with the custom_vjp rule (i.e. their output types didn't agree in the way that they should), but when we only use grad (possibly together with jit) we never actually run the primal function; we only run its forward rule. When grad is applied, we only actually run the primal function when under a jax.checkpoint or jax.scan (or jax.cond etc); that's just because of a JAX implementation detail (these are "initial-style higher-order primitives") which is usually invisible, except apparently when there's a type error in a custom_vjp rule!

from flash-attention-jax.

GallagherCommaJack avatar GallagherCommaJack commented on May 29, 2024

with the fix it's working with lq = lkv under jax.checkpoint!
still fails with lq != lkv which I'm trying to debug now

from flash-attention-jax.

GallagherCommaJack avatar GallagherCommaJack commented on May 29, 2024

midjourney@f690412

from flash-attention-jax.

GallagherCommaJack avatar GallagherCommaJack commented on May 29, 2024

the error with lq = 16; lkv = 17 is TypeError: add got incompatible shapes for broadcasting: (5, 3, 17, 19), (5, 3, 16, 19).

full backtrace:

TypeError                                 Traceback (most recent call last)
Cell In [5], line 22
     18 @jax.jit
     19 def bench_flash_bwd(q, k, v, mask):
     20     return jax.grad(jax.checkpoint(lambda x: jnp.sum(flash_attention(x, k, v, mask)[0])))(q)
---> 22 bench_flash_bwd(q, k, v, mask)

    [... skipping hidden 14 frame]

Cell In [5], line 20, in bench_flash_bwd(q, k, v, mask)
     18 @jax.jit
     19 def bench_flash_bwd(q, k, v, mask):
---> 20     return jax.grad(jax.checkpoint(lambda x: jnp.sum(flash_attention(x, k, v, mask)[0])))(q)

    [... skipping hidden 30 frame]

File ~/code/flash-attention-jax/flash_attention_jax/flash_attention.py:172, in flash_attention_backward(res, do)
    169     dq_chunk, dk_chunk, dv_chunk = _query_chunk_flash_attention_backward(q_chunk, k, v, key_mask, o_chunk, do_chunk, l_chunk, m_chunk)
    170     return (chunk_idx + chunk_sizes, dk + dk_chunk, dv + dv_chunk), dq_chunk
--> 172 (_, dk, dv), dq = lax.scan(chunk_scanner, init = (0, dk, dv), xs = None, length = math.ceil(q_len / Q_CHUNK_SIZE))
    174 dq = rearrange(dq, 'c n b h d -> b h (c n) d')
    175 dk, dv = map(lambda t: rearrange(t, 'n b h d -> b h n d'), (dk, dv))

    [... skipping hidden 11 frame]

File ~/code/flash-attention-jax/flash_attention_jax/flash_attention.py:170, in flash_attention_backward.<locals>.chunk_scanner(carries, _)
    167 do_chunk = lax.dynamic_slice(do, (chunk_idx, batch, heads, 0), slice_sizes = (chunk_sizes, batch, heads, do.shape[-1]))
    169 dq_chunk, dk_chunk, dv_chunk = _query_chunk_flash_attention_backward(q_chunk, k, v, key_mask, o_chunk, do_chunk, l_chunk, m_chunk)
--> 170 return (chunk_idx + chunk_sizes, dk + dk_chunk, dv + dv_chunk), dq_chunk

    [... skipping hidden 1 frame]

File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-doaswkte-py3.9/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py:4658, in _defer_to_unrecognized_arg.<locals>.deferring_binary_op(self, other)
   4656 args = (other, self) if swap else (self, other)
   4657 if isinstance(other, _accepted_binop_types):
-> 4658   return binary_op(*args)
   4659 if isinstance(other, _rejected_binop_types):
   4660   raise TypeError(f"unsupported operand type(s) for {opchar}: "
   4661                   f"{type(args[0]).__name__!r} and {type(args[1]).__name__!r}")

    [... skipping hidden 7 frame]

File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-doaswkte-py3.9/lib/python3.9/site-packages/jax/_src/numpy/ufuncs.py:84, in _maybe_bool_binop.<locals>.fn(x1, x2)
     82 def fn(x1, x2):
     83   x1, x2 = _promote_args(numpy_fn.__name__, x1, x2)
---> 84   return lax_fn(x1, x2) if x1.dtype != np.bool_ else bool_lax_fn(x1, x2)

    [... skipping hidden 7 frame]

File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-doaswkte-py3.9/lib/python3.9/site-packages/jax/_src/lax/lax.py:1537, in broadcasting_shape_rule(name, *avals)
   1535       result_shape.append(non_1s[0])
   1536     else:
-> 1537       raise TypeError(f'{name} got incompatible shapes for broadcasting: '
   1538                       f'{", ".join(map(str, map(tuple, shapes)))}.')
   1540 return tuple(result_shape)

TypeError: add got incompatible shapes for broadcasting: (5, 3, 17, 19), (5, 3, 16, 19).

from flash-attention-jax.

mattjj avatar mattjj commented on May 29, 2024

It looks like one of chunk_idx + chunk_sizes, dk + dk_chunk, dv + dv_chunk has a shape error, in flash_attention_backward. (EDIT: I don't feel comfortable debugging that without learning what this code is actually doing, so hopefully someone who knows the code/algorithm can help!)

from flash-attention-jax.

GallagherCommaJack avatar GallagherCommaJack commented on May 29, 2024

debugging a bit, it looks like the issue is that dk has shape h, b, lkv, d and dk_chunk has shape h, b, lq, d

from flash-attention-jax.

GallagherCommaJack avatar GallagherCommaJack commented on May 29, 2024

@lucidrains looks like there's an implicit assumption somewhere in here that lq == lkv in the backwards pass, in _query_chunk_flash_attention_backward

from flash-attention-jax.

mattjj avatar mattjj commented on May 29, 2024

@GallagherCommaJack the fix I proposed in #8 is different from the commit you sent, just FYI.

from flash-attention-jax.

GallagherCommaJack avatar GallagherCommaJack commented on May 29, 2024

does that work with lq != lkv?

from flash-attention-jax.

GallagherCommaJack avatar GallagherCommaJack commented on May 29, 2024

looks like it does not

from flash-attention-jax.

mattjj avatar mattjj commented on May 29, 2024

Indeed I think the shape issue is unrelated.

from flash-attention-jax.

Related Issues (11)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.