cuda-mode / ring-attention Goto Github PK
View Code? Open in Web Editor NEWring-attention experiments
License: Apache License 2.0
ring-attention experiments
License: Apache License 2.0
ring-attention/ring-llama/modeling_llama.py
Line 669 in 65f904c
While it is acceptable that ring flash attn could lag behind flash attn in terms of speed, I am wondering whether the ring flash attn could outperform the flash attn in terms of memory use?
And it has been witnessed that in both aspects, ring flash attn can underperform flash attn (zhuzilin/ring-flash-attention#23), which is not that reasonable.
lucidrains & zhuzilin were hard working the last days and have completed the following two ring-attention implementations:
Create a test setup that verifies correctness and compares the performance of both solutions.
Phil decided to use a custom triton kernel. Find out why this kernel is used and if it is indeed faster than the cuda flash-attention 2.
Please generate a little report of your findings, either as markdown file or ipynb.
Please create a little markdown report about your findings.
Get some feeling for the impls:
The DummyRingAttentionImpl.ipynb tests if the ring attention result matches the result of the naive_attn()
function. Currently the abs-error is still to high for the small matrices used.
Task: Find & fix bug, submit PR.
Hi,
In this block, the query, key, and value are being updated via _upad_input
. Then in line 658, shouldn't the qkv be packed again?
ring-attention/ring-llama/modeling_llama.py
Lines 647 to 664 in d7aa779
Create an ipynb to analyze in PyTorch the peer-to-peer (between two GPUs) memory transfer and computing in parallel. Dummy computation could for example be some larger matmuls in a loop. Create notebooks folder and place the file there.
Goal should be to demonstrate that memory transfer and computation can run (to some degree) overlapped.
Quote from the ring-attention paper:
"If the computation time exceeds the time required for transferring key-value blocks, this results in no additional communication cost. This overlapping mechanism applies to both forward and backward passes of our approach since the same operations and techniques can be used"
Extend the naive flash-attn notebook to allow block-wise processing of only a fraction of the blocks at a time, i.e. pass in and out state required to continue updating the outputs for the current queries (e.g. store block max, current sum etc).
With new function create a little test that shows that all values of splitted processing are "allclose()" to the same computation as classic dot product attention (see naive_attn()
in the notebook linked above).
Store the generated ipynb file in the notebooks folder of this repo.
Thanks for the great work!
https://github.com/cuda-mode/ring-attention/blob/main/ring-llama/test.ipynb
I can load the model with LlamaRingFlashAttention
and move to the device but I've seen
RuntimeError: Default process group has not been initialized, please make sure to call init_process_group.
when I run y = model.generate
What did I miss? Thanks in advance!
results from the dual RTX A5000 system
Flash2 fwd: 84.08 TFLOPs/s, bwd: 52.88 TFLOPs/s, fwd + bwd: 59.15 TFLOPs/s
Pytorch fwd: 14.52 TFLOPs/s, bwd: 17.06 TFLOPs/s, fwd + bwd: 16.25 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Flash2 fwd: 81.02 TFLOPs/s, bwd: 62.54 TFLOPs/s, fwd + bwd: 66.90 TFLOPs/s
Pytorch fwd: 16.72 TFLOPs/s, bwd: 19.12 TFLOPs/s, fwd + bwd: 18.36 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Flash2 fwd: 81.31 TFLOPs/s, bwd: 70.07 TFLOPs/s, fwd + bwd: 72.95 TFLOPs/s
Pytorch fwd: 15.50 TFLOPs/s, bwd: 18.70 TFLOPs/s, fwd + bwd: 17.66 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Flash2 fwd: 81.69 TFLOPs/s, bwd: 74.80 TFLOPs/s, fwd + bwd: 76.64 TFLOPs/s
Pytorch fwd: 18.56 TFLOPs/s, bwd: 19.67 TFLOPs/s, fwd + bwd: 19.34 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Flash2 fwd: 81.86 TFLOPs/s, bwd: 77.42 TFLOPs/s, fwd + bwd: 78.64 TFLOPs/s
Pytorch fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Flash2 fwd: 82.60 TFLOPs/s, bwd: 78.50 TFLOPs/s, fwd + bwd: 79.63 TFLOPs/s
Pytorch fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Flash2 fwd: 82.91 TFLOPs/s, bwd: 49.25 TFLOPs/s, fwd + bwd: 55.71 TFLOPs/s
Pytorch fwd: 20.51 TFLOPs/s, bwd: 26.73 TFLOPs/s, fwd + bwd: 24.60 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Flash2 fwd: 79.48 TFLOPs/s, bwd: 57.66 TFLOPs/s, fwd + bwd: 62.57 TFLOPs/s
Pytorch fwd: 25.90 TFLOPs/s, bwd: 32.16 TFLOPs/s, fwd + bwd: 30.08 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Flash2 fwd: 80.54 TFLOPs/s, bwd: 64.37 TFLOPs/s, fwd + bwd: 68.29 TFLOPs/s
Pytorch fwd: 26.50 TFLOPs/s, bwd: 33.92 TFLOPs/s, fwd + bwd: 31.41 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Flash2 fwd: 82.49 TFLOPs/s, bwd: 68.40 TFLOPs/s, fwd + bwd: 71.91 TFLOPs/s
Pytorch fwd: 31.77 TFLOPs/s, bwd: 35.66 TFLOPs/s, fwd + bwd: 34.46 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Flash2 fwd: 83.24 TFLOPs/s, bwd: 70.70 TFLOPs/s, fwd + bwd: 73.88 TFLOPs/s
Pytorch fwd: 32.55 TFLOPs/s, bwd: 36.49 TFLOPs/s, fwd + bwd: 35.27 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Flash2 fwd: 83.51 TFLOPs/s, bwd: 70.94 TFLOPs/s, fwd + bwd: 74.13 TFLOPs/s
Pytorch fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Flash2 fwd: 51.81 TFLOPs/s, bwd: 36.22 TFLOPs/s, fwd + bwd: 39.62 TFLOPs/s
Pytorch fwd: 5.24 TFLOPs/s, bwd: 8.53 TFLOPs/s, fwd + bwd: 7.23 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Flash2 fwd: 68.11 TFLOPs/s, bwd: 46.40 TFLOPs/s, fwd + bwd: 51.05 TFLOPs/s
Pytorch fwd: 5.43 TFLOPs/s, bwd: 9.60 TFLOPs/s, fwd + bwd: 7.87 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Flash2 fwd: 70.29 TFLOPs/s, bwd: 59.55 TFLOPs/s, fwd + bwd: 62.27 TFLOPs/s
Pytorch fwd: 5.41 TFLOPs/s, bwd: 9.38 TFLOPs/s, fwd + bwd: 7.76 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Flash2 fwd: 74.57 TFLOPs/s, bwd: 65.41 TFLOPs/s, fwd + bwd: 67.79 TFLOPs/s
Pytorch fwd: 5.60 TFLOPs/s, bwd: 9.81 TFLOPs/s, fwd + bwd: 8.08 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Flash2 fwd: 75.38 TFLOPs/s, bwd: 71.20 TFLOPs/s, fwd + bwd: 72.35 TFLOPs/s
Pytorch fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Flash2 fwd: 75.68 TFLOPs/s, bwd: 73.99 TFLOPs/s, fwd + bwd: 74.46 TFLOPs/s
Pytorch fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Flash2 fwd: 59.04 TFLOPs/s, bwd: 34.96 TFLOPs/s, fwd + bwd: 39.57 TFLOPs/s
Pytorch fwd: 7.99 TFLOPs/s, bwd: 13.42 TFLOPs/s, fwd + bwd: 11.24 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Flash2 fwd: 67.22 TFLOPs/s, bwd: 45.18 TFLOPs/s, fwd + bwd: 49.85 TFLOPs/s
Pytorch fwd: 9.14 TFLOPs/s, bwd: 16.29 TFLOPs/s, fwd + bwd: 13.31 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Flash2 fwd: 67.44 TFLOPs/s, bwd: 54.87 TFLOPs/s, fwd + bwd: 57.96 TFLOPs/s
Pytorch fwd: 9.58 TFLOPs/s, bwd: 17.10 TFLOPs/s, fwd + bwd: 13.97 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Flash2 fwd: 68.65 TFLOPs/s, bwd: 62.42 TFLOPs/s, fwd + bwd: 64.08 TFLOPs/s
Pytorch fwd: 10.09 TFLOPs/s, bwd: 18.08 TFLOPs/s, fwd + bwd: 14.75 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Flash2 fwd: 68.06 TFLOPs/s, bwd: 67.21 TFLOPs/s, fwd + bwd: 67.45 TFLOPs/s
Pytorch fwd: 10.03 TFLOPs/s, bwd: 18.42 TFLOPs/s, fwd + bwd: 14.87 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Flash2 fwd: 65.36 TFLOPs/s, bwd: 69.11 TFLOPs/s, fwd + bwd: 68.00 TFLOPs/s
Pytorch fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Results from the dual RTX A5000 box for ring-flash-attention
##############################
forward:
##############################
out: max 4.3125, mean 0.04052734375
lse: max 8.985279083251953, mean 7.747061729431152
out diff:
[0] max 0.0, mean 0.0
[1] max 0.0009765625, mean 4.863739013671875e-05
lse diff:
[0] max 0.0, mean 0.0
[1] max 1.9073486328125e-06, mean 2.862522592295136e-07
##############################
backward:
##############################
load_dq:
[0] max 2.34375, mean 0.0537109375
[1] max 0.32421875, mean 0.0247802734375
dq diff:
[0] max 0.0009765625, mean 4.5693013817071915e-09
[1] max 0.001953125, mean 4.363059997558594e-05
load_dk:
[0] max 3.328125, mean 0.050537109375
[1] max 0.2294921875, mean 0.011962890625
dk diff:
[0] max 0.015625, mean 8.0108642578125e-05
[1] max 0.00048828125, mean 5.692243576049805e-06
load_dv:
[0] max 3.921875, mean 0.052978515625
[1] max 0.1904296875, mean 0.0120849609375
dv diff:
[0] max 0.015625, mean 8.153915405273438e-05
[1] max 0.00048828125, mean 6.938353180885315e-08
#######################################################
##############################
forward:
##############################
out: max 3.296875, mean 0.057373046875
out diff:
[0] max 0.0, mean 0.0
[1] max 0.00390625, mean 6.961822509765625e-05
lse: max 5.656599521636963, mean 4.309932708740234
lse diff:
[0] max 0.0, mean 0.0
[1] max 4.76837158203125e-07, mean 1.7325083945252118e-07
lse: max 7.727584362030029, mean 6.5278754234313965
lse diff:
[0] max 0.0, mean 0.0
[1] max 9.5367431640625e-07, mean 1.954694113237565e-07
lse: max 8.730499267578125, mean 7.501138687133789
lse diff:
[0] max 0.0, mean 0.0
[1] max 1.9073486328125e-06, mean 2.5631595690356335e-07
##############################
backward:
##############################
load_dq:
[0] max 3.0625, mean 0.07177734375
[1] max 1.0859375, mean 0.035400390625
dq diff:
[0] max 0.00048828125, mean 2.051820047199726e-09
[1] max 0.00390625, mean 6.389617919921875e-05
load_dk:
[0] max 3.484375, mean 0.0693359375
[1] max 1.0390625, mean 0.0169677734375
dk diff:
[0] max 0.015625, mean 0.00011157989501953125
[1] max 0.00390625, mean 1.0073184967041016e-05
load_dv:
[0] max 5.9375, mean 0.07373046875
[1] max 0.94921875, mean 0.0172119140625
dv diff:
[0] max 0.03125, mean 0.00011348724365234375
[1] max 0.00048828125, mean 6.379559636116028e-08
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.