Skip to content

Commit 6a54038

Browse files
committed
Add allreduce_add_rmsnorm
Signed-off-by: Yanan Cao <[email protected]>
1 parent 0d2cdb1 commit 6a54038

File tree

3 files changed

+1311
-0
lines changed

3 files changed

+1311
-0
lines changed

tests/compile/distributed/test_fusion_all_reduce.py

Lines changed: 319 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,25 @@
3636
from ...utils import has_module_attribute, multi_gpu_test
3737
from ..backend import TestBackend
3838

39+
# Helion imports
40+
try:
41+
import torch.distributed._symmetric_memory as symm_mem
42+
from vllm.compilation.helion.allreduce_add_rmsnorm import (
43+
helion_allreduce_add_rmsnorm,
44+
)
45+
46+
HELION_AVAILABLE = True
47+
except ImportError:
48+
HELION_AVAILABLE = False
49+
50+
# FlashInfer imports for baseline comparison
51+
try:
52+
import flashinfer.comm as flashinfer_comm
53+
54+
FLASHINFER_AVAILABLE = True
55+
except ImportError:
56+
FLASHINFER_AVAILABLE = False
57+
3958

4059
class TestAllReduceRMSNormModel(torch.nn.Module):
4160
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
@@ -192,6 +211,33 @@ def ops_in_model_before(self):
192211
]
193212

194213

214+
class TestHelionAllReduceAddRMSNormModel(torch.nn.Module):
215+
"""Test model using Helion AllReduce + Add + RMSNorm fusion."""
216+
217+
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
218+
super().__init__()
219+
self.hidden_size = hidden_size
220+
self.token_num = token_num
221+
self.eps = eps
222+
self.norm = RMSNorm(hidden_size, eps)
223+
self.rms_gamma = self.norm.weight
224+
225+
def forward(self, input_shared, residual):
226+
"""
227+
Forward pass using Helion fused op.
228+
229+
Args:
230+
input_shared: Symmetric tensor to be all-reduced
231+
residual: Residual tensor to add
232+
233+
Returns:
234+
Tuple of (normalized_output, updated_residual)
235+
"""
236+
return helion_allreduce_add_rmsnorm(
237+
input_shared, residual, self.rms_gamma, self.eps, splits_per_rank=4
238+
)
239+
240+
195241
@multi_gpu_test(num_gpus=2)
196242
@pytest.mark.parametrize(
197243
"test_model, enable_quant_fp8_custom_op",
@@ -330,3 +376,276 @@ def all_reduce_fusion_pass_on_test_model(
330376
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
331377
backend.check_after_ops(model.ops_in_model_after())
332378
del all_reduce_fusion_pass
379+
380+
381+
@multi_gpu_test(num_gpus=2)
382+
@pytest.mark.parametrize("batch_size", [8, 16])
383+
@pytest.mark.parametrize("seq_len", [8, 16])
384+
@pytest.mark.parametrize("hidden_size", [64, 128])
385+
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
386+
@pytest.mark.parametrize("splits_per_rank", [2, 4])
387+
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
388+
@pytest.mark.skipif(not HELION_AVAILABLE, reason="Helion not available")
389+
@pytest.mark.skipif(not FLASHINFER_AVAILABLE, reason="FlashInfer not available")
390+
def test_helion_allreduce_add_rmsnorm(
391+
batch_size: int,
392+
seq_len: int,
393+
hidden_size: int,
394+
dtype: torch.dtype,
395+
splits_per_rank: int,
396+
):
397+
"""
398+
Test Helion AllReduce + Add + RMSNorm fusion.
399+
400+
This test validates:
401+
1. Numerical correctness against FlashInfer baseline
402+
2. Performance comparison against FlashInfer baseline
403+
404+
Args:
405+
batch_size: Batch size for the test
406+
seq_len: Sequence length for the test
407+
hidden_size: Hidden dimension size
408+
dtype: Data type (bfloat16 or float16)
409+
splits_per_rank: Number of splits for progressive AllReduce
410+
"""
411+
num_processes = 2
412+
413+
def run_torch_spawn(fn, nprocs):
414+
torch.multiprocessing.spawn(
415+
fn,
416+
args=(
417+
num_processes,
418+
batch_size,
419+
seq_len,
420+
hidden_size,
421+
dtype,
422+
splits_per_rank,
423+
),
424+
nprocs=nprocs,
425+
)
426+
427+
run_torch_spawn(helion_allreduce_add_rmsnorm_worker, num_processes)
428+
429+
430+
def helion_allreduce_add_rmsnorm_worker(
431+
local_rank: int,
432+
world_size: int,
433+
batch_size: int,
434+
seq_len: int,
435+
hidden_size: int,
436+
dtype: torch.dtype,
437+
splits_per_rank: int,
438+
):
439+
"""Worker function for testing Helion AllReduce + Add + RMSNorm."""
440+
import torch.distributed as dist
441+
442+
current_platform.seed_everything(0)
443+
444+
device = torch.device(f"cuda:{local_rank}")
445+
torch.cuda.set_device(device)
446+
torch.set_default_device(device)
447+
torch.set_default_dtype(dtype)
448+
449+
# Initialize distributed environment
450+
update_environment_variables(
451+
{
452+
"RANK": str(local_rank),
453+
"LOCAL_RANK": str(local_rank),
454+
"WORLD_SIZE": str(world_size),
455+
"MASTER_ADDR": "localhost",
456+
"MASTER_PORT": "12346", # Different port from other tests
457+
}
458+
)
459+
460+
init_distributed_environment()
461+
initialize_model_parallel(tensor_model_parallel_size=world_size)
462+
463+
token_num = batch_size * seq_len
464+
M, K = token_num, hidden_size
465+
rms_eps = 1e-6
466+
467+
# ========== Setup FlashInfer baseline ==========
468+
flashinfer_ipc_handles, flashinfer_workspace = (
469+
flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
470+
tp_rank=local_rank,
471+
tp_size=world_size,
472+
max_token_num=M,
473+
hidden_dim=K,
474+
group=dist.group.WORLD,
475+
use_fp32_lamport=False,
476+
)
477+
)
478+
479+
# ========== Test Numerical Correctness ==========
480+
# Create test data (same seed across ranks for reproducibility)
481+
torch.manual_seed(42 + local_rank) # Different data per rank
482+
input_data = torch.randn(M, K, dtype=dtype, device=device)
483+
residual_data = torch.randn(M, K, dtype=dtype, device=device)
484+
rms_gamma = torch.ones(K, dtype=dtype, device=device)
485+
486+
# Run FlashInfer baseline
487+
input_baseline = symm_mem.empty(M, K, dtype=dtype, device=device)
488+
input_baseline.copy_(input_data)
489+
residual_baseline = residual_data.clone()
490+
491+
norm_out_baseline = input_baseline # FlashInfer operates in-place
492+
residual_out_baseline = residual_baseline
493+
494+
flashinfer_comm.trtllm_allreduce_fusion(
495+
allreduce_in=input_baseline,
496+
token_num=M,
497+
residual_in=residual_baseline,
498+
residual_out=residual_out_baseline,
499+
norm_out=norm_out_baseline,
500+
rms_gamma=rms_gamma,
501+
rms_eps=rms_eps,
502+
hidden_dim=K,
503+
workspace_ptrs=flashinfer_workspace,
504+
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
505+
allreduce_out=None,
506+
quant_out=None,
507+
scale_out=None,
508+
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
509+
scale_factor=None,
510+
use_oneshot=False,
511+
world_rank=local_rank,
512+
world_size=world_size,
513+
launch_with_pdl=True,
514+
trigger_completion_at_end=True,
515+
fp32_acc=True,
516+
)
517+
torch.cuda.synchronize()
518+
519+
# Run Helion
520+
input_helion = symm_mem.empty(M, K, dtype=dtype, device=device)
521+
input_helion.copy_(input_data)
522+
residual_helion = residual_data.clone()
523+
524+
norm_out_helion, residual_out_helion = helion_allreduce_add_rmsnorm(
525+
input_helion, residual_helion, rms_gamma, rms_eps, splits_per_rank
526+
)
527+
torch.cuda.synchronize()
528+
529+
# Compare results
530+
# Use relaxed tolerances for bfloat16 and for accumulated errors
531+
if dtype == torch.bfloat16:
532+
rtol, atol = 1e-2, 1e-2
533+
else:
534+
rtol, atol = 1e-3, 1e-3
535+
536+
# Check normalized output
537+
torch.testing.assert_close(
538+
norm_out_helion,
539+
norm_out_baseline,
540+
rtol=rtol,
541+
atol=atol,
542+
msg=f"Normalized output mismatch (rank={local_rank}, dtype={dtype})",
543+
)
544+
545+
# Check residual output
546+
torch.testing.assert_close(
547+
residual_out_helion,
548+
residual_out_baseline,
549+
rtol=rtol,
550+
atol=atol,
551+
msg=f"Residual output mismatch (rank={local_rank}, dtype={dtype})",
552+
)
553+
554+
if local_rank == 0:
555+
print(
556+
f"✓ Numerical correctness test passed "
557+
f"(M={M}, K={K}, dtype={dtype}, splits={splits_per_rank})"
558+
)
559+
560+
# ========== Performance Comparison ==========
561+
num_iterations = 20
562+
warmup = 5
563+
564+
def time_kernel(kernel_fn):
565+
# Warmup
566+
for _ in range(warmup):
567+
kernel_fn()
568+
torch.cuda.synchronize()
569+
570+
# Benchmark
571+
start_event = torch.cuda.Event(enable_timing=True)
572+
end_event = torch.cuda.Event(enable_timing=True)
573+
574+
start_event.record()
575+
for _ in range(num_iterations):
576+
kernel_fn()
577+
end_event.record()
578+
579+
torch.cuda.synchronize()
580+
581+
return start_event.elapsed_time(end_event) / num_iterations
582+
583+
# Benchmark FlashInfer
584+
input_baseline_perf = symm_mem.empty(M, K, dtype=dtype, device=device)
585+
residual_baseline_perf = torch.empty(M, K, dtype=dtype, device=device)
586+
input_data_perf = torch.randn(M, K, dtype=dtype, device=device)
587+
residual_data_perf = torch.randn(M, K, dtype=dtype, device=device)
588+
589+
def baseline_fn():
590+
input_baseline_perf.copy_(input_data_perf)
591+
residual_baseline_perf.copy_(residual_data_perf)
592+
593+
flashinfer_comm.trtllm_allreduce_fusion(
594+
allreduce_in=input_baseline_perf,
595+
token_num=M,
596+
residual_in=residual_baseline_perf,
597+
residual_out=residual_baseline_perf,
598+
norm_out=input_baseline_perf,
599+
rms_gamma=rms_gamma,
600+
rms_eps=rms_eps,
601+
hidden_dim=K,
602+
workspace_ptrs=flashinfer_workspace,
603+
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
604+
allreduce_out=None,
605+
quant_out=None,
606+
scale_out=None,
607+
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
608+
scale_factor=None,
609+
use_oneshot=False,
610+
world_rank=local_rank,
611+
world_size=world_size,
612+
launch_with_pdl=True,
613+
trigger_completion_at_end=True,
614+
fp32_acc=True,
615+
)
616+
617+
dist.barrier()
618+
baseline_time = time_kernel(baseline_fn)
619+
dist.barrier()
620+
621+
# Benchmark Helion
622+
input_helion_perf = symm_mem.empty(M, K, dtype=dtype, device=device)
623+
residual_helion_perf = torch.empty(M, K, dtype=dtype, device=device)
624+
625+
def helion_fn():
626+
input_helion_perf.copy_(input_data_perf)
627+
residual_helion_perf.copy_(residual_data_perf)
628+
629+
helion_allreduce_add_rmsnorm(
630+
input_helion_perf, residual_helion_perf, rms_gamma, rms_eps, splits_per_rank
631+
)
632+
633+
dist.barrier()
634+
helion_time = time_kernel(helion_fn)
635+
dist.barrier()
636+
637+
if local_rank == 0:
638+
speedup = baseline_time / helion_time
639+
print(f"✓ Performance comparison (M={M}, K={K}, dtype={dtype}):")
640+
print(f" FlashInfer: {baseline_time:.4f} ms")
641+
print(f" Helion: {helion_time:.4f} ms")
642+
print(f" Speedup: {speedup:.2f}x")
643+
644+
# Cleanup
645+
try:
646+
flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce_fusion(
647+
flashinfer_ipc_handles
648+
)
649+
except:
650+
pass
651+

0 commit comments

Comments
 (0)