Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions onnxscript/rewriter/rules/fusion/_rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,13 @@ class RotaryEmbedding23Fusion(pattern.RewriteRuleClassBase):
def __init__(self):
super().__init__(name="RotaryEmbedding23")

def pattern(self, op, x, cos, sin, start1, end1, start2, end2):
return x * cos + _rotate_half_pattern(op, x, start1, end1, start2, end2) * sin
def pattern(self, op, x, freqs, start1, end1, start2, end2):
freqs_repeated = op.Concat(freqs, freqs, axis=-1)
cos = op.Cos(freqs_repeated)
sin = op.Sin(freqs_repeated)
cos_4d = op.Unsqueeze(cos, 1)
sin_4d = op.Unsqueeze(sin, 1)
return x * cos_4d + _rotate_half_pattern(op, x, start1, end1, start2, end2) * sin_4d

def check(self, op, x, start1, end1, start2, end2, **_) -> pattern.MatchResult: # type: ignore[name-defined]
check_result = pattern.MatchResult()
Expand All @@ -59,8 +64,10 @@ def check(self, op, x, start1, end1, start2, end2, **_) -> pattern.MatchResult:
)
return check_result

def rewrite(self, op, x, cos, sin, **_):
def rewrite(self, op, x, freqs, **_):
num_heads = x.shape[1]
cos = op.Cos(freqs)
sin = op.Sin(freqs)
return op.RotaryEmbedding(
x,
cos,
Expand Down
49 changes: 49 additions & 0 deletions onnxscript/rewriter/rules/fusion/_rotary_embedding_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import unittest

import onnx_ir as ir
from parameterized import parameterized

import onnxscript
from onnxscript.rewriter import onnx_fusions
from onnxscript.rewriter.models import _rotary_embedding_models
import onnxscript.rewriter.testing

class RotaryEmbeddingOnnxFusionTest(unittest.TestCase):
@parameterized.expand(
[
(
"test_case_1",
_rotary_embedding_models.test_case_1,
),
(
"test_case_2",
_rotary_embedding_models.test_case_2,
),
]
)
def test_rotary_embedding_fusion(self, _: str, test_data_constructor):
test = test_data_constructor()
for opset_version in [22, 23]:
model: ir.Model = test.get_onnx_model()
model.graph.opset_imports[""] = opset_version
model_proto = ir.serde.serialize_model(model)
onnxscript.optimizer.optimize(model)
onnx_fusions.fuse(model)
op_types = [n.op_type for n in model.graph]
if opset_version == 22:
self.assertNotIn("RotaryEmbedding", op_types)
else:
self.assertIn("RotaryEmbedding", op_types)
rewritten_model_proto = ir.serde.serialize_model(model)
inputs = test.get_ort_inputs()
onnxscript.rewriter.testing.assert_numerically_equal(
model_proto, rewritten_model_proto, args=inputs, use_reference=True
)


if __name__ == "__main__":
unittest.main()
Loading