Code Monkey home page Code Monkey logo

ring-attention's Introduction

ring-attention

Ring Attention leverages blockwise computation of self-attention on multiple GPUs and enables training and inference of sequences that would be too long for a single devices.

This repository contains notebooks, experiments and a collection of links to papers and other material related to Ring Attention.

Weekly Meeting

Every Sunday 5 PM UTC we meet in the "General" voice channel of the CUDA MODE discord server. You can contact us any time asynchronously in the #ring-attention channel.

Reserach / Material

Notebooks

Development References

How to contribute

Contact us on the CUDA MODE discord server: https://discord.gg/cudamode, PRs are welcome (please create an issue first).

ring-attention's People

Contributors

andreaskoepf avatar ericauld avatar lancerts avatar melvinebenezer avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

ring-attention's Issues

[info] test results for ring-flash-attention

Results from the dual RTX A5000 box for ring-flash-attention

test_qkvpackaded_func

##############################
forward:
##############################
out: max 4.3125, mean 0.04052734375
lse: max 8.985279083251953, mean 7.747061729431152
out diff:
[0] max 0.0, mean 0.0
[1] max 0.0009765625, mean 4.863739013671875e-05
lse diff:
[0] max 0.0, mean 0.0
[1] max 1.9073486328125e-06, mean 2.862522592295136e-07
##############################
backward:
##############################
load_dq:
[0] max 2.34375, mean 0.0537109375
[1] max 0.32421875, mean 0.0247802734375
dq diff:
[0] max 0.0009765625, mean 4.5693013817071915e-09
[1] max 0.001953125, mean 4.363059997558594e-05
load_dk:
[0] max 3.328125, mean 0.050537109375
[1] max 0.2294921875, mean 0.011962890625
dk diff:
[0] max 0.015625, mean 8.0108642578125e-05
[1] max 0.00048828125, mean 5.692243576049805e-06
load_dv:
[0] max 3.921875, mean 0.052978515625
[1] max 0.1904296875, mean 0.0120849609375
dv diff:
[0] max 0.015625, mean 8.153915405273438e-05
[1] max 0.00048828125, mean 6.938353180885315e-08

#######################################################

test_varlen_qkvpackaged_func

##############################
forward:
##############################
out: max 3.296875, mean 0.057373046875
out diff:
[0] max 0.0, mean 0.0
[1] max 0.00390625, mean 6.961822509765625e-05
lse: max 5.656599521636963, mean 4.309932708740234
lse diff:
[0] max 0.0, mean 0.0
[1] max 4.76837158203125e-07, mean 1.7325083945252118e-07
lse: max 7.727584362030029, mean 6.5278754234313965
lse diff:
[0] max 0.0, mean 0.0
[1] max 9.5367431640625e-07, mean 1.954694113237565e-07
lse: max 8.730499267578125, mean 7.501138687133789
lse diff:
[0] max 0.0, mean 0.0
[1] max 1.9073486328125e-06, mean 2.5631595690356335e-07
##############################
backward:
##############################
load_dq:
[0] max 3.0625, mean 0.07177734375
[1] max 1.0859375, mean 0.035400390625
dq diff:
[0] max 0.00048828125, mean 2.051820047199726e-09
[1] max 0.00390625, mean 6.389617919921875e-05
load_dk:
[0] max 3.484375, mean 0.0693359375
[1] max 1.0390625, mean 0.0169677734375
dk diff:
[0] max 0.015625, mean 0.00011157989501953125
[1] max 0.00390625, mean 1.0073184967041016e-05
load_dv:
[0] max 5.9375, mean 0.07373046875
[1] max 0.94921875, mean 0.0172119140625
dv diff:
[0] max 0.03125, mean 0.00011348724365234375
[1] max 0.00048828125, mean 6.379559636116028e-08

Compare ring-flash-attention & ring-attention-pytorch

lucidrains & zhuzilin were hard working the last days and have completed the following two ring-attention implementations:

Create a test setup that verifies correctness and compares the performance of both solutions.

Phil decided to use a custom triton kernel. Find out why this kernel is used and if it is indeed faster than the cuda flash-attention 2.

Please generate a little report of your findings, either as markdown file or ipynb.

Analyze overlapped P2P memory transfer and computing

Create an ipynb to analyze in PyTorch the peer-to-peer (between two GPUs) memory transfer and computing in parallel. Dummy computation could for example be some larger matmuls in a loop. Create notebooks folder and place the file there.

Goal should be to demonstrate that memory transfer and computation can run (to some degree) overlapped.

Quote from the ring-attention paper:
"If the computation time exceeds the time required for transferring key-value blocks, this results in no additional communication cost. This overlapping mechanism applies to both forward and backward passes of our approach since the same operations and techniques can be used"

The performance comparison between flash attn and ring flash attn

attn_output = ring_flash_attn_qkvpacked_func(qkv, dropout_p=dropout, softmax_scale=softmax_scale, causal=causal)

While it is acceptable that ring flash attn could lag behind flash attn in terms of speed, I am wondering whether the ring flash attn could outperform the flash attn in terms of memory use?

And it has been witnessed that in both aspects, ring flash attn can underperform flash attn (zhuzilin/ring-flash-attention#23), which is not that reasonable.

Extend educational naive flash-attn impl to allow partial kv-block processing (create naive ring-attn)

Extend the naive flash-attn notebook to allow block-wise processing of only a fraction of the blocks at a time, i.e. pass in and out state required to continue updating the outputs for the current queries (e.g. store block max, current sum etc).

With new function create a little test that shows that all values of splitted processing are "allclose()" to the same computation as classic dot product attention (see naive_attn() in the notebook linked above).

Store the generated ipynb file in the notebooks folder of this repo.

Updating `qkv` after padding

Hi,

In this block, the query, key, and value are being updated via _upad_input. Then in line 658, shouldn't the qkv be packed again?

if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
# cu_seqlens_q, cu_seqlens_k = cu_seq_lens
# max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
#attn_output_unpad = flash_attn_varlen_qkvpacked_func(
attn_output_unpad = ring_flash_attn_varlen_qkvpacked_func(
qkv,
cu_seq_lens,
max_seq_lens,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
)

[info] flash attention benchmark

results from the dual RTX A5000 system

causal=False, headdim=64, batch_size=32, seqlen=512

Flash2 fwd: 84.08 TFLOPs/s, bwd: 52.88 TFLOPs/s, fwd + bwd: 59.15 TFLOPs/s
Pytorch fwd: 14.52 TFLOPs/s, bwd: 17.06 TFLOPs/s, fwd + bwd: 16.25 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s

causal=False, headdim=64, batch_size=16, seqlen=1024

Flash2 fwd: 81.02 TFLOPs/s, bwd: 62.54 TFLOPs/s, fwd + bwd: 66.90 TFLOPs/s
Pytorch fwd: 16.72 TFLOPs/s, bwd: 19.12 TFLOPs/s, fwd + bwd: 18.36 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s

causal=False, headdim=64, batch_size=8, seqlen=2048

Flash2 fwd: 81.31 TFLOPs/s, bwd: 70.07 TFLOPs/s, fwd + bwd: 72.95 TFLOPs/s
Pytorch fwd: 15.50 TFLOPs/s, bwd: 18.70 TFLOPs/s, fwd + bwd: 17.66 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s

causal=False, headdim=64, batch_size=4, seqlen=4096

Flash2 fwd: 81.69 TFLOPs/s, bwd: 74.80 TFLOPs/s, fwd + bwd: 76.64 TFLOPs/s
Pytorch fwd: 18.56 TFLOPs/s, bwd: 19.67 TFLOPs/s, fwd + bwd: 19.34 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s

causal=False, headdim=64, batch_size=2, seqlen=8192

Flash2 fwd: 81.86 TFLOPs/s, bwd: 77.42 TFLOPs/s, fwd + bwd: 78.64 TFLOPs/s
Pytorch fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s

causal=False, headdim=64, batch_size=1, seqlen=16384

Flash2 fwd: 82.60 TFLOPs/s, bwd: 78.50 TFLOPs/s, fwd + bwd: 79.63 TFLOPs/s
Pytorch fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s

causal=False, headdim=128, batch_size=32, seqlen=512

Flash2 fwd: 82.91 TFLOPs/s, bwd: 49.25 TFLOPs/s, fwd + bwd: 55.71 TFLOPs/s
Pytorch fwd: 20.51 TFLOPs/s, bwd: 26.73 TFLOPs/s, fwd + bwd: 24.60 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s

causal=False, headdim=128, batch_size=16, seqlen=1024

Flash2 fwd: 79.48 TFLOPs/s, bwd: 57.66 TFLOPs/s, fwd + bwd: 62.57 TFLOPs/s
Pytorch fwd: 25.90 TFLOPs/s, bwd: 32.16 TFLOPs/s, fwd + bwd: 30.08 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s

causal=False, headdim=128, batch_size=8, seqlen=2048

Flash2 fwd: 80.54 TFLOPs/s, bwd: 64.37 TFLOPs/s, fwd + bwd: 68.29 TFLOPs/s
Pytorch fwd: 26.50 TFLOPs/s, bwd: 33.92 TFLOPs/s, fwd + bwd: 31.41 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s

causal=False, headdim=128, batch_size=4, seqlen=4096

Flash2 fwd: 82.49 TFLOPs/s, bwd: 68.40 TFLOPs/s, fwd + bwd: 71.91 TFLOPs/s
Pytorch fwd: 31.77 TFLOPs/s, bwd: 35.66 TFLOPs/s, fwd + bwd: 34.46 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s

causal=False, headdim=128, batch_size=2, seqlen=8192

Flash2 fwd: 83.24 TFLOPs/s, bwd: 70.70 TFLOPs/s, fwd + bwd: 73.88 TFLOPs/s
Pytorch fwd: 32.55 TFLOPs/s, bwd: 36.49 TFLOPs/s, fwd + bwd: 35.27 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s

causal=False, headdim=128, batch_size=1, seqlen=16384

Flash2 fwd: 83.51 TFLOPs/s, bwd: 70.94 TFLOPs/s, fwd + bwd: 74.13 TFLOPs/s
Pytorch fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s

causal=True, headdim=64, batch_size=32, seqlen=512

Flash2 fwd: 51.81 TFLOPs/s, bwd: 36.22 TFLOPs/s, fwd + bwd: 39.62 TFLOPs/s
Pytorch fwd: 5.24 TFLOPs/s, bwd: 8.53 TFLOPs/s, fwd + bwd: 7.23 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s

causal=True, headdim=64, batch_size=16, seqlen=1024

Flash2 fwd: 68.11 TFLOPs/s, bwd: 46.40 TFLOPs/s, fwd + bwd: 51.05 TFLOPs/s
Pytorch fwd: 5.43 TFLOPs/s, bwd: 9.60 TFLOPs/s, fwd + bwd: 7.87 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s

causal=True, headdim=64, batch_size=8, seqlen=2048

Flash2 fwd: 70.29 TFLOPs/s, bwd: 59.55 TFLOPs/s, fwd + bwd: 62.27 TFLOPs/s
Pytorch fwd: 5.41 TFLOPs/s, bwd: 9.38 TFLOPs/s, fwd + bwd: 7.76 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s

causal=True, headdim=64, batch_size=4, seqlen=4096

Flash2 fwd: 74.57 TFLOPs/s, bwd: 65.41 TFLOPs/s, fwd + bwd: 67.79 TFLOPs/s
Pytorch fwd: 5.60 TFLOPs/s, bwd: 9.81 TFLOPs/s, fwd + bwd: 8.08 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s

causal=True, headdim=64, batch_size=2, seqlen=8192

Flash2 fwd: 75.38 TFLOPs/s, bwd: 71.20 TFLOPs/s, fwd + bwd: 72.35 TFLOPs/s
Pytorch fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s

causal=True, headdim=64, batch_size=1, seqlen=16384

Flash2 fwd: 75.68 TFLOPs/s, bwd: 73.99 TFLOPs/s, fwd + bwd: 74.46 TFLOPs/s
Pytorch fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s

causal=True, headdim=128, batch_size=32, seqlen=512

Flash2 fwd: 59.04 TFLOPs/s, bwd: 34.96 TFLOPs/s, fwd + bwd: 39.57 TFLOPs/s
Pytorch fwd: 7.99 TFLOPs/s, bwd: 13.42 TFLOPs/s, fwd + bwd: 11.24 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s

causal=True, headdim=128, batch_size=16, seqlen=1024

Flash2 fwd: 67.22 TFLOPs/s, bwd: 45.18 TFLOPs/s, fwd + bwd: 49.85 TFLOPs/s
Pytorch fwd: 9.14 TFLOPs/s, bwd: 16.29 TFLOPs/s, fwd + bwd: 13.31 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s

causal=True, headdim=128, batch_size=8, seqlen=2048

Flash2 fwd: 67.44 TFLOPs/s, bwd: 54.87 TFLOPs/s, fwd + bwd: 57.96 TFLOPs/s
Pytorch fwd: 9.58 TFLOPs/s, bwd: 17.10 TFLOPs/s, fwd + bwd: 13.97 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s

causal=True, headdim=128, batch_size=4, seqlen=4096

Flash2 fwd: 68.65 TFLOPs/s, bwd: 62.42 TFLOPs/s, fwd + bwd: 64.08 TFLOPs/s
Pytorch fwd: 10.09 TFLOPs/s, bwd: 18.08 TFLOPs/s, fwd + bwd: 14.75 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s

causal=True, headdim=128, batch_size=2, seqlen=8192

Flash2 fwd: 68.06 TFLOPs/s, bwd: 67.21 TFLOPs/s, fwd + bwd: 67.45 TFLOPs/s
Pytorch fwd: 10.03 TFLOPs/s, bwd: 18.42 TFLOPs/s, fwd + bwd: 14.87 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s

causal=True, headdim=128, batch_size=1, seqlen=16384

Flash2 fwd: 65.36 TFLOPs/s, bwd: 69.11 TFLOPs/s, fwd + bwd: 68.00 TFLOPs/s
Pytorch fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.