Skip to content

Commit 756ba4a

Browse files
committed
fix(utils): update method of getting ub size
Signed-off-by: zhoux77899 <[email protected]>
1 parent 717f8a0 commit 756ba4a

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

vllm_ascend/ops/rotary_embedding.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,12 @@
3535

3636
@lru_cache(maxsize=128)
3737
def maybe_exceed_ub_size(q_n: int, k_n: int, dtype: torch.dtype,
38-
soc_version: AscendSocVersion) -> bool:
39-
if soc_version not in {AscendSocVersion.A2, AscendSocVersion.A3}:
38+
ascend_device_type: AscendDeviceType) -> bool:
39+
if ascend_device_type in {AscendDeviceType._910B, AscendDeviceType._910_93}:
40+
ub_size = 192 * 1024
41+
elif ascend_device_type in {AscendDeviceType._310P, AscendDeviceType._910_95}:
42+
ub_size = 248 * 1024
43+
else:
4044
logger.warning(
4145
"Cannot get correct UB size, may fail to run rotary_embedding")
4246
return False
@@ -46,7 +50,6 @@ def maybe_exceed_ub_size(q_n: int, k_n: int, dtype: torch.dtype,
4650
ub_required = (q_n + k_n) * 128 * cast_size * 2 + 128 * dtype_size * 4 + (
4751
q_n + k_n) * 128 * cast_size + (
4852
q_n + k_n) * 128 * cast_size * 2 + cast * 128 * 4 * 2
49-
ub_size = 192 * 1024
5053
return ub_required > ub_size
5154

5255

@@ -99,7 +102,7 @@ def _rope_forward_oot(
99102
q_n=query_head_num,
100103
k_n=key_head_num,
101104
dtype=query.dtype,
102-
soc_version=get_ascend_soc_version(),
105+
ascend_device_type=get_ascend_device_type(),
103106
):
104107
query = query.contiguous().view(1, query.shape[0], -1,
105108
self.head_size)

0 commit comments

Comments
 (0)