66import collections
77import gc
88import os
9- from typing import Optional , Text , Union
9+ from typing import List , Optional , Text , Union
1010
1111from coremltools import (
1212 _LOWEST_ALLOWED_SPECIFICATION_VERSION_FOR_MILPROGRAM ,
2323from coremltools .converters .mil .converter import mil_convert
2424from coremltools .converters .mil .input_types import (
2525 ClassifierConfig ,
26+ EnumeratedShapes ,
2627 ImageType ,
2728 InputType ,
29+ RangeDim ,
30+ Shape ,
2831 TensorType ,
2932)
3033from coremltools .converters .mil .mil import Program , types
@@ -395,7 +398,7 @@ def skip_real_div_ops(op):
395398
396399 pipeline = ct.PassPipeline()
397400 pipeline.remove_passes({"common::fuse_conv_batchnorm"})
398- ct.convert(model, pass_pipeline=pipeline)
401+ mlmodel = ct.convert(model, pass_pipeline=pipeline)
399402
400403 * To avoid folding too-large ``const`` ops that lead to a large model, set pass option
401404 as shown in the following example:
@@ -404,7 +407,34 @@ def skip_real_div_ops(op):
404407
405408 pipeline = ct.PassPipeline()
406409 pipeline.set_options("common::const_elimination", {"skip_const_by_size": "1e6"})
407- ct.convert(model, pass_pipeline=pipeline)
410+ mlmodel = ct.convert(model, pass_pipeline=pipeline)
411+
412+ We also provide a set of predefined pass pipelines that you can directly call.
413+
414+ * To avoid running all graph pass, you can use:
415+
416+ .. sourcecode:: python
417+
418+ mlmodel = ct.convert(model, pass_pipeline=ct.PassPipeline.EMPTY)
419+
420+ * To only run the cleanup graph passes, like constant_elimination, dead_code_elimination, etc.
421+ You can use:
422+
423+ .. sourcecode:: python
424+
425+ mlmodel = ct.convert(model, pass_pipeline=ct.PassPipeline.CLEANUP)
426+
427+ * To convert a source model with sparse weights to a sparse format Core ML model, you can use:
428+
429+ .. sourcecode:: python
430+
431+ mlmodel = ct.convert(model, pass_pipeline=ct.PassPipeline.DEFAULT_PRUNING)
432+
433+ * To convert a source model with palettized weights to a compressed format Core ML model, you can use:
434+
435+ .. sourcecode:: python
436+
437+ mlmodel = ct.convert(model, pass_pipeline=ct.PassPipeline.DEFAULT_PALETTIZATION)
408438
409439 Returns
410440 -------
@@ -463,9 +493,17 @@ def skip_real_div_ops(op):
463493 outputs_as_tensor_or_image_types ,
464494 outputs )
465495 exact_target = _determine_target (convert_to , minimum_deployment_target )
466- _validate_conversion_arguments (model , exact_source , inputs , outputs_as_tensor_or_image_types ,
467- classifier_config , compute_precision ,
468- exact_target , minimum_deployment_target )
496+ _validate_conversion_arguments (
497+ model ,
498+ exact_source ,
499+ exact_target ,
500+ inputs ,
501+ outputs_as_tensor_or_image_types ,
502+ classifier_config ,
503+ compute_precision ,
504+ exact_target ,
505+ minimum_deployment_target ,
506+ )
469507
470508 if pass_pipeline is None :
471509 pass_pipeline = PassPipeline ()
@@ -504,6 +542,12 @@ def skip_real_div_ops(op):
504542 main_pipeline = pass_pipeline ,
505543 )
506544
545+ if exact_target == "mlprogram" and mlmodel ._input_has_infinite_upper_bound ():
546+ raise ValueError (
547+ "For mlprogram, inputs with infinite upper_bound is not allowed. Please set upper_bound"
548+ ' to a positive value in "RangeDim()" for the "inputs" param in ct.convert().'
549+ )
550+
507551 if exact_target == 'milinternal' :
508552 return mlmodel # Returns the MIL program
509553
@@ -539,7 +583,7 @@ def _need_fp16_cast_pass(
539583 raise ValueError (f"Invalid value of the argument 'compute_precision': { compute_precision } " )
540584
541585
542- def _set_default_specification_version (target ):
586+ def _set_default_specification_version (target ) -> Optional [ AvailableTarget ] :
543587 if target == "neuralnetwork" :
544588 return _LOWEST_ALLOWED_SPECIFICATION_VERSION_FOR_NEURALNETWORK
545589 elif target == "mlprogram" :
@@ -625,18 +669,20 @@ def _validate_outputs_argument(outputs):
625669 return output_names , outputs
626670
627671
628- def _validate_conversion_arguments (model ,
629- exact_source ,
630- inputs ,
631- outputs ,
632- classifier_config ,
633- compute_precision ,
634- convert_to ,
635- minimum_deployment_target ,
636- ):
672+ def _validate_conversion_arguments (
673+ model ,
674+ exact_source ,
675+ exact_target ,
676+ inputs ,
677+ outputs ,
678+ classifier_config ,
679+ compute_precision ,
680+ convert_to ,
681+ minimum_deployment_target ,
682+ ):
637683 """
638684 Validate and process model, inputs, classifier_config based on
639- `exact_source` (which cannot be `auto`)
685+ `exact_source` (which cannot be `auto`) and `exact_target`.
640686 """
641687
642688 def raise_if_duplicated (input_list ):
@@ -672,10 +718,10 @@ def _flatten_list(_inputs):
672718
673719 # get flattened inputs
674720 flat_inputs = _flatten_list (inputs )
675- for t in flat_inputs :
676- if not isinstance (t , InputType ):
721+ for flat_input in flat_inputs :
722+ if not isinstance (flat_input , InputType ):
677723 raise ValueError ("inputs must be a list of type ct.TensorType or ct.ImageType" )
678- if t .dtype == types .fp16 :
724+ if flat_input .dtype == types .fp16 :
679725 if not (
680726 minimum_deployment_target is not None
681727 and minimum_deployment_target >= AvailableTarget .iOS16
@@ -685,6 +731,24 @@ def _flatten_list(_inputs):
685731 "target >= iOS16/macOS13/watchOS9/tvOS16"
686732 )
687733
734+ if exact_target == "mlprogram" :
735+ err_msg_infinite_bound = (
736+ "For mlprogram, inputs with infinite upper_bound is not allowed. Please set upper_bound"
737+ ' to a positive value in "RangeDim()" for the "inputs" param in ct.convert().'
738+ )
739+ if inputs is not None :
740+ for flat_input in _flatten_list (inputs ):
741+ tensor_shapes : List [Optional [Shape ]] = (
742+ flat_input .shape .shapes
743+ if isinstance (flat_input .shape , EnumeratedShapes )
744+ else [flat_input .shape ]
745+ )
746+ for tensor_shape in tensor_shapes :
747+ if tensor_shape is not None :
748+ for shape in tensor_shape .shape :
749+ if isinstance (shape , RangeDim ) and shape .upper_bound < 0 :
750+ raise ValueError (err_msg_infinite_bound )
751+
688752 if outputs is not None :
689753 for t in outputs :
690754 if t .dtype == types .fp16 :
0 commit comments