Skip to content

Commit ba7cca0

Browse files
varun-sundar-rabindranathVarun Sundar Rabindranath
authored andcommitted
[GPTOSS][DP/EP][Marlin] Enable GPTOSS Batched DP/EP using Marlin kernels (vllm-project#25997)
Signed-off-by: Varun Sundar Rabindranath <[email protected]> Co-authored-by: Varun Sundar Rabindranath <[email protected]>
1 parent 93bbbd2 commit ba7cca0

File tree

12 files changed

+1153
-314
lines changed

12 files changed

+1153
-314
lines changed

csrc/moe/moe_align_sum_kernels.cu

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,77 @@
88

99
#include "../cuda_compat.h"
1010
#include "../dispatch_utils.h"
11+
#include "core/math.hpp"
1112

1213
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
1314

1415
namespace vllm {
1516
namespace moe {
1617

18+
namespace batched_moe_align_block_size {
19+
20+
// Note num_threads needs to be 1024 for BlockScan Reduction in the kernel.
21+
static constexpr int32_t num_threads = 1024;
22+
static constexpr int32_t num_blocks = 1;
23+
__global__ void batched_moe_align_block_size_kernel(
24+
int32_t const num_batches, int32_t const max_tokens_per_batch,
25+
int32_t const block_size, int32_t const* __restrict__ batch_num_tokens,
26+
int32_t* __restrict__ sorted_ids, int32_t* __restrict__ block_ids,
27+
int32_t* __restrict__ num_tokens_post_pad) {
28+
// TODO(varun): This is a naive implementation. Could be optimized.
29+
30+
size_t const batch_id = threadIdx.x;
31+
size_t const stride = blockDim.x * gridDim.x;
32+
int32_t const num_blocks_per_batch =
33+
CEILDIV(max_tokens_per_batch, block_size);
34+
int32_t const sorted_ids_size =
35+
num_blocks_per_batch * num_batches * block_size;
36+
int32_t const block_ids_size = sorted_ids_size / block_size;
37+
int32_t const SENTINEL =
38+
num_batches * max_tokens_per_batch; // To denote invalid entries.
39+
// Intialize sorted_ids
40+
for (size_t i = threadIdx.x; i < sorted_ids_size; i += stride) {
41+
sorted_ids[i] = SENTINEL;
42+
}
43+
// Intialize expert_ids with -1
44+
for (size_t i = threadIdx.x; i < block_ids_size; i += stride) {
45+
block_ids[i] = -1;
46+
}
47+
48+
int32_t b_num_tokens = 0;
49+
if (batch_id < num_batches) {
50+
b_num_tokens = batch_num_tokens[batch_id];
51+
}
52+
int32_t const ceil_b_num_tokens =
53+
CEILDIV(b_num_tokens, block_size) * block_size;
54+
55+
// Compute prefix sum over token counts per expert
56+
using BlockScan = cub::BlockScan<int32_t, 1024>;
57+
__shared__ typename BlockScan::TempStorage temp_storage;
58+
int cumsum_val;
59+
BlockScan(temp_storage).ExclusiveSum(ceil_b_num_tokens, cumsum_val);
60+
__syncthreads();
61+
62+
bool const is_last_batch = batch_id == (num_batches - 1);
63+
if (is_last_batch) {
64+
*num_tokens_post_pad = cumsum_val + ceil_b_num_tokens;
65+
}
66+
67+
if (batch_id < num_batches) {
68+
int32_t const batch_offset = batch_id * max_tokens_per_batch;
69+
for (size_t i = 0; i < b_num_tokens; ++i) {
70+
sorted_ids[cumsum_val + i] = batch_offset + i;
71+
}
72+
73+
int32_t const block_start = cumsum_val / block_size;
74+
int32_t const num_blocks = ceil_b_num_tokens / block_size;
75+
for (size_t i = 0; i < num_blocks; ++i) {
76+
block_ids[block_start + i] = batch_id;
77+
}
78+
}
79+
}
80+
} // namespace batched_moe_align_block_size
81+
1782
template <typename scalar_t>
1883
__global__ void moe_align_block_size_kernel(
1984
const scalar_t* __restrict__ topk_ids,
@@ -280,6 +345,33 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
280345
});
281346
}
282347

348+
void batched_moe_align_block_size(int64_t max_tokens_per_batch,
349+
int64_t block_size,
350+
torch::Tensor const& batch_num_tokens,
351+
torch::Tensor sorted_ids,
352+
torch::Tensor batch_ids,
353+
torch::Tensor num_tokens_post_pad) {
354+
namespace batched_kernel = vllm::moe::batched_moe_align_block_size;
355+
356+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
357+
int32_t const B = batch_num_tokens.size(0);
358+
int32_t const num_blocks_per_batch =
359+
round_to_next_multiple_of(max_tokens_per_batch, block_size) / block_size;
360+
int32_t const num_blocks = num_blocks_per_batch * B;
361+
int64_t const sorted_ids_size = num_blocks * block_size;
362+
363+
TORCH_CHECK(sorted_ids.size(0) == sorted_ids_size);
364+
TORCH_CHECK(batch_ids.size(0) == sorted_ids_size / block_size);
365+
TORCH_CHECK(num_tokens_post_pad.size(0) == 1);
366+
TORCH_CHECK(B <= batched_kernel::num_threads);
367+
368+
batched_kernel::batched_moe_align_block_size_kernel<<<
369+
batched_kernel::num_blocks, batched_kernel::num_threads, 0, stream>>>(
370+
B, max_tokens_per_batch, block_size, batch_num_tokens.data_ptr<int32_t>(),
371+
sorted_ids.data_ptr<int32_t>(), batch_ids.data_ptr<int32_t>(),
372+
num_tokens_post_pad.data_ptr<int32_t>());
373+
}
374+
283375
void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
284376
torch::Tensor& output) // [num_tokens, hidden_size]
285377
{

csrc/moe/moe_ops.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,14 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
1212
int64_t block_size, torch::Tensor sorted_token_ids,
1313
torch::Tensor experts_ids,
1414
torch::Tensor num_tokens_post_pad);
15+
16+
void batched_moe_align_block_size(int64_t max_tokens_per_batch,
17+
int64_t block_size,
18+
torch::Tensor const& expert_num_tokens,
19+
torch::Tensor sorted_ids,
20+
torch::Tensor expert_ids,
21+
torch::Tensor num_tokens_post_pad);
22+
1523
#ifndef USE_ROCM
1624
torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
1725
torch::Tensor b_qweight, torch::Tensor b_scales,

csrc/moe/torch_bindings.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
2222
" Tensor! num_tokens_post_pad) -> ()");
2323
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
2424

25+
// Aligning the number of tokens to be processed by each expert such
26+
// that it is divisible by the block size, but for the batched case.
27+
m.def(
28+
"batched_moe_align_block_size(int max_tokens_per_batch,"
29+
" int block_size, Tensor expert_num_tokens,"
30+
" Tensor! sorted_token_ids,"
31+
" Tensor! experts_ids,"
32+
" Tensor! num_tokens_post_pad) -> ()");
33+
m.impl("batched_moe_align_block_size", torch::kCUDA,
34+
&batched_moe_align_block_size);
35+
2536
#ifndef USE_ROCM
2637
m.def(
2738
"moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, "

docs/design/moe_kernel_features.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ To be used with a particular `FusedMoEPrepareAndFinalize` sub-class, MoE kernels
9292
| flashinfer | standard | nvfp4,</br>fp8 | T | <sup>5</sup> | N | Y | [`flashinfer_cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.flashinfer_cutlass_moe_fp4],</br>[`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] |
9393
| gpt oss triton | standard | N/A | N/A | <sup>5</sup> | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],</br>[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] |
9494
| deep gemm+triton<sup>2</sup> | standard,</br>batched | all<sup>1</sup> | G(128),A,T | silu, gelu | <sup>6</sup> | Y | [`TritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe.TritonOrDeepGemmExperts],</br>[`BatchedTritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe.BatchedTritonOrDeepGemmExperts] |
95-
| marlin | standard | <sup>3</sup> | <sup>3</sup> | silu,</br>swigluoai | Y | N | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe] |
96-
| marlin experts | standard | N/A | N/A | silu,</br>swigluoai | Y | Y | [`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts] |
95+
| marlin | standard | <sup>3</sup> | <sup>3</sup> | silu,</br>swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],</br>[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],</br>[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] |
96+
| marlin experts | standard,</br>batched | N/A | N/A | silu,</br>swigluoai | Y | Y | [`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],</br>[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] |
9797
| trtllm | standard | mxfp4,</br>nvfp4 | G(16),G(32) | <sup>5</sup> | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] |
9898
| pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] |
9999
| iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] |
@@ -115,6 +115,6 @@ The following table shows "families" of modular kernels that are intended to wor
115115

116116
| backend | `FusedMoEPrepareAndFinalize` subclasses | `FusedMoEPermuteExpertsUnpermute` subclasses |
117117
|----------------------------------|------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------|
118-
| deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,</br>`TritonExperts`,</br>`TritonOrDeepGemmExperts`,</br>`CutlassExpertsFp8`, </br>`MarlinExperts` |
119-
| deepep_low_latency,</br>pplx | `DeepEPLLPrepareAndFinalize`,</br>`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,</br>`BatchedTritonExperts`,</br>`BatchedTritonOrDeepGemmExperts`,</br>`CutlassBatchedExpertsFp8`|
120-
| flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` |
118+
| deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,</br>`TritonExperts`,</br>`TritonOrDeepGemmExperts`,</br>`CutlassExpertsFp8`, </br>`MarlinExperts` |
119+
| deepep_low_latency,</br>pplx | `DeepEPLLPrepareAndFinalize`,</br>`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,</br>`BatchedTritonExperts`,</br>`BatchedTritonOrDeepGemmExperts`,</br>`CutlassBatchedExpertsFp8`,</br>`BatchedMarlinExperts`|
120+
| flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` |

0 commit comments

Comments
 (0)