2323
2424from vllm_ascend .ascend_forward_context import MoECommType
2525from vllm_ascend .utils import (AscendDeviceType , dispose_tensor ,
26- get_ascend_device_type )
26+ enable_custom_op , get_ascend_device_type )
27+
28+
29+ def _custom_gmm_swiglu_enabled (fusion , dynamic_eplb ):
30+ return fusion and dynamic_eplb and enable_custom_op ()
2731
2832
2933def cumsum_group_list (group_list : torch .Tensor ,
@@ -55,10 +59,10 @@ def cumsum_group_list(group_list: torch.Tensor,
5559
5660
5761def quant_apply_mlp (hidden_states : torch .Tensor ,
58- w1 : torch .Tensor ,
59- w1_scale : torch .Tensor ,
60- w2 : torch .Tensor ,
61- w2_scale : torch .Tensor ,
62+ w1 : list [ torch .Tensor ] ,
63+ w1_scale : list [ torch .Tensor ] ,
64+ w2 : list [ torch .Tensor ] ,
65+ w2_scale : list [ torch .Tensor ] ,
6266 group_list : torch .Tensor ,
6367 group_list_type : int = 1 ,
6468 dynamic_scale : torch .Tensor = None ,
@@ -79,31 +83,42 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
7983 quantized_hidden_states = hidden_states
8084
8185 bias1 , bias2 = None , None
82- _output_dtype = w2_scale .dtype
86+ _output_dtype = w2_scale [ 0 ] .dtype
8387
8488 weight_prefetch_method = get_forward_context ().weight_prefetch_method
8589 if weight_prefetch_method :
8690 weight_prefetch_method .maybe_prefetch_moe_weight_postprocess (
8791 hidden_states )
8892 is_mc2 = get_forward_context ().moe_comm_type == MoECommType .MC2
8993 if w1_scale_bias is None and is_mc2 :
90- if fusion and not dynamic_eplb :
94+ if _custom_gmm_swiglu_enabled (fusion , dynamic_eplb ):
95+ # gmm1: gate_up_proj & act_fn: swiglu
96+ hidden_states , swiglu_out_scale , _ = (
97+ torch .ops ._C_ascend .
98+ grouped_matmul_swiglu_quant_weight_nz_tensor_list (
99+ x = hidden_states ,
100+ weight = w1 ,
101+ weight_scale = w1_scale ,
102+ x_scale = pertoken_scale ,
103+ group_list = cumsum_group_list (group_list , group_list_type ),
104+ ))
105+ elif fusion and not dynamic_eplb :
91106 # gmm1: gate_up_proj & act_fn: swiglu
92107 hidden_states , swiglu_out_scale , _ = torch_npu .npu_grouped_matmul_swiglu_quant (
93108 x = hidden_states ,
94- weight = w1 ,
109+ weight = w1 [ 0 ] ,
95110 group_list = cumsum_group_list (group_list , group_list_type ),
96- weight_scale = w1_scale ,
111+ weight_scale = w1_scale [ 0 ] ,
97112 x_scale = pertoken_scale )
98113 if quantized_hidden_states is not None :
99114 dispose_tensor (quantized_hidden_states )
100115 else :
101- if w1_scale .dtype != torch .float32 :
102- w1_scale = w1_scale .to (torch .float32 )
116+ if w1_scale [ 0 ] .dtype != torch .float32 :
117+ w1_scale [ 0 ] = w1_scale [ 0 ] .to (torch .float32 )
103118 # gmm1: gate_up_proj
104119 hidden_states = torch_npu .npu_grouped_matmul (
105120 x = [hidden_states ],
106- weight = [ w1 ] ,
121+ weight = w1 ,
107122 split_item = 3 ,
108123 group_list_type = group_list_type ,
109124 group_type = 0 ,
@@ -126,14 +141,14 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
126141 # gmm2: down_proj
127142 hidden_states = torch_npu .npu_grouped_matmul (
128143 x = [hidden_states ],
129- weight = [ w2 ] ,
130- scale = [ w2_scale ] ,
144+ weight = w2 ,
145+ scale = w2_scale ,
131146 per_token_scale = [swiglu_out_scale ],
132147 split_item = 2 ,
133148 group_list_type = group_list_type ,
134149 group_type = 0 ,
135150 group_list = group_list ,
136- output_dtype = w2_scale .dtype )[0 ]
151+ output_dtype = w2_scale [ 0 ] .dtype )[0 ]
137152 else :
138153 if w1_scale_bias is not None :
139154 if group_list_type == 0 :
@@ -146,23 +161,36 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
146161 # TODO w4a8 scene: dynamic acquisition of dtype in the future
147162 _output_dtype = torch .bfloat16
148163
149- if fusion and not dynamic_eplb :
164+ if _custom_gmm_swiglu_enabled (fusion , dynamic_eplb ):
165+ # gmm1: gate_up_proj & act_fn: swiglu
166+ hidden_states , swiglu_out_scale , _ = (
167+ torch .ops ._C_ascend .
168+ grouped_matmul_swiglu_quant_weight_nz_tensor_list (
169+ x = hidden_states ,
170+ weight = w1 ,
171+ weight_scale = w1_scale ,
172+ x_scale = pertoken_scale ,
173+ group_list = cumsum_group_list (group_list , group_list_type ),
174+ bias = bias1 ,
175+ ))
176+ elif fusion and not dynamic_eplb :
150177 # gmm1: gate_up_proj & act_fn: swiglu
151178 hidden_states , swiglu_out_scale , _ = torch_npu .npu_grouped_matmul_swiglu_quant (
152179 x = hidden_states ,
153- weight = w1 ,
180+ weight = w1 [ 0 ] ,
154181 bias = bias1 ,
155182 group_list = cumsum_group_list (group_list , group_list_type ),
156- weight_scale = w1_scale ,
183+ weight_scale = w1_scale [ 0 ] ,
157184 x_scale = pertoken_scale )
158185 if quantized_hidden_states is not None :
159186 dispose_tensor (quantized_hidden_states )
160187 else :
188+ w1_scale [0 ] = w1_scale [0 ].to (w2_scale [0 ].dtype )
161189 # gmm1: gate_up_proj
162190 hidden_states = torch_npu .npu_grouped_matmul (
163191 x = [hidden_states ],
164- weight = [ w1 ] ,
165- scale = [ w1_scale . to ( w2_scale . dtype )] ,
192+ weight = w1 ,
193+ scale = w1_scale ,
166194 bias = bias1 ,
167195 per_token_scale = [pertoken_scale ],
168196 split_item = 2 ,
@@ -179,8 +207,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
179207 # gmm2: down_proj
180208 hidden_states = torch_npu .npu_grouped_matmul (
181209 x = [hidden_states ],
182- weight = [ w2 ] ,
183- scale = [ w2_scale ] ,
210+ weight = w2 ,
211+ scale = w2_scale ,
184212 bias = bias2 ,
185213 per_token_scale = [swiglu_out_scale ],
186214 split_item = 2 ,
@@ -232,11 +260,11 @@ def unquant_apply_mlp(hidden_states: torch.Tensor,
232260
233261
234262def unified_apply_mlp (hidden_states : torch .Tensor ,
235- w1 : torch .Tensor ,
236- w1_scale : torch .Tensor ,
237- w2 : torch .Tensor ,
238- w2_scale : torch .Tensor ,
263+ w1 : torch .Tensor | list [torch .Tensor ],
264+ w2 : torch .Tensor | list [torch .Tensor ],
239265 group_list : torch .Tensor ,
266+ w1_scale : Optional [list [torch .Tensor ]] = None ,
267+ w2_scale : Optional [list [torch .Tensor ]] = None ,
240268 dynamic_scale : torch .Tensor = None ,
241269 group_list_type : int = 1 ,
242270 w1_scale_bias : torch .Tensor = None ,
@@ -247,6 +275,7 @@ def unified_apply_mlp(hidden_states: torch.Tensor,
247275 need_trans : bool = True ,
248276 dynamic_eplb : bool = False ) -> torch .Tensor :
249277 if with_quant :
278+ assert w1_scale is not None and w2_scale is not None
250279 return quant_apply_mlp (hidden_states = hidden_states ,
251280 w1 = w1 ,
252281 w1_scale = w1_scale ,
0 commit comments