@@ -31,6 +31,7 @@ def _flatten_kv_cache(
3131 stride_vos : tl .constexpr ,
3232 stride_vod : tl .constexpr ,
3333 stride_boff ,
34+ OUT_SIZE : tl .constexpr ,
3435 HEAD_DIM_K : tl .constexpr ,
3536 HEAD_DIM_V : tl .constexpr ,
3637 BLOCK_BS : tl .constexpr ,
@@ -42,7 +43,13 @@ def _flatten_kv_cache(
4243 batch_id = tl .program_id (1 )
4344 head_id = tl .program_id (2 )
4445
46+ num_batches = tl .num_programs (1 )
47+
4548 seqlen = tl .load (seqlens_ptr + batch_id )
49+ start_loc = tl .load (start_loc_ptr + batch_id )
50+ # fill last block to prevent attention nan
51+ if batch_id == num_batches - 1 :
52+ seqlen = OUT_SIZE - start_loc
4653 if page_id * BLOCK_BS >= seqlen :
4754 return
4855
@@ -117,6 +124,7 @@ def _flatten_kv_cache_quant(
117124 stride_vod : tl .constexpr ,
118125 stride_boff ,
119126 quant_policy : tl .constexpr ,
127+ OUT_SIZE : tl .constexpr ,
120128 HEAD_DIM_K : tl .constexpr ,
121129 HEAD_DIM_V : tl .constexpr ,
122130 BLOCK_BS : tl .constexpr ,
@@ -128,11 +136,15 @@ def _flatten_kv_cache_quant(
128136 batch_id = tl .program_id (1 )
129137 head_id = tl .program_id (2 )
130138
139+ num_batches = tl .num_programs (1 )
140+
131141 seqlen = tl .load (seqlens_ptr + batch_id )
142+ start_loc = tl .load (start_loc_ptr + batch_id )
143+ if batch_id == num_batches - 1 :
144+ seqlen = OUT_SIZE - start_loc
132145 if page_id * BLOCK_BS >= seqlen :
133146 return
134147
135- start_loc = tl .load (start_loc_ptr + batch_id )
136148 b_off = tl .load (block_offsets_ptr + batch_id * stride_boff + page_id )
137149
138150 offs_bs = tl .arange (0 , BLOCK_BS )
@@ -258,6 +270,7 @@ def flatten_kv_cache(k_caches: Tensor,
258270 stride_vos = v_states .stride (1 ),
259271 stride_vod = v_states .stride (2 ),
260272 stride_boff = block_offsets .stride (0 ),
273+ OUT_SIZE = out_size ,
261274 HEAD_DIM_K = k_head_dim ,
262275 HEAD_DIM_V = v_head_dim ,
263276 BLOCK_BS = BLOCK_BS ,
@@ -299,6 +312,7 @@ def flatten_kv_cache(k_caches: Tensor,
299312 stride_vod = v_states .stride (2 ),
300313 stride_boff = block_offsets .stride (0 ),
301314 quant_policy = quant_policy ,
315+ OUT_SIZE = out_size ,
302316 HEAD_DIM_K = k_head_dim ,
303317 HEAD_DIM_V = v_head_dim ,
304318 BLOCK_BS = BLOCK_BS ,
0 commit comments