From ac1020f806414b3302c8af0d624cdf6b7593c165 Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson Date: Tue, 4 Nov 2025 11:54:32 +0400 Subject: [PATCH 1/6] adjust block size according attn supported kernel sizes Signed-off-by: Vadim Gimpelson --- vllm/config/vllm.py | 4 ++ vllm/model_executor/layers/config.py | 62 ++++++++++++++++++++++++ vllm/model_executor/models/config.py | 6 ++- vllm/v1/attention/backends/flashinfer.py | 5 +- 4 files changed, 75 insertions(+), 2 deletions(-) create mode 100644 vllm/model_executor/layers/config.py diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index ee91cb0ef5c3..60d5ba9aa907 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -880,6 +880,10 @@ def try_verify_and_update_config(self): if architecture is None: return + from vllm.model_executor.layers.config import AttentionConfig + + AttentionConfig.verify_and_update_config(self) + from vllm.model_executor.models.config import ( MODELS_CONFIG_MAP, HybridAttentionMambaModelConfig, diff --git a/vllm/model_executor/layers/config.py b/vllm/model_executor/layers/config.py new file mode 100644 index 000000000000..49e7c542388b --- /dev/null +++ b/vllm/model_executor/layers/config.py @@ -0,0 +1,62 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from math import lcm +from typing import TYPE_CHECKING + +from vllm.attention.backends.abstract import MultipleOf +from vllm.attention.selector import get_attn_backend +from vllm.logger import init_logger +from vllm.model_executor.models.config import VerifyAndUpdateConfig + +if TYPE_CHECKING: + from vllm.config import VllmConfig + +logger = init_logger(__name__) + + +class AttentionConfig(VerifyAndUpdateConfig): + @classmethod + def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: + """ + Align cache_config.block_size with attention backend's minimum + supported kernel block size. + + Args: + vllm_config: vLLM Config + """ + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + assert cache_config is not None + + backend_cls = get_attn_backend( + head_size=model_config.get_head_size(), + dtype=model_config.dtype, + kv_cache_dtype=cache_config.cache_dtype, + block_size=cache_config.block_size + if cache_config.block_size is not None + else 16, + ) + # For now enable it for FlashInfer only. + # Other backend need debugging. + # TODO: enable it for all backends. + if backend_cls.get_name() != "FLASHINFER": + return + + supported_sizes = backend_cls.get_supported_kernel_block_size() + supported_sizes = [ + s.base if isinstance(s, MultipleOf) else s for s in supported_sizes + ] + min_size = min(supported_sizes) + if cache_config.block_size is None: + new_block_size = min_size + else: + new_block_size = lcm(cache_config.block_size, min_size) + + if new_block_size != cache_config.block_size: + cache_config.block_size = new_block_size + logger.info( + "Setting attention block size to %d tokens " + "to align with %s attention backend's supported kernel block sizes.", + new_block_size, + backend_cls.get_name(), + ) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 33fa06fe0e9b..89d156d2621a 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -356,7 +356,11 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: dtype=kv_cache_dtype, ).page_size_bytes else: - kernel_block_alignment_size = 16 + if cache_config.block_size is not None: + kernel_block_alignment_size = cache_config.block_size + else: + kernel_block_alignment_size = 16 + if ( current_platform.is_device_capability(100) and model_config.get_head_size() == 256 diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index ddc63b902dff..ab1eab4834f1 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -173,7 +173,10 @@ def get_supported_kernel_block_size() -> list[int | MultipleOf]: # Note: Not sure for all platforms, # but on Blackwell, only support a page size of # 16, 32, 64 - return [16, 32, 64] + # TODO: 16 is temporary removed because TRT-LLM kernel has a bug when using 16. + # See https://github.com/flashinfer-ai/flashinfer/issues/1993 + # for more details. + return [32, 64] @classmethod def validate_head_size(cls, head_size: int) -> None: From 3c284e299a578ccc8cc7c646fd2d70a04864079c Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson Date: Tue, 4 Nov 2025 17:37:46 +0400 Subject: [PATCH 2/6] fixes Signed-off-by: Vadim Gimpelson --- vllm/attention/selector.py | 3 --- vllm/model_executor/layers/config.py | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 9c26a8d40eda..273766549fc3 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -5,7 +5,6 @@ from collections.abc import Generator from contextlib import contextmanager from dataclasses import dataclass -from functools import cache import torch @@ -145,7 +144,6 @@ def get_attn_backend( ) -@cache def _cached_get_attn_backend( head_size: int, dtype: torch.dtype, @@ -236,4 +234,3 @@ def global_force_attn_backend_context_manager( finally: # Revert the original global backend override, if any global_force_attn_backend(original_value) - _cached_get_attn_backend.cache_clear() diff --git a/vllm/model_executor/layers/config.py b/vllm/model_executor/layers/config.py index 49e7c542388b..cdb7abea361e 100644 --- a/vllm/model_executor/layers/config.py +++ b/vllm/model_executor/layers/config.py @@ -52,7 +52,7 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: else: new_block_size = lcm(cache_config.block_size, min_size) - if new_block_size != cache_config.block_size: + if cache_config.block_size is None or new_block_size != cache_config.block_size: cache_config.block_size = new_block_size logger.info( "Setting attention block size to %d tokens " From 2221dd148283b4604b5421e4722056ad7c9a1cf7 Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson Date: Tue, 4 Nov 2025 18:54:28 +0400 Subject: [PATCH 3/6] fixes Signed-off-by: Vadim Gimpelson --- tests/kernels/attention/test_attention_selector.py | 3 +-- tests/kernels/attention/test_mha_attn.py | 2 +- tests/kernels/attention/test_rocm_attention_selector.py | 3 +-- tests/v1/tpu/test_mha_attn.py | 3 +-- 4 files changed, 4 insertions(+), 7 deletions(-) diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index 8149ce7672cd..ca41d912e8ce 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -16,8 +16,7 @@ @pytest.fixture(autouse=True) def clear_cache(): """Clear lru cache to ensure each test case runs without caching.""" - _cached_get_attn_backend.cache_clear() - + pass # Define MLA and non-MLA backends separately DEVICE_MLA_BACKENDS = { diff --git a/tests/kernels/attention/test_mha_attn.py b/tests/kernels/attention/test_mha_attn.py index 14d1618bca3c..9c711c9d2c69 100644 --- a/tests/kernels/attention/test_mha_attn.py +++ b/tests/kernels/attention/test_mha_attn.py @@ -23,7 +23,7 @@ @pytest.fixture(autouse=True) def clear_cache(): """Clear lru cache to ensure each test case runs without caching.""" - _cached_get_attn_backend.cache_clear() + pass # Clear xformers availability cache import vllm.attention.layer as layer_module diff --git a/tests/kernels/attention/test_rocm_attention_selector.py b/tests/kernels/attention/test_rocm_attention_selector.py index 9b7fb664956c..cfa719ae11aa 100644 --- a/tests/kernels/attention/test_rocm_attention_selector.py +++ b/tests/kernels/attention/test_rocm_attention_selector.py @@ -12,8 +12,7 @@ @pytest.fixture(autouse=True) def clear_cache(): """Clear lru cache to ensure each test case runs without caching.""" - _cached_get_attn_backend.cache_clear() - + pass @pytest.mark.skip(reason="Skipped for now. Should be revisited.") def test_selector(monkeypatch: pytest.MonkeyPatch): diff --git a/tests/v1/tpu/test_mha_attn.py b/tests/v1/tpu/test_mha_attn.py index 5debdf85bea8..d7f39b5f6db0 100644 --- a/tests/v1/tpu/test_mha_attn.py +++ b/tests/v1/tpu/test_mha_attn.py @@ -20,8 +20,7 @@ @pytest.fixture(autouse=True) def clear_cache(): """Clear lru cache to ensure each test case runs without caching.""" - _cached_get_attn_backend.cache_clear() - + pass def ref_attention( query: torch.Tensor, From 3a5181483a14067d661916c9c3dd7b732e278451 Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson Date: Wed, 5 Nov 2025 04:27:24 +0400 Subject: [PATCH 4/6] avoid cuda initalisation during AttentionConfig.verify_and_update_config call Signed-off-by: Vadim Gimpelson --- vllm/attention/utils/fa_utils.py | 63 +++++++++++++++++++++++++++++--- 1 file changed, 58 insertions(+), 5 deletions(-) diff --git a/vllm/attention/utils/fa_utils.py b/vllm/attention/utils/fa_utils.py index adb9b08a6573..8a84d1f62a10 100644 --- a/vllm/attention/utils/fa_utils.py +++ b/vllm/attention/utils/fa_utils.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional, Tuple + from vllm import envs from vllm.logger import init_logger from vllm.platforms import current_platform @@ -12,6 +14,12 @@ reshape_and_cache_flash = ops.reshape_and_cache_flash from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata + from vllm.vllm_flash_attn.flash_attn_interface import ( + FA2_AVAILABLE, + FA2_UNAVAILABLE_REASON, + FA3_AVAILABLE, + FA3_UNAVAILABLE_REASON, + ) elif current_platform.is_xpu(): from vllm._ipex_ops import ipex_ops as ops @@ -20,6 +28,56 @@ get_scheduler_metadata = ops.get_scheduler_metadata +# Functions copied from vllm/vllm_flash_attn/flash_attn_interface.py +# Modified to use current_platform.get_device_capability() instead of +# torch.cuda.get_device_capability(device) because current_platform.get_device_capability() +# does not initialize CUDA. +def _is_fa2_supported(device=None) -> Tuple[bool, Optional[str]]: + if not FA2_AVAILABLE: + return False, f"FA2 is unavaible due to: {FA2_UNAVAILABLE_REASON}" + device_capability = current_platform.get_device_capability() + if device_capability.major < 8: + return ( + False, + "FA2 is only supported on devices with compute capability >= 8", + ) + return True, None + + +def _is_fa3_supported(device=None) -> Tuple[bool, Optional[str]]: + if not FA3_AVAILABLE: + return False, f"FA3 is unavaible due to: {FA3_UNAVAILABLE_REASON}" + device_capability = current_platform.get_device_capability() + if ( + device_capability.major < 8 + or device_capability.major >= 10 + or device_capability == (8, 6) + or device_capability == (8, 9) + ): + return ( + False, + "FA3 is only supported on devices with compute capability >= 8" + " excluding 8.6 and 8.9 and Blackwell archs (>=10)", + ) + return True, None + + +def is_fa_version_supported(fa_version: int, device=None) -> bool: + assert fa_version in [2, 3], f"Unsupported FA version: {fa_version}" + if fa_version == 2: + return _is_fa2_supported(device)[0] + elif fa_version == 3: + return _is_fa3_supported(device)[0] + + +def fa_version_unsupported_reason(fa_version: int, device=None) -> Optional[str]: + assert fa_version in [2, 3], f"Unsupported FA version: {fa_version}" + if fa_version == 2: + return _is_fa2_supported(device)[1] + elif fa_version == 3: + return _is_fa3_supported(device)[1] + + def get_flash_attn_version(requires_alibi: bool = False) -> int | None: # import here to avoid circular dependencies from vllm.platforms import current_platform @@ -27,11 +85,6 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None: if current_platform.is_xpu(): return 2 try: - from vllm.vllm_flash_attn.flash_attn_interface import ( - fa_version_unsupported_reason, - is_fa_version_supported, - ) - device_capability = current_platform.get_device_capability() assert device_capability is not None From 6e4d374a6b0ffa39ee62f732c077cf7b2cf49203 Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson Date: Wed, 5 Nov 2025 17:46:57 +0400 Subject: [PATCH 5/6] fix test Signed-off-by: Vadim Gimpelson --- tests/v1/worker/test_gpu_model_runner.py | 95 ++++++++++++------------ 1 file changed, 48 insertions(+), 47 deletions(-) diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index db0215511d32..f5583c2e1cd9 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -790,55 +790,56 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): cache_dtype="auto", ) parallel_config = ParallelConfig() - vllm_config = VllmConfig( - model_config=model_config, - cache_config=cache_config, - scheduler_config=scheduler_config, - parallel_config=parallel_config, - ) - - layer_0 = "model.layers.0.self_attn.attn" - layer_1 = "model.layers.1.self_attn.attn" - layer_2 = "model.layers.2.mixer" - layer_3 = "model.layers.3.mixer" - layer_4 = "model.layers.4.mixer" - layer_5 = "model.layers.5.mixer" - - with set_current_vllm_config(vllm_config), monkeypatch.context() as m: - m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") - hf_config = vllm_config.model_config.hf_config - fwd_context = {} - for key in [layer_0, layer_1]: - fwd_context[key] = Attention( - num_heads=model_config.get_num_attention_heads(parallel_config), - num_kv_heads=model_config.get_num_kv_heads(parallel_config), - head_size=model_config.get_head_size(), - scale=1.0, - prefix=key, - ) - for key in [layer_2, layer_3, layer_4, layer_5]: - fwd_context[key] = MambaMixer2( - hidden_size=hf_config.hidden_size, - ssm_state_size=hf_config.mamba_d_state, - conv_kernel_size=hf_config.mamba_d_conv, - intermediate_size=hf_config.mamba_expand * hf_config.hidden_size, - use_conv_bias=hf_config.mamba_conv_bias, - use_bias=hf_config.mamba_proj_bias, - n_groups=hf_config.mamba_n_groups, - num_heads=hf_config.mamba_n_heads, - head_dim=hf_config.mamba_d_head, - rms_norm_eps=hf_config.rms_norm_eps, - activation=hf_config.hidden_act, - cache_config=cache_config, - model_config=model_config, - prefix=key, - ) - # suppress var not used error - assert fwd_context is not None - vllm_ctx = vllm_config.compilation_config.static_forward_context - with monkeypatch.context() as m: + # Attention backend should be set before creating VllmConfig because + # VllmConfig will determine the kv block size based on the attention backend m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") + vllm_config = VllmConfig( + model_config=model_config, + cache_config=cache_config, + scheduler_config=scheduler_config, + parallel_config=parallel_config, + ) + + layer_0 = "model.layers.0.self_attn.attn" + layer_1 = "model.layers.1.self_attn.attn" + layer_2 = "model.layers.2.mixer" + layer_3 = "model.layers.3.mixer" + layer_4 = "model.layers.4.mixer" + layer_5 = "model.layers.5.mixer" + + with set_current_vllm_config(vllm_config): + hf_config = vllm_config.model_config.hf_config + fwd_context = {} + for key in [layer_0, layer_1]: + fwd_context[key] = Attention( + num_heads=model_config.get_num_attention_heads(parallel_config), + num_kv_heads=model_config.get_num_kv_heads(parallel_config), + head_size=model_config.get_head_size(), + scale=1.0, + prefix=key, + ) + for key in [layer_2, layer_3, layer_4, layer_5]: + fwd_context[key] = MambaMixer2( + hidden_size=hf_config.hidden_size, + ssm_state_size=hf_config.mamba_d_state, + conv_kernel_size=hf_config.mamba_d_conv, + intermediate_size=hf_config.mamba_expand * hf_config.hidden_size, + use_conv_bias=hf_config.mamba_conv_bias, + use_bias=hf_config.mamba_proj_bias, + n_groups=hf_config.mamba_n_groups, + num_heads=hf_config.mamba_n_heads, + head_dim=hf_config.mamba_d_head, + rms_norm_eps=hf_config.rms_norm_eps, + activation=hf_config.hidden_act, + cache_config=cache_config, + model_config=model_config, + prefix=key, + ) + # suppress var not used error + assert fwd_context is not None + vllm_ctx = vllm_config.compilation_config.static_forward_context + runner = GPUModelRunner(vllm_config, DEVICE) kv_cache_spec = runner.get_kv_cache_spec() From ecd76fd444b2f06ebd0de28dc57bd261ea6f1fab Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson Date: Thu, 6 Nov 2025 04:12:25 +0400 Subject: [PATCH 6/6] try to apply for all backend (before was FI only) Signed-off-by: Vadim Gimpelson --- vllm/model_executor/layers/config.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/vllm/model_executor/layers/config.py b/vllm/model_executor/layers/config.py index cdb7abea361e..be76c6c3d4a9 100644 --- a/vllm/model_executor/layers/config.py +++ b/vllm/model_executor/layers/config.py @@ -36,11 +36,6 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: if cache_config.block_size is not None else 16, ) - # For now enable it for FlashInfer only. - # Other backend need debugging. - # TODO: enable it for all backends. - if backend_cls.get_name() != "FLASHINFER": - return supported_sizes = backend_cls.get_supported_kernel_block_size() supported_sizes = [