Skip to content

Commit cc6e14d

Browse files
authored
【Hackathon 9th No.46】add test_fused_rotary_position_encoding (#3848)
* add test_fused_rotary_position_encoding * 添加版权 * fix according to the review
1 parent 24180fb commit cc6e14d

File tree

1 file changed

+134
-0
lines changed

1 file changed

+134
-0
lines changed
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
import paddle
19+
20+
from fastdeploy.model_executor.ops.gpu import fused_rotary_position_encoding
21+
22+
23+
class TestFusedRotaryPositionEncoding(unittest.TestCase):
24+
def setUp(self):
25+
paddle.set_device("gpu")
26+
np.random.seed(42)
27+
28+
def _make_cos_sin_cache(self, max_position: int, rot_dim: int) -> np.ndarray:
29+
"""Generate cos/sin cache."""
30+
assert rot_dim % 2 == 0, "rot_dim must be even"
31+
half_dim = rot_dim // 2
32+
inv_freq = 1.0 / (10000 ** (np.arange(0, half_dim).astype("float32") / half_dim))
33+
positions = np.arange(max_position, dtype="float32")
34+
freqs = np.outer(positions, inv_freq) # [max_position, half_dim]
35+
cos_np = np.cos(freqs)
36+
sin_np = np.sin(freqs)
37+
return np.concatenate([cos_np, sin_np], axis=1).astype("float32")
38+
39+
def _ref_rotary(self, query, key, position_ids, cos_sin_cache, head_size, is_neox):
40+
"""Numpy reference implementation."""
41+
num_tokens, num_heads, _ = query.shape
42+
num_kv_heads = key.shape[1]
43+
rot_dim = cos_sin_cache.shape[1]
44+
embed_dim = rot_dim // 2
45+
46+
query_ref = query.copy()
47+
key_ref = key.copy()
48+
49+
for t in range(num_tokens):
50+
pos = position_ids[t]
51+
cos_ptr = cos_sin_cache[pos, :embed_dim]
52+
sin_ptr = cos_sin_cache[pos, embed_dim:]
53+
54+
for h in range(num_heads):
55+
arr = query_ref[t, h]
56+
for i in range(embed_dim):
57+
if is_neox:
58+
x_idx, y_idx = i, embed_dim + i
59+
cos, sin = cos_ptr[i], sin_ptr[i]
60+
else:
61+
x_idx, y_idx = 2 * i, 2 * i + 1
62+
cos, sin = cos_ptr[i], sin_ptr[i]
63+
x, y = arr[x_idx], arr[y_idx]
64+
arr[x_idx] = x * cos - y * sin
65+
arr[y_idx] = y * cos + x * sin
66+
67+
for h in range(num_kv_heads):
68+
arr = key_ref[t, h]
69+
for i in range(embed_dim):
70+
if is_neox:
71+
x_idx, y_idx = i, embed_dim + i
72+
cos, sin = cos_ptr[i], sin_ptr[i]
73+
else:
74+
x_idx, y_idx = 2 * i, 2 * i + 1
75+
cos, sin = cos_ptr[i], sin_ptr[i]
76+
x, y = arr[x_idx], arr[y_idx]
77+
arr[x_idx] = x * cos - y * sin
78+
arr[y_idx] = y * cos + x * sin
79+
80+
return query_ref, key_ref
81+
82+
def _run_op(
83+
self,
84+
query_np: np.ndarray,
85+
key_np: np.ndarray,
86+
position_ids_np: np.ndarray,
87+
cos_sin_cache_np: np.ndarray,
88+
head_size: int,
89+
is_neox: bool,
90+
):
91+
"""Run fused_rotary_position_encoding operator."""
92+
query = paddle.to_tensor(query_np, dtype="float32")
93+
key = paddle.to_tensor(key_np, dtype="float32")
94+
position_ids = paddle.to_tensor(position_ids_np, dtype="int32")
95+
cos_sin_cache = paddle.to_tensor(cos_sin_cache_np, dtype="float32")
96+
97+
fused_rotary_position_encoding(query, key, position_ids, cos_sin_cache, head_size, is_neox)
98+
return query.numpy(), key.numpy()
99+
100+
def _check_correctness(self, num_tokens, num_heads, num_kv_heads, head_size, rot_dim, is_neox):
101+
query_np = np.random.rand(num_tokens, num_heads, head_size).astype("float32")
102+
key_np = np.random.rand(num_tokens, num_kv_heads, head_size).astype("float32")
103+
position_ids_np = np.arange(num_tokens, dtype="int32")
104+
cos_sin_cache_np = self._make_cos_sin_cache(num_tokens, rot_dim)
105+
106+
query_out, key_out = self._run_op(query_np, key_np, position_ids_np, cos_sin_cache_np, head_size, is_neox)
107+
query_ref, key_ref = self._ref_rotary(query_np, key_np, position_ids_np, cos_sin_cache_np, head_size, is_neox)
108+
109+
np.testing.assert_allclose(query_out, query_ref, rtol=1e-5, atol=1e-6)
110+
np.testing.assert_allclose(key_out, key_ref, rtol=1e-5, atol=1e-6)
111+
112+
def test_basic_case(self):
113+
self._check_correctness(num_tokens=4, num_heads=2, num_kv_heads=2, head_size=6, rot_dim=4, is_neox=False)
114+
115+
def test_neox_mode(self):
116+
self._check_correctness(num_tokens=3, num_heads=2, num_kv_heads=2, head_size=8, rot_dim=8, is_neox=True)
117+
118+
def test_large_num_tokens(self):
119+
self._check_correctness(num_tokens=10, num_heads=2, num_kv_heads=2, head_size=4, rot_dim=4, is_neox=False)
120+
121+
def test_exceed_max_tokens(self):
122+
num_tokens, num_heads, head_size = 65537, 1, 4
123+
num_kv_heads, rot_dim = 1, 4
124+
query_np = np.random.rand(num_tokens, num_heads, head_size).astype("float32")
125+
key_np = np.random.rand(num_tokens, num_kv_heads, head_size).astype("float32")
126+
position_ids_np = np.arange(num_tokens, dtype="int32")
127+
cos_sin_cache_np = self._make_cos_sin_cache(num_tokens, rot_dim)
128+
129+
with self.assertRaises(Exception):
130+
self._run_op(query_np, key_np, position_ids_np, cos_sin_cache_np, head_size, is_neox=False)
131+
132+
133+
if __name__ == "__main__":
134+
unittest.main()

0 commit comments

Comments
 (0)