@@ -183,7 +183,6 @@ class AscendMetadataForDecode:
183183class AscendMetadata :
184184 # **************************** Basic Properties ************************** #
185185 attn_mask : Optional [torch .Tensor ] = None
186- fia_attn_mask : Optional [torch .Tensor ] = None
187186 # Current state of this attention run.
188187 attn_state : AscendAttentionState = AscendAttentionState .ChunkedPrefill
189188
@@ -312,21 +311,18 @@ def build(
312311 num_actual_tokens_pcp_padded ]
313312 # slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
314313 attn_mask = common_attn_metadata .attn_mask
315- fia_attn_mask = common_attn_metadata .fia_attn_mask
316314 attn_state = common_attn_metadata .attn_state
317315 query_start_loc_cpu = common_attn_metadata .query_start_loc_cpu [:
318316 num_reqs
319317 + 1 ]
320318 num_computed_tokens_cpu = (seq_lens - query_lens )
321319
322- if attn_state == AscendAttentionState .DecodeOnly and \
323- common_attn_metadata .num_input_tokens > num_actual_tokens :
320+ if common_attn_metadata .num_input_tokens > num_actual_tokens :
324321 padded_num_tokens = common_attn_metadata .num_input_tokens - num_actual_tokens
325322 seq_lens = torch .cat ([
326323 seq_lens ,
327- torch .ones (padded_num_tokens ,
328- dtype = seq_lens .dtype ,
329- device = seq_lens .device )
324+ torch .tensor ([padded_num_tokens
325+ ]).to (seq_lens .device ).to (seq_lens .dtype )
330326 ])
331327 block_table_padding = torch .zeros (
332328 (padded_num_tokens , ) + block_table .shape [1 :],
@@ -335,10 +331,8 @@ def build(
335331 block_table = torch .cat ([block_table , block_table_padding ], dim = 0 )
336332 query_start_loc_cpu = torch .cat ([
337333 query_start_loc_cpu ,
338- torch .arange (query_start_loc_cpu [- 1 ] + 1 ,
339- query_start_loc_cpu [- 1 ] + padded_num_tokens ,
340- dtype = query_start_loc_cpu .dtype ,
341- device = query_start_loc_cpu .device )
334+ torch .tensor ([query_start_loc_cpu [- 1 ] + padded_num_tokens ]).to (
335+ query_start_loc_cpu .device ).to (query_start_loc_cpu .dtype )
342336 ])
343337
344338 query_start_loc = query_start_loc_cpu .to (self .device ,
@@ -471,7 +465,6 @@ def build(
471465 actual_seq_lengths_q = query_start_loc_cpu [1 :].tolist (),
472466 slot_mapping = slot_mapping ,
473467 attn_mask = attn_mask ,
474- fia_attn_mask = fia_attn_mask ,
475468 attn_state = attn_state ,
476469 num_prefills = num_prefills ,
477470 num_decodes = num_decodes ,
@@ -604,7 +597,6 @@ def full_graph_attention(self,
604597 actual_seq_lengths_kv = attn_metadata .seq_lens_list
605598
606599 num_tokens = attn_metadata .query_start_loc_list [- 1 ]
607- query = query [:num_tokens ]
608600 graph_params = get_graph_params ()
609601 query_start_loc = attn_metadata .query_start_loc_list
610602 # Prepare tensors for attention output
@@ -618,7 +610,7 @@ def full_graph_attention(self,
618610 query = query ,
619611 key = key ,
620612 value = value ,
621- atten_mask = attn_metadata .fia_attn_mask ,
613+ atten_mask = attn_metadata .attn_mask ,
622614 block_table = block_table ,
623615 input_layout = "TND" ,
624616 block_size = block_size ,
@@ -641,7 +633,7 @@ def full_graph_attention(self,
641633 graph_params .attn_params [num_tokens ].append (
642634 (weak_ref_tensors (query ), weak_ref_tensors (key ),
643635 weak_ref_tensors (value ), weak_ref_tensors (block_table ),
644- weak_ref_tensors (attn_metadata .fia_attn_mask ), block_size ,
636+ weak_ref_tensors (attn_metadata .attn_mask ), block_size ,
645637 actual_seq_lengths_kv , query_start_loc , self .num_kv_heads ,
646638 self .num_heads , self .scale , weak_ref_tensors (output ),
647639 weak_ref_tensors (softmax_lse )))
@@ -651,7 +643,7 @@ def full_graph_attention(self,
651643 query = query ,
652644 key = key ,
653645 value = value ,
654- atten_mask = attn_metadata .fia_attn_mask ,
646+ atten_mask = attn_metadata .attn_mask ,
655647 block_table = block_table ,
656648 input_layout = "TND" ,
657649 block_size = block_size ,
0 commit comments