Skip to content

Commit 34dd841

Browse files
committed
fix sliding window
1 parent 74db002 commit 34dd841

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

lmdeploy/pytorch/kernels/cuda/flashattention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,10 +199,10 @@ def _flash_prefill_fwd_kernel(
199199
loop_start = 0
200200
kv_min_loc = tl.zeros([BLOCK_M], dtype=tl.int32)
201201
if window_size > 0:
202-
start_block_id = tl.maximum(history_len - window_size, 0) // BLOCK_N
202+
start_block_id = tl.maximum(
203+
history_len + start_m * BLOCK_M - window_size, 0) // BLOCK_N
203204
kv_min_loc = tl.maximum(history_len + offs_m - window_size, 0)
204205
loop_start = start_block_id * BLOCK_N
205-
kv_start_loc += loop_start
206206

207207
offs_dk = tl.arange(0, BLOCK_DK)
208208
mask_dk = offs_dk < head_dim_k

tests/pytorch/kernel/test_flash_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def window_gt(self, conti_q, conti_kv, q_seqlens, kv_seqlens, win_size):
224224
@pytest.mark.parametrize(['num_heads_q', 'num_heads_k'], [(4, 2)],
225225
indirect=True)
226226
@pytest.mark.parametrize(['q_seqlens', 'history_lens'], [
227-
([30, 50, 70, 90], [50, 40, 30, 20]),
227+
([30, 50, 70, 90], [50, 40, 30, 90]),
228228
],
229229
indirect=True)
230230
@pytest.mark.parametrize('win_size', (32, ), indirect=True)

0 commit comments

Comments
 (0)