Skip to content

Commit

Permalink
Fix kernel cache miss and add RDNA configs (#246)
Browse files Browse the repository at this point in the history
* Fix kernel cache miss and add RDNA configs

- added Navi configurations (Related PR: ROCm/triton#640)
- resolved cache miss issue during flash attention calls by fixing max_seqlen_q/k to 0

* Remove Navi autotune configs for triton FP8 support
  • Loading branch information
hyoon1 authored Dec 6, 2024
1 parent 2b17421 commit 8663822
Showing 1 changed file with 135 additions and 52 deletions.
187 changes: 135 additions & 52 deletions vllm/attention/ops/triton_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import triton
import triton.language as tl

from vllm.utils import is_navi

torch_dtype: tl.constexpr = torch.float16


Expand Down Expand Up @@ -217,88 +219,80 @@ def _attn_fwd_inner(
return acc, l_i, m_i


@triton.autotune(
configs=[
def get_cdna_autotune_configs():
return [
triton.Config(
{
"BLOCK_M": 256,
"BLOCK_N": 64,
"waves_per_eu": 2,
"PRE_LOAD_V": False,
'BLOCK_M': 256,
'BLOCK_N': 64,
'waves_per_eu': 2,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=8,
),
num_warps=8),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 128,
"waves_per_eu": 2,
"PRE_LOAD_V": False,
'BLOCK_M': 128,
'BLOCK_N': 128,
'waves_per_eu': 2,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=4,
),
num_warps=4),
triton.Config(
{
"BLOCK_M": 256,
"BLOCK_N": 128,
"waves_per_eu": 2,
"PRE_LOAD_V": False,
'BLOCK_M': 256,
'BLOCK_N': 128,
'waves_per_eu': 2,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=8,
),
num_warps=8),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 64,
"waves_per_eu": 1,
"PRE_LOAD_V": False,
'BLOCK_M': 128,
'BLOCK_N': 64,
'waves_per_eu': 1,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=4,
),
num_warps=4),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 64,
"waves_per_eu": 3,
"PRE_LOAD_V": True,
'BLOCK_M': 128,
'BLOCK_N': 64,
'waves_per_eu': 3,
'PRE_LOAD_V': True
},
num_stages=1,
num_warps=4,
),
num_warps=4),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 64,
"waves_per_eu": 3,
"PRE_LOAD_V": False,
'BLOCK_M': 128,
'BLOCK_N': 64,
'waves_per_eu': 3,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=4,
),
num_warps=4),
triton.Config(
{
"BLOCK_M": 64,
"BLOCK_N": 64,
"waves_per_eu": 4,
"PRE_LOAD_V": False,
'BLOCK_M': 64,
'BLOCK_N': 64,
'waves_per_eu': 4,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=8,
),
num_warps=8),
triton.Config(
{
"BLOCK_M": 32,
"BLOCK_N": 32,
"waves_per_eu": 4,
"PRE_LOAD_V": False,
'BLOCK_M': 32,
'BLOCK_N': 32,
'waves_per_eu': 4,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=8,
),
num_warps=8),
# TODO: This config fails with head_size not pow2 with data mismatches.
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1,
# 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
Expand All @@ -314,8 +308,93 @@ def _attn_fwd_inner(
# num_stages=1,
# num_warps=4,
# ),
],
key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8'],
], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8']


def get_rdna_autotune_configs():
return [
triton.Config(
{
'BLOCK_M': 32,
'BLOCK_N': 32,
'waves_per_eu': 4,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=2),
triton.Config(
{
'BLOCK_M': 32,
'BLOCK_N': 32,
'waves_per_eu': 2,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=2),
triton.Config(
{
'BLOCK_M': 32,
'BLOCK_N': 16,
'waves_per_eu': 4,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=2),
triton.Config(
{
'BLOCK_M': 32,
'BLOCK_N': 16,
'waves_per_eu': 2,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=2),
# Fails in AccelerateAMDMatmul (Triton) assert when using FP8:
# triton.Config(
# {
# 'BLOCK_M': 16,
# 'BLOCK_N': 16,
# 'waves_per_eu': 4,
# 'PRE_LOAD_V': False
# },
# num_stages=1,
# num_warps=2),
# triton.Config(
# {
# 'BLOCK_M': 16,
# 'BLOCK_N': 16,
# 'waves_per_eu': 2,
# 'PRE_LOAD_V': False
# },
# num_stages=1,
# num_warps=2),
# # Fall-back config.
# triton.Config(
# {
# 'BLOCK_M': 16,
# 'BLOCK_N': 16,
# 'waves_per_eu': 1,
# 'PRE_LOAD_V': False
# },
# num_stages=1,
# num_warps=2),
], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8']


def get_autotune_configs():
if is_navi():
return get_rdna_autotune_configs()
else:
return get_cdna_autotune_configs()


autotune_configs, autotune_keys = get_autotune_configs()


@triton.autotune(
configs=autotune_configs,
key=autotune_keys,
use_cuda_graph=True,
)
@triton.jit
def attn_fwd(
Expand Down Expand Up @@ -833,6 +912,10 @@ def check_and_convert(t, scale):
p_descale = 1.0 / p_scale
o_descale = 1.0 / o_scale

if is_navi():
max_seqlens_q = 0
max_seqlens_k = 0

attn_fwd[grid](
q,
k,
Expand Down

0 comments on commit 8663822

Please sign in to comment.