|  | 
| 19 | 19 | @pytest.fixture(autouse=True) | 
| 20 | 20 | def enable_batch_invariant_mode(): | 
| 21 | 21 |     """Automatically enable batch invariant kernel overrides for all tests.""" | 
| 22 |  | -    old_value = os.environ.get("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT") | 
| 23 |  | -    os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = "1" | 
|  | 22 | +    old_value = os.environ.get("VLLM_BATCH_INVARIANT") | 
|  | 23 | +    os.environ["VLLM_BATCH_INVARIANT"] = "1" | 
| 24 | 24 |     yield | 
| 25 | 25 |     # Restore original value after test | 
| 26 | 26 |     if old_value is None: | 
| 27 |  | -        os.environ.pop("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", None) | 
|  | 27 | +        os.environ.pop("VLLM_BATCH_INVARIANT", None) | 
| 28 | 28 |     else: | 
| 29 |  | -        os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = old_value | 
|  | 29 | +        os.environ["VLLM_BATCH_INVARIANT"] = old_value | 
| 30 | 30 | 
 | 
| 31 | 31 | 
 | 
| 32 | 32 | def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str: | 
| @@ -231,10 +231,10 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend): | 
| 231 | 231 |     # For batch invariance, disable custom all-reduce to ensure deterministic | 
| 232 | 232 |     # all-reduce operations (custom all-reduce may not be deterministic) | 
| 233 | 233 |     from vllm.model_executor.layers.batch_invariant import ( | 
| 234 |  | -        vllm_kernel_override_batch_invariant, | 
|  | 234 | +        vllm_is_batch_invariant, | 
| 235 | 235 |     ) | 
| 236 | 236 | 
 | 
| 237 |  | -    disable_custom_ar = vllm_kernel_override_batch_invariant() | 
|  | 237 | +    disable_custom_ar = vllm_is_batch_invariant() | 
| 238 | 238 | 
 | 
| 239 | 239 |     if disable_custom_ar: | 
| 240 | 240 |         print(f"\n{'=' * 80}") | 
| @@ -494,8 +494,8 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend): | 
| 494 | 494 |     os.environ["VLLM_ATTENTION_BACKEND"] = backend | 
| 495 | 495 | 
 | 
| 496 | 496 |     # CRITICAL: Disable batch invariance for this test | 
| 497 |  | -    old_value = os.environ.get("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT") | 
| 498 |  | -    os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = "0" | 
|  | 497 | +    old_value = os.environ.get("VLLM_BATCH_INVARIANT") | 
|  | 498 | +    os.environ["VLLM_BATCH_INVARIANT"] = "0" | 
| 499 | 499 | 
 | 
| 500 | 500 |     try: | 
| 501 | 501 |         seed = int(os.getenv("VLLM_TEST_SEED", "12345")) | 
| @@ -687,9 +687,9 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend): | 
| 687 | 687 |     finally: | 
| 688 | 688 |         # Restore original value | 
| 689 | 689 |         if old_value is None: | 
| 690 |  | -            os.environ.pop("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", None) | 
|  | 690 | +            os.environ.pop("VLLM_BATCH_INVARIANT", None) | 
| 691 | 691 |         else: | 
| 692 |  | -            os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = old_value | 
|  | 692 | +            os.environ["VLLM_BATCH_INVARIANT"] = old_value | 
| 693 | 693 | 
 | 
| 694 | 694 | 
 | 
| 695 | 695 | @hopper_only | 
| @@ -718,10 +718,10 @@ def test_decode_logprobs_match_prefill_logprobs(backend): | 
| 718 | 718 |     tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1")) | 
| 719 | 719 | 
 | 
| 720 | 720 |     from vllm.model_executor.layers.batch_invariant import ( | 
| 721 |  | -        vllm_kernel_override_batch_invariant, | 
|  | 721 | +        vllm_is_batch_invariant, | 
| 722 | 722 |     ) | 
| 723 | 723 | 
 | 
| 724 |  | -    disable_custom_ar = vllm_kernel_override_batch_invariant() | 
|  | 724 | +    disable_custom_ar = vllm_is_batch_invariant() | 
| 725 | 725 | 
 | 
| 726 | 726 |     if disable_custom_ar: | 
| 727 | 727 |         print(f"\n{'=' * 80}") | 
|  | 
0 commit comments