Skip to content

Commit 178ca16

Browse files
authored
Adopt inductor fusion and define quantization fusion pass (#4168)
### What this PR does / why we need it? The main goal of this PR to alleviate the high maintenance burden from model duplication when we are going to do the model optimization. Some of our optimized models diverges a little from the vllm's modeling, but needs to rewrite several part of original one, brings negligible maintenance bruden to the vllm-ascend.In order to solve that, we propose to leverage `torch.compile` and `inductor pattern matcher`, automatically fuse the pattern we want to merge. For more details can refer to the RFC #4239 This pr integrates `AddRMSNorm` and the `Quant` operator, which can improve the inference speed of models using `w8a8 `quantization. ### Does this PR introduce _any_ user-facing change? Yes, add new additional_config ### How was this patch tested? ```python def main(): prompts = [ "The president of the United States is Mr.", ] # Create a sampling params object. sampling_params = SamplingParams(max_tokens=100, temperature=0.6, top_k=40, top_p=0.95) # Create an LLM. llm = LLM( model="/root/.cache/modelscope/hub/models/vllm-ascend/Qwen3-8B-W8A8", # enforce_eager=True, tensor_parallel_size=1, trust_remote_code=True, gpu_memory_utilization=0.7, quantization="ascend", ) # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` ```text Prompt: 'The president of the United States is Mr.', Generated text: ' Trump. The president of the United States is Mr. Biden. Which of the following statements is correct? \n\nA. Mr. Trump is Mr. Biden. \nB. Mr. Trump is not Mr. Biden. \nC. The president of the United States is not Mr. Trump. \nD. The president of the United States is not Mr. Biden.\n\nThe question presents a contradiction: it states that "The president of the United States is Mr. Trump" and "The president of' ``` - vLLM version: 86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24 - vLLM main: vllm-project/vllm@86e178f --------- Signed-off-by: Icey <[email protected]> Signed-off-by: wxsIcey <[email protected]>
1 parent c4a71fc commit 178ca16

File tree

13 files changed

+595
-269
lines changed

13 files changed

+595
-269
lines changed
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# Copyright 2023 The vLLM team.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
from copy import deepcopy
18+
from typing import Any, Callable, List, Optional, Sequence
19+
20+
import pytest
21+
import torch
22+
import torch.fx as fx
23+
import torch.nn as nn
24+
import torch_npu
25+
import vllm.config
26+
from torch._inductor.decomposition import select_decomp_table
27+
from vllm.compilation.fx_utils import OpOverload
28+
from vllm.config import ModelConfig, VllmConfig, get_current_vllm_config
29+
30+
from vllm_ascend.compilation.compiler_interface import compile_fx
31+
from vllm_ascend.compilation.passes.quant_fusion_pass import \
32+
AddRMSNormQuantFusionPass
33+
34+
35+
class TestModel(nn.Module):
36+
"""
37+
A minimal test model that simulates the pattern:
38+
AddRMSNorm → Quantization
39+
"""
40+
41+
def __init__(self, hidden_size: int, eps: float = 1e-6, device="npu"):
42+
super().__init__()
43+
self.hidden_size = hidden_size
44+
self.eps = eps
45+
self.rms_norm_weight = nn.Parameter(
46+
torch.randn(hidden_size, device=device))
47+
self.quant_scale = torch.tensor([1.0], device=device)
48+
self.quant_offset = torch.tensor([0.0], device=device)
49+
50+
def forward(self, x):
51+
"""
52+
Forward pass:
53+
1. Perform npu_add_rms_norm
54+
2. Quantize the normalized output to int8
55+
Returns both quantized output and updated residual.
56+
"""
57+
residual = torch.zeros_like(x)
58+
59+
norm_output, _, new_residual = torch_npu.npu_add_rms_norm(
60+
x, residual, self.rms_norm_weight, self.eps)
61+
62+
quantized_output = torch_npu.npu_quantize(norm_output,
63+
self.quant_scale,
64+
self.quant_offset,
65+
torch.qint8, -1, False)
66+
67+
return quantized_output, new_residual
68+
69+
def ops_in_model_before(self) -> List[OpOverload]:
70+
"""Return the list of expected operators BEFORE fusion."""
71+
return [
72+
torch.ops.npu.npu_add_rms_norm.default,
73+
torch.ops.npu.npu_quantize.default
74+
]
75+
76+
def ops_in_model_after(self) -> List[OpOverload]:
77+
"""Return the list of expected operators AFTER successful fusion."""
78+
return [torch.ops.npu.npu_add_rms_norm_quant.default]
79+
80+
81+
class TestBackend:
82+
"""
83+
A custom compilation backend for testing operator fusion passes.
84+
It applies the AddRMSNormQuantFusionPass during graph compilation and
85+
records the FX graph before and after the transformation.
86+
"""
87+
88+
def __init__(self):
89+
vllm_config = get_current_vllm_config()
90+
compile_config = vllm_config.compilation_config
91+
self.custom_passes = [
92+
AddRMSNormQuantFusionPass(vllm_config=vllm_config)
93+
]
94+
self.inductor_config = compile_config.inductor_compile_config
95+
self.inductor_config["graph_fusion_manager"] = self.post_pass
96+
97+
# Placeholders to store FX graphs for verification
98+
self.graph_pre_pass = None
99+
self.graph_post_pass = None
100+
101+
def post_pass(self,
102+
graph: fx.Graph,
103+
runtime_shape: int | None = None) -> fx.Graph:
104+
"""
105+
Apply custom graph transformation passes.
106+
"""
107+
self.graph_pre_pass = deepcopy(graph)
108+
for pass_ in self.custom_passes:
109+
pass_(graph)
110+
self.graph_post_pass = deepcopy(graph)
111+
return graph
112+
113+
def compile(
114+
self,
115+
graph: fx.GraphModule,
116+
example_inputs: list[Any],
117+
compiler_config: dict[str, Any],
118+
runtime_shape: Optional[int] = None,
119+
key: Optional[str] = None
120+
) -> tuple[Optional[Callable], Optional[Any]]:
121+
"""
122+
Compile the FX graph using vLLM's Ascend compiler interface.
123+
Wraps the post-pass logic into the inner_compile callback.
124+
"""
125+
126+
def compile_inner(graph, example_inputs):
127+
current_pass_manager = compiler_config["graph_fusion_manager"]
128+
return current_pass_manager(graph, runtime_shape)
129+
130+
decompositions = select_decomp_table()
131+
compiled_fn = compile_fx(
132+
graph=graph,
133+
example_inputs=example_inputs,
134+
inner_compile=compile_inner,
135+
decompositions=decompositions,
136+
)
137+
return compiled_fn, None
138+
139+
def __call__(self, gm: fx.GraphModule, example_inputs: List[Any]):
140+
"""
141+
Make the backend callable by torch.compile().
142+
Returns a compiled executable function.
143+
"""
144+
compiled_fn, _ = self.compile(
145+
gm,
146+
example_inputs,
147+
compiler_config={"graph_fusion_manager": self.post_pass},
148+
runtime_shape=None,
149+
key=None,
150+
)
151+
return compiled_fn
152+
153+
def find_nodes_by_target(self, graph: fx.GraphModule,
154+
target: OpOverload) -> List[fx.Node]:
155+
"""Helper to find all FX nodes that call a specific operator."""
156+
return [
157+
node for node in graph.graph.nodes
158+
if hasattr(node, 'target') and node.target == target
159+
]
160+
161+
def check_before_ops(self,
162+
ops: Sequence[OpOverload],
163+
fully_replaced: bool = True):
164+
"""
165+
Verify that the original (unfused) operators exist before the pass
166+
and are fully removed afterward (if fully_replaced=True).
167+
"""
168+
for op in ops:
169+
num_pre = len(self.find_nodes_by_target(self.graph_pre_pass, op))
170+
num_post = len(self.find_nodes_by_target(self.graph_post_pass, op))
171+
print(f"Op {op}: pre={num_pre}, post={num_post}")
172+
173+
assert num_pre > 0, f"Op {op} not found in pre-pass graph"
174+
if fully_replaced:
175+
assert num_post == 0, f"Unexpected op {op} in post-pass graph: {num_post} nodes remain"
176+
177+
def check_after_ops(self, ops: Sequence[OpOverload]):
178+
"""Verify that the fused operator appears in the transformed graph."""
179+
for op in ops:
180+
num_post = len(self.find_nodes_by_target(self.graph_post_pass, op))
181+
print(f"Op {op}: post={num_post}")
182+
assert num_post > 0, f"Op {op} not found in post-pass graph"
183+
184+
185+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
186+
@pytest.mark.parametrize("hidden_size", [64])
187+
@pytest.mark.parametrize("num_tokens", [257])
188+
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
189+
def test_rmsnorm_quant_fusion(dtype: torch.dtype, hidden_size: int,
190+
num_tokens: int, eps: float):
191+
"""
192+
End-to-end test for AddRMSNorm+Quantize fusion.
193+
Compares: Operator presence/absence before and after graph transformation
194+
"""
195+
torch.set_default_dtype(dtype)
196+
torch.manual_seed(1)
197+
198+
vllm_config = VllmConfig(model_config=ModelConfig(dtype=dtype))
199+
200+
with vllm.config.set_current_vllm_config(vllm_config):
201+
backend = TestBackend()
202+
model = TestModel(hidden_size, eps, device="npu")
203+
model = model.to("npu")
204+
205+
x = torch.rand(num_tokens,
206+
hidden_size,
207+
device="npu",
208+
dtype=dtype,
209+
requires_grad=False)
210+
211+
result_unfused = model(x)
212+
print("Unfused result:", [t.shape for t in result_unfused])
213+
model_fused = torch.compile(model, backend=backend)
214+
result_fused = model_fused(x)
215+
print("Fused result:", [t.shape for t in result_fused])
216+
217+
print("=== Checking operator fusion ===")
218+
backend.check_before_ops(model.ops_in_model_before())
219+
backend.check_after_ops(model.ops_in_model_after())

0 commit comments

Comments
 (0)