Skip to content

Commit 94fb24f

Browse files
authored
Record names of contributing values in the constant folding pass (#2575)
Record names of contributing values in the constant folding pass to the newly created output as metadata, so that downstream users like Olive can use the info for further manipulations. This is useful for Olive to identify transposed lora weights in the graph. --------- Signed-off-by: Justin Chu <[email protected]>
1 parent df8f706 commit 94fb24f

File tree

3 files changed

+36
-8
lines changed

3 files changed

+36
-8
lines changed

docs/api/optimizer.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,4 @@
1515
optimizer.inline
1616
optimizer.basic_constant_propagation
1717
optimizer.fold_constants
18-
optimizer.remove_unused_nodes
1918
```

onnxscript/optimizer/__init__.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,8 @@
1919

2020
import onnxscript.optimizer._constant_folding as constant_folding
2121
from onnxscript import ir
22-
from onnxscript.optimizer._constant_folding import (
23-
basic_constant_propagation,
24-
)
25-
from onnxscript.optimizer._constant_folding import (
26-
fold_constants as fold_constants_ir,
27-
)
22+
from onnxscript.optimizer._constant_folding import basic_constant_propagation
23+
from onnxscript.optimizer._constant_folding import fold_constants as fold_constants_ir
2824
from onnxscript.optimizer._optimizer import optimize_ir
2925

3026
_ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model)

onnxscript/optimizer/_constant_folding.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,13 @@
55

66
from __future__ import annotations
77

8+
__all__ = [
9+
"basic_constant_propagation",
10+
"fold_constants",
11+
"FoldConstantsPass",
12+
"FOLDED_FROM_KEY",
13+
]
14+
815
import dataclasses
916
import logging
1017
import math
@@ -23,6 +30,9 @@
2330

2431
DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT = 512 * 512
2532

33+
# Key used to store the metadata
34+
FOLDED_FROM_KEY = "pkg.onnxscript.optimizer.folded_from"
35+
2636

2737
_NON_DETERMINISTIC_OPS = frozenset(
2838
{
@@ -914,6 +924,24 @@ def merge_dims(dim1, dim2):
914924
return ir.Shape([merge_dims(dim1, dim2) for dim1, dim2 in zip(shape1, shape2)])
915925

916926

927+
def _record_contributing_values(original_node: ir.Node, replacement: Replacement) -> None:
928+
"""Record the set of original input values that contributed to the constant-folded outputs."""
929+
folded_from: set[str] = set()
930+
for input in original_node.inputs:
931+
if input is None:
932+
continue
933+
folded_from.update(input.meta.get(FOLDED_FROM_KEY, set()))
934+
assert input.name is not None
935+
folded_from.add(input.name)
936+
937+
for new_output in replacement.new_outputs:
938+
if new_output is None:
939+
continue
940+
new_output.meta[FOLDED_FROM_KEY] = folded_from
941+
# Store the string representation of the set to metadata_props to persist it across serialization
942+
new_output.metadata_props[FOLDED_FROM_KEY] = repr(sorted(folded_from))
943+
944+
917945
class FoldConstantsPass(ir.passes.InPlacePass):
918946
"""A pass that folds constant expressions in the model.
919947
@@ -1203,9 +1231,14 @@ def convert(av):
12031231
)
12041232
return None
12051233

1206-
def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function) -> None:
1234+
def replace_node(
1235+
self, node: ir.Node, replacement: Replacement, root: ir.Graph | ir.Function
1236+
) -> None:
12071237
logger.debug("Replacing node: %s::%s %s", node.domain, node.op_type, node.name)
12081238

1239+
# Record the names of the values that has contributed to the replacement
1240+
_record_contributing_values(node, replacement)
1241+
12091242
ir.convenience.replace_nodes_and_values(
12101243
root, node, [node], replacement.new_nodes, node.outputs, replacement.new_outputs
12111244
)

0 commit comments

Comments
 (0)