Skip to content

Commit cfd3816

Browse files
committed
Use input size limits for constant folding
1 parent ed28222 commit cfd3816

File tree

2 files changed

+43
-8
lines changed

2 files changed

+43
-8
lines changed

onnxscript/optimizer/__init__.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,17 +111,33 @@ def optimize(
111111
return model
112112

113113

114+
_DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT = (
115+
_constant_folding._DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT
116+
)
117+
118+
_DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT = (
119+
_constant_folding._DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT
120+
)
121+
122+
114123
def optimize_ir(
115124
model: ir.Model,
116125
num_iterations: int = 2,
117126
*,
118127
onnx_shape_inference: bool = True,
119128
stop_if_no_change: bool = True,
129+
input_size_limit: int = _DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT,
130+
output_size_limit: int = _DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT,
120131
) -> None:
121132
del stop_if_no_change # Looks like rewriter doesn't support this yet.
122133
_inliner.inline(model)
123134
for _ in range(num_iterations):
124-
_constant_folding.fold_constants(model, onnx_shape_inference=onnx_shape_inference)
135+
_constant_folding.fold_constants(
136+
model,
137+
onnx_shape_inference=onnx_shape_inference,
138+
input_size_limit=input_size_limit,
139+
output_size_limit=output_size_limit,
140+
)
125141
rewriter.rewrite(model, pattern_rewrite_rules=_DEFAULT_REWRITE_RULES)
126142
remove_unused_nodes(model)
127143

onnxscript/optimizer/_constant_folding.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ def is_constant_op(node: ir.Node) -> bool:
4343
)
4444

4545

46-
_DEFAULT_CONSTANT_FOLD_SIZE_LIMIT = constant_folding._DEFAULT_CONSTANT_FOLD_SIZE_LIMIT
46+
_DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT = 1024
47+
48+
_DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT = constant_folding._DEFAULT_CONSTANT_FOLD_SIZE_LIMIT
4749

4850
logger = logging.getLogger(__name__)
4951

@@ -550,11 +552,16 @@ class ConstantFolder:
550552

551553
def __init__(
552554
self,
555+
*,
553556
external_data_folder: str,
554-
do_shape_inference: bool,
557+
shape_inference: bool,
558+
input_size_limit: int,
559+
output_size_limit: int,
555560
) -> None:
556561
self._external_data_folder = external_data_folder
557-
self._do_shape_inference = do_shape_inference
562+
self._shape_inference = shape_inference
563+
self._input_size_limit = input_size_limit
564+
self._output_size_limit = output_size_limit
558565
self._init()
559566

560567
def _init(self) -> None:
@@ -632,7 +639,7 @@ def new_constant(self, irvalue: ir.Value, value):
632639

633640
irvalue.const_value = _convenience.tensor(value)
634641

635-
if value.nbytes > _DEFAULT_CONSTANT_FOLD_SIZE_LIMIT:
642+
if value.nbytes > self._output_size_limit:
636643
logger.info(
637644
"Skip storing constant folded nvalue %s due to large size %s.",
638645
irvalue.name,
@@ -667,7 +674,7 @@ def process_node(self, node: ir.Node):
667674
# TODO(rama): consider merging type/other info from both values
668675

669676
# Do incremental shape inference
670-
if self._do_shape_inference and not is_control_flow_op(node):
677+
if self._shape_inference and not is_control_flow_op(node):
671678
self._do_inference(node)
672679

673680
if node.domain not in self.opset_imports:
@@ -696,6 +703,14 @@ def process_node(self, node: ir.Node):
696703
if any(x is None for x in input_values):
697704
return None
698705

706+
if any(input.size > self._input_size_limit for input in input_values):
707+
if logger.isEnabledFor(logging.DEBUG):
708+
input_sizes = [input.size for input in input_values]
709+
logger.debug(
710+
f"Skipping constant folding for op {node.op_type} due to large input size: {input_sizes}"
711+
)
712+
return None
713+
699714
# Filter out bfloat16 cases?
700715
def convert(av):
701716
if av.type == ir.AttributeType.TENSOR:
@@ -770,14 +785,18 @@ def fold_constants(
770785
external_data_folder: str = "",
771786
*,
772787
onnx_shape_inference: bool = False,
788+
input_size_limit: int = _DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT,
789+
output_size_limit: int = _DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT,
773790
) -> bool:
774791
"""
775792
Applies constant folding optimization to the model.
776793
Returns true iff the model was modified.
777794
"""
778795
folder = ConstantFolder(
779-
external_data_folder,
780-
onnx_shape_inference,
796+
external_data_folder=external_data_folder,
797+
shape_inference=onnx_shape_inference,
798+
input_size_limit=input_size_limit,
799+
output_size_limit=output_size_limit,
781800
)
782801
folder.visit_model(model)
783802
for op in folder.counts:

0 commit comments

Comments
 (0)