- Triton provides a flexible way to write customized kernels for user-defined operations, such as flash attention with flexible masks.
- This repo provides a very drafty and prototypical implementation, which is mostly based upon and adapted from (actually mostly contributes to these great works!):
- Triton-kernels: https://github.com/triton-lang/kernels/blob/main/kernels/flash_attention.py
- (This kernel uses nested-loops by default and is slow in backward, see here.)
- FlagAttention: https://github.com/FlagOpen/FlagAttention/blob/main/src/flag_attn/flash.py
- (This kernel uses two separate kernels for dkdv and dq in backward, and seems to be much faster, see here.)
flash_triton.py
contains the kernels andtesting.py
contains some simple testing codes.- One addition feature added is support for flexible masks:
- For example, the
attn_mask
argument in sdpa, which currently the CUDA-version flash attention does not support (for example, see these issues, here, and here). - Exactly the same motivation as FlexAttention.
- For this extra feature, you can provide two extra input tensors of
eq
andek
, (which have the shape of [bs, H, Lq] and [bs, H, Lk],) and ascore_func_mode
indicating how the masks would be calculated with these extra inputs. - Currently, an example mode
SCORE_FUNC_MODE1_DOC
is implemented, which supports a document attention mask ofextra_attn_mask = ((eq.unsqueeze(-1) >= 0) | (eq.unsqueeze(-1) == eq.unsqueeze(-2))) # [bs, H, Lq, Lk]
(see here). - Similar flexible masking modes can be implemented similarly.
- Things not implemented and might need more efforts: real block-sparse attention.
- For example, the
- Testings are done with the environment of
triton==3.0.0 torch==2.4.0
and with oneA100-SXM4-40GB
GPU. - See
2410_flash.pdf
for a simple illustration of flash attention and some related results.
-
Notifications
You must be signed in to change notification settings - Fork 0
Triton-Version of Flash Attention with Flexible Masks
License
zzsfornlp/Triton-Flash-Attention-with-Flexible-Masks
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
About
Triton-Version of Flash Attention with Flexible Masks
Resources
License
Stars
Watchers
Forks
Releases
No releases published
Packages 0
No packages published