|
| 1 | +(heading-target-checkers)= |
| 2 | +# Node and Value Level Checkers |
| 3 | + |
| 4 | +The pattern matching infrastructure supports custom validation logic at both the node and value levels through checker functions. These checkers allow for more sophisticated pattern matching by enabling additional constraints beyond basic operator and structure matching. |
| 5 | + |
| 6 | +## Value-Level Checkers |
| 7 | + |
| 8 | +Value-level checkers validate properties of specific values in the pattern. They are particularly useful for checking constants, shapes, or other value-specific properties. |
| 9 | + |
| 10 | +### Basic Usage |
| 11 | + |
| 12 | +A value checker is a function that takes a `MatchContext` and an `ir.Value`, and returns either a boolean or a `MatchResult`: |
| 13 | + |
| 14 | +```python |
| 15 | +def is_positive_constant(context, value: ir.Value): |
| 16 | + """Check if a value is a positive constant.""" |
| 17 | + if value.const_value is not None: |
| 18 | + # Get the numpy array from const_value |
| 19 | + numpy_array = value.const_value.numpy() |
| 20 | + |
| 21 | + # Check if it represents a single value and is positive |
| 22 | + if numpy_array.size != 1: |
| 23 | + return False |
| 24 | + |
| 25 | + return float(numpy_array.item()) > 0 |
| 26 | + |
| 27 | + return False |
| 28 | +``` |
| 29 | + |
| 30 | +You can use this checker directly in your pattern by passing the callable as an input: |
| 31 | + |
| 32 | +```python |
| 33 | +def add_pattern(op, x, y): |
| 34 | + # Use callable as input to create ValuePattern with checker |
| 35 | + return op.Add(is_positive_constant, y) |
| 36 | +``` |
| 37 | + |
| 38 | +This pattern will only match `Add` operations where the first input is a positive constant value. |
| 39 | + |
| 40 | +### Example Usage |
| 41 | + |
| 42 | +```python |
| 43 | +from onnxscript.rewriter import pattern |
| 44 | +from onnxscript import ir, optimizer |
| 45 | +import onnx |
| 46 | + |
| 47 | +# Create a model with different Add operations |
| 48 | +model_proto = onnx.parser.parse_model(""" |
| 49 | + <ir_version: 7, opset_import: [ "" : 17]> |
| 50 | + agraph (float[N] x, float[N] y) => (float[N] z1, float[N] z2, float[N] z3) |
| 51 | + { |
| 52 | + pos_const = Constant <value_float = 2.5> () |
| 53 | + neg_const = Constant <value_float = -1.5> () |
| 54 | + z1 = Add(x, y) # non-constant first parameter |
| 55 | + z2 = Add(pos_const, y) # positive constant first parameter |
| 56 | + z3 = Add(neg_const, y) # negative constant first parameter |
| 57 | + } |
| 58 | +""") |
| 59 | +model = ir.serde.deserialize_model(model_proto) |
| 60 | + |
| 61 | +# Apply constant propagation to set const_value fields |
| 62 | +optimizer.basic_constant_propagation(model.graph.all_nodes()) |
| 63 | + |
| 64 | +# Create the pattern with value checker |
| 65 | +rule_pattern = pattern.Pattern(add_pattern) |
| 66 | + |
| 67 | +# Test matching against different Add nodes |
| 68 | +add_nodes = [node for node in model.graph if node.op_type == "Add"] |
| 69 | + |
| 70 | +# Non-constant first parameter - will not match |
| 71 | +match_result = rule_pattern.match(model, model.graph, add_nodes[0]) |
| 72 | +print(f"Non-constant: {bool(match_result)}") # False |
| 73 | + |
| 74 | +# Positive constant first parameter - will match |
| 75 | +match_result = rule_pattern.match(model, model.graph, add_nodes[1]) |
| 76 | +print(f"Positive constant: {bool(match_result)}") # True |
| 77 | + |
| 78 | +# Negative constant first parameter - will not match |
| 79 | +match_result = rule_pattern.match(model, model.graph, add_nodes[2]) |
| 80 | +print(f"Negative constant: {bool(match_result)}") # False |
| 81 | +``` |
| 82 | + |
| 83 | +## Node-Level Checkers |
| 84 | + |
| 85 | +Node-level checkers validate properties of the operation nodes themselves, such as attributes, operation types, or other node-specific properties. |
| 86 | + |
| 87 | +### Basic Usage |
| 88 | + |
| 89 | +A node checker is a function that takes a `MatchContext` and an `ir.Node`, and returns either a boolean or a `MatchResult`: |
| 90 | + |
| 91 | +```python |
| 92 | +def shape_node_checker(context, node): |
| 93 | + """Check if a Shape operation has start attribute equal to 0.""" |
| 94 | + return node.attributes.get_int("start", 0) == 0 |
| 95 | +``` |
| 96 | + |
| 97 | +You can use this checker by passing it to the `_check` parameter of an operation: |
| 98 | + |
| 99 | +```python |
| 100 | +def shape_pattern(op, x): |
| 101 | + return op.Shape(x, _check=shape_node_checker) |
| 102 | +``` |
| 103 | + |
| 104 | +This pattern will only match `Shape` operations where the `start` attribute is 0 (or not present, as the default is 0). |
| 105 | + |
| 106 | +### Example Usage |
| 107 | + |
| 108 | +```python |
| 109 | +from onnxscript.rewriter import pattern |
| 110 | +from onnxscript import ir |
| 111 | +import onnx |
| 112 | + |
| 113 | +# Create a model with different Shape operations |
| 114 | +model_proto = onnx.parser.parse_model(""" |
| 115 | + <ir_version: 7, opset_import: [ "" : 17]> |
| 116 | + agraph (float[N, M] x) => (int64[2] z1, int64[2] z2, int64[1] z3) |
| 117 | + { |
| 118 | + z1 = Shape(x) |
| 119 | + z2 = Shape <start: int = 0>(x) |
| 120 | + z3 = Shape <start: int = 1>(x) |
| 121 | + } |
| 122 | +""") |
| 123 | +model = ir.serde.deserialize_model(model_proto) |
| 124 | + |
| 125 | +# Create the pattern with node checker |
| 126 | +rule_pattern = pattern.Pattern(shape_pattern) |
| 127 | + |
| 128 | +# Test matching against different Shape nodes |
| 129 | +nodes = list(model.graph) |
| 130 | +shape_nodes = [node for node in nodes if node.op_type == "Shape"] |
| 131 | + |
| 132 | +# Shape without start attribute (default 0) - will match |
| 133 | +match_result = rule_pattern.match(model, model.graph, shape_nodes[0]) |
| 134 | +print(f"No start attr: {bool(match_result)}") # True |
| 135 | + |
| 136 | +# Shape with start=0 - will match |
| 137 | +match_result = rule_pattern.match(model, model.graph, shape_nodes[1]) |
| 138 | +print(f"Start=0: {bool(match_result)}") # True |
| 139 | + |
| 140 | +# Shape with start=1 - will not match |
| 141 | +match_result = rule_pattern.match(model, model.graph, shape_nodes[2]) |
| 142 | +print(f"Start=1: {bool(match_result)}") # False |
| 143 | +``` |
| 144 | + |
| 145 | +## Combining Checkers |
| 146 | + |
| 147 | +You can combine both node-level and value-level checkers in the same pattern for more sophisticated matching: |
| 148 | + |
| 149 | +```python |
| 150 | +def complex_pattern(op, x, y): |
| 151 | + # Value-level checker for first input |
| 152 | + validated_x = is_positive_constant |
| 153 | + # Node-level checker for the operation |
| 154 | + return op.Add(validated_x, y, _check=lambda ctx, node: len(node.attributes) == 0) |
| 155 | +``` |
| 156 | + |
| 157 | +This pattern will only match `Add` operations where: |
| 158 | +1. The first input is a positive constant (value-level check) |
| 159 | +2. The node has no custom attributes (node-level check) |
| 160 | + |
| 161 | +## Execution Timing and Limitations |
| 162 | + |
| 163 | +### When Checkers Are Called |
| 164 | + |
| 165 | +Node-level and value-level checkers are called **only at the end of the complete structural match**. This means: |
| 166 | + |
| 167 | +1. **Structural matching happens first**: The pattern matching engine first validates that the graph structure matches the pattern (correct operators, connections, etc.) |
| 168 | +2. **Checkers run after structural validation**: Only after the structural match succeeds do the node and value checkers execute |
| 169 | +3. **Order of execution**: Value-level checkers run first, followed by node-level checkers, and finally the pattern's condition function |
| 170 | + |
| 171 | +### Limitations with Pattern Disjunctions |
| 172 | + |
| 173 | +One important limitation of this design is that these checks don't compose well with pattern disjunctions (multiple alternative patterns). When searching among multiple value patterns: |
| 174 | + |
| 175 | +- **Only structural checking is performed initially**: If structural matching succeeds for the first alternative, other alternatives are not considered |
| 176 | +- **Checker failures don't trigger backtracking**: If a checker fails, the entire pattern match fails rather than trying the next alternative pattern |
| 177 | + |
| 178 | +This means you should be careful when designing patterns with multiple alternatives that rely on checkers, as the checker logic may prevent exploration of valid alternative matches. |
| 179 | + |
| 180 | +## Error Handling |
| 181 | + |
| 182 | +Checkers can return either: |
| 183 | +- `True`: Check passed, continue matching |
| 184 | +- `False`: Check failed, pattern does not match |
| 185 | +- `MatchResult`: More detailed result with potential failure reasons |
| 186 | + |
| 187 | +If a checker raises an exception, it will be caught and treated as a match failure, allowing patterns to fail gracefully when encountering unexpected conditions. |
0 commit comments