Skip to content

Commit 727210b

Browse files
Copilotgramalingamjustinchubytitaiwangms
authored
Expose match functionality of rewrite-rule by extracting base classes (#2447)
This PR extracts the pattern matching functionality from rewrite rules into standalone base classes, allowing users to use pattern matching without needing the replacement functionality. ## Changes ### New Base Classes **`PatternImpl`**: Core pattern matching functionality - Encapsulates `_target_pattern`, `_matcher`, and `_condition_function` - Provides `match()` method that returns `MatchResult` or `None` - Can be used standalone for pattern matching without rewriting **`PatternBase`**: Base class for class-based pattern definition - Provides abstract `pattern()` method for defining patterns - Provides optional `check()` method for condition functions - Includes `create_pattern_impl()` method to generate `PatternImpl` instances ### Updated Classes **`RewriteRule`**: Now inherits from `PatternImpl` - Maintains all existing functionality - Gains access to standalone pattern matching capabilities - Uses inherited `match()` method in `try_rewrite()` **`RewriteRuleClassBase`**: Now inherits from `PatternBase` - Maintains all existing functionality - Gains access to pattern-only capabilities - Still provides `rule()` class method to create `RewriteRule` instances ## Usage Examples ### Standalone Pattern Matching ```python from onnxscript.rewriter import pattern # Define a pattern def identity_pattern(op, x): return op.Identity(x) # Create a pattern matcher (no replacement needed) pattern_matcher = pattern.PatternImpl(identity_pattern, name="IdentityMatcher") # Use it to check if a node matches the pattern match_result = pattern_matcher.match(model, graph, node) if match_result: print(f"Pattern matched! Found {len(match_result.nodes)} nodes") ``` ### Class-Based Pattern Definition ```python class MyPattern(pattern.PatternBase): def pattern(self, op, x): return op.Identity(x) def check(self, context, x): # Custom condition logic return pattern.MatchResult() # Create a pattern implementation my_pattern = MyPattern() pattern_impl = my_pattern.create_pattern_impl() ``` ### Existing Functionality Preserved ```python # RewriteRule still works exactly as before rule = pattern.RewriteRule(target_pattern, replacement_pattern) # But now it can also be used for just pattern matching match_result = rule.match(model, graph, node) # New capability count = rule.apply_to_model(model) # Existing functionality ``` ## Backward Compatibility All existing functionality is preserved. The changes are purely additive - existing code using `RewriteRule` and `RewriteRuleClassBase` will continue to work without modification. ## Testing - All existing tests pass (34/34 tests successful) - Added comprehensive test suite for new base classes - Created example demonstrating standalone pattern matching usage - Verified inheritance relationships work correctly Fixes #2446. <!-- START COPILOT CODING AGENT TIPS --> --- 💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more [Copilot coding agent tips](https://gh.io/copilot-coding-agent-tips) in the docs. --------- Signed-off-by: Ganesan Ramalingam <[email protected]> Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: gramalingam <[email protected]> Co-authored-by: G. Ramalingam <[email protected]> Co-authored-by: Justin Chu <[email protected]> Co-authored-by: Ti-Tai Wang <[email protected]> Co-authored-by: justinchuby <[email protected]>
1 parent 7517f2e commit 727210b

File tree

4 files changed

+627
-84
lines changed

4 files changed

+627
-84
lines changed
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""Example demonstrating the new pattern matching functionality."""
4+
5+
import onnx.parser
6+
7+
from onnxscript import ir
8+
from onnxscript.rewriter import pattern
9+
10+
11+
def example_standalone_pattern_matching():
12+
"""Example showing how to use Pattern for standalone pattern matching."""
13+
14+
print("=== Standalone Pattern Matching Example ===")
15+
16+
# Define a pattern that matches Identity nodes
17+
def identity_pattern(op, x):
18+
return op.Identity(x)
19+
20+
# Create a Pattern for standalone pattern matching (no replacement)
21+
pattern_matcher = pattern.Pattern(identity_pattern, name="IdentityMatcher")
22+
23+
# Create a model with an Identity node
24+
model_proto = onnx.parser.parse_model(
25+
"""
26+
<ir_version: 7, opset_import: [ "" : 17]>
27+
agraph (float[N] x) => (float[N] z)
28+
{
29+
z = Identity(x)
30+
}
31+
"""
32+
)
33+
model = ir.serde.deserialize_model(model_proto)
34+
35+
# Find nodes to test pattern matching against
36+
for node in model.graph:
37+
print(f"Testing pattern against {node.op_type} node...")
38+
match_result = pattern_matcher.match(model, model.graph, node)
39+
40+
if match_result is not None:
41+
print(f" ✓ Pattern matched! Found {len(match_result.nodes)} nodes in match.")
42+
print(f" Matched node: {match_result.nodes[0].op_type}")
43+
else:
44+
print(f" ✗ Pattern did not match {node.op_type} node.")
45+
46+
47+
def example_class_based_pattern():
48+
"""Example showing how to use PatternBase for class-based pattern definition."""
49+
50+
print("\n=== Class-Based Pattern Example ===")
51+
52+
class IdentityPatternClass(pattern.PatternBase):
53+
"""A class-based pattern that matches Identity nodes."""
54+
55+
def pattern(self, op, x):
56+
return op.Identity(x)
57+
58+
def check(self, context, x):
59+
"""Custom condition - always succeeds for this example."""
60+
print(f" Checking condition for input: {x}")
61+
return pattern.MatchResult() # Always succeeds
62+
63+
# Create an instance of the pattern class
64+
identity_pattern_class = IdentityPatternClass(name="ClassBasedIdentity")
65+
66+
# The Pattern is created internally, we can use the pattern directly
67+
print(f"Created pattern matcher: {identity_pattern_class.name}")
68+
69+
# Use it directly with the match method
70+
model_proto = onnx.parser.parse_model(
71+
"""
72+
<ir_version: 7, opset_import: [ "" : 17]>
73+
agraph (float[N] x) => (float[N] z)
74+
{
75+
z = Identity(x)
76+
}
77+
"""
78+
)
79+
model = ir.serde.deserialize_model(model_proto)
80+
81+
for node in model.graph:
82+
if node.op_type == "Identity":
83+
print(f"Testing class-based pattern against {node.op_type} node...")
84+
match_result = identity_pattern_class.match(model, model.graph, node)
85+
86+
if match_result is not None:
87+
print(" ✓ Class-based pattern matched!")
88+
else:
89+
print(" ✗ Class-based pattern did not match.")
90+
91+
92+
def example_rewrite_rule_still_works():
93+
"""Example showing that existing RewriteRule functionality is preserved."""
94+
95+
print("\n=== Existing RewriteRule Still Works ===")
96+
97+
def identity_pattern(op, x):
98+
return op.Identity(x)
99+
100+
def identity_replacement(op, x):
101+
return op.Identity(x) # No-op replacement
102+
103+
# Create a RewriteRule (which now inherits from Pattern)
104+
rule = pattern.RewriteRule(identity_pattern, identity_replacement, name="IdentityRule")
105+
106+
print(f"Created rewrite rule: {rule.name}")
107+
print(f"Rule is also a Pattern: {isinstance(rule, pattern.Pattern)}")
108+
109+
# The rule can be used both for pattern matching and rewriting
110+
model_proto = onnx.parser.parse_model(
111+
"""
112+
<ir_version: 7, opset_import: [ "" : 17]>
113+
agraph (float[N] x) => (float[N] z)
114+
{
115+
z = Identity(x)
116+
}
117+
"""
118+
)
119+
model = ir.serde.deserialize_model(model_proto)
120+
121+
# Use it for just pattern matching (inherited from Pattern)
122+
for node in model.graph:
123+
if node.op_type == "Identity":
124+
print(f"Using RewriteRule for pattern matching on {node.op_type}...")
125+
match_result = rule.match(model, model.graph, node)
126+
127+
if match_result is not None:
128+
print(" ✓ RewriteRule matched as a pattern matcher!")
129+
130+
# Use it for rewriting (original functionality)
131+
print("Using RewriteRule for rewriting...")
132+
count = rule.apply_to_model(model)
133+
print(f" Applied rule {count} times")
134+
135+
136+
if __name__ == "__main__":
137+
example_standalone_pattern_matching()
138+
example_class_based_pattern()
139+
example_rewrite_rule_still_works()
140+
print("\n=== All Examples Completed ===")

0 commit comments

Comments
 (0)