Skip to content

Commit 3a51814

Browse files
committed
avoid cuda initalisation during AttentionConfig.verify_and_update_config call
Signed-off-by: Vadim Gimpelson <[email protected]>
1 parent 2221dd1 commit 3a51814

File tree

1 file changed

+58
-5
lines changed

1 file changed

+58
-5
lines changed

vllm/attention/utils/fa_utils.py

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
from typing import Optional, Tuple
5+
46
from vllm import envs
57
from vllm.logger import init_logger
68
from vllm.platforms import current_platform
@@ -12,6 +14,12 @@
1214

1315
reshape_and_cache_flash = ops.reshape_and_cache_flash
1416
from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata
17+
from vllm.vllm_flash_attn.flash_attn_interface import (
18+
FA2_AVAILABLE,
19+
FA2_UNAVAILABLE_REASON,
20+
FA3_AVAILABLE,
21+
FA3_UNAVAILABLE_REASON,
22+
)
1523
elif current_platform.is_xpu():
1624
from vllm._ipex_ops import ipex_ops as ops
1725

@@ -20,18 +28,63 @@
2028
get_scheduler_metadata = ops.get_scheduler_metadata
2129

2230

31+
# Functions copied from vllm/vllm_flash_attn/flash_attn_interface.py
32+
# Modified to use current_platform.get_device_capability() instead of
33+
# torch.cuda.get_device_capability(device) because current_platform.get_device_capability()
34+
# does not initialize CUDA.
35+
def _is_fa2_supported(device=None) -> Tuple[bool, Optional[str]]:
36+
if not FA2_AVAILABLE:
37+
return False, f"FA2 is unavaible due to: {FA2_UNAVAILABLE_REASON}"
38+
device_capability = current_platform.get_device_capability()
39+
if device_capability.major < 8:
40+
return (
41+
False,
42+
"FA2 is only supported on devices with compute capability >= 8",
43+
)
44+
return True, None
45+
46+
47+
def _is_fa3_supported(device=None) -> Tuple[bool, Optional[str]]:
48+
if not FA3_AVAILABLE:
49+
return False, f"FA3 is unavaible due to: {FA3_UNAVAILABLE_REASON}"
50+
device_capability = current_platform.get_device_capability()
51+
if (
52+
device_capability.major < 8
53+
or device_capability.major >= 10
54+
or device_capability == (8, 6)
55+
or device_capability == (8, 9)
56+
):
57+
return (
58+
False,
59+
"FA3 is only supported on devices with compute capability >= 8"
60+
" excluding 8.6 and 8.9 and Blackwell archs (>=10)",
61+
)
62+
return True, None
63+
64+
65+
def is_fa_version_supported(fa_version: int, device=None) -> bool:
66+
assert fa_version in [2, 3], f"Unsupported FA version: {fa_version}"
67+
if fa_version == 2:
68+
return _is_fa2_supported(device)[0]
69+
elif fa_version == 3:
70+
return _is_fa3_supported(device)[0]
71+
72+
73+
def fa_version_unsupported_reason(fa_version: int, device=None) -> Optional[str]:
74+
assert fa_version in [2, 3], f"Unsupported FA version: {fa_version}"
75+
if fa_version == 2:
76+
return _is_fa2_supported(device)[1]
77+
elif fa_version == 3:
78+
return _is_fa3_supported(device)[1]
79+
80+
2381
def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
2482
# import here to avoid circular dependencies
2583
from vllm.platforms import current_platform
2684

2785
if current_platform.is_xpu():
2886
return 2
2987
try:
30-
from vllm.vllm_flash_attn.flash_attn_interface import (
31-
fa_version_unsupported_reason,
32-
is_fa_version_supported,
33-
)
34-
3588
device_capability = current_platform.get_device_capability()
3689

3790
assert device_capability is not None

0 commit comments

Comments
 (0)