Skip to content

Commit b336c7a

Browse files
fa4 support (#101)
Signed-off-by: Yongye Zhu <[email protected]> Co-authored-by: Lucas Wilkinson <[email protected]>
1 parent a893712 commit b336c7a

File tree

1 file changed

+22
-2
lines changed

1 file changed

+22
-2
lines changed

vllm_flash_attn/flash_attn_interface.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,14 @@
2525
FA3_UNAVAILABLE_REASON = str(e)
2626
FA3_AVAILABLE = False
2727

28+
try:
29+
from flash_attn.cute.interface import _flash_attn_fwd # noqa: F401
30+
FA4_UNAVAILABLE_REASON = None
31+
FA4_AVAILABLE = True
32+
except ImportError as e:
33+
FA4_UNAVAILABLE_REASON = str(e)
34+
FA4_AVAILABLE = False
35+
2836
# isort: on
2937

3038
DEFAULT_FA_VERSION = 2
@@ -49,20 +57,32 @@ def _is_fa3_supported(device = None) -> Tuple[bool, Optional[str]]:
4957
" excluding 8.6 and 8.9 and Blackwell archs (>=10)"
5058
return True, None
5159

60+
def _is_fa4_supported(device = None) -> Tuple[bool, Optional[str]]:
61+
if not FA4_AVAILABLE:
62+
return False, f"FA4 is unavaible due to: {FA4_UNAVAILABLE_REASON}"
63+
if torch.cuda.get_device_capability(device)[0] != 10:
64+
return False, \
65+
"FA4 is only supported on devices with compute capability == 10"
66+
return True, None
67+
5268
def is_fa_version_supported(fa_version: int, device = None) -> bool:
53-
assert fa_version in [2, 3], f"Unsupported FA version: {fa_version}"
69+
assert fa_version in [2, 3, 4], f"Unsupported FA version: {fa_version}"
5470
if fa_version == 2:
5571
return _is_fa2_supported(device)[0]
5672
elif fa_version == 3:
5773
return _is_fa3_supported(device)[0]
74+
elif fa_version == 4:
75+
return _is_fa4_supported(device)[0]
5876

5977
def fa_version_unsupported_reason(fa_version: int, device = None) \
6078
-> Optional[str]:
61-
assert fa_version in [2, 3], f"Unsupported FA version: {fa_version}"
79+
assert fa_version in [2, 3, 4], f"Unsupported FA version: {fa_version}"
6280
if fa_version == 2:
6381
return _is_fa2_supported(device)[1]
6482
elif fa_version == 3:
6583
return _is_fa3_supported(device)[1]
84+
elif fa_version == 4:
85+
return _is_fa4_supported(device)[1]
6686

6787
#
6888
# For vLLM we only care about `flash_attn_varlen_func` and

0 commit comments

Comments
 (0)