3
3
from __future__ import annotations
4
4
5
5
import logging
6
+ from typing import Callable
6
7
7
8
import onnx_ir as ir
8
9
import onnx_ir .passes .common as common_passes
@@ -21,6 +22,7 @@ def optimize_ir(
21
22
stop_if_no_change : bool = True ,
22
23
input_size_limit : int = _constant_folding .DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT ,
23
24
output_size_limit : int = _constant_folding .DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT ,
25
+ should_fold : Callable [[ir .Node ], bool | None ] = lambda node : None ,
24
26
inline : bool = True ,
25
27
) -> None :
26
28
"""Optimizes a model.
@@ -29,11 +31,15 @@ def optimize_ir(
29
31
model: The model to be optimized.
30
32
num_iterations: Number of times the optimization loop is repeated.
31
33
onnx_shape_inference: Applies node-level shape-inference as part of optimization
34
+ stop_if_no_change: Stop the optimization loop if no change is detected in an iteration.
32
35
input_size_limit: Will not apply constant folding to ops with any input of size
33
36
greater than this. Does not apply to special ops like Shape() and Size().
34
37
output_size_limit: Will not rewrite any foldable-op into a Constant op if the size
35
38
of the output tensor is greater than this.
36
- stop_if_no_change: Stop the optimization loop if no change is detected in an iteration.
39
+ should_fold: An optional function that takes a node and returns True if
40
+ the node should be considered for folding.
41
+ The function should return True/False value to indicate if this particular
42
+ node should be folded, or None to use the default folding rules.
37
43
inline: If True, inlines all functions in the model.
38
44
"""
39
45
passes = [
@@ -43,6 +49,7 @@ def optimize_ir(
43
49
shape_inference = onnx_shape_inference ,
44
50
input_size_limit = input_size_limit ,
45
51
output_size_limit = output_size_limit ,
52
+ should_fold = should_fold ,
46
53
),
47
54
rewriter .RewritePass (rewriter ._DEFAULT_REWRITE_RULES ),
48
55
common_passes .RemoveUnusedNodesPass (),
0 commit comments