Skip to content

Commit 0bf5ca0

Browse files
Copilotgramalingam
andauthored
[Rewriter] Implement value/node level checkers for pattern matching infrastructure (#2459)
This PR extends the pattern matching infrastructure to support value/node level checkers as requested in the issue. The implementation allows for more sophisticated pattern matching by enabling custom validation logic at both the node and value levels. ## Key Changes ### 1. Extended Pattern IR Classes - **ValuePattern**: Added optional `_check` callable attribute via `check` keyword argument - **NodePattern**: Added optional `_check` callable attribute via `check` keyword argument - Both checkers take `(MatchContext, ir.Node/ir.Value)` and return `bool` or `MatchResult` ### 2. Enhanced Pattern Building - **_to_value_pattern**: Now accepts callable inputs, automatically creating `ValuePattern` with checker - **OpPatternBuilder.__call__**: Added `_check` parameter for node-level validation ### 3. Extended MatchResult - Added `node_bindings` property (similar to existing `value_bindings`) - Provides access to pattern node → matched node mappings ### 4. Enhanced Pattern Matching - **Pattern.match**: Now executes value/node level checks before condition function - Iterates through `node_bindings` and `value_bindings` to run associated checkers - Stops on first check failure with appropriate error handling ## Usage Examples ### Node-Level Checker ```python def validated_add_checker(context, node): """Only accept Add nodes with no custom attributes.""" return node.op_type == "Add" and len(node.attributes) == 0 def pattern(op, x, y): return op.Add(x, y, _check=validated_add_checker) ``` ### Value-Level Checker ```python def shape_checker(context, value): """Validate value has expected shape properties.""" return hasattr(value, 'type') and hasattr(value.type, 'shape') def pattern(op, x, y): validated_x = shape_checker # Converted to ValuePattern with checker return op.Add(validated_x, y) ``` ### Combined Checkers ```python def pattern(op, x, y): validated_x = value_checker # Value-level check return op.Add(validated_x, y, _check=node_checker) # Node-level check ``` ## Testing Added comprehensive test suite (`ValueNodeCheckersTest`) covering: - ✅ ValuePattern and NodePattern with checkers - ✅ _to_value_pattern with callable inputs - ✅ OpPatternBuilder with _check parameter - ✅ Pattern.match with successful node/value checkers - ✅ Pattern.match with failing checkers (proper error handling) - ✅ Backward compatibility with existing patterns All existing tests continue to pass, ensuring no breaking changes. Fixes #2458. <!-- START COPILOT CODING AGENT TIPS --> --- 💬 Share your feedback on Copilot coding agent for the chance to win a $200 gift card! Click [here](https://survey.alchemer.com/s3/8343779/Copilot-Coding-agent) to start the survey. --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: gramalingam <[email protected]>
1 parent 2f147eb commit 0bf5ca0

File tree

7 files changed

+460
-13
lines changed

7 files changed

+460
-13
lines changed
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
(heading-target-checkers)=
2+
# Node and Value Level Checkers
3+
4+
The pattern matching infrastructure supports custom validation logic at both the node and value levels through checker functions. These checkers allow for more sophisticated pattern matching by enabling additional constraints beyond basic operator and structure matching.
5+
6+
## Value-Level Checkers
7+
8+
Value-level checkers validate properties of specific values in the pattern. They are particularly useful for checking constants, shapes, or other value-specific properties.
9+
10+
### Basic Usage
11+
12+
A value checker is a function that takes a `MatchContext` and an `ir.Value`, and returns either a boolean or a `MatchResult`:
13+
14+
```python
15+
def is_positive_constant(context, value: ir.Value):
16+
"""Check if a value is a positive constant."""
17+
if value.const_value is not None:
18+
# Get the numpy array from const_value
19+
numpy_array = value.const_value.numpy()
20+
21+
# Check if it represents a single value and is positive
22+
if numpy_array.size != 1:
23+
return False
24+
25+
return float(numpy_array.item()) > 0
26+
27+
return False
28+
```
29+
30+
You can use this checker directly in your pattern by passing the callable as an input:
31+
32+
```python
33+
def add_pattern(op, x, y):
34+
# Use callable as input to create ValuePattern with checker
35+
return op.Add(is_positive_constant, y)
36+
```
37+
38+
This pattern will only match `Add` operations where the first input is a positive constant value.
39+
40+
### Example Usage
41+
42+
```python
43+
from onnxscript.rewriter import pattern
44+
from onnxscript import ir, optimizer
45+
import onnx
46+
47+
# Create a model with different Add operations
48+
model_proto = onnx.parser.parse_model("""
49+
<ir_version: 7, opset_import: [ "" : 17]>
50+
agraph (float[N] x, float[N] y) => (float[N] z1, float[N] z2, float[N] z3)
51+
{
52+
pos_const = Constant <value_float = 2.5> ()
53+
neg_const = Constant <value_float = -1.5> ()
54+
z1 = Add(x, y) # non-constant first parameter
55+
z2 = Add(pos_const, y) # positive constant first parameter
56+
z3 = Add(neg_const, y) # negative constant first parameter
57+
}
58+
""")
59+
model = ir.serde.deserialize_model(model_proto)
60+
61+
# Apply constant propagation to set const_value fields
62+
optimizer.basic_constant_propagation(model.graph.all_nodes())
63+
64+
# Create the pattern with value checker
65+
rule_pattern = pattern.Pattern(add_pattern)
66+
67+
# Test matching against different Add nodes
68+
add_nodes = [node for node in model.graph if node.op_type == "Add"]
69+
70+
# Non-constant first parameter - will not match
71+
match_result = rule_pattern.match(model, model.graph, add_nodes[0])
72+
print(f"Non-constant: {bool(match_result)}") # False
73+
74+
# Positive constant first parameter - will match
75+
match_result = rule_pattern.match(model, model.graph, add_nodes[1])
76+
print(f"Positive constant: {bool(match_result)}") # True
77+
78+
# Negative constant first parameter - will not match
79+
match_result = rule_pattern.match(model, model.graph, add_nodes[2])
80+
print(f"Negative constant: {bool(match_result)}") # False
81+
```
82+
83+
## Node-Level Checkers
84+
85+
Node-level checkers validate properties of the operation nodes themselves, such as attributes, operation types, or other node-specific properties.
86+
87+
### Basic Usage
88+
89+
A node checker is a function that takes a `MatchContext` and an `ir.Node`, and returns either a boolean or a `MatchResult`:
90+
91+
```python
92+
def shape_node_checker(context, node):
93+
"""Check if a Shape operation has start attribute equal to 0."""
94+
return node.attributes.get_int("start", 0) == 0
95+
```
96+
97+
You can use this checker by passing it to the `_check` parameter of an operation:
98+
99+
```python
100+
def shape_pattern(op, x):
101+
return op.Shape(x, _check=shape_node_checker)
102+
```
103+
104+
This pattern will only match `Shape` operations where the `start` attribute is 0 (or not present, as the default is 0).
105+
106+
### Example Usage
107+
108+
```python
109+
from onnxscript.rewriter import pattern
110+
from onnxscript import ir
111+
import onnx
112+
113+
# Create a model with different Shape operations
114+
model_proto = onnx.parser.parse_model("""
115+
<ir_version: 7, opset_import: [ "" : 17]>
116+
agraph (float[N, M] x) => (int64[2] z1, int64[2] z2, int64[1] z3)
117+
{
118+
z1 = Shape(x)
119+
z2 = Shape <start: int = 0>(x)
120+
z3 = Shape <start: int = 1>(x)
121+
}
122+
""")
123+
model = ir.serde.deserialize_model(model_proto)
124+
125+
# Create the pattern with node checker
126+
rule_pattern = pattern.Pattern(shape_pattern)
127+
128+
# Test matching against different Shape nodes
129+
nodes = list(model.graph)
130+
shape_nodes = [node for node in nodes if node.op_type == "Shape"]
131+
132+
# Shape without start attribute (default 0) - will match
133+
match_result = rule_pattern.match(model, model.graph, shape_nodes[0])
134+
print(f"No start attr: {bool(match_result)}") # True
135+
136+
# Shape with start=0 - will match
137+
match_result = rule_pattern.match(model, model.graph, shape_nodes[1])
138+
print(f"Start=0: {bool(match_result)}") # True
139+
140+
# Shape with start=1 - will not match
141+
match_result = rule_pattern.match(model, model.graph, shape_nodes[2])
142+
print(f"Start=1: {bool(match_result)}") # False
143+
```
144+
145+
## Combining Checkers
146+
147+
You can combine both node-level and value-level checkers in the same pattern for more sophisticated matching:
148+
149+
```python
150+
def complex_pattern(op, x, y):
151+
# Value-level checker for first input
152+
validated_x = is_positive_constant
153+
# Node-level checker for the operation
154+
return op.Add(validated_x, y, _check=lambda ctx, node: len(node.attributes) == 0)
155+
```
156+
157+
This pattern will only match `Add` operations where:
158+
1. The first input is a positive constant (value-level check)
159+
2. The node has no custom attributes (node-level check)
160+
161+
## Execution Timing and Limitations
162+
163+
### When Checkers Are Called
164+
165+
Node-level and value-level checkers are called **only at the end of the complete structural match**. This means:
166+
167+
1. **Structural matching happens first**: The pattern matching engine first validates that the graph structure matches the pattern (correct operators, connections, etc.)
168+
2. **Checkers run after structural validation**: Only after the structural match succeeds do the node and value checkers execute
169+
3. **Order of execution**: Value-level checkers run first, followed by node-level checkers, and finally the pattern's condition function
170+
171+
### Limitations with Pattern Disjunctions
172+
173+
One important limitation of this design is that these checks don't compose well with pattern disjunctions (multiple alternative patterns). When searching among multiple value patterns:
174+
175+
- **Only structural checking is performed initially**: If structural matching succeeds for the first alternative, other alternatives are not considered
176+
- **Checker failures don't trigger backtracking**: If a checker fails, the entire pattern match fails rather than trying the next alternative pattern
177+
178+
This means you should be careful when designing patterns with multiple alternatives that rely on checkers, as the checker logic may prevent exploration of valid alternative matches.
179+
180+
## Error Handling
181+
182+
Checkers can return either:
183+
- `True`: Check passed, continue matching
184+
- `False`: Check failed, pattern does not match
185+
- `MatchResult`: More detailed result with potential failure reasons
186+
187+
If a checker raises an exception, it will be caught and treated as a match failure, allowing patterns to fail gracefully when encountering unexpected conditions.

docs/tutorial/rewriter/rewrite_patterns.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,6 @@ There are three main components needed when rewriting patterns in the graph:
2424

2525
```{include} commute.md
2626
```
27+
28+
```{include} node_value_checkers.md
29+
```

onnxscript/rewriter/_basics.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,13 @@ def value_bindings(self) -> dict[_pattern_ir.ValuePattern, ir.Value]:
188188
raise ValueError("Value bindings can be accessed only at the top-level match.")
189189
return self._current_match.value_bindings
190190

191+
@property
192+
def node_bindings(self) -> dict[_pattern_ir.NodePattern, ir.Node]:
193+
"""Returns the bindings for the node variables."""
194+
if len(self._partial_matches) > 1:
195+
raise ValueError("Node bindings can be accessed only at the top-level match.")
196+
return self._current_match.node_bindings
197+
191198
@property
192199
def outputs(self) -> MutableSequence[ir.Value]:
193200
"""Returns the list of output values that matched the pattern."""

onnxscript/rewriter/_pattern_ir.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def __call__(
224224
_outputs: int | list[str | None] = 1,
225225
_allow_other_attributes: bool | None = None,
226226
_allow_other_inputs: bool | None = None,
227+
_check: Callable | None = None,
227228
**kwargs,
228229
):
229230
if _version is not None:
@@ -255,6 +256,7 @@ def __call__(
255256
_outputs,
256257
allow_other_attributes=_allow_other_attributes,
257258
allow_other_inputs=_allow_other_inputs,
259+
check=_check,
258260
)
259261
self.pattern_builder.add_node(node_pattern)
260262
output_values = node_pattern.outputs
@@ -266,7 +268,7 @@ def __call__(
266268

267269

268270
def _to_value_pattern(
269-
x: ValuePattern | int | float | None,
271+
x: ValuePattern | int | float | Callable | None,
270272
) -> ValuePattern | None:
271273
"""Promotes an input-value used to construct a NodePattern to a ValuePattern.
272274
@@ -282,6 +284,8 @@ def _to_value_pattern(
282284
explicitly write this as:
283285
::
284286
z = op.Add(x, op.Constant(0))
287+
288+
If a callable is provided, it will be converted to a ValuePattern with the callable as the check attribute.
285289
"""
286290
if x is None or isinstance(x, ValuePattern):
287291
return x
@@ -291,6 +295,8 @@ def _to_value_pattern(
291295
if all(isinstance(i, (int, float)) for i in x):
292296
return Constant(x)
293297
raise ValueError("Only lists of int/float can be used as a ValuePattern")
298+
if callable(x):
299+
return ValuePattern(None, check=x)
294300

295301
raise TypeError(f"Cannot convert {type(x)} to ValuePattern")
296302

@@ -314,19 +320,24 @@ class ValuePattern:
314320
operations, so that we can write patterns like `x + 1` and `1 + x`.
315321
"""
316322

317-
def __init__(self, name: str | None) -> None:
323+
def __init__(self, name: str | None, *, check: Callable | None = None) -> None:
318324
self._name = name
325+
self._check = check
319326
# Note: uses will be computed only when the full graph-pattern is constructed.
320327
self._uses: list[tuple[NodePattern, int]] = []
321328

322329
def clone(self, node_map: dict[NodePattern, NodePattern]) -> ValuePattern:
323330
del node_map
324-
return ValuePattern(self._name)
331+
return ValuePattern(self._name, check=self._check)
325332

326333
@property
327334
def name(self) -> str | None:
328335
return self._name
329336

337+
@property
338+
def check_method(self) -> Callable | None:
339+
return self._check
340+
330341
def producer(self) -> NodePattern | None:
331342
return None
332343

@@ -397,6 +408,7 @@ def __init__(
397408
*,
398409
allow_other_attributes: bool | None,
399410
allow_other_inputs: bool | None,
411+
check: Callable | None = None,
400412
):
401413
if allow_other_attributes is None:
402414
# Default behavior: allow other unmatched attributes in the node.
@@ -410,6 +422,7 @@ def __init__(
410422
self.attributes = attributes
411423
self.allow_other_attributes = allow_other_attributes
412424
self.allow_other_inputs = allow_other_inputs
425+
self._check = check
413426
# In the common case, domain and op are constants, which can be used to optimize matching.
414427
if isinstance(op, str) and isinstance(domain, StringConstantPattern):
415428
# TODO(rama): support overloaded operators.
@@ -445,6 +458,10 @@ def op_identifier(self) -> ir.OperatorIdentifier | None:
445458
def op_type(self) -> str:
446459
return str(self.op)
447460

461+
@property
462+
def check_method(self) -> Callable | None:
463+
return self._check
464+
448465
def matches(self, node: ir.Node, match: _basics.MatchResult) -> _basics.MatchResult:
449466
"""Matches the pattern represented by self against a node.
450467
@@ -498,6 +515,7 @@ def clone(self, node_map: dict[NodePattern, NodePattern], swap: bool) -> NodePat
498515
outputs,
499516
allow_other_attributes=self.allow_other_attributes,
500517
allow_other_inputs=self.allow_other_inputs,
518+
check=self._check,
501519
)
502520
node_map[self] = copied
503521
return copied
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
import unittest
4+
5+
from onnxscript.rewriter import _pattern_ir
6+
7+
8+
class PatternIRTest(unittest.TestCase):
9+
"""Test _pattern_ir module functionality."""
10+
11+
def test_value_pattern_with_check(self):
12+
"""Test ValuePattern with check attribute."""
13+
14+
def value_checker(context, value):
15+
return True
16+
17+
# Test creating ValuePattern with check
18+
value_pattern = _pattern_ir.ValuePattern("test_value", check=value_checker)
19+
self.assertIs(value_pattern._check, value_checker)
20+
self.assertEqual(value_pattern.name, "test_value")
21+
22+
def test_node_pattern_with_check(self):
23+
"""Test NodePattern with check attribute."""
24+
25+
def node_checker(context, node):
26+
return True
27+
28+
# Test creating NodePattern with check
29+
domain_pattern = _pattern_ir.StringConstantPattern("")
30+
inputs = []
31+
attributes = {}
32+
outputs = ["output"]
33+
34+
node_pattern = _pattern_ir.NodePattern(
35+
domain_pattern,
36+
"Add",
37+
inputs,
38+
attributes,
39+
outputs,
40+
allow_other_attributes=True,
41+
allow_other_inputs=True,
42+
check=node_checker,
43+
)
44+
self.assertIs(node_pattern._check, node_checker)
45+
46+
def test_to_value_pattern_with_callable(self):
47+
"""Test _to_value_pattern function with callable input."""
48+
49+
def my_checker(context, value):
50+
return True
51+
52+
result = _pattern_ir._to_value_pattern(my_checker)
53+
self.assertIsInstance(result, _pattern_ir.ValuePattern)
54+
self.assertIs(result._check, my_checker)
55+
self.assertIsNone(result.name)
56+
57+
def test_op_pattern_builder_with_check(self):
58+
"""Test OpPatternBuilder with _check parameter."""
59+
60+
def node_checker(context, node):
61+
return True
62+
63+
# Create OpPatternBuilder and call with _check parameter
64+
opset_builder = _pattern_ir.OpsetPatternBuilder("")
65+
result = opset_builder.Add(None, None, _check=node_checker)
66+
67+
# The result should be a NodeOutputPattern, and its producer should have the check
68+
self.assertTrue(hasattr(result, "producer"))
69+
producer = result.producer()
70+
self.assertIsNotNone(producer)
71+
self.assertTrue(hasattr(producer, "_check"))
72+
self.assertIs(producer._check, node_checker)
73+
74+
75+
if __name__ == "__main__":
76+
unittest.main()

0 commit comments

Comments
 (0)