From 326ade7cb1f0997f434add3c329b85331ac097ca Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 27 Dec 2024 17:46:02 -0800 Subject: [PATCH] Update double squeeze rewrite rule --- onnxscript/rewriter/llama_rule_sets.py | 54 ++++++++------------- onnxscript/rewriter/llama_rule_sets_test.py | 23 +++++++++ 2 files changed, 43 insertions(+), 34 deletions(-) diff --git a/onnxscript/rewriter/llama_rule_sets.py b/onnxscript/rewriter/llama_rule_sets.py index faf81eeb73..a6b24b7141 100644 --- a/onnxscript/rewriter/llama_rule_sets.py +++ b/onnxscript/rewriter/llama_rule_sets.py @@ -4,10 +4,10 @@ from typing import ClassVar -import numpy as np import onnx.numpy_helper import onnxscript.ir as ir +import onnxscript.rewriter._ir_utils as ir_utils import onnxscript.rewriter.no_op as no_op import onnxscript.rewriter.pattern as orp @@ -230,42 +230,37 @@ class UnsqueezeUnsqueeze(orp.RewriteRuleAsClass): def pattern(cls, op, x, axes1, axes2): return op.Unsqueeze(op.Unsqueeze(x, axes1), axes2) - @classmethod - def _combine_axes(cls, axes1: np.ndarray, axes2: np.ndarray) -> np.ndarray: - """Combines two single axes into one tensor of two axes.""" - if axes1[0] < axes2[0]: - return np.hstack([axes1, axes2]) - return np.hstack([axes2, axes1 + 1]).astype(np.int64) - @classmethod def rewrite(cls, op, x: ir.Value, axes1: ir.Value, axes2: ir.Value): - assert axes1.const_value is not None - assert axes2.const_value is not None - axes = cls._combine_axes(axes1.const_value.numpy(), axes2.const_value.numpy()) - return op.Unsqueeze(x, op.Constant(value=onnx.numpy_helper.from_array(axes))) + v1 = ir_utils.get_singleton_value(axes1) + v2 = ir_utils.get_singleton_value(axes2) + axes = [v1, v2] if v1 < v2 else [v2, v1 + 1] + return op.Unsqueeze(x, op.Constant(value=ir.tensor(axes, dtype=ir.DataType.INT64))) @classmethod def check(cls, context, x, axes1, axes2) -> bool: del context # Unused del x # Unused - if axes1.const_value is None or axes2.const_value is None: - return False - - v1 = axes1.const_value.numpy() - v2 = axes2.const_value.numpy() - if not v1.shape or not v2.shape: - return False - if v1.shape[0] != 1 or v2.shape[0] != 1: - # Implemented later if needed. + # Currently restricted to single element positive axis + v1 = ir_utils.get_singleton_value(axes1) + v2 = ir_utils.get_singleton_value(axes2) + if v1 is None or v2 is None: return False - if v1.min() < 0: + if (v1 < 0) or (v2 < 0): return False - if v2.min() < 0: - return False - return True +cast_cast_rule = orp.make_rewrite_rule_from_class(CastCast) +cast_identity_rule = orp.make_rewrite_rule_from_class(CastIdentity) +expand_identity_rule = orp.make_rewrite_rule_from_class(ExpandIdentity) +reshape_reshape_rule = orp.make_rewrite_rule_from_class(ReshapeReshape) +slice_split_rule = orp.make_rewrite_rule_from_class(SlicesSplit, True) +transpose_identity_rule = orp.make_rewrite_rule_from_class(TransposeIdentity) +transpose_transpose_rule = orp.make_rewrite_rule_from_class(TransposeTranspose) +unsqueeze_unsqueeze_rule = orp.make_rewrite_rule_from_class(UnsqueezeUnsqueeze) + + def llama_p0_rule_set() -> orp.RewriteRuleSet: """Returns a set of rules which should be applied before any other one as they usually remove unnecessary computation @@ -274,15 +269,6 @@ def llama_p0_rule_set() -> orp.RewriteRuleSet: Returns: RewriteRuleSet """ - cast_cast_rule = orp.make_rewrite_rule_from_class(CastCast) - cast_identity_rule = orp.make_rewrite_rule_from_class(CastIdentity) - expand_identity_rule = orp.make_rewrite_rule_from_class(ExpandIdentity) - reshape_reshape_rule = orp.make_rewrite_rule_from_class(ReshapeReshape) - slice_split_rule = orp.make_rewrite_rule_from_class(SlicesSplit, True) - transpose_identity_rule = orp.make_rewrite_rule_from_class(TransposeIdentity) - transpose_transpose_rule = orp.make_rewrite_rule_from_class(TransposeTranspose) - unsqueeze_unsqueeze_rule = orp.make_rewrite_rule_from_class(UnsqueezeUnsqueeze) - return orp.RewriteRuleSet( [ no_op.mul_by_1_rule, diff --git a/onnxscript/rewriter/llama_rule_sets_test.py b/onnxscript/rewriter/llama_rule_sets_test.py index 2415130c70..0d430760f4 100644 --- a/onnxscript/rewriter/llama_rule_sets_test.py +++ b/onnxscript/rewriter/llama_rule_sets_test.py @@ -309,6 +309,29 @@ def test_llama_p0_rule_set_expand_identity( opset_imports=[onnx.helper.make_opsetid("", 18)], ), ), + ( + "double_unsqueezes_3", + _make_model( + onnx.helper.make_graph( + [ + onnx.helper.make_node("Unsqueeze", ["X", "axes1"], ["Xu"]), + onnx.helper.make_node("Unsqueeze", ["Xu", "axes2"], ["Y"]), + ], + "name", + [onnx.helper.make_tensor_value_info("X", FLOAT, [3])], + [onnx.helper.make_tensor_value_info("Y", FLOAT, [1, 3, 1])], + [ + onnx.numpy_helper.from_array( + np.array(0, dtype=np.int64), name="axes1" + ), + onnx.numpy_helper.from_array( + np.array(1, dtype=np.int64), name="axes2" + ), + ], + ), + opset_imports=[onnx.helper.make_opsetid("", 18)], + ), + ), ] ) def test_llama_p0_rule_set_unsqueeze_unsqueeze(self, _: str, model: ir.Model):