|
9 | 9 | from datetime import datetime |
10 | 10 | from itertools import product |
11 | 11 | from typing import Any, TypedDict |
| 12 | +import inspect |
12 | 13 |
|
13 | 14 | import ray |
14 | 15 | import torch |
15 | 16 | from ray.experimental.tqdm_ray import tqdm |
16 | 17 |
|
17 | 18 | from vllm.model_executor.layers.fused_moe.config import ( |
18 | 19 | FusedMoEQuantConfig, |
19 | | - _get_config_dtype_str, |
20 | 20 | ) |
21 | 21 | from vllm.model_executor.layers.fused_moe.fused_moe import * |
22 | 22 | from vllm.platforms import current_platform |
@@ -145,20 +145,15 @@ def run(): |
145 | 145 | else: |
146 | 146 | quant_dtype = None |
147 | 147 |
|
148 | | - quant_config = FusedMoEQuantConfig.make( |
149 | | - quant_dtype=quant_dtype, |
150 | | - w1_scale=w1_scale, |
151 | | - w2_scale=w2_scale, |
152 | | - a1_scale=a1_scale, |
153 | | - a2_scale=a2_scale, |
154 | | - block_shape=block_quant_shape, |
| 148 | + quant_config = make_quant_config_compatible( |
| 149 | + quant_dtype, w1_scale, w2_scale, a1_scale, a2_scale, block_quant_shape |
155 | 150 | ) |
156 | 151 |
|
157 | 152 | with override_config(config): |
158 | 153 | topk_weights, topk_ids, token_expert_indices = fused_topk( |
159 | 154 | x, input_gating, topk, renormalize=not use_deep_gemm |
160 | 155 | ) |
161 | | - return fused_experts( |
| 156 | + return fused_experts_compatible( |
162 | 157 | x, |
163 | 158 | w1, |
164 | 159 | w2, |
@@ -411,7 +406,7 @@ def benchmark( |
411 | 406 | use_deep_gemm: bool = False, |
412 | 407 | ) -> tuple[dict[str, int], float]: |
413 | 408 | current_platform.seed_everything(self.seed) |
414 | | - dtype_str = _get_config_dtype_str( |
| 409 | + dtype_str = _get_config_dtype_str_compatible( |
415 | 410 | dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 |
416 | 411 | ) |
417 | 412 | # NOTE(woosuk): The current naming convention uses w2.shape[2], which |
@@ -568,6 +563,78 @@ def get_weight_block_size_safety(config, default_value=None): |
568 | 563 | return default_value |
569 | 564 |
|
570 | 565 |
|
| 566 | +def _get_config_dtype_str_compatible(config, quant_config): |
| 567 | + """Multi-level import fallback for _get_config_dtype_str function.""" |
| 568 | + try: |
| 569 | + from vllm.model_executor.layers.fused_moe.config import _get_config_dtype_str as _original_func |
| 570 | + return _original_func(config, quant_config) |
| 571 | + except ImportError: |
| 572 | + try: |
| 573 | + from vllm.model_executor.layers.fused_moe import _get_config_dtype_str as _original_func |
| 574 | + return _original_func(config, quant_config) |
| 575 | + except ImportError: |
| 576 | + try: |
| 577 | + from vllm.model_executor.layers.fused_moe.layer import _get_config_dtype_str as _original_func |
| 578 | + return _original_func(config, quant_config) |
| 579 | + except ImportError: |
| 580 | + try: |
| 581 | + from vllm.model_executor.layers.fused_moe import FusedMoE |
| 582 | + if hasattr(FusedMoE, '_get_config_dtype_str'): |
| 583 | + return getattr(FusedMoE, '_get_config_dtype_str')(config, quant_config) |
| 584 | + except ImportError: |
| 585 | + pass |
| 586 | + if hasattr(config, 'torch_dtype'): |
| 587 | + return str(config.torch_dtype).split('.')[-1] |
| 588 | + return "float16" |
| 589 | + |
| 590 | +def make_quant_config_compatible(quant_dtype, w1_scale, w2_scale, a1_scale, a2_scale, block_quant_shape): |
| 591 | + """Compatible wrapper for FusedMoEQuantConfig.make() across vLLM versions.""" |
| 592 | + from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig |
| 593 | + if quant_dtype is None: |
| 594 | + return None |
| 595 | + param_combinations = [ |
| 596 | + { |
| 597 | + 'quant_dtype': quant_dtype, |
| 598 | + 'w1_scale': w1_scale, |
| 599 | + 'w2_scale': w2_scale, |
| 600 | + 'a1_scale': a1_scale, |
| 601 | + 'a2_scale': a2_scale, |
| 602 | + 'block_quant_shape': block_quant_shape, |
| 603 | + }, |
| 604 | + { |
| 605 | + 'quant_dtype': quant_dtype, |
| 606 | + 'w1_scale': w1_scale, |
| 607 | + 'w2_scale': w2_scale, |
| 608 | + 'a1_scale': a1_scale, |
| 609 | + 'a2_scale': a2_scale, |
| 610 | + }, |
| 611 | + { |
| 612 | + 'dtype': quant_dtype, |
| 613 | + 'w1_scale': w1_scale, |
| 614 | + 'w2_scale': w2_scale, |
| 615 | + 'a1_scale': a1_scale, |
| 616 | + 'a2_scale': a2_scale, |
| 617 | + }, |
| 618 | + ] |
| 619 | + for params in param_combinations: |
| 620 | + filtered_params = {k: v for k, v in params.items() if v is not None} |
| 621 | + try: |
| 622 | + return FusedMoEQuantConfig.make(**filtered_params) |
| 623 | + except TypeError: |
| 624 | + continue |
| 625 | + raise TypeError("Unable to create FusedMoEQuantConfig with any known parameter combination.") |
| 626 | + |
| 627 | +def fused_experts_compatible(x, w1, w2, topk_weights, topk_ids, inplace=True, quant_config=None, allow_deep_gemm=False): |
| 628 | + """Compatible wrapper for fused_experts function.""" |
| 629 | + from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts |
| 630 | + sig = inspect.signature(fused_experts) |
| 631 | + kwargs = {'inplace': inplace} |
| 632 | + if 'quant_config' in sig.parameters: |
| 633 | + kwargs['quant_config'] = quant_config |
| 634 | + if 'allow_deep_gemm' in sig.parameters: |
| 635 | + kwargs['allow_deep_gemm'] = allow_deep_gemm |
| 636 | + return fused_experts(x, w1, w2, topk_weights, topk_ids, **kwargs) |
| 637 | + |
571 | 638 | def main(args: argparse.Namespace): |
572 | 639 | print(args) |
573 | 640 |
|
@@ -664,8 +731,8 @@ def main(args: argparse.Namespace): |
664 | 731 |
|
665 | 732 | if current_platform.is_rocm() and "HIP_VISIBLE_DEVICES" in os.environ: |
666 | 733 | # Ray will set ROCR_VISIBLE_DEVICES for device visibility |
667 | | - logger.warning( |
668 | | - "Ray uses ROCR_VISIBLE_DEVICES to control device accessibility." |
| 734 | + print( |
| 735 | + "Ray uses ROCR_VISIBLE_DEVICES to control device accessibility. " |
669 | 736 | "Replacing HIP_VISIBLE_DEVICES with ROCR_VISIBLE_DEVICES." |
670 | 737 | ) |
671 | 738 | val = os.environ["HIP_VISIBLE_DEVICES"] |
|
0 commit comments