Skip to content

Commit 4578142

Browse files
authored
Cleanup optimizer (#1904)
Cleanup optimizer by moving older proto-based optimizations into a _legacy folder, renaming files to distinguish internal implementation files, and other minor restructuring.
1 parent 8fef233 commit 4578142

20 files changed

+225
-464
lines changed

.lintrunner.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,11 @@ exclude_patterns = [
4646
'onnxscript/onnx_types.py',
4747
'onnxscript/**/*_test.py', # Skip linting test files for speed
4848
'onnxscript/function_libs/torch_lib/ops/**', # Operators typing do not play well with mypy
49-
'onnxscript/optimizer/evaluator.py', # FIXME
50-
'onnxscript/optimizer/constant_folding.py', # FIXME
49+
'onnxscript/optimizer/_legacy/evaluator.py', # FIXME
50+
'onnxscript/optimizer/_legacy/constant_folding.py', # FIXME
5151
'onnxscript/rewriter/onnxruntime/transformers/fastgelu.py', # FIXME
5252
'onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py', # FIXME
5353
'onnxscript/_legacy_ir/irbuilder.py', # FIXME
54-
'onnxscript/optimizer/fold_constants_v0.py', # FIXME
5554
'onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py', # FIXME
5655
'onnxscript/tools/function_unittest_producer.py', # FIXME
5756
'onnxscript/_legacy_ir/visitor.py', # FIXME

onnxscript/optimizer/__init__.py

Lines changed: 11 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -2,160 +2,22 @@
22
# Licensed under the MIT License.
33
from __future__ import annotations
44

5-
import logging
6-
from typing import Any
7-
85
import onnx
9-
import onnx.shape_inference
10-
11-
from onnxscript import ir, rewriter
12-
from onnxscript.optimizer import _constant_folding, _inliner
13-
from onnxscript.optimizer.constant_folding import fold_constants
14-
from onnxscript.optimizer.remove_unused import remove_unused_nodes
15-
from onnxscript.optimizer.remove_unused_function import remove_unused_functions
16-
from onnxscript.optimizer.simple_function_folding import (
17-
inline_functions_with_unused_outputs,
18-
inline_simple_functions,
19-
)
20-
from onnxscript.rewriter import (
21-
broadcast_to_matmul,
22-
cast_constant_of_shape,
23-
gemm_to_matmul_add,
24-
no_op,
25-
)
26-
27-
logger = logging.getLogger(__name__)
28-
29-
_DEFAULT_REWRITE_RULES = [
30-
*no_op.rules.rules, # TODO: merge this rule into constant folding?
31-
*broadcast_to_matmul.rules.rules,
32-
gemm_to_matmul_add.rule,
33-
*cast_constant_of_shape.rules.rules,
34-
]
35-
36-
37-
def optimize(
38-
model: onnx.ModelProto,
39-
num_iterations: int = 2,
40-
*,
41-
onnx_shape_inference: bool = True,
42-
stop_if_no_change: bool = True,
43-
external_data_folder: str = "",
44-
**kwargs: Any,
45-
) -> onnx.ModelProto:
46-
"""Optimize the model. Perform optimizations and clean-ups such as constant folding, dead code elimination, etc.
47-
48-
Args:
49-
model (onnx.ModelProto): The model to optimize.
50-
num_iterations (int, optional): Number of iterations to perform.
51-
onnx_shape_inference (bool, optional): Whether to perform onnx shape inference on the model.
52-
Set this to False to turn off onnx shape inference, and rely on model carried shapes and types.
53-
This is useful for models produced by PyTorch 2.2+ dynamo onnx exporter, where the model carries
54-
the symbolic shapes recorded from dynamo tracing.
55-
stop_if_no_change (bool, optional): Whether to stop if no change is detected.
56-
external_data_folder (str, optional): The folder to store external data.
57-
**kwargs: Additional keyword arguments. For BC purposes.
58-
"""
59-
if kwargs.pop("function_aware_folding", None) is not None:
60-
logger.warning(
61-
"'function_aware_folding' is deprecated. 'optimize' now supports both fully inlined models and models with functions. "
62-
"To achieve the same behavior as 'function_aware_folding=True' before, set 'onnx_shape_inference=False'. "
63-
"This would turn off incremental onnx shape inference and rely on model carried shapes and types. "
64-
"See 'onnx_shape_inference' for more details."
65-
)
66-
for _ in range(num_iterations):
67-
if onnx_shape_inference:
68-
if model.ByteSize() < 1024 * 1024 * 1024 * 2:
69-
# NOTE: strict mode is disabled because it crashes on the models
70-
# that have different shapes inferred from the model carried shapes.
71-
# The case can be found in:
72-
# https://github.com/microsoft/onnxscript/issues/1443
73-
model = onnx.shape_inference.infer_shapes(
74-
model, check_type=True, strict_mode=False, data_prop=True
75-
)
76-
else:
77-
logger.warning(
78-
"The model size is too large for full model shape inference. "
79-
"Skipping this step."
80-
)
81-
82-
inline_simple_functions(model)
83-
modified = fold_constants(
84-
model, external_data_folder, onnx_shape_inference=onnx_shape_inference
85-
)
86-
87-
remove_unused_nodes(model)
88-
inline_simple_functions(model)
89-
model = remove_unused_functions(model)
90-
inline_functions_with_unused_outputs(model)
91-
# NOTE: This is general rewrite rules
92-
model = rewriter.rewrite(model, pattern_rewrite_rules=_DEFAULT_REWRITE_RULES)
93-
if stop_if_no_change and not modified:
94-
logger.debug("Stopping after %d iterations.", _)
95-
break
96-
97-
for node in model.graph.node:
98-
logger.debug("Node %s::%s name %s.", node.domain, node.op_type, node.name)
99-
100-
for function in model.functions:
101-
for node in function.node:
102-
logger.debug(
103-
"Function %s::%s node %s::%s name %s.",
104-
function.domain,
105-
function.name,
106-
node.domain,
107-
node.op_type,
108-
node.name,
109-
)
110-
111-
return model
112-
113-
114-
_DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT = (
115-
_constant_folding._DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT
116-
)
117-
118-
_DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT = (
119-
_constant_folding._DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT
120-
)
121-
1226

123-
def optimize_ir(
124-
model: ir.Model,
125-
num_iterations: int = 2,
126-
*,
127-
onnx_shape_inference: bool = True,
128-
stop_if_no_change: bool = True,
129-
input_size_limit: int = _DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT,
130-
output_size_limit: int = _DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT,
131-
) -> None:
132-
"""Optimizes a model.
7+
import onnxscript.optimizer._legacy._optimizer as legacy_optimizer
8+
from onnxscript import ir
9+
from onnxscript.optimizer._constant_folding import basic_constant_propagation
10+
from onnxscript.optimizer._legacy.constant_folding import fold_constants
11+
from onnxscript.optimizer._optimizer import optimize_ir
12+
from onnxscript.optimizer._remove_unused import remove_unused_nodes
13313

134-
Args:
135-
model: The model to be optimized.
136-
num_iterations: Number of times the optimization loop is repeated.
137-
onnx_shape_inference: Applies node-level shape-inference as part of optimization
138-
input_size_limit: Will not apply constant folding to ops with any input of size
139-
greater than this. Does not apply to special ops like Shape() and Size().
140-
output_size_limit: Will not rewrite any foldable-op into a Constant op if the size
141-
of the output tensor is greater than this.
142-
stop_if_no_change: Not supported currently (has no effect). Meant to stop the
143-
outer optimization loop if no change is detected in one iteration.
144-
"""
145-
del stop_if_no_change # Looks like rewriter doesn't support this yet.
146-
_inliner.inline(model)
147-
for _ in range(num_iterations):
148-
_constant_folding.fold_constants(
149-
model,
150-
onnx_shape_inference=onnx_shape_inference,
151-
input_size_limit=input_size_limit,
152-
output_size_limit=output_size_limit,
153-
)
154-
rewriter.rewrite(model, pattern_rewrite_rules=_DEFAULT_REWRITE_RULES)
155-
remove_unused_nodes(model)
15614

15+
def optimize(model: ir.Model | onnx.ModelProto, *args, **kwargs):
16+
if isinstance(model, ir.Model):
17+
return optimize_ir(model, *args, **kwargs)
18+
else:
19+
return legacy_optimizer.optimize(model, *args, **kwargs)
15720

158-
basic_constant_propagation = _constant_folding.basic_constant_propagation
15921

16022
__all__ = [
16123
"fold_constants",

onnxscript/optimizer/_constant_folding.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,32 @@
1717

1818
import onnxscript.ir as ir
1919
import onnxscript.ir._convenience as _convenience
20-
import onnxscript.optimizer.constant_folding as constant_folding
2120
import onnxscript.rewriter.pattern as orp
2221
import onnxscript.utils.utils as utils
2322

23+
DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT = 1024
24+
25+
DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT = 1024 * 1024
26+
2427

2528
def is_control_flow_op(node: ir.Node) -> bool:
2629
graph_types = {ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS}
2730
return any(attr.type in graph_types for attr in node.attributes.values())
2831

2932

33+
non_deterministic_ops = frozenset(
34+
{
35+
"RandomUniform",
36+
"RandomNormal",
37+
"RandomUniformLike",
38+
"RandomNormalLike",
39+
"Multinomial",
40+
}
41+
)
42+
43+
3044
def is_non_deterministic_op(node: ir.Node) -> bool:
31-
return node.op_type in constant_folding.non_deterministic_ops and utils.is_onnx_domain(
32-
node.domain
33-
)
45+
return node.op_type in non_deterministic_ops and utils.is_onnx_domain(node.domain)
3446

3547

3648
def is_onnx_op(node: ir.Node, op_type: str) -> bool:
@@ -43,10 +55,6 @@ def is_constant_op(node: ir.Node) -> bool:
4355
)
4456

4557

46-
_DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT = 1024
47-
48-
_DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT = constant_folding._DEFAULT_CONSTANT_FOLD_SIZE_LIMIT
49-
5058
logger = logging.getLogger(__name__)
5159

5260
# "Standard" evaluators are used to perform constant-folding.
@@ -787,8 +795,8 @@ def fold_constants(
787795
external_data_folder: str = "",
788796
*,
789797
onnx_shape_inference: bool = False,
790-
input_size_limit: int = _DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT,
791-
output_size_limit: int = _DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT,
798+
input_size_limit: int = DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT,
799+
output_size_limit: int = DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT,
792800
) -> bool:
793801
"""
794802
Applies constant folding optimization to the model.

onnxscript/optimizer/constant_folding_test.py renamed to onnxscript/optimizer/_constant_folding_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99
import onnxscript.optimizer as optimizer
1010
from onnxscript.ir import serde
11-
from onnxscript.optimizer import _constant_folding, constant_folding
11+
from onnxscript.optimizer import _constant_folding
12+
from onnxscript.optimizer._legacy import constant_folding
1213

1314

1415
@parameterized.parameterized_class(("using_ir",), [(False,), (True,)])
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from __future__ import annotations
4+
5+
import logging
6+
from typing import Any
7+
8+
import onnx
9+
import onnx.shape_inference
10+
11+
from onnxscript import rewriter
12+
from onnxscript.optimizer._legacy._simple_function_folding import (
13+
inline_functions_with_unused_outputs,
14+
inline_simple_functions,
15+
)
16+
from onnxscript.optimizer._legacy.constant_folding import fold_constants
17+
from onnxscript.optimizer._optimizer import _DEFAULT_REWRITE_RULES
18+
from onnxscript.optimizer._remove_unused import remove_unused_nodes
19+
from onnxscript.optimizer._remove_unused_function import remove_unused_functions
20+
21+
logger = logging.getLogger(__name__)
22+
23+
24+
def optimize(
25+
model: onnx.ModelProto,
26+
num_iterations: int = 2,
27+
*,
28+
onnx_shape_inference: bool = True,
29+
stop_if_no_change: bool = True,
30+
external_data_folder: str = "",
31+
**kwargs: Any,
32+
) -> onnx.ModelProto:
33+
"""Optimize the model. Perform optimizations and clean-ups such as constant folding, dead code elimination, etc.
34+
35+
Args:
36+
model (onnx.ModelProto): The model to optimize.
37+
num_iterations (int, optional): Number of iterations to perform.
38+
onnx_shape_inference (bool, optional): Whether to perform onnx shape inference on the model.
39+
Set this to False to turn off onnx shape inference, and rely on model carried shapes and types.
40+
This is useful for models produced by PyTorch 2.2+ dynamo onnx exporter, where the model carries
41+
the symbolic shapes recorded from dynamo tracing.
42+
stop_if_no_change (bool, optional): Whether to stop if no change is detected.
43+
external_data_folder (str, optional): The folder to store external data.
44+
**kwargs: Additional keyword arguments. For BC purposes.
45+
"""
46+
if kwargs.pop("function_aware_folding", None) is not None:
47+
logger.warning(
48+
"'function_aware_folding' is deprecated. 'optimize' now supports both fully inlined models and models with functions. "
49+
"To achieve the same behavior as 'function_aware_folding=True' before, set 'onnx_shape_inference=False'. "
50+
"This would turn off incremental onnx shape inference and rely on model carried shapes and types. "
51+
"See 'onnx_shape_inference' for more details."
52+
)
53+
for _ in range(num_iterations):
54+
if onnx_shape_inference:
55+
if model.ByteSize() < 1024 * 1024 * 1024 * 2:
56+
# NOTE: strict mode is disabled because it crashes on the models
57+
# that have different shapes inferred from the model carried shapes.
58+
# The case can be found in:
59+
# https://github.com/microsoft/onnxscript/issues/1443
60+
model = onnx.shape_inference.infer_shapes(
61+
model, check_type=True, strict_mode=False, data_prop=True
62+
)
63+
else:
64+
logger.warning(
65+
"The model size is too large for full model shape inference. "
66+
"Skipping this step."
67+
)
68+
69+
inline_simple_functions(model)
70+
modified = fold_constants(
71+
model, external_data_folder, onnx_shape_inference=onnx_shape_inference
72+
)
73+
74+
remove_unused_nodes(model)
75+
inline_simple_functions(model)
76+
model = remove_unused_functions(model)
77+
inline_functions_with_unused_outputs(model)
78+
# NOTE: This is general rewrite rules
79+
model = rewriter.rewrite(model, pattern_rewrite_rules=_DEFAULT_REWRITE_RULES)
80+
if stop_if_no_change and not modified:
81+
logger.debug("Stopping after %d iterations.", _)
82+
break
83+
84+
for node in model.graph.node:
85+
logger.debug("Node %s::%s name %s.", node.domain, node.op_type, node.name)
86+
87+
for function in model.functions:
88+
for node in function.node:
89+
logger.debug(
90+
"Function %s::%s node %s::%s name %s.",
91+
function.domain,
92+
function.name,
93+
node.domain,
94+
node.op_type,
95+
node.name,
96+
)
97+
98+
return model

onnxscript/optimizer/simple_function_folding.py renamed to onnxscript/optimizer/_legacy/_simple_function_folding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import onnxscript._legacy_ir as ir
1313
from onnxscript._legacy_ir import visitor
14-
from onnxscript.optimizer import remove_unused_proto
14+
from onnxscript.optimizer._legacy import _remove_unused_proto
1515

1616
logger = logging.getLogger(__name__)
1717

@@ -168,7 +168,7 @@ def _find_nodes_with_any_unused_output(
168168
# All unused output means the node is not used at all.
169169
# Hence do not update used_values with the node's inputs.
170170
continue
171-
used_values |= remove_unused_proto.compute_used_in_node(node)
171+
used_values |= _remove_unused_proto.compute_used_in_node(node)
172172
return target_nodes
173173

174174
def visit_model(self, model: onnx.ModelProto) -> None:

0 commit comments

Comments
 (0)