diff --git a/lmdeploy/pytorch/kernels/cuda/pagedattention.py b/lmdeploy/pytorch/kernels/cuda/pagedattention.py index e15ab911fc..d125eabc85 100644 --- a/lmdeploy/pytorch/kernels/cuda/pagedattention.py +++ b/lmdeploy/pytorch/kernels/cuda/pagedattention.py @@ -121,6 +121,8 @@ def _fwd_grouped_split_kernel( cur_head = cur_kv_head * HEAD_PER_CTA + tl.arange(0, BLOCK_H) mask_h = cur_head < cur_kv_head * HEAD_PER_CTA + HEAD_PER_CTA mask_h = mask_h & (cur_head < num_heads_q) + if BLOCK_H < kv_group_num: + cur_kv_head = (cur_kv_head * HEAD_PER_CTA) // kv_group_num q_seqlen = 1 kv_seqlen = tl.load(KV_seqlens + cur_batch) @@ -366,6 +368,8 @@ def _fwd_grouped_split_quant_kernel( cur_head = cur_kv_head * HEAD_PER_CTA + tl.arange(0, BLOCK_H) mask_h = cur_head < cur_kv_head * HEAD_PER_CTA + HEAD_PER_CTA mask_h = mask_h & (cur_head < num_heads_q) + if BLOCK_H < kv_group_num: + cur_kv_head = (cur_kv_head * HEAD_PER_CTA) // kv_group_num q_seqlen = 1 kv_seqlen = tl.load(KV_seqlens + cur_batch) diff --git a/tests/pytorch/kernel/test_paged_attention.py b/tests/pytorch/kernel/test_paged_attention.py index 7f63b281c5..0ef0db7330 100644 --- a/tests/pytorch/kernel/test_paged_attention.py +++ b/tests/pytorch/kernel/test_paged_attention.py @@ -244,7 +244,8 @@ def conti_gt(self, gt, seq_lens): @pytest.mark.parametrize('feat_dim', [48, 32], indirect=True) @pytest.mark.parametrize('feat_dim_v', [32], indirect=True) - @pytest.mark.parametrize(['num_heads_q', 'num_heads_k'], [(8, 2), (2, 2)], + @pytest.mark.parametrize(['num_heads_q', 'num_heads_k'], [(128, 2), (8, 2), + (2, 2)], indirect=True) @pytest.mark.parametrize(['seq_lens', 'history_lens'], [([30, 50, 70, 90], [50, 40, 30, 20]),