1717
1818from vllm .compilation .helion .benchmark import DistributedKernelBenchmark
1919from vllm .compilation .helion .custom_op import HelionCustomOp
20+ from vllm .compilation .helion .register import register_kernel
2021from vllm .logger import init_logger
2122from vllm .model_executor .custom_op import CustomOp
2223
@@ -188,8 +189,10 @@ def copy_engine_all_reduce_w_progress(
188189
189190
190191# Create a custom op wrapper for fake tensor support
192+ # TODO(gmagogsfm): remove this custom op registration when torch.compile
193+ # and make_fx support it
191194@torch .library .custom_op (
192- "my_helion_lib ::copy_engine_all_reduce_w_progress" ,
195+ "vllm_helion ::copy_engine_all_reduce_w_progress" ,
193196 mutates_args = ("output" , "progress" ), # output and progress tensors are mutated
194197 device_types = "cuda" ,
195198)
@@ -231,7 +234,36 @@ def copy_engine_all_reduce_w_progress_fake(
231234
232235# Only define the Helion kernel if Helion is available
233236if HELION_AVAILABLE :
234- # Pure Helion kernel for autotuning - this has the autotune method
237+
238+ def _allreduce_add_rmsnorm_fake (
239+ allreduce_buf : torch .Tensor ,
240+ residual : torch .Tensor ,
241+ rms_gamma : torch .Tensor ,
242+ progress : torch .Tensor ,
243+ rms_eps : float ,
244+ SPLITS_PER_RANK : int ,
245+ ) -> tuple [torch .Tensor , torch .Tensor ]:
246+ """
247+ Custom fake implementation for allreduce_add_rmsnorm.
248+
249+ Shape contract:
250+ - allreduce_buf: [M, K]
251+ - residual: [M, K]
252+ - rms_gamma: [K]
253+ - progress: [SPLITS_PER_RANK]
254+ - returns: tuple of (normalized_output, updated_residual) both [M, K]
255+ """
256+ M , K = allreduce_buf .size ()
257+ out = torch .empty (
258+ [M , K ], dtype = allreduce_buf .dtype , device = allreduce_buf .device
259+ )
260+ residual_out = torch .empty (
261+ [M , K ], dtype = allreduce_buf .dtype , device = allreduce_buf .device
262+ )
263+ return out , residual_out
264+
265+ # Apply @register_kernel to the actual Helion kernel
266+ @register_kernel ("allreduce_add_rmsnorm" , fake_impl = _allreduce_add_rmsnorm_fake )
235267 @helion .kernel (
236268 autotune_baseline_atol = 0.0 ,
237269 autotune_baseline_rtol = 0.0 ,
@@ -273,7 +305,7 @@ def copy_engine_all_reduce_w_progress_fake(
273305 ),
274306 static_shapes = True ,
275307 )
276- def _allreduce_add_rmsnorm_pure_helion_kernel (
308+ def allreduce_add_rmsnorm (
277309 allreduce_buf : torch .Tensor ,
278310 residual : torch .Tensor ,
279311 rms_gamma : torch .Tensor ,
@@ -343,70 +375,6 @@ def _allreduce_add_rmsnorm_pure_helion_kernel(
343375
344376 return out , residual_out
345377
346- # PyTorch custom op wrapper - calls the pure Helion kernel
347- @torch .library .custom_op (
348- "my_helion_lib::allreduce_add_rmsnorm" ,
349- mutates_args = (),
350- device_types = "cuda" ,
351- )
352- def _allreduce_add_rmsnorm_helion_kernel (
353- allreduce_buf : torch .Tensor ,
354- residual : torch .Tensor ,
355- rms_gamma : torch .Tensor ,
356- progress : torch .Tensor ,
357- rms_eps : float ,
358- SPLITS_PER_RANK : int ,
359- ) -> tuple [torch .Tensor , torch .Tensor ]:
360- """
361- PyTorch custom op wrapper for Helion AllReduce+Add+RMSNorm kernel.
362-
363- Operation: RMSNorm(AllReduce(input) + residual), returns both normalized
364- and residual
365-
366- Args:
367- allreduce_buf: Buffer being filled by AllReduce [M, K]
368- residual: Residual tensor to add [M, K]
369- rms_gamma: RMSNorm gamma weights [K]
370- progress: Progress tracking tensor [SPLITS_PER_RANK]
371- rms_eps: Epsilon for numerical stability
372- SPLITS_PER_RANK: Number of splits per rank
373-
374- Returns:
375- Tuple of (normalized_output, updated_residual) both [M, K]
376- """
377- return _allreduce_add_rmsnorm_pure_helion_kernel (
378- allreduce_buf , residual , rms_gamma , progress , rms_eps , SPLITS_PER_RANK
379- )
380-
381- @_allreduce_add_rmsnorm_helion_kernel .register_fake
382- def _allreduce_add_rmsnorm_helion_kernel_fake (
383- allreduce_buf : torch .Tensor ,
384- residual : torch .Tensor ,
385- rms_gamma : torch .Tensor ,
386- progress : torch .Tensor ,
387- rms_eps : float ,
388- SPLITS_PER_RANK : int ,
389- ) -> tuple [torch .Tensor , torch .Tensor ]:
390- """
391- Fake/meta implementation for allreduce_add_rmsnorm Helion kernel.
392- Defines the input/output shape relationship without actual computation.
393-
394- Shape contract:
395- - allreduce_buf: [M, K]
396- - residual: [M, K]
397- - rms_gamma: [K]
398- - progress: [SPLITS_PER_RANK]
399- - returns: tuple of (normalized_output, updated_residual) both [M, K]
400- """
401- M , K = allreduce_buf .size ()
402- out = torch .empty (
403- [M , K ], dtype = allreduce_buf .dtype , device = allreduce_buf .device
404- )
405- residual_out = torch .empty (
406- [M , K ], dtype = allreduce_buf .dtype , device = allreduce_buf .device
407- )
408- return out , residual_out
409-
410378
411379def helion_allreduce_add_rmsnorm (
412380 input_shared : torch .Tensor ,
@@ -462,12 +430,12 @@ def helion_allreduce_add_rmsnorm(
462430 )
463431
464432 # Perform AllReduce with progress tracking (custom op handles fake mode and symmetric memory conversion)
465- torch .ops .my_helion_lib .copy_engine_all_reduce_w_progress (
433+ torch .ops .vllm_helion .copy_engine_all_reduce_w_progress (
466434 allreduce_out , input_shared , progress , splits_per_rank
467435 )
468436
469437 # Call the Helion kernel for Add + RMSNorm
470- norm_out , residual_out = torch . ops . my_helion_lib . allreduce_add_rmsnorm (
438+ norm_out , residual_out = allreduce_add_rmsnorm (
471439 allreduce_out ,
472440 residual ,
473441 rms_gamma ,
@@ -662,9 +630,9 @@ def get_best_config(
662630 splits_match = key_splits == splits
663631
664632 if distance < best_distance or (
665- distance == best_distance and splits_match and (
666- best_match is None or not best_match [ 2 ]
667- )
633+ distance == best_distance
634+ and splits_match
635+ and ( best_match is None or not best_match [ 2 ] )
668636 ):
669637 best_match = (size , key , splits_match )
670638 best_distance = distance
@@ -688,7 +656,7 @@ def get_best_config(
688656 def helion_kernel (self ):
689657 """The Helion kernel function for autotuning."""
690658 if HELION_AVAILABLE :
691- return _allreduce_add_rmsnorm_pure_helion_kernel
659+ return allreduce_add_rmsnorm . _helion_kernel
692660 return None
693661
694662
0 commit comments