Code Monkey home page Code Monkey logo

Comments (19)

lucidrains avatar lucidrains commented on July 28, 2024 4

Germany, Beijing, San Francisco

only in open source (and science)

from ring-attention-pytorch.

andreaskoepf avatar andreaskoepf commented on July 28, 2024 3

@lucidrains thanks a lot for your hard work & very interesting that you used a custom triton kernel! :-)

from ring-attention-pytorch.

lucidrains avatar lucidrains commented on July 28, 2024 2

8gp0cg

from ring-attention-pytorch.

andreaskoepf avatar andreaskoepf commented on July 28, 2024 1

@zhuzilin awesome work, we‘ll organize a little hack today 19:00 UTC on the cuda-mode discord to hack on your impl (do some testing, benchmarking and discussion about best comms options for single node and multi node etc.) - just fyi https://x.com/neurosp1ke/status/1760558683136589983

from ring-attention-pytorch.

zhuzilin avatar zhuzilin commented on July 28, 2024 1

oh... sorry, I took a day off and missed all the notification from github....

from ring-attention-pytorch.

lucidrains avatar lucidrains commented on July 28, 2024 1

@zhuzilin i think my version is working too now, with a modified forward flash attention kernel to minimize ring passes

thanks for sharing your repo for proof of concept!

from ring-attention-pytorch.

ericauld avatar ericauld commented on July 28, 2024 1

@lucidrains What is the issue you're referring to with the backward pass?

from ring-attention-pytorch.

andreaskoepf avatar andreaskoepf commented on July 28, 2024 1

What am I missing, and what is the utility of returning the LSE?

The returned log sum exp is what allows to apply flash-attenion in a blockwise manner (e.g. without it it wouldn't be possible to use flash-attn to implement ring-attn). See ring_flash_attn/utils.py#L19-L21.

from ring-attention-pytorch.

lucidrains avatar lucidrains commented on July 28, 2024

@zhuzilin hey Zilin! this looks like a good start, and what I intended to do at the very end! i was imagining that the ring communication could be done within CUDA using IPC? (however, I am far from CUDA expert, so I could be wrong and it is not possible) Are you planning on upstreaming the finalized implementation to Tri Dao's official flash attention repository? That would be a big contribution!

from ring-attention-pytorch.

lucidrains avatar lucidrains commented on July 28, 2024

@zhuzilin if you do embark on the pull request, the minimal features would be the ring IPC, able to specify the maximum number of ring passes (as I believe they must have curriculum learned the local attention to a full global, or mixed local and global using variable ring passes throughout the transformer), and finally, if you have the bandwidth, specialize masking logic for striped autoregressive attention to balance the workload

from ring-attention-pytorch.

lucidrains avatar lucidrains commented on July 28, 2024
Screen Shot 2024-02-21 at 7 11 49 AM

thank you! 🚀 ❤️

from ring-attention-pytorch.

lucidrains avatar lucidrains commented on July 28, 2024

@zhuzilin actually, after looking into CUDA IPC stuff, your approach may be the best for now

from ring-attention-pytorch.

zhuzilin avatar zhuzilin commented on July 28, 2024

Are you planning on upstreaming the finalized implementation to Tri Dao's official flash attention repository?

I'll draft an issue to the flash attention repo to see if they have interest in upstreaming (or designing a better version) in the official repo :)

after looking into CUDA IPC stuff, your approach may be the best for now

yeah, using nccl based p2p communication would be at least an easier way to implement with acceptable performance.

from ring-attention-pytorch.

lucidrains avatar lucidrains commented on July 28, 2024

i also wanted to do some LOTR references, but one meme is enough

from ring-attention-pytorch.

lucidrains avatar lucidrains commented on July 28, 2024

@andreaskoepf thanks! seems like there's still issue with backwards, but i'll leave it to someone or some team to fix. yup, i think the forwards requires the key, values to be iterated on the outer loop (to save on extraneous ring passes), so the reduced outputs, row maxes, lse needs to be stored and passed back in on the next ring pass. but i could be wrong and there may be a simpler way

from ring-attention-pytorch.

lucidrains avatar lucidrains commented on July 28, 2024

it isn't correct, probably something small with regards to how i'm using the flash attention api

feel free to submit a PR, i likely won't be able to get to this as i'll be running around bay area meeting people next month

from ring-attention-pytorch.

lucidrains avatar lucidrains commented on July 28, 2024

@ericauld ah, good news, the cuda backwards actually yielded the right gradients (full attention, no causal or key padding mask). it is my naive version that is broken

alright, i guess it is safe to remove the wip

from ring-attention-pytorch.

apaz-cli avatar apaz-cli commented on July 28, 2024

@lucidrains Knowing the LSE doesn't actually help you compute the backwards for softmax though, correct? The derivative of LSE is softmax, not the other way around. What am I missing, and what is the utility of returning the LSE?

from ring-attention-pytorch.

apaz-cli avatar apaz-cli commented on July 28, 2024

Ah, alright. That's what I'm missing. Makes sense :)

from ring-attention-pytorch.

Related Issues (12)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.