Skip to content

Commit 18eefc2

Browse files
[feature] Support W8A8 PD-Mix Quantization (#4235)
In PD-separated deployment scenarios: * MoE layers use dynamic quantization exclusively. * For the Attention module, Prefill (P) nodes use **dynamic** quantization, while Decode (D) nodes use **static** quantization. In PD-mixed deployment scenarios: * **All components fall back to dynamic quantization**, as it is difficult to distinguish between Prefill and Decode tokens. ___ - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 --------- Signed-off-by: SlightwindSec <[email protected]> Signed-off-by: Slightwind <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent ff70613 commit 18eefc2

File tree

6 files changed

+93
-7
lines changed

6 files changed

+93
-7
lines changed

vllm_ascend/ops/fused_moe/fused_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ def forward_impl(self, hidden_states: torch.Tensor,
387387

388388
def transpose_weight(self, loaded_weight, expert_data, shard_dim):
389389
# Ensure training and inference weight shapes match during RL weight updates
390-
if (
390+
if (len(loaded_weight.shape) >= 2 and len(expert_data.shape) >= 2 and \
391391
loaded_weight.shape[1] != expert_data.shape[1] and \
392392
loaded_weight.shape[0] != expert_data.shape[0]
393393
):

vllm_ascend/ops/linear.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -277,18 +277,20 @@ def __init__(
277277
weight_loader=(
278278
self.weight_loader_v2 if self.quant_method.__class__.__name__
279279
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
280+
bias_initialized_by_quant = ("bias" in self._parameters
281+
and self._parameters["bias"] is not None)
280282
if not reduce_results and (bias and not skip_bias_add):
281283
raise ValueError("When not reduce the results, adding bias to the "
282284
"results can lead to incorrect results")
283285

284-
if bias:
286+
if bias and not bias_initialized_by_quant:
285287
self.bias = Parameter(
286288
torch.empty(self.output_size, dtype=params_dtype))
287289
set_weight_attrs(self.bias, {
288290
"output_dim": 0,
289291
"weight_loader": self.weight_loader,
290292
})
291-
else:
293+
elif not bias and not bias_initialized_by_quant:
292294
self.register_parameter("bias", None)
293295

294296
if self.custom_op is not None:
@@ -366,15 +368,17 @@ def __init__(
366368
weight_loader=(
367369
self.weight_loader_v2 if self.quant_method.__class__.__name__
368370
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
369-
if bias:
371+
bias_initialized_by_quant = ("bias" in self._parameters
372+
and self._parameters["bias"] is not None)
373+
if bias and not bias_initialized_by_quant:
370374
self.bias = Parameter(
371375
torch.empty(self.output_size_per_partition,
372376
dtype=params_dtype))
373377
set_weight_attrs(self.bias, {
374378
"output_dim": 0,
375379
"weight_loader": self.weight_loader,
376380
})
377-
else:
381+
elif not bias and not bias_initialized_by_quant:
378382
self.register_parameter("bias", None)
379383

380384
if self.custom_op is not None:
@@ -445,14 +449,16 @@ def __init__(
445449
self.params_dtype,
446450
weight_loader=self.weight_loader)
447451

448-
if bias:
452+
bias_initialized_by_quant = ("bias" in self._parameters
453+
and self._parameters["bias"] is not None)
454+
if bias and not bias_initialized_by_quant:
449455
self.bias = Parameter(
450456
torch.empty(self.output_size, dtype=self.params_dtype))
451457
set_weight_attrs(self.bias, {
452458
"output_dim": 0,
453459
"weight_loader": self.weight_loader,
454460
})
455-
else:
461+
elif not bias and not bias_initialized_by_quant:
456462
self.register_parameter("bias", None)
457463

458464
if self.custom_op is not None:

vllm_ascend/quantization/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
AscendW8A8LinearMethod)
1313
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
1414
AscendW8A8DynamicLinearMethod)
15+
from .w8a8_pdmix import (AscendW8A8PDMixFusedMoeMethod,
16+
AscendW8A8PDMixLinearMethod)
1517

1618
ASCEND_QUANTIZATION_METHOD_MAP: Dict[str, Dict[str, Type[Any]]] = {
1719
"W4A8_DYNAMIC": {
@@ -30,6 +32,10 @@
3032
"linear": AscendW8A8DynamicLinearMethod,
3133
"moe": AscendW8A8DynamicFusedMoEMethod,
3234
},
35+
"W8A8_MIX": {
36+
"linear": AscendW8A8PDMixLinearMethod,
37+
"moe": AscendW8A8PDMixFusedMoeMethod,
38+
},
3339
"C8": {
3440
"attention": AscendC8KVCacheMethod,
3541
},

vllm_ascend/quantization/w8a8.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def get_perchannel_param(
8787
params_dict["weight_offset"] = torch.empty(output_size,
8888
1,
8989
dtype=params_dtype)
90+
params_dict["bias"] = torch.zeros(output_size, dtype=torch.float32)
9091
return params_dict
9192

9293
def get_pergroup_param(self,
@@ -192,6 +193,7 @@ def process_weights_after_loading(self, layer):
192193
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
193194
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
194195
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
196+
layer.bias.data = layer.bias.data.to(layer.weight_scale.data.dtype)
195197
if getattr(layer, "ascend_quant_method",
196198
"") == COMPRESSED_TENSORS_METHOD:
197199
deq_scale = layer.input_scale.data * layer.weight_scale.data

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def get_perchannel_param(
6060
params_dict["weight_offset"] = torch.empty(output_size,
6161
1,
6262
dtype=params_dtype)
63+
params_dict["bias"] = torch.zeros(output_size, dtype=torch.float32)
6364
return params_dict
6465

6566
def get_pergroup_param(self,
@@ -110,6 +111,7 @@ def process_weights_after_loading(self, layer):
110111
layer.weight_scale.data = layer.weight_scale.data.flatten()
111112
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
112113
layer.weight_offset.data = layer.weight_offset.data.flatten()
114+
layer.bias.data = layer.bias.data.to(layer.weight_scale.data.dtype)
113115

114116

115117
class AscendW8A8DynamicFusedMoEMethod:
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from typing import Any, Dict, cast
2+
3+
import torch
4+
from vllm.config import get_current_vllm_config
5+
6+
from .w8a8 import AscendW8A8LinearMethod
7+
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
8+
AscendW8A8DynamicLinearMethod)
9+
10+
11+
class AscendW8A8PDMixLinearMethod(AscendW8A8DynamicLinearMethod):
12+
13+
def __init__(self):
14+
self.kv_transfer_config = get_current_vllm_config().kv_transfer_config
15+
super().__init__()
16+
17+
@staticmethod
18+
def apply(layer, x, bias=None, tp_rank=0):
19+
if layer.is_kv_consumer:
20+
return AscendW8A8LinearMethod.apply(layer, x, bias, tp_rank)
21+
else:
22+
return AscendW8A8DynamicLinearMethod.apply(layer, x, bias, tp_rank)
23+
24+
@staticmethod
25+
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
26+
return AscendW8A8LinearMethod.get_pertensor_param(params_dtype)
27+
28+
@staticmethod
29+
def get_perchannel_param(
30+
output_size: int,
31+
params_dtype: torch.dtype,
32+
) -> Dict[str, Any]:
33+
return AscendW8A8LinearMethod.get_perchannel_param(
34+
output_size, params_dtype)
35+
36+
def process_weights_after_loading(self, layer):
37+
AscendW8A8LinearMethod.process_weights_after_loading(
38+
cast(AscendW8A8LinearMethod, self), layer)
39+
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
40+
layer.is_kv_consumer = self.kv_transfer_config is not None and self.kv_transfer_config.is_kv_consumer
41+
42+
43+
class AscendW8A8PDMixFusedMoeMethod(AscendW8A8DynamicFusedMoEMethod):
44+
45+
def __init__(self):
46+
super().__init__()
47+
48+
@staticmethod
49+
def get_dynamic_quant_param(num_experts: int,
50+
intermediate_size_per_partition: int,
51+
hidden_sizes: int,
52+
params_dtype: torch.dtype) -> Dict[str, Any]:
53+
param_dict = AscendW8A8DynamicFusedMoEMethod.get_dynamic_quant_param(
54+
num_experts, intermediate_size_per_partition, hidden_sizes,
55+
params_dtype)
56+
param_dict["w2_deq_scale"] = torch.empty(num_experts,
57+
hidden_sizes,
58+
dtype=torch.float32)
59+
param_dict["w13_deq_scale"] = torch.empty(
60+
num_experts,
61+
2 * intermediate_size_per_partition,
62+
dtype=torch.float32)
63+
param_dict["w2_input_offset"] = torch.empty(num_experts,
64+
1,
65+
dtype=torch.int8)
66+
param_dict["w13_input_offset"] = torch.empty(num_experts,
67+
1,
68+
dtype=torch.int8)
69+
70+
return param_dict

0 commit comments

Comments
 (0)