diff --git a/csrc/core/batch_invariant.hpp b/csrc/core/batch_invariant.hpp index e769e1a25ac0..fffe96b86857 100644 --- a/csrc/core/batch_invariant.hpp +++ b/csrc/core/batch_invariant.hpp @@ -5,11 +5,11 @@ namespace vllm { -// vllm_kernel_override_batch_invariant(); returns true -// if env VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT=1 -inline bool vllm_kernel_override_batch_invariant() { +// vllm_is_batch_invariant(); returns true +// if env VLLM_BATCH_INVARIANT=1 +inline bool vllm_is_batch_invariant() { static bool cached = []() { - std::string env_key = "VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"; + std::string env_key = "VLLM_BATCH_INVARIANT"; const char* val = std::getenv(env_key.c_str()); return (val && std::atoi(val) != 0) ? 1 : 0; }(); diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index aa7927f09cbb..7e8ec9937e54 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -426,7 +426,7 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] wt_ptr % req_alignment_bytes == 0; bool offsets_are_multiple_of_vector_width = hidden_size % vector_width == 0 && input_stride % vector_width == 0; - bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant(); + bool batch_invariant_launch = vllm::vllm_is_batch_invariant(); if (ptrs_are_aligned && offsets_are_multiple_of_vector_width && !batch_invariant_launch) { LAUNCH_FUSED_ADD_RMS_NORM(8); @@ -474,7 +474,7 @@ void poly_norm(torch::Tensor& out, // [..., hidden_size] auto inp_ptr = reinterpret_cast(input.data_ptr()); auto out_ptr = reinterpret_cast(out.data_ptr()); bool ptrs_are_aligned = inp_ptr % 16 == 0 && out_ptr % 16 == 0; - bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant(); + bool batch_invariant_launch = vllm::vllm_is_batch_invariant(); if (ptrs_are_aligned && hidden_size % 8 == 0 && !batch_invariant_launch) { LAUNCH_FUSED_POLY_NORM(8); } else { diff --git a/csrc/layernorm_quant_kernels.cu b/csrc/layernorm_quant_kernels.cu index 7f9a0bccdd34..64d14429f938 100644 --- a/csrc/layernorm_quant_kernels.cu +++ b/csrc/layernorm_quant_kernels.cu @@ -254,7 +254,7 @@ void fused_add_rms_norm_static_fp8_quant( auto wt_ptr = reinterpret_cast(weight.data_ptr()); bool ptrs_are_aligned = inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0; - bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant(); + bool batch_invariant_launch = vllm::vllm_is_batch_invariant(); if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0 && !batch_invariant_launch) { LAUNCH_FUSED_ADD_RMS_NORM(8); diff --git a/tests/v1/e2e/test_async_sched_and_preempt.py b/tests/v1/e2e/test_async_sched_and_preempt.py index 0f7ccb35a757..bc93a4c8c697 100644 --- a/tests/v1/e2e/test_async_sched_and_preempt.py +++ b/tests/v1/e2e/test_async_sched_and_preempt.py @@ -39,7 +39,7 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch): with monkeypatch.context() as m: m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION") - # m.setenv("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", "1") + # m.setenv("VLLM_BATCH_INVARIANT", "1") outputs: list[tuple[str, list]] = [] for test_preemption in [False, True]: diff --git a/tests/v1/generation/test_batch_invariance.py b/tests/v1/generation/test_batch_invariance.py index 6fe7c42df283..bb685e8509e7 100644 --- a/tests/v1/generation/test_batch_invariance.py +++ b/tests/v1/generation/test_batch_invariance.py @@ -14,14 +14,14 @@ @pytest.fixture(autouse=True) def enable_batch_invariant_mode(): """Automatically enable batch invariant kernel overrides for all tests.""" - old_value = os.environ.get("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT") - os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = "1" + old_value = os.environ.get("VLLM_BATCH_INVARIANT") + os.environ["VLLM_BATCH_INVARIANT"] = "1" yield # Restore original value after test if old_value is None: - os.environ.pop("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", None) + os.environ.pop("VLLM_BATCH_INVARIANT", None) else: - os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = old_value + os.environ["VLLM_BATCH_INVARIANT"] = old_value def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str: @@ -236,10 +236,10 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend): # For batch invariance, disable custom all-reduce to ensure deterministic # all-reduce operations (custom all-reduce may not be deterministic) from vllm.model_executor.layers.batch_invariant import ( - vllm_kernel_override_batch_invariant, + vllm_is_batch_invariant, ) - disable_custom_ar = vllm_kernel_override_batch_invariant() + disable_custom_ar = vllm_is_batch_invariant() if disable_custom_ar: print(f"\n{'=' * 80}") @@ -509,8 +509,8 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend): os.environ["VLLM_ATTENTION_BACKEND"] = backend # CRITICAL: Disable batch invariance for this test - old_value = os.environ.get("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT") - os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = "0" + old_value = os.environ.get("VLLM_BATCH_INVARIANT") + os.environ["VLLM_BATCH_INVARIANT"] = "0" try: seed = int(os.getenv("VLLM_TEST_SEED", "12345")) @@ -702,9 +702,9 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend): finally: # Restore original value if old_value is None: - os.environ.pop("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", None) + os.environ.pop("VLLM_BATCH_INVARIANT", None) else: - os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = old_value + os.environ["VLLM_BATCH_INVARIANT"] = old_value @pytest.mark.skipif( @@ -740,10 +740,10 @@ def test_decode_logprobs_match_prefill_logprobs(backend): tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1")) from vllm.model_executor.layers.batch_invariant import ( - vllm_kernel_override_batch_invariant, + vllm_is_batch_invariant, ) - disable_custom_ar = vllm_kernel_override_batch_invariant() + disable_custom_ar = vllm_is_batch_invariant() if disable_custom_ar: print(f"\n{'=' * 80}") diff --git a/vllm/config/model.py b/vllm/config/model.py index ebad9bfb9c90..b572967d364c 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -21,7 +21,7 @@ from vllm.config.utils import assert_hashable, config, getattr_iter from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( - vllm_kernel_override_batch_invariant, + vllm_is_batch_invariant, ) from vllm.platforms import current_platform from vllm.transformers_utils.config import ( @@ -423,7 +423,7 @@ def __post_init__( video_pruning_rate: float | None, ) -> None: # Enable batch invariance settings if requested - if vllm_kernel_override_batch_invariant(): + if vllm_is_batch_invariant(): self.enforce_eager = True # Set the default seed to 0 in V1. diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 9b0634ad2ac9..953aa1a147de 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -15,7 +15,7 @@ from vllm.config.utils import config from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( - vllm_kernel_override_batch_invariant, + vllm_is_batch_invariant, ) from vllm.platforms import current_platform from vllm.utils import cuda_device_count_stateless, get_open_ports_list @@ -565,7 +565,7 @@ def _verify_args(self) -> Self: from vllm.executor.executor_base import ExecutorBase # Enable batch invariance settings if requested - if vllm_kernel_override_batch_invariant(): + if vllm_is_batch_invariant(): self.disable_custom_all_reduce = True if ( diff --git a/vllm/distributed/device_communicators/all_reduce_utils.py b/vllm/distributed/device_communicators/all_reduce_utils.py index a3eef87b451f..2eb3ce2976d2 100644 --- a/vllm/distributed/device_communicators/all_reduce_utils.py +++ b/vllm/distributed/device_communicators/all_reduce_utils.py @@ -20,7 +20,7 @@ from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( - vllm_kernel_override_batch_invariant, + vllm_is_batch_invariant, ) from vllm.utils import cuda_device_count_stateless, update_environment_variables @@ -74,7 +74,7 @@ def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor) is_symmetric_memory_enabled, ) - if vllm_kernel_override_batch_invariant(): + if vllm_is_batch_invariant(): return False if not is_symmetric_memory_enabled(): diff --git a/vllm/distributed/device_communicators/symm_mem.py b/vllm/distributed/device_communicators/symm_mem.py index f214c013bd3b..74d6fb40c83b 100644 --- a/vllm/distributed/device_communicators/symm_mem.py +++ b/vllm/distributed/device_communicators/symm_mem.py @@ -10,7 +10,7 @@ ) from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( - vllm_kernel_override_batch_invariant, + vllm_is_batch_invariant, ) from vllm.platforms import current_platform @@ -103,7 +103,7 @@ def __init__( return self.force_multimem = force_multimem self.disabled = False - if vllm_kernel_override_batch_invariant(): + if vllm_is_batch_invariant(): self.disabled = True def should_use_symm_mem(self, inp: torch.Tensor): diff --git a/vllm/model_executor/layers/batch_invariant.py b/vllm/model_executor/layers/batch_invariant.py index 029605aed502..582e4aae78fc 100644 --- a/vllm/model_executor/layers/batch_invariant.py +++ b/vllm/model_executor/layers/batch_invariant.py @@ -741,8 +741,8 @@ def get_batch_invariant_attention_block_size() -> AttentionBlockSize: return AttentionBlockSize(block_m=16, block_n=16) -def vllm_kernel_override_batch_invariant(): - env_key = "VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT" +def vllm_is_batch_invariant(): + env_key = "VLLM_BATCH_INVARIANT" is_overridden = False val = os.getenv(env_key, "0") try: @@ -797,7 +797,7 @@ def override_envs_for_invariance(): def init_batch_invariance(): # this will hit all the csrc overrides as well - if vllm_kernel_override_batch_invariant(): + if vllm_is_batch_invariant(): override_envs_for_invariance() enable_batch_invariant_mode() diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 256f4964b654..42724f2ff3c0 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -16,7 +16,7 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( - vllm_kernel_override_batch_invariant, + vllm_is_batch_invariant, ) from vllm.model_executor.layers.fused_moe.config import ( FUSED_MOE_UNQUANTIZED_CONFIG, @@ -841,7 +841,7 @@ def get_moe_configs( """ # Avoid optimizing for the batch invariant case. Use default config - if vllm_kernel_override_batch_invariant(): + if vllm_is_batch_invariant(): return None # First look up if an optimized configuration is available in the configs @@ -976,7 +976,7 @@ def get_default_config( dtype: str | None, block_shape: list[int] | None = None, ) -> dict[str, int]: - if vllm_kernel_override_batch_invariant(): + if vllm_is_batch_invariant(): config = { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, @@ -1136,7 +1136,7 @@ def fused_topk_bias( ) + e_score_correction_bias.unsqueeze(0) # For batch invariance, use sorted=True to ensure deterministic expert selection - use_sorted = vllm_kernel_override_batch_invariant() + use_sorted = vllm_is_batch_invariant() topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[1] topk_weights = scores.gather(1, topk_indices) if renormalize: @@ -1200,7 +1200,7 @@ def grouped_topk( ) # [n, n_group] # For batch invariance, use sorted=True to ensure deterministic expert selection - use_sorted = vllm_kernel_override_batch_invariant() + use_sorted = vllm_is_batch_invariant() group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[ 1 ] # [n, top_k_group] diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index a689bc7be00f..85ead0d81059 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -10,7 +10,7 @@ from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.batch_invariant import ( rms_norm_batch_invariant, - vllm_kernel_override_batch_invariant, + vllm_is_batch_invariant, ) from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op @@ -25,7 +25,7 @@ def rms_norm( ) -> torch.Tensor: from vllm import _custom_ops as ops - if vllm_kernel_override_batch_invariant(): + if vllm_is_batch_invariant(): return rms_norm_batch_invariant(x, weight, variance_epsilon) out = torch.empty_like(x) ops.rms_norm( @@ -45,7 +45,7 @@ def fused_add_rms_norm( ) -> tuple[torch.Tensor, torch.Tensor]: from vllm import _custom_ops as ops - if vllm_kernel_override_batch_invariant(): + if vllm_is_batch_invariant(): return rms_norm_batch_invariant( x + residual, weight, variance_epsilon ), x + residual diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 4edb55d816cf..bfd8fd7b9f7c 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -15,7 +15,7 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( - vllm_kernel_override_batch_invariant, + vllm_is_batch_invariant, ) from vllm.model_executor.layers.fused_moe import ( FusedMoE, @@ -356,7 +356,7 @@ def __init__(self, quant_config: Fp8Config): # Disable marlin for rocm if current_platform.is_rocm(): self.use_marlin = False - if vllm_kernel_override_batch_invariant(): + if vllm_is_batch_invariant(): self.use_marlin = False self.use_aiter_and_is_supported = check_aiter_fp8_linear_support() @@ -540,7 +540,7 @@ def apply( bias: torch.Tensor | None = None, ) -> torch.Tensor: # If batch invariant mode is enabled, dequantize and use BF16 compute - if vllm_kernel_override_batch_invariant(): + if vllm_is_batch_invariant(): # Dequantize FP8 weights to BF16 weight_fp8 = layer.weight.to(torch.bfloat16) weight_scale = layer.weight_scale.to(torch.bfloat16) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 6811860a34b0..8affde914782 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -35,7 +35,7 @@ from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( - vllm_kernel_override_batch_invariant, + vllm_is_batch_invariant, ) from vllm.utils import cdiv from vllm.v1.attention.backends.utils import ( @@ -308,7 +308,7 @@ def build( # we only set num_splits when using cuda graphs. max_num_splits = self.max_num_splits - if vllm_kernel_override_batch_invariant(): + if vllm_is_batch_invariant(): max_num_splits = 1 def schedule( @@ -484,7 +484,7 @@ def __init__( self.attn_type = attn_type self.vllm_flash_attn_version = get_flash_attn_version() # Cache the batch invariant result for use in forward passes - self.batch_invariant_enabled = vllm_kernel_override_batch_invariant() + self.batch_invariant_enabled = vllm_is_batch_invariant() if is_quantized_kv_cache(self.kv_cache_dtype) and not flash_attn_supports_fp8(): raise NotImplementedError( @@ -963,7 +963,7 @@ def cascade_attention( # s_aux is incorporated into prefix_lse inside the GPU kernel, # enabling its effect during the final attention merge. s_aux=s_aux, - num_splits=1 if vllm_kernel_override_batch_invariant() else 0, + num_splits=1 if vllm_is_batch_invariant() else 0, ) descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2]) @@ -988,7 +988,7 @@ def cascade_attention( q_descale=q_descale.expand(descale_shape) if q_descale is not None else None, k_descale=k_descale.expand(descale_shape) if k_descale is not None else None, v_descale=v_descale.expand(descale_shape) if v_descale is not None else None, - num_splits=1 if vllm_kernel_override_batch_invariant() else 0, + num_splits=1 if vllm_is_batch_invariant() else 0, ) # Merge prefix and suffix outputs, and store the result in output. diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 34225602f025..cd54b964c41f 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -25,7 +25,7 @@ from vllm.config import CUDAGraphMode, VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( - vllm_kernel_override_batch_invariant, + vllm_is_batch_invariant, ) from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, @@ -291,7 +291,7 @@ def __init__( self._prefill_wrapper = None # Wrapper for prefill/append self._decode_wrapper = None # Wrapper for decode (general shape) - if vllm_kernel_override_batch_invariant(): + if vllm_is_batch_invariant(): self.decode_fixed_split_size = 2048 self.prefill_fixed_split_size = 4096 self.disable_split_kv = True @@ -404,7 +404,7 @@ def __init__( def _get_workspace_buffer(self): if self._workspace_buffer is None: buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE - if vllm_kernel_override_batch_invariant(): + if vllm_is_batch_invariant(): buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT self._workspace_buffer = torch.zeros( buffer_size, dtype=torch.uint8, device=self.device diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index 902872bb25b3..29884700d9a4 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -26,7 +26,7 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( - vllm_kernel_override_batch_invariant, + vllm_is_batch_invariant, ) from vllm.utils import cdiv, is_torch_equal_or_newer from vllm.v1.attention.backends.utils import ( @@ -863,7 +863,7 @@ def get_kernel_options( kernel_options: dict[str, int | bool] = { "FORCE_USE_FLEX_ATTENTION": True, } - if vllm_kernel_override_batch_invariant(): + if vllm_is_batch_invariant(): kernel_options["BLOCK_M"] = 16 kernel_options["BLOCK_N"] = 16 kernel_options["IS_DIVISIBLE"] = False diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 1d4e3e4cfe22..51a9032f4269 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -212,7 +212,7 @@ from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( - vllm_kernel_override_batch_invariant, + vllm_is_batch_invariant, ) from vllm.model_executor.layers.linear import ( ColumnParallelLinear, @@ -1283,7 +1283,7 @@ def _flash_attn_varlen_diff_headdims( # ROCm leverages the upstream flash_attn, which takes a parameter # called "return_attn_probs" instead of return_softmax_lse kwargs["return_attn_probs"] = return_softmax_lse - if vllm_kernel_override_batch_invariant(): + if vllm_is_batch_invariant(): kwargs["num_splits"] = 1 attn_out = self.flash_attn_varlen_func( diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 3e404d50ee7c..71f5473bc9de 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -19,7 +19,7 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( - vllm_kernel_override_batch_invariant, + vllm_is_batch_invariant, ) from vllm.v1.attention.backends.mla.common import ( MLACommonBackend, @@ -110,7 +110,7 @@ def __init__( # pre-allocated during capture. self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH - if vllm_kernel_override_batch_invariant(): + if vllm_is_batch_invariant(): self.max_num_splits = 1 def _schedule_decode( @@ -181,7 +181,7 @@ def _build_decode( # we only set num_splits when using cuda graphs. max_num_splits = self.max_num_splits - if vllm_kernel_override_batch_invariant(): + if vllm_is_batch_invariant(): max_num_splits = 1 metadata = FlashAttnMLADecodeMetadata( diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index fc8fb34afb18..34d3c8ee1ba2 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -15,7 +15,7 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( - vllm_kernel_override_batch_invariant, + vllm_is_batch_invariant, ) from vllm.v1.attention.backends.mla.common import ( MLACommonBackend, @@ -234,7 +234,7 @@ def _forward_decode( tile_scheduler_metadata = attn_metadata.decode.tile_scheduler_metadata num_splits = attn_metadata.decode.num_splits - if vllm_kernel_override_batch_invariant(): + if vllm_is_batch_invariant(): device = q.device dtype = torch.int32 diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index d3524020bc7f..781f77e96319 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -14,7 +14,7 @@ from vllm.attention.ops.triton_flash_attention import triton_attention from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( - vllm_kernel_override_batch_invariant, + vllm_is_batch_invariant, ) from vllm.platforms import current_platform from vllm.triton_utils import HAS_TRITON @@ -163,7 +163,7 @@ def _forward_decode( lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device) # For batch invariance, use only 1 split to ensure deterministic reduction - num_kv_splits = 1 if vllm_kernel_override_batch_invariant() else 4 + num_kv_splits = 1 if vllm_is_batch_invariant() else 4 # TODO(lucas) Allocate ahead of time attn_logits = torch.empty(