|
36 | 36 | from ...utils import has_module_attribute, multi_gpu_test |
37 | 37 | from ..backend import TestBackend |
38 | 38 |
|
| 39 | +# Helion imports |
| 40 | +try: |
| 41 | + import torch.distributed._symmetric_memory as symm_mem |
| 42 | + from vllm.compilation.helion.allreduce_add_rmsnorm import ( |
| 43 | + helion_allreduce_add_rmsnorm, |
| 44 | + ) |
| 45 | + |
| 46 | + HELION_AVAILABLE = True |
| 47 | +except ImportError: |
| 48 | + HELION_AVAILABLE = False |
| 49 | + |
| 50 | +# FlashInfer imports for baseline comparison |
| 51 | +try: |
| 52 | + import flashinfer.comm as flashinfer_comm |
| 53 | + |
| 54 | + FLASHINFER_AVAILABLE = True |
| 55 | +except ImportError: |
| 56 | + FLASHINFER_AVAILABLE = False |
| 57 | + |
39 | 58 |
|
40 | 59 | class TestAllReduceRMSNormModel(torch.nn.Module): |
41 | 60 | def __init__(self, hidden_size=16, token_num=16, eps=1e-6): |
@@ -192,6 +211,33 @@ def ops_in_model_before(self): |
192 | 211 | ] |
193 | 212 |
|
194 | 213 |
|
| 214 | +class TestHelionAllReduceAddRMSNormModel(torch.nn.Module): |
| 215 | + """Test model using Helion AllReduce + Add + RMSNorm fusion.""" |
| 216 | + |
| 217 | + def __init__(self, hidden_size=16, token_num=16, eps=1e-6): |
| 218 | + super().__init__() |
| 219 | + self.hidden_size = hidden_size |
| 220 | + self.token_num = token_num |
| 221 | + self.eps = eps |
| 222 | + self.norm = RMSNorm(hidden_size, eps) |
| 223 | + self.rms_gamma = self.norm.weight |
| 224 | + |
| 225 | + def forward(self, input_shared, residual): |
| 226 | + """ |
| 227 | + Forward pass using Helion fused op. |
| 228 | +
|
| 229 | + Args: |
| 230 | + input_shared: Symmetric tensor to be all-reduced |
| 231 | + residual: Residual tensor to add |
| 232 | +
|
| 233 | + Returns: |
| 234 | + Tuple of (normalized_output, updated_residual) |
| 235 | + """ |
| 236 | + return helion_allreduce_add_rmsnorm( |
| 237 | + input_shared, residual, self.rms_gamma, self.eps, splits_per_rank=4 |
| 238 | + ) |
| 239 | + |
| 240 | + |
195 | 241 | @multi_gpu_test(num_gpus=2) |
196 | 242 | @pytest.mark.parametrize( |
197 | 243 | "test_model, enable_quant_fp8_custom_op", |
@@ -330,3 +376,276 @@ def all_reduce_fusion_pass_on_test_model( |
330 | 376 | backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False) |
331 | 377 | backend.check_after_ops(model.ops_in_model_after()) |
332 | 378 | del all_reduce_fusion_pass |
| 379 | + |
| 380 | + |
| 381 | +@multi_gpu_test(num_gpus=2) |
| 382 | +@pytest.mark.parametrize("batch_size", [8, 16]) |
| 383 | +@pytest.mark.parametrize("seq_len", [8, 16]) |
| 384 | +@pytest.mark.parametrize("hidden_size", [64, 128]) |
| 385 | +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) |
| 386 | +@pytest.mark.parametrize("splits_per_rank", [2, 4]) |
| 387 | +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") |
| 388 | +@pytest.mark.skipif(not HELION_AVAILABLE, reason="Helion not available") |
| 389 | +@pytest.mark.skipif(not FLASHINFER_AVAILABLE, reason="FlashInfer not available") |
| 390 | +def test_helion_allreduce_add_rmsnorm( |
| 391 | + batch_size: int, |
| 392 | + seq_len: int, |
| 393 | + hidden_size: int, |
| 394 | + dtype: torch.dtype, |
| 395 | + splits_per_rank: int, |
| 396 | +): |
| 397 | + """ |
| 398 | + Test Helion AllReduce + Add + RMSNorm fusion. |
| 399 | +
|
| 400 | + This test validates: |
| 401 | + 1. Numerical correctness against FlashInfer baseline |
| 402 | + 2. Performance comparison against FlashInfer baseline |
| 403 | +
|
| 404 | + Args: |
| 405 | + batch_size: Batch size for the test |
| 406 | + seq_len: Sequence length for the test |
| 407 | + hidden_size: Hidden dimension size |
| 408 | + dtype: Data type (bfloat16 or float16) |
| 409 | + splits_per_rank: Number of splits for progressive AllReduce |
| 410 | + """ |
| 411 | + num_processes = 2 |
| 412 | + |
| 413 | + def run_torch_spawn(fn, nprocs): |
| 414 | + torch.multiprocessing.spawn( |
| 415 | + fn, |
| 416 | + args=( |
| 417 | + num_processes, |
| 418 | + batch_size, |
| 419 | + seq_len, |
| 420 | + hidden_size, |
| 421 | + dtype, |
| 422 | + splits_per_rank, |
| 423 | + ), |
| 424 | + nprocs=nprocs, |
| 425 | + ) |
| 426 | + |
| 427 | + run_torch_spawn(helion_allreduce_add_rmsnorm_worker, num_processes) |
| 428 | + |
| 429 | + |
| 430 | +def helion_allreduce_add_rmsnorm_worker( |
| 431 | + local_rank: int, |
| 432 | + world_size: int, |
| 433 | + batch_size: int, |
| 434 | + seq_len: int, |
| 435 | + hidden_size: int, |
| 436 | + dtype: torch.dtype, |
| 437 | + splits_per_rank: int, |
| 438 | +): |
| 439 | + """Worker function for testing Helion AllReduce + Add + RMSNorm.""" |
| 440 | + import torch.distributed as dist |
| 441 | + |
| 442 | + current_platform.seed_everything(0) |
| 443 | + |
| 444 | + device = torch.device(f"cuda:{local_rank}") |
| 445 | + torch.cuda.set_device(device) |
| 446 | + torch.set_default_device(device) |
| 447 | + torch.set_default_dtype(dtype) |
| 448 | + |
| 449 | + # Initialize distributed environment |
| 450 | + update_environment_variables( |
| 451 | + { |
| 452 | + "RANK": str(local_rank), |
| 453 | + "LOCAL_RANK": str(local_rank), |
| 454 | + "WORLD_SIZE": str(world_size), |
| 455 | + "MASTER_ADDR": "localhost", |
| 456 | + "MASTER_PORT": "12346", # Different port from other tests |
| 457 | + } |
| 458 | + ) |
| 459 | + |
| 460 | + init_distributed_environment() |
| 461 | + initialize_model_parallel(tensor_model_parallel_size=world_size) |
| 462 | + |
| 463 | + token_num = batch_size * seq_len |
| 464 | + M, K = token_num, hidden_size |
| 465 | + rms_eps = 1e-6 |
| 466 | + |
| 467 | + # ========== Setup FlashInfer baseline ========== |
| 468 | + flashinfer_ipc_handles, flashinfer_workspace = ( |
| 469 | + flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( |
| 470 | + tp_rank=local_rank, |
| 471 | + tp_size=world_size, |
| 472 | + max_token_num=M, |
| 473 | + hidden_dim=K, |
| 474 | + group=dist.group.WORLD, |
| 475 | + use_fp32_lamport=False, |
| 476 | + ) |
| 477 | + ) |
| 478 | + |
| 479 | + # ========== Test Numerical Correctness ========== |
| 480 | + # Create test data (same seed across ranks for reproducibility) |
| 481 | + torch.manual_seed(42 + local_rank) # Different data per rank |
| 482 | + input_data = torch.randn(M, K, dtype=dtype, device=device) |
| 483 | + residual_data = torch.randn(M, K, dtype=dtype, device=device) |
| 484 | + rms_gamma = torch.ones(K, dtype=dtype, device=device) |
| 485 | + |
| 486 | + # Run FlashInfer baseline |
| 487 | + input_baseline = symm_mem.empty(M, K, dtype=dtype, device=device) |
| 488 | + input_baseline.copy_(input_data) |
| 489 | + residual_baseline = residual_data.clone() |
| 490 | + |
| 491 | + norm_out_baseline = input_baseline # FlashInfer operates in-place |
| 492 | + residual_out_baseline = residual_baseline |
| 493 | + |
| 494 | + flashinfer_comm.trtllm_allreduce_fusion( |
| 495 | + allreduce_in=input_baseline, |
| 496 | + token_num=M, |
| 497 | + residual_in=residual_baseline, |
| 498 | + residual_out=residual_out_baseline, |
| 499 | + norm_out=norm_out_baseline, |
| 500 | + rms_gamma=rms_gamma, |
| 501 | + rms_eps=rms_eps, |
| 502 | + hidden_dim=K, |
| 503 | + workspace_ptrs=flashinfer_workspace, |
| 504 | + pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm, |
| 505 | + allreduce_out=None, |
| 506 | + quant_out=None, |
| 507 | + scale_out=None, |
| 508 | + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, |
| 509 | + scale_factor=None, |
| 510 | + use_oneshot=False, |
| 511 | + world_rank=local_rank, |
| 512 | + world_size=world_size, |
| 513 | + launch_with_pdl=True, |
| 514 | + trigger_completion_at_end=True, |
| 515 | + fp32_acc=True, |
| 516 | + ) |
| 517 | + torch.cuda.synchronize() |
| 518 | + |
| 519 | + # Run Helion |
| 520 | + input_helion = symm_mem.empty(M, K, dtype=dtype, device=device) |
| 521 | + input_helion.copy_(input_data) |
| 522 | + residual_helion = residual_data.clone() |
| 523 | + |
| 524 | + norm_out_helion, residual_out_helion = helion_allreduce_add_rmsnorm( |
| 525 | + input_helion, residual_helion, rms_gamma, rms_eps, splits_per_rank |
| 526 | + ) |
| 527 | + torch.cuda.synchronize() |
| 528 | + |
| 529 | + # Compare results |
| 530 | + # Use relaxed tolerances for bfloat16 and for accumulated errors |
| 531 | + if dtype == torch.bfloat16: |
| 532 | + rtol, atol = 1e-2, 1e-2 |
| 533 | + else: |
| 534 | + rtol, atol = 1e-3, 1e-3 |
| 535 | + |
| 536 | + # Check normalized output |
| 537 | + torch.testing.assert_close( |
| 538 | + norm_out_helion, |
| 539 | + norm_out_baseline, |
| 540 | + rtol=rtol, |
| 541 | + atol=atol, |
| 542 | + msg=f"Normalized output mismatch (rank={local_rank}, dtype={dtype})", |
| 543 | + ) |
| 544 | + |
| 545 | + # Check residual output |
| 546 | + torch.testing.assert_close( |
| 547 | + residual_out_helion, |
| 548 | + residual_out_baseline, |
| 549 | + rtol=rtol, |
| 550 | + atol=atol, |
| 551 | + msg=f"Residual output mismatch (rank={local_rank}, dtype={dtype})", |
| 552 | + ) |
| 553 | + |
| 554 | + if local_rank == 0: |
| 555 | + print( |
| 556 | + f"✓ Numerical correctness test passed " |
| 557 | + f"(M={M}, K={K}, dtype={dtype}, splits={splits_per_rank})" |
| 558 | + ) |
| 559 | + |
| 560 | + # ========== Performance Comparison ========== |
| 561 | + num_iterations = 20 |
| 562 | + warmup = 5 |
| 563 | + |
| 564 | + def time_kernel(kernel_fn): |
| 565 | + # Warmup |
| 566 | + for _ in range(warmup): |
| 567 | + kernel_fn() |
| 568 | + torch.cuda.synchronize() |
| 569 | + |
| 570 | + # Benchmark |
| 571 | + start_event = torch.cuda.Event(enable_timing=True) |
| 572 | + end_event = torch.cuda.Event(enable_timing=True) |
| 573 | + |
| 574 | + start_event.record() |
| 575 | + for _ in range(num_iterations): |
| 576 | + kernel_fn() |
| 577 | + end_event.record() |
| 578 | + |
| 579 | + torch.cuda.synchronize() |
| 580 | + |
| 581 | + return start_event.elapsed_time(end_event) / num_iterations |
| 582 | + |
| 583 | + # Benchmark FlashInfer |
| 584 | + input_baseline_perf = symm_mem.empty(M, K, dtype=dtype, device=device) |
| 585 | + residual_baseline_perf = torch.empty(M, K, dtype=dtype, device=device) |
| 586 | + input_data_perf = torch.randn(M, K, dtype=dtype, device=device) |
| 587 | + residual_data_perf = torch.randn(M, K, dtype=dtype, device=device) |
| 588 | + |
| 589 | + def baseline_fn(): |
| 590 | + input_baseline_perf.copy_(input_data_perf) |
| 591 | + residual_baseline_perf.copy_(residual_data_perf) |
| 592 | + |
| 593 | + flashinfer_comm.trtllm_allreduce_fusion( |
| 594 | + allreduce_in=input_baseline_perf, |
| 595 | + token_num=M, |
| 596 | + residual_in=residual_baseline_perf, |
| 597 | + residual_out=residual_baseline_perf, |
| 598 | + norm_out=input_baseline_perf, |
| 599 | + rms_gamma=rms_gamma, |
| 600 | + rms_eps=rms_eps, |
| 601 | + hidden_dim=K, |
| 602 | + workspace_ptrs=flashinfer_workspace, |
| 603 | + pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm, |
| 604 | + allreduce_out=None, |
| 605 | + quant_out=None, |
| 606 | + scale_out=None, |
| 607 | + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, |
| 608 | + scale_factor=None, |
| 609 | + use_oneshot=False, |
| 610 | + world_rank=local_rank, |
| 611 | + world_size=world_size, |
| 612 | + launch_with_pdl=True, |
| 613 | + trigger_completion_at_end=True, |
| 614 | + fp32_acc=True, |
| 615 | + ) |
| 616 | + |
| 617 | + dist.barrier() |
| 618 | + baseline_time = time_kernel(baseline_fn) |
| 619 | + dist.barrier() |
| 620 | + |
| 621 | + # Benchmark Helion |
| 622 | + input_helion_perf = symm_mem.empty(M, K, dtype=dtype, device=device) |
| 623 | + residual_helion_perf = torch.empty(M, K, dtype=dtype, device=device) |
| 624 | + |
| 625 | + def helion_fn(): |
| 626 | + input_helion_perf.copy_(input_data_perf) |
| 627 | + residual_helion_perf.copy_(residual_data_perf) |
| 628 | + |
| 629 | + helion_allreduce_add_rmsnorm( |
| 630 | + input_helion_perf, residual_helion_perf, rms_gamma, rms_eps, splits_per_rank |
| 631 | + ) |
| 632 | + |
| 633 | + dist.barrier() |
| 634 | + helion_time = time_kernel(helion_fn) |
| 635 | + dist.barrier() |
| 636 | + |
| 637 | + if local_rank == 0: |
| 638 | + speedup = baseline_time / helion_time |
| 639 | + print(f"✓ Performance comparison (M={M}, K={K}, dtype={dtype}):") |
| 640 | + print(f" FlashInfer: {baseline_time:.4f} ms") |
| 641 | + print(f" Helion: {helion_time:.4f} ms") |
| 642 | + print(f" Speedup: {speedup:.2f}x") |
| 643 | + |
| 644 | + # Cleanup |
| 645 | + try: |
| 646 | + flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce_fusion( |
| 647 | + flashinfer_ipc_handles |
| 648 | + ) |
| 649 | + except: |
| 650 | + pass |
| 651 | + |
0 commit comments