@@ -188,7 +188,9 @@ __launch_bounds__(TPB) __global__ void moeTopK(
188188 It fuses the softmax, max and argmax into a single kernel.
189189
190190 Limitations:
191- 1) This implementation is intended for when the number of experts is a small power of 2.
191+ 1) This implementation is optimized for when the number of experts is a small power of 2.
192+ Additionally it also supports when number of experts is multiple of 64 which is still
193+ faster than the computing softmax and topK separately (only tested on CUDA yet).
192194 2) This implementation assumes k is small, but will work for any k.
193195*/
194196
@@ -198,8 +200,6 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
198200 int * source_rows, const int k, const int start_expert, const int end_expert)
199201{
200202 // We begin by enforcing compile time assertions and setting up compile time constants.
201- static_assert (VPT == (VPT & -VPT), " VPT must be power of 2" );
202- static_assert (NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), " NUM_EXPERTS must be power of 2" );
203203 static_assert (BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), " BYTES_PER_LDG must be power of 2" );
204204 static_assert (BYTES_PER_LDG <= 16 , " BYTES_PER_LDG must be leq 16" );
205205
@@ -407,12 +407,10 @@ struct TopkConstants
407407};
408408} // namespace detail
409409
410- template <int EXPERTS, int WARPS_PER_TB, int WARP_SIZE_PARAM, typename IndType>
410+ template <int EXPERTS, int WARPS_PER_TB, int WARP_SIZE_PARAM, int MAX_BYTES_PER_LDG, typename IndType>
411411void topkGatingSoftmaxLauncherHelper (const float * input, const bool * finished, float * output, IndType* indices,
412412 int * source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
413413{
414- static constexpr std::size_t MAX_BYTES_PER_LDG = 16 ;
415-
416414 static constexpr int BYTES_PER_LDG = MIN (MAX_BYTES_PER_LDG, sizeof (float ) * EXPERTS);
417415 using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM>;
418416 static constexpr int VPT = Constants::VPT;
@@ -425,21 +423,12 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
425423 input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert);
426424}
427425
428- #define LAUNCH_SOFTMAX (NUM_EXPERTS, WARPS_PER_TB ) \
429- switch (warpSize ) { \
430- case 32 : \
431- topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 32 >( \
432- gating_output, nullptr , topk_weights, topk_indices, \
433- token_expert_indices, num_tokens, topk, 0 , num_experts, stream); \
434- break ; \
435- case 64 : \
436- topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 64 >( \
437- gating_output, nullptr , topk_weights, topk_indices, \
438- token_expert_indices, num_tokens, topk, 0 , num_experts, stream); \
439- break ; \
440- default : \
441- TORCH_CHECK (false , " Unsupported warp size: " , warpSize ); \
442- }
426+ #define LAUNCH_SOFTMAX (NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES ) \
427+ static_assert (WARP_SIZE == 32 || WARP_SIZE == 64 , \
428+ " Unsupported warp size. Only 32 and 64 are supported." ); \
429+ topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, WARP_SIZE, MAX_BYTES>( \
430+ gating_output, nullptr , topk_weights, topk_indices, \
431+ token_expert_indices, num_tokens, topk, 0 , num_experts, stream);
443432
444433template <typename IndType>
445434void topkGatingSoftmaxKernelLauncher (
@@ -453,38 +442,62 @@ void topkGatingSoftmaxKernelLauncher(
453442 const int topk,
454443 cudaStream_t stream) {
455444 static constexpr int WARPS_PER_TB = 4 ;
456- auto warpSize = WARP_SIZE;
445+ static constexpr int BYTES_PER_LDG_POWER_OF_2 = 16 ;
446+ static constexpr int BYTES_PER_LDG_MULTIPLE_64 = 8 ;
457447 switch (num_experts) {
458448 case 1 :
459- LAUNCH_SOFTMAX (1 , WARPS_PER_TB);
449+ LAUNCH_SOFTMAX (1 , WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2 );
460450 break ;
461451 case 2 :
462- LAUNCH_SOFTMAX (2 , WARPS_PER_TB);
452+ LAUNCH_SOFTMAX (2 , WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2 );
463453 break ;
464454 case 4 :
465- LAUNCH_SOFTMAX (4 , WARPS_PER_TB);
455+ LAUNCH_SOFTMAX (4 , WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2 );
466456 break ;
467457 case 8 :
468- LAUNCH_SOFTMAX (8 , WARPS_PER_TB);
458+ LAUNCH_SOFTMAX (8 , WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2 );
469459 break ;
470460 case 16 :
471- LAUNCH_SOFTMAX (16 , WARPS_PER_TB);
461+ LAUNCH_SOFTMAX (16 , WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2 );
472462 break ;
473463 case 32 :
474- LAUNCH_SOFTMAX (32 , WARPS_PER_TB);
464+ LAUNCH_SOFTMAX (32 , WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2 );
475465 break ;
476466 case 64 :
477- LAUNCH_SOFTMAX (64 , WARPS_PER_TB);
467+ LAUNCH_SOFTMAX (64 , WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2 );
478468 break ;
479469 case 128 :
480- LAUNCH_SOFTMAX (128 , WARPS_PER_TB);
470+ LAUNCH_SOFTMAX (128 , WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2 );
481471 break ;
482472 case 256 :
483- LAUNCH_SOFTMAX (256 , WARPS_PER_TB);
473+ LAUNCH_SOFTMAX (256 , WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
474+ break ;
475+ case 512 :
476+ LAUNCH_SOFTMAX (512 , WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
477+ break ;
478+ // (CUDA only) support multiples of 64 when num_experts is not power of 2.
479+ // ROCm uses WARP_SIZE 64 so 8 bytes loading won't fit for some of num_experts,
480+ // alternatively we can test 4 bytes loading and enable it in future.
481+ #ifndef USE_ROCM
482+ case 192 :
483+ LAUNCH_SOFTMAX (192 , WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
484484 break ;
485+ case 320 :
486+ LAUNCH_SOFTMAX (320 , WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
487+ break ;
488+ case 384 :
489+ LAUNCH_SOFTMAX (384 , WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
490+ break ;
491+ case 448 :
492+ LAUNCH_SOFTMAX (448 , WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
493+ break ;
494+ case 576 :
495+ LAUNCH_SOFTMAX (576 , WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
496+ break ;
497+ #endif
485498 default : {
486499 TORCH_CHECK (softmax_workspace != nullptr ,
487- " softmax_workspace must be provided for num_experts that are not a power of 2." );
500+ " softmax_workspace must be provided for num_experts that are not a power of 2 or multiple of 64 ." );
488501 static constexpr int TPB = 256 ;
489502 moeSoftmax<TPB><<<num_tokens, TPB, 0 , stream>>> (
490503 gating_output, nullptr , softmax_workspace, num_experts);
0 commit comments