Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 20 additions & 34 deletions onnxscript/rewriter/llama_rule_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

from typing import ClassVar

import numpy as np
import onnx.numpy_helper

import onnxscript.ir as ir
import onnxscript.rewriter._ir_utils as ir_utils
import onnxscript.rewriter.no_op as no_op
import onnxscript.rewriter.pattern as orp

Expand Down Expand Up @@ -230,42 +230,37 @@ class UnsqueezeUnsqueeze(orp.RewriteRuleAsClass):
def pattern(cls, op, x, axes1, axes2):
return op.Unsqueeze(op.Unsqueeze(x, axes1), axes2)

@classmethod
def _combine_axes(cls, axes1: np.ndarray, axes2: np.ndarray) -> np.ndarray:
"""Combines two single axes into one tensor of two axes."""
if axes1[0] < axes2[0]:
return np.hstack([axes1, axes2])
return np.hstack([axes2, axes1 + 1]).astype(np.int64)

@classmethod
def rewrite(cls, op, x: ir.Value, axes1: ir.Value, axes2: ir.Value):
assert axes1.const_value is not None
assert axes2.const_value is not None
axes = cls._combine_axes(axes1.const_value.numpy(), axes2.const_value.numpy())
return op.Unsqueeze(x, op.Constant(value=onnx.numpy_helper.from_array(axes)))
v1 = ir_utils.get_singleton_value(axes1)
v2 = ir_utils.get_singleton_value(axes2)
axes = [v1, v2] if v1 < v2 else [v2, v1 + 1]
return op.Unsqueeze(x, op.Constant(value=ir.tensor(axes, dtype=ir.DataType.INT64)))

@classmethod
def check(cls, context, x, axes1, axes2) -> bool:
del context # Unused
del x # Unused
if axes1.const_value is None or axes2.const_value is None:
return False

v1 = axes1.const_value.numpy()
v2 = axes2.const_value.numpy()
if not v1.shape or not v2.shape:
return False
if v1.shape[0] != 1 or v2.shape[0] != 1:
# Implemented later if needed.
# Currently restricted to single element positive axis
v1 = ir_utils.get_singleton_value(axes1)
v2 = ir_utils.get_singleton_value(axes2)
if v1 is None or v2 is None:
return False
if v1.min() < 0:
if (v1 < 0) or (v2 < 0):
return False
if v2.min() < 0:
return False

return True


cast_cast_rule = orp.make_rewrite_rule_from_class(CastCast)
cast_identity_rule = orp.make_rewrite_rule_from_class(CastIdentity)
expand_identity_rule = orp.make_rewrite_rule_from_class(ExpandIdentity)
reshape_reshape_rule = orp.make_rewrite_rule_from_class(ReshapeReshape)
slice_split_rule = orp.make_rewrite_rule_from_class(SlicesSplit, True)
transpose_identity_rule = orp.make_rewrite_rule_from_class(TransposeIdentity)
transpose_transpose_rule = orp.make_rewrite_rule_from_class(TransposeTranspose)
unsqueeze_unsqueeze_rule = orp.make_rewrite_rule_from_class(UnsqueezeUnsqueeze)


def llama_p0_rule_set() -> orp.RewriteRuleSet:
"""Returns a set of rules which should be applied
before any other one as they usually remove unnecessary computation
Expand All @@ -274,15 +269,6 @@ def llama_p0_rule_set() -> orp.RewriteRuleSet:
Returns:
RewriteRuleSet
"""
cast_cast_rule = orp.make_rewrite_rule_from_class(CastCast)
cast_identity_rule = orp.make_rewrite_rule_from_class(CastIdentity)
expand_identity_rule = orp.make_rewrite_rule_from_class(ExpandIdentity)
reshape_reshape_rule = orp.make_rewrite_rule_from_class(ReshapeReshape)
slice_split_rule = orp.make_rewrite_rule_from_class(SlicesSplit, True)
transpose_identity_rule = orp.make_rewrite_rule_from_class(TransposeIdentity)
transpose_transpose_rule = orp.make_rewrite_rule_from_class(TransposeTranspose)
unsqueeze_unsqueeze_rule = orp.make_rewrite_rule_from_class(UnsqueezeUnsqueeze)

return orp.RewriteRuleSet(
[
no_op.mul_by_1_rule,
Expand Down
23 changes: 23 additions & 0 deletions onnxscript/rewriter/llama_rule_sets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,29 @@ def test_llama_p0_rule_set_expand_identity(
opset_imports=[onnx.helper.make_opsetid("", 18)],
),
),
(
"double_unsqueezes_3",
_make_model(
onnx.helper.make_graph(
[
onnx.helper.make_node("Unsqueeze", ["X", "axes1"], ["Xu"]),
onnx.helper.make_node("Unsqueeze", ["Xu", "axes2"], ["Y"]),
],
"name",
[onnx.helper.make_tensor_value_info("X", FLOAT, [3])],
[onnx.helper.make_tensor_value_info("Y", FLOAT, [1, 3, 1])],
[
onnx.numpy_helper.from_array(
np.array(0, dtype=np.int64), name="axes1"
),
onnx.numpy_helper.from_array(
np.array(1, dtype=np.int64), name="axes2"
),
],
),
opset_imports=[onnx.helper.make_opsetid("", 18)],
),
),
]
)
def test_llama_p0_rule_set_unsqueeze_unsqueeze(self, _: str, model: ir.Model):
Expand Down
Loading