11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
4+
45import torch
56
67import vllm .model_executor .layers .fused_moe .modular_kernel as mk
1213from vllm .model_executor .layers .fused_moe .topk_weight_and_reduce import (
1314 TopKWeightAndReduceNoOP ,
1415)
16+ from vllm .model_executor .layers .fused_moe .utils import _resize_cache
1517from vllm .triton_utils import tl , triton
1618from vllm .utils .import_utils import has_triton_kernels
1719
@@ -88,14 +90,17 @@ def triton_kernel_moe_forward(
8890 gating_output , topk , sm_first = not renormalize
8991 )
9092
93+ output = torch .empty_like (hidden_states )
94+
9195 return triton_kernel_fused_experts (
92- None ,
96+ output ,
9397 hidden_states ,
9498 w1 ,
9599 w2 ,
96100 routing_data ,
97101 gather_idx ,
98102 scatter_idx ,
103+ topk = topk ,
99104 activation = activation ,
100105 quant_config = quant_config ,
101106 apply_router_weight_on_input = apply_router_weight_on_input ,
@@ -113,13 +118,15 @@ def triton_kernel_fused_experts(
113118 routing_data , # RoutingData
114119 gather_indx , # GatherIndx
115120 scatter_indx , # ScatterIndx
121+ topk : int ,
116122 activation : str = "silu" ,
117123 quant_config : FusedMoEQuantConfig | None = None ,
118124 swiglu_alpha : float = 1.702 ,
119125 swiglu_limit : float = 7.0 ,
120126 apply_router_weight_on_input : bool = False ,
121127 global_num_experts : int = - 1 ,
122128 expert_map : torch .Tensor | None = None ,
129+ intermediate_cache : torch .Tensor | None = None ,
123130 a1q_scale : torch .Tensor | None = None ,
124131) -> torch .Tensor :
125132 if quant_config is None :
@@ -131,22 +138,38 @@ def triton_kernel_fused_experts(
131138 assert quant_config .w2_bias is None or quant_config .w2_bias .dtype == torch .float32
132139
133140 # Shape check, only check non-mxfp4
141+ assert hidden_states .ndim == 2
134142 assert hidden_states .shape [- 1 ] == w1 .shape [- 2 ]
135143 assert w2 .shape [- 1 ] == w1 .shape [1 ]
136144
145+ batch_dim = 1
146+ M , K = hidden_states .shape [- 2 :]
137147 E , _ , N = w1 .shape
138148
139149 if global_num_experts == - 1 :
140150 global_num_experts = E
141151
152+ if intermediate_cache is None :
153+ intermediate_cache = torch .empty (
154+ (batch_dim , M * topk , N // 2 ),
155+ device = hidden_states .device ,
156+ dtype = hidden_states .dtype ,
157+ )
158+
159+ # Add batch_dim to output buffer because matmul_ogs expects 3D output
160+ intermediate_cache = _resize_cache (
161+ intermediate_cache , (batch_dim , M * topk , N // 2 )
162+ )
163+ output_tensor = _resize_cache (output_tensor , (batch_dim , M , K ))
164+
142165 act = FusedActivation (
143166 FnSpecs ("swiglu" , triton_kernels .swiglu .swiglu_fn , ("alpha" , "limit" )),
144167 (swiglu_alpha , swiglu_limit ),
145168 2 ,
146169 )
147170 gammas = routing_data .gate_scal if routing_data else None
148171
149- intermediate_cache1 = matmul_ogs (
172+ matmul_ogs (
150173 hidden_states ,
151174 w1 ,
152175 quant_config .w1_bias ,
@@ -155,10 +178,11 @@ def triton_kernel_fused_experts(
155178 precision_config = quant_config .w1_precision ,
156179 gammas = gammas if apply_router_weight_on_input else None ,
157180 fused_activation = act ,
181+ y = intermediate_cache ,
158182 )
159183
160- intermediate_cache3 = matmul_ogs (
161- intermediate_cache1 ,
184+ matmul_ogs (
185+ intermediate_cache . view ( M * topk , N // 2 ) ,
162186 w2 ,
163187 quant_config .w2_bias ,
164188 routing_data ,
@@ -167,7 +191,8 @@ def triton_kernel_fused_experts(
167191 gammas = None if apply_router_weight_on_input else gammas ,
168192 y = output_tensor ,
169193 )
170- return intermediate_cache3
194+ output_tensor = output_tensor .view (M , K )
195+ return output_tensor
171196
172197
173198def make_routing_data (
@@ -221,6 +246,42 @@ def __init__(self, quant_config: FusedMoEQuantConfig):
221246 def supports_expert_map (self ) -> bool :
222247 return True
223248
249+ def moe_problem_size (
250+ self ,
251+ a1 : torch .Tensor ,
252+ w1 : torch .Tensor ,
253+ w2 : torch .Tensor ,
254+ topk_ids : torch .Tensor ,
255+ ) -> tuple [int , int , int , int , int ]:
256+ """
257+ Extract the MoE problem size from the given tensor arguments:
258+ - a: The hidden states, input to the MoE layer.
259+ - w1: The first set of expert weights.
260+ - w2: The second set of expert weights.
261+ - topk_ids: The topk ids.
262+ Note: extracting the problem shape from the weight and activation
263+ tensors is not obvious. It needs to be done this way specifically
264+ due to subtle issues with particular kernels, e.g. the int4 kernels
265+ divide the trailing dimension by two, so it's not "correct" to
266+ extract N or K from the trailing dimension of w1 or w2. Similarly,
267+ some kernels transpose the weights, so this needs to be kept in mind.
268+ Note: This implementation covers most cases. However, if experts
269+ require a specialized implementation, like MarlinExperts, they are free
270+ to override this function.
271+ """
272+ assert w1 .dim () == 3 and w2 .dim () == 3
273+ E , _ , N = w1 .size ()
274+ K = a1 .size (- 1 )
275+
276+ assert a1 .dim () == 2
277+ assert topk_ids .size (0 ) == a1 .size (0 ), f"{ topk_ids .size (0 )} != { a1 .size (0 )} "
278+ M = a1 .size (0 )
279+
280+ assert topk_ids .dim () == 2
281+ topk = topk_ids .size (1 )
282+
283+ return E , M , N , K , topk
284+
224285 def finalize_weight_and_reduce_impl (self ) -> mk .TopKWeightAndReduce :
225286 # Weight application and reduction happens in the fused_experts kernel.
226287 return TopKWeightAndReduceNoOP ()
@@ -263,8 +324,8 @@ def workspace_shapes(
263324 expert_tokens_meta : mk .ExpertTokensMetadata | None ,
264325 ) -> tuple [tuple [int , ...], tuple [int , ...], tuple [int , ...]]:
265326 # workspace are allocated inside the kernel
266- workspace1 = (M , K )
267- workspace2 = (0 , 0 )
327+ workspace1 = (0 , 0 )
328+ workspace2 = (M * topk , N // 2 )
268329 output = (M , K )
269330 return (workspace1 , workspace2 , output )
270331
@@ -297,20 +358,21 @@ def apply(
297358 topk_ids , topk_weights , local_num_experts
298359 )
299360
300- experts_output = triton_kernel_fused_experts (
301- None ,
361+ topk = topk_ids .size (1 )
362+ triton_kernel_fused_experts (
363+ output ,
302364 hidden_states ,
303365 w1 ,
304366 w2 ,
305367 routing_data ,
306368 gather_indx ,
307369 scatter_indx ,
370+ topk = topk ,
308371 activation = activation ,
309372 quant_config = self .quant_config ,
310373 apply_router_weight_on_input = False ,
311374 global_num_experts = local_num_experts ,
312375 expert_map = None , # applied already
376+ intermediate_cache = workspace2 ,
313377 a1q_scale = a1q_scale ,
314378 )
315-
316- output .copy_ (experts_output , non_blocking = True )
0 commit comments