1313)
1414from vllm .model_executor .layers .fused_moe .deep_gemm_utils import (
1515 compute_aligned_M ,
16- deep_gemm_block_shape ,
1716 deepgemm_moe_permute ,
1817 deepgemm_unpermute_and_reduce ,
1918)
2827 per_token_group_quant_fp8 ,
2928)
3029from vllm .utils import has_deep_gemm
31- from vllm .utils .deep_gemm import m_grouped_fp8_gemm_nt_contiguous
30+ from vllm .utils .deep_gemm import (
31+ get_mk_alignment_for_contiguous_layout ,
32+ m_grouped_fp8_gemm_nt_contiguous ,
33+ )
3234from vllm .utils .functools import run_once
3335
3436logger = init_logger (__name__ )
3537
3638
3739def _valid_deep_gemm_shape (M : int , N : int , K : int ) -> bool :
38- align = deep_gemm_block_shape ()[0 ]
40+ align = get_mk_alignment_for_contiguous_layout ()[0 ]
3941 return align <= M and N % align == 0 and K % align == 0
4042
4143
@@ -54,7 +56,7 @@ def _valid_deep_gemm(
5456 M = hidden_states .size (0 )
5557 _ , K , N = w2 .size ()
5658
57- align = deep_gemm_block_shape ()[0 ]
59+ align = get_mk_alignment_for_contiguous_layout ()[0 ]
5860
5961 if not _valid_deep_gemm_shape (M , N , K ):
6062 logger .debug_once (
@@ -124,7 +126,7 @@ def warmup_deepgemm_gg_contiguous_kernels(
124126
125127 assert w1 .size (0 ) == w2 .size (0 ), "w1 and w2 must have the same number of experts"
126128
127- block_m = deep_gemm_block_shape ()[0 ]
129+ block_m = get_mk_alignment_for_contiguous_layout ()[0 ]
128130 num_experts = w1 .size (0 )
129131 device = w1 .device
130132
@@ -173,7 +175,7 @@ def _warmup(w: torch.Tensor, w_scale: torch.Tensor):
173175class DeepGemmExperts (mk .FusedMoEPermuteExpertsUnpermute ):
174176 def __init__ (self , quant_config : FusedMoEQuantConfig ):
175177 super ().__init__ (quant_config )
176- assert quant_config .block_shape == deep_gemm_block_shape ()
178+ assert quant_config .block_shape == get_mk_alignment_for_contiguous_layout ()
177179 assert quant_config .quant_dtype == torch .float8_e4m3fn
178180 assert not quant_config .per_act_token_quant
179181 assert not quant_config .per_out_ch_quant
@@ -255,7 +257,7 @@ def apply(
255257 M = topk_ids .size (0 ),
256258 num_topk = topk_ids .size (1 ),
257259 local_num_experts = local_num_experts ,
258- alignment = deep_gemm_block_shape ()[0 ],
260+ alignment = get_mk_alignment_for_contiguous_layout ()[0 ],
259261 expert_tokens_meta = expert_tokens_meta ,
260262 )
261263
@@ -364,7 +366,7 @@ def deep_gemm_moe_fp8(
364366 w2_scale = w2_scale ,
365367 a1_scale = a1_scale ,
366368 a2_scale = a2_scale ,
367- block_shape = deep_gemm_block_shape (),
369+ block_shape = get_mk_alignment_for_contiguous_layout (),
368370 )
369371
370372 fn = mk .FusedMoEModularKernel (
0 commit comments