Skip to content

Commit b9543b7

Browse files
author
wangxiaoxin-sherie
committed
fix fia error.
Signed-off-by: wangxiaoxin-sherie <[email protected]>
1 parent bb1610d commit b9543b7

File tree

2 files changed

+12
-41
lines changed

2 files changed

+12
-41
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,6 @@ class AscendMetadataForDecode:
183183
class AscendMetadata:
184184
# **************************** Basic Properties ************************** #
185185
attn_mask: Optional[torch.Tensor] = None
186-
fia_attn_mask: Optional[torch.Tensor] = None
187186
# Current state of this attention run.
188187
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
189188

@@ -312,21 +311,18 @@ def build(
312311
num_actual_tokens_pcp_padded]
313312
# slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
314313
attn_mask = common_attn_metadata.attn_mask
315-
fia_attn_mask = common_attn_metadata.fia_attn_mask
316314
attn_state = common_attn_metadata.attn_state
317315
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
318316
num_reqs
319317
+ 1]
320318
num_computed_tokens_cpu = (seq_lens - query_lens)
321319

322-
if attn_state == AscendAttentionState.DecodeOnly and \
323-
common_attn_metadata.num_input_tokens > num_actual_tokens:
320+
if common_attn_metadata.num_input_tokens > num_actual_tokens:
324321
padded_num_tokens = common_attn_metadata.num_input_tokens - num_actual_tokens
325322
seq_lens = torch.cat([
326323
seq_lens,
327-
torch.ones(padded_num_tokens,
328-
dtype=seq_lens.dtype,
329-
device=seq_lens.device)
324+
torch.tensor([padded_num_tokens
325+
]).to(seq_lens.device).to(seq_lens.dtype)
330326
])
331327
block_table_padding = torch.zeros(
332328
(padded_num_tokens, ) + block_table.shape[1:],
@@ -335,10 +331,8 @@ def build(
335331
block_table = torch.cat([block_table, block_table_padding], dim=0)
336332
query_start_loc_cpu = torch.cat([
337333
query_start_loc_cpu,
338-
torch.arange(query_start_loc_cpu[-1] + 1,
339-
query_start_loc_cpu[-1] + padded_num_tokens,
340-
dtype=query_start_loc_cpu.dtype,
341-
device=query_start_loc_cpu.device)
334+
torch.tensor([query_start_loc_cpu[-1] + padded_num_tokens]).to(
335+
query_start_loc_cpu.device).to(query_start_loc_cpu.dtype)
342336
])
343337

344338
query_start_loc = query_start_loc_cpu.to(self.device,
@@ -471,7 +465,6 @@ def build(
471465
actual_seq_lengths_q=query_start_loc_cpu[1:].tolist(),
472466
slot_mapping=slot_mapping,
473467
attn_mask=attn_mask,
474-
fia_attn_mask=fia_attn_mask,
475468
attn_state=attn_state,
476469
num_prefills=num_prefills,
477470
num_decodes=num_decodes,
@@ -604,7 +597,6 @@ def full_graph_attention(self,
604597
actual_seq_lengths_kv = attn_metadata.seq_lens_list
605598

606599
num_tokens = attn_metadata.query_start_loc_list[-1]
607-
query = query[:num_tokens]
608600
graph_params = get_graph_params()
609601
query_start_loc = attn_metadata.query_start_loc_list
610602
# Prepare tensors for attention output
@@ -618,7 +610,7 @@ def full_graph_attention(self,
618610
query=query,
619611
key=key,
620612
value=value,
621-
atten_mask=attn_metadata.fia_attn_mask,
613+
atten_mask=attn_metadata.attn_mask,
622614
block_table=block_table,
623615
input_layout="TND",
624616
block_size=block_size,
@@ -641,7 +633,7 @@ def full_graph_attention(self,
641633
graph_params.attn_params[num_tokens].append(
642634
(weak_ref_tensors(query), weak_ref_tensors(key),
643635
weak_ref_tensors(value), weak_ref_tensors(block_table),
644-
weak_ref_tensors(attn_metadata.fia_attn_mask), block_size,
636+
weak_ref_tensors(attn_metadata.attn_mask), block_size,
645637
actual_seq_lengths_kv, query_start_loc, self.num_kv_heads,
646638
self.num_heads, self.scale, weak_ref_tensors(output),
647639
weak_ref_tensors(softmax_lse)))
@@ -651,7 +643,7 @@ def full_graph_attention(self,
651643
query=query,
652644
key=key,
653645
value=value,
654-
atten_mask=attn_metadata.fia_attn_mask,
646+
atten_mask=attn_metadata.attn_mask,
655647
block_table=block_table,
656648
input_layout="TND",
657649
block_size=block_size,

vllm_ascend/worker/model_runner_v1.py

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
323323
self.attn_groups: list[list[AttentionGroup]] = []
324324
self.encoder_cache: Dict[str, torch.Tensor] = {}
325325
self.attn_mask = None
326-
self.fia_attn_mask = None
327326
self.attn_state = None
328327
self.requests: Dict[str, CachedRequestState] = {}
329328
self.intermediate_tensors: Optional[IntermediateTensors] = None
@@ -984,23 +983,6 @@ def _make_attention_mask(self, seq_lens, position,
984983
# Pooling situation.
985984
if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS":
986985
return self.attn_mask_builder.get_pooling_mask(self.device)
987-
# fia prefill situation.
988-
if attn_state in [
989-
AscendAttentionState.PrefillNoCache,
990-
AscendAttentionState.PrefillCacheHit,
991-
AscendAttentionState.ChunkedPrefill
992-
]:
993-
return self.attn_mask_builder.get_splitfuse_attn_mask()
994-
995-
# Decode-only situation.
996-
return None
997-
998-
def _make_fia_attention_mask(self) -> torch.Tensor:
999-
# pcp situation.
1000-
if self.pcp_size > 1:
1001-
return None
1002-
if self.attn_mask_builder is None:
1003-
raise ValueError("Attn mask builder is None")
1004986
return self.attn_mask_builder.get_splitfuse_attn_mask()
1005987

1006988
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
@@ -1581,7 +1563,6 @@ def _prepare_inputs(
15811563
self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu,
15821564
position=positions_cpu,
15831565
attn_state=attn_state)
1584-
self.fia_attn_mask = self._make_fia_attention_mask()
15851566
self.attn_state = attn_state # type: ignore
15861567

15871568
self.with_prefill = with_prefill
@@ -1806,7 +1787,6 @@ def _prepare_inputs(
18061787
num_computed_tokens_cpu=num_computed_tokens_cpu,
18071788
positions=self.positions,
18081789
attn_mask=self.attn_mask,
1809-
fia_attn_mask=self.fia_attn_mask,
18101790
spec_attn_mask=self.spec_attn_mask,
18111791
attn_state=self.attn_state,
18121792
is_only_prefill=bool(np.all(num_valid_tokens != 1)),
@@ -2729,10 +2709,10 @@ def _build_dummy_attn_metadata(
27292709
self.query_lens = torch.from_numpy(num_scheduled_tokens)
27302710

27312711
assigned_mask_dim = 2048
2732-
self.fia_attn_mask = torch.triu(torch.ones(assigned_mask_dim,
2733-
assigned_mask_dim),
2734-
diagonal=1).to(torch.int8).to(
2735-
self.device)
2712+
self.attn_mask = torch.triu(torch.ones(assigned_mask_dim,
2713+
assigned_mask_dim),
2714+
diagonal=1).to(torch.int8).to(
2715+
self.device)
27362716

27372717
num_computed_tokens_cpu = (
27382718
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
@@ -2776,7 +2756,6 @@ def _build_dummy_attn_metadata(
27762756
num_computed_tokens_cpu=num_computed_tokens_cpu,
27772757
positions=self.positions,
27782758
attn_mask=self.attn_mask,
2779-
fia_attn_mask=self.fia_attn_mask,
27802759
spec_attn_mask=self.spec_attn_mask,
27812760
attn_state=self.attn_state,
27822761
max_query_len=max_query_len,

0 commit comments

Comments
 (0)