diff --git a/onnxscript/rewriter/onnx_fusions/_onnx_fusions_test.py b/onnxscript/rewriter/rules/fusion/_rms_normalization_test.py similarity index 53% rename from onnxscript/rewriter/onnx_fusions/_onnx_fusions_test.py rename to onnxscript/rewriter/rules/fusion/_rms_normalization_test.py index 22d6120da1..e70c4ec7a0 100644 --- a/onnxscript/rewriter/onnx_fusions/_onnx_fusions_test.py +++ b/onnxscript/rewriter/rules/fusion/_rms_normalization_test.py @@ -5,14 +5,12 @@ import unittest import onnx_ir as ir -from parameterized import parameterized import onnxscript -from onnxscript.rewriter import onnx_fusions -from onnxscript.rewriter.models import _rotary_embedding_models +from onnxscript.rewriter.rules.fusion import _rms_normalization -class OnnxFusionsTest(unittest.TestCase): +class RmsNormOnnxFusionsTest(unittest.TestCase): def test_rms_normalization_fusion(self): opset23 = onnxscript.values.Opset("", 23) @@ -34,34 +32,10 @@ def rms_norm_script(embedding, layernorm_weight): output_types=[onnxscript.FLOAT[128]], ) model = ir.serde.deserialize_model(rms_norm_model_proto) - onnx_fusions.fuse(model, debug=True) + count = _rms_normalization.fuse_rms_normalization(model) + self.assertEqual(count, 1) self.assertEqual(model.graph.node(-1).op_type, "RMSNormalization") - @parameterized.expand( - [ - ( - "test_case_1", - _rotary_embedding_models.test_case_1, - ), - ( - "test_case_2", - _rotary_embedding_models.test_case_2, - ), - ] - ) - def test_rotary_embedding_fusion(self, _: str, test_data_constructor): - test = test_data_constructor() - for opset_version in [22, 23]: - model: ir.Model = test.get_onnx_model() - model.graph.opset_imports[""] = opset_version - onnxscript.optimizer.optimize(model) - onnx_fusions.fuse(model) - op_types = [n.op_type for n in model.graph] - if opset_version == 22: - self.assertNotIn("RotaryEmbedding", op_types) - else: - self.assertIn("RotaryEmbedding", op_types) - if __name__ == "__main__": unittest.main() diff --git a/onnxscript/rewriter/rules/fusion/_rotary_embedding.py b/onnxscript/rewriter/rules/fusion/_rotary_embedding.py index 2009c6953f..524b6f4806 100644 --- a/onnxscript/rewriter/rules/fusion/_rotary_embedding.py +++ b/onnxscript/rewriter/rules/fusion/_rotary_embedding.py @@ -30,13 +30,34 @@ def _rotate_half_pattern(op, x, start1, end1, start2, end2): class RotaryEmbedding23Fusion(pattern.RewriteRuleClassBase): def __init__(self): - super().__init__(name="RotaryEmbedding23") + super().__init__(name="RotaryEmbedding23", remove_nodes=False) - def pattern(self, op, x, cos, sin, start1, end1, start2, end2): - return x * cos + _rotate_half_pattern(op, x, start1, end1, start2, end2) * sin + def pattern(self, op, x, freqs, start1, end1, start2, end2, one1, one2): + freqs_repeated = op.Concat(freqs, freqs, axis=-1) + cos = op.Cos(freqs_repeated) + sin = op.Sin(freqs_repeated) + cos_4d = op.Unsqueeze(cos, one1) + sin_4d = op.Unsqueeze(sin, one2) + return x * cos_4d + _rotate_half_pattern(op, x, start1, end1, start2, end2) * sin_4d - def check(self, op, x, start1, end1, start2, end2, **_) -> pattern.MatchResult: # type: ignore[name-defined] + def check(self, op, x, start1, end1, start2, end2, one1, one2, **_) -> pattern.MatchResult: # type: ignore[name-defined] check_result = pattern.MatchResult() + + def is_one(val): + """Check if val is a 0/1 dimensional tensor with a single element equal to 1.""" + np_val = _ir_utils.get_numpy_value(val) + return ( + np_val is not None + and np_val.size == 1 + and np_val.ndim <= 1 + and np_val.item() == 1 + ) + + if not is_one(one1): + return check_result.fail("Unsqueeze axes is not [1]", one1) + if not is_one(one2): + return check_result.fail("Unsqueeze axes is not [1]", one2) + # x needs to be a 4D tensor with known last dimension size (== head_size) and known second dimension (num_heads) if x is None or x.shape is None or len(x.shape) != 4: return check_result.fail("Input is not known to be a 4D tensor.", x) @@ -59,8 +80,10 @@ def check(self, op, x, start1, end1, start2, end2, **_) -> pattern.MatchResult: ) return check_result - def rewrite(self, op, x, cos, sin, **_): + def rewrite(self, op, x, freqs, **_): num_heads = x.shape[1] + cos = op.Cos(freqs) + sin = op.Sin(freqs) return op.RotaryEmbedding( x, cos, diff --git a/onnxscript/rewriter/rules/fusion/_rotary_embedding_test.py b/onnxscript/rewriter/rules/fusion/_rotary_embedding_test.py new file mode 100644 index 0000000000..b8ffe95cac --- /dev/null +++ b/onnxscript/rewriter/rules/fusion/_rotary_embedding_test.py @@ -0,0 +1,53 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest + +import onnx +import onnx_ir as ir +from packaging.version import Version +from parameterized import parameterized + +import onnxscript +import onnxscript.rewriter.testing +from onnxscript.rewriter.models import _rotary_embedding_models +from onnxscript.rewriter.rules.fusion import _rotary_embedding + + +class RotaryEmbeddingOnnxFusionTest(unittest.TestCase): + @parameterized.expand( + [ + ( + "test_case_1", + _rotary_embedding_models.test_case_1, + ), + ( + "test_case_2", + _rotary_embedding_models.test_case_2, + ), + ] + ) + def test_rotary_embedding_fusion(self, _: str, test_data_constructor): + test = test_data_constructor() + model: ir.Model = test.get_onnx_model() + model.graph.opset_imports[""] = 23 + model_proto = ir.serde.serialize_model(model) + onnxscript.optimizer.optimize(model) + _rotary_embedding.fuse_rotary_embedding(model) + op_types = [n.op_type for n in model.graph] + self.assertIn("RotaryEmbedding", op_types) + rewritten_model_proto = ir.serde.serialize_model(model) + inputs = test.get_ort_inputs() + + onnx_version = Version(onnx.__version__) + min_version = Version("1.19.1") + is_stable = not (onnx_version.is_devrelease or onnx_version.is_prerelease) + if onnx_version >= min_version and is_stable: + onnxscript.rewriter.testing.assert_numerically_equal( + model_proto, rewritten_model_proto, args=inputs, use_reference=True + ) + + +if __name__ == "__main__": + unittest.main()