Skip to content

Commit ef24e85

Browse files
committed
fill last block
1 parent be975f5 commit ef24e85

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def _flatten_kv_cache(
3131
stride_vos: tl.constexpr,
3232
stride_vod: tl.constexpr,
3333
stride_boff,
34+
OUT_SIZE: tl.constexpr,
3435
HEAD_DIM_K: tl.constexpr,
3536
HEAD_DIM_V: tl.constexpr,
3637
BLOCK_BS: tl.constexpr,
@@ -42,7 +43,13 @@ def _flatten_kv_cache(
4243
batch_id = tl.program_id(1)
4344
head_id = tl.program_id(2)
4445

46+
num_batches = tl.num_programs(1)
47+
4548
seqlen = tl.load(seqlens_ptr + batch_id)
49+
start_loc = tl.load(start_loc_ptr + batch_id)
50+
# fill last block to prevent attention nan
51+
if batch_id == num_batches - 1:
52+
seqlen = OUT_SIZE - start_loc
4653
if page_id * BLOCK_BS >= seqlen:
4754
return
4855

@@ -117,6 +124,7 @@ def _flatten_kv_cache_quant(
117124
stride_vod: tl.constexpr,
118125
stride_boff,
119126
quant_policy: tl.constexpr,
127+
OUT_SIZE: tl.constexpr,
120128
HEAD_DIM_K: tl.constexpr,
121129
HEAD_DIM_V: tl.constexpr,
122130
BLOCK_BS: tl.constexpr,
@@ -128,11 +136,15 @@ def _flatten_kv_cache_quant(
128136
batch_id = tl.program_id(1)
129137
head_id = tl.program_id(2)
130138

139+
num_batches = tl.num_programs(1)
140+
131141
seqlen = tl.load(seqlens_ptr + batch_id)
142+
start_loc = tl.load(start_loc_ptr + batch_id)
143+
if batch_id == num_batches - 1:
144+
seqlen = OUT_SIZE - start_loc
132145
if page_id * BLOCK_BS >= seqlen:
133146
return
134147

135-
start_loc = tl.load(start_loc_ptr + batch_id)
136148
b_off = tl.load(block_offsets_ptr + batch_id * stride_boff + page_id)
137149

138150
offs_bs = tl.arange(0, BLOCK_BS)
@@ -258,6 +270,7 @@ def flatten_kv_cache(k_caches: Tensor,
258270
stride_vos=v_states.stride(1),
259271
stride_vod=v_states.stride(2),
260272
stride_boff=block_offsets.stride(0),
273+
OUT_SIZE=out_size,
261274
HEAD_DIM_K=k_head_dim,
262275
HEAD_DIM_V=v_head_dim,
263276
BLOCK_BS=BLOCK_BS,
@@ -299,6 +312,7 @@ def flatten_kv_cache(k_caches: Tensor,
299312
stride_vod=v_states.stride(2),
300313
stride_boff=block_offsets.stride(0),
301314
quant_policy=quant_policy,
315+
OUT_SIZE=out_size,
302316
HEAD_DIM_K=k_head_dim,
303317
HEAD_DIM_V=v_head_dim,
304318
BLOCK_BS=BLOCK_BS,

0 commit comments

Comments
 (0)