|
8 | 8 |
|
9 | 9 | #include "../cuda_compat.h" |
10 | 10 | #include "../dispatch_utils.h" |
| 11 | +#include "core/math.hpp" |
11 | 12 |
|
12 | 13 | #define CEILDIV(x, y) (((x) + (y) - 1) / (y)) |
13 | 14 |
|
14 | 15 | namespace vllm { |
15 | 16 | namespace moe { |
16 | 17 |
|
| 18 | +namespace batched_moe_align_block_size { |
| 19 | + |
| 20 | +// Note num_threads needs to be 1024 for BlockScan Reduction in the kernel. |
| 21 | +static constexpr int32_t num_threads = 1024; |
| 22 | +static constexpr int32_t num_blocks = 1; |
| 23 | +__global__ void batched_moe_align_block_size_kernel( |
| 24 | + int32_t const num_batches, int32_t const max_tokens_per_batch, |
| 25 | + int32_t const block_size, int32_t const* __restrict__ batch_num_tokens, |
| 26 | + int32_t* __restrict__ sorted_ids, int32_t* __restrict__ block_ids, |
| 27 | + int32_t* __restrict__ num_tokens_post_pad) { |
| 28 | + // TODO(varun): This is a naive implementation. Could be optimized. |
| 29 | + |
| 30 | + size_t const batch_id = threadIdx.x; |
| 31 | + size_t const stride = blockDim.x * gridDim.x; |
| 32 | + int32_t const num_blocks_per_batch = |
| 33 | + CEILDIV(max_tokens_per_batch, block_size); |
| 34 | + int32_t const sorted_ids_size = |
| 35 | + num_blocks_per_batch * num_batches * block_size; |
| 36 | + int32_t const block_ids_size = sorted_ids_size / block_size; |
| 37 | + int32_t const SENTINEL = |
| 38 | + num_batches * max_tokens_per_batch; // To denote invalid entries. |
| 39 | + // Intialize sorted_ids |
| 40 | + for (size_t i = threadIdx.x; i < sorted_ids_size; i += stride) { |
| 41 | + sorted_ids[i] = SENTINEL; |
| 42 | + } |
| 43 | + // Intialize expert_ids with -1 |
| 44 | + for (size_t i = threadIdx.x; i < block_ids_size; i += stride) { |
| 45 | + block_ids[i] = -1; |
| 46 | + } |
| 47 | + |
| 48 | + int32_t b_num_tokens = 0; |
| 49 | + if (batch_id < num_batches) { |
| 50 | + b_num_tokens = batch_num_tokens[batch_id]; |
| 51 | + } |
| 52 | + int32_t const ceil_b_num_tokens = |
| 53 | + CEILDIV(b_num_tokens, block_size) * block_size; |
| 54 | + |
| 55 | + // Compute prefix sum over token counts per expert |
| 56 | + using BlockScan = cub::BlockScan<int32_t, 1024>; |
| 57 | + __shared__ typename BlockScan::TempStorage temp_storage; |
| 58 | + int cumsum_val; |
| 59 | + BlockScan(temp_storage).ExclusiveSum(ceil_b_num_tokens, cumsum_val); |
| 60 | + __syncthreads(); |
| 61 | + |
| 62 | + bool const is_last_batch = batch_id == (num_batches - 1); |
| 63 | + if (is_last_batch) { |
| 64 | + *num_tokens_post_pad = cumsum_val + ceil_b_num_tokens; |
| 65 | + } |
| 66 | + |
| 67 | + if (batch_id < num_batches) { |
| 68 | + int32_t const batch_offset = batch_id * max_tokens_per_batch; |
| 69 | + for (size_t i = 0; i < b_num_tokens; ++i) { |
| 70 | + sorted_ids[cumsum_val + i] = batch_offset + i; |
| 71 | + } |
| 72 | + |
| 73 | + int32_t const block_start = cumsum_val / block_size; |
| 74 | + int32_t const num_blocks = ceil_b_num_tokens / block_size; |
| 75 | + for (size_t i = 0; i < num_blocks; ++i) { |
| 76 | + block_ids[block_start + i] = batch_id; |
| 77 | + } |
| 78 | + } |
| 79 | +} |
| 80 | +} // namespace batched_moe_align_block_size |
| 81 | + |
17 | 82 | template <typename scalar_t> |
18 | 83 | __global__ void moe_align_block_size_kernel( |
19 | 84 | const scalar_t* __restrict__ topk_ids, |
@@ -280,6 +345,33 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, |
280 | 345 | }); |
281 | 346 | } |
282 | 347 |
|
| 348 | +void batched_moe_align_block_size(int64_t max_tokens_per_batch, |
| 349 | + int64_t block_size, |
| 350 | + torch::Tensor const& batch_num_tokens, |
| 351 | + torch::Tensor sorted_ids, |
| 352 | + torch::Tensor batch_ids, |
| 353 | + torch::Tensor num_tokens_post_pad) { |
| 354 | + namespace batched_kernel = vllm::moe::batched_moe_align_block_size; |
| 355 | + |
| 356 | + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
| 357 | + int32_t const B = batch_num_tokens.size(0); |
| 358 | + int32_t const num_blocks_per_batch = |
| 359 | + round_to_next_multiple_of(max_tokens_per_batch, block_size) / block_size; |
| 360 | + int32_t const num_blocks = num_blocks_per_batch * B; |
| 361 | + int64_t const sorted_ids_size = num_blocks * block_size; |
| 362 | + |
| 363 | + TORCH_CHECK(sorted_ids.size(0) == sorted_ids_size); |
| 364 | + TORCH_CHECK(batch_ids.size(0) == sorted_ids_size / block_size); |
| 365 | + TORCH_CHECK(num_tokens_post_pad.size(0) == 1); |
| 366 | + TORCH_CHECK(B <= batched_kernel::num_threads); |
| 367 | + |
| 368 | + batched_kernel::batched_moe_align_block_size_kernel<<< |
| 369 | + batched_kernel::num_blocks, batched_kernel::num_threads, 0, stream>>>( |
| 370 | + B, max_tokens_per_batch, block_size, batch_num_tokens.data_ptr<int32_t>(), |
| 371 | + sorted_ids.data_ptr<int32_t>(), batch_ids.data_ptr<int32_t>(), |
| 372 | + num_tokens_post_pad.data_ptr<int32_t>()); |
| 373 | +} |
| 374 | + |
283 | 375 | void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size] |
284 | 376 | torch::Tensor& output) // [num_tokens, hidden_size] |
285 | 377 | { |
|
0 commit comments