@@ -43,7 +43,9 @@ def is_constant_op(node: ir.Node) -> bool:
4343 )
4444
4545
46- _DEFAULT_CONSTANT_FOLD_SIZE_LIMIT = constant_folding ._DEFAULT_CONSTANT_FOLD_SIZE_LIMIT
46+ _DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT = 1024
47+
48+ _DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT = constant_folding ._DEFAULT_CONSTANT_FOLD_SIZE_LIMIT
4749
4850logger = logging .getLogger (__name__ )
4951
@@ -550,11 +552,16 @@ class ConstantFolder:
550552
551553 def __init__ (
552554 self ,
555+ * ,
553556 external_data_folder : str ,
554- do_shape_inference : bool ,
557+ shape_inference : bool ,
558+ input_size_limit : int ,
559+ output_size_limit : int ,
555560 ) -> None :
556561 self ._external_data_folder = external_data_folder
557- self ._do_shape_inference = do_shape_inference
562+ self ._shape_inference = shape_inference
563+ self ._input_size_limit = input_size_limit
564+ self ._output_size_limit = output_size_limit
558565 self ._init ()
559566
560567 def _init (self ) -> None :
@@ -632,7 +639,7 @@ def new_constant(self, irvalue: ir.Value, value):
632639
633640 irvalue .const_value = _convenience .tensor (value )
634641
635- if value .nbytes > _DEFAULT_CONSTANT_FOLD_SIZE_LIMIT :
642+ if value .nbytes > self . _output_size_limit :
636643 logger .info (
637644 "Skip storing constant folded nvalue %s due to large size %s." ,
638645 irvalue .name ,
@@ -667,7 +674,7 @@ def process_node(self, node: ir.Node):
667674 # TODO(rama): consider merging type/other info from both values
668675
669676 # Do incremental shape inference
670- if self ._do_shape_inference and not is_control_flow_op (node ):
677+ if self ._shape_inference and not is_control_flow_op (node ):
671678 self ._do_inference (node )
672679
673680 if node .domain not in self .opset_imports :
@@ -696,6 +703,14 @@ def process_node(self, node: ir.Node):
696703 if any (x is None for x in input_values ):
697704 return None
698705
706+ if any (input .size > self ._input_size_limit for input in input_values ):
707+ if logger .isEnabledFor (logging .DEBUG ):
708+ input_sizes = [input .size for input in input_values ]
709+ logger .debug (
710+ f"Skipping constant folding for op { node .op_type } due to large input size: { input_sizes } "
711+ )
712+ return None
713+
699714 # Filter out bfloat16 cases?
700715 def convert (av ):
701716 if av .type == ir .AttributeType .TENSOR :
@@ -770,14 +785,18 @@ def fold_constants(
770785 external_data_folder : str = "" ,
771786 * ,
772787 onnx_shape_inference : bool = False ,
788+ input_size_limit : int = _DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT ,
789+ output_size_limit : int = _DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT ,
773790) -> bool :
774791 """
775792 Applies constant folding optimization to the model.
776793 Returns true iff the model was modified.
777794 """
778795 folder = ConstantFolder (
779- external_data_folder ,
780- onnx_shape_inference ,
796+ external_data_folder = external_data_folder ,
797+ shape_inference = onnx_shape_inference ,
798+ input_size_limit = input_size_limit ,
799+ output_size_limit = output_size_limit ,
781800 )
782801 folder .visit_model (model )
783802 for op in folder .counts :
0 commit comments