Skip to content

Commit 94b955c

Browse files
authored
Merge pull request #465 from bhushan23/master
Adding tf converter options into tf coremltools path
2 parents 05ae10f + 459d507 commit 94b955c

File tree

4 files changed

+150
-14
lines changed

4 files changed

+150
-14
lines changed

coremltools/converters/nnssa/coreml/ssa_converter.py

Lines changed: 68 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import numpy as np
22

3+
from six import string_types as _string_types
4+
35
from coremltools.models import datatypes
46
from coremltools.proto import NeuralNetwork_pb2
57
from coremltools.models.neural_network import NeuralNetworkBuilder
@@ -18,7 +20,6 @@
1820

1921
DEBUG = False
2022

21-
2223
def _is_scalar(type_):
2324
if type_ is None:
2425
return False
@@ -31,10 +32,19 @@ def ssa_convert(ssa,
3132
top_func='main',
3233
inputs=None,
3334
outputs=None,
35+
image_input_names=None,
36+
is_bgr=False,
37+
red_bias=0.0,
38+
green_bias=0.0,
39+
blue_bias=0.0,
40+
gray_bias=0.0,
41+
image_scale=1.0,
42+
class_labels=None,
43+
predicted_feature_name=None,
44+
predicted_probabilities_output='',
3445
add_custom_layers=False,
3546
custom_conversion_functions={},
36-
custom_shape_functions={}
37-
):
47+
custom_shape_functions={}):
3848
"""
3949
Convert NNSSA into CoreML spec.
4050
ssa : NetworkEnsemble
@@ -109,33 +119,82 @@ def ssa_convert(ssa,
109119
for f in list(ssa.functions.values()):
110120
check_connections(f.graph)
111121

122+
# Set classifier flag
123+
is_classifier = class_labels is not None
124+
neural_network_type = 'classifier' if is_classifier else None
125+
112126
converter = SSAConverter(ssa,
113127
top_func=top_func,
114128
inputs=inputs,
115129
outputs=outputs,
130+
neural_network_type=neural_network_type,
116131
add_custom_layers=add_custom_layers,
117132
custom_conversion_functions=custom_conversion_functions,
118133
custom_shape_functions=custom_shape_functions)
119134
converter.convert()
135+
136+
builder = converter._get_builder(func=top_func)
137+
# Add image input identifier
138+
if image_input_names is not None and isinstance(
139+
image_input_names, _string_types):
140+
image_input_names = [image_input_names]
141+
142+
# Add classifier classes (if applicable)
143+
if is_classifier:
144+
classes_in = class_labels
145+
if isinstance(classes_in, _string_types):
146+
import os
147+
if not os.path.isfile(classes_in):
148+
raise ValueError("Path to class labels (%s) does not exist." % \
149+
classes_in)
150+
with open(classes_in, 'r') as f:
151+
classes = f.read()
152+
classes = classes.splitlines()
153+
elif type(classes_in) is list: # list[int or str]
154+
classes = classes_in
155+
else:
156+
raise ValueError('Class labels must be a list of integers / strings,'\
157+
' or a file path')
158+
159+
if predicted_feature_name is not None:
160+
builder.set_class_labels(
161+
classes, predicted_feature_name=predicted_feature_name,
162+
prediction_blob=predicted_probabilities_output)
163+
else:
164+
builder.set_class_labels(classes)
165+
166+
image_format = ssa.get_image_format()
167+
# Set pre-processing parameters
168+
builder.set_pre_processing_parameters(image_input_names=image_input_names,
169+
is_bgr=is_bgr,
170+
red_bias=red_bias,
171+
green_bias=green_bias,
172+
blue_bias=blue_bias,
173+
gray_bias=gray_bias,
174+
image_scale=image_scale,
175+
image_format=image_format)
176+
120177
mlmodel_spec = converter.get_spec()
121178

179+
# MLModel passes
122180
mlmodel_passes = [remove_disconnected_constants]
123181
for p in mlmodel_passes:
124182
p(mlmodel_spec)
125183

184+
126185
if DEBUG:
127186
coremltools.models.utils.save_spec(mlmodel_spec, '/tmp/model_from_spec.mlmodel')
128187

129188
return mlmodel_spec
130189

131190

132191
class SSAConverter(object):
133-
134192
def __init__(self,
135193
net_ensemble, # type: NetworkEnsemble
136194
top_func='main', # type: str
137195
inputs=None, # type: List[str]
138196
outputs=None, # type: List[str]
197+
neural_network_type=None, # type: str
139198
add_custom_layers=False, # type: bool
140199
custom_conversion_functions={}, # type: Dict[Text, Any]
141200
custom_shape_functions={} # type: Dict[Text, Any]
@@ -213,10 +272,10 @@ def __init__(self,
213272
else:
214273
top_output_features = list(zip(top_output_names, [None] * len(top_output_names)))
215274

216-
self.top_builder = NeuralNetworkBuilder(
217-
input_features=top_input_features,
218-
output_features=top_output_features,
219-
disable_rank5_shape_mapping=True)
275+
self.top_builder = NeuralNetworkBuilder(input_features=top_input_features,
276+
output_features=top_output_features,
277+
disable_rank5_shape_mapping=True,
278+
mode=neural_network_type)
220279

221280
self.spec = self.top_builder.spec
222281

@@ -574,6 +633,7 @@ def _convert_transpose(self, node):
574633
raise ValueError('[SSAConverter] Cannot handle dynamic Transpose')
575634
dim = list(dim)
576635
builder = self._get_builder()
636+
577637
layer = builder.add_transpose(
578638
name=node.name, axes=dim, input_name=input_names[0], output_name=node.name)
579639

coremltools/converters/nnssa/nnssa.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,3 +430,19 @@ def _find_free_name(self, prefix):
430430
idx += 1
431431
else:
432432
return name
433+
434+
def get_image_format(self):
435+
"""
436+
Iterates over graph and returns input format (`NCHW` or `NHWC`)
437+
if input is of type Image, otherwise `None`
438+
"""
439+
for fn_key in list(self.functions.keys()):
440+
graph = self.functions[fn_key].graph
441+
442+
for name in graph:
443+
node = graph[name]
444+
if node.attr.get('data_format', None) == 'NHWC' or node.attr.get('data_format') == 'NHWC_format_inserted':
445+
return 'NHWC'
446+
elif node.attr.get('data_format', None) == 'NCHW':
447+
return 'NCHW'
448+
return None

coremltools/converters/tensorflow/_tf_converter.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,19 @@
77
import os.path
88
from ...models import MLModel
99

10-
1110
def convert(filename,
1211
inputs=None,
1312
outputs=None,
13+
image_input_names=None,
14+
is_bgr=False,
15+
red_bias=0.0,
16+
green_bias=0.0,
17+
blue_bias=0.0,
18+
gray_bias=0.0,
19+
image_scale=1.0,
20+
class_labels=None,
21+
predicted_feature_name=None,
22+
predicted_probabilities_output='',
1423
add_custom_layers=False, # type: bool
1524
custom_conversion_functions={}, # type: Dict[Text, Any]
1625
custom_shape_functions={}, # type: Dict[Text, Any]
@@ -20,7 +29,6 @@ def convert(filename,
2029

2130
if not filename.endswith('.pb'):
2231
raise ValueError('invalid input tf_model_path format, expecting TensorFlow frozen graph (.pb) model.')
23-
2432
# convert from TensorFlow to SSA
2533
try:
2634
from ..nnssa.frontend.tensorflow import load as frontend_load
@@ -35,10 +43,19 @@ def convert(filename,
3543
top_func='main',
3644
inputs=inputs,
3745
outputs=outputs,
46+
image_input_names=image_input_names,
47+
is_bgr=is_bgr,
48+
red_bias=red_bias,
49+
green_bias=green_bias,
50+
blue_bias=blue_bias,
51+
gray_bias=gray_bias,
52+
image_scale=image_scale,
53+
class_labels=class_labels,
54+
predicted_feature_name=predicted_feature_name,
55+
predicted_probabilities_output=predicted_probabilities_output,
3856
add_custom_layers=add_custom_layers,
3957
custom_conversion_functions=custom_conversion_functions,
40-
custom_shape_functions=custom_shape_functions
41-
)
58+
custom_shape_functions=custom_shape_functions)
4259
except ImportError as err:
4360
raise ImportError("Backend converter not found! Error message:\n%s" % err)
4461

coremltools/models/neural_network/builder.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3204,7 +3204,8 @@ def add_crop_resize(self, name, input_names, output_name, target_height=1, targe
32043204
return spec_layer
32053205

32063206
def set_pre_processing_parameters(self, image_input_names=None, is_bgr=False,
3207-
red_bias=0.0, green_bias=0.0, blue_bias=0.0, gray_bias=0.0, image_scale=1.0):
3207+
red_bias=0.0, green_bias=0.0, blue_bias=0.0, gray_bias=0.0, image_scale=1.0,
3208+
image_format='NCHW'):
32083209
"""
32093210
Add a pre-processing parameters layer to the neural network object.
32103211
@@ -3232,6 +3233,9 @@ def set_pre_processing_parameters(self, image_input_names=None, is_bgr=False,
32323233
32333234
image_scale: float or dict()
32343235
Value by which to scale the images.
3236+
3237+
image_format: str
3238+
Image format, either 'NCHW' / 'NHWC'
32353239
32363240
See Also
32373241
--------
@@ -3241,6 +3245,9 @@ def set_pre_processing_parameters(self, image_input_names=None, is_bgr=False,
32413245
if not image_input_names:
32423246
return # nothing to do here
32433247

3248+
if image_format != 'NCHW' and image_format != 'NHWC':
3249+
raise ValueError("Input image format must be either 'NCHW' or 'NHWC'. Provided {}".format(image_format))
3250+
32443251
if not isinstance(is_bgr, dict):
32453252
is_bgr = dict.fromkeys(image_input_names, is_bgr)
32463253
if not isinstance(red_bias, dict):
@@ -3259,7 +3266,43 @@ def set_pre_processing_parameters(self, image_input_names=None, is_bgr=False,
32593266
if input_.name in image_input_names:
32603267
if input_.type.WhichOneof('Type') == 'multiArrayType':
32613268
array_shape = tuple(input_.type.multiArrayType.shape)
3262-
channels, height, width = array_shape
3269+
3270+
if len(array_shape) == 4:
3271+
input_indices = [0, 1, 2, 3] if image_format == 'NCHW' else [0, 3, 1, 2]
3272+
elif len(array_shape) == 3:
3273+
# Adding dummy index for 'batch' for compatibility
3274+
input_indices = [0, 0, 1, 2] if image_format == 'NCHW' else [0, 2, 0, 1]
3275+
else:
3276+
raise ValueError("Invalid input shape. Input of rank {}, but expecting input of either rank 3 or rank 4".format(len(array_shape)))
3277+
3278+
# Extract image shape depending on input format
3279+
_, channels, height, width = [array_shape[e] for e in input_indices]
3280+
3281+
if image_format == 'NHWC':
3282+
# If input format is 'NHWC', then add transpose
3283+
# after the input and replace all use of input
3284+
# with output of transpose
3285+
axes = [1, 2, 0]
3286+
if len(array_shape) == 4:
3287+
axes = [0, 2, 3, 1]
3288+
input_transpose = input_.name + '_to_nhwc'
3289+
transpose_layer = self.add_transpose(
3290+
name=input_transpose,
3291+
axes=axes,
3292+
input_name=input_.name,
3293+
output_name=input_transpose
3294+
)
3295+
layers = spec.neuralNetwork.layers
3296+
layers.insert(0, layers.pop())
3297+
for layer_ in layers:
3298+
for i in range(len(layer_.input)):
3299+
if layer_.name == input_transpose:
3300+
continue
3301+
if layer_.input[i] == input_.name:
3302+
layer_.input[i] = input_transpose
3303+
3304+
# TODO: If input is not rank 3 or 4, then accordingly handle
3305+
# e.g. for rank-2 input, squeeze additional dimension in case of Gray scale image
32633306
if channels == 1:
32643307
input_.type.imageType.colorSpace = _FeatureTypes_pb2.ImageFeatureType.ColorSpace.Value(
32653308
'GRAYSCALE')

0 commit comments

Comments
 (0)