Comments (8)
Is there any plan to implement this feature?
I want to apply it to my custom jax code.
from flash-attention-jax.
^0^ great. Thanks for your support! Take care.
from flash-attention-jax.
I am looking into this code more carefully, and it seems that the unwanted computation (upper triangular region in the causal attention) is not excluded in the computational process. (I don't expect that the compiler also handles this aspect..)
I think it is intended for the easy understanding of flash attention, but it could be 2x faster if the length of query and key is the same and the process is compute bounded.
The issue is closed.
from flash-attention-jax.
yea, it should, as it is agnostic to how many preceding dimensions there are (whether it is batch, heads, etc)
from flash-attention-jax.
oh shoot, i never built it
believe at the time i thought vmap
would suffice
from flash-attention-jax.
sure! I can add it tomorrow morning, California time
from flash-attention-jax.
provided I don't drink too much tonight :)
from flash-attention-jax.
@sh0416 ok its done, you can test it by running
from flash_attention_jax import causal_attention, causal_flash_attention, value_and_grad_difference
diff, (dq_diff, dk_diff, dv_diff) = value_and_grad_difference(
causal_attention,
causal_flash_attention,
seed = 42,
add_key_mask = False
)
print('shows differences between normal and flash attention for output, dq, dk, dv')
print(f'o: {diff}') # < 1e-4
print(f'dq: {dq_diff}') # < 1e-6
print(f'dk: {dk_diff}') # < 1e-6
print(f'dv: {dv_diff}') # < 1e-6
from flash-attention-jax.
Related Issues (11)
- Question about calculation of Q and transpose(K).
- Slower than non-flash attention HOT 1
- Reshape error in causal_flash_attention when sequence length is not a multiple of 1024
- Online Softmax from FlashAttention HOT 2
- can I work on making a flax attention function out of this repository? HOT 1
- batch & multihead support? HOT 3
- more general mask support HOT 1
- support for per-head scales for cosine sim attention HOT 6
- fix compatibility with jax transformations HOT 28
- Performance benchmarks? HOT 20
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from flash-attention-jax.