11# Copyright (c) OpenMMLab. All rights reserved.
2- from typing import Tuple
2+ from typing import Optional , Tuple
33
44import torch
55
@@ -19,6 +19,46 @@ def _get_meta_flashmla(kv_seqlens, num_attention_heads):
1919 return tile_scheduler_metadata , num_splits
2020
2121
22+ def _get_meta_flashattn (
23+ batch_size : int ,
24+ max_seqlen_q : int ,
25+ max_seqlen_k : int ,
26+ num_heads_q : int ,
27+ num_heads_kv : int ,
28+ headdim : int ,
29+ cache_seqlens : torch .Tensor ,
30+ qkv_dtype = torch .bfloat16 ,
31+ headdim_v = None ,
32+ cu_seqlens_q : Optional [torch .Tensor ] = None ,
33+ cu_seqlens_k_new : Optional [torch .Tensor ] = None ,
34+ page_size : Optional [int ] = None ,
35+ causal = True ,
36+ window_size = (- 1 , - 1 ), # -1 means infinite context window
37+ num_splits = 0 ,
38+ ):
39+ """Get scheduler metadata for flash attn."""
40+ from flash_attn_interface import get_scheduler_metadata
41+
42+ metadata = get_scheduler_metadata (
43+ batch_size ,
44+ max_seqlen_q ,
45+ max_seqlen_k ,
46+ num_heads_q ,
47+ num_heads_kv ,
48+ headdim ,
49+ cache_seqlens ,
50+ qkv_dtype = qkv_dtype ,
51+ headdim_v = headdim_v ,
52+ cu_seqlens_q = cu_seqlens_q ,
53+ cu_seqlens_k_new = cu_seqlens_k_new ,
54+ page_size = page_size ,
55+ causal = causal ,
56+ window_size = window_size ,
57+ num_splits = num_splits ,
58+ )
59+ return metadata
60+
61+
2262class CudaOpsBackend (DefaultOpsBackend ):
2363 """Cuda layer backend."""
2464
@@ -121,6 +161,28 @@ def update_meta_flashmla(cls, attn_metadata, num_attention_heads):
121161 if attn_metadata .block_offsets .dtype != torch .int32 :
122162 attn_metadata .block_offsets = attn_metadata .block_offsets .to (torch .int32 )
123163
164+ @classmethod
165+ def update_meta_flashattn (cls , attn_metadata , step_context ):
166+ batch_size = attn_metadata .q_seqlens .size (0 )
167+ max_seqlen_q = step_context .input_ids .size (1 ) // batch_size
168+ block_size = step_context .kv_caches [0 ][0 ].size (1 )
169+ window_size = (step_context .model_config .sliding_window , ) * 2
170+ scheduler_metadata = _get_meta_flashattn (
171+ batch_size = batch_size ,
172+ max_seqlen_q = max_seqlen_q ,
173+ max_seqlen_k = step_context .max_kv_seqlen ,
174+ num_heads_q = step_context .model_config .num_attention_heads ,
175+ num_heads_kv = step_context .model_config .num_key_value_heads ,
176+ headdim = step_context .model_config .head_dim ,
177+ cache_seqlens = attn_metadata .kv_seqlens .to (torch .int32 ),
178+ qkv_dtype = step_context .model_config .dtype ,
179+ page_size = block_size ,
180+ window_size = window_size ,
181+ )
182+ attn_metadata .scheduler_metadata = scheduler_metadata
183+ attn_metadata .max_kv_seqlen = step_context .max_kv_seqlen
184+ return attn_metadata
185+
124186 @classmethod
125187 def update_step_context (cls , step_context ):
126188 """Update step context."""
@@ -135,9 +197,10 @@ def update_step_context(cls, step_context):
135197 cu_seqlens_q = None
136198 cu_seqlens_k = None
137199 if use_flash_mla or use_flash_attn3 :
138- cu_seqlens_q = torch .nn .functional .pad (torch .cumsum (q_seqlens , dim = 0 , dtype = torch .int32 ), (1 , 0 ))
139- cu_seqlens_k = torch .nn .functional .pad (torch .cumsum (kv_seqlens , dim = 0 , dtype = torch .int32 ), (1 , 0 ))
140200 step_context .block_offsets = step_context .block_offsets .to (torch .int32 )
201+ if not step_context .is_decoding :
202+ cu_seqlens_q = torch .nn .functional .pad (torch .cumsum (q_seqlens , dim = 0 , dtype = torch .int32 ), (1 , 0 ))
203+ cu_seqlens_k = torch .nn .functional .pad (torch .cumsum (kv_seqlens , dim = 0 , dtype = torch .int32 ), (1 , 0 ))
141204
142205 if not step_context .is_decoding :
143206 kv_start_loc = kv_seqlens .cumsum (0 ) - kv_seqlens
@@ -160,6 +223,10 @@ def update_step_context(cls, step_context):
160223 cls .update_meta_flashmla (attn_metadata ,
161224 step_context .model_config .num_attention_heads * decode_query_len )
162225
226+ if use_flash_attn3 :
227+ if step_context .is_decoding is True :
228+ attn_metadata = cls .update_meta_flashattn (attn_metadata , step_context )
229+
163230 cross_seqlens = step_context .cross_seqlens
164231 cross_kv_seqlens = step_context .cross_kv_seqlens
165232 cross_attn_metadata = None
0 commit comments