Skip to content

Commit bc67696

Browse files
845473182白永斌
andauthored
[EPLB][Ops] Integerate grouped_matmul_swiglu_quant_weight_nz_tensor_list operator into dynamic EPLB (#4216)
### What this PR does / why we need it? Integerate grouped_matmul_swiglu_quant_weight_nz_tensor_list into dynamic EPLB to support list-type parameters This PR also modify the logic of loading model in dynamic-eplb scenario. The operator is based on this pr: #3804 ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ``` vllm serve /home/weight/DeepSeek-V3.1_w8a8mix_mtp \ --max_num_seqs 8 \ --max-model-len 8192 \ --max-num-batched-tokens 16384 \ --tensor-parallel-size 8 \ --data-parallel-size 2 \ --enable-expert-parallel \ --served-model-name ds_r1 \ --enable-auto-tool-choice \ --tool-call-parser hermes \ --no-enable-prefix-caching \ --port 8999 \ --quantization "ascend" \ --gpu-memory-utilization 0.85 \ --trust-remote-code \ --compilation_config '{"cudagraph_capture_sizes":[1,2,4,8,16,32]}' \ --additional-config='{"dynamic_eplb":true, "num_iterations_eplb_update":100, "num_wait_worker_iterations":100}' ``` input&output: 2k 2k This PR: <img width="1318" height="695" alt="fusion" src="https://github.com/user-attachments/assets/f8657813-0c02-42f4-8396-d99e730f48cd" /> Baseline: <img width="1323" height="690" alt="baseline" src="https://github.com/user-attachments/assets/e1323a78-af26-4523-820c-e20e5642a38e" /> - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 --------- Signed-off-by: 白永斌 <[email protected]> Signed-off-by: 欧派果奶我还要 <[email protected]> Co-authored-by: 白永斌 <[email protected]>
1 parent 18eefc2 commit bc67696

File tree

6 files changed

+139
-50
lines changed

6 files changed

+139
-50
lines changed

tests/ut/ops/test_moe_comm_method.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,8 @@ def test_fused_experts_method(self, mock_unified_apply_mlp,
226226
w2 = w2.contiguous()
227227

228228
result = comm_impl.fused_experts(hidden_states=hidden_states,
229-
w1=w1,
230-
w2=w2,
229+
w1=[w1],
230+
w2=[w2],
231231
topk_weights=topk_weights,
232232
topk_ids=topk_ids,
233233
activation="silu")

vllm_ascend/eplb/adaptor/vllm_adaptor.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,22 @@ def __init__(self, model, **args):
4444
self.init_redundancy_expert = get_ascend_config(
4545
).init_redundancy_expert
4646

47+
for i in range(self.num_dense_layers,
48+
self.model.config.num_hidden_layers):
49+
self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w13_weight_list"] = \
50+
self.model.model.layers[i].mlp.experts.w13_weight_list
51+
self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_list"] = \
52+
self.model.model.layers[i].mlp.experts.w2_weight_list
53+
self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w13_weight_scale_fp32_list"] = \
54+
self.model.model.layers[i].mlp.experts.w13_weight_scale_fp32_list
55+
self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_scale_list"] = \
56+
self.model.model.layers[i].mlp.experts.w2_weight_scale_list
4757
# TODO: init self.expert_weight_names depending on different model types, only deepseek v3 w8a8 and qwen3-moe is supported here
4858
if self.model.quant_config is not None:
4959
self.expert_weight_names = [
50-
"w13_weight", "w2_weight", "w13_weight_scale",
51-
"w13_weight_offset", "w2_weight_scale", "w2_weight_offset"
60+
"w13_weight_list", "w2_weight_list",
61+
"w13_weight_scale_fp32_list", "w13_weight_offset",
62+
"w2_weight_scale_list", "w2_weight_offset"
5263
]
5364
else:
5465
self.expert_weight_names = ["w13_weight", "w2_weight"]
@@ -84,9 +95,14 @@ def init_buffer_tensor(self, num_buffer_tensor):
8495
for name in self.expert_weight_names:
8596
complete_name = "model.layers." + str(
8697
self.num_dense_layers) + ".mlp.experts." + name
87-
expert_tensor = self.param_dict[complete_name].data[0]
88-
if name in ["w13_weight", "w2_weight"]:
98+
if name in [
99+
"w13_weight_list", "w2_weight_list",
100+
"w13_weight_scale_fp32_list", "w2_weight_scale_list"
101+
]:
102+
expert_tensor = self.param_dict[complete_name][0]
89103
expert_tensor = expert_tensor.clone()
104+
else:
105+
expert_tensor = self.param_dict[complete_name][0].data[0]
90106
buffer_tensor = torch.empty_like(expert_tensor)
91107
self.buffer_tensor_list[buffer_id].append(buffer_tensor)
92108

@@ -97,12 +113,23 @@ def init_expert_param_per_layer(self):
97113
layer_idx = self.num_dense_layers + moe_layer_id
98114
self.expert_param_per_layer[layer_idx] = list()
99115
for local_expert_id in range(num_local_expert):
100-
self.expert_param_per_layer[layer_idx].append([
101-
self.param_dict["model.layers." + str(layer_idx) +
102-
".mlp.experts." +
103-
name].data[local_expert_id]
104-
for name in self.expert_weight_names
105-
])
116+
per_expert_param = list()
117+
for name in self.expert_weight_names:
118+
if name in [
119+
"w13_weight_list", "w2_weight_list",
120+
"w13_weight_scale_fp32_list",
121+
"w2_weight_scale_list"
122+
]:
123+
per_expert_param.append(
124+
self.param_dict["model.layers." + str(layer_idx) +
125+
".mlp.experts." +
126+
name][local_expert_id])
127+
else:
128+
per_expert_param.append(
129+
self.param_dict["model.layers." + str(layer_idx) +
130+
".mlp.experts." +
131+
name][0].data[local_expert_id])
132+
self.expert_param_per_layer[layer_idx].append(per_expert_param)
106133

107134
def get_rank_expert_workload(self) -> torch.Tensor:
108135
self.moe_load = self.model.get_all_moe_loads()

vllm_ascend/ops/fused_moe/moe_comm_method.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ def finalize(self,
8383
def fused_experts(
8484
self,
8585
hidden_states: torch.Tensor,
86-
w1: torch.Tensor,
87-
w2: torch.Tensor,
86+
w1: torch.Tensor | list[torch.Tensor],
87+
w2: torch.Tensor | list[torch.Tensor],
8888
topk_weights: torch.Tensor,
8989
topk_ids: torch.Tensor,
9090
activation: str = "silu",
@@ -93,8 +93,8 @@ def fused_experts(
9393
use_int4_w4a8: bool = False,
9494
global_num_experts: Optional[int] = None,
9595
expert_map: Optional[torch.Tensor] = None,
96-
w1_scale: Optional[torch.Tensor] = None,
97-
w2_scale: Optional[torch.Tensor] = None,
96+
w1_scale: Optional[list[torch.Tensor]] = None,
97+
w2_scale: Optional[list[torch.Tensor]] = None,
9898
w1_scale_bias: torch.Tensor = None,
9999
w2_scale_bias: torch.Tensor = None,
100100
# For TorchAir graph

vllm_ascend/ops/fused_moe/moe_mlp.py

Lines changed: 55 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,11 @@
2323

2424
from vllm_ascend.ascend_forward_context import MoECommType
2525
from vllm_ascend.utils import (AscendDeviceType, dispose_tensor,
26-
get_ascend_device_type)
26+
enable_custom_op, get_ascend_device_type)
27+
28+
29+
def _custom_gmm_swiglu_enabled(fusion, dynamic_eplb):
30+
return fusion and dynamic_eplb and enable_custom_op()
2731

2832

2933
def cumsum_group_list(group_list: torch.Tensor,
@@ -55,10 +59,10 @@ def cumsum_group_list(group_list: torch.Tensor,
5559

5660

5761
def quant_apply_mlp(hidden_states: torch.Tensor,
58-
w1: torch.Tensor,
59-
w1_scale: torch.Tensor,
60-
w2: torch.Tensor,
61-
w2_scale: torch.Tensor,
62+
w1: list[torch.Tensor],
63+
w1_scale: list[torch.Tensor],
64+
w2: list[torch.Tensor],
65+
w2_scale: list[torch.Tensor],
6266
group_list: torch.Tensor,
6367
group_list_type: int = 1,
6468
dynamic_scale: torch.Tensor = None,
@@ -79,31 +83,42 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
7983
quantized_hidden_states = hidden_states
8084

8185
bias1, bias2 = None, None
82-
_output_dtype = w2_scale.dtype
86+
_output_dtype = w2_scale[0].dtype
8387

8488
weight_prefetch_method = get_forward_context().weight_prefetch_method
8589
if weight_prefetch_method:
8690
weight_prefetch_method.maybe_prefetch_moe_weight_postprocess(
8791
hidden_states)
8892
is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2
8993
if w1_scale_bias is None and is_mc2:
90-
if fusion and not dynamic_eplb:
94+
if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb):
95+
# gmm1: gate_up_proj & act_fn: swiglu
96+
hidden_states, swiglu_out_scale, _ = (
97+
torch.ops._C_ascend.
98+
grouped_matmul_swiglu_quant_weight_nz_tensor_list(
99+
x=hidden_states,
100+
weight=w1,
101+
weight_scale=w1_scale,
102+
x_scale=pertoken_scale,
103+
group_list=cumsum_group_list(group_list, group_list_type),
104+
))
105+
elif fusion and not dynamic_eplb:
91106
# gmm1: gate_up_proj & act_fn: swiglu
92107
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
93108
x=hidden_states,
94-
weight=w1,
109+
weight=w1[0],
95110
group_list=cumsum_group_list(group_list, group_list_type),
96-
weight_scale=w1_scale,
111+
weight_scale=w1_scale[0],
97112
x_scale=pertoken_scale)
98113
if quantized_hidden_states is not None:
99114
dispose_tensor(quantized_hidden_states)
100115
else:
101-
if w1_scale.dtype != torch.float32:
102-
w1_scale = w1_scale.to(torch.float32)
116+
if w1_scale[0].dtype != torch.float32:
117+
w1_scale[0] = w1_scale[0].to(torch.float32)
103118
# gmm1: gate_up_proj
104119
hidden_states = torch_npu.npu_grouped_matmul(
105120
x=[hidden_states],
106-
weight=[w1],
121+
weight=w1,
107122
split_item=3,
108123
group_list_type=group_list_type,
109124
group_type=0,
@@ -126,14 +141,14 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
126141
# gmm2: down_proj
127142
hidden_states = torch_npu.npu_grouped_matmul(
128143
x=[hidden_states],
129-
weight=[w2],
130-
scale=[w2_scale],
144+
weight=w2,
145+
scale=w2_scale,
131146
per_token_scale=[swiglu_out_scale],
132147
split_item=2,
133148
group_list_type=group_list_type,
134149
group_type=0,
135150
group_list=group_list,
136-
output_dtype=w2_scale.dtype)[0]
151+
output_dtype=w2_scale[0].dtype)[0]
137152
else:
138153
if w1_scale_bias is not None:
139154
if group_list_type == 0:
@@ -146,23 +161,36 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
146161
# TODO w4a8 scene: dynamic acquisition of dtype in the future
147162
_output_dtype = torch.bfloat16
148163

149-
if fusion and not dynamic_eplb:
164+
if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb):
165+
# gmm1: gate_up_proj & act_fn: swiglu
166+
hidden_states, swiglu_out_scale, _ = (
167+
torch.ops._C_ascend.
168+
grouped_matmul_swiglu_quant_weight_nz_tensor_list(
169+
x=hidden_states,
170+
weight=w1,
171+
weight_scale=w1_scale,
172+
x_scale=pertoken_scale,
173+
group_list=cumsum_group_list(group_list, group_list_type),
174+
bias=bias1,
175+
))
176+
elif fusion and not dynamic_eplb:
150177
# gmm1: gate_up_proj & act_fn: swiglu
151178
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
152179
x=hidden_states,
153-
weight=w1,
180+
weight=w1[0],
154181
bias=bias1,
155182
group_list=cumsum_group_list(group_list, group_list_type),
156-
weight_scale=w1_scale,
183+
weight_scale=w1_scale[0],
157184
x_scale=pertoken_scale)
158185
if quantized_hidden_states is not None:
159186
dispose_tensor(quantized_hidden_states)
160187
else:
188+
w1_scale[0] = w1_scale[0].to(w2_scale[0].dtype)
161189
# gmm1: gate_up_proj
162190
hidden_states = torch_npu.npu_grouped_matmul(
163191
x=[hidden_states],
164-
weight=[w1],
165-
scale=[w1_scale.to(w2_scale.dtype)],
192+
weight=w1,
193+
scale=w1_scale,
166194
bias=bias1,
167195
per_token_scale=[pertoken_scale],
168196
split_item=2,
@@ -179,8 +207,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
179207
# gmm2: down_proj
180208
hidden_states = torch_npu.npu_grouped_matmul(
181209
x=[hidden_states],
182-
weight=[w2],
183-
scale=[w2_scale],
210+
weight=w2,
211+
scale=w2_scale,
184212
bias=bias2,
185213
per_token_scale=[swiglu_out_scale],
186214
split_item=2,
@@ -232,11 +260,11 @@ def unquant_apply_mlp(hidden_states: torch.Tensor,
232260

233261

234262
def unified_apply_mlp(hidden_states: torch.Tensor,
235-
w1: torch.Tensor,
236-
w1_scale: torch.Tensor,
237-
w2: torch.Tensor,
238-
w2_scale: torch.Tensor,
263+
w1: torch.Tensor | list[torch.Tensor],
264+
w2: torch.Tensor | list[torch.Tensor],
239265
group_list: torch.Tensor,
266+
w1_scale: Optional[list[torch.Tensor]] = None,
267+
w2_scale: Optional[list[torch.Tensor]] = None,
240268
dynamic_scale: torch.Tensor = None,
241269
group_list_type: int = 1,
242270
w1_scale_bias: torch.Tensor = None,
@@ -247,6 +275,7 @@ def unified_apply_mlp(hidden_states: torch.Tensor,
247275
need_trans: bool = True,
248276
dynamic_eplb: bool = False) -> torch.Tensor:
249277
if with_quant:
278+
assert w1_scale is not None and w2_scale is not None
250279
return quant_apply_mlp(hidden_states=hidden_states,
251280
w1=w1,
252281
w1_scale=w1_scale,

vllm_ascend/quantization/w4a8_dynamic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -379,10 +379,10 @@ def apply(
379379
moe_comm_method = get_forward_context().moe_comm_method
380380
return moe_comm_method.fused_experts(
381381
hidden_states=x,
382-
w1=layer.w13_weight,
383-
w2=layer.w2_weight,
384-
w1_scale=layer.w13_weight_scale,
385-
w2_scale=layer.w2_weight_scale,
382+
w1=[layer.w13_weight],
383+
w2=[layer.w2_weight],
384+
w1_scale=[layer.w13_weight_scale],
385+
w2_scale=[layer.w2_weight_scale],
386386
w1_scale_bias=layer.w13_scale_bias,
387387
w2_scale_bias=layer.w2_scale_bias,
388388
topk_weights=topk_weights,

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -236,13 +236,24 @@ def apply(
236236
topk_weights = topk_weights.to(self.in_dtype)
237237

238238
moe_comm_method = get_forward_context().moe_comm_method
239+
if self.dynamic_eplb:
240+
w1 = layer.w13_weight_list
241+
w1_scale = layer.w13_weight_scale_fp32_list
242+
w2 = layer.w2_weight_list
243+
w2_scale = layer.w2_weight_scale_list
244+
else:
245+
w1 = [layer.w13_weight]
246+
w1_scale = [layer.w13_weight_scale_fp32]
247+
w2 = [layer.w2_weight]
248+
w2_scale = [layer.w2_weight_scale]
249+
239250
return moe_comm_method.fused_experts(
240251
hidden_states=x,
241252
pertoken_scale=pertoken_scale,
242-
w1=layer.w13_weight,
243-
w1_scale=layer.w13_weight_scale_fp32,
244-
w2=layer.w2_weight,
245-
w2_scale=layer.w2_weight_scale,
253+
w1=w1,
254+
w1_scale=w1_scale,
255+
w2=w2,
256+
w2_scale=w2_scale,
246257
topk_weights=topk_weights,
247258
topk_ids=topk_ids,
248259
use_int8_w8a8=True,
@@ -274,3 +285,25 @@ def process_weights_after_loading(self, layer):
274285
layer.w2_weight_scale.data.shape[0], -1)
275286
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
276287
layer.w2_weight_offset.data.shape[0], -1)
288+
if self.dynamic_eplb:
289+
layer.w13_weight_list = [
290+
weight.clone()
291+
for weight in layer.w13_weight.data.unbind(dim=0)
292+
]
293+
layer.w2_weight_list = [
294+
weight.clone() for weight in layer.w2_weight.data.unbind(dim=0)
295+
]
296+
layer.w13_weight_scale_fp32_list = [
297+
weight.clone()
298+
for weight in layer.w13_weight_scale.data.unbind(dim=0)
299+
]
300+
layer.w2_weight_scale_list = [
301+
weight.clone()
302+
for weight in layer.w2_weight_scale.data.unbind(dim=0)
303+
]
304+
del layer.w13_weight
305+
del layer.w2_weight
306+
del layer.w13_weight_scale
307+
del layer.w13_weight_scale_fp32
308+
del layer.w2_weight_scale
309+
torch.npu.empty_cache()

0 commit comments

Comments
 (0)