Skip to content

Commit 53d7f1f

Browse files
authored
[Kernel] Use pre-allocated output buffer for triton kernel fused_experts (vllm-project#29219)
Signed-off-by: Xin Yang <[email protected]>
1 parent c5ee430 commit 53d7f1f

File tree

1 file changed

+73
-11
lines changed

1 file changed

+73
-11
lines changed

vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py

Lines changed: 73 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
45
import torch
56

67
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
@@ -12,6 +13,7 @@
1213
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
1314
TopKWeightAndReduceNoOP,
1415
)
16+
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
1517
from vllm.triton_utils import tl, triton
1618
from vllm.utils.import_utils import has_triton_kernels
1719

@@ -88,14 +90,17 @@ def triton_kernel_moe_forward(
8890
gating_output, topk, sm_first=not renormalize
8991
)
9092

93+
output = torch.empty_like(hidden_states)
94+
9195
return triton_kernel_fused_experts(
92-
None,
96+
output,
9397
hidden_states,
9498
w1,
9599
w2,
96100
routing_data,
97101
gather_idx,
98102
scatter_idx,
103+
topk=topk,
99104
activation=activation,
100105
quant_config=quant_config,
101106
apply_router_weight_on_input=apply_router_weight_on_input,
@@ -113,13 +118,15 @@ def triton_kernel_fused_experts(
113118
routing_data, # RoutingData
114119
gather_indx, # GatherIndx
115120
scatter_indx, # ScatterIndx
121+
topk: int,
116122
activation: str = "silu",
117123
quant_config: FusedMoEQuantConfig | None = None,
118124
swiglu_alpha: float = 1.702,
119125
swiglu_limit: float = 7.0,
120126
apply_router_weight_on_input: bool = False,
121127
global_num_experts: int = -1,
122128
expert_map: torch.Tensor | None = None,
129+
intermediate_cache: torch.Tensor | None = None,
123130
a1q_scale: torch.Tensor | None = None,
124131
) -> torch.Tensor:
125132
if quant_config is None:
@@ -131,22 +138,38 @@ def triton_kernel_fused_experts(
131138
assert quant_config.w2_bias is None or quant_config.w2_bias.dtype == torch.float32
132139

133140
# Shape check, only check non-mxfp4
141+
assert hidden_states.ndim == 2
134142
assert hidden_states.shape[-1] == w1.shape[-2]
135143
assert w2.shape[-1] == w1.shape[1]
136144

145+
batch_dim = 1
146+
M, K = hidden_states.shape[-2:]
137147
E, _, N = w1.shape
138148

139149
if global_num_experts == -1:
140150
global_num_experts = E
141151

152+
if intermediate_cache is None:
153+
intermediate_cache = torch.empty(
154+
(batch_dim, M * topk, N // 2),
155+
device=hidden_states.device,
156+
dtype=hidden_states.dtype,
157+
)
158+
159+
# Add batch_dim to output buffer because matmul_ogs expects 3D output
160+
intermediate_cache = _resize_cache(
161+
intermediate_cache, (batch_dim, M * topk, N // 2)
162+
)
163+
output_tensor = _resize_cache(output_tensor, (batch_dim, M, K))
164+
142165
act = FusedActivation(
143166
FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")),
144167
(swiglu_alpha, swiglu_limit),
145168
2,
146169
)
147170
gammas = routing_data.gate_scal if routing_data else None
148171

149-
intermediate_cache1 = matmul_ogs(
172+
matmul_ogs(
150173
hidden_states,
151174
w1,
152175
quant_config.w1_bias,
@@ -155,10 +178,11 @@ def triton_kernel_fused_experts(
155178
precision_config=quant_config.w1_precision,
156179
gammas=gammas if apply_router_weight_on_input else None,
157180
fused_activation=act,
181+
y=intermediate_cache,
158182
)
159183

160-
intermediate_cache3 = matmul_ogs(
161-
intermediate_cache1,
184+
matmul_ogs(
185+
intermediate_cache.view(M * topk, N // 2),
162186
w2,
163187
quant_config.w2_bias,
164188
routing_data,
@@ -167,7 +191,8 @@ def triton_kernel_fused_experts(
167191
gammas=None if apply_router_weight_on_input else gammas,
168192
y=output_tensor,
169193
)
170-
return intermediate_cache3
194+
output_tensor = output_tensor.view(M, K)
195+
return output_tensor
171196

172197

173198
def make_routing_data(
@@ -221,6 +246,42 @@ def __init__(self, quant_config: FusedMoEQuantConfig):
221246
def supports_expert_map(self) -> bool:
222247
return True
223248

249+
def moe_problem_size(
250+
self,
251+
a1: torch.Tensor,
252+
w1: torch.Tensor,
253+
w2: torch.Tensor,
254+
topk_ids: torch.Tensor,
255+
) -> tuple[int, int, int, int, int]:
256+
"""
257+
Extract the MoE problem size from the given tensor arguments:
258+
- a: The hidden states, input to the MoE layer.
259+
- w1: The first set of expert weights.
260+
- w2: The second set of expert weights.
261+
- topk_ids: The topk ids.
262+
Note: extracting the problem shape from the weight and activation
263+
tensors is not obvious. It needs to be done this way specifically
264+
due to subtle issues with particular kernels, e.g. the int4 kernels
265+
divide the trailing dimension by two, so it's not "correct" to
266+
extract N or K from the trailing dimension of w1 or w2. Similarly,
267+
some kernels transpose the weights, so this needs to be kept in mind.
268+
Note: This implementation covers most cases. However, if experts
269+
require a specialized implementation, like MarlinExperts, they are free
270+
to override this function.
271+
"""
272+
assert w1.dim() == 3 and w2.dim() == 3
273+
E, _, N = w1.size()
274+
K = a1.size(-1)
275+
276+
assert a1.dim() == 2
277+
assert topk_ids.size(0) == a1.size(0), f"{topk_ids.size(0)} != {a1.size(0)}"
278+
M = a1.size(0)
279+
280+
assert topk_ids.dim() == 2
281+
topk = topk_ids.size(1)
282+
283+
return E, M, N, K, topk
284+
224285
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
225286
# Weight application and reduction happens in the fused_experts kernel.
226287
return TopKWeightAndReduceNoOP()
@@ -263,8 +324,8 @@ def workspace_shapes(
263324
expert_tokens_meta: mk.ExpertTokensMetadata | None,
264325
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
265326
# workspace are allocated inside the kernel
266-
workspace1 = (M, K)
267-
workspace2 = (0, 0)
327+
workspace1 = (0, 0)
328+
workspace2 = (M * topk, N // 2)
268329
output = (M, K)
269330
return (workspace1, workspace2, output)
270331

@@ -297,20 +358,21 @@ def apply(
297358
topk_ids, topk_weights, local_num_experts
298359
)
299360

300-
experts_output = triton_kernel_fused_experts(
301-
None,
361+
topk = topk_ids.size(1)
362+
triton_kernel_fused_experts(
363+
output,
302364
hidden_states,
303365
w1,
304366
w2,
305367
routing_data,
306368
gather_indx,
307369
scatter_indx,
370+
topk=topk,
308371
activation=activation,
309372
quant_config=self.quant_config,
310373
apply_router_weight_on_input=False,
311374
global_num_experts=local_num_experts,
312375
expert_map=None, # applied already
376+
intermediate_cache=workspace2,
313377
a1q_scale=a1q_scale,
314378
)
315-
316-
output.copy_(experts_output, non_blocking=True)

0 commit comments

Comments
 (0)