Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions tests/kernels/attention/test_attention_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
@pytest.fixture(autouse=True)
def clear_cache():
"""Clear lru cache to ensure each test case runs without caching."""
_cached_get_attn_backend.cache_clear()

pass

# Define MLA and non-MLA backends separately
DEVICE_MLA_BACKENDS = {
Expand Down
2 changes: 1 addition & 1 deletion tests/kernels/attention/test_mha_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
@pytest.fixture(autouse=True)
def clear_cache():
"""Clear lru cache to ensure each test case runs without caching."""
_cached_get_attn_backend.cache_clear()
pass
# Clear xformers availability cache
import vllm.attention.layer as layer_module

Expand Down
3 changes: 1 addition & 2 deletions tests/kernels/attention/test_rocm_attention_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
@pytest.fixture(autouse=True)
def clear_cache():
"""Clear lru cache to ensure each test case runs without caching."""
_cached_get_attn_backend.cache_clear()

pass

@pytest.mark.skip(reason="Skipped for now. Should be revisited.")
def test_selector(monkeypatch: pytest.MonkeyPatch):
Expand Down
3 changes: 1 addition & 2 deletions tests/v1/tpu/test_mha_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
@pytest.fixture(autouse=True)
def clear_cache():
"""Clear lru cache to ensure each test case runs without caching."""
_cached_get_attn_backend.cache_clear()

pass

def ref_attention(
query: torch.Tensor,
Expand Down
95 changes: 48 additions & 47 deletions tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,55 +790,56 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
cache_dtype="auto",
)
parallel_config = ParallelConfig()
vllm_config = VllmConfig(
model_config=model_config,
cache_config=cache_config,
scheduler_config=scheduler_config,
parallel_config=parallel_config,
)

layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
layer_2 = "model.layers.2.mixer"
layer_3 = "model.layers.3.mixer"
layer_4 = "model.layers.4.mixer"
layer_5 = "model.layers.5.mixer"

with set_current_vllm_config(vllm_config), monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
hf_config = vllm_config.model_config.hf_config
fwd_context = {}
for key in [layer_0, layer_1]:
fwd_context[key] = Attention(
num_heads=model_config.get_num_attention_heads(parallel_config),
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(),
scale=1.0,
prefix=key,
)
for key in [layer_2, layer_3, layer_4, layer_5]:
fwd_context[key] = MambaMixer2(
hidden_size=hf_config.hidden_size,
ssm_state_size=hf_config.mamba_d_state,
conv_kernel_size=hf_config.mamba_d_conv,
intermediate_size=hf_config.mamba_expand * hf_config.hidden_size,
use_conv_bias=hf_config.mamba_conv_bias,
use_bias=hf_config.mamba_proj_bias,
n_groups=hf_config.mamba_n_groups,
num_heads=hf_config.mamba_n_heads,
head_dim=hf_config.mamba_d_head,
rms_norm_eps=hf_config.rms_norm_eps,
activation=hf_config.hidden_act,
cache_config=cache_config,
model_config=model_config,
prefix=key,
)
# suppress var not used error
assert fwd_context is not None
vllm_ctx = vllm_config.compilation_config.static_forward_context

with monkeypatch.context() as m:
# Attention backend should be set before creating VllmConfig because
# VllmConfig will determine the kv block size based on the attention backend
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
vllm_config = VllmConfig(
model_config=model_config,
cache_config=cache_config,
scheduler_config=scheduler_config,
parallel_config=parallel_config,
)

layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
layer_2 = "model.layers.2.mixer"
layer_3 = "model.layers.3.mixer"
layer_4 = "model.layers.4.mixer"
layer_5 = "model.layers.5.mixer"

with set_current_vllm_config(vllm_config):
hf_config = vllm_config.model_config.hf_config
fwd_context = {}
for key in [layer_0, layer_1]:
fwd_context[key] = Attention(
num_heads=model_config.get_num_attention_heads(parallel_config),
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(),
scale=1.0,
prefix=key,
)
for key in [layer_2, layer_3, layer_4, layer_5]:
fwd_context[key] = MambaMixer2(
hidden_size=hf_config.hidden_size,
ssm_state_size=hf_config.mamba_d_state,
conv_kernel_size=hf_config.mamba_d_conv,
intermediate_size=hf_config.mamba_expand * hf_config.hidden_size,
use_conv_bias=hf_config.mamba_conv_bias,
use_bias=hf_config.mamba_proj_bias,
n_groups=hf_config.mamba_n_groups,
num_heads=hf_config.mamba_n_heads,
head_dim=hf_config.mamba_d_head,
rms_norm_eps=hf_config.rms_norm_eps,
activation=hf_config.hidden_act,
cache_config=cache_config,
model_config=model_config,
prefix=key,
)
# suppress var not used error
assert fwd_context is not None
vllm_ctx = vllm_config.compilation_config.static_forward_context


runner = GPUModelRunner(vllm_config, DEVICE)
kv_cache_spec = runner.get_kv_cache_spec()
Expand Down
3 changes: 0 additions & 3 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from collections.abc import Generator
from contextlib import contextmanager
from dataclasses import dataclass
from functools import cache

import torch

Expand Down Expand Up @@ -145,7 +144,6 @@ def get_attn_backend(
)


@cache
def _cached_get_attn_backend(
head_size: int,
dtype: torch.dtype,
Expand Down Expand Up @@ -236,4 +234,3 @@ def global_force_attn_backend_context_manager(
finally:
# Revert the original global backend override, if any
global_force_attn_backend(original_value)
_cached_get_attn_backend.cache_clear()
63 changes: 58 additions & 5 deletions vllm/attention/utils/fa_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import Optional, Tuple

from vllm import envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
Expand All @@ -12,6 +14,12 @@

reshape_and_cache_flash = ops.reshape_and_cache_flash
from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata
from vllm.vllm_flash_attn.flash_attn_interface import (
FA2_AVAILABLE,
FA2_UNAVAILABLE_REASON,
FA3_AVAILABLE,
FA3_UNAVAILABLE_REASON,
)
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops as ops

Expand All @@ -20,18 +28,63 @@
get_scheduler_metadata = ops.get_scheduler_metadata


# Functions copied from vllm/vllm_flash_attn/flash_attn_interface.py
# Modified to use current_platform.get_device_capability() instead of

Check failure on line 32 in vllm/attention/utils/fa_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/attention/utils/fa_utils.py:32:89: E501 Line too long (91 > 88)
# torch.cuda.get_device_capability(device) because current_platform.get_device_capability()
# does not initialize CUDA.
def _is_fa2_supported(device=None) -> Tuple[bool, Optional[str]]:
if not FA2_AVAILABLE:
return False, f"FA2 is unavaible due to: {FA2_UNAVAILABLE_REASON}"
device_capability = current_platform.get_device_capability()
if device_capability.major < 8:
return (
False,
"FA2 is only supported on devices with compute capability >= 8",
)
return True, None


def _is_fa3_supported(device=None) -> Tuple[bool, Optional[str]]:
if not FA3_AVAILABLE:
return False, f"FA3 is unavaible due to: {FA3_UNAVAILABLE_REASON}"
device_capability = current_platform.get_device_capability()
if (
device_capability.major < 8
or device_capability.major >= 10
or device_capability == (8, 6)
or device_capability == (8, 9)
):
return (
False,
"FA3 is only supported on devices with compute capability >= 8"
" excluding 8.6 and 8.9 and Blackwell archs (>=10)",
)
return True, None


Check failure on line 64 in vllm/attention/utils/fa_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Missing return statement [return]

Check failure on line 64 in vllm/attention/utils/fa_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Missing return statement [return]

Check failure on line 64 in vllm/attention/utils/fa_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Missing return statement [return]

Check failure on line 64 in vllm/attention/utils/fa_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Missing return statement [return]
def is_fa_version_supported(fa_version: int, device=None) -> bool:
assert fa_version in [2, 3], f"Unsupported FA version: {fa_version}"
if fa_version == 2:
return _is_fa2_supported(device)[0]
elif fa_version == 3:
return _is_fa3_supported(device)[0]


Check failure on line 72 in vllm/attention/utils/fa_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Missing return statement [return]

Check failure on line 72 in vllm/attention/utils/fa_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Missing return statement [return]

Check failure on line 72 in vllm/attention/utils/fa_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Missing return statement [return]

Check failure on line 72 in vllm/attention/utils/fa_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Missing return statement [return]
def fa_version_unsupported_reason(fa_version: int, device=None) -> Optional[str]:
assert fa_version in [2, 3], f"Unsupported FA version: {fa_version}"
if fa_version == 2:
return _is_fa2_supported(device)[1]
elif fa_version == 3:
return _is_fa3_supported(device)[1]


def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
# import here to avoid circular dependencies
from vllm.platforms import current_platform

if current_platform.is_xpu():
return 2
try:
from vllm.vllm_flash_attn.flash_attn_interface import (
fa_version_unsupported_reason,
is_fa_version_supported,
)

device_capability = current_platform.get_device_capability()

assert device_capability is not None
Expand Down
4 changes: 4 additions & 0 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,10 @@ def try_verify_and_update_config(self):
if architecture is None:
return

from vllm.model_executor.layers.config import AttentionConfig

AttentionConfig.verify_and_update_config(self)

from vllm.model_executor.models.config import (
MODELS_CONFIG_MAP,
HybridAttentionMambaModelConfig,
Expand Down
57 changes: 57 additions & 0 deletions vllm/model_executor/layers/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from math import lcm
from typing import TYPE_CHECKING

from vllm.attention.backends.abstract import MultipleOf
from vllm.attention.selector import get_attn_backend
from vllm.logger import init_logger
from vllm.model_executor.models.config import VerifyAndUpdateConfig

if TYPE_CHECKING:
from vllm.config import VllmConfig

logger = init_logger(__name__)


class AttentionConfig(VerifyAndUpdateConfig):
@classmethod
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
"""
Align cache_config.block_size with attention backend's minimum
supported kernel block size.

Args:
vllm_config: vLLM Config
"""
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
assert cache_config is not None

backend_cls = get_attn_backend(
head_size=model_config.get_head_size(),
dtype=model_config.dtype,
kv_cache_dtype=cache_config.cache_dtype,
block_size=cache_config.block_size
if cache_config.block_size is not None
else 16,
)

supported_sizes = backend_cls.get_supported_kernel_block_size()
supported_sizes = [
s.base if isinstance(s, MultipleOf) else s for s in supported_sizes
]
min_size = min(supported_sizes)
if cache_config.block_size is None:
new_block_size = min_size
else:
new_block_size = lcm(cache_config.block_size, min_size)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer to raise an error if the user sets a block_size but the block_size is not supported by the attention backend it selects.


if cache_config.block_size is None or new_block_size != cache_config.block_size:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need to add info-level logging if block_size is None and is initialized normally.

cache_config.block_size = new_block_size
logger.info(
"Setting attention block size to %d tokens "
"to align with %s attention backend's supported kernel block sizes.",
new_block_size,
backend_cls.get_name(),
)
6 changes: 5 additions & 1 deletion vllm/model_executor/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,11 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
dtype=kv_cache_dtype,
).page_size_bytes
else:
kernel_block_alignment_size = 16
if cache_config.block_size is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this part called before or after AttentionConfig.verify_and_update_config(self)?

I think the kernel_block_alignment_size should be resolved from backend_cls.get_supported_kernel_block_size

kernel_block_alignment_size = cache_config.block_size
else:
kernel_block_alignment_size = 16

if (
current_platform.is_device_capability(100)
and model_config.get_head_size() == 256
Expand Down
5 changes: 4 additions & 1 deletion vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,10 @@ def get_supported_kernel_block_size() -> list[int | MultipleOf]:
# Note: Not sure for all platforms,
# but on Blackwell, only support a page size of
# 16, 32, 64
return [16, 32, 64]
# TODO: 16 is temporary removed because TRT-LLM kernel has a bug when using 16.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the problem only exist in trtllm, what about only remove 16 for trtllm like

if current_platform.is_device_capability(100):
    return [32, 64]
else:
    return [16, 32, 64]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, we should only override on Blackwell

# See https://github.com/flashinfer-ai/flashinfer/issues/1993
# for more details.
return [32, 64]

@classmethod
def validate_head_size(cls, head_size: int) -> None:
Expand Down
Loading