Skip to content

Commit 578ac03

Browse files
committed
Add regsiter_kernel decorator
Signed-off-by: Yanan Cao <[email protected]>
1 parent b4444ee commit 578ac03

File tree

9 files changed

+303
-188
lines changed

9 files changed

+303
-188
lines changed

tests/compile/distributed/test_fusion_all_reduce.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
if HELION_OP_AVAILABLE:
4949
import torch
5050

51-
_ = torch.ops.my_helion_lib.allreduce_add_rmsnorm # Will raise if not registered
51+
_ = torch.ops.vllm_helion.allreduce_add_rmsnorm # Will raise if not registered
5252
except (ImportError, AttributeError):
5353
HELION_OP_AVAILABLE = False
5454

@@ -89,7 +89,7 @@ def ops_in_model_before(self):
8989

9090
def ops_in_model_after(self):
9191
if self.use_helion:
92-
return [torch.ops.my_helion_lib.allreduce_add_rmsnorm.default]
92+
return [torch.ops.vllm_helion.allreduce_add_rmsnorm.default]
9393
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
9494

9595

@@ -144,7 +144,7 @@ def forward(self, hidden_states):
144144

145145
def ops_in_model_after(self):
146146
if self.use_helion:
147-
return [torch.ops.my_helion_lib.allreduce_add_rmsnorm.default]
147+
return [torch.ops.vllm_helion.allreduce_add_rmsnorm.default]
148148
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
149149

150150
def ops_in_model_before(self):
@@ -161,7 +161,9 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6, use_helion=False):
161161
super().__init__()
162162
self.hidden_size = hidden_size
163163
self.eps = eps
164-
self.use_helion = use_helion # Not used for FP4 model, but accept for consistency
164+
self.use_helion = (
165+
use_helion # Not used for FP4 model, but accept for consistency
166+
)
165167
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
166168

167169
self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)]

tests/compile/test_silu_mul_quant_fusion.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,14 @@
4040

4141
# Check if Helion torch.ops are available
4242
try:
43-
from vllm.compilation.helion.silu_mul_fp8 import SiluMulFp8Helion
44-
45-
# Check if the op is available - this will be True if Helion is installed and enabled
46-
HELION_OP_AVAILABLE = SiluMulFp8Helion.is_helion_available()
43+
# Import to trigger registration
4744
# Try to access the torch.ops to verify it's registered
48-
if HELION_OP_AVAILABLE:
49-
import torch
50-
_ = torch.ops.my_helion_lib.silu_mul_fp8 # Will raise if not registered
45+
import torch
46+
47+
from vllm.compilation.helion.silu_mul_fp8 import silu_mul_fp8
48+
49+
_ = torch.ops.vllm_helion.silu_mul_fp8 # Will raise if not registered
50+
HELION_OP_AVAILABLE = True
5151
except (ImportError, AttributeError):
5252
HELION_OP_AVAILABLE = False
5353

@@ -100,7 +100,7 @@ def ops_in_model_before(self):
100100

101101
def ops_in_model_after(self):
102102
if self.use_helion:
103-
return [torch.ops.my_helion_lib.silu_mul_fp8]
103+
return [torch.ops.vllm_helion.silu_mul_fp8]
104104
return [FUSED_OPS[kFp8StaticTensorSym]]
105105

106106

@@ -155,7 +155,11 @@ def ops_in_model_after(self):
155155
@pytest.mark.parametrize(
156156
"model_class, enable_quant_fp8_custom_op, cuda_force_torch, use_helion",
157157
# Test FP8 model with both Helion and non-Helion
158-
list(itertools.product([TestSiluMulFp8QuantModel], [True, False], [True, False], [False, True]))
158+
list(
159+
itertools.product(
160+
[TestSiluMulFp8QuantModel], [True, False], [True, False], [False, True]
161+
)
162+
)
159163
# Test NVFP4 model only without Helion (use_helion must be False)
160164
+ [(TestSiluMulNvfp4QuantModel, False, False, False)],
161165
)
@@ -209,7 +213,10 @@ def test_fusion_silu_and_mul_quant(
209213
passes = [NoOpEliminationPass(config), fusion_pass, PostCleanupPass(config)]
210214
backend = TestBackend(*passes)
211215
model = model_class(
212-
hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x, use_helion=use_helion
216+
hidden_size=hidden_size,
217+
cuda_force_torch=cuda_force_torch,
218+
x=x,
219+
use_helion=use_helion,
213220
)
214221

215222
# First dimension dynamic

vllm/compilation/activation_quant_fusion.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
from abc import ABC, abstractmethod
5-
from contextlib import suppress
65

76
import torch
87
from torch._higher_order_ops.auto_functionalize import auto_functionalized
@@ -120,7 +119,7 @@ def replacement(
120119
# This encapsulates all the enable/disable logic in one place
121120
if self.helion_op is not None and self.helion_op.enabled():
122121
# Call the Helion CustomOp's forward method
123-
# This will internally call torch.ops.my_helion_lib.silu_mul_fp8
122+
# This will internally call the decorated Helion kernel directly
124123
return self.helion_op.forward_helion(input, scale)
125124
else:
126125
d = input.shape[-1] // 2

vllm/compilation/helion/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from vllm.compilation.helion.benchmark import KernelBenchmark
1717
from vllm.compilation.helion.custom_op import HelionCustomOp
18+
from vllm.compilation.helion.register import register_kernel
1819

1920
# Automatically import all kernel modules to trigger registration
2021
# This allows new kernels to be added without modifying this file
@@ -47,5 +48,5 @@
4748
__all__ = [
4849
"HelionCustomOp",
4950
"KernelBenchmark",
51+
"register_kernel",
5052
] + sorted(_helion_ops.keys())
51-

vllm/compilation/helion/allreduce_add_rmsnorm.py

Lines changed: 41 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from vllm.compilation.helion.benchmark import DistributedKernelBenchmark
1919
from vllm.compilation.helion.custom_op import HelionCustomOp
20+
from vllm.compilation.helion.register import register_kernel
2021
from vllm.logger import init_logger
2122
from vllm.model_executor.custom_op import CustomOp
2223

@@ -188,8 +189,10 @@ def copy_engine_all_reduce_w_progress(
188189

189190

190191
# Create a custom op wrapper for fake tensor support
192+
# TODO(gmagogsfm): remove this custom op registration when torch.compile
193+
# and make_fx support it
191194
@torch.library.custom_op(
192-
"my_helion_lib::copy_engine_all_reduce_w_progress",
195+
"vllm_helion::copy_engine_all_reduce_w_progress",
193196
mutates_args=("output", "progress"), # output and progress tensors are mutated
194197
device_types="cuda",
195198
)
@@ -231,7 +234,36 @@ def copy_engine_all_reduce_w_progress_fake(
231234

232235
# Only define the Helion kernel if Helion is available
233236
if HELION_AVAILABLE:
234-
# Pure Helion kernel for autotuning - this has the autotune method
237+
238+
def _allreduce_add_rmsnorm_fake(
239+
allreduce_buf: torch.Tensor,
240+
residual: torch.Tensor,
241+
rms_gamma: torch.Tensor,
242+
progress: torch.Tensor,
243+
rms_eps: float,
244+
SPLITS_PER_RANK: int,
245+
) -> tuple[torch.Tensor, torch.Tensor]:
246+
"""
247+
Custom fake implementation for allreduce_add_rmsnorm.
248+
249+
Shape contract:
250+
- allreduce_buf: [M, K]
251+
- residual: [M, K]
252+
- rms_gamma: [K]
253+
- progress: [SPLITS_PER_RANK]
254+
- returns: tuple of (normalized_output, updated_residual) both [M, K]
255+
"""
256+
M, K = allreduce_buf.size()
257+
out = torch.empty(
258+
[M, K], dtype=allreduce_buf.dtype, device=allreduce_buf.device
259+
)
260+
residual_out = torch.empty(
261+
[M, K], dtype=allreduce_buf.dtype, device=allreduce_buf.device
262+
)
263+
return out, residual_out
264+
265+
# Apply @register_kernel to the actual Helion kernel
266+
@register_kernel("allreduce_add_rmsnorm", fake_impl=_allreduce_add_rmsnorm_fake)
235267
@helion.kernel(
236268
autotune_baseline_atol=0.0,
237269
autotune_baseline_rtol=0.0,
@@ -273,7 +305,7 @@ def copy_engine_all_reduce_w_progress_fake(
273305
),
274306
static_shapes=True,
275307
)
276-
def _allreduce_add_rmsnorm_pure_helion_kernel(
308+
def allreduce_add_rmsnorm(
277309
allreduce_buf: torch.Tensor,
278310
residual: torch.Tensor,
279311
rms_gamma: torch.Tensor,
@@ -343,70 +375,6 @@ def _allreduce_add_rmsnorm_pure_helion_kernel(
343375

344376
return out, residual_out
345377

346-
# PyTorch custom op wrapper - calls the pure Helion kernel
347-
@torch.library.custom_op(
348-
"my_helion_lib::allreduce_add_rmsnorm",
349-
mutates_args=(),
350-
device_types="cuda",
351-
)
352-
def _allreduce_add_rmsnorm_helion_kernel(
353-
allreduce_buf: torch.Tensor,
354-
residual: torch.Tensor,
355-
rms_gamma: torch.Tensor,
356-
progress: torch.Tensor,
357-
rms_eps: float,
358-
SPLITS_PER_RANK: int,
359-
) -> tuple[torch.Tensor, torch.Tensor]:
360-
"""
361-
PyTorch custom op wrapper for Helion AllReduce+Add+RMSNorm kernel.
362-
363-
Operation: RMSNorm(AllReduce(input) + residual), returns both normalized
364-
and residual
365-
366-
Args:
367-
allreduce_buf: Buffer being filled by AllReduce [M, K]
368-
residual: Residual tensor to add [M, K]
369-
rms_gamma: RMSNorm gamma weights [K]
370-
progress: Progress tracking tensor [SPLITS_PER_RANK]
371-
rms_eps: Epsilon for numerical stability
372-
SPLITS_PER_RANK: Number of splits per rank
373-
374-
Returns:
375-
Tuple of (normalized_output, updated_residual) both [M, K]
376-
"""
377-
return _allreduce_add_rmsnorm_pure_helion_kernel(
378-
allreduce_buf, residual, rms_gamma, progress, rms_eps, SPLITS_PER_RANK
379-
)
380-
381-
@_allreduce_add_rmsnorm_helion_kernel.register_fake
382-
def _allreduce_add_rmsnorm_helion_kernel_fake(
383-
allreduce_buf: torch.Tensor,
384-
residual: torch.Tensor,
385-
rms_gamma: torch.Tensor,
386-
progress: torch.Tensor,
387-
rms_eps: float,
388-
SPLITS_PER_RANK: int,
389-
) -> tuple[torch.Tensor, torch.Tensor]:
390-
"""
391-
Fake/meta implementation for allreduce_add_rmsnorm Helion kernel.
392-
Defines the input/output shape relationship without actual computation.
393-
394-
Shape contract:
395-
- allreduce_buf: [M, K]
396-
- residual: [M, K]
397-
- rms_gamma: [K]
398-
- progress: [SPLITS_PER_RANK]
399-
- returns: tuple of (normalized_output, updated_residual) both [M, K]
400-
"""
401-
M, K = allreduce_buf.size()
402-
out = torch.empty(
403-
[M, K], dtype=allreduce_buf.dtype, device=allreduce_buf.device
404-
)
405-
residual_out = torch.empty(
406-
[M, K], dtype=allreduce_buf.dtype, device=allreduce_buf.device
407-
)
408-
return out, residual_out
409-
410378

411379
def helion_allreduce_add_rmsnorm(
412380
input_shared: torch.Tensor,
@@ -462,12 +430,12 @@ def helion_allreduce_add_rmsnorm(
462430
)
463431

464432
# Perform AllReduce with progress tracking (custom op handles fake mode and symmetric memory conversion)
465-
torch.ops.my_helion_lib.copy_engine_all_reduce_w_progress(
433+
torch.ops.vllm_helion.copy_engine_all_reduce_w_progress(
466434
allreduce_out, input_shared, progress, splits_per_rank
467435
)
468436

469437
# Call the Helion kernel for Add + RMSNorm
470-
norm_out, residual_out = torch.ops.my_helion_lib.allreduce_add_rmsnorm(
438+
norm_out, residual_out = allreduce_add_rmsnorm(
471439
allreduce_out,
472440
residual,
473441
rms_gamma,
@@ -662,9 +630,9 @@ def get_best_config(
662630
splits_match = key_splits == splits
663631

664632
if distance < best_distance or (
665-
distance == best_distance and splits_match and (
666-
best_match is None or not best_match[2]
667-
)
633+
distance == best_distance
634+
and splits_match
635+
and (best_match is None or not best_match[2])
668636
):
669637
best_match = (size, key, splits_match)
670638
best_distance = distance
@@ -688,7 +656,7 @@ def get_best_config(
688656
def helion_kernel(self):
689657
"""The Helion kernel function for autotuning."""
690658
if HELION_AVAILABLE:
691-
return _allreduce_add_rmsnorm_pure_helion_kernel
659+
return allreduce_add_rmsnorm._helion_kernel
692660
return None
693661

694662

vllm/compilation/helion/custom_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class HelionCustomOp(CustomOp):
4545
@CustomOp.register("my_helion_op")
4646
class MyHelionOp(HelionCustomOp):
4747
def forward_helion(self, x):
48-
return torch.ops.my_helion_lib.my_op(x)
48+
return torch.ops.vllm_helion.my_op(x)
4949
5050
Checking if an op is enabled:
5151
# Class method (call on the class)

0 commit comments

Comments
 (0)