|
5 | 5 |
|
6 | 6 | from __future__ import annotations
|
7 | 7 |
|
| 8 | +__all__ = [ |
| 9 | + "basic_constant_propagation", |
| 10 | + "fold_constants", |
| 11 | + "FoldConstantsPass", |
| 12 | + "FOLDED_FROM_KEY", |
| 13 | +] |
| 14 | + |
8 | 15 | import dataclasses
|
9 | 16 | import logging
|
10 | 17 | import math
|
|
23 | 30 |
|
24 | 31 | DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT = 512 * 512
|
25 | 32 |
|
| 33 | +# Key used to store the metadata |
| 34 | +FOLDED_FROM_KEY = "pkg.onnxscript.optimizer.folded_from" |
| 35 | + |
26 | 36 |
|
27 | 37 | _NON_DETERMINISTIC_OPS = frozenset(
|
28 | 38 | {
|
@@ -914,6 +924,24 @@ def merge_dims(dim1, dim2):
|
914 | 924 | return ir.Shape([merge_dims(dim1, dim2) for dim1, dim2 in zip(shape1, shape2)])
|
915 | 925 |
|
916 | 926 |
|
| 927 | +def _record_contributing_values(original_node: ir.Node, replacement: Replacement) -> None: |
| 928 | + """Record the set of original input values that contributed to the constant-folded outputs.""" |
| 929 | + folded_from: set[str] = set() |
| 930 | + for input in original_node.inputs: |
| 931 | + if input is None: |
| 932 | + continue |
| 933 | + folded_from.update(input.meta.get(FOLDED_FROM_KEY, set())) |
| 934 | + assert input.name is not None |
| 935 | + folded_from.add(input.name) |
| 936 | + |
| 937 | + for new_output in replacement.new_outputs: |
| 938 | + if new_output is None: |
| 939 | + continue |
| 940 | + new_output.meta[FOLDED_FROM_KEY] = folded_from |
| 941 | + # Store the string representation of the set to metadata_props to persist it across serialization |
| 942 | + new_output.metadata_props[FOLDED_FROM_KEY] = repr(sorted(folded_from)) |
| 943 | + |
| 944 | + |
917 | 945 | class FoldConstantsPass(ir.passes.InPlacePass):
|
918 | 946 | """A pass that folds constant expressions in the model.
|
919 | 947 |
|
@@ -1203,9 +1231,14 @@ def convert(av):
|
1203 | 1231 | )
|
1204 | 1232 | return None
|
1205 | 1233 |
|
1206 |
| - def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function) -> None: |
| 1234 | + def replace_node( |
| 1235 | + self, node: ir.Node, replacement: Replacement, root: ir.Graph | ir.Function |
| 1236 | + ) -> None: |
1207 | 1237 | logger.debug("Replacing node: %s::%s %s", node.domain, node.op_type, node.name)
|
1208 | 1238 |
|
| 1239 | + # Record the names of the values that has contributed to the replacement |
| 1240 | + _record_contributing_values(node, replacement) |
| 1241 | + |
1209 | 1242 | ir.convenience.replace_nodes_and_values(
|
1210 | 1243 | root, node, [node], replacement.new_nodes, node.outputs, replacement.new_outputs
|
1211 | 1244 | )
|
|
0 commit comments