11# Copyright (c) OpenMMLab. All rights reserved.
22
3+ import functools
34from dataclasses import dataclass
45from typing import Literal
56
2021 assert torch .ops .flash_attn_3 is not None
2122 use_fa3 = True
2223except Exception :
23- logger .warning ('For higher performance, please install FlashAttention-3 '
24- 'https://github.com/Dao-AILab/flash-attention' )
24+ logger .debug ('For higher performance, please install FlashAttention-3 '
25+ 'https://github.com/Dao-AILab/flash-attention' )
2526
2627
2728@dataclass
@@ -221,6 +222,15 @@ def forward(
221222 return attn_output
222223
223224
225+ @functools .lru_cache
226+ def use_fa3_warning ():
227+ if use_fa3 :
228+ return True
229+ logger .warning ('For higher performance, please install FlashAttention-3 '
230+ 'https://github.com/Dao-AILab/flash-attention' )
231+ return False
232+
233+
224234class FlashMLAImpl (TritonAttentionImpl ):
225235
226236 def __init__ (
@@ -255,6 +265,7 @@ def __init__(
255265 from lmdeploy .pytorch .kernels .cuda import flash_mla_fwd
256266 self .flash_mla_fwd = flash_mla_fwd
257267 assert num_kv_heads == 1 , 'MLA requires num kv heads equal to 1'
268+ use_fa3_warning ()
258269
259270 def forward (
260271 self ,
@@ -515,6 +526,14 @@ def forward(
515526 return attn_output
516527
517528
529+ @functools .lru_cache
530+ def _enable_fa3 (alibi : bool , learnable_sink : bool , block_sparse_size : int ):
531+ enable = not alibi and not learnable_sink and block_sparse_size == 1
532+ if enable and not use_fa3_warning ():
533+ enable = False
534+ return enable
535+
536+
518537class TritonAttentionBuilder (AttentionBuilder [TritonAttentionMetadata ]):
519538 """Triton attention builder."""
520539
@@ -535,8 +554,9 @@ def build(
535554 ** kwargs ,
536555 ) -> TritonAttentionImpl :
537556 """build."""
538- enable_fa3 = use_fa3 and not alibi and not learnable_sink and block_sparse_size == 1
557+ enable_fa3 = _enable_fa3 ( alibi , learnable_sink , block_sparse_size )
539558 if use_flash_mla is True :
559+ logger .debug ('Build FlashMLAImpl Attention' )
540560 return FlashMLAImpl (num_heads ,
541561 head_size ,
542562 scale = scale ,
@@ -548,6 +568,7 @@ def build(
548568 causal = causal ,
549569 ** kwargs )
550570 elif enable_fa3 :
571+ logger .debug ('Build FA3Impl Attention' )
551572 return FA3Impl (num_heads ,
552573 head_size ,
553574 scale = scale ,
@@ -559,6 +580,7 @@ def build(
559580 causal = causal ,
560581 ** kwargs )
561582 else :
583+ logger .debug ('Build TritonAttentionImpl Attention' )
562584 return TritonAttentionImpl (num_heads ,
563585 head_size ,
564586 scale = scale ,
0 commit comments