From 66971ddf45f94d59114a29d2f9adfa3d9d2d8777 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 14 Jun 2025 03:41:07 +0000 Subject: [PATCH 1/3] Initial plan for issue From e18c2ed9323bed0a00016fe0e41a831da86f72c0 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 14 Jun 2025 03:51:10 +0000 Subject: [PATCH 2/3] Add test reproducing RefAttr issue in rewrite rules Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> --- test_attribute_refs.py | 80 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 test_attribute_refs.py diff --git a/test_attribute_refs.py b/test_attribute_refs.py new file mode 100644 index 0000000000..1192717d37 --- /dev/null +++ b/test_attribute_refs.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python3 + +"""Test for attribute references in rewrite rules.""" + +import unittest +import onnx +import onnxscript.ir as ir +import onnxscript.rewriter.pattern as pattern + + +class TestAttributeRefs(unittest.TestCase): + def test_rewrite_rule_with_attribute_ref_fails_in_copy(self): + """Test that rewrite rules fail when trying to extract function with RefAttr.""" + + # Create a pattern that matches Transpose + def transpose_pattern(op, x): + return op.Transpose(x, _outputs=["result"]) + + def replacement(op, x, result: ir.Value): + return op.Identity(x) + + # This will trigger the _copy_for_function issue when as_function=True + rule = pattern.RewriteRule(transpose_pattern, replacement, as_function=True) + + # Create a simple model manually using the IR + input_val = ir.Value(name="x", type=ir.TensorType(ir.DataType.FLOAT)) + output_val = ir.Value(name="y", type=ir.TensorType(ir.DataType.FLOAT)) + + # Create a transpose node with a ref attribute + perm_ref_attr = ir.RefAttr("perm", "axis_param", ir.AttributeType.INTS) + + transpose_node = ir.Node( + domain="", + op_type="Transpose", + inputs=[input_val], + outputs=[output_val], + attributes={"perm": perm_ref_attr} + ) + + # Create graph + graph = ir.Graph( + inputs=[input_val], + outputs=[output_val], + nodes=[transpose_node] + ) + + # Create model + model = ir.Model( + graph=graph, + ir_version=8 + ) + + print("Graph nodes:") + for node in model.graph: + print(f" Node: {node.op_type}") + for attr_name, attr in node.attributes.items(): + print(f" Attribute {attr_name}: {attr}, is_ref: {attr.is_ref()}") + if attr.is_ref(): + print(f" References: {attr.ref_attr_name}") + + try: + # This should trigger the NotImplementedError in _copy_for_function + count = rule.apply_to_model(model) + print(f"Unexpected success: Rewrite applied {count} times") + return False + except NotImplementedError as e: + print(f"Expected NotImplementedError: {e}") + # This confirms the issue exists + return True + except Exception as e: + print(f"Unexpected error: {e}") + import traceback + traceback.print_exc() + return False + + +if __name__ == "__main__": + test = TestAttributeRefs() + success = test.test_rewrite_rule_with_attribute_ref_fails_in_copy() + print(f"Test result: {'PASS' if success else 'FAIL'}") \ No newline at end of file From 50d86f16a4be3ebd861974491c19511a9a8878fb Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 14 Jun 2025 03:58:25 +0000 Subject: [PATCH 3/3] Implement RefAttr support in rewrite rules Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> --- onnxscript/rewriter/_rewrite_rule.py | 7 +- onnxscript/rewriter/generic_pattern_test.py | 10 ++- onnxscript/rewriter/pattern_test.py | 67 +++++++++++++++++ test_attribute_refs.py | 80 --------------------- 4 files changed, 79 insertions(+), 85 deletions(-) delete mode 100644 test_attribute_refs.py diff --git a/onnxscript/rewriter/_rewrite_rule.py b/onnxscript/rewriter/_rewrite_rule.py index 3e910edd52..cc6f93c321 100644 --- a/onnxscript/rewriter/_rewrite_rule.py +++ b/onnxscript/rewriter/_rewrite_rule.py @@ -371,9 +371,10 @@ def copy_value(value: ir.Value | None) -> ir.Value | None: def copy_attr_value(attr: ir.Attr) -> ir.Attr: if attr.is_ref(): - # No need to support this currently, as rewriting inside a function is - # not used, as it has several challenges. - raise NotImplementedError("RefAttr not supported.") + # RefAttr objects are immutable and can be shared as-is. + # The referenced attribute parameter will be handled separately + # when the function signature is created. + return attr if attr.type in {ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS}: # No need to support this currently, as rewriting control-flow constructs # is not used and has several challenges. diff --git a/onnxscript/rewriter/generic_pattern_test.py b/onnxscript/rewriter/generic_pattern_test.py index dadaf5e8bb..7395ebec77 100644 --- a/onnxscript/rewriter/generic_pattern_test.py +++ b/onnxscript/rewriter/generic_pattern_test.py @@ -552,10 +552,16 @@ def transpose_transpose_check(op, **_) -> bool: def transpose_transpose_apply_pattern(op, X, XT: ir.Value, Y, **_): perm0 = XT.producer().attributes.get("perm") if perm0 is not None: - perm0 = perm0.value # TODO(rama): handle RefAttr + if perm0.is_ref(): + # Cannot optimize when attribute is a reference, as we don't have the concrete value + return None + perm0 = perm0.value perm1 = Y.producer().attributes.get("perm") if perm1 is not None: - perm1 = perm1.value # TODO(rama): handle RefAttr + if perm1.is_ref(): + # Cannot optimize when attribute is a reference, as we don't have the concrete value + return None + perm1 = perm1.value if perm0 is None and perm1 is None: return op.Identity(X) if perm0 is None: diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index 6706eea193..236232a0d5 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -779,6 +779,73 @@ def test_model1(x: FLOAT[16, 32], y: FLOAT[16, 32]) -> FLOAT[16, 32]: rule.apply_to_model(model) self.assertEqual([x.op_type for x in model.graph], ["ReluPlus"]) + def test_rewrite_rule_with_ref_attr(self): + """Test that rewrite rules handle RefAttr correctly in as_function mode.""" + + # Create a pattern that matches Transpose nodes + def transpose_pattern(op, x): + return op.Transpose(x, _outputs=["result"]) + + def replacement(op, x, result: ir.Value): + return op.Identity(x) + + # This should work with as_function=True even when encountering RefAttr + rule = pattern.RewriteRule(transpose_pattern, replacement, as_function=True) + + # Create a model with a node that has a RefAttr + input_val = ir.Value(name="x", type=ir.TensorType(ir.DataType.FLOAT)) + output_val = ir.Value(name="y", type=ir.TensorType(ir.DataType.FLOAT)) + + # Create a transpose node with a ref attribute + perm_ref_attr = ir.RefAttr("perm", "axis_param", ir.AttributeType.INTS) + + transpose_node = ir.Node( + domain="", + op_type="Transpose", + inputs=[input_val], + outputs=[output_val], + attributes={"perm": perm_ref_attr} + ) + + # Create graph and model + graph = ir.Graph( + inputs=[input_val], + outputs=[output_val], + nodes=[transpose_node] + ) + + model = ir.Model(graph=graph, ir_version=8) + + # Verify the original node has RefAttr + original_node = model.graph[0] + self.assertEqual(original_node.op_type, "Transpose") + perm_attr = original_node.attributes["perm"] + self.assertTrue(perm_attr.is_ref()) + self.assertEqual(perm_attr.ref_attr_name, "axis_param") + + # Apply the rewrite rule - this should not fail + count = rule.apply_to_model(model) + self.assertEqual(count, 1) + + # Verify the result: main graph should have function call, function should exist + self.assertEqual(len(model.graph), 1) + call_node = model.graph[0] + self.assertEqual(call_node.op_type, "Identity") # Function name becomes op_type + + # Verify function was created + self.assertEqual(len(model.functions), 1) + func = list(model.functions.values())[0] + self.assertEqual(func.name, "Identity") + + # Verify function contains the original Transpose with RefAttr preserved + func_nodes = list(func) + self.assertEqual(len(func_nodes), 1) + func_transpose = func_nodes[0] + self.assertEqual(func_transpose.op_type, "Transpose") + func_perm_attr = func_transpose.attributes["perm"] + self.assertTrue(func_perm_attr.is_ref()) + self.assertEqual(func_perm_attr.ref_attr_name, "axis_param") + class PatternBuilderTest(unittest.TestCase): def test_pattern_builder_context(self): diff --git a/test_attribute_refs.py b/test_attribute_refs.py deleted file mode 100644 index 1192717d37..0000000000 --- a/test_attribute_refs.py +++ /dev/null @@ -1,80 +0,0 @@ -#!/usr/bin/env python3 - -"""Test for attribute references in rewrite rules.""" - -import unittest -import onnx -import onnxscript.ir as ir -import onnxscript.rewriter.pattern as pattern - - -class TestAttributeRefs(unittest.TestCase): - def test_rewrite_rule_with_attribute_ref_fails_in_copy(self): - """Test that rewrite rules fail when trying to extract function with RefAttr.""" - - # Create a pattern that matches Transpose - def transpose_pattern(op, x): - return op.Transpose(x, _outputs=["result"]) - - def replacement(op, x, result: ir.Value): - return op.Identity(x) - - # This will trigger the _copy_for_function issue when as_function=True - rule = pattern.RewriteRule(transpose_pattern, replacement, as_function=True) - - # Create a simple model manually using the IR - input_val = ir.Value(name="x", type=ir.TensorType(ir.DataType.FLOAT)) - output_val = ir.Value(name="y", type=ir.TensorType(ir.DataType.FLOAT)) - - # Create a transpose node with a ref attribute - perm_ref_attr = ir.RefAttr("perm", "axis_param", ir.AttributeType.INTS) - - transpose_node = ir.Node( - domain="", - op_type="Transpose", - inputs=[input_val], - outputs=[output_val], - attributes={"perm": perm_ref_attr} - ) - - # Create graph - graph = ir.Graph( - inputs=[input_val], - outputs=[output_val], - nodes=[transpose_node] - ) - - # Create model - model = ir.Model( - graph=graph, - ir_version=8 - ) - - print("Graph nodes:") - for node in model.graph: - print(f" Node: {node.op_type}") - for attr_name, attr in node.attributes.items(): - print(f" Attribute {attr_name}: {attr}, is_ref: {attr.is_ref()}") - if attr.is_ref(): - print(f" References: {attr.ref_attr_name}") - - try: - # This should trigger the NotImplementedError in _copy_for_function - count = rule.apply_to_model(model) - print(f"Unexpected success: Rewrite applied {count} times") - return False - except NotImplementedError as e: - print(f"Expected NotImplementedError: {e}") - # This confirms the issue exists - return True - except Exception as e: - print(f"Unexpected error: {e}") - import traceback - traceback.print_exc() - return False - - -if __name__ == "__main__": - test = TestAttributeRefs() - success = test.test_rewrite_rule_with_attribute_ref_fails_in_copy() - print(f"Test result: {'PASS' if success else 'FAIL'}") \ No newline at end of file