Skip to content

Commit 55a6668

Browse files
AKKamathAditya K Kamath
andauthored
bugfix: Fix POD JIT bugs (flashinfer-ai#971)
Removes AOT header files from JIT compilation flow and corrects some typos so that JIT compilation of POD doesn't fail anymore. --------- Co-authored-by: Aditya K Kamath <[email protected]>
1 parent 61e049a commit 55a6668

File tree

4 files changed

+172
-173
lines changed

4 files changed

+172
-173
lines changed

csrc/pod.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
#include <flashinfer/pos_enc.cuh>
1818
#include <optional>
1919

20-
#include "aot_extension_utils.h"
2120
#include "pod_config.inc"
2221
#include "pytorch_conversion_utils.h"
2322
#include "pytorch_extension_utils.h"

csrc/pod_customize_config.jinja

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,8 @@ using DecodeParams = BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, IdType>;
3535

3636
#define DISPATCH_context(MASK_MODE_P, MASK_MODE_D, DTypeQ, DTypeKV, HEAD_DIM_QK, \
3737
USE_SLIDING_WINDOW_P, USE_SLIDING_WINDOW_D, USE_LOGITS_SOFT_CAP, ...) \
38-
DISPATCH_mask_mode(mask_mode_p, MASK_MODE_P, [&] { \
39-
return DISPATCH_mask_mode(mask_mode_d, MASK_MODE_D, [&] { \
38+
DISPATCH_MASK_MODE(mask_mode_p, MASK_MODE_P, { \
39+
DISPATCH_MASK_MODE(mask_mode_d, MASK_MODE_D, { \
4040
__VA_ARGS__(); \
41-
return true; \
4241
}); \
4342
});

csrc/pod_kernel_inst.jinja

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010

1111
#include "pytorch_conversion_utils.h"
1212
#include "pytorch_extension_utils.h"
13-
#include "aot_default_additional_params.h"
14-
#include "aot_extension_utils.h"
1513

1614
#include "pod_config.inc"
1715

include/flashinfer/attention/pod.cuh

Lines changed: 170 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params,
205205
const uint_fastdiv group_size_fastdiv(group_size);
206206
constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16;
207207
constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16;
208+
208209
uint32_t cta_tile_q_p = 0;
209210
int64_t unpacked_qo_len = qo_len * group_size;
210211
if (unpacked_qo_len > 64 && HEAD_DIM_VO < 256) {
@@ -268,183 +269,185 @@ cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params,
268269
NUM_MMA_Q_D * NUM_WARPS_Q_D) /
269270
(2 * NUM_WARPS_KV_D);
270271

271-
// DISPATCH_CTA_TILE_Q(cta_tile_q_p, CTA_TILE_Q_P, {
272-
constexpr size_t CTA_TILE_Q_P = 128;
273-
constexpr uint32_t NUM_WARPS_Q_P = get_num_warps_q(CTA_TILE_Q_P);
274-
constexpr uint32_t NUM_WARPS_KV_P = get_num_warps_kv(CTA_TILE_Q_P);
275-
constexpr uint32_t NUM_MMA_Q_P = get_num_mma_q(CTA_TILE_Q_P);
276-
277-
using DTypeQKAccum_P =
278-
typename std::conditional<USE_FP16_QK_REDUCTION && std::is_same_v<DTypeQ_P, half>, half,
279-
float>::type;
280-
281-
// we expect each sm execute two threadblocks
282-
// TODO(Zihao): fix the following computation
283-
const int num_ctas_per_sm_p =
284-
max_smem_per_sm > (16 * HEAD_DIM_QK * sizeof(DTypeQ_P) * 16) ? 2 : 1;
285-
const int max_smem_per_threadblock_p = max_smem_per_sm / num_ctas_per_sm_p;
286-
287-
constexpr uint32_t max_num_mma_kv_reg_p =
288-
(HEAD_DIM_VO >= 128 && NUM_MMA_Q_P == 2 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama &&
289-
!USE_FP16_QK_REDUCTION)
290-
? 2
291-
: (8 / NUM_MMA_Q_P);
292-
// TODO(Zihao): fix the following computation
293-
const uint32_t max_num_mma_kv_smem_p =
294-
(max_smem_per_threadblock_p / (16 * HEAD_DIM_QK * sizeof(DTypeQ_P)) -
295-
NUM_MMA_Q_P * NUM_WARPS_Q_P) /
296-
(2 * NUM_WARPS_KV_P);
297-
298-
// control NUM_MMA_KV for maximum warp occupancy
299-
DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem_p, max_num_mma_kv_reg_p), NUM_MMA_KV_P, {
300-
using KTraits_P = KernelTraits<MASK_MODE_P, CTA_TILE_Q_P, NUM_MMA_Q_P, NUM_MMA_KV_P,
301-
NUM_MMA_D_QK, NUM_MMA_D_VO, NUM_WARPS_Q_P, NUM_WARPS_KV_P,
302-
POS_ENCODING_MODE, DTypeQ_P, DTypeKV_P, DTypeO_P, DTypeQKAccum_P,
303-
typename PrefillParams::IdType, PrefillAttentionVariant>;
304-
305-
if constexpr (KTraits_P::IsInvalid()) {
306-
// Invalid configuration, skip
307-
std::ostringstream err_msg;
308-
err_msg << "FlashInfer Internal Error: Invalid configuration : NUM_MMA_Q=" << NUM_MMA_Q_P
309-
<< " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO
310-
<< " NUM_MMA_KV=" << NUM_MMA_KV_P << " NUM_WARPS_Q=" << NUM_WARPS_Q_P
311-
<< " NUM_WARPS_KV=" << NUM_WARPS_KV_P
312-
<< " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)"
313-
" and report the issue to the developers.";
314-
FLASHINFER_ERROR(err_msg.str());
315-
} else {
316-
// Decode stuff
317-
// TODO: Is there a way to avoid this nested dispatch?
318-
DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem_d, max_num_mma_kv_reg_d), NUM_MMA_KV_D, {
319-
using KTraits_D =
320-
KernelTraits<MASK_MODE_D, CTA_TILE_Q_D, NUM_MMA_Q_D, NUM_MMA_KV_D, NUM_MMA_D_QK,
321-
NUM_MMA_D_VO, NUM_WARPS_Q_D, NUM_WARPS_KV_D, POS_ENCODING_MODE, DTypeQ_D,
322-
DTypeKV_D, DTypeO_D, DTypeQKAccum_D, typename DecodeParams::IdType,
323-
DecodeAttentionVariant>;
324-
if constexpr (KTraits_D::IsInvalid()) {
325-
// Invalid configuration, skip
326-
std::ostringstream err_msg;
327-
err_msg << "FlashInfer Internal Error: Invalid configuration : NUM_MMA_Q=" << NUM_MMA_Q_D
328-
<< " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO
329-
<< " NUM_MMA_KV=" << NUM_MMA_KV_D << " NUM_WARPS_Q=" << NUM_WARPS_Q_D
330-
<< " NUM_WARPS_KV=" << NUM_WARPS_KV_D
331-
<< " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)"
332-
" and report the issue to the developers.";
333-
FLASHINFER_ERROR(err_msg.str());
334-
} else {
335-
// End decode stuff
336-
constexpr uint32_t num_threads_p = (NUM_WARPS_Q_P * NUM_WARPS_KV_P) * WARP_SIZE;
337-
size_t smem_size_p = sizeof(typename KTraits_P::SharedStorage);
338-
size_t smem_size_d = sizeof(typename KTraits_D::SharedStorage);
339-
340-
auto kernel =
341-
PODWithKVCacheTensorKernel<KTraits_P, KTraits_D, true, PrefillParams, DecodeParams>;
342-
// Prefill: decide num_splits for split-kv
343-
int num_blocks_per_sm = 0;
344-
int num_sm = 0;
345-
FLASHINFER_CUDA_CALL(
346-
cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id));
347-
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
348-
&num_blocks_per_sm, kernel, num_threads_p, smem_size_p));
349-
uint32_t max_num_kv_chunks =
350-
(num_blocks_per_sm * num_sm) /
351-
(num_kv_heads * ceil_div(qo_len * group_size, KTraits_P::CTA_TILE_Q));
352-
uint32_t num_chunks;
353-
if (max_num_kv_chunks > 0) {
354-
uint32_t chunk_size = max(ceil_div(kv_len, max_num_kv_chunks), 256);
355-
num_chunks = ceil_div(kv_len, chunk_size);
272+
DISPATCH_CTA_TILE_Q(cta_tile_q_p, CTA_TILE_Q_P, {
273+
constexpr uint32_t NUM_WARPS_Q_P = get_num_warps_q(CTA_TILE_Q_P);
274+
constexpr uint32_t NUM_WARPS_KV_P = get_num_warps_kv(CTA_TILE_Q_P);
275+
constexpr uint32_t NUM_MMA_Q_P = get_num_mma_q(CTA_TILE_Q_P);
276+
277+
using DTypeQKAccum_P =
278+
typename std::conditional<USE_FP16_QK_REDUCTION && std::is_same_v<DTypeQ_P, half>, half,
279+
float>::type;
280+
281+
// we expect each sm execute two threadblocks
282+
// TODO(Zihao): fix the following computation
283+
const int num_ctas_per_sm_p =
284+
max_smem_per_sm > (16 * HEAD_DIM_QK * sizeof(DTypeQ_P) * 16) ? 2 : 1;
285+
const int max_smem_per_threadblock_p = max_smem_per_sm / num_ctas_per_sm_p;
286+
287+
constexpr uint32_t max_num_mma_kv_reg_p =
288+
(HEAD_DIM_VO >= 128 && NUM_MMA_Q_P == 2 &&
289+
POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && !USE_FP16_QK_REDUCTION)
290+
? 2
291+
: (8 / NUM_MMA_Q_P);
292+
// TODO(Zihao): fix the following computation
293+
const uint32_t max_num_mma_kv_smem_p =
294+
(max_smem_per_threadblock_p / (16 * HEAD_DIM_QK * sizeof(DTypeQ_P)) -
295+
NUM_MMA_Q_P * NUM_WARPS_Q_P) /
296+
(2 * NUM_WARPS_KV_P);
297+
298+
// control NUM_MMA_KV for maximum warp occupancy
299+
DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem_p, max_num_mma_kv_reg_p), NUM_MMA_KV_P, {
300+
using KTraits_P =
301+
KernelTraits<MASK_MODE_P, CTA_TILE_Q_P, NUM_MMA_Q_P, NUM_MMA_KV_P, NUM_MMA_D_QK,
302+
NUM_MMA_D_VO, NUM_WARPS_Q_P, NUM_WARPS_KV_P, POS_ENCODING_MODE, DTypeQ_P,
303+
DTypeKV_P, DTypeO_P, DTypeQKAccum_P, typename PrefillParams::IdType,
304+
PrefillAttentionVariant>;
305+
306+
if constexpr (KTraits_P::IsInvalid()) {
307+
// Invalid configuration, skip
308+
std::ostringstream err_msg;
309+
err_msg << "FlashInfer Internal Error: Invalid configuration : NUM_MMA_Q=" << NUM_MMA_Q_P
310+
<< " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO
311+
<< " NUM_MMA_KV=" << NUM_MMA_KV_P << " NUM_WARPS_Q=" << NUM_WARPS_Q_P
312+
<< " NUM_WARPS_KV=" << NUM_WARPS_KV_P
313+
<< " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)"
314+
" and report the issue to the developers.";
315+
FLASHINFER_ERROR(err_msg.str());
316+
} else {
317+
// Decode stuff
318+
// TODO: Is there a way to avoid this nested dispatch?
319+
DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem_d, max_num_mma_kv_reg_d), NUM_MMA_KV_D, {
320+
using KTraits_D =
321+
KernelTraits<MASK_MODE_D, CTA_TILE_Q_D, NUM_MMA_Q_D, NUM_MMA_KV_D, NUM_MMA_D_QK,
322+
NUM_MMA_D_VO, NUM_WARPS_Q_D, NUM_WARPS_KV_D, POS_ENCODING_MODE, DTypeQ_D,
323+
DTypeKV_D, DTypeO_D, DTypeQKAccum_D, typename DecodeParams::IdType,
324+
DecodeAttentionVariant>;
325+
if constexpr (KTraits_D::IsInvalid()) {
326+
// Invalid configuration, skip
327+
std::ostringstream err_msg;
328+
err_msg
329+
<< "FlashInfer Internal Error: Invalid configuration : NUM_MMA_Q=" << NUM_MMA_Q_D
330+
<< " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO
331+
<< " NUM_MMA_KV=" << NUM_MMA_KV_D << " NUM_WARPS_Q=" << NUM_WARPS_Q_D
332+
<< " NUM_WARPS_KV=" << NUM_WARPS_KV_D
333+
<< " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)"
334+
" and report the issue to the developers.";
335+
FLASHINFER_ERROR(err_msg.str());
356336
} else {
357-
num_chunks = 0;
358-
}
337+
// End decode stuff
338+
constexpr uint32_t num_threads_p = (NUM_WARPS_Q_P * NUM_WARPS_KV_P) * WARP_SIZE;
339+
size_t smem_size_p = sizeof(typename KTraits_P::SharedStorage);
340+
size_t smem_size_d = sizeof(typename KTraits_D::SharedStorage);
359341

360-
// Setup new prefill params if (not) split
361-
auto o_p = prefill_params.o;
362-
auto lse_p = prefill_params.lse;
363-
float* tmp_lse = (float*)(tmp_p + num_chunks * qo_len * num_qo_heads * HEAD_DIM_VO);
364-
if (num_chunks <= 1 || tmp_p == nullptr) {
365-
// Enough parallelism, do not split-kv
366-
prefill_params.partition_kv = 0;
367-
kernel = PODWithKVCacheTensorKernel<KTraits_P, KTraits_D, false, PrefillParams,
368-
DecodeParams>;
369-
} else {
370-
// Use cooperative groups to increase occupancy
371-
prefill_params.partition_kv = num_chunks;
372-
prefill_params.o = tmp_p;
373-
prefill_params.lse = tmp_lse;
374-
kernel =
342+
auto kernel =
375343
PODWithKVCacheTensorKernel<KTraits_P, KTraits_D, true, PrefillParams, DecodeParams>;
376-
}
344+
// Prefill: decide num_splits for split-kv
345+
int num_blocks_per_sm = 0;
346+
int num_sm = 0;
347+
FLASHINFER_CUDA_CALL(
348+
cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id));
349+
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
350+
&num_blocks_per_sm, kernel, num_threads_p, smem_size_p));
351+
uint32_t max_num_kv_chunks =
352+
(num_blocks_per_sm * num_sm) /
353+
(num_kv_heads * ceil_div(qo_len * group_size, KTraits_P::CTA_TILE_Q));
354+
uint32_t num_chunks;
355+
if (max_num_kv_chunks > 0) {
356+
uint32_t chunk_size = max(ceil_div(kv_len, max_num_kv_chunks), 256);
357+
num_chunks = ceil_div(kv_len, chunk_size);
358+
} else {
359+
num_chunks = 0;
360+
}
377361

378-
// Setup new decode params if (not) split
379-
auto o_d = decode_params.o;
380-
auto lse_d = decode_params.lse;
381-
if (tmp_v == nullptr) {
382-
// do not partition kv
383-
decode_params.partition_kv = false;
384-
} else {
385-
decode_params.partition_kv = true;
386-
decode_params.o = tmp_v;
387-
decode_params.lse = tmp_s;
388-
}
389-
uint32_t xsize = ceil_div(qo_len * group_size, KTraits_P::CTA_TILE_Q);
390-
int nblks_p(xsize * (prefill_params.partition_kv ? prefill_params.partition_kv : 1) *
391-
num_kv_heads);
392-
int nthrs_p(32 * NUM_WARPS_Q_P * NUM_WARPS_KV_P);
393-
394-
int nblks_d(padded_batch_size_d * 1 * num_kv_heads);
395-
int nthrs_d(32 * NUM_WARPS_Q_D * NUM_WARPS_KV_D);
396-
397-
// ******* Select final combined sizes here ******* /
398-
size_t smem_size = max(smem_size_p, smem_size_d);
399-
int nblks = nblks_p + nblks_d;
400-
int nthrs = max(nthrs_p, nthrs_d);
401-
402-
// printf("Smem: prefill %zu, decode %zu, total %zu\n", smem_size_p, smem_size_d,
403-
// smem_size); printf("Blocks: prefill %d, decode %d, total %d\n", nblks_p, nblks_d,
404-
// nblks); printf("Threads: prefill %d, decode %d, total %d\n", nthrs_p, nthrs_d, nthrs);
405-
// ************************************************ /
406-
407-
static int* tbAssign = nullptr;
408-
if (tbAssign == nullptr) cudaMalloc(&tbAssign, sizeof(int) * (num_sm + 2));
409-
cudaMemset(tbAssign, 0, sizeof(int) * (num_sm + 2));
410-
411-
// Setup kernel arguments
412-
void* args[] = {(void*)&xsize, (void*)&prefill_params, (void*)&decode_params,
413-
(void*)&tbAssign};
414-
FLASHINFER_CUDA_CALL(
415-
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
416-
// Launch kernel
417-
FLASHINFER_CUDA_CALL(
418-
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
419-
420-
// Post-kernel stuff for split-kv prefill
421-
if (!(num_chunks <= 1 || tmp_p == nullptr)) {
422-
if constexpr (PrefillAttentionVariant::use_softmax) {
423-
FLASHINFER_CUDA_CALL(MergeStates(tmp_p, tmp_lse, o_p, lse_p, num_chunks, qo_len,
424-
num_qo_heads, HEAD_DIM_VO, stream));
362+
// Setup new prefill params if (not) split
363+
auto o_p = prefill_params.o;
364+
auto lse_p = prefill_params.lse;
365+
float* tmp_lse = (float*)(tmp_p + num_chunks * qo_len * num_qo_heads * HEAD_DIM_VO);
366+
if (num_chunks <= 1 || tmp_p == nullptr) {
367+
// Enough parallelism, do not split-kv
368+
prefill_params.partition_kv = 0;
369+
kernel = PODWithKVCacheTensorKernel<KTraits_P, KTraits_D, false, PrefillParams,
370+
DecodeParams>;
425371
} else {
426-
FLASHINFER_CUDA_CALL(
427-
AttentionSum(tmp_p, o_p, num_chunks, qo_len, num_qo_heads, HEAD_DIM_VO, stream));
372+
// Use cooperative groups to increase occupancy
373+
prefill_params.partition_kv = num_chunks;
374+
prefill_params.o = tmp_p;
375+
prefill_params.lse = tmp_lse;
376+
kernel = PODWithKVCacheTensorKernel<KTraits_P, KTraits_D, true, PrefillParams,
377+
DecodeParams>;
428378
}
429-
}
430-
// Post-kernel stuff for split-kv decode
431-
if (tmp_v != nullptr) {
432-
if constexpr (DecodeAttentionVariant::use_softmax) {
433-
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(
434-
tmp_v, tmp_s, decode_params.merge_indptr, o_d, lse_d,
435-
decode_params.max_total_num_rows, decode_params.total_num_rows, num_qo_heads,
436-
HEAD_DIM_VO, stream));
379+
380+
// Setup new decode params if (not) split
381+
auto o_d = decode_params.o;
382+
auto lse_d = decode_params.lse;
383+
if (tmp_v == nullptr) {
384+
// do not partition kv
385+
decode_params.partition_kv = false;
437386
} else {
438-
FLASHINFER_CUDA_CALL(VariableLengthAttentionSum(
439-
tmp_v, decode_params.merge_indptr, o_d, decode_params.max_total_num_rows,
440-
decode_params.total_num_rows, num_qo_heads, HEAD_DIM_VO, stream));
387+
decode_params.partition_kv = true;
388+
decode_params.o = tmp_v;
389+
decode_params.lse = tmp_s;
390+
}
391+
uint32_t xsize = ceil_div(qo_len * group_size, KTraits_P::CTA_TILE_Q);
392+
int nblks_p(xsize * (prefill_params.partition_kv ? prefill_params.partition_kv : 1) *
393+
num_kv_heads);
394+
int nthrs_p(32 * NUM_WARPS_Q_P * NUM_WARPS_KV_P);
395+
396+
int nblks_d(padded_batch_size_d * 1 * num_kv_heads);
397+
int nthrs_d(32 * NUM_WARPS_Q_D * NUM_WARPS_KV_D);
398+
399+
// ******* Select final combined sizes here ******* /
400+
size_t smem_size = max(smem_size_p, smem_size_d);
401+
int nblks = nblks_p + nblks_d;
402+
int nthrs = max(nthrs_p, nthrs_d);
403+
404+
// printf("Smem: prefill %zu, decode %zu, total %zu\n", smem_size_p, smem_size_d,
405+
// smem_size); printf("Blocks: prefill %d, decode %d, total %d\n", nblks_p, nblks_d,
406+
// nblks); printf("Threads: prefill %d, decode %d, total %d\n", nthrs_p, nthrs_d,
407+
// nthrs);
408+
// ************************************************ /
409+
410+
static int* tbAssign = nullptr;
411+
if (tbAssign == nullptr) cudaMalloc(&tbAssign, sizeof(int) * (num_sm + 2));
412+
cudaMemset(tbAssign, 0, sizeof(int) * (num_sm + 2));
413+
414+
// Setup kernel arguments
415+
void* args[] = {(void*)&xsize, (void*)&prefill_params, (void*)&decode_params,
416+
(void*)&tbAssign};
417+
FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(
418+
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
419+
// Launch kernel
420+
FLASHINFER_CUDA_CALL(
421+
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
422+
423+
// Post-kernel stuff for split-kv prefill
424+
if (!(num_chunks <= 1 || tmp_p == nullptr)) {
425+
if constexpr (PrefillAttentionVariant::use_softmax) {
426+
FLASHINFER_CUDA_CALL(MergeStates(tmp_p, tmp_lse, o_p, lse_p, num_chunks, qo_len,
427+
num_qo_heads, HEAD_DIM_VO, stream));
428+
} else {
429+
FLASHINFER_CUDA_CALL(AttentionSum(tmp_p, o_p, num_chunks, qo_len, num_qo_heads,
430+
HEAD_DIM_VO, stream));
431+
}
432+
}
433+
// Post-kernel stuff for split-kv decode
434+
if (tmp_v != nullptr) {
435+
if constexpr (DecodeAttentionVariant::use_softmax) {
436+
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(
437+
tmp_v, tmp_s, decode_params.merge_indptr, o_d, lse_d,
438+
decode_params.max_total_num_rows, decode_params.total_num_rows, num_qo_heads,
439+
HEAD_DIM_VO, stream));
440+
} else {
441+
FLASHINFER_CUDA_CALL(VariableLengthAttentionSum(
442+
tmp_v, decode_params.merge_indptr, o_d, decode_params.max_total_num_rows,
443+
decode_params.total_num_rows, num_qo_heads, HEAD_DIM_VO, stream));
444+
}
441445
}
442446
}
443-
}
444-
});
445-
}
447+
});
448+
}
449+
});
446450
});
447-
//});
448451
return cudaSuccess;
449452
}
450453

0 commit comments

Comments
 (0)