@@ -205,6 +205,7 @@ cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params,
205205 const uint_fastdiv group_size_fastdiv (group_size);
206206 constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16 ;
207207 constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16 ;
208+
208209 uint32_t cta_tile_q_p = 0 ;
209210 int64_t unpacked_qo_len = qo_len * group_size;
210211 if (unpacked_qo_len > 64 && HEAD_DIM_VO < 256 ) {
@@ -268,183 +269,185 @@ cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params,
268269 NUM_MMA_Q_D * NUM_WARPS_Q_D) /
269270 (2 * NUM_WARPS_KV_D);
270271
271- // DISPATCH_CTA_TILE_Q(cta_tile_q_p, CTA_TILE_Q_P, {
272- constexpr size_t CTA_TILE_Q_P = 128 ;
273- constexpr uint32_t NUM_WARPS_Q_P = get_num_warps_q (CTA_TILE_Q_P);
274- constexpr uint32_t NUM_WARPS_KV_P = get_num_warps_kv (CTA_TILE_Q_P);
275- constexpr uint32_t NUM_MMA_Q_P = get_num_mma_q (CTA_TILE_Q_P);
276-
277- using DTypeQKAccum_P =
278- typename std::conditional<USE_FP16_QK_REDUCTION && std::is_same_v<DTypeQ_P, half>, half,
279- float >::type;
280-
281- // we expect each sm execute two threadblocks
282- // TODO(Zihao): fix the following computation
283- const int num_ctas_per_sm_p =
284- max_smem_per_sm > (16 * HEAD_DIM_QK * sizeof (DTypeQ_P) * 16 ) ? 2 : 1 ;
285- const int max_smem_per_threadblock_p = max_smem_per_sm / num_ctas_per_sm_p;
286-
287- constexpr uint32_t max_num_mma_kv_reg_p =
288- (HEAD_DIM_VO >= 128 && NUM_MMA_Q_P == 2 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama &&
289- !USE_FP16_QK_REDUCTION)
290- ? 2
291- : (8 / NUM_MMA_Q_P);
292- // TODO(Zihao): fix the following computation
293- const uint32_t max_num_mma_kv_smem_p =
294- (max_smem_per_threadblock_p / (16 * HEAD_DIM_QK * sizeof (DTypeQ_P)) -
295- NUM_MMA_Q_P * NUM_WARPS_Q_P) /
296- (2 * NUM_WARPS_KV_P);
297-
298- // control NUM_MMA_KV for maximum warp occupancy
299- DISPATCH_NUM_MMA_KV (min (max_num_mma_kv_smem_p, max_num_mma_kv_reg_p), NUM_MMA_KV_P, {
300- using KTraits_P = KernelTraits<MASK_MODE_P, CTA_TILE_Q_P, NUM_MMA_Q_P, NUM_MMA_KV_P,
301- NUM_MMA_D_QK, NUM_MMA_D_VO, NUM_WARPS_Q_P, NUM_WARPS_KV_P,
302- POS_ENCODING_MODE, DTypeQ_P, DTypeKV_P, DTypeO_P, DTypeQKAccum_P,
303- typename PrefillParams::IdType, PrefillAttentionVariant>;
304-
305- if constexpr (KTraits_P::IsInvalid ()) {
306- // Invalid configuration, skip
307- std::ostringstream err_msg;
308- err_msg << " FlashInfer Internal Error: Invalid configuration : NUM_MMA_Q=" << NUM_MMA_Q_P
309- << " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO
310- << " NUM_MMA_KV=" << NUM_MMA_KV_P << " NUM_WARPS_Q=" << NUM_WARPS_Q_P
311- << " NUM_WARPS_KV=" << NUM_WARPS_KV_P
312- << " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)"
313- " and report the issue to the developers." ;
314- FLASHINFER_ERROR (err_msg.str ());
315- } else {
316- // Decode stuff
317- // TODO: Is there a way to avoid this nested dispatch?
318- DISPATCH_NUM_MMA_KV (min (max_num_mma_kv_smem_d, max_num_mma_kv_reg_d), NUM_MMA_KV_D, {
319- using KTraits_D =
320- KernelTraits<MASK_MODE_D, CTA_TILE_Q_D, NUM_MMA_Q_D, NUM_MMA_KV_D, NUM_MMA_D_QK,
321- NUM_MMA_D_VO, NUM_WARPS_Q_D, NUM_WARPS_KV_D, POS_ENCODING_MODE, DTypeQ_D,
322- DTypeKV_D, DTypeO_D, DTypeQKAccum_D, typename DecodeParams::IdType,
323- DecodeAttentionVariant>;
324- if constexpr (KTraits_D::IsInvalid ()) {
325- // Invalid configuration, skip
326- std::ostringstream err_msg;
327- err_msg << " FlashInfer Internal Error: Invalid configuration : NUM_MMA_Q=" << NUM_MMA_Q_D
328- << " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO
329- << " NUM_MMA_KV=" << NUM_MMA_KV_D << " NUM_WARPS_Q=" << NUM_WARPS_Q_D
330- << " NUM_WARPS_KV=" << NUM_WARPS_KV_D
331- << " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)"
332- " and report the issue to the developers." ;
333- FLASHINFER_ERROR (err_msg.str ());
334- } else {
335- // End decode stuff
336- constexpr uint32_t num_threads_p = (NUM_WARPS_Q_P * NUM_WARPS_KV_P) * WARP_SIZE;
337- size_t smem_size_p = sizeof (typename KTraits_P::SharedStorage);
338- size_t smem_size_d = sizeof (typename KTraits_D::SharedStorage);
339-
340- auto kernel =
341- PODWithKVCacheTensorKernel<KTraits_P, KTraits_D, true , PrefillParams, DecodeParams>;
342- // Prefill: decide num_splits for split-kv
343- int num_blocks_per_sm = 0 ;
344- int num_sm = 0 ;
345- FLASHINFER_CUDA_CALL (
346- cudaDeviceGetAttribute (&num_sm, cudaDevAttrMultiProcessorCount, dev_id));
347- FLASHINFER_CUDA_CALL (cudaOccupancyMaxActiveBlocksPerMultiprocessor (
348- &num_blocks_per_sm, kernel, num_threads_p, smem_size_p));
349- uint32_t max_num_kv_chunks =
350- (num_blocks_per_sm * num_sm) /
351- (num_kv_heads * ceil_div (qo_len * group_size, KTraits_P::CTA_TILE_Q));
352- uint32_t num_chunks;
353- if (max_num_kv_chunks > 0 ) {
354- uint32_t chunk_size = max (ceil_div (kv_len, max_num_kv_chunks), 256 );
355- num_chunks = ceil_div (kv_len, chunk_size);
272+ DISPATCH_CTA_TILE_Q (cta_tile_q_p, CTA_TILE_Q_P, {
273+ constexpr uint32_t NUM_WARPS_Q_P = get_num_warps_q (CTA_TILE_Q_P);
274+ constexpr uint32_t NUM_WARPS_KV_P = get_num_warps_kv (CTA_TILE_Q_P);
275+ constexpr uint32_t NUM_MMA_Q_P = get_num_mma_q (CTA_TILE_Q_P);
276+
277+ using DTypeQKAccum_P =
278+ typename std::conditional<USE_FP16_QK_REDUCTION && std::is_same_v<DTypeQ_P, half>, half,
279+ float >::type;
280+
281+ // we expect each sm execute two threadblocks
282+ // TODO(Zihao): fix the following computation
283+ const int num_ctas_per_sm_p =
284+ max_smem_per_sm > (16 * HEAD_DIM_QK * sizeof (DTypeQ_P) * 16 ) ? 2 : 1 ;
285+ const int max_smem_per_threadblock_p = max_smem_per_sm / num_ctas_per_sm_p;
286+
287+ constexpr uint32_t max_num_mma_kv_reg_p =
288+ (HEAD_DIM_VO >= 128 && NUM_MMA_Q_P == 2 &&
289+ POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && !USE_FP16_QK_REDUCTION)
290+ ? 2
291+ : (8 / NUM_MMA_Q_P);
292+ // TODO(Zihao): fix the following computation
293+ const uint32_t max_num_mma_kv_smem_p =
294+ (max_smem_per_threadblock_p / (16 * HEAD_DIM_QK * sizeof (DTypeQ_P)) -
295+ NUM_MMA_Q_P * NUM_WARPS_Q_P) /
296+ (2 * NUM_WARPS_KV_P);
297+
298+ // control NUM_MMA_KV for maximum warp occupancy
299+ DISPATCH_NUM_MMA_KV (min (max_num_mma_kv_smem_p, max_num_mma_kv_reg_p), NUM_MMA_KV_P, {
300+ using KTraits_P =
301+ KernelTraits<MASK_MODE_P, CTA_TILE_Q_P, NUM_MMA_Q_P, NUM_MMA_KV_P, NUM_MMA_D_QK,
302+ NUM_MMA_D_VO, NUM_WARPS_Q_P, NUM_WARPS_KV_P, POS_ENCODING_MODE, DTypeQ_P,
303+ DTypeKV_P, DTypeO_P, DTypeQKAccum_P, typename PrefillParams::IdType,
304+ PrefillAttentionVariant>;
305+
306+ if constexpr (KTraits_P::IsInvalid ()) {
307+ // Invalid configuration, skip
308+ std::ostringstream err_msg;
309+ err_msg << " FlashInfer Internal Error: Invalid configuration : NUM_MMA_Q=" << NUM_MMA_Q_P
310+ << " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO
311+ << " NUM_MMA_KV=" << NUM_MMA_KV_P << " NUM_WARPS_Q=" << NUM_WARPS_Q_P
312+ << " NUM_WARPS_KV=" << NUM_WARPS_KV_P
313+ << " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)"
314+ " and report the issue to the developers." ;
315+ FLASHINFER_ERROR (err_msg.str ());
316+ } else {
317+ // Decode stuff
318+ // TODO: Is there a way to avoid this nested dispatch?
319+ DISPATCH_NUM_MMA_KV (min (max_num_mma_kv_smem_d, max_num_mma_kv_reg_d), NUM_MMA_KV_D, {
320+ using KTraits_D =
321+ KernelTraits<MASK_MODE_D, CTA_TILE_Q_D, NUM_MMA_Q_D, NUM_MMA_KV_D, NUM_MMA_D_QK,
322+ NUM_MMA_D_VO, NUM_WARPS_Q_D, NUM_WARPS_KV_D, POS_ENCODING_MODE, DTypeQ_D,
323+ DTypeKV_D, DTypeO_D, DTypeQKAccum_D, typename DecodeParams::IdType,
324+ DecodeAttentionVariant>;
325+ if constexpr (KTraits_D::IsInvalid ()) {
326+ // Invalid configuration, skip
327+ std::ostringstream err_msg;
328+ err_msg
329+ << " FlashInfer Internal Error: Invalid configuration : NUM_MMA_Q=" << NUM_MMA_Q_D
330+ << " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO
331+ << " NUM_MMA_KV=" << NUM_MMA_KV_D << " NUM_WARPS_Q=" << NUM_WARPS_Q_D
332+ << " NUM_WARPS_KV=" << NUM_WARPS_KV_D
333+ << " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)"
334+ " and report the issue to the developers." ;
335+ FLASHINFER_ERROR (err_msg.str ());
356336 } else {
357- num_chunks = 0 ;
358- }
337+ // End decode stuff
338+ constexpr uint32_t num_threads_p = (NUM_WARPS_Q_P * NUM_WARPS_KV_P) * WARP_SIZE;
339+ size_t smem_size_p = sizeof (typename KTraits_P::SharedStorage);
340+ size_t smem_size_d = sizeof (typename KTraits_D::SharedStorage);
359341
360- // Setup new prefill params if (not) split
361- auto o_p = prefill_params.o ;
362- auto lse_p = prefill_params.lse ;
363- float * tmp_lse = (float *)(tmp_p + num_chunks * qo_len * num_qo_heads * HEAD_DIM_VO);
364- if (num_chunks <= 1 || tmp_p == nullptr ) {
365- // Enough parallelism, do not split-kv
366- prefill_params.partition_kv = 0 ;
367- kernel = PODWithKVCacheTensorKernel<KTraits_P, KTraits_D, false , PrefillParams,
368- DecodeParams>;
369- } else {
370- // Use cooperative groups to increase occupancy
371- prefill_params.partition_kv = num_chunks;
372- prefill_params.o = tmp_p;
373- prefill_params.lse = tmp_lse;
374- kernel =
342+ auto kernel =
375343 PODWithKVCacheTensorKernel<KTraits_P, KTraits_D, true , PrefillParams, DecodeParams>;
376- }
344+ // Prefill: decide num_splits for split-kv
345+ int num_blocks_per_sm = 0 ;
346+ int num_sm = 0 ;
347+ FLASHINFER_CUDA_CALL (
348+ cudaDeviceGetAttribute (&num_sm, cudaDevAttrMultiProcessorCount, dev_id));
349+ FLASHINFER_CUDA_CALL (cudaOccupancyMaxActiveBlocksPerMultiprocessor (
350+ &num_blocks_per_sm, kernel, num_threads_p, smem_size_p));
351+ uint32_t max_num_kv_chunks =
352+ (num_blocks_per_sm * num_sm) /
353+ (num_kv_heads * ceil_div (qo_len * group_size, KTraits_P::CTA_TILE_Q));
354+ uint32_t num_chunks;
355+ if (max_num_kv_chunks > 0 ) {
356+ uint32_t chunk_size = max (ceil_div (kv_len, max_num_kv_chunks), 256 );
357+ num_chunks = ceil_div (kv_len, chunk_size);
358+ } else {
359+ num_chunks = 0 ;
360+ }
377361
378- // Setup new decode params if (not) split
379- auto o_d = decode_params.o ;
380- auto lse_d = decode_params.lse ;
381- if (tmp_v == nullptr ) {
382- // do not partition kv
383- decode_params.partition_kv = false ;
384- } else {
385- decode_params.partition_kv = true ;
386- decode_params.o = tmp_v;
387- decode_params.lse = tmp_s;
388- }
389- uint32_t xsize = ceil_div (qo_len * group_size, KTraits_P::CTA_TILE_Q);
390- int nblks_p (xsize * (prefill_params.partition_kv ? prefill_params.partition_kv : 1 ) *
391- num_kv_heads);
392- int nthrs_p (32 * NUM_WARPS_Q_P * NUM_WARPS_KV_P);
393-
394- int nblks_d (padded_batch_size_d * 1 * num_kv_heads);
395- int nthrs_d (32 * NUM_WARPS_Q_D * NUM_WARPS_KV_D);
396-
397- // ******* Select final combined sizes here ******* /
398- size_t smem_size = max (smem_size_p, smem_size_d);
399- int nblks = nblks_p + nblks_d;
400- int nthrs = max (nthrs_p, nthrs_d);
401-
402- // printf("Smem: prefill %zu, decode %zu, total %zu\n", smem_size_p, smem_size_d,
403- // smem_size); printf("Blocks: prefill %d, decode %d, total %d\n", nblks_p, nblks_d,
404- // nblks); printf("Threads: prefill %d, decode %d, total %d\n", nthrs_p, nthrs_d, nthrs);
405- // ************************************************ /
406-
407- static int * tbAssign = nullptr ;
408- if (tbAssign == nullptr ) cudaMalloc (&tbAssign, sizeof (int ) * (num_sm + 2 ));
409- cudaMemset (tbAssign, 0 , sizeof (int ) * (num_sm + 2 ));
410-
411- // Setup kernel arguments
412- void * args[] = {(void *)&xsize, (void *)&prefill_params, (void *)&decode_params,
413- (void *)&tbAssign};
414- FLASHINFER_CUDA_CALL (
415- cudaFuncSetAttribute (kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
416- // Launch kernel
417- FLASHINFER_CUDA_CALL (
418- cudaLaunchKernel ((void *)kernel, nblks, nthrs, args, smem_size, stream));
419-
420- // Post-kernel stuff for split-kv prefill
421- if (!(num_chunks <= 1 || tmp_p == nullptr )) {
422- if constexpr (PrefillAttentionVariant::use_softmax) {
423- FLASHINFER_CUDA_CALL (MergeStates (tmp_p, tmp_lse, o_p, lse_p, num_chunks, qo_len,
424- num_qo_heads, HEAD_DIM_VO, stream));
362+ // Setup new prefill params if (not) split
363+ auto o_p = prefill_params.o ;
364+ auto lse_p = prefill_params.lse ;
365+ float * tmp_lse = (float *)(tmp_p + num_chunks * qo_len * num_qo_heads * HEAD_DIM_VO);
366+ if (num_chunks <= 1 || tmp_p == nullptr ) {
367+ // Enough parallelism, do not split-kv
368+ prefill_params.partition_kv = 0 ;
369+ kernel = PODWithKVCacheTensorKernel<KTraits_P, KTraits_D, false , PrefillParams,
370+ DecodeParams>;
425371 } else {
426- FLASHINFER_CUDA_CALL (
427- AttentionSum (tmp_p, o_p, num_chunks, qo_len, num_qo_heads, HEAD_DIM_VO, stream));
372+ // Use cooperative groups to increase occupancy
373+ prefill_params.partition_kv = num_chunks;
374+ prefill_params.o = tmp_p;
375+ prefill_params.lse = tmp_lse;
376+ kernel = PODWithKVCacheTensorKernel<KTraits_P, KTraits_D, true , PrefillParams,
377+ DecodeParams>;
428378 }
429- }
430- // Post-kernel stuff for split-kv decode
431- if (tmp_v != nullptr ) {
432- if constexpr (DecodeAttentionVariant::use_softmax) {
433- FLASHINFER_CUDA_CALL (VariableLengthMergeStates (
434- tmp_v, tmp_s, decode_params.merge_indptr , o_d, lse_d,
435- decode_params.max_total_num_rows , decode_params.total_num_rows , num_qo_heads,
436- HEAD_DIM_VO, stream));
379+
380+ // Setup new decode params if (not) split
381+ auto o_d = decode_params.o ;
382+ auto lse_d = decode_params.lse ;
383+ if (tmp_v == nullptr ) {
384+ // do not partition kv
385+ decode_params.partition_kv = false ;
437386 } else {
438- FLASHINFER_CUDA_CALL (VariableLengthAttentionSum (
439- tmp_v, decode_params.merge_indptr , o_d, decode_params.max_total_num_rows ,
440- decode_params.total_num_rows , num_qo_heads, HEAD_DIM_VO, stream));
387+ decode_params.partition_kv = true ;
388+ decode_params.o = tmp_v;
389+ decode_params.lse = tmp_s;
390+ }
391+ uint32_t xsize = ceil_div (qo_len * group_size, KTraits_P::CTA_TILE_Q);
392+ int nblks_p (xsize * (prefill_params.partition_kv ? prefill_params.partition_kv : 1 ) *
393+ num_kv_heads);
394+ int nthrs_p (32 * NUM_WARPS_Q_P * NUM_WARPS_KV_P);
395+
396+ int nblks_d (padded_batch_size_d * 1 * num_kv_heads);
397+ int nthrs_d (32 * NUM_WARPS_Q_D * NUM_WARPS_KV_D);
398+
399+ // ******* Select final combined sizes here ******* /
400+ size_t smem_size = max (smem_size_p, smem_size_d);
401+ int nblks = nblks_p + nblks_d;
402+ int nthrs = max (nthrs_p, nthrs_d);
403+
404+ // printf("Smem: prefill %zu, decode %zu, total %zu\n", smem_size_p, smem_size_d,
405+ // smem_size); printf("Blocks: prefill %d, decode %d, total %d\n", nblks_p, nblks_d,
406+ // nblks); printf("Threads: prefill %d, decode %d, total %d\n", nthrs_p, nthrs_d,
407+ // nthrs);
408+ // ************************************************ /
409+
410+ static int * tbAssign = nullptr ;
411+ if (tbAssign == nullptr ) cudaMalloc (&tbAssign, sizeof (int ) * (num_sm + 2 ));
412+ cudaMemset (tbAssign, 0 , sizeof (int ) * (num_sm + 2 ));
413+
414+ // Setup kernel arguments
415+ void * args[] = {(void *)&xsize, (void *)&prefill_params, (void *)&decode_params,
416+ (void *)&tbAssign};
417+ FLASHINFER_CUDA_CALL (cudaFuncSetAttribute (
418+ kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
419+ // Launch kernel
420+ FLASHINFER_CUDA_CALL (
421+ cudaLaunchKernel ((void *)kernel, nblks, nthrs, args, smem_size, stream));
422+
423+ // Post-kernel stuff for split-kv prefill
424+ if (!(num_chunks <= 1 || tmp_p == nullptr )) {
425+ if constexpr (PrefillAttentionVariant::use_softmax) {
426+ FLASHINFER_CUDA_CALL (MergeStates (tmp_p, tmp_lse, o_p, lse_p, num_chunks, qo_len,
427+ num_qo_heads, HEAD_DIM_VO, stream));
428+ } else {
429+ FLASHINFER_CUDA_CALL (AttentionSum (tmp_p, o_p, num_chunks, qo_len, num_qo_heads,
430+ HEAD_DIM_VO, stream));
431+ }
432+ }
433+ // Post-kernel stuff for split-kv decode
434+ if (tmp_v != nullptr ) {
435+ if constexpr (DecodeAttentionVariant::use_softmax) {
436+ FLASHINFER_CUDA_CALL (VariableLengthMergeStates (
437+ tmp_v, tmp_s, decode_params.merge_indptr , o_d, lse_d,
438+ decode_params.max_total_num_rows , decode_params.total_num_rows , num_qo_heads,
439+ HEAD_DIM_VO, stream));
440+ } else {
441+ FLASHINFER_CUDA_CALL (VariableLengthAttentionSum (
442+ tmp_v, decode_params.merge_indptr , o_d, decode_params.max_total_num_rows ,
443+ decode_params.total_num_rows , num_qo_heads, HEAD_DIM_VO, stream));
444+ }
441445 }
442446 }
443- }
444- });
445- }
447+ });
448+ }
449+ });
446450 });
447- // });
448451 return cudaSuccess;
449452}
450453
0 commit comments