Skip to content

Commit e7957f4

Browse files
committed
Fix the Eagle3 inference failure issue. 1、Hardcoding version 8.3 for judgment is not conducive to maintenance. It is recommended to define the version condition as a constant 2、Magic number: The value 0 for max_seq_len should be defined as a constant
Signed-off-by:sunchendd <[email protected]>
1 parent 6e4dd8e commit e7957f4

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

vllm_ascend/attention/attention_mask.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
import torch
1616

1717

18+
MIN_CANN_VERSION_FOR_OPTIMIZED_MASK = "8.3"
19+
DEFAULT_MAX_SEQ_LEN = 0
20+
21+
1822
def _generate_attn_mask(max_seq_len, dtype):
1923
# Construct lower triangle matrix.
2024
mask_flag = torch.ones((max_seq_len, max_seq_len),
@@ -88,8 +92,9 @@ def get_splitfuse_attn_mask(
8892
) -> torch.Tensor:
8993
cann_version = getattr(torch.version, "cann", "")
9094
target_device = device or self.device
91-
use_chunked_mask = (seq_lens is None or position is None
92-
or dtype is None or cann_version.startswith("8.3"))
95+
use_chunked_mask = (
96+
seq_lens is None or position is None or dtype is None
97+
or cann_version.startswith(MIN_CANN_VERSION_FOR_OPTIMIZED_MASK))
9398

9499
if use_chunked_mask:
95100
if target_device is None:
@@ -106,7 +111,8 @@ def get_splitfuse_attn_mask(
106111
if target_device is None:
107112
raise ValueError(
108113
"splitfuse_attn_mask requires device for non-chunked mask")
109-
max_seq_len = seq_lens.max().item() if seq_lens.numel() > 0 else 0
114+
max_seq_len = (seq_lens.max().item()
115+
if seq_lens.numel() > 0 else DEFAULT_MAX_SEQ_LEN)
110116
self._update_attn_cache(max_seq_len, dtype)
111117
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
112118
# is not the same. Fix this in the future when kernel is ready.

0 commit comments

Comments
 (0)