Skip to content

Commit 929a7f2

Browse files
authored
Expose the should_fold option to optimize() (#2594)
Signed-off-by: Justin Chu <[email protected]>
1 parent 149d567 commit 929a7f2

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

onnxscript/optimizer/_optimizer.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import logging
6+
from typing import Callable
67

78
import onnx_ir as ir
89
import onnx_ir.passes.common as common_passes
@@ -21,6 +22,7 @@ def optimize_ir(
2122
stop_if_no_change: bool = True,
2223
input_size_limit: int = _constant_folding.DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT,
2324
output_size_limit: int = _constant_folding.DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT,
25+
should_fold: Callable[[ir.Node], bool | None] = lambda node: None,
2426
inline: bool = True,
2527
) -> None:
2628
"""Optimizes a model.
@@ -29,11 +31,15 @@ def optimize_ir(
2931
model: The model to be optimized.
3032
num_iterations: Number of times the optimization loop is repeated.
3133
onnx_shape_inference: Applies node-level shape-inference as part of optimization
34+
stop_if_no_change: Stop the optimization loop if no change is detected in an iteration.
3235
input_size_limit: Will not apply constant folding to ops with any input of size
3336
greater than this. Does not apply to special ops like Shape() and Size().
3437
output_size_limit: Will not rewrite any foldable-op into a Constant op if the size
3538
of the output tensor is greater than this.
36-
stop_if_no_change: Stop the optimization loop if no change is detected in an iteration.
39+
should_fold: An optional function that takes a node and returns True if
40+
the node should be considered for folding.
41+
The function should return True/False value to indicate if this particular
42+
node should be folded, or None to use the default folding rules.
3743
inline: If True, inlines all functions in the model.
3844
"""
3945
passes = [
@@ -43,6 +49,7 @@ def optimize_ir(
4349
shape_inference=onnx_shape_inference,
4450
input_size_limit=input_size_limit,
4551
output_size_limit=output_size_limit,
52+
should_fold=should_fold,
4653
),
4754
rewriter.RewritePass(rewriter._DEFAULT_REWRITE_RULES),
4855
common_passes.RemoveUnusedNodesPass(),

0 commit comments

Comments
 (0)