Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions csrc/core/batch_invariant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

namespace vllm {

// vllm_kernel_override_batch_invariant(); returns true
// if env VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT=1
inline bool vllm_kernel_override_batch_invariant() {
// vllm_is_batch_invariant(); returns true
// if env VLLM_BATCH_INVARIANT=1
inline bool vllm_is_batch_invariant() {
static bool cached = []() {
std::string env_key = "VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT";
std::string env_key = "VLLM_BATCH_INVARIANT";
const char* val = std::getenv(env_key.c_str());
return (val && std::atoi(val) != 0) ? 1 : 0;
}();
Expand Down
4 changes: 2 additions & 2 deletions csrc/layernorm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
wt_ptr % req_alignment_bytes == 0;
bool offsets_are_multiple_of_vector_width =
hidden_size % vector_width == 0 && input_stride % vector_width == 0;
bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
bool batch_invariant_launch = vllm::vllm_is_batch_invariant();
if (ptrs_are_aligned && offsets_are_multiple_of_vector_width &&
!batch_invariant_launch) {
LAUNCH_FUSED_ADD_RMS_NORM(8);
Expand Down Expand Up @@ -474,7 +474,7 @@ void poly_norm(torch::Tensor& out, // [..., hidden_size]
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
auto out_ptr = reinterpret_cast<std::uintptr_t>(out.data_ptr());
bool ptrs_are_aligned = inp_ptr % 16 == 0 && out_ptr % 16 == 0;
bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
bool batch_invariant_launch = vllm::vllm_is_batch_invariant();
if (ptrs_are_aligned && hidden_size % 8 == 0 && !batch_invariant_launch) {
LAUNCH_FUSED_POLY_NORM(8);
} else {
Expand Down
2 changes: 1 addition & 1 deletion csrc/layernorm_quant_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ void fused_add_rms_norm_static_fp8_quant(
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
bool ptrs_are_aligned =
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
bool batch_invariant_launch = vllm::vllm_is_batch_invariant();
if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0 &&
!batch_invariant_launch) {
LAUNCH_FUSED_ADD_RMS_NORM(8);
Expand Down
2 changes: 1 addition & 1 deletion tests/v1/e2e/test_async_sched_and_preempt.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):

with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
# m.setenv("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", "1")
# m.setenv("VLLM_BATCH_INVARIANT", "1")

outputs: list[tuple[str, list]] = []
for test_preemption in [False, True]:
Expand Down
24 changes: 12 additions & 12 deletions tests/v1/generation/test_batch_invariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
@pytest.fixture(autouse=True)
def enable_batch_invariant_mode():
"""Automatically enable batch invariant kernel overrides for all tests."""
old_value = os.environ.get("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT")
os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = "1"
old_value = os.environ.get("VLLM_BATCH_INVARIANT")
os.environ["VLLM_BATCH_INVARIANT"] = "1"
yield
# Restore original value after test
if old_value is None:
os.environ.pop("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", None)
os.environ.pop("VLLM_BATCH_INVARIANT", None)
else:
os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = old_value
os.environ["VLLM_BATCH_INVARIANT"] = old_value


def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
Expand Down Expand Up @@ -236,10 +236,10 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
# For batch invariance, disable custom all-reduce to ensure deterministic
# all-reduce operations (custom all-reduce may not be deterministic)
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
vllm_is_batch_invariant,
)

disable_custom_ar = vllm_kernel_override_batch_invariant()
disable_custom_ar = vllm_is_batch_invariant()

if disable_custom_ar:
print(f"\n{'=' * 80}")
Expand Down Expand Up @@ -509,8 +509,8 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
os.environ["VLLM_ATTENTION_BACKEND"] = backend

# CRITICAL: Disable batch invariance for this test
old_value = os.environ.get("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT")
os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = "0"
old_value = os.environ.get("VLLM_BATCH_INVARIANT")
os.environ["VLLM_BATCH_INVARIANT"] = "0"

try:
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
Expand Down Expand Up @@ -702,9 +702,9 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
finally:
# Restore original value
if old_value is None:
os.environ.pop("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", None)
os.environ.pop("VLLM_BATCH_INVARIANT", None)
else:
os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = old_value
os.environ["VLLM_BATCH_INVARIANT"] = old_value


@pytest.mark.skipif(
Expand Down Expand Up @@ -740,10 +740,10 @@ def test_decode_logprobs_match_prefill_logprobs(backend):
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))

from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
vllm_is_batch_invariant,
)

disable_custom_ar = vllm_kernel_override_batch_invariant()
disable_custom_ar = vllm_is_batch_invariant()

if disable_custom_ar:
print(f"\n{'=' * 80}")
Expand Down
4 changes: 2 additions & 2 deletions vllm/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from vllm.config.utils import assert_hashable, config, getattr_iter
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
vllm_is_batch_invariant,
)
from vllm.platforms import current_platform
from vllm.transformers_utils.config import (
Expand Down Expand Up @@ -423,7 +423,7 @@ def __post_init__(
video_pruning_rate: float | None,
) -> None:
# Enable batch invariance settings if requested
if vllm_kernel_override_batch_invariant():
if vllm_is_batch_invariant():
self.enforce_eager = True

# Set the default seed to 0 in V1.
Expand Down
4 changes: 2 additions & 2 deletions vllm/config/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
vllm_is_batch_invariant,
)
from vllm.platforms import current_platform
from vllm.utils import cuda_device_count_stateless, get_open_ports_list
Expand Down Expand Up @@ -565,7 +565,7 @@ def _verify_args(self) -> Self:
from vllm.executor.executor_base import ExecutorBase

# Enable batch invariance settings if requested
if vllm_kernel_override_batch_invariant():
if vllm_is_batch_invariant():
self.disable_custom_all_reduce = True

if (
Expand Down
4 changes: 2 additions & 2 deletions vllm/distributed/device_communicators/all_reduce_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
vllm_is_batch_invariant,
)
from vllm.utils import cuda_device_count_stateless, update_environment_variables

Expand Down Expand Up @@ -74,7 +74,7 @@ def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor)
is_symmetric_memory_enabled,
)

if vllm_kernel_override_batch_invariant():
if vllm_is_batch_invariant():
return False

if not is_symmetric_memory_enabled():
Expand Down
4 changes: 2 additions & 2 deletions vllm/distributed/device_communicators/symm_mem.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
)
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
vllm_is_batch_invariant,
)
from vllm.platforms import current_platform

Expand Down Expand Up @@ -103,7 +103,7 @@ def __init__(
return
self.force_multimem = force_multimem
self.disabled = False
if vllm_kernel_override_batch_invariant():
if vllm_is_batch_invariant():
self.disabled = True

def should_use_symm_mem(self, inp: torch.Tensor):
Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/layers/batch_invariant.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,8 +741,8 @@ def get_batch_invariant_attention_block_size() -> AttentionBlockSize:
return AttentionBlockSize(block_m=16, block_n=16)


def vllm_kernel_override_batch_invariant():
env_key = "VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"
def vllm_is_batch_invariant():
env_key = "VLLM_BATCH_INVARIANT"
is_overridden = False
val = os.getenv(env_key, "0")
try:
Expand Down Expand Up @@ -797,7 +797,7 @@ def override_envs_for_invariance():

def init_batch_invariance():
# this will hit all the csrc overrides as well
if vllm_kernel_override_batch_invariant():
if vllm_is_batch_invariant():
override_envs_for_invariance()
enable_batch_invariant_mode()

Expand Down
10 changes: 5 additions & 5 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
vllm_is_batch_invariant,
)
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
Expand Down Expand Up @@ -841,7 +841,7 @@ def get_moe_configs(
"""

# Avoid optimizing for the batch invariant case. Use default config
if vllm_kernel_override_batch_invariant():
if vllm_is_batch_invariant():
return None

# First look up if an optimized configuration is available in the configs
Expand Down Expand Up @@ -976,7 +976,7 @@ def get_default_config(
dtype: str | None,
block_shape: list[int] | None = None,
) -> dict[str, int]:
if vllm_kernel_override_batch_invariant():
if vllm_is_batch_invariant():
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
Expand Down Expand Up @@ -1136,7 +1136,7 @@ def fused_topk_bias(
) + e_score_correction_bias.unsqueeze(0)

# For batch invariance, use sorted=True to ensure deterministic expert selection
use_sorted = vllm_kernel_override_batch_invariant()
use_sorted = vllm_is_batch_invariant()
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[1]
topk_weights = scores.gather(1, topk_indices)
if renormalize:
Expand Down Expand Up @@ -1200,7 +1200,7 @@ def grouped_topk(
) # [n, n_group]

# For batch invariance, use sorted=True to ensure deterministic expert selection
use_sorted = vllm_kernel_override_batch_invariant()
use_sorted = vllm_is_batch_invariant()
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[
1
] # [n, top_k_group]
Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.batch_invariant import (
rms_norm_batch_invariant,
vllm_kernel_override_batch_invariant,
vllm_is_batch_invariant,
)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
Expand All @@ -25,7 +25,7 @@ def rms_norm(
) -> torch.Tensor:
from vllm import _custom_ops as ops

if vllm_kernel_override_batch_invariant():
if vllm_is_batch_invariant():
return rms_norm_batch_invariant(x, weight, variance_epsilon)
out = torch.empty_like(x)
ops.rms_norm(
Expand All @@ -45,7 +45,7 @@ def fused_add_rms_norm(
) -> tuple[torch.Tensor, torch.Tensor]:
from vllm import _custom_ops as ops

if vllm_kernel_override_batch_invariant():
if vllm_is_batch_invariant():
return rms_norm_batch_invariant(
x + residual, weight, variance_epsilon
), x + residual
Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
vllm_is_batch_invariant,
)
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
Expand Down Expand Up @@ -356,7 +356,7 @@ def __init__(self, quant_config: Fp8Config):
# Disable marlin for rocm
if current_platform.is_rocm():
self.use_marlin = False
if vllm_kernel_override_batch_invariant():
if vllm_is_batch_invariant():
self.use_marlin = False

self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
Expand Down Expand Up @@ -540,7 +540,7 @@ def apply(
bias: torch.Tensor | None = None,
) -> torch.Tensor:
# If batch invariant mode is enabled, dequantize and use BF16 compute
if vllm_kernel_override_batch_invariant():
if vllm_is_batch_invariant():
# Dequantize FP8 weights to BF16
weight_fp8 = layer.weight.to(torch.bfloat16)
weight_scale = layer.weight_scale.to(torch.bfloat16)
Expand Down
10 changes: 5 additions & 5 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from vllm.distributed.parallel_state import get_dcp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
vllm_is_batch_invariant,
)
from vllm.utils import cdiv
from vllm.v1.attention.backends.utils import (
Expand Down Expand Up @@ -308,7 +308,7 @@ def build(
# we only set num_splits when using cuda graphs.
max_num_splits = self.max_num_splits

if vllm_kernel_override_batch_invariant():
if vllm_is_batch_invariant():
max_num_splits = 1

def schedule(
Expand Down Expand Up @@ -484,7 +484,7 @@ def __init__(
self.attn_type = attn_type
self.vllm_flash_attn_version = get_flash_attn_version()
# Cache the batch invariant result for use in forward passes
self.batch_invariant_enabled = vllm_kernel_override_batch_invariant()
self.batch_invariant_enabled = vllm_is_batch_invariant()

if is_quantized_kv_cache(self.kv_cache_dtype) and not flash_attn_supports_fp8():
raise NotImplementedError(
Expand Down Expand Up @@ -963,7 +963,7 @@ def cascade_attention(
# s_aux is incorporated into prefix_lse inside the GPU kernel,
# enabling its effect during the final attention merge.
s_aux=s_aux,
num_splits=1 if vllm_kernel_override_batch_invariant() else 0,
num_splits=1 if vllm_is_batch_invariant() else 0,
)

descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
Expand All @@ -988,7 +988,7 @@ def cascade_attention(
q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
num_splits=1 if vllm_kernel_override_batch_invariant() else 0,
num_splits=1 if vllm_is_batch_invariant() else 0,
)

# Merge prefix and suffix outputs, and store the result in output.
Expand Down
6 changes: 3 additions & 3 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
vllm_is_batch_invariant,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
Expand Down Expand Up @@ -291,7 +291,7 @@ def __init__(
self._prefill_wrapper = None # Wrapper for prefill/append
self._decode_wrapper = None # Wrapper for decode (general shape)

if vllm_kernel_override_batch_invariant():
if vllm_is_batch_invariant():
self.decode_fixed_split_size = 2048
self.prefill_fixed_split_size = 4096
self.disable_split_kv = True
Expand Down Expand Up @@ -404,7 +404,7 @@ def __init__(
def _get_workspace_buffer(self):
if self._workspace_buffer is None:
buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE
if vllm_kernel_override_batch_invariant():
if vllm_is_batch_invariant():
buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT
self._workspace_buffer = torch.zeros(
buffer_size, dtype=torch.uint8, device=self.device
Expand Down
Loading