Skip to content

Commit 9c7ccdc

Browse files
committed
Apply ruff formatting fixes
- Fix import ordering (move inspect to correct position) - Add trailing commas in function parameters and dictionaries - Standardize quotes to double quotes - Add proper blank line separations - Simplify getattr call to direct attribute access Resolves ruff-check and ruff-format pre-commit failures Signed-off-by: Alfred <[email protected]>
1 parent 9698a28 commit 9c7ccdc

File tree

1 file changed

+40
-33
lines changed

1 file changed

+40
-33
lines changed

benchmarks/kernels/benchmark_moe.py

Lines changed: 40 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import argparse
5+
import inspect
56
import json
67
import os
78
import time
89
from contextlib import nullcontext
910
from datetime import datetime
1011
from itertools import product
1112
from typing import Any, TypedDict
12-
import inspect
1313

1414
import ray
1515
import torch
@@ -568,54 +568,58 @@ def _get_config_dtype_str_compatible(
568568
use_fp8_w8a8: bool = False,
569569
use_int8_w8a16: bool = False,
570570
use_int4_w4a16: bool = False,
571-
**kwargs
571+
**kwargs,
572572
) -> str | None:
573573
"""Multi-level import fallback for _get_config_dtype_str function."""
574574
try:
575575
from vllm.model_executor.layers.fused_moe.config import (
576-
_get_config_dtype_str as _original_func
576+
_get_config_dtype_str as _original_func,
577577
)
578+
578579
return _original_func(
579580
dtype,
580581
use_fp8_w8a8=use_fp8_w8a8,
581582
use_int8_w8a16=use_int8_w8a16,
582583
use_int4_w4a16=use_int4_w4a16,
583-
**kwargs
584+
**kwargs,
584585
)
585586
except ImportError:
586587
try:
587588
from vllm.model_executor.layers.fused_moe import (
588-
_get_config_dtype_str as _original_func
589+
_get_config_dtype_str as _original_func,
589590
)
591+
590592
return _original_func(
591593
dtype,
592594
use_fp8_w8a8=use_fp8_w8a8,
593595
use_int8_w8a16=use_int8_w8a16,
594596
use_int4_w4a16=use_int4_w4a16,
595-
**kwargs
597+
**kwargs,
596598
)
597599
except ImportError:
598600
try:
599601
from vllm.model_executor.layers.fused_moe.layer import (
600-
_get_config_dtype_str as _original_func
602+
_get_config_dtype_str as _original_func,
601603
)
604+
602605
return _original_func(
603606
dtype,
604607
use_fp8_w8a8=use_fp8_w8a8,
605608
use_int8_w8a16=use_int8_w8a16,
606609
use_int4_w4a16=use_int4_w4a16,
607-
**kwargs
610+
**kwargs,
608611
)
609612
except ImportError:
610613
try:
611614
from vllm.model_executor.layers.fused_moe import FusedMoE
612-
if hasattr(FusedMoE, '_get_config_dtype_str'):
613-
return getattr(FusedMoE, '_get_config_dtype_str')(
615+
616+
if hasattr(FusedMoE, "_get_config_dtype_str"):
617+
return FusedMoE._get_config_dtype_str(
614618
dtype,
615619
use_fp8_w8a8=use_fp8_w8a8,
616620
use_int8_w8a16=use_int8_w8a16,
617621
use_int4_w4a16=use_int4_w4a16,
618-
**kwargs
622+
**kwargs,
619623
)
620624
except ImportError:
621625
pass
@@ -632,35 +636,35 @@ def _get_config_dtype_str_compatible(
632636
return "float32"
633637
return None
634638

639+
635640
def make_quant_config_compatible(
636641
quant_dtype, w1_scale, w2_scale, a1_scale, a2_scale, block_quant_shape
637642
):
638643
"""Compatible wrapper for FusedMoEQuantConfig.make() across vLLM versions."""
639-
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
640644
if quant_dtype is None:
641645
return None
642646
param_combinations = [
643647
{
644-
'quant_dtype': quant_dtype,
645-
'w1_scale': w1_scale,
646-
'w2_scale': w2_scale,
647-
'a1_scale': a1_scale,
648-
'a2_scale': a2_scale,
649-
'block_quant_shape': block_quant_shape,
648+
"quant_dtype": quant_dtype,
649+
"w1_scale": w1_scale,
650+
"w2_scale": w2_scale,
651+
"a1_scale": a1_scale,
652+
"a2_scale": a2_scale,
653+
"block_quant_shape": block_quant_shape,
650654
},
651655
{
652-
'quant_dtype': quant_dtype,
653-
'w1_scale': w1_scale,
654-
'w2_scale': w2_scale,
655-
'a1_scale': a1_scale,
656-
'a2_scale': a2_scale,
656+
"quant_dtype": quant_dtype,
657+
"w1_scale": w1_scale,
658+
"w2_scale": w2_scale,
659+
"a1_scale": a1_scale,
660+
"a2_scale": a2_scale,
657661
},
658662
{
659-
'dtype': quant_dtype,
660-
'w1_scale': w1_scale,
661-
'w2_scale': w2_scale,
662-
'a1_scale': a1_scale,
663-
'a2_scale': a2_scale,
663+
"dtype": quant_dtype,
664+
"w1_scale": w1_scale,
665+
"w2_scale": w2_scale,
666+
"a1_scale": a1_scale,
667+
"a2_scale": a2_scale,
664668
},
665669
]
666670
for params in param_combinations:
@@ -673,6 +677,7 @@ def make_quant_config_compatible(
673677
"Unable to create FusedMoEQuantConfig with any known parameter combination."
674678
)
675679

680+
676681
def fused_experts_compatible(
677682
x,
678683
w1,
@@ -685,14 +690,16 @@ def fused_experts_compatible(
685690
):
686691
"""Compatible wrapper for fused_experts function."""
687692
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
693+
688694
sig = inspect.signature(fused_experts)
689-
kwargs = {'inplace': inplace}
690-
if 'quant_config' in sig.parameters:
691-
kwargs['quant_config'] = quant_config
692-
if 'allow_deep_gemm' in sig.parameters:
693-
kwargs['allow_deep_gemm'] = allow_deep_gemm
695+
kwargs = {"inplace": inplace}
696+
if "quant_config" in sig.parameters:
697+
kwargs["quant_config"] = quant_config
698+
if "allow_deep_gemm" in sig.parameters:
699+
kwargs["allow_deep_gemm"] = allow_deep_gemm
694700
return fused_experts(x, w1, w2, topk_weights, topk_ids, **kwargs)
695701

702+
696703
def main(args: argparse.Namespace):
697704
print(args)
698705

0 commit comments

Comments
 (0)