Skip to content

Commit 230f1e5

Browse files
authored
Initial implementation of make_pipeline utility (#1803)
* initial implementation of make pipeline utility * Fix unit test * Set weights_dir parameter * minor cleanups * Fix typo * When mapping input shapes to output shapes, don't override shapes
1 parent 32b1ee0 commit 230f1e5

File tree

4 files changed

+181
-17
lines changed

4 files changed

+181
-17
lines changed

coremltools/models/model.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@
2222
from ..proto import Model_pb2 as _Model_pb2
2323
from .utils import (_MLMODEL_EXTENSION, _MLPACKAGE_AUTHOR_NAME,
2424
_MLPACKAGE_EXTENSION, _WEIGHTS_DIR_NAME, _create_mlpackage,
25-
_has_custom_layer, _is_macos, _macos_version)
26-
from .utils import load_spec as _load_spec
27-
from .utils import save_spec as _save_spec
25+
_has_custom_layer, _is_macos, _macos_version,
26+
load_spec as _load_spec, save_spec as _save_spec,
27+
)
2828

2929
if _HAS_TORCH:
30-
import torch
30+
import torch as _torch
3131

3232
if _HAS_TF_1 or _HAS_TF_2:
33-
import tensorflow as tf
33+
import tensorflow as _tf
3434

3535

3636
try:
@@ -258,7 +258,7 @@ def __init__(
258258
i.e. a spec object.
259259
260260
is_temp_package: bool
261-
Set to true if the input model package dir is temporary and can be deleted upon interpreter termination.
261+
Set to True if the input model package dir is temporary and can be deleted upon interpreter termination.
262262
263263
mil_program: coremltools.converters.mil.Program
264264
Set to the MIL program object, if available.
@@ -326,9 +326,8 @@ def cleanup(package_path):
326326
self.is_temp_package = False
327327
self.package_path = None
328328
self._weights_dir = None
329-
if mil_program is not None:
330-
if not isinstance(mil_program, _Program):
331-
raise ValueError("mil_program must be of type 'coremltools.converters.mil.Program'")
329+
if mil_program is not None and not isinstance(mil_program, _Program):
330+
raise ValueError('"mil_program" must be of type "coremltools.converters.mil.Program"')
332331
self._mil_program = mil_program
333332

334333
if isinstance(model, str):
@@ -342,8 +341,9 @@ def cleanup(package_path):
342341
model, compute_units, skip_model_load=skip_model_load,
343342
)
344343
elif isinstance(model, _Model_pb2.Model):
345-
if model.WhichOneof('Type') == "mlProgram":
346-
if weights_dir is None:
344+
model_type = model.WhichOneof('Type')
345+
if model_type in ("mlProgram", 'pipelineClassifier', 'pipelineRegressor', 'pipeline'):
346+
if model_type == "mlProgram" and weights_dir is None:
347347
raise Exception('MLModel of type mlProgram cannot be loaded just from the model spec object. '
348348
'It also needs the path to the weights file. Please provide that as well, '
349349
'using the \'weights_dir\' argument.')
@@ -443,6 +443,7 @@ def save(self, save_path: str):
443443
loaded_model = MLModel('my_model_file.mlmodel')
444444
"""
445445
save_path = _os.path.expanduser(save_path)
446+
446447
# Clean up existing file or directory.
447448
if _os.path.exists(save_path):
448449
if _os.path.isdir(save_path):
@@ -489,7 +490,7 @@ def predict(self, data):
489490
490491
Returns
491492
-------
492-
out: dict[str, value]
493+
dict[str, value]
493494
Predictions as a dictionary where each key is the output feature
494495
name.
495496
@@ -648,10 +649,10 @@ def _convert_tensor_to_numpy(self, input_dict):
648649
def convert(given_input):
649650
if isinstance(given_input, _numpy.ndarray):
650651
sanitized_input = given_input
651-
elif _HAS_TORCH and isinstance(given_input, torch.Tensor):
652+
elif _HAS_TORCH and isinstance(given_input, _torch.Tensor):
652653
sanitized_input = given_input.detach().numpy()
653-
elif (_HAS_TF_1 or _HAS_TF_2) and isinstance(given_input, tf.Tensor):
654-
sanitized_input = given_input.eval(session=tf.compat.v1.Session())
654+
elif (_HAS_TF_1 or _HAS_TF_2) and isinstance(given_input, _tf.Tensor):
655+
sanitized_input = given_input.eval(session=_tf.compat.v1.Session())
655656
else:
656657
sanitized_input = _numpy.array(given_input)
657658
return sanitized_input

coremltools/models/utils.py

Lines changed: 99 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,26 @@
66
"""
77
Utilities for the entire package.
88
"""
9+
10+
from collections.abc import Iterable as _Iterable
11+
from functools import lru_cache as _lru_cache
912
import math as _math
1013
import os as _os
1114
import shutil as _shutil
1215
import subprocess as _subprocess
1316
import sys as _sys
1417
import tempfile as _tempfile
15-
import warnings as _warnings
16-
from functools import lru_cache as _lru_cache
1718
from typing import Optional as _Optional
19+
import warnings as _warnings
1820

1921
import numpy as _np
2022

23+
import coremltools as _ct
2124
from coremltools import ComputeUnit as _ComputeUnit
2225
from coremltools.converters.mil.mil.passes.name_sanitization_utils import \
2326
NameSanitizer as _NameSanitizer
2427
from coremltools.proto import Model_pb2 as _Model_pb2
28+
import coremltools.proto.MIL_pb2 as _mil_proto
2529

2630
from .._deps import _HAS_SCIPY
2731

@@ -1000,3 +1004,96 @@ def _convert_to_float(feature):
10001004
if spec.WhichOneof("Type") == "pipeline":
10011005
for model_spec in spec.pipeline.models:
10021006
convert_double_to_float_multiarray_type(model_spec)
1007+
1008+
1009+
def make_pipeline(*models):
1010+
"""
1011+
Makes a pipeline with the given models.
1012+
1013+
Parameters
1014+
----------
1015+
*models - two or more instances of ct.models.MLModel
1016+
1017+
Returns
1018+
-------
1019+
ct.models.MLModel
1020+
1021+
Examples
1022+
--------
1023+
my_model1 = ct.models.MLModel('/tmp/m1.mlpackage')
1024+
my_model2 = ct.models.MLModel('/tmp/m2.mlmodel')
1025+
1026+
my_pipeline_model = ct.utils.make_pipeline(my_model1, my_model2)
1027+
"""
1028+
1029+
def updateBlobFileName(proto_message, new_path):
1030+
if type(proto_message) == _mil_proto.Value:
1031+
# Value protobuf message. This is what might need to be updated.
1032+
if proto_message.WhichOneof('value') == 'blobFileValue':
1033+
assert proto_message.blobFileValue.fileName == "@model_path/weights/weight.bin"
1034+
proto_message.blobFileValue.fileName = new_path
1035+
elif hasattr(proto_message, 'ListFields'):
1036+
# Normal protobuf message
1037+
for f in proto_message.ListFields():
1038+
updateBlobFileName(f[1], new_path)
1039+
elif hasattr(proto_message, 'values'):
1040+
# Protobuf map
1041+
for v in proto_message.values():
1042+
updateBlobFileName(v, new_path)
1043+
elif isinstance(proto_message, _Iterable) and not isinstance(proto_message, str):
1044+
# Repeated protobuf message
1045+
for e in proto_message:
1046+
updateBlobFileName(e, new_path)
1047+
1048+
1049+
assert len(models) > 1
1050+
input_specs = list(map(lambda m: m.get_spec(), models))
1051+
1052+
pipeline_spec = _ct.proto.Model_pb2.Model()
1053+
pipeline_spec.specificationVersion = max(
1054+
map(lambda spec: spec.specificationVersion, input_specs)
1055+
)
1056+
1057+
# Set pipeline input
1058+
pipeline_spec.description.input.MergeFrom(
1059+
input_specs[0].description.input
1060+
)
1061+
1062+
# Set pipeline output
1063+
pipeline_spec.description.output.MergeFrom(
1064+
input_specs[-1].description.output
1065+
)
1066+
1067+
# Map input shapes to output shapes
1068+
var_name_to_type = {}
1069+
for i in range(len(input_specs) - 1):
1070+
for j in input_specs[i + 1].description.input:
1071+
var_name_to_type[j.name] = j.type
1072+
1073+
for j in input_specs[i].description.output:
1074+
# If shape is already present, don't override it
1075+
if j.type.WhichOneof('Type') == 'multiArrayType' and len(j.type.multiArrayType.shape) != 0:
1076+
continue
1077+
1078+
if j.name in var_name_to_type:
1079+
j.type.CopyFrom(var_name_to_type[j.name])
1080+
1081+
# Update each model's spec to have a unique weight filename
1082+
for i, cur_spec in enumerate(input_specs):
1083+
if cur_spec.WhichOneof("Type") == "mlProgram":
1084+
new_file_path = f"@model_path/weights/{i}-weight.bin"
1085+
updateBlobFileName(cur_spec.mlProgram, new_file_path)
1086+
pipeline_spec.pipeline.models.append(cur_spec)
1087+
1088+
mlpackage_path = _create_mlpackage(pipeline_spec)
1089+
dst = mlpackage_path + '/Data/' + _MLPACKAGE_AUTHOR_NAME + '/' + _WEIGHTS_DIR_NAME
1090+
_os.mkdir(dst)
1091+
1092+
# Copy and rename each model's weight file
1093+
for i, cur_model in enumerate(models):
1094+
if cur_model.weights_dir is not None:
1095+
weight_file_path = cur_model.weights_dir + "/" + _WEIGHTS_FILE_NAME
1096+
if _os.path.exists(weight_file_path):
1097+
_shutil.copyfile(weight_file_path, dst + f"/{i}-weight.bin")
1098+
1099+
return _ct.models.MLModel(pipeline_spec, weights_dir=dst)

coremltools/test/api/test_api_visibilities.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def test_utils(self):
6060
"evaluate_classifier_with_probabilities",
6161
"evaluate_regressor",
6262
"evaluate_transformer",
63+
"make_pipeline",
6364
"load_spec",
6465
"rename_feature",
6566
"save_spec",

coremltools/test/pipeline/test_pipeline.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,17 @@
33
# Use of this source code is governed by a BSD-3-clause license that can be
44
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
55

6+
import itertools
7+
import tempfile
68
import unittest
79

10+
import numpy as np
11+
import pytest
12+
13+
import coremltools as ct
814
from coremltools._deps import _HAS_LIBSVM, _HAS_SKLEARN
15+
from coremltools.converters.mil.mil import Builder as mb
16+
from coremltools.converters.mil.mil import Function, Program
917
from coremltools.models.pipeline import PipelineClassifier, PipelineRegressor
1018

1119
if _HAS_SKLEARN:
@@ -210,3 +218,60 @@ def test_conversion_bad_inputs(self):
210218
with self.assertRaises(TypeError):
211219
model = OneHotEncoder()
212220
spec = converter.convert(model, "data", "out", "regressor")
221+
222+
223+
class TestMakePipeline:
224+
@staticmethod
225+
def _make_model(input_name, input_length,
226+
output_name, output_length,
227+
convert_to):
228+
229+
weight_tensor = np.arange(input_length * output_length, dtype='float32')
230+
weight_tensor = weight_tensor.reshape(output_length, input_length)
231+
232+
prog = Program()
233+
func_inputs = {input_name: mb.placeholder(shape=(input_length,))}
234+
with Function(func_inputs) as ssa_fun:
235+
input = ssa_fun.inputs[input_name]
236+
y = mb.linear(x=input, weight=weight_tensor, name=output_name)
237+
ssa_fun.set_outputs([y])
238+
prog.add_function("main", ssa_fun)
239+
240+
return ct.convert(prog, convert_to=convert_to)
241+
242+
243+
@staticmethod
244+
@pytest.mark.parametrize(
245+
"model1_backend, model2_backend",
246+
itertools.product(["mlprogram", "neuralnetwork"], ["mlprogram", "neuralnetwork"]),
247+
)
248+
def test_simple(model1_backend, model2_backend):
249+
# Create models
250+
m1 = TestMakePipeline._make_model("x", 20, "y1", 10, model1_backend)
251+
m2 = TestMakePipeline._make_model("y1", 10, "y2", 2, model2_backend)
252+
253+
# Get non-pipeline result
254+
x = np.random.rand(20)
255+
y1 = m1.predict({"x": x})["y1"]
256+
y2 = m2.predict({"y1": y1})
257+
258+
pipeline_model = ct.utils.make_pipeline(m1, m2)
259+
260+
y_pipeline = pipeline_model.predict({"x": x})
261+
np.testing.assert_allclose(y2["y2"], y_pipeline["y2"])
262+
263+
# Check save/load
264+
with tempfile.TemporaryDirectory() as save_dir:
265+
# Save pipeline
266+
save_path = save_dir + "/test.mlpackage"
267+
pipeline_model.save(save_path)
268+
269+
# Check loading from a mlpackage path
270+
p2 = ct.models.MLModel(save_path)
271+
y_pipeline = p2.predict({"x": x})
272+
np.testing.assert_allclose(y2["y2"], y_pipeline["y2"])
273+
274+
# Check loading from spec and weight dir
275+
p3 = ct.models.MLModel(p2.get_spec(), weights_dir=p2.weights_dir)
276+
y_pipeline = p3.predict({"x": x})
277+
np.testing.assert_allclose(y2["y2"], y_pipeline["y2"])

0 commit comments

Comments
 (0)