From cd8ea57c7edbcd1714bd4e4a404cec990940ccc8 Mon Sep 17 00:00:00 2001 From: Bhagirath Mehta Date: Thu, 12 Jun 2025 17:43:12 +0000 Subject: [PATCH 1/2] Merge metadata props when fusing --- onnxscript/rewriter/_rewrite_rule.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/onnxscript/rewriter/_rewrite_rule.py b/onnxscript/rewriter/_rewrite_rule.py index 3e910edd52..ee4f4bec82 100644 --- a/onnxscript/rewriter/_rewrite_rule.py +++ b/onnxscript/rewriter/_rewrite_rule.py @@ -531,6 +531,15 @@ def _apply_to_graph_or_function( f = ir.Function(domain, name, overload, graph=graph, attributes=()) model.functions[f.identifier()] = f + # If we are fusing nodes, update the metadata props of the new node(s) + if delta.match.nodes and delta.new_nodes: + # Concatenate metadata props from all original nodes + fused_metadata_props = "Fused: " + "\t\n".join( + n.metadata_props for n in delta.match.nodes if getattr(n, "metadata_props", None) + ) + # Assign to all new nodes (or just the first, depending on your policy) + delta.new_nodes[0].metadata_props += fused_metadata_props + if verbose: name = f"{rule.name}: " if rule.name else "" print(f"----{name}Matched Nodes----") From eff274c705466525a2b74a07b74a0fc3b8b12139 Mon Sep 17 00:00:00 2001 From: Bhagirath Mehta Date: Thu, 19 Jun 2025 17:16:44 +0000 Subject: [PATCH 2/2] Merge attribute dictionaries --- onnxscript/rewriter/_rewrite_rule.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/onnxscript/rewriter/_rewrite_rule.py b/onnxscript/rewriter/_rewrite_rule.py index ee4f4bec82..a797be745f 100644 --- a/onnxscript/rewriter/_rewrite_rule.py +++ b/onnxscript/rewriter/_rewrite_rule.py @@ -531,14 +531,22 @@ def _apply_to_graph_or_function( f = ir.Function(domain, name, overload, graph=graph, attributes=()) model.functions[f.identifier()] = f - # If we are fusing nodes, update the metadata props of the new node(s) + # If we are fusing nodes, update the docstring of the new node(s) + attributes = ["namespace", "pkg.torch.onnx.class_hierarchy", "pkg.torch.onnx.fx_node", "pkg.torch.onnx.name_scopes", "pkg.torch.onnx.stack_trace"] if delta.match.nodes and delta.new_nodes: - # Concatenate metadata props from all original nodes - fused_metadata_props = "Fused: " + "\t\n".join( - n.metadata_props for n in delta.match.nodes if getattr(n, "metadata_props", None) - ) - # Assign to all new nodes (or just the first, depending on your policy) - delta.new_nodes[0].metadata_props += fused_metadata_props + # Concatenate docstrings from all original nodes + for attribute in attributes: + fused_attribute = "\n".join( + n.metadata_props[attribute] for n in delta.match.nodes if getattr(n, "metadata_props", None) and attribute in n.metadata_props + ) + if fused_attribute.strip(): + fused_attribute = "Fused from nodes with following attributes: " + fused_attribute + for node in delta.new_nodes: + # Assign to all new nodes + if attribute in node.metadata_props: + node.metadata_props[attribute] += f"\n{fused_attribute}" + else: + node.metadata_props[attribute] = fused_attribute if verbose: name = f"{rule.name}: " if rule.name else ""