Skip to content

Commit dddf0c2

Browse files
authored
Fix Onnx 23 Rotary Fusion (#2576)
Fix Onnx 23 Rotary Fusion --------- Signed-off-by: Ganesan Ramalingam <[email protected]>
1 parent 168fd8a commit dddf0c2

File tree

3 files changed

+85
-35
lines changed

3 files changed

+85
-35
lines changed

onnxscript/rewriter/onnx_fusions/_onnx_fusions_test.py renamed to onnxscript/rewriter/rules/fusion/_rms_normalization_test.py

Lines changed: 4 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,12 @@
55
import unittest
66

77
import onnx_ir as ir
8-
from parameterized import parameterized
98

109
import onnxscript
11-
from onnxscript.rewriter import onnx_fusions
12-
from onnxscript.rewriter.models import _rotary_embedding_models
10+
from onnxscript.rewriter.rules.fusion import _rms_normalization
1311

1412

15-
class OnnxFusionsTest(unittest.TestCase):
13+
class RmsNormOnnxFusionsTest(unittest.TestCase):
1614
def test_rms_normalization_fusion(self):
1715
opset23 = onnxscript.values.Opset("", 23)
1816

@@ -34,34 +32,10 @@ def rms_norm_script(embedding, layernorm_weight):
3432
output_types=[onnxscript.FLOAT[128]],
3533
)
3634
model = ir.serde.deserialize_model(rms_norm_model_proto)
37-
onnx_fusions.fuse(model, debug=True)
35+
count = _rms_normalization.fuse_rms_normalization(model)
36+
self.assertEqual(count, 1)
3837
self.assertEqual(model.graph.node(-1).op_type, "RMSNormalization")
3938

40-
@parameterized.expand(
41-
[
42-
(
43-
"test_case_1",
44-
_rotary_embedding_models.test_case_1,
45-
),
46-
(
47-
"test_case_2",
48-
_rotary_embedding_models.test_case_2,
49-
),
50-
]
51-
)
52-
def test_rotary_embedding_fusion(self, _: str, test_data_constructor):
53-
test = test_data_constructor()
54-
for opset_version in [22, 23]:
55-
model: ir.Model = test.get_onnx_model()
56-
model.graph.opset_imports[""] = opset_version
57-
onnxscript.optimizer.optimize(model)
58-
onnx_fusions.fuse(model)
59-
op_types = [n.op_type for n in model.graph]
60-
if opset_version == 22:
61-
self.assertNotIn("RotaryEmbedding", op_types)
62-
else:
63-
self.assertIn("RotaryEmbedding", op_types)
64-
6539

6640
if __name__ == "__main__":
6741
unittest.main()

onnxscript/rewriter/rules/fusion/_rotary_embedding.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,34 @@ def _rotate_half_pattern(op, x, start1, end1, start2, end2):
3030

3131
class RotaryEmbedding23Fusion(pattern.RewriteRuleClassBase):
3232
def __init__(self):
33-
super().__init__(name="RotaryEmbedding23")
33+
super().__init__(name="RotaryEmbedding23", remove_nodes=False)
3434

35-
def pattern(self, op, x, cos, sin, start1, end1, start2, end2):
36-
return x * cos + _rotate_half_pattern(op, x, start1, end1, start2, end2) * sin
35+
def pattern(self, op, x, freqs, start1, end1, start2, end2, one1, one2):
36+
freqs_repeated = op.Concat(freqs, freqs, axis=-1)
37+
cos = op.Cos(freqs_repeated)
38+
sin = op.Sin(freqs_repeated)
39+
cos_4d = op.Unsqueeze(cos, one1)
40+
sin_4d = op.Unsqueeze(sin, one2)
41+
return x * cos_4d + _rotate_half_pattern(op, x, start1, end1, start2, end2) * sin_4d
3742

38-
def check(self, op, x, start1, end1, start2, end2, **_) -> pattern.MatchResult: # type: ignore[name-defined]
43+
def check(self, op, x, start1, end1, start2, end2, one1, one2, **_) -> pattern.MatchResult: # type: ignore[name-defined]
3944
check_result = pattern.MatchResult()
45+
46+
def is_one(val):
47+
"""Check if val is a 0/1 dimensional tensor with a single element equal to 1."""
48+
np_val = _ir_utils.get_numpy_value(val)
49+
return (
50+
np_val is not None
51+
and np_val.size == 1
52+
and np_val.ndim <= 1
53+
and np_val.item() == 1
54+
)
55+
56+
if not is_one(one1):
57+
return check_result.fail("Unsqueeze axes is not [1]", one1)
58+
if not is_one(one2):
59+
return check_result.fail("Unsqueeze axes is not [1]", one2)
60+
4061
# x needs to be a 4D tensor with known last dimension size (== head_size) and known second dimension (num_heads)
4162
if x is None or x.shape is None or len(x.shape) != 4:
4263
return check_result.fail("Input is not known to be a 4D tensor.", x)
@@ -59,8 +80,10 @@ def check(self, op, x, start1, end1, start2, end2, **_) -> pattern.MatchResult:
5980
)
6081
return check_result
6182

62-
def rewrite(self, op, x, cos, sin, **_):
83+
def rewrite(self, op, x, freqs, **_):
6384
num_heads = x.shape[1]
85+
cos = op.Cos(freqs)
86+
sin = op.Sin(freqs)
6487
return op.RotaryEmbedding(
6588
x,
6689
cos,
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from __future__ import annotations
4+
5+
import unittest
6+
7+
import onnx
8+
import onnx_ir as ir
9+
from packaging.version import Version
10+
from parameterized import parameterized
11+
12+
import onnxscript
13+
import onnxscript.rewriter.testing
14+
from onnxscript.rewriter.models import _rotary_embedding_models
15+
from onnxscript.rewriter.rules.fusion import _rotary_embedding
16+
17+
18+
class RotaryEmbeddingOnnxFusionTest(unittest.TestCase):
19+
@parameterized.expand(
20+
[
21+
(
22+
"test_case_1",
23+
_rotary_embedding_models.test_case_1,
24+
),
25+
(
26+
"test_case_2",
27+
_rotary_embedding_models.test_case_2,
28+
),
29+
]
30+
)
31+
def test_rotary_embedding_fusion(self, _: str, test_data_constructor):
32+
test = test_data_constructor()
33+
model: ir.Model = test.get_onnx_model()
34+
model.graph.opset_imports[""] = 23
35+
model_proto = ir.serde.serialize_model(model)
36+
onnxscript.optimizer.optimize(model)
37+
_rotary_embedding.fuse_rotary_embedding(model)
38+
op_types = [n.op_type for n in model.graph]
39+
self.assertIn("RotaryEmbedding", op_types)
40+
rewritten_model_proto = ir.serde.serialize_model(model)
41+
inputs = test.get_ort_inputs()
42+
43+
onnx_version = Version(onnx.__version__)
44+
min_version = Version("1.19.1")
45+
is_stable = not (onnx_version.is_devrelease or onnx_version.is_prerelease)
46+
if onnx_version >= min_version and is_stable:
47+
onnxscript.rewriter.testing.assert_numerically_equal(
48+
model_proto, rewritten_model_proto, args=inputs, use_reference=True
49+
)
50+
51+
52+
if __name__ == "__main__":
53+
unittest.main()

0 commit comments

Comments
 (0)