2323
2424from vllm_ascend .ascend_forward_context import MoECommType
2525from vllm_ascend .utils import (AscendDeviceType , dispose_tensor ,
26- get_ascend_device_type )
26+ enable_custom_op , get_ascend_device_type )
27+
28+
29+ def _custom_gmm_swiglu_enabled (fusion , dynamic_eplb ):
30+ return fusion and dynamic_eplb and enable_custom_op ()
2731
2832
2933def cumsum_group_list (group_list : torch .Tensor ,
@@ -55,10 +59,10 @@ def cumsum_group_list(group_list: torch.Tensor,
5559
5660
5761def quant_apply_mlp (hidden_states : torch .Tensor ,
58- w1 : torch .Tensor ,
59- w1_scale : torch .Tensor ,
60- w2 : torch .Tensor ,
61- w2_scale : torch .Tensor ,
62+ w1 : list [ torch .Tensor ] ,
63+ w1_scale : list [ torch .Tensor ] ,
64+ w2 : list [ torch .Tensor ] ,
65+ w2_scale : list [ torch .Tensor ] ,
6266 group_list : torch .Tensor ,
6367 group_list_type : int = 1 ,
6468 dynamic_scale : torch .Tensor = None ,
@@ -79,31 +83,42 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
7983 quantized_hidden_states = hidden_states
8084
8185 bias1 , bias2 = None , None
82- _output_dtype = w2_scale .dtype
86+ _output_dtype = w2_scale [ 0 ] .dtype
8387
8488 weight_prefetch_method = get_forward_context ().weight_prefetch_method
8589 if weight_prefetch_method :
8690 weight_prefetch_method .maybe_prefetch_moe_weight_postprocess (
8791 hidden_states )
8892 is_mc2 = get_forward_context ().moe_comm_type == MoECommType .MC2
8993 if w1_scale_bias is None and is_mc2 :
90- if fusion and not dynamic_eplb :
94+ if _custom_gmm_swiglu_enabled (fusion , dynamic_eplb ):
95+ # gmm1: gate_up_proj & act_fn: swiglu
96+ hidden_states , swiglu_out_scale , _ = (
97+ torch .ops ._C_ascend .
98+ grouped_matmul_swiglu_quant_weight_nz_tensor_list (
99+ x = hidden_states ,
100+ weight = w1 ,
101+ weight_scale = w1_scale ,
102+ x_scale = pertoken_scale ,
103+ group_list = cumsum_group_list (group_list , group_list_type ),
104+ ))
105+ elif fusion and not dynamic_eplb :
91106 # gmm1: gate_up_proj & act_fn: swiglu
92107 hidden_states , swiglu_out_scale , _ = torch_npu .npu_grouped_matmul_swiglu_quant (
93108 x = hidden_states ,
94- weight = w1 ,
109+ weight = w1 [ 0 ] ,
95110 group_list = cumsum_group_list (group_list , group_list_type ),
96- weight_scale = w1_scale ,
111+ weight_scale = w1_scale [ 0 ] ,
97112 x_scale = pertoken_scale )
98113 if quantized_hidden_states is not None :
99114 dispose_tensor (quantized_hidden_states )
100115 else :
101- if w1_scale .dtype != torch .float32 :
102- w1_scale = w1_scale .to (torch .float32 )
116+ if w1_scale [ 0 ] .dtype != torch .float32 :
117+ w1_scale [ 0 ] = w1_scale [ 0 ] .to (torch .float32 )
103118 # gmm1: gate_up_proj
104119 hidden_states = torch_npu .npu_grouped_matmul (
105120 x = [hidden_states ],
106- weight = [ w1 ] ,
121+ weight = w1 ,
107122 split_item = 3 ,
108123 group_list_type = group_list_type ,
109124 group_type = 0 ,
@@ -128,14 +143,14 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
128143 # gmm2: down_proj
129144 hidden_states = torch_npu .npu_grouped_matmul (
130145 x = [hidden_states ],
131- weight = [ w2 ] ,
132- scale = [ w2_scale ] ,
146+ weight = w2 ,
147+ scale = w2_scale ,
133148 per_token_scale = [swiglu_out_scale ],
134149 split_item = 2 ,
135150 group_list_type = group_list_type ,
136151 group_type = 0 ,
137152 group_list = group_list ,
138- output_dtype = w2_scale .dtype )[0 ]
153+ output_dtype = w2_scale [ 0 ] .dtype )[0 ]
139154 else :
140155 if w1_scale_bias is not None :
141156 if group_list_type == 0 :
@@ -148,23 +163,36 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
148163 # TODO w4a8 scene: dynamic acquisition of dtype in the future
149164 _output_dtype = torch .bfloat16
150165
151- if fusion and not dynamic_eplb :
166+ if _custom_gmm_swiglu_enabled (fusion , dynamic_eplb ):
167+ # gmm1: gate_up_proj & act_fn: swiglu
168+ hidden_states , swiglu_out_scale , _ = (
169+ torch .ops ._C_ascend .
170+ grouped_matmul_swiglu_quant_weight_nz_tensor_list (
171+ x = hidden_states ,
172+ weight = w1 ,
173+ weight_scale = w1_scale ,
174+ x_scale = pertoken_scale ,
175+ group_list = cumsum_group_list (group_list , group_list_type ),
176+ bias = bias1 ,
177+ ))
178+ elif fusion and not dynamic_eplb :
152179 # gmm1: gate_up_proj & act_fn: swiglu
153180 hidden_states , swiglu_out_scale , _ = torch_npu .npu_grouped_matmul_swiglu_quant (
154181 x = hidden_states ,
155- weight = w1 ,
182+ weight = w1 [ 0 ] ,
156183 bias = bias1 ,
157184 group_list = cumsum_group_list (group_list , group_list_type ),
158- weight_scale = w1_scale ,
185+ weight_scale = w1_scale [ 0 ] ,
159186 x_scale = pertoken_scale )
160187 if quantized_hidden_states is not None :
161188 dispose_tensor (quantized_hidden_states )
162189 else :
190+ w1_scale [0 ] = w1_scale [0 ].to (w2_scale [0 ].dtype )
163191 # gmm1: gate_up_proj
164192 hidden_states = torch_npu .npu_grouped_matmul (
165193 x = [hidden_states ],
166- weight = [ w1 ] ,
167- scale = [ w1_scale . to ( w2_scale . dtype )] ,
194+ weight = w1 ,
195+ scale = w1_scale ,
168196 bias = bias1 ,
169197 per_token_scale = [pertoken_scale ],
170198 split_item = 2 ,
@@ -181,8 +209,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
181209 # gmm2: down_proj
182210 hidden_states = torch_npu .npu_grouped_matmul (
183211 x = [hidden_states ],
184- weight = [ w2 ] ,
185- scale = [ w2_scale ] ,
212+ weight = w2 ,
213+ scale = w2_scale ,
186214 bias = bias2 ,
187215 per_token_scale = [swiglu_out_scale ],
188216 split_item = 2 ,
@@ -234,11 +262,11 @@ def unquant_apply_mlp(hidden_states: torch.Tensor,
234262
235263
236264def unified_apply_mlp (hidden_states : torch .Tensor ,
237- w1 : torch .Tensor ,
238- w1_scale : torch .Tensor ,
239- w2 : torch .Tensor ,
240- w2_scale : torch .Tensor ,
265+ w1 : torch .Tensor | list [torch .Tensor ],
266+ w2 : torch .Tensor | list [torch .Tensor ],
241267 group_list : torch .Tensor ,
268+ w1_scale : Optional [list [torch .Tensor ]] = None ,
269+ w2_scale : Optional [list [torch .Tensor ]] = None ,
242270 dynamic_scale : torch .Tensor = None ,
243271 group_list_type : int = 1 ,
244272 w1_scale_bias : torch .Tensor = None ,
@@ -249,6 +277,7 @@ def unified_apply_mlp(hidden_states: torch.Tensor,
249277 need_trans : bool = True ,
250278 dynamic_eplb : bool = False ) -> torch .Tensor :
251279 if with_quant :
280+ assert w1_scale is not None and w2_scale is not None
252281 return quant_apply_mlp (hidden_states = hidden_states ,
253282 w1 = w1 ,
254283 w1_scale = w1_scale ,
0 commit comments