22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
44import argparse
5+ import inspect
56import json
67import os
78import time
89from contextlib import nullcontext
910from datetime import datetime
1011from itertools import product
1112from typing import Any , TypedDict
12- import inspect
1313
1414import ray
1515import torch
@@ -568,54 +568,58 @@ def _get_config_dtype_str_compatible(
568568 use_fp8_w8a8 : bool = False ,
569569 use_int8_w8a16 : bool = False ,
570570 use_int4_w4a16 : bool = False ,
571- ** kwargs
571+ ** kwargs ,
572572) -> str | None :
573573 """Multi-level import fallback for _get_config_dtype_str function."""
574574 try :
575575 from vllm .model_executor .layers .fused_moe .config import (
576- _get_config_dtype_str as _original_func
576+ _get_config_dtype_str as _original_func ,
577577 )
578+
578579 return _original_func (
579580 dtype ,
580581 use_fp8_w8a8 = use_fp8_w8a8 ,
581582 use_int8_w8a16 = use_int8_w8a16 ,
582583 use_int4_w4a16 = use_int4_w4a16 ,
583- ** kwargs
584+ ** kwargs ,
584585 )
585586 except ImportError :
586587 try :
587588 from vllm .model_executor .layers .fused_moe import (
588- _get_config_dtype_str as _original_func
589+ _get_config_dtype_str as _original_func ,
589590 )
591+
590592 return _original_func (
591593 dtype ,
592594 use_fp8_w8a8 = use_fp8_w8a8 ,
593595 use_int8_w8a16 = use_int8_w8a16 ,
594596 use_int4_w4a16 = use_int4_w4a16 ,
595- ** kwargs
597+ ** kwargs ,
596598 )
597599 except ImportError :
598600 try :
599601 from vllm .model_executor .layers .fused_moe .layer import (
600- _get_config_dtype_str as _original_func
602+ _get_config_dtype_str as _original_func ,
601603 )
604+
602605 return _original_func (
603606 dtype ,
604607 use_fp8_w8a8 = use_fp8_w8a8 ,
605608 use_int8_w8a16 = use_int8_w8a16 ,
606609 use_int4_w4a16 = use_int4_w4a16 ,
607- ** kwargs
610+ ** kwargs ,
608611 )
609612 except ImportError :
610613 try :
611614 from vllm .model_executor .layers .fused_moe import FusedMoE
612- if hasattr (FusedMoE , '_get_config_dtype_str' ):
613- return getattr (FusedMoE , '_get_config_dtype_str' )(
615+
616+ if hasattr (FusedMoE , "_get_config_dtype_str" ):
617+ return FusedMoE ._get_config_dtype_str (
614618 dtype ,
615619 use_fp8_w8a8 = use_fp8_w8a8 ,
616620 use_int8_w8a16 = use_int8_w8a16 ,
617621 use_int4_w4a16 = use_int4_w4a16 ,
618- ** kwargs
622+ ** kwargs ,
619623 )
620624 except ImportError :
621625 pass
@@ -632,35 +636,35 @@ def _get_config_dtype_str_compatible(
632636 return "float32"
633637 return None
634638
639+
635640def make_quant_config_compatible (
636641 quant_dtype , w1_scale , w2_scale , a1_scale , a2_scale , block_quant_shape
637642):
638643 """Compatible wrapper for FusedMoEQuantConfig.make() across vLLM versions."""
639- from vllm .model_executor .layers .fused_moe .config import FusedMoEQuantConfig
640644 if quant_dtype is None :
641645 return None
642646 param_combinations = [
643647 {
644- ' quant_dtype' : quant_dtype ,
645- ' w1_scale' : w1_scale ,
646- ' w2_scale' : w2_scale ,
647- ' a1_scale' : a1_scale ,
648- ' a2_scale' : a2_scale ,
649- ' block_quant_shape' : block_quant_shape ,
648+ " quant_dtype" : quant_dtype ,
649+ " w1_scale" : w1_scale ,
650+ " w2_scale" : w2_scale ,
651+ " a1_scale" : a1_scale ,
652+ " a2_scale" : a2_scale ,
653+ " block_quant_shape" : block_quant_shape ,
650654 },
651655 {
652- ' quant_dtype' : quant_dtype ,
653- ' w1_scale' : w1_scale ,
654- ' w2_scale' : w2_scale ,
655- ' a1_scale' : a1_scale ,
656- ' a2_scale' : a2_scale ,
656+ " quant_dtype" : quant_dtype ,
657+ " w1_scale" : w1_scale ,
658+ " w2_scale" : w2_scale ,
659+ " a1_scale" : a1_scale ,
660+ " a2_scale" : a2_scale ,
657661 },
658662 {
659- ' dtype' : quant_dtype ,
660- ' w1_scale' : w1_scale ,
661- ' w2_scale' : w2_scale ,
662- ' a1_scale' : a1_scale ,
663- ' a2_scale' : a2_scale ,
663+ " dtype" : quant_dtype ,
664+ " w1_scale" : w1_scale ,
665+ " w2_scale" : w2_scale ,
666+ " a1_scale" : a1_scale ,
667+ " a2_scale" : a2_scale ,
664668 },
665669 ]
666670 for params in param_combinations :
@@ -673,6 +677,7 @@ def make_quant_config_compatible(
673677 "Unable to create FusedMoEQuantConfig with any known parameter combination."
674678 )
675679
680+
676681def fused_experts_compatible (
677682 x ,
678683 w1 ,
@@ -685,14 +690,16 @@ def fused_experts_compatible(
685690):
686691 """Compatible wrapper for fused_experts function."""
687692 from vllm .model_executor .layers .fused_moe .fused_moe import fused_experts
693+
688694 sig = inspect .signature (fused_experts )
689- kwargs = {' inplace' : inplace }
690- if ' quant_config' in sig .parameters :
691- kwargs [' quant_config' ] = quant_config
692- if ' allow_deep_gemm' in sig .parameters :
693- kwargs [' allow_deep_gemm' ] = allow_deep_gemm
695+ kwargs = {" inplace" : inplace }
696+ if " quant_config" in sig .parameters :
697+ kwargs [" quant_config" ] = quant_config
698+ if " allow_deep_gemm" in sig .parameters :
699+ kwargs [" allow_deep_gemm" ] = allow_deep_gemm
694700 return fused_experts (x , w1 , w2 , topk_weights , topk_ids , ** kwargs )
695701
702+
696703def main (args : argparse .Namespace ):
697704 print (args )
698705
0 commit comments