Skip to content

Commit dc1cd80

Browse files
committed
Enhance benchmark_moe.py: vLLM version compatibility fixes
- Multi-level import fallback for _get_config_dtype_str - Dynamic wrapper for FusedMoEQuantConfig.make() - Automatic function signature detection for fused_experts() - Clean English output, production-ready logging - Enables seamless usage across vLLM 0.6.0+ to 0.10.0+
1 parent 8a81d77 commit dc1cd80

File tree

1 file changed

+79
-12
lines changed

1 file changed

+79
-12
lines changed

benchmarks/kernels/benchmark_moe.py

Lines changed: 79 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@
99
from datetime import datetime
1010
from itertools import product
1111
from typing import Any, TypedDict
12+
import inspect
1213

1314
import ray
1415
import torch
1516
from ray.experimental.tqdm_ray import tqdm
1617

1718
from vllm.model_executor.layers.fused_moe.config import (
1819
FusedMoEQuantConfig,
19-
_get_config_dtype_str,
2020
)
2121
from vllm.model_executor.layers.fused_moe.fused_moe import *
2222
from vllm.platforms import current_platform
@@ -145,20 +145,15 @@ def run():
145145
else:
146146
quant_dtype = None
147147

148-
quant_config = FusedMoEQuantConfig.make(
149-
quant_dtype=quant_dtype,
150-
w1_scale=w1_scale,
151-
w2_scale=w2_scale,
152-
a1_scale=a1_scale,
153-
a2_scale=a2_scale,
154-
block_shape=block_quant_shape,
148+
quant_config = make_quant_config_compatible(
149+
quant_dtype, w1_scale, w2_scale, a1_scale, a2_scale, block_quant_shape
155150
)
156151

157152
with override_config(config):
158153
topk_weights, topk_ids, token_expert_indices = fused_topk(
159154
x, input_gating, topk, renormalize=not use_deep_gemm
160155
)
161-
return fused_experts(
156+
return fused_experts_compatible(
162157
x,
163158
w1,
164159
w2,
@@ -411,7 +406,7 @@ def benchmark(
411406
use_deep_gemm: bool = False,
412407
) -> tuple[dict[str, int], float]:
413408
current_platform.seed_everything(self.seed)
414-
dtype_str = _get_config_dtype_str(
409+
dtype_str = _get_config_dtype_str_compatible(
415410
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
416411
)
417412
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
@@ -568,6 +563,78 @@ def get_weight_block_size_safety(config, default_value=None):
568563
return default_value
569564

570565

566+
def _get_config_dtype_str_compatible(config, quant_config):
567+
"""Multi-level import fallback for _get_config_dtype_str function."""
568+
try:
569+
from vllm.model_executor.layers.fused_moe.config import _get_config_dtype_str as _original_func
570+
return _original_func(config, quant_config)
571+
except ImportError:
572+
try:
573+
from vllm.model_executor.layers.fused_moe import _get_config_dtype_str as _original_func
574+
return _original_func(config, quant_config)
575+
except ImportError:
576+
try:
577+
from vllm.model_executor.layers.fused_moe.layer import _get_config_dtype_str as _original_func
578+
return _original_func(config, quant_config)
579+
except ImportError:
580+
try:
581+
from vllm.model_executor.layers.fused_moe import FusedMoE
582+
if hasattr(FusedMoE, '_get_config_dtype_str'):
583+
return getattr(FusedMoE, '_get_config_dtype_str')(config, quant_config)
584+
except ImportError:
585+
pass
586+
if hasattr(config, 'torch_dtype'):
587+
return str(config.torch_dtype).split('.')[-1]
588+
return "float16"
589+
590+
def make_quant_config_compatible(quant_dtype, w1_scale, w2_scale, a1_scale, a2_scale, block_quant_shape):
591+
"""Compatible wrapper for FusedMoEQuantConfig.make() across vLLM versions."""
592+
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
593+
if quant_dtype is None:
594+
return None
595+
param_combinations = [
596+
{
597+
'quant_dtype': quant_dtype,
598+
'w1_scale': w1_scale,
599+
'w2_scale': w2_scale,
600+
'a1_scale': a1_scale,
601+
'a2_scale': a2_scale,
602+
'block_quant_shape': block_quant_shape,
603+
},
604+
{
605+
'quant_dtype': quant_dtype,
606+
'w1_scale': w1_scale,
607+
'w2_scale': w2_scale,
608+
'a1_scale': a1_scale,
609+
'a2_scale': a2_scale,
610+
},
611+
{
612+
'dtype': quant_dtype,
613+
'w1_scale': w1_scale,
614+
'w2_scale': w2_scale,
615+
'a1_scale': a1_scale,
616+
'a2_scale': a2_scale,
617+
},
618+
]
619+
for params in param_combinations:
620+
filtered_params = {k: v for k, v in params.items() if v is not None}
621+
try:
622+
return FusedMoEQuantConfig.make(**filtered_params)
623+
except TypeError:
624+
continue
625+
raise TypeError("Unable to create FusedMoEQuantConfig with any known parameter combination.")
626+
627+
def fused_experts_compatible(x, w1, w2, topk_weights, topk_ids, inplace=True, quant_config=None, allow_deep_gemm=False):
628+
"""Compatible wrapper for fused_experts function."""
629+
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
630+
sig = inspect.signature(fused_experts)
631+
kwargs = {'inplace': inplace}
632+
if 'quant_config' in sig.parameters:
633+
kwargs['quant_config'] = quant_config
634+
if 'allow_deep_gemm' in sig.parameters:
635+
kwargs['allow_deep_gemm'] = allow_deep_gemm
636+
return fused_experts(x, w1, w2, topk_weights, topk_ids, **kwargs)
637+
571638
def main(args: argparse.Namespace):
572639
print(args)
573640

@@ -664,8 +731,8 @@ def main(args: argparse.Namespace):
664731

665732
if current_platform.is_rocm() and "HIP_VISIBLE_DEVICES" in os.environ:
666733
# Ray will set ROCR_VISIBLE_DEVICES for device visibility
667-
logger.warning(
668-
"Ray uses ROCR_VISIBLE_DEVICES to control device accessibility."
734+
print(
735+
"Ray uses ROCR_VISIBLE_DEVICES to control device accessibility. "
669736
"Replacing HIP_VISIBLE_DEVICES with ROCR_VISIBLE_DEVICES."
670737
)
671738
val = os.environ["HIP_VISIBLE_DEVICES"]

0 commit comments

Comments
 (0)