|
2 | 2 | # Licensed under the MIT License. |
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | | -import logging |
6 | | -from typing import Any |
7 | | - |
8 | 5 | import onnx |
9 | | -import onnx.shape_inference |
10 | | - |
11 | | -from onnxscript import ir, rewriter |
12 | | -from onnxscript.optimizer import _constant_folding, _inliner |
13 | | -from onnxscript.optimizer.constant_folding import fold_constants |
14 | | -from onnxscript.optimizer.remove_unused import remove_unused_nodes |
15 | | -from onnxscript.optimizer.remove_unused_function import remove_unused_functions |
16 | | -from onnxscript.optimizer.simple_function_folding import ( |
17 | | - inline_functions_with_unused_outputs, |
18 | | - inline_simple_functions, |
19 | | -) |
20 | | -from onnxscript.rewriter import ( |
21 | | - broadcast_to_matmul, |
22 | | - cast_constant_of_shape, |
23 | | - gemm_to_matmul_add, |
24 | | - no_op, |
25 | | -) |
26 | | - |
27 | | -logger = logging.getLogger(__name__) |
28 | | - |
29 | | -_DEFAULT_REWRITE_RULES = [ |
30 | | - *no_op.rules.rules, # TODO: merge this rule into constant folding? |
31 | | - *broadcast_to_matmul.rules.rules, |
32 | | - gemm_to_matmul_add.rule, |
33 | | - *cast_constant_of_shape.rules.rules, |
34 | | -] |
35 | | - |
36 | | - |
37 | | -def optimize( |
38 | | - model: onnx.ModelProto, |
39 | | - num_iterations: int = 2, |
40 | | - *, |
41 | | - onnx_shape_inference: bool = True, |
42 | | - stop_if_no_change: bool = True, |
43 | | - external_data_folder: str = "", |
44 | | - **kwargs: Any, |
45 | | -) -> onnx.ModelProto: |
46 | | - """Optimize the model. Perform optimizations and clean-ups such as constant folding, dead code elimination, etc. |
47 | | -
|
48 | | - Args: |
49 | | - model (onnx.ModelProto): The model to optimize. |
50 | | - num_iterations (int, optional): Number of iterations to perform. |
51 | | - onnx_shape_inference (bool, optional): Whether to perform onnx shape inference on the model. |
52 | | - Set this to False to turn off onnx shape inference, and rely on model carried shapes and types. |
53 | | - This is useful for models produced by PyTorch 2.2+ dynamo onnx exporter, where the model carries |
54 | | - the symbolic shapes recorded from dynamo tracing. |
55 | | - stop_if_no_change (bool, optional): Whether to stop if no change is detected. |
56 | | - external_data_folder (str, optional): The folder to store external data. |
57 | | - **kwargs: Additional keyword arguments. For BC purposes. |
58 | | - """ |
59 | | - if kwargs.pop("function_aware_folding", None) is not None: |
60 | | - logger.warning( |
61 | | - "'function_aware_folding' is deprecated. 'optimize' now supports both fully inlined models and models with functions. " |
62 | | - "To achieve the same behavior as 'function_aware_folding=True' before, set 'onnx_shape_inference=False'. " |
63 | | - "This would turn off incremental onnx shape inference and rely on model carried shapes and types. " |
64 | | - "See 'onnx_shape_inference' for more details." |
65 | | - ) |
66 | | - for _ in range(num_iterations): |
67 | | - if onnx_shape_inference: |
68 | | - if model.ByteSize() < 1024 * 1024 * 1024 * 2: |
69 | | - # NOTE: strict mode is disabled because it crashes on the models |
70 | | - # that have different shapes inferred from the model carried shapes. |
71 | | - # The case can be found in: |
72 | | - # https://github.com/microsoft/onnxscript/issues/1443 |
73 | | - model = onnx.shape_inference.infer_shapes( |
74 | | - model, check_type=True, strict_mode=False, data_prop=True |
75 | | - ) |
76 | | - else: |
77 | | - logger.warning( |
78 | | - "The model size is too large for full model shape inference. " |
79 | | - "Skipping this step." |
80 | | - ) |
81 | | - |
82 | | - inline_simple_functions(model) |
83 | | - modified = fold_constants( |
84 | | - model, external_data_folder, onnx_shape_inference=onnx_shape_inference |
85 | | - ) |
86 | | - |
87 | | - remove_unused_nodes(model) |
88 | | - inline_simple_functions(model) |
89 | | - model = remove_unused_functions(model) |
90 | | - inline_functions_with_unused_outputs(model) |
91 | | - # NOTE: This is general rewrite rules |
92 | | - model = rewriter.rewrite(model, pattern_rewrite_rules=_DEFAULT_REWRITE_RULES) |
93 | | - if stop_if_no_change and not modified: |
94 | | - logger.debug("Stopping after %d iterations.", _) |
95 | | - break |
96 | | - |
97 | | - for node in model.graph.node: |
98 | | - logger.debug("Node %s::%s name %s.", node.domain, node.op_type, node.name) |
99 | | - |
100 | | - for function in model.functions: |
101 | | - for node in function.node: |
102 | | - logger.debug( |
103 | | - "Function %s::%s node %s::%s name %s.", |
104 | | - function.domain, |
105 | | - function.name, |
106 | | - node.domain, |
107 | | - node.op_type, |
108 | | - node.name, |
109 | | - ) |
110 | | - |
111 | | - return model |
112 | | - |
113 | | - |
114 | | -_DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT = ( |
115 | | - _constant_folding._DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT |
116 | | -) |
117 | | - |
118 | | -_DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT = ( |
119 | | - _constant_folding._DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT |
120 | | -) |
121 | | - |
122 | 6 |
|
123 | | -def optimize_ir( |
124 | | - model: ir.Model, |
125 | | - num_iterations: int = 2, |
126 | | - *, |
127 | | - onnx_shape_inference: bool = True, |
128 | | - stop_if_no_change: bool = True, |
129 | | - input_size_limit: int = _DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT, |
130 | | - output_size_limit: int = _DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT, |
131 | | -) -> None: |
132 | | - """Optimizes a model. |
| 7 | +import onnxscript.optimizer._legacy._optimizer as legacy_optimizer |
| 8 | +from onnxscript import ir |
| 9 | +from onnxscript.optimizer._constant_folding import basic_constant_propagation |
| 10 | +from onnxscript.optimizer._legacy.constant_folding import fold_constants |
| 11 | +from onnxscript.optimizer._optimizer import optimize_ir |
| 12 | +from onnxscript.optimizer._remove_unused import remove_unused_nodes |
133 | 13 |
|
134 | | - Args: |
135 | | - model: The model to be optimized. |
136 | | - num_iterations: Number of times the optimization loop is repeated. |
137 | | - onnx_shape_inference: Applies node-level shape-inference as part of optimization |
138 | | - input_size_limit: Will not apply constant folding to ops with any input of size |
139 | | - greater than this. Does not apply to special ops like Shape() and Size(). |
140 | | - output_size_limit: Will not rewrite any foldable-op into a Constant op if the size |
141 | | - of the output tensor is greater than this. |
142 | | - stop_if_no_change: Not supported currently (has no effect). Meant to stop the |
143 | | - outer optimization loop if no change is detected in one iteration. |
144 | | - """ |
145 | | - del stop_if_no_change # Looks like rewriter doesn't support this yet. |
146 | | - _inliner.inline(model) |
147 | | - for _ in range(num_iterations): |
148 | | - _constant_folding.fold_constants( |
149 | | - model, |
150 | | - onnx_shape_inference=onnx_shape_inference, |
151 | | - input_size_limit=input_size_limit, |
152 | | - output_size_limit=output_size_limit, |
153 | | - ) |
154 | | - rewriter.rewrite(model, pattern_rewrite_rules=_DEFAULT_REWRITE_RULES) |
155 | | - remove_unused_nodes(model) |
156 | 14 |
|
| 15 | +def optimize(model: ir.Model | onnx.ModelProto, *args, **kwargs): |
| 16 | + if isinstance(model, ir.Model): |
| 17 | + return optimize_ir(model, *args, **kwargs) |
| 18 | + else: |
| 19 | + return legacy_optimizer.optimize(model, *args, **kwargs) |
157 | 20 |
|
158 | | -basic_constant_propagation = _constant_folding.basic_constant_propagation |
159 | 21 |
|
160 | 22 | __all__ = [ |
161 | 23 | "fold_constants", |
|
0 commit comments