|
1 | 1 | # SPDX-License-Identifier: Apache-2.0 |
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
3 | 3 |
|
| 4 | +from typing import Optional, Tuple |
| 5 | + |
4 | 6 | from vllm import envs |
5 | 7 | from vllm.logger import init_logger |
6 | 8 | from vllm.platforms import current_platform |
|
12 | 14 |
|
13 | 15 | reshape_and_cache_flash = ops.reshape_and_cache_flash |
14 | 16 | from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata |
| 17 | + from vllm.vllm_flash_attn.flash_attn_interface import ( |
| 18 | + FA2_AVAILABLE, |
| 19 | + FA2_UNAVAILABLE_REASON, |
| 20 | + FA3_AVAILABLE, |
| 21 | + FA3_UNAVAILABLE_REASON, |
| 22 | + ) |
15 | 23 | elif current_platform.is_xpu(): |
16 | 24 | from vllm._ipex_ops import ipex_ops as ops |
17 | 25 |
|
|
20 | 28 | get_scheduler_metadata = ops.get_scheduler_metadata |
21 | 29 |
|
22 | 30 |
|
| 31 | +# Functions copied from vllm/vllm_flash_attn/flash_attn_interface.py |
| 32 | +# Modified to use current_platform.get_device_capability() instead of |
| 33 | +# torch.cuda.get_device_capability(device) because current_platform.get_device_capability() |
| 34 | +# does not initialize CUDA. |
| 35 | +def _is_fa2_supported(device=None) -> Tuple[bool, Optional[str]]: |
| 36 | + if not FA2_AVAILABLE: |
| 37 | + return False, f"FA2 is unavaible due to: {FA2_UNAVAILABLE_REASON}" |
| 38 | + device_capability = current_platform.get_device_capability() |
| 39 | + if device_capability.major < 8: |
| 40 | + return ( |
| 41 | + False, |
| 42 | + "FA2 is only supported on devices with compute capability >= 8", |
| 43 | + ) |
| 44 | + return True, None |
| 45 | + |
| 46 | + |
| 47 | +def _is_fa3_supported(device=None) -> Tuple[bool, Optional[str]]: |
| 48 | + if not FA3_AVAILABLE: |
| 49 | + return False, f"FA3 is unavaible due to: {FA3_UNAVAILABLE_REASON}" |
| 50 | + device_capability = current_platform.get_device_capability() |
| 51 | + if ( |
| 52 | + device_capability.major < 8 |
| 53 | + or device_capability.major >= 10 |
| 54 | + or device_capability == (8, 6) |
| 55 | + or device_capability == (8, 9) |
| 56 | + ): |
| 57 | + return ( |
| 58 | + False, |
| 59 | + "FA3 is only supported on devices with compute capability >= 8" |
| 60 | + " excluding 8.6 and 8.9 and Blackwell archs (>=10)", |
| 61 | + ) |
| 62 | + return True, None |
| 63 | + |
| 64 | + |
| 65 | +def is_fa_version_supported(fa_version: int, device=None) -> bool: |
| 66 | + assert fa_version in [2, 3], f"Unsupported FA version: {fa_version}" |
| 67 | + if fa_version == 2: |
| 68 | + return _is_fa2_supported(device)[0] |
| 69 | + elif fa_version == 3: |
| 70 | + return _is_fa3_supported(device)[0] |
| 71 | + |
| 72 | + |
| 73 | +def fa_version_unsupported_reason(fa_version: int, device=None) -> Optional[str]: |
| 74 | + assert fa_version in [2, 3], f"Unsupported FA version: {fa_version}" |
| 75 | + if fa_version == 2: |
| 76 | + return _is_fa2_supported(device)[1] |
| 77 | + elif fa_version == 3: |
| 78 | + return _is_fa3_supported(device)[1] |
| 79 | + |
| 80 | + |
23 | 81 | def get_flash_attn_version(requires_alibi: bool = False) -> int | None: |
24 | 82 | # import here to avoid circular dependencies |
25 | 83 | from vllm.platforms import current_platform |
26 | 84 |
|
27 | 85 | if current_platform.is_xpu(): |
28 | 86 | return 2 |
29 | 87 | try: |
30 | | - from vllm.vllm_flash_attn.flash_attn_interface import ( |
31 | | - fa_version_unsupported_reason, |
32 | | - is_fa_version_supported, |
33 | | - ) |
34 | | - |
35 | 88 | device_capability = current_platform.get_device_capability() |
36 | 89 |
|
37 | 90 | assert device_capability is not None |
|
0 commit comments