Skip to content

Commit f18dadc

Browse files
authored
A couple of extensions to rewriter (#1912)
A couple of extensions to the rewriter, motivated by fusion optimization experimentation with SmoLLM. * Support list of constants in match-pattern. * One multi-output scenario is easy to handle with the single-output pattern-matcher (eg. defining a fusion rule for SkipNormalization): namely when the extra outputs are intermediate values used in the computation of the first value. Extend algorithm to handle this scenario using the efficient single-output matching-algorithm. An example for the second point is the following pattern: ```py def skip_norm_pattern(op, input, skip, gamma, epsilon, stash_type): skip_sum = op.Add(input, skip) normalized = op.SimplifiedLayerNormalization( skip_sum, gamma, axis=-1, epsilon=epsilon, stash_type=stash_type, _domain="com.microsoft") return normalized, skip_sum ``` If we successfully find a match for `normalized` (which transitively finds a match for all of the pattern subgraph that leads up to `normalized`), we have also found a successful match for `skip_sum`, so no need for a multi-output match. (Will add test-cases later, as I work through the fusion optimizations I am experimenting with.)
1 parent 3016daa commit f18dadc

File tree

1 file changed

+75
-19
lines changed

1 file changed

+75
-19
lines changed

onnxscript/rewriter/pattern.py

Lines changed: 75 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -282,12 +282,11 @@ def _to_value_pattern(
282282
return x
283283
if isinstance(x, (int, float)):
284284
return Constant(x)
285-
# TODO(rama): support lists of int/float
286-
# if isinstance(x, list):
287-
# if all(isinstance(i, (int, float)) for i in x):
288-
# return Constant(x)
289-
# raise ValueError("Only lists of int/float can be used as a ValuePattern")
290-
# TODO(titaiwang): Could this be wrapped Constant?
285+
if isinstance(x, Sequence):
286+
if all(isinstance(i, (int, float)) for i in x):
287+
return Constant(x)
288+
raise ValueError("Only lists of int/float can be used as a ValuePattern")
289+
291290
raise TypeError(f"Cannot convert {type(x)} to ValuePattern")
292291

293292

@@ -602,10 +601,13 @@ class Constant(ValuePattern):
602601
"""Represents a pattern that matches against a scalar constant value."""
603602

604603
def __init__(
605-
self, value: int | float, rel_tol: float = 1e-5, abs_tol: float = 1e-8
604+
self,
605+
value: int | float | Sequence[int] | Sequence[float],
606+
rel_tol: float = 1e-5,
607+
abs_tol: float = 1e-8,
606608
) -> None:
607609
super().__init__(None)
608-
self._value = value
610+
self._value = list(value) if isinstance(value, Sequence) else value
609611
self._rel_tol = rel_tol
610612
self._abs_tol = abs_tol
611613

@@ -614,7 +616,7 @@ def clone(self, node_map: dict[NodePattern, NodePattern]) -> Constant:
614616
return Constant(self._value, self._rel_tol, self._abs_tol)
615617

616618
@property
617-
def value(self) -> int | float:
619+
def value(self) -> int | float | list[int] | list[float]:
618620
return self._value
619621

620622
def matches(self, value: ir.Value, match: MatchResult) -> MatchResult:
@@ -623,6 +625,24 @@ def matches(self, value: ir.Value, match: MatchResult) -> MatchResult:
623625
return match.fail(f"Value is not a constant, expecting {self.value}.")
624626

625627
constant_value_numpy = constant_value.numpy()
628+
if isinstance(self._value, list):
629+
if constant_value_numpy.shape != (len(self._value),):
630+
return match.fail(f"Value has mismatching shape, expecting ({self.value},).")
631+
if not all(
632+
math.isclose(
633+
constant_value_numpy.item(i),
634+
self._value[i],
635+
rel_tol=self._rel_tol,
636+
abs_tol=self._abs_tol,
637+
)
638+
for i in range(len(self._value))
639+
):
640+
return match.fail(
641+
f"Value mismatch: expected {self._value}, got {constant_value_numpy}."
642+
)
643+
return match
644+
645+
# Scalar constant case:
626646
# TODO (rama): allow users to specify shape requirement, if desired.
627647
if constant_value_numpy.size != 1:
628648
return match.fail(f"Value is not a scalar, expecting {self.value}.")
@@ -664,6 +684,20 @@ def visit(value_patterns: Sequence[ValuePattern | None]) -> None:
664684
return node_patterns
665685

666686

687+
def _add_backward_slice(node: NodePattern, backward_slice: set[NodePattern]) -> None:
688+
"""Adds all nodes in the backward slice of given node to the set `backward_slice`.
689+
690+
The backward slice of a node is the set of all nodes that are reachable from the node
691+
in a backward traversal from the given node.
692+
"""
693+
if node in backward_slice:
694+
return
695+
backward_slice.add(node)
696+
for value_pattern in node.inputs:
697+
if isinstance(value_pattern, NodeOutputPattern):
698+
_add_backward_slice(value_pattern.producer(), backward_slice)
699+
700+
667701
class GraphPattern:
668702
"""Represents a pattern that can be matched against a subgraph."""
669703

@@ -679,8 +713,10 @@ def __init__(
679713
raise ValueError("GraphPattern must have at least one output")
680714
self._nodes = nodes # _nodes_in_pattern(outputs)
681715

682-
# Check if all outputs are produced by the same node.
716+
# Determine the output nodes of the pattern. These are a minimal set of nodes
717+
# whose backward-slices cover the entire pattern.
683718
output_nodes: set[NodePattern] = set()
719+
covered: set[NodePattern] = set()
684720
for value_pattern in outputs:
685721
if not isinstance(value_pattern, ValuePattern):
686722
raise TypeError(
@@ -691,7 +727,11 @@ def __init__(
691727
"Constant values are not allowed as graph pattern outputs."
692728
)
693729
if isinstance(value_pattern, NodeOutputPattern):
694-
output_nodes.add(value_pattern.producer())
730+
candidate = value_pattern.producer()
731+
if candidate not in covered:
732+
output_nodes.add(candidate)
733+
_add_backward_slice(candidate, covered)
734+
695735
self.output_nodes: list[NodePattern] = list(output_nodes)
696736

697737
@property
@@ -924,20 +964,41 @@ def _match_constant(self, pattern_constant: Constant, value: ir.Value) -> bool:
924964
constant_value_numpy = constant_value.numpy()
925965
except FileNotFoundError:
926966
return self.fail(f"Constant value of {value.name} not available.")
967+
968+
pattern_constant_value = pattern_constant._value
969+
970+
if isinstance(pattern_constant_value, list):
971+
expected_shape = (len(pattern_constant_value),)
972+
if constant_value_numpy.shape != expected_shape:
973+
return self.fail(f"Value has mismatching shape, expecting {expected_shape}.")
974+
if not all(
975+
math.isclose(
976+
constant_value_numpy.item(i),
977+
pattern_constant_value[i],
978+
rel_tol=pattern_constant._rel_tol,
979+
abs_tol=pattern_constant._abs_tol,
980+
)
981+
for i in range(len(pattern_constant_value))
982+
):
983+
return self.fail(
984+
f"Value mismatch: expected {pattern_constant_value}, got {constant_value_numpy}."
985+
)
986+
return True
987+
927988
# TODO (rama): allow users to specify shape requirement, if desired.
928989
if constant_value_numpy.size != 1:
929990
return self.fail(
930-
f"Value {value.name} is not a scalar, expecting {pattern_constant.value}.",
991+
f"Value {value.name} is not a scalar, expecting {pattern_constant_value}.",
931992
)
932993

933994
if not math.isclose(
934995
constant_value_numpy.item(),
935-
pattern_constant._value,
996+
pattern_constant_value,
936997
rel_tol=pattern_constant._rel_tol,
937998
abs_tol=pattern_constant._abs_tol,
938999
):
9391000
return self.fail(
940-
f"Constant value mismatch: expected {pattern_constant._value}, got {constant_value_numpy.item()}.",
1001+
f"Constant value mismatch: expected {pattern_constant_value}, got {constant_value_numpy.item()}.",
9411002
)
9421003

9431004
return True
@@ -1079,11 +1140,6 @@ def _match_single_output_node(
10791140
if not _valid_to_replace(match.nodes, output_values):
10801141
return match.fail("Matched nodes have other uses preventing replacement.")
10811142

1082-
if len(node.outputs) != pattern.num_outputs:
1083-
return match.fail(
1084-
f"Number of node outputs mismatch: expected {pattern.num_outputs}, got {len(node.outputs)}."
1085-
)
1086-
10871143
match.outputs.extend(output_values)
10881144
return match
10891145

0 commit comments

Comments
 (0)