Comments (19)
Germany, Beijing, San Francisco
only in open source (and science)
from ring-attention-pytorch.
@lucidrains thanks a lot for your hard work & very interesting that you used a custom triton kernel! :-)
from ring-attention-pytorch.
from ring-attention-pytorch.
@zhuzilin awesome work, we‘ll organize a little hack today 19:00 UTC on the cuda-mode discord to hack on your impl (do some testing, benchmarking and discussion about best comms options for single node and multi node etc.) - just fyi https://x.com/neurosp1ke/status/1760558683136589983
from ring-attention-pytorch.
oh... sorry, I took a day off and missed all the notification from github....
from ring-attention-pytorch.
@zhuzilin i think my version is working too now, with a modified forward flash attention kernel to minimize ring passes
thanks for sharing your repo for proof of concept!
from ring-attention-pytorch.
@lucidrains What is the issue you're referring to with the backward pass?
from ring-attention-pytorch.
What am I missing, and what is the utility of returning the LSE?
The returned log sum exp is what allows to apply flash-attenion in a blockwise manner (e.g. without it it wouldn't be possible to use flash-attn to implement ring-attn). See ring_flash_attn/utils.py#L19-L21.
from ring-attention-pytorch.
@zhuzilin hey Zilin! this looks like a good start, and what I intended to do at the very end! i was imagining that the ring communication could be done within CUDA using IPC? (however, I am far from CUDA expert, so I could be wrong and it is not possible) Are you planning on upstreaming the finalized implementation to Tri Dao's official flash attention repository? That would be a big contribution!
from ring-attention-pytorch.
@zhuzilin if you do embark on the pull request, the minimal features would be the ring IPC, able to specify the maximum number of ring passes (as I believe they must have curriculum learned the local attention to a full global, or mixed local and global using variable ring passes throughout the transformer), and finally, if you have the bandwidth, specialize masking logic for striped autoregressive attention to balance the workload
from ring-attention-pytorch.
![Screen Shot 2024-02-21 at 7 11 49 AM](https://private-user-images.githubusercontent.com/108653/306673015-0120fbcb-504c-4657-b310-7252cd60b1b0.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MjIxMTU4ODEsIm5iZiI6MTcyMjExNTU4MSwicGF0aCI6Ii8xMDg2NTMvMzA2NjczMDE1LTAxMjBmYmNiLTUwNGMtNDY1Ny1iMzEwLTcyNTJjZDYwYjFiMC5wbmc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBVkNPRFlMU0E1M1BRSzRaQSUyRjIwMjQwNzI3JTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDI0MDcyN1QyMTI2MjFaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT0xNzA4NzM5YzY2OGYzMjRmYjI5YTI0NGU5NWQ3MGY0Njc1NzRhYWZjNThkZTkzNDcxNThmNWUyY2M4YTViMmM5JlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCZhY3Rvcl9pZD0wJmtleV9pZD0wJnJlcG9faWQ9MCJ9.7dsbvm95w_MXDPfu_MhiecqccBXMKV0DD7eMl6XGjI4)
thank you! 🚀 ❤️
from ring-attention-pytorch.
@zhuzilin actually, after looking into CUDA IPC stuff, your approach may be the best for now
from ring-attention-pytorch.
Are you planning on upstreaming the finalized implementation to Tri Dao's official flash attention repository?
I'll draft an issue to the flash attention repo to see if they have interest in upstreaming (or designing a better version) in the official repo :)
after looking into CUDA IPC stuff, your approach may be the best for now
yeah, using nccl based p2p communication would be at least an easier way to implement with acceptable performance.
from ring-attention-pytorch.
i also wanted to do some LOTR references, but one meme is enough
from ring-attention-pytorch.
@andreaskoepf thanks! seems like there's still issue with backwards, but i'll leave it to someone or some team to fix. yup, i think the forwards requires the key, values to be iterated on the outer loop (to save on extraneous ring passes), so the reduced outputs, row maxes, lse needs to be stored and passed back in on the next ring pass. but i could be wrong and there may be a simpler way
from ring-attention-pytorch.
it isn't correct, probably something small with regards to how i'm using the flash attention api
feel free to submit a PR, i likely won't be able to get to this as i'll be running around bay area meeting people next month
from ring-attention-pytorch.
@ericauld ah, good news, the cuda backwards actually yielded the right gradients (full attention, no causal or key padding mask). it is my naive version that is broken
alright, i guess it is safe to remove the wip
from ring-attention-pytorch.
@lucidrains Knowing the LSE doesn't actually help you compute the backwards for softmax though, correct? The derivative of LSE is softmax, not the other way around. What am I missing, and what is the utility of returning the LSE?
from ring-attention-pytorch.
Ah, alright. That's what I'm missing. Makes sense :)
from ring-attention-pytorch.
Related Issues (12)
- Comment about use of all gather HOT 2
- 8 A100S HOT 1
- ValueError: Invalid expression '[ True]', must be integers HOT 7
- I'm doing an image generation experiment, but my script outputs a json file, how do I train a Transformer model to generate a pixel representation of an image? HOT 1
- inference for open LLM
- Connection closed by peer HOT 1
- striped causal version of `ring_flash_attn_cuda` is not working HOT 2
- Cross Attention variant? HOT 1
- Is the GPU being used? HOT 1
- Question about RoPE HOT 1
- Ring attention with hugging face or accelerate deep speed
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from ring-attention-pytorch.