Skip to content

Commit 3d08630

Browse files
bwastilywa1998
authored andcommitted
[small][batch invariance] Rename the env and internal flags to simplify usage (vllm-project#26855)
Signed-off-by: Bram Wasti <[email protected]>
1 parent b399814 commit 3d08630

File tree

20 files changed

+61
-61
lines changed

20 files changed

+61
-61
lines changed

csrc/core/batch_invariant.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55

66
namespace vllm {
77

8-
// vllm_kernel_override_batch_invariant(); returns true
9-
// if env VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT=1
10-
inline bool vllm_kernel_override_batch_invariant() {
8+
// vllm_is_batch_invariant(); returns true
9+
// if env VLLM_BATCH_INVARIANT=1
10+
inline bool vllm_is_batch_invariant() {
1111
static bool cached = []() {
12-
std::string env_key = "VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT";
12+
std::string env_key = "VLLM_BATCH_INVARIANT";
1313
const char* val = std::getenv(env_key.c_str());
1414
return (val && std::atoi(val) != 0) ? 1 : 0;
1515
}();

csrc/layernorm_kernels.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
426426
wt_ptr % req_alignment_bytes == 0;
427427
bool offsets_are_multiple_of_vector_width =
428428
hidden_size % vector_width == 0 && input_stride % vector_width == 0;
429-
bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
429+
bool batch_invariant_launch = vllm::vllm_is_batch_invariant();
430430
if (ptrs_are_aligned && offsets_are_multiple_of_vector_width &&
431431
!batch_invariant_launch) {
432432
LAUNCH_FUSED_ADD_RMS_NORM(8);
@@ -474,7 +474,7 @@ void poly_norm(torch::Tensor& out, // [..., hidden_size]
474474
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
475475
auto out_ptr = reinterpret_cast<std::uintptr_t>(out.data_ptr());
476476
bool ptrs_are_aligned = inp_ptr % 16 == 0 && out_ptr % 16 == 0;
477-
bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
477+
bool batch_invariant_launch = vllm::vllm_is_batch_invariant();
478478
if (ptrs_are_aligned && hidden_size % 8 == 0 && !batch_invariant_launch) {
479479
LAUNCH_FUSED_POLY_NORM(8);
480480
} else {

csrc/layernorm_quant_kernels.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ void fused_add_rms_norm_static_fp8_quant(
254254
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
255255
bool ptrs_are_aligned =
256256
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
257-
bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
257+
bool batch_invariant_launch = vllm::vllm_is_batch_invariant();
258258
if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0 &&
259259
!batch_invariant_launch) {
260260
LAUNCH_FUSED_ADD_RMS_NORM(8);

tests/v1/e2e/test_async_sched_and_preempt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
3939

4040
with monkeypatch.context() as m:
4141
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
42-
# m.setenv("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", "1")
42+
# m.setenv("VLLM_BATCH_INVARIANT", "1")
4343

4444
outputs: list[tuple[str, list]] = []
4545
for test_preemption in [False, True]:

tests/v1/generation/test_batch_invariance.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@
1919
@pytest.fixture(autouse=True)
2020
def enable_batch_invariant_mode():
2121
"""Automatically enable batch invariant kernel overrides for all tests."""
22-
old_value = os.environ.get("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT")
23-
os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = "1"
22+
old_value = os.environ.get("VLLM_BATCH_INVARIANT")
23+
os.environ["VLLM_BATCH_INVARIANT"] = "1"
2424
yield
2525
# Restore original value after test
2626
if old_value is None:
27-
os.environ.pop("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", None)
27+
os.environ.pop("VLLM_BATCH_INVARIANT", None)
2828
else:
29-
os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = old_value
29+
os.environ["VLLM_BATCH_INVARIANT"] = old_value
3030

3131

3232
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
@@ -231,10 +231,10 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
231231
# For batch invariance, disable custom all-reduce to ensure deterministic
232232
# all-reduce operations (custom all-reduce may not be deterministic)
233233
from vllm.model_executor.layers.batch_invariant import (
234-
vllm_kernel_override_batch_invariant,
234+
vllm_is_batch_invariant,
235235
)
236236

237-
disable_custom_ar = vllm_kernel_override_batch_invariant()
237+
disable_custom_ar = vllm_is_batch_invariant()
238238

239239
if disable_custom_ar:
240240
print(f"\n{'=' * 80}")
@@ -494,8 +494,8 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
494494
os.environ["VLLM_ATTENTION_BACKEND"] = backend
495495

496496
# CRITICAL: Disable batch invariance for this test
497-
old_value = os.environ.get("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT")
498-
os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = "0"
497+
old_value = os.environ.get("VLLM_BATCH_INVARIANT")
498+
os.environ["VLLM_BATCH_INVARIANT"] = "0"
499499

500500
try:
501501
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
@@ -687,9 +687,9 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
687687
finally:
688688
# Restore original value
689689
if old_value is None:
690-
os.environ.pop("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", None)
690+
os.environ.pop("VLLM_BATCH_INVARIANT", None)
691691
else:
692-
os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = old_value
692+
os.environ["VLLM_BATCH_INVARIANT"] = old_value
693693

694694

695695
@hopper_only
@@ -718,10 +718,10 @@ def test_decode_logprobs_match_prefill_logprobs(backend):
718718
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
719719

720720
from vllm.model_executor.layers.batch_invariant import (
721-
vllm_kernel_override_batch_invariant,
721+
vllm_is_batch_invariant,
722722
)
723723

724-
disable_custom_ar = vllm_kernel_override_batch_invariant()
724+
disable_custom_ar = vllm_is_batch_invariant()
725725

726726
if disable_custom_ar:
727727
print(f"\n{'=' * 80}")

vllm/config/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from vllm.config.utils import assert_hashable, config, getattr_iter
2222
from vllm.logger import init_logger
2323
from vllm.model_executor.layers.batch_invariant import (
24-
vllm_kernel_override_batch_invariant,
24+
vllm_is_batch_invariant,
2525
)
2626
from vllm.platforms import current_platform
2727
from vllm.transformers_utils.config import (
@@ -423,7 +423,7 @@ def __post_init__(
423423
video_pruning_rate: float | None,
424424
) -> None:
425425
# Enable batch invariance settings if requested
426-
if vllm_kernel_override_batch_invariant():
426+
if vllm_is_batch_invariant():
427427
self.enforce_eager = True
428428

429429
# Set the default seed to 0 in V1.

vllm/config/parallel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from vllm.config.utils import config
1616
from vllm.logger import init_logger
1717
from vllm.model_executor.layers.batch_invariant import (
18-
vllm_kernel_override_batch_invariant,
18+
vllm_is_batch_invariant,
1919
)
2020
from vllm.platforms import current_platform
2121
from vllm.utils import cuda_device_count_stateless, get_open_ports_list
@@ -565,7 +565,7 @@ def _verify_args(self) -> Self:
565565
from vllm.executor.executor_base import ExecutorBase
566566

567567
# Enable batch invariance settings if requested
568-
if vllm_kernel_override_batch_invariant():
568+
if vllm_is_batch_invariant():
569569
self.disable_custom_all_reduce = True
570570

571571
if (

vllm/distributed/device_communicators/all_reduce_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
2121
from vllm.logger import init_logger
2222
from vllm.model_executor.layers.batch_invariant import (
23-
vllm_kernel_override_batch_invariant,
23+
vllm_is_batch_invariant,
2424
)
2525
from vllm.utils import cuda_device_count_stateless, update_environment_variables
2626

@@ -74,7 +74,7 @@ def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor)
7474
is_symmetric_memory_enabled,
7575
)
7676

77-
if vllm_kernel_override_batch_invariant():
77+
if vllm_is_batch_invariant():
7878
return False
7979

8080
if not is_symmetric_memory_enabled():

vllm/distributed/device_communicators/symm_mem.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
)
1111
from vllm.logger import init_logger
1212
from vllm.model_executor.layers.batch_invariant import (
13-
vllm_kernel_override_batch_invariant,
13+
vllm_is_batch_invariant,
1414
)
1515
from vllm.platforms import current_platform
1616

@@ -103,7 +103,7 @@ def __init__(
103103
return
104104
self.force_multimem = force_multimem
105105
self.disabled = False
106-
if vllm_kernel_override_batch_invariant():
106+
if vllm_is_batch_invariant():
107107
self.disabled = True
108108

109109
def should_use_symm_mem(self, inp: torch.Tensor):

vllm/model_executor/layers/batch_invariant.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -741,8 +741,8 @@ def get_batch_invariant_attention_block_size() -> AttentionBlockSize:
741741
return AttentionBlockSize(block_m=16, block_n=16)
742742

743743

744-
def vllm_kernel_override_batch_invariant():
745-
env_key = "VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"
744+
def vllm_is_batch_invariant():
745+
env_key = "VLLM_BATCH_INVARIANT"
746746
is_overridden = False
747747
val = os.getenv(env_key, "0")
748748
try:
@@ -797,7 +797,7 @@ def override_envs_for_invariance():
797797

798798
def init_batch_invariance():
799799
# this will hit all the csrc overrides as well
800-
if vllm_kernel_override_batch_invariant():
800+
if vllm_is_batch_invariant():
801801
override_envs_for_invariance()
802802
enable_batch_invariant_mode()
803803

0 commit comments

Comments
 (0)