Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,12 @@
import unittest

import onnx_ir as ir
from parameterized import parameterized

import onnxscript
from onnxscript.rewriter import onnx_fusions
from onnxscript.rewriter.models import _rotary_embedding_models
from onnxscript.rewriter.rules.fusion import _rms_normalization


class OnnxFusionsTest(unittest.TestCase):
class RmsNormOnnxFusionsTest(unittest.TestCase):
def test_rms_normalization_fusion(self):
opset23 = onnxscript.values.Opset("", 23)

Expand All @@ -34,34 +32,10 @@ def rms_norm_script(embedding, layernorm_weight):
output_types=[onnxscript.FLOAT[128]],
)
model = ir.serde.deserialize_model(rms_norm_model_proto)
onnx_fusions.fuse(model, debug=True)
count = _rms_normalization.fuse_rms_normalization(model)
self.assertEqual(count, 1)
self.assertEqual(model.graph.node(-1).op_type, "RMSNormalization")

@parameterized.expand(
[
(
"test_case_1",
_rotary_embedding_models.test_case_1,
),
(
"test_case_2",
_rotary_embedding_models.test_case_2,
),
]
)
def test_rotary_embedding_fusion(self, _: str, test_data_constructor):
test = test_data_constructor()
for opset_version in [22, 23]:
model: ir.Model = test.get_onnx_model()
model.graph.opset_imports[""] = opset_version
onnxscript.optimizer.optimize(model)
onnx_fusions.fuse(model)
op_types = [n.op_type for n in model.graph]
if opset_version == 22:
self.assertNotIn("RotaryEmbedding", op_types)
else:
self.assertIn("RotaryEmbedding", op_types)


if __name__ == "__main__":
unittest.main()
33 changes: 28 additions & 5 deletions onnxscript/rewriter/rules/fusion/_rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,34 @@ def _rotate_half_pattern(op, x, start1, end1, start2, end2):

class RotaryEmbedding23Fusion(pattern.RewriteRuleClassBase):
def __init__(self):
super().__init__(name="RotaryEmbedding23")
super().__init__(name="RotaryEmbedding23", remove_nodes=False)

def pattern(self, op, x, cos, sin, start1, end1, start2, end2):
return x * cos + _rotate_half_pattern(op, x, start1, end1, start2, end2) * sin
def pattern(self, op, x, freqs, start1, end1, start2, end2, one1, one2):
freqs_repeated = op.Concat(freqs, freqs, axis=-1)
cos = op.Cos(freqs_repeated)
sin = op.Sin(freqs_repeated)
cos_4d = op.Unsqueeze(cos, one1)
sin_4d = op.Unsqueeze(sin, one2)
return x * cos_4d + _rotate_half_pattern(op, x, start1, end1, start2, end2) * sin_4d

def check(self, op, x, start1, end1, start2, end2, **_) -> pattern.MatchResult: # type: ignore[name-defined]
def check(self, op, x, start1, end1, start2, end2, one1, one2, **_) -> pattern.MatchResult: # type: ignore[name-defined]
check_result = pattern.MatchResult()

def is_one(val):
"""Check if val is a 0/1 dimensional tensor with a single element equal to 1."""
np_val = _ir_utils.get_numpy_value(val)
return (
np_val is not None
and np_val.size == 1
and np_val.ndim <= 1
and np_val.item() == 1
)

if not is_one(one1):
return check_result.fail("Unsqueeze axes is not [1]", one1)
if not is_one(one2):
return check_result.fail("Unsqueeze axes is not [1]", one2)

# x needs to be a 4D tensor with known last dimension size (== head_size) and known second dimension (num_heads)
if x is None or x.shape is None or len(x.shape) != 4:
return check_result.fail("Input is not known to be a 4D tensor.", x)
Expand All @@ -59,8 +80,10 @@ def check(self, op, x, start1, end1, start2, end2, **_) -> pattern.MatchResult:
)
return check_result

def rewrite(self, op, x, cos, sin, **_):
def rewrite(self, op, x, freqs, **_):
num_heads = x.shape[1]
cos = op.Cos(freqs)
sin = op.Sin(freqs)
return op.RotaryEmbedding(
x,
cos,
Expand Down
53 changes: 53 additions & 0 deletions onnxscript/rewriter/rules/fusion/_rotary_embedding_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import unittest

import onnx
import onnx_ir as ir
from packaging.version import Version
from parameterized import parameterized

import onnxscript
import onnxscript.rewriter.testing
from onnxscript.rewriter.models import _rotary_embedding_models
from onnxscript.rewriter.rules.fusion import _rotary_embedding


class RotaryEmbeddingOnnxFusionTest(unittest.TestCase):
@parameterized.expand(
[
(
"test_case_1",
_rotary_embedding_models.test_case_1,
),
(
"test_case_2",
_rotary_embedding_models.test_case_2,
),
]
)
def test_rotary_embedding_fusion(self, _: str, test_data_constructor):
test = test_data_constructor()
model: ir.Model = test.get_onnx_model()
model.graph.opset_imports[""] = 23
model_proto = ir.serde.serialize_model(model)
onnxscript.optimizer.optimize(model)
_rotary_embedding.fuse_rotary_embedding(model)
op_types = [n.op_type for n in model.graph]
self.assertIn("RotaryEmbedding", op_types)
rewritten_model_proto = ir.serde.serialize_model(model)
inputs = test.get_ort_inputs()

onnx_version = Version(onnx.__version__)
min_version = Version("1.19.1")
is_stable = not (onnx_version.is_devrelease or onnx_version.is_prerelease)
if onnx_version >= min_version and is_stable:

Check failure

Code scanning / CodeQL

Potentially uninitialized local variable Error

Local variable 'is_stable' may be used before it is initialized.

Copilot Autofix

AI 18 days ago

Copilot could not generate an autofix suggestion

Copilot could not generate an autofix suggestion for this alert. Try pushing a new commit or if the problem persists contact support.

onnxscript.rewriter.testing.assert_numerically_equal(
model_proto, rewritten_model_proto, args=inputs, use_reference=True
)


if __name__ == "__main__":
unittest.main()
Loading