@@ -47,7 +47,7 @@ __global__ void prepare_varlen_num_blocks_kernel(
4747 int num_batch, int num_head, int qhead_per_khead, int num_sm, int num_splits_static,
4848 cutlass::FastDivmod blockm_divmod, cutlass::FastDivmod blockn_divmod,
4949 int * const tile_count_semaphore,
50- int * const num_m_blocks_ptr ,
50+ int * const prepare_seqlen_q_ptr ,
5151 int * const num_splits_dynamic_ptr,
5252 int * const varlen_batch_idx_ptr,
5353 // int* const num_n_blocks_ptr,
@@ -78,7 +78,7 @@ __global__ void prepare_varlen_num_blocks_kernel(
7878
7979 int lane = threadIdx .x % cutlass::NumThreadsPerWarp;
8080
81- auto get_num_m_blocks = [&](int batch_idx) {
81+ auto get_num_m_blocks_and_seqlen = [&](int batch_idx) {
8282 int seqlen;
8383 if (seqused_q) {
8484 seqlen = batch_idx < num_batch ? seqused_q[batch_idx] : 0 ;
@@ -91,7 +91,8 @@ __global__ void prepare_varlen_num_blocks_kernel(
9191 }
9292 if (packgqa) { seqlen *= qhead_per_khead; }
9393 return batch_idx < num_batch && lane < kNumBatchPerWarp
94- ? blockm_divmod.div (seqlen + blockm_divmod.divisor - 1 ) : 0 ;
94+ ? cute::make_tuple (blockm_divmod.div (seqlen + blockm_divmod.divisor - 1 ), seqlen)
95+ : cute::make_tuple (0 , 0 );
9596 };
9697
9798 auto get_num_n_blocks = [&](int batch_idx) {
@@ -124,7 +125,10 @@ __global__ void prepare_varlen_num_blocks_kernel(
124125 int batch_cta_idx_offset = int (blockIdx .x ) * 992 ;
125126 int bidb_start = batch_cta_idx_offset + kNumBatchPerWarp * warp_idx;
126127 int batch_idx = lane + bidb_start;
127- int num_m_blocks = get_num_m_blocks (batch_idx);
128+ // int num_m_blocks = get_num_m_blocks(batch_idx);
129+ auto seqlen_q_info = get_num_m_blocks_and_seqlen (batch_idx);
130+ int num_m_blocks = cute::get<0 >(seqlen_q_info);
131+ int seqlen_q = cute::get<1 >(seqlen_q_info);
128132 int num_n_blocks = get_num_n_blocks (batch_idx);
129133
130134 auto get_nheads_in_l2 = [&](int n_blocks) {
@@ -165,47 +169,35 @@ __global__ void prepare_varlen_num_blocks_kernel(
165169 num_n_blocks = INT_MIN; // sort last
166170 } else if (is_causal) {
167171 // sort by shortest member to process
168- num_n_blocks = num_n_blocks * blockn_divmod.divisor - num_m_blocks * blockm_divmod. divisor ;
172+ num_n_blocks = num_n_blocks * blockn_divmod.divisor - seqlen_q ;
169173 }
170174 int4 batch_coords[ITEMS_PER_THREAD]; // 1 item per thread
171- batch_coords[0 ] = make_int4 (num_n_blocks, num_m_blocks, num_splits_dynamic, batch_idx);
172-
173- // if (threadIdx.x == 0) {
174- // printf("Unsorted: num_n_blocks - num_m_blocks = %d, num_m_blocks = %d, num_splits = %d, batch_idx = %d.\n",
175- // batch_coords[0].x, batch_coords[0].y, batch_coords[0].z, batch_coords[0].w);
176- // } __syncthreads();
175+ batch_coords[0 ] = make_int4 (num_n_blocks, seqlen_q, num_splits_dynamic, batch_idx);
177176
178177 // Sort batches by num_n_blocks in descending order
179178 BlockMergeSort (temp_storage).Sort (batch_coords, PrepareSortOp<int4 >());
180179
181- // if (threadIdx.x == 0) {
182- // printf("Sorted: num_n_blocks - num_m_blocks = %d, num_m_blocks = %d, num_splits = %d, batch_idx = %d.\n",
183- // batch_coords[0].x, batch_coords[0].y, batch_coords[0].z, batch_coords[0].w);
184- // } __syncthreads();
185-
186180 if (is_causal) {
187181 // reset value to num_n_blocks
188- batch_coords[0 ].x = blockn_divmod.div (batch_coords[0 ].x + batch_coords[0 ].y * blockm_divmod. divisor );
182+ batch_coords[0 ].x = blockn_divmod.div (batch_coords[0 ].x + batch_coords[0 ].y );
189183 }
190184
191185 // When sorting, we re-index some metadata by 'virtual batch index'
192186 // and also store the vbidx -> bidx mapping.
193187 // 1. num_nheads_in_l2_ptr: virtual_batch_idx -> num_nheads_in_l2[batch_idx]
194188 // 2. num_splits_dynamic_ptr: virtual_batch_idx -> num_splits[batch_idx]
195- // 3. num_m_blocks_ptr : virtual_batch_idx -> num_m_blocks [batch_idx]
189+ // 3. prepare_seqlen_q_ptr : virtual_batch_idx -> seqlen_q [batch_idx] * (packgqa ? qhead_per_khead : 1)
196190 // 4. varlen_batch_idx_ptr: virtual_batch_idx -> batch_idx
197191 batch_idx = batch_cta_idx_offset + threadIdx .x ;
198192 if (batch_idx < num_batch && threadIdx .x < 992 ) {
199- // num_n_blocks_ptr[threadIdx.x] = max(batch_coords[0].x, 1);
200193 if (num_nheads_in_l2_ptr) { num_nheads_in_l2_ptr[batch_idx] = get_nheads_in_l2 (max (batch_coords[0 ].x , 1 )); }
201- num_m_blocks_ptr [batch_idx] = batch_coords[0 ].y ;
194+ prepare_seqlen_q_ptr [batch_idx] = batch_coords[0 ].y ;
202195 if (num_splits_dynamic_ptr) { num_splits_dynamic_ptr[batch_idx] = batch_coords[0 ].z ; }
203196 varlen_batch_idx_ptr[batch_idx] = batch_coords[0 ].w ;
204197 }
205198 } else {
206199 if (batch_idx < num_batch && lane < kNumBatchPerWarp ) {
207- num_m_blocks_ptr[batch_idx] = num_m_blocks;
208- // num_n_blocks_ptr[batch_idx] = max(num_n_blocks, 1);
200+ prepare_seqlen_q_ptr[batch_idx] = seqlen_q;
209201 if (num_splits_dynamic_ptr) { num_splits_dynamic_ptr[batch_idx] = num_splits_dynamic; }
210202 if (num_nheads_in_l2_ptr) { num_nheads_in_l2_ptr[batch_idx] = get_nheads_in_l2 (max (num_n_blocks, 1 )); }
211203 // printf("idx = %d, num_m_blocks = %d, num_n_blocks = %d, num_split_static = %d, num_splits_dynamic = %d\n", bidb_start + lane, num_m_blocks_ptr[bidb_start + lane], num_n_blocks, num_splits_static, num_splits_dynamic);
@@ -236,7 +228,7 @@ void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bo
236228 params.b , !packgqa ? params.h : params.h_k , qhead_per_khead, params.num_sm , params.num_splits ,
237229 cutlass::FastDivmod (blockM), cutlass::FastDivmod (blockN),
238230 params.tile_count_semaphore ,
239- params.num_m_blocks_ptr ,
231+ params.prepare_seqlen_q_ptr ,
240232 params.num_splits_dynamic_ptr ,
241233 params.varlen_batch_idx_ptr ,
242234 // params.num_n_blocks_ptr,
0 commit comments