Skip to content

zzsfornlp/Triton-Flash-Attention-with-Flexible-Masks

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Triton-Version-of-Flash-Attention-with-Flexible-Masks

Flash Attention with Flexible Masks in Triton

  • Triton provides a flexible way to write customized kernels for user-defined operations, such as flash attention with flexible masks.
  • This repo provides a very drafty and prototypical implementation, which is mostly based upon and adapted from (actually mostly contributes to these great works!):
  • flash_triton.py contains the kernels and testing.py contains some simple testing codes.
  • One addition feature added is support for flexible masks:
    • For example, the attn_mask argument in sdpa, which currently the CUDA-version flash attention does not support (for example, see these issues, here, and here).
    • Exactly the same motivation as FlexAttention.
    • For this extra feature, you can provide two extra input tensors of eq and ek, (which have the shape of [bs, H, Lq] and [bs, H, Lk],) and a score_func_mode indicating how the masks would be calculated with these extra inputs.
    • Currently, an example mode SCORE_FUNC_MODE1_DOC is implemented, which supports a document attention mask of extra_attn_mask = ((eq.unsqueeze(-1) >= 0) | (eq.unsqueeze(-1) == eq.unsqueeze(-2))) # [bs, H, Lq, Lk] (see here).
    • Similar flexible masking modes can be implemented similarly.
    • Things not implemented and might need more efforts: real block-sparse attention.
  • Testings are done with the environment of triton==3.0.0 torch==2.4.0 and with one A100-SXM4-40GB GPU.
  • See 2410_flash.pdf for a simple illustration of flash attention and some related results.

About

Triton-Version of Flash Attention with Flexible Masks

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages