@@ -37,18 +37,18 @@ def forward(self, q: mindtorch.Tensor, k: mindtorch.Tensor, v: mindtorch.Tensor)
3737 k_cache , v_cache = self .k_cache , self .v_cache
3838 if k_cache .numel () and v_cache .numel ():
3939 store_kvcache (k , v , k_cache , v_cache , context .slot_mapping )
40- if context .is_prefill :
41- if context .block_tables is not None : # prefix cache
42- k , v = k_cache , v_cache
43- # pylint: disable=undefined-variable
44- o = flash_attn_varlen_func (q , k , v ,
45- max_seqlen_q = context .max_seqlen_q , cu_seqlens_q = context .cu_seqlens_q ,
46- max_seqlen_k = context .max_seqlen_k , cu_seqlens_k = context .cu_seqlens_k ,
47- softmax_scale = self .scale , causal = True , block_table = context .block_tables )
48- else : # decode
49- # flash_attn_with_kvcache is conditionally imported from flash_attn
50- # pylint: disable=undefined-variable
51- o = flash_attn_with_kvcache (q .unsqueeze (1 ), k_cache , v_cache , # noqa: F821
52- cache_seqlens = context .context_lens , block_table = context .block_tables ,
53- softmax_scale = self .scale , causal = True )
54- return o
40+ # if context.is_prefill:
41+ # if context.block_tables is not None: # prefix cache
42+ # k, v = k_cache, v_cache
43+ # # pylint: disable=undefined-variable
44+ # o = flash_attn_varlen_func(q, k, v,
45+ # max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
46+ # max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
47+ # softmax_scale=self.scale, causal=True, block_table=context.block_tables)
48+ # else: # decode
49+ # # flash_attn_with_kvcache is conditionally imported from flash_attn
50+ # # pylint: disable=undefined-variable
51+ # o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache, # noqa: F821
52+ # cache_seqlens=context.context_lens, block_table=context.block_tables,
53+ # softmax_scale=self.scale, causal=True)
54+ # return o
0 commit comments