Skip to content

Commit 76e2e98

Browse files
committed
fix decoding kernel for deepseekv2
1 parent dde5d23 commit 76e2e98

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

lmdeploy/pytorch/kernels/cuda/pagedattention.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ def _fwd_grouped_split_kernel(
121121
cur_head = cur_kv_head * HEAD_PER_CTA + tl.arange(0, BLOCK_H)
122122
mask_h = cur_head < cur_kv_head * HEAD_PER_CTA + HEAD_PER_CTA
123123
mask_h = mask_h & (cur_head < num_heads_q)
124+
if BLOCK_H < kv_group_num:
125+
cur_kv_head = (cur_kv_head * HEAD_PER_CTA) // kv_group_num
124126

125127
q_seqlen = 1
126128
kv_seqlen = tl.load(KV_seqlens + cur_batch)
@@ -366,6 +368,8 @@ def _fwd_grouped_split_quant_kernel(
366368
cur_head = cur_kv_head * HEAD_PER_CTA + tl.arange(0, BLOCK_H)
367369
mask_h = cur_head < cur_kv_head * HEAD_PER_CTA + HEAD_PER_CTA
368370
mask_h = mask_h & (cur_head < num_heads_q)
371+
if BLOCK_H < kv_group_num:
372+
cur_kv_head = (cur_kv_head * HEAD_PER_CTA) // kv_group_num
369373

370374
q_seqlen = 1
371375
kv_seqlen = tl.load(KV_seqlens + cur_batch)

tests/pytorch/kernel/test_paged_attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,8 @@ def conti_gt(self, gt, seq_lens):
244244

245245
@pytest.mark.parametrize('feat_dim', [48, 32], indirect=True)
246246
@pytest.mark.parametrize('feat_dim_v', [32], indirect=True)
247-
@pytest.mark.parametrize(['num_heads_q', 'num_heads_k'], [(8, 2), (2, 2)],
247+
@pytest.mark.parametrize(['num_heads_q', 'num_heads_k'], [(128, 2), (8, 2),
248+
(2, 2)],
248249
indirect=True)
249250
@pytest.mark.parametrize(['seq_lens', 'history_lens'],
250251
[([30, 50, 70, 90], [50, 40, 30, 20]),

0 commit comments

Comments
 (0)