Skip to content

Commit 6e4dd8e

Browse files
sunchendongsunchendd
authored andcommitted
[Bugfix]Fix the Eagle3 inference failure issue.
Signed-off-by:sunchendd <[email protected]>
1 parent 96c3623 commit 6e4dd8e

File tree

3 files changed

+38
-10
lines changed

3 files changed

+38
-10
lines changed

vllm_ascend/attention/attention_mask.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,36 @@ def get_splitfuse_attn_mask(
8686
dtype: torch.dtype = None,
8787
device: torch.device = None,
8888
) -> torch.Tensor:
89-
return self.chunked_prefill_attn_mask
89+
cann_version = getattr(torch.version, "cann", "")
90+
target_device = device or self.device
91+
use_chunked_mask = (seq_lens is None or position is None
92+
or dtype is None or cann_version.startswith("8.3"))
93+
94+
if use_chunked_mask:
95+
if target_device is None:
96+
raise ValueError(
97+
"splitfuse_attn_mask requires device when using chunked mask"
98+
)
99+
100+
return self.chunked_prefill_attn_mask.to(target_device,
101+
non_blocking=True)
102+
103+
if dtype not in [torch.float16, torch.bfloat16]:
104+
raise ValueError(
105+
"splitfuse_attn_mask now only supports bf16 and fp16")
106+
if target_device is None:
107+
raise ValueError(
108+
"splitfuse_attn_mask requires device for non-chunked mask")
109+
max_seq_len = seq_lens.max().item() if seq_lens.numel() > 0 else 0
110+
self._update_attn_cache(max_seq_len, dtype)
111+
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
112+
# is not the same. Fix this in the future when kernel is ready.
113+
mask_scale_factor = AttentionMaskBuilder.get_mask_scale_factor(dtype)
114+
attn_mask = torch.index_select(self.attn_mask_cache,
115+
dim=0,
116+
index=position)[:, :max_seq_len]
117+
attn_mask *= mask_scale_factor
118+
return attn_mask.contiguous().to(target_device, non_blocking=True)
90119

91120
def _update_attn_cache(self, seqlen: int, dtype: torch.dtype):
92121
if seqlen > self._seq_len_cached:

vllm_ascend/spec_decode/eagle_proposer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def __init__(self,
7272
dtype=torch.int32)
7373
attn_mask_len = self.vllm_config.model_config.max_model_len
7474
self.attn_mask_builder = AttentionMaskBuilder(
75-
attn_mask_len, self.vllm_config.model_config.dtype)
75+
attn_mask_len, self.vllm_config.model_config.dtype, device=device)
7676

7777
def load_model(self, model: nn.Module) -> None:
7878
target_attn_layer_names = set(
@@ -422,9 +422,7 @@ def _propose(
422422

423423
query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
424424
max_query_len = query_lens.max().item()
425-
attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask(
426-
seq_lens, target_positions, self.vllm_config.model_config.dtype,
427-
self.device)
425+
attn_mask = self.runner.attn_mask
428426

429427
common_attn_metadata = AscendCommonAttentionMetadata(
430428
query_start_loc=cu_num_tokens.to(device),

vllm_ascend/worker/model_runner_v1.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -904,8 +904,8 @@ def _make_attention_mask(self, seq_lens, position,
904904
max_seq_len, self.dtype, self.device)
905905
# Prefill with cache hit.
906906
elif attn_state == AscendAttentionState.PrefillCacheHit:
907-
return self.attn_mask_builder.get_attn_mask(
908-
128, self.dtype, self.device)
907+
return self.attn_mask_builder.get_splitfuse_attn_mask().to(
908+
torch.bool)
909909
# Decode-only situation.
910910
else:
911911
return None
@@ -1587,10 +1587,11 @@ def _build_attn_state(self, num_reqs, num_scheduled_tokens,
15871587
attn_state = AscendAttentionState.SpecDecoding
15881588
# Speculative decoding.
15891589
elif np.all(num_valid_tokens == 1):
1590-
if self.speculative_config and self.speculative_config.method == 'deepseek_mtp':
1591-
attn_state = AscendAttentionState.SpecDecoding
1592-
else:
1590+
if self.drafter and self.drafter.name in (SpecDcodeType.EAGLE,
1591+
SpecDcodeType.EAGLE3):
15931592
attn_state = AscendAttentionState.ChunkedPrefill
1593+
else:
1594+
attn_state = AscendAttentionState.SpecDecoding
15941595
# splitfuse
15951596
elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled:
15961597
attn_state = AscendAttentionState.ChunkedPrefill

0 commit comments

Comments
 (0)