Code Monkey home page Code Monkey logo

flash-attention-jax's Introduction

Flash Attention - Jax

Implementation of Flash Attention in Jax. It will likely not be as performant as with the official CUDA version, given lack of ability for fine memory management. But just for educational purposes as well as to see how clever XLA compiler is (or is not).

Install

$ pip install flash-attention-jax

Usage

from jax import random
from flash_attention_jax import flash_attention

rng_key = random.PRNGKey(42)

q = random.normal(rng_key, (1, 2, 131072, 512))  # (batch, heads, seq, dim)
k = random.normal(rng_key, (1, 2, 131072, 512))
v = random.normal(rng_key, (1, 2, 131072, 512))
mask = random.randint(rng_key, (1, 131072,), 0, 2) # (batch, seq)

out, _ = flash_attention(q, k, v, mask)

out.shape  # (1, 2, 131072, 512) - (batch, heads, seq, dim)

Quick sanity check

from flash_attention_jax import plain_attention, flash_attention, value_and_grad_difference

diff, (dq_diff, dk_diff, dv_diff) = value_and_grad_difference(
    plain_attention,
    flash_attention,
    seed = 42
)

print('shows differences between normal and flash attention for output, dq, dk, dv')
print(f'o: {diff}')       # < 1e-4
print(f'dq: {dq_diff}')   # < 1e-6
print(f'dk: {dk_diff}')   # < 1e-6
print(f'dv: {dv_diff}')   # < 1e-6

Autoregressive Flash Attention - GPT-like decoder attention

from jax import random
from flash_attention_jax import causal_flash_attention

rng_key = random.PRNGKey(42)

q = random.normal(rng_key, (131072, 512))
k = random.normal(rng_key, (131072, 512))
v = random.normal(rng_key, (131072, 512))

out, _ = causal_flash_attention(q, k, v)

out.shape  # (131072, 512)

Todo

  • leading dimensions for causal flash attention variant

  • figure out issue with jit and static argnums

  • comment with references to paper algorithms and explanations

  • make sure it can work one-headed key / values, as in PaLM

Citations

@article{Dao2022FlashAttentionFA,
    title   = {FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness},
    author  = {Tri Dao and Daniel Y. Fu and Stefano Ermon and Atri Rudra and Christopher R'e},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2205.14135}
}
@article{Rabe2021SelfattentionDN,
    title   = {Self-attention Does Not Need O(n2) Memory},
    author  = {Markus N. Rabe and Charles Staats},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2112.05682}
}

flash-attention-jax's People

Contributors

lucidrains avatar mattjj avatar vhellendoorn avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

flash-attention-jax's Issues

support for per-head scales for cosine sim attention

usually with cosine-sim models I'd train with learned per-head scales for the attention logits, I guess I can get this from multiplying by q & k by sqrt(scales) before the dot product but that's probably less stable

Question about calculation of Q and transpose(K).

Thanks for your effort to make this great platform.

In normal attention, the input of softmax function is a form of matmul(Q,K_T) and its dimension is (batch, num_heads, q_len, k_len)
Also, the attention mask is like a trigonal shape (total shape is could be q_len x k_len)
so, matmul(q, k_t) is masked with the attention mask.

However, I don't understand how matmul(q_chunk, transposed k_chunk) works and results in masked input of softmax compared with original attention algorithm flow at the code lines below.

attn_weights = einsum('i ... d, j ... d -> i ... j', q_scaled, k_chunk)
key_mask_chunk = rearrange(key_mask_chunk, 'j b -> 1 b 1 j')
attn_weights = jnp.where(key_mask_chunk, attn_weights, MASK_VALUE)

Can you explain it with details?

batch & multihead support?

how hard would it be to add in support for leading dimensions (for e.g. batching & multiple heads)?
in my experience vmap is often less performant than batching by hand.

can I work on making a flax attention function out of this repository?

Hi lucidrain!

I wanted to use flash attention in one of my projects. I wanted a transformer model that works on sequences as long as 2400 with a batch size of 1000. The original flash attention does not fit in the memory for me. I wanted to use flash attention and found your implementation.

However, I found out I cannot just pass your attention implementation to flax.MultiHeadDotProductAttention here because there the attention_fn needs to be multiheaded, accept mask, dropout_rate, etc.

I was wondering if I could use your flash attention building block and add the required capabilities to it. I am not familiar with flash attention implementation but I am familiar with jax and flax. I was wondering if it is doable without understanding the underlying flash attention. If you think it is possible I can work on it and then create a pull request.

Reshape error in causal_flash_attention when sequence length is not a multiple of 1024

First off, thanks for writing this. It'd been a substantial improvement, even if the hand written CUDA kernels would've been better.

I've discovered a bug with odd sequence lengths. For e.g. 1025, you get TypeError: reshape total size must be unchanged, got new_sizes (1025, 256, 64) for shape (2, 1024, 256, 64). with a traceback pointing to causal_flash_attention.py:96 which is this line: out = out.reshape(q_len, bh, v_dim). AFAICT the problem occurs whenever your sequence length is greater than 1024 and not a multiple of 1024.

Repro:

import jax.numpy as jnp
from flash_attention_jax import causal_flash_attention

q = k = v = jnp.ones((1, 1, 1025, 16), dtype=jnp.float32)
_ = causal_flash_attention(q, k, v)

That fails, changing 1025 to 1024 works fine.

more general mask support

the general case of attention is (using annotations from jaxtyping)

q: Float["lq d"]
k: Float["lkv d"]
v: Float["lkv o"]
mask: Bool["lq lkv"]

returns: Float["lq o"]

but it looks like right now this library only supports a 1 dimensional mask?

fix compatibility with jax transformations

currently impossible to use flash_attention within a function that will use gradient checkpointing

minimal example to reproduce:

b = 3
lq = 16
lkv = 17
h = 5
d = 19
q = jax.random.normal(keys[0], (b, lq, h, d))
k = jax.random.normal(keys[1], (b, lkv, h, d))
v = jax.random.normal(keys[2], (b, lkv, h, d))
mask = jax.random.bernoulli(keys[3], 0.5, (b, lkv))

@jax.jit
def bench_flash_bwd(q, k, v, mask):
    return jax.grad(jax.checkpoint(lambda x: jnp.sum(flash_attention(x, k, v, mask)[0])))(q)

fails with error:

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
/home/jack/code/k-diffusion-jax/misc/test_flash.ipynb Cell 9 in <cell line: 1>()
----> [1](vscode-notebook-cell://ssh-remote%2Bjackcpu/home/jack/code/k-diffusion-jax/misc/test_flash.ipynb#X11sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0) get_ipython().run_line_magic('timeit', 'bench_flash_bwd(q, k, v, mask).block_until_ready()')

File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-doaswkte-py3.9/lib/python3.9/site-packages/IPython/core/interactiveshell.py:2305, in InteractiveShell.run_line_magic(self, magic_name, line, _stack_depth)
   2303     kwargs['local_ns'] = self.get_local_scope(stack_depth)
   2304 with self.builtin_trap:
-> 2305     result = fn(*args, **kwargs)
   2306 return result

File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-doaswkte-py3.9/lib/python3.9/site-packages/IPython/core/magics/execution.py:1162, in ExecutionMagics.timeit(self, line, cell, local_ns)
   1160 for index in range(0, 10):
   1161     number = 10 ** index
-> 1162     time_number = timer.timeit(number)
   1163     if time_number >= 0.2:
   1164         break

File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-doaswkte-py3.9/lib/python3.9/site-packages/IPython/core/magics/execution.py:156, in Timer.timeit(self, number)
    154 gc.disable()
    155 try:
--> 156     timing = self.inner(it, self.timer)
    157 finally:
    158     if gcold:

File <magic-timeit>:1, in inner(_it, _timer)

    [... skipping hidden 14 frame]

/home/jack/code/k-diffusion-jax/misc/test_flash.ipynb Cell 9 in bench_flash_bwd(q, k, v, mask)
      [1](vscode-notebook-cell://ssh-remote%2Bjackcpu/home/jack/code/k-diffusion-jax/misc/test_flash.ipynb#X11sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0) @jax.jit
      [2](vscode-notebook-cell://ssh-remote%2Bjackcpu/home/jack/code/k-diffusion-jax/misc/test_flash.ipynb#X11sdnNjb2RlLXJlbW90ZQ%3D%3D?line=1) def bench_flash_bwd(q, k, v, mask):
----> [3](vscode-notebook-cell://ssh-remote%2Bjackcpu/home/jack/code/k-diffusion-jax/misc/test_flash.ipynb#X11sdnNjb2RlLXJlbW90ZQ%3D%3D?line=2)     return jax.grad(jax.checkpoint(lambda x: jnp.sum(flash_attention(x, k, v, mask)[0]), policy=jax.checkpoint_policies.everything_saveable))(q)

    [... skipping hidden 25 frame]

File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-doaswkte-py3.9/lib/python3.9/site-packages/jax/_src/util.py:48, in safe_map(f, *args)
     46 n = len(args[0])
     47 for arg in args[1:]:
---> 48   assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
     49 return list(map(f, *args))

AssertionError: length mismatch: [3, 1]

Slower than non-flash attention

I tried to compare this implementation with a no bells-and-whistles implementation:

import time
import jax
import jax.numpy as jnp
import numpy as np
from flash_attention_jax import flash_attention

import jax.random

@jax.jit
def jax_attention(q, k, v):
  n_seq = q.shape[-2]
  logits = jnp.matmul(q, k)
  mask = jnp.tril(jnp.ones((1, 1, n_seq, n_seq), dtype=q.dtype))
  mask = jnp.broadcast_to(mask, logits.shape)
  logits = jnp.where(mask, logits, float('-inf'))
  ref_qk = jax.nn.softmax(logits)
  return jnp.matmul(ref_qk, v)


BATCH, N_HEADS, N_CTX, D_HEAD = 8, 64, 2000, 64


def bench_jax_flash(batch, heads, seq_len, d_model):
    shape = (batch, heads, seq_len, d_model,)
    q_jax = jnp.ones(shape, dtype=jnp.float16)
    k_jax = jnp.ones(shape, dtype=jnp.float16)
    v_jax = jnp.ones(shape, dtype=jnp.float16)
    mask = jnp.ones((batch, seq_len), dtype=jnp.int_)

    # warmup
    print('Warming up...')
    flash_attention(q_jax, k_jax, v_jax, mask).block_until_ready()
    flash_attention(q_jax, k_jax, v_jax, mask).block_until_ready()

    print('Benchmarking...')
    t1 = time.time()
    num_runs = 100
    for _ in range(num_runs):
        flash_attention(q_jax, k_jax, q_jax, mask).block_until_ready()
    estimate_ms = 1000 * (time.time() - t1) / num_runs
    return estimate_ms

print('Flash Jax implementation:')
print(bench_jax_flash(batch=BATCH, heads=N_HEADS, seq_len=N_CTX, d_model=D_HEAD))

def bench_jax(batch, heads, seq_len, d_model):
    q_jax = jnp.ones((batch, heads, seq_len, d_model), dtype=jnp.float16)
    k_jax = jnp.ones((batch, heads, d_model, seq_len), dtype=jnp.float16)
    v_jax = jnp.ones((batch, heads, seq_len, d_model), dtype=jnp.float16)
    # warmup
    print('Warming up...')
    jax_attention(q_jax, k_jax, q_jax).block_until_ready()
    jax_attention(q_jax, k_jax, q_jax).block_until_ready()

    print('Benchmarking...')
    t1 = time.time()
    num_runs = 100
    for _ in range(num_runs):
        jax_attention(q_jax, k_jax, q_jax).block_until_ready()
    estimate_ms = 1000 * (time.time() - t1) / num_runs
    return estimate_ms

print('Jax implementation:')
print(bench_jax(batch=BATCH, heads=N_HEADS, seq_len=N_CTX, d_model=D_HEAD))

Output:

Flash Jax implementation:
Warming up...
Benchmarking...
51.73063278198242
Jax implementation:
Warming up...
Benchmarking...
32.94001817703247

Performance benchmarks?

Are there any benchmark results now? Looking forward to performance comparisons with original attention, and official torch+CUDA implementation.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.