1515from coremltools import ComputeUnit as _ComputeUnit
1616from coremltools import __version__ as _ct_version
1717from coremltools import _logger as logger
18- from coremltools ._deps import _HAS_TF_1 , _HAS_TF_2 , _HAS_TORCH
18+ from coremltools ._deps import _HAS_TF_1 , _HAS_TF_2 , _HAS_TORCH , _HAS_TORCH_EXPORT_API
1919from coremltools .converters ._profile_utils import _profile
2020from coremltools .converters .mil ._deployment_compatibility import (
2121 AvailableTarget ,
3636from coremltools .converters .mil .mil .passes .defs .quantization import FP16ComputePrecision
3737from coremltools .converters .mil .mil .passes .graph_pass import PassOption as _PassOption
3838from coremltools .converters .mil .mil .passes .pass_pipeline import PassPipeline
39- from coremltools .models import _METADATA_SOURCE , _METADATA_VERSION
39+ from coremltools .models import _METADATA_SOURCE , _METADATA_SOURCE_DIALECT , _METADATA_VERSION
4040from coremltools .models .utils import _MLPACKAGE_EXTENSION
4141
4242if _HAS_TF_1 :
5151if _HAS_TORCH :
5252 import torch
5353
54- from coremltools .converters .mil .frontend .torch .load import \
55- _torchscript_from_model as pytorch_load
54+ from coremltools .converters .mil .frontend .torch .load import (
55+ _torchscript_from_spec as try_load_torchscript ,
56+ )
57+
58+ if _HAS_TORCH_EXPORT_API :
59+ from torch .export import ExportedProgram
60+
5661
5762
5863@_profile
@@ -102,8 +107,12 @@ def convert(
102107
103108 * PyTorch
104109
105- - A `TorchScript <https://pytorch.org/docs/stable/jit.html>`_ object
106- - Path to a ``.pt`` file
110+ - TorchScript Models:
111+ - A `TorchScript <https://pytorch.org/docs/stable/jit.html>`_ object
112+ - Path to a ``.pt`` file
113+
114+ - Torch Exported Models:
115+ - A `ExportedProgram <https://pytorch.org/docs/stable/export.html#torch.export.ExportedProgram> ` object with `EDGE` dialect
107116
108117 source : str (optional)
109118
@@ -161,18 +170,23 @@ def convert(
161170 When ``inputs`` not provided or ``dtype`` not specified. The float 32 inputs defaults to float 16.
162171
163172 * PyTorch:
164- - The ``inputs`` parameter is required.
165- - Number of elements in ``inputs`` must match the number of inputs
166- of the PyTorch model.
167- - ``inputs`` may be a nested list or tuple.
168- - ``TensorType`` and ``ImageType`` must have the ``shape`` specified.
169- - If the ``name`` argument is specified with ``TensorType`` or
170- ``ImageType``, the converted Core ML model will have inputs with
171- the same name.
172- - If ``dtype`` is missing:
173- * For ``minimum_deployment_target <= ct.target.macOS12``, it defaults to float 32.
174- * For ``minimum_deployment_target >= ct.target.macOS13``, and with ``compute_precision`` in float 16 precision.
175- It defaults to float 16.
173+
174+ - TorchScript Models:
175+ - The ``inputs`` parameter is required.
176+ - Number of elements in ``inputs`` must match the number of inputs
177+ of the PyTorch model.
178+ - ``inputs`` may be a nested list or tuple.
179+ - ``TensorType`` and ``ImageType`` must have the ``shape`` specified.
180+ - If the ``name`` argument is specified with ``TensorType`` or
181+ ``ImageType``, the converted Core ML model will have inputs with
182+ the same name.
183+ - If ``dtype`` is missing:
184+ * For ``minimum_deployment_target <= ct.target.macOS12``, it defaults to float 32.
185+ * For ``minimum_deployment_target >= ct.target.macOS13``, and with ``compute_precision`` in float 16 precision.
186+ It defaults to float 16.
187+
188+ - Torch Exported Models:
189+ - The ``inputs`` parameter is not supported. ``inputs`` parameter is inferred from Torch ExportedProgram.
176190
177191 outputs : list of ``TensorType`` or ``ImageType`` (optional)
178192
@@ -218,13 +232,17 @@ def convert(
218232
219233 * PyTorch:
220234
221- - If specified, the length of the list must match the number of
222- outputs returned by the PyTorch model.
223- - If ``name`` is specified, it is applied to the output names of the
224- converted Core ML model.
225- - For ``minimum_deployment_target >= ct.target.macOS13``, and with ``compute_precision`` in float 16 precision.
226- If ``dtype`` not specified, the outputs inferred of type float 32
227- defaults to float 16.
235+ - TorchScript Models:
236+ - If specified, the length of the list must match the number of
237+ outputs returned by the PyTorch model.
238+ - If ``name`` is specified, it is applied to the output names of the
239+ converted Core ML model.
240+ - For ``minimum_deployment_target >= ct.target.macOS13``, and with ``compute_precision`` in float 16 precision.
241+ If ``dtype`` not specified, the outputs inferred of type float 32
242+ defaults to float 16.
243+
244+ - Torch Exported Models:
245+ - The ``outputs`` parameter is not supported. ``outputs`` parameter is inferred from Torch ExportedProgram.
228246
229247
230248 classifier_config : ClassifierConfig class (optional)
@@ -308,7 +326,7 @@ def convert(
308326 The above transform iterates through all the ops, looking at each op's
309327 inputs and outputs. If they are of type float 32, ``cast``
310328 ops are injected to convert those tensors (also known as `vars`) to
311- type float 16.
329+ type float 16. Similarly, int32 vars will also be cast to int16.
312330
313331 - ``coremltools.precision.FLOAT32`` enum: No transform is applied.
314332
@@ -489,15 +507,17 @@ def skip_real_div_ops(op):
489507
490508 PyTorch:
491509
492- >>> model = torchvision.models.mobilenet_v2()
493- >>> model.eval()
494- >>> example_input = torch.rand(1, 3, 256, 256)
495- >>> traced_model = torch.jit.trace(model, example_input)
510+ TorchScript Models:
496511
497- >>> input = ct.TensorType(name='input_name', shape=(1, 3, 256, 256))
498- >>> mlmodel = ct.convert(traced_model, inputs=[input])
499- >>> results = mlmodel.predict({"input": example_input.numpy()})
500- >>> print(results['1651']) # 1651 is the node name given by PyTorch's JIT
512+ >>> model = torchvision.models.mobilenet_v2()
513+ >>> model.eval()
514+ >>> example_input = torch.rand(1, 3, 256, 256)
515+ >>> traced_model = torch.jit.trace(model, example_input)
516+
517+ >>> input = ct.TensorType(name='input_name', shape=(1, 3, 256, 256))
518+ >>> mlmodel = ct.convert(traced_model, inputs=[input])
519+ >>> results = mlmodel.predict({"input": example_input.numpy()})
520+ >>> print(results['1651']) # 1651 is the node name given by PyTorch's JIT
501521
502522 See `Conversion Options <https://coremltools.readme.io/docs/neural-network-conversion>`_ for
503523 more advanced options.
@@ -508,6 +528,7 @@ def skip_real_div_ops(op):
508528 outputs_as_strings ,
509529 outputs_as_tensor_or_image_types ,
510530 outputs )
531+ source_dialect = _determine_source_dialect (model , exact_source )
511532 exact_target = _determine_target (convert_to , minimum_deployment_target )
512533 _validate_conversion_arguments (
513534 model ,
@@ -525,7 +546,7 @@ def skip_real_div_ops(op):
525546 if pass_pipeline is None :
526547 pass_pipeline = PassPipeline ()
527548 if not need_fp16_cast_pass :
528- pass_pipeline .remove_passes ({"common::add_fp16_cast" })
549+ pass_pipeline .remove_passes ({"common::add_fp16_cast" , "common::add_int16_cast" })
529550 if isinstance (compute_precision , FP16ComputePrecision ):
530551 # For backward compatibility with the `op_selector` param in FP16ComputePrecision.
531552 pass_pipeline ._pass_options ["common::add_fp16_cast" ] = [
@@ -584,7 +605,7 @@ def skip_real_div_ops(op):
584605
585606 gc .collect ()
586607
587- mlmodel = _record_build_metadata (mlmodel , exact_source )
608+ mlmodel = _record_build_metadata (mlmodel , exact_source , source_dialect = source_dialect )
588609
589610 return mlmodel
590611
@@ -819,16 +840,45 @@ def _flatten_list(_inputs):
819840 raise ValueError ("Input should be a list of TensorType or ImageType" )
820841
821842 elif exact_source == "pytorch" :
822- if inputs is None :
823- raise ValueError ('Expected argument for pytorch "inputs" not provided' )
843+ if _HAS_TORCH_EXPORT_API and isinstance (model , ExportedProgram ):
844+ if model .dialect != "EDGE" :
845+ raise NotImplementedError (
846+ f"Conversion for models with only EDGE dialect is supported/tested. Provided Dialect: { model .dialect } "
847+ )
824848
825- raise_if_duplicated (flat_inputs )
826- if inputs is not None and not all (
827- [isinstance (_input , InputType ) for _input in flat_inputs ]
828- ):
829- raise ValueError (
830- "Input should be a list/tuple (or nested lists/tuples) of TensorType or ImageType"
831- )
849+ # TODO: rdar://115845792 ([Executorch] Handle user provided inputs/outputs in the convert API)
850+ if inputs is not None :
851+ raise AssertionError ("'inputs' argument should be None for ExportedProgram" )
852+
853+ if outputs is not None :
854+ raise AssertionError ("'outputs' argument should be None for ExportedProgram" )
855+
856+ else :
857+ is_torch_load_successful = False
858+ try :
859+ try_load_torchscript (model )
860+ is_torch_load_successful = True
861+ except :
862+ pass
863+ if is_torch_load_successful :
864+ if inputs is None :
865+ raise ValueError (
866+ 'Expected argument "inputs" for TorchScript models not provided'
867+ )
868+
869+ raise_if_duplicated (flat_inputs )
870+ if inputs is not None and not all (
871+ [isinstance (_input , InputType ) for _input in flat_inputs ]
872+ ):
873+ raise ValueError (
874+ "Input should be a list/tuple (or nested lists/tuples) of TensorType or ImageType"
875+ )
876+ else :
877+ raise TypeError (
878+ "@model must either be a TorchScript object (or .pt or .pth file) or an ExportedProgram object (if using torch.export based API), received: {}" .format (
879+ type (model )
880+ )
881+ )
832882
833883 elif exact_source == "milinternal" :
834884 if not isinstance (model , Program ):
@@ -837,6 +887,19 @@ def _flatten_list(_inputs):
837887 )
838888
839889
890+ def _determine_source_dialect (model , exact_source ):
891+
892+ source_dialect = None
893+ if exact_source == "pytorch" :
894+
895+ if _HAS_TORCH_EXPORT_API and isinstance (model , ExportedProgram ):
896+ return f"TorchExport::{ model .dialect } "
897+ else :
898+ return "TorchScript"
899+
900+ return source_dialect
901+
902+
840903def _determine_source (model , source ,
841904 output_names ,
842905 outputs_as_tensor_or_image_types ,
@@ -875,9 +938,13 @@ def _determine_source(model, source,
875938 pass
876939
877940 if source == "auto" and _HAS_TORCH :
941+
942+ if _HAS_TORCH_EXPORT_API and isinstance (model , ExportedProgram ):
943+ return "pytorch"
944+
878945 is_torch_load_successful = False
879946 try :
880- pytorch_load (model )
947+ try_load_torchscript (model )
881948 is_torch_load_successful = True
882949 except :
883950 pass
@@ -953,6 +1020,12 @@ def _get_metadata_from_mlmodel(mlmodel):
9531020 src_pkg_version = mlmodel .user_defined_metadata [_METADATA_SOURCE ]
9541021 coremltools_version = mlmodel .user_defined_metadata [_METADATA_VERSION ]
9551022
1023+ src_dialect = (
1024+ None
1025+ if _METADATA_SOURCE_DIALECT not in mlmodel .user_defined_metadata
1026+ else mlmodel .user_defined_metadata [_METADATA_SOURCE_DIALECT ]
1027+ )
1028+
9561029 src_pkg_version_list = src_pkg_version .split ("==" )
9571030 if len (src_pkg_version_list ) == 0 :
9581031 src_pkg , pkg_ver = None , None
@@ -969,10 +1042,13 @@ def _get_metadata_from_mlmodel(mlmodel):
9691042 if src_pkg is not None and pkg_ver is not None :
9701043 build_info ['coremltools-component-' + src_pkg ] = str (pkg_ver )
9711044
1045+ if src_dialect is not None :
1046+ build_info ["coremltools-source-dialect" ] = src_dialect
1047+
9721048 return build_info
9731049
9741050
975- def _record_build_metadata (mlmodel , exact_source ):
1051+ def _record_build_metadata (mlmodel , exact_source , source_dialect = None ):
9761052 # recording metadata: coremltools version, source framework and version
9771053 if exact_source in {"tensorflow" , "tensorflow2" } and (_HAS_TF_1 or _HAS_TF_2 ):
9781054 src_pkg_version = "tensorflow=={0}" .format (tf .__version__ )
@@ -986,6 +1062,9 @@ def _record_build_metadata(mlmodel, exact_source):
9861062 mlmodel .user_defined_metadata [_METADATA_SOURCE ] = src_pkg_version
9871063 mlmodel .user_defined_metadata [_METADATA_VERSION ] = _ct_version
9881064
1065+ if source_dialect is not None :
1066+ mlmodel .user_defined_metadata [_METADATA_SOURCE_DIALECT ] = source_dialect
1067+
9891068 build_info = _get_metadata_from_mlmodel (mlmodel )
9901069
9911070 mlmodel ._set_build_info_mil_attributes (build_info )
0 commit comments