2525 FA3_UNAVAILABLE_REASON = str (e )
2626 FA3_AVAILABLE = False
2727
28+ try :
29+ from flash_attn .cute .interface import _flash_attn_fwd # noqa: F401
30+ FA4_UNAVAILABLE_REASON = None
31+ FA4_AVAILABLE = True
32+ except ImportError as e :
33+ FA4_UNAVAILABLE_REASON = str (e )
34+ FA4_AVAILABLE = False
35+
2836# isort: on
2937
3038DEFAULT_FA_VERSION = 2
@@ -49,20 +57,32 @@ def _is_fa3_supported(device = None) -> Tuple[bool, Optional[str]]:
4957 " excluding 8.6 and 8.9 and Blackwell archs (>=10)"
5058 return True , None
5159
60+ def _is_fa4_supported (device = None ) -> Tuple [bool , Optional [str ]]:
61+ if not FA4_AVAILABLE :
62+ return False , f"FA4 is unavaible due to: { FA4_UNAVAILABLE_REASON } "
63+ if torch .cuda .get_device_capability (device )[0 ] != 10 :
64+ return False , \
65+ "FA4 is only supported on devices with compute capability == 10"
66+ return True , None
67+
5268def is_fa_version_supported (fa_version : int , device = None ) -> bool :
53- assert fa_version in [2 , 3 ], f"Unsupported FA version: { fa_version } "
69+ assert fa_version in [2 , 3 , 4 ], f"Unsupported FA version: { fa_version } "
5470 if fa_version == 2 :
5571 return _is_fa2_supported (device )[0 ]
5672 elif fa_version == 3 :
5773 return _is_fa3_supported (device )[0 ]
74+ elif fa_version == 4 :
75+ return _is_fa4_supported (device )[0 ]
5876
5977def fa_version_unsupported_reason (fa_version : int , device = None ) \
6078 -> Optional [str ]:
61- assert fa_version in [2 , 3 ], f"Unsupported FA version: { fa_version } "
79+ assert fa_version in [2 , 3 , 4 ], f"Unsupported FA version: { fa_version } "
6280 if fa_version == 2 :
6381 return _is_fa2_supported (device )[1 ]
6482 elif fa_version == 3 :
6583 return _is_fa3_supported (device )[1 ]
84+ elif fa_version == 4 :
85+ return _is_fa4_supported (device )[1 ]
6686
6787#
6888# For vLLM we only care about `flash_attn_varlen_func` and
0 commit comments