Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions onnxscript/rewriter/_rewrite_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,9 +371,10 @@ def copy_value(value: ir.Value | None) -> ir.Value | None:

def copy_attr_value(attr: ir.Attr) -> ir.Attr:
if attr.is_ref():
# No need to support this currently, as rewriting inside a function is
# not used, as it has several challenges.
raise NotImplementedError("RefAttr not supported.")
# RefAttr objects are immutable and can be shared as-is.
# The referenced attribute parameter will be handled separately
# when the function signature is created.
return attr
if attr.type in {ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS}:
# No need to support this currently, as rewriting control-flow constructs
# is not used and has several challenges.
Expand Down
10 changes: 8 additions & 2 deletions onnxscript/rewriter/generic_pattern_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,10 +552,16 @@ def transpose_transpose_check(op, **_) -> bool:
def transpose_transpose_apply_pattern(op, X, XT: ir.Value, Y, **_):
perm0 = XT.producer().attributes.get("perm")
if perm0 is not None:
perm0 = perm0.value # TODO(rama): handle RefAttr
if perm0.is_ref():
# Cannot optimize when attribute is a reference, as we don't have the concrete value
return None
perm0 = perm0.value
perm1 = Y.producer().attributes.get("perm")
if perm1 is not None:
perm1 = perm1.value # TODO(rama): handle RefAttr
if perm1.is_ref():
# Cannot optimize when attribute is a reference, as we don't have the concrete value
return None
perm1 = perm1.value
if perm0 is None and perm1 is None:
return op.Identity(X)
if perm0 is None:
Expand Down
67 changes: 67 additions & 0 deletions onnxscript/rewriter/pattern_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,73 @@ def test_model1(x: FLOAT[16, 32], y: FLOAT[16, 32]) -> FLOAT[16, 32]:
rule.apply_to_model(model)
self.assertEqual([x.op_type for x in model.graph], ["ReluPlus"])

def test_rewrite_rule_with_ref_attr(self):
"""Test that rewrite rules handle RefAttr correctly in as_function mode."""

# Create a pattern that matches Transpose nodes
def transpose_pattern(op, x):
return op.Transpose(x, _outputs=["result"])

def replacement(op, x, result: ir.Value):
return op.Identity(x)

# This should work with as_function=True even when encountering RefAttr
rule = pattern.RewriteRule(transpose_pattern, replacement, as_function=True)

# Create a model with a node that has a RefAttr
input_val = ir.Value(name="x", type=ir.TensorType(ir.DataType.FLOAT))
output_val = ir.Value(name="y", type=ir.TensorType(ir.DataType.FLOAT))

# Create a transpose node with a ref attribute
perm_ref_attr = ir.RefAttr("perm", "axis_param", ir.AttributeType.INTS)

transpose_node = ir.Node(
domain="",
op_type="Transpose",
inputs=[input_val],
outputs=[output_val],
attributes={"perm": perm_ref_attr}
)

# Create graph and model
graph = ir.Graph(
inputs=[input_val],
outputs=[output_val],
nodes=[transpose_node]
)

model = ir.Model(graph=graph, ir_version=8)

# Verify the original node has RefAttr
original_node = model.graph[0]
self.assertEqual(original_node.op_type, "Transpose")
perm_attr = original_node.attributes["perm"]
self.assertTrue(perm_attr.is_ref())
self.assertEqual(perm_attr.ref_attr_name, "axis_param")

# Apply the rewrite rule - this should not fail
count = rule.apply_to_model(model)
self.assertEqual(count, 1)

# Verify the result: main graph should have function call, function should exist
self.assertEqual(len(model.graph), 1)
call_node = model.graph[0]
self.assertEqual(call_node.op_type, "Identity") # Function name becomes op_type

# Verify function was created
self.assertEqual(len(model.functions), 1)
func = list(model.functions.values())[0]
self.assertEqual(func.name, "Identity")

# Verify function contains the original Transpose with RefAttr preserved
func_nodes = list(func)
self.assertEqual(len(func_nodes), 1)
func_transpose = func_nodes[0]
self.assertEqual(func_transpose.op_type, "Transpose")
func_perm_attr = func_transpose.attributes["perm"]
self.assertTrue(func_perm_attr.is_ref())
self.assertEqual(func_perm_attr.ref_attr_name, "axis_param")


class PatternBuilderTest(unittest.TestCase):
def test_pattern_builder_context(self):
Expand Down