Skip to content

Commit af9a1f2

Browse files
committed
change to storing seqlen_q so combine kernel can do coalesced reads for virtual batch metadata
1 parent 34ccbb7 commit af9a1f2

File tree

7 files changed

+47
-47
lines changed

7 files changed

+47
-47
lines changed

hopper/flash.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,8 @@ struct Flash_fwd_params : public Qkv_params {
151151
bool pack_gqa;
152152

153153
int * __restrict__ tile_count_semaphore;
154-
int * __restrict__ num_m_blocks_ptr;
154+
// int * __restrict__ num_m_blocks_ptr;
155+
int * __restrict__ prepare_seqlen_q_ptr;
155156
// int * __restrict__ num_n_blocks_ptr;
156157
int * __restrict__ num_splits_dynamic_ptr;
157158
int * __restrict__ varlen_batch_idx_ptr; // virtual -> actual

hopper/flash_api.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -656,8 +656,8 @@ mha_fwd_get_scheduler_metadata(
656656
tile_count_semaphore = torch::empty(
657657
{int(scheduler_needs_semaphore) + tile_count_semaphore_offset},
658658
opts.dtype(torch::kInt32));
659-
// ORDER: {num_m_blocks, num_splits_dynamic, varlen_batch_idx, num_nheads_in_l2}
660-
params.num_m_blocks_ptr = use_prepare_varlen ? tile_count_semaphore.data_ptr<int>() : nullptr;
659+
// ORDER: {prepare_seqlen_q, num_splits_dynamic, varlen_batch_idx, num_nheads_in_l2}
660+
params.prepare_seqlen_q_ptr = use_prepare_varlen ? tile_count_semaphore.data_ptr<int>() : nullptr;
661661
params.num_splits_dynamic_ptr = use_prepare_varlen && use_dynamic_split ? tile_count_semaphore.data_ptr<int>() + b_rounded : nullptr;
662662
params.varlen_batch_idx_ptr = use_prepare_varlen && params.varlen_sort_batches ? tile_count_semaphore.data_ptr<int>() + sort_offset : nullptr;
663663
// params.num_n_blocks_ptr = use_prepare_varlen && params.head_swizzle ? tile_count_semaphore.data_ptr<int>() + head_swizzle_offset : nullptr;
@@ -1058,8 +1058,8 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
10581058
if (scheduler_needs_semaphore && !use_prepare_varlen) {
10591059
tile_count_semaphore.zero_(); // If varlen we'll manually do the zero-ing
10601060
}
1061-
// ORDER: {num_m_blocks, num_splits_dynamic, varlen_batch_idx, num_nheads_in_l2}
1062-
params.num_m_blocks_ptr = use_prepare_varlen ? tile_count_semaphore.data_ptr<int>() : nullptr;
1061+
// ORDER: {prepare_seqlen_q, num_splits_dynamic, varlen_batch_idx, num_nheads_in_l2}
1062+
params.prepare_seqlen_q_ptr = use_prepare_varlen ? tile_count_semaphore.data_ptr<int>() : nullptr;
10631063
params.num_splits_dynamic_ptr = use_prepare_varlen && use_dynamic_split ? tile_count_semaphore.data_ptr<int>() + b_rounded : nullptr;
10641064
params.varlen_batch_idx_ptr = use_prepare_varlen && params.varlen_sort_batches ? tile_count_semaphore.data_ptr<int>() + sort_offset : nullptr;
10651065
// params.num_n_blocks_ptr = use_prepare_varlen && params.head_swizzle ? tile_count_semaphore.data_ptr<int>() + head_swizzle_offset : nullptr;

hopper/flash_fwd_combine_kernel.h

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -209,10 +209,12 @@ class FlashAttnFwdCombine {
209209
int seqlen_q;
210210
int total_q;
211211
int num_heads;
212+
int num_heads_kv;
212213
int dv;
214+
bool pack_gqa;
213215
int const* cu_seqlens_q;
214216
int const* seqused_q;
215-
int const* varlen_batch_idx_ptr;
217+
int const* prepare_seqlen_q_ptr;
216218
};
217219

218220
struct StaticTileScheduler {
@@ -257,12 +259,14 @@ class FlashAttnFwdCombine {
257259
};
258260

259261
struct Params {
260-
int b;
261-
int num_heads;
262+
int const b;
263+
int const num_heads;
264+
int const num_heads_kv;
265+
bool const pack_gqa;
262266
int const* const cu_seqlens_q;
263267
int const* const seqused_q;
268+
int const* const prepare_seqlen_q_ptr;
264269
SchedulingAlgo algo;
265-
int const* const varlen_batch_idx_ptr = nullptr;
266270
};
267271

268272
SharedStorage& shared_storage;
@@ -286,10 +290,12 @@ class FlashAttnFwdCombine {
286290
return {
287291
args.b,
288292
args.num_heads,
293+
args.num_heads_kv,
294+
args.pack_gqa,
289295
args.cu_seqlens_q,
290296
args.seqused_q,
291-
choose_scheduling_algo(args),
292-
args.varlen_batch_idx_ptr
297+
args.prepare_seqlen_q_ptr,
298+
choose_scheduling_algo(args)
293299
};
294300
}
295301

@@ -315,7 +321,6 @@ class FlashAttnFwdCombine {
315321
}
316322

317323
CUTE_DEVICE BlockCoord get_block_coord_linearized_m_and_batch(Params const& params) {
318-
int num_heads = params.num_heads;
319324
int curr_tile_id = blockIdx.x;
320325

321326
// Scan through the batches find the batch that contains the current
@@ -338,9 +343,13 @@ class FlashAttnFwdCombine {
338343

339344
auto get_num_m_blocks = [&](int bidb) {
340345
if (bidb >= params.b) return 0;
341-
int actual_bidb = params.varlen_batch_idx_ptr ? params.varlen_batch_idx_ptr[bidb] : bidb;
342-
flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{actual_bidb, 0, params.cu_seqlens_q, params.seqused_q};
343-
return cute::ceil_div(seqlen_info.seqlen * num_heads, Int<kBlockM>{}());
346+
if (params.prepare_seqlen_q_ptr) {
347+
int length = params.prepare_seqlen_q_ptr[bidb] * (!params.pack_gqa ? params.num_heads : params.num_heads_kv);
348+
return cute::ceil_div(length, Int<kBlockM>{});
349+
} else {
350+
flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{bidb, 0, params.cu_seqlens_q, params.seqused_q};
351+
return cute::ceil_div(seqlen_info.seqlen * params.num_heads, Int<kBlockM>{});
352+
}
344353
};
345354

346355
// Cumulative number of blocks for the next 31 batches

hopper/flash_fwd_combine_launch_template.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,8 @@ void run_flash_fwd_combine(Flash_fwd_params &params, cudaStream_t stream, bool e
4040
};
4141

4242
typename CombineKernel::SchedulerArguments scheduler_args {
43-
params.b, params.seqlen_q, params.total_q, params.h, params.dv,
44-
params.cu_seqlens_q, params.seqused_q,
45-
params.varlen_batch_idx_ptr
43+
params.b, params.seqlen_q, params.total_q, params.h, params.h_k, params.dv, params.pack_gqa,
44+
params.cu_seqlens_q, params.seqused_q, params.prepare_seqlen_q_ptr
4645
};
4746

4847
typename CombineKernel::Params kernel_params = {

hopper/flash_fwd_launch_template.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
157157
params.seqlen_k, params.d, params.dv, sizeof(Element),
158158
params.tile_count_semaphore, params.cu_seqlens_q, params.seqused_q,
159159
params.num_splits_dynamic_ptr,
160-
params.num_m_blocks_ptr,
160+
params.prepare_seqlen_q_ptr,
161161
params.varlen_batch_idx_ptr,
162162
params.num_nheads_in_l2_ptr
163163
};

hopper/flash_prepare_scheduler.cu

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ __global__ void prepare_varlen_num_blocks_kernel(
4747
int num_batch, int num_head, int qhead_per_khead, int num_sm, int num_splits_static,
4848
cutlass::FastDivmod blockm_divmod, cutlass::FastDivmod blockn_divmod,
4949
int* const tile_count_semaphore,
50-
int* const num_m_blocks_ptr,
50+
int* const prepare_seqlen_q_ptr,
5151
int* const num_splits_dynamic_ptr,
5252
int* const varlen_batch_idx_ptr,
5353
// int* const num_n_blocks_ptr,
@@ -78,7 +78,7 @@ __global__ void prepare_varlen_num_blocks_kernel(
7878

7979
int lane = threadIdx.x % cutlass::NumThreadsPerWarp;
8080

81-
auto get_num_m_blocks = [&](int batch_idx) {
81+
auto get_num_m_blocks_and_seqlen = [&](int batch_idx) {
8282
int seqlen;
8383
if (seqused_q) {
8484
seqlen = batch_idx < num_batch ? seqused_q[batch_idx] : 0;
@@ -91,7 +91,8 @@ __global__ void prepare_varlen_num_blocks_kernel(
9191
}
9292
if(packgqa) { seqlen *= qhead_per_khead; }
9393
return batch_idx < num_batch && lane < kNumBatchPerWarp
94-
? blockm_divmod.div(seqlen + blockm_divmod.divisor - 1) : 0;
94+
? cute::make_tuple(blockm_divmod.div(seqlen + blockm_divmod.divisor - 1), seqlen)
95+
: cute::make_tuple(0, 0);
9596
};
9697

9798
auto get_num_n_blocks = [&](int batch_idx) {
@@ -124,7 +125,10 @@ __global__ void prepare_varlen_num_blocks_kernel(
124125
int batch_cta_idx_offset = int(blockIdx.x) * 992;
125126
int bidb_start = batch_cta_idx_offset + kNumBatchPerWarp * warp_idx;
126127
int batch_idx = lane + bidb_start;
127-
int num_m_blocks = get_num_m_blocks(batch_idx);
128+
// int num_m_blocks = get_num_m_blocks(batch_idx);
129+
auto seqlen_q_info = get_num_m_blocks_and_seqlen(batch_idx);
130+
int num_m_blocks = cute::get<0>(seqlen_q_info);
131+
int seqlen_q = cute::get<1>(seqlen_q_info);
128132
int num_n_blocks = get_num_n_blocks(batch_idx);
129133

130134
auto get_nheads_in_l2 = [&](int n_blocks) {
@@ -165,47 +169,35 @@ __global__ void prepare_varlen_num_blocks_kernel(
165169
num_n_blocks = INT_MIN; // sort last
166170
} else if (is_causal) {
167171
// sort by shortest member to process
168-
num_n_blocks = num_n_blocks * blockn_divmod.divisor - num_m_blocks * blockm_divmod.divisor;
172+
num_n_blocks = num_n_blocks * blockn_divmod.divisor - seqlen_q;
169173
}
170174
int4 batch_coords[ITEMS_PER_THREAD]; // 1 item per thread
171-
batch_coords[0] = make_int4(num_n_blocks, num_m_blocks, num_splits_dynamic, batch_idx);
172-
173-
// if (threadIdx.x == 0) {
174-
// printf("Unsorted: num_n_blocks - num_m_blocks = %d, num_m_blocks = %d, num_splits = %d, batch_idx = %d.\n",
175-
// batch_coords[0].x, batch_coords[0].y, batch_coords[0].z, batch_coords[0].w);
176-
// } __syncthreads();
175+
batch_coords[0] = make_int4(num_n_blocks, seqlen_q, num_splits_dynamic, batch_idx);
177176

178177
// Sort batches by num_n_blocks in descending order
179178
BlockMergeSort(temp_storage).Sort(batch_coords, PrepareSortOp<int4>());
180179

181-
// if (threadIdx.x == 0) {
182-
// printf("Sorted: num_n_blocks - num_m_blocks = %d, num_m_blocks = %d, num_splits = %d, batch_idx = %d.\n",
183-
// batch_coords[0].x, batch_coords[0].y, batch_coords[0].z, batch_coords[0].w);
184-
// } __syncthreads();
185-
186180
if (is_causal) {
187181
// reset value to num_n_blocks
188-
batch_coords[0].x = blockn_divmod.div(batch_coords[0].x + batch_coords[0].y * blockm_divmod.divisor);
182+
batch_coords[0].x = blockn_divmod.div(batch_coords[0].x + batch_coords[0].y);
189183
}
190184

191185
// When sorting, we re-index some metadata by 'virtual batch index'
192186
// and also store the vbidx -> bidx mapping.
193187
// 1. num_nheads_in_l2_ptr: virtual_batch_idx -> num_nheads_in_l2[batch_idx]
194188
// 2. num_splits_dynamic_ptr: virtual_batch_idx -> num_splits[batch_idx]
195-
// 3. num_m_blocks_ptr: virtual_batch_idx -> num_m_blocks[batch_idx]
189+
// 3. prepare_seqlen_q_ptr: virtual_batch_idx -> seqlen_q[batch_idx] * (packgqa ? qhead_per_khead : 1)
196190
// 4. varlen_batch_idx_ptr: virtual_batch_idx -> batch_idx
197191
batch_idx = batch_cta_idx_offset + threadIdx.x;
198192
if (batch_idx < num_batch && threadIdx.x < 992) {
199-
// num_n_blocks_ptr[threadIdx.x] = max(batch_coords[0].x, 1);
200193
if(num_nheads_in_l2_ptr) { num_nheads_in_l2_ptr[batch_idx] = get_nheads_in_l2(max(batch_coords[0].x, 1)); }
201-
num_m_blocks_ptr[batch_idx] = batch_coords[0].y;
194+
prepare_seqlen_q_ptr[batch_idx] = batch_coords[0].y;
202195
if(num_splits_dynamic_ptr) { num_splits_dynamic_ptr[batch_idx] = batch_coords[0].z; }
203196
varlen_batch_idx_ptr[batch_idx] = batch_coords[0].w;
204197
}
205198
} else {
206199
if (batch_idx < num_batch && lane < kNumBatchPerWarp) {
207-
num_m_blocks_ptr[batch_idx] = num_m_blocks;
208-
// num_n_blocks_ptr[batch_idx] = max(num_n_blocks, 1);
200+
prepare_seqlen_q_ptr[batch_idx] = seqlen_q;
209201
if(num_splits_dynamic_ptr) { num_splits_dynamic_ptr[batch_idx] = num_splits_dynamic; }
210202
if(num_nheads_in_l2_ptr) { num_nheads_in_l2_ptr[batch_idx] = get_nheads_in_l2(max(num_n_blocks, 1)); }
211203
// printf("idx = %d, num_m_blocks = %d, num_n_blocks = %d, num_split_static = %d, num_splits_dynamic = %d\n", bidb_start + lane, num_m_blocks_ptr[bidb_start + lane], num_n_blocks, num_splits_static, num_splits_dynamic);
@@ -236,7 +228,7 @@ void prepare_varlen_num_blocks(Flash_fwd_params &params, cudaStream_t stream, bo
236228
params.b, !packgqa ? params.h : params.h_k, qhead_per_khead, params.num_sm, params.num_splits,
237229
cutlass::FastDivmod(blockM), cutlass::FastDivmod(blockN),
238230
params.tile_count_semaphore,
239-
params.num_m_blocks_ptr,
231+
params.prepare_seqlen_q_ptr,
240232
params.num_splits_dynamic_ptr,
241233
params.varlen_batch_idx_ptr,
242234
// params.num_n_blocks_ptr,

hopper/tile_scheduler.hpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ struct TileSchedulerArguments {
2525
int const* const cu_seqlens = nullptr;
2626
int const* const seqused = nullptr;
2727
int const* const num_splits_dynamic_ptr = nullptr;
28-
int const* const num_m_blocks_ptr = nullptr;
28+
int const* const prepare_seqlen_q_ptr = nullptr;
2929
int const* const varlen_batch_idx_ptr = nullptr;
3030
// int const* const num_n_blocks_ptr = nullptr;
3131
int const* const num_nheads_in_l2_ptr = nullptr;
@@ -385,7 +385,7 @@ class VarlenDynamicPersistentTileScheduler {
385385
int const* const cu_seqlens;
386386
int const* const seqused;
387387
int const* const num_splits_dynamic_ptr;
388-
int const* const num_m_blocks_ptr;
388+
int const* const prepare_seqlen_q_ptr;
389389
int const* const varlen_batch_idx_ptr;
390390
// int const* const num_n_blocks_ptr;
391391
int const* const num_nheads_in_l2_ptr;
@@ -408,7 +408,7 @@ class VarlenDynamicPersistentTileScheduler {
408408
cutlass::FastDivmod(!Split ? 1 : args.num_splits),
409409
args.tile_count_semaphore, args.cu_seqlens, args.seqused,
410410
args.num_splits_dynamic_ptr,
411-
args.num_m_blocks_ptr,
411+
args.prepare_seqlen_q_ptr,
412412
args.varlen_batch_idx_ptr,
413413
// aras.num_n_blocks_ptr,
414414
args.num_nheads_in_l2_ptr};
@@ -470,7 +470,7 @@ class VarlenDynamicPersistentTileScheduler {
470470
int batch_idx = lane + bidb_start;
471471
if constexpr (Prepared) {
472472
return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1
473-
? params.num_m_blocks_ptr[batch_idx] : 0;
473+
? cute::ceil_div(params.prepare_seqlen_q_ptr[batch_idx], kBlockM) : 0;
474474
} else {
475475
int seqlen = params.seqlen * (!PackGQA ? 1 : params.qhead_per_khead);
476476
if (seqlen > kBlockM) {
@@ -487,7 +487,6 @@ class VarlenDynamicPersistentTileScheduler {
487487
}
488488
return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1
489489
? cute::ceil_div(seqlen, kBlockM) : 0;
490-
// ? params.num_m_blocks_ptr[batch_idx] : 0;
491490
}
492491
};
493492

0 commit comments

Comments
 (0)