@@ -282,12 +282,11 @@ def _to_value_pattern(
282282 return x
283283 if isinstance (x , (int , float )):
284284 return Constant (x )
285- # TODO(rama): support lists of int/float
286- # if isinstance(x, list):
287- # if all(isinstance(i, (int, float)) for i in x):
288- # return Constant(x)
289- # raise ValueError("Only lists of int/float can be used as a ValuePattern")
290- # TODO(titaiwang): Could this be wrapped Constant?
285+ if isinstance (x , Sequence ):
286+ if all (isinstance (i , (int , float )) for i in x ):
287+ return Constant (x )
288+ raise ValueError ("Only lists of int/float can be used as a ValuePattern" )
289+
291290 raise TypeError (f"Cannot convert { type (x )} to ValuePattern" )
292291
293292
@@ -602,10 +601,13 @@ class Constant(ValuePattern):
602601 """Represents a pattern that matches against a scalar constant value."""
603602
604603 def __init__ (
605- self , value : int | float , rel_tol : float = 1e-5 , abs_tol : float = 1e-8
604+ self ,
605+ value : int | float | Sequence [int ] | Sequence [float ],
606+ rel_tol : float = 1e-5 ,
607+ abs_tol : float = 1e-8 ,
606608 ) -> None :
607609 super ().__init__ (None )
608- self ._value = value
610+ self ._value = list ( value ) if isinstance ( value , Sequence ) else value
609611 self ._rel_tol = rel_tol
610612 self ._abs_tol = abs_tol
611613
@@ -614,7 +616,7 @@ def clone(self, node_map: dict[NodePattern, NodePattern]) -> Constant:
614616 return Constant (self ._value , self ._rel_tol , self ._abs_tol )
615617
616618 @property
617- def value (self ) -> int | float :
619+ def value (self ) -> int | float | list [ int ] | list [ float ] :
618620 return self ._value
619621
620622 def matches (self , value : ir .Value , match : MatchResult ) -> MatchResult :
@@ -623,6 +625,24 @@ def matches(self, value: ir.Value, match: MatchResult) -> MatchResult:
623625 return match .fail (f"Value is not a constant, expecting { self .value } ." )
624626
625627 constant_value_numpy = constant_value .numpy ()
628+ if isinstance (self ._value , list ):
629+ if constant_value_numpy .shape != (len (self ._value ),):
630+ return match .fail (f"Value has mismatching shape, expecting ({ self .value } ,)." )
631+ if not all (
632+ math .isclose (
633+ constant_value_numpy .item (i ),
634+ self ._value [i ],
635+ rel_tol = self ._rel_tol ,
636+ abs_tol = self ._abs_tol ,
637+ )
638+ for i in range (len (self ._value ))
639+ ):
640+ return match .fail (
641+ f"Value mismatch: expected { self ._value } , got { constant_value_numpy } ."
642+ )
643+ return match
644+
645+ # Scalar constant case:
626646 # TODO (rama): allow users to specify shape requirement, if desired.
627647 if constant_value_numpy .size != 1 :
628648 return match .fail (f"Value is not a scalar, expecting { self .value } ." )
@@ -664,6 +684,20 @@ def visit(value_patterns: Sequence[ValuePattern | None]) -> None:
664684 return node_patterns
665685
666686
687+ def _add_backward_slice (node : NodePattern , backward_slice : set [NodePattern ]) -> None :
688+ """Adds all nodes in the backward slice of given node to the set `backward_slice`.
689+
690+ The backward slice of a node is the set of all nodes that are reachable from the node
691+ in a backward traversal from the given node.
692+ """
693+ if node in backward_slice :
694+ return
695+ backward_slice .add (node )
696+ for value_pattern in node .inputs :
697+ if isinstance (value_pattern , NodeOutputPattern ):
698+ _add_backward_slice (value_pattern .producer (), backward_slice )
699+
700+
667701class GraphPattern :
668702 """Represents a pattern that can be matched against a subgraph."""
669703
@@ -679,8 +713,10 @@ def __init__(
679713 raise ValueError ("GraphPattern must have at least one output" )
680714 self ._nodes = nodes # _nodes_in_pattern(outputs)
681715
682- # Check if all outputs are produced by the same node.
716+ # Determine the output nodes of the pattern. These are a minimal set of nodes
717+ # whose backward-slices cover the entire pattern.
683718 output_nodes : set [NodePattern ] = set ()
719+ covered : set [NodePattern ] = set ()
684720 for value_pattern in outputs :
685721 if not isinstance (value_pattern , ValuePattern ):
686722 raise TypeError (
@@ -691,7 +727,11 @@ def __init__(
691727 "Constant values are not allowed as graph pattern outputs."
692728 )
693729 if isinstance (value_pattern , NodeOutputPattern ):
694- output_nodes .add (value_pattern .producer ())
730+ candidate = value_pattern .producer ()
731+ if candidate not in covered :
732+ output_nodes .add (candidate )
733+ _add_backward_slice (candidate , covered )
734+
695735 self .output_nodes : list [NodePattern ] = list (output_nodes )
696736
697737 @property
@@ -924,20 +964,41 @@ def _match_constant(self, pattern_constant: Constant, value: ir.Value) -> bool:
924964 constant_value_numpy = constant_value .numpy ()
925965 except FileNotFoundError :
926966 return self .fail (f"Constant value of { value .name } not available." )
967+
968+ pattern_constant_value = pattern_constant ._value
969+
970+ if isinstance (pattern_constant_value , list ):
971+ expected_shape = (len (pattern_constant_value ),)
972+ if constant_value_numpy .shape != expected_shape :
973+ return self .fail (f"Value has mismatching shape, expecting { expected_shape } ." )
974+ if not all (
975+ math .isclose (
976+ constant_value_numpy .item (i ),
977+ pattern_constant_value [i ],
978+ rel_tol = pattern_constant ._rel_tol ,
979+ abs_tol = pattern_constant ._abs_tol ,
980+ )
981+ for i in range (len (pattern_constant_value ))
982+ ):
983+ return self .fail (
984+ f"Value mismatch: expected { pattern_constant_value } , got { constant_value_numpy } ."
985+ )
986+ return True
987+
927988 # TODO (rama): allow users to specify shape requirement, if desired.
928989 if constant_value_numpy .size != 1 :
929990 return self .fail (
930- f"Value { value .name } is not a scalar, expecting { pattern_constant . value } ." ,
991+ f"Value { value .name } is not a scalar, expecting { pattern_constant_value } ." ,
931992 )
932993
933994 if not math .isclose (
934995 constant_value_numpy .item (),
935- pattern_constant . _value ,
996+ pattern_constant_value ,
936997 rel_tol = pattern_constant ._rel_tol ,
937998 abs_tol = pattern_constant ._abs_tol ,
938999 ):
9391000 return self .fail (
940- f"Constant value mismatch: expected { pattern_constant . _value } , got { constant_value_numpy .item ()} ." ,
1001+ f"Constant value mismatch: expected { pattern_constant_value } , got { constant_value_numpy .item ()} ." ,
9411002 )
9421003
9431004 return True
@@ -1079,11 +1140,6 @@ def _match_single_output_node(
10791140 if not _valid_to_replace (match .nodes , output_values ):
10801141 return match .fail ("Matched nodes have other uses preventing replacement." )
10811142
1082- if len (node .outputs ) != pattern .num_outputs :
1083- return match .fail (
1084- f"Number of node outputs mismatch: expected { pattern .num_outputs } , got { len (node .outputs )} ."
1085- )
1086-
10871143 match .outputs .extend (output_values )
10881144 return match
10891145
0 commit comments