11import mindtorch
22from mindtorch import nn
3- import triton
4- import triton .language as tl
53
64from ..utils .context import get_context
75
86
97def store_kvcache (key : mindtorch .Tensor , value : mindtorch .Tensor , k_cache : mindtorch .Tensor , v_cache : mindtorch .Tensor , slot_mapping : mindtorch .Tensor ):
8+ # pylint: disable=undefined-variable
9+ # These are conditionally imported from flash_attn or other backends
1010 N , num_heads , head_dim = key .shape
1111 D = num_heads * head_dim
1212 assert key .stride (- 1 ) == 1 and value .stride (- 1 ) == 1
@@ -40,12 +40,15 @@ def forward(self, q: mindtorch.Tensor, k: mindtorch.Tensor, v: mindtorch.Tensor)
4040 if context .is_prefill :
4141 if context .block_tables is not None : # prefix cache
4242 k , v = k_cache , v_cache
43+ # pylint: disable=undefined-variable
4344 o = flash_attn_varlen_func (q , k , v ,
4445 max_seqlen_q = context .max_seqlen_q , cu_seqlens_q = context .cu_seqlens_q ,
4546 max_seqlen_k = context .max_seqlen_k , cu_seqlens_k = context .cu_seqlens_k ,
4647 softmax_scale = self .scale , causal = True , block_table = context .block_tables )
4748 else : # decode
48- o = flash_attn_with_kvcache (q .unsqueeze (1 ), k_cache , v_cache ,
49- cache_seqlens = context .context_lens , block_table = context .block_tables ,
49+ # flash_attn_with_kvcache is conditionally imported from flash_attn
50+ # pylint: disable=undefined-variable
51+ o = flash_attn_with_kvcache (q .unsqueeze (1 ), k_cache , v_cache , # noqa: F821
52+ cache_seqlens = context .context_lens , block_table = context .block_tables ,
5053 softmax_scale = self .scale , causal = True )
51- return o
54+ return o
0 commit comments