|
4 | 4 | # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause |
5 | 5 |
|
6 | 6 | import logging |
7 | | -import numpy as _np |
| 7 | +import numpy as np |
8 | 8 | import os |
9 | | -import tempfile |
10 | | -import shutil |
11 | 9 |
|
12 | | -from coremltools.converters.mil.backend.mil.helper import * |
13 | | -from coremltools.converters.mil.backend.backend_helper import _get_probability_var_for_classifier |
14 | 10 | from .passes import mil_passes |
15 | | -import coremltools.proto.MIL_pb2 as pm |
16 | | -from coremltools.converters.mil.mil import types |
17 | | -from coremltools.converters.mil.mil import Function |
| 11 | +from coremltools import _SPECIFICATION_VERSION_IOS_15 |
| 12 | +from coremltools.converters.mil.backend.mil.helper import ( |
| 13 | + cast_to_framework_io_dtype, |
| 14 | + create_file_value, |
| 15 | + create_immediate_value, |
| 16 | + create_list_scalarvalue, |
| 17 | + create_scalar_value, |
| 18 | + types_to_proto |
| 19 | +) |
| 20 | +from coremltools.converters.mil.backend.backend_helper import _get_probability_var_for_classifier |
| 21 | +from coremltools.converters.mil.mil import ( |
| 22 | + Builder as mb, |
| 23 | + Function, |
| 24 | + mil_list, |
| 25 | + types |
| 26 | +) |
18 | 27 | from coremltools.converters.mil.backend.nn.load import _set_optional_inputs |
| 28 | +from coremltools.converters.mil.input_types import ImageType, TensorType, EnumeratedShapes, RangeDim |
19 | 29 | from coremltools.converters.mil.mil.ops.registry import SSAOpRegistry |
20 | 30 | from coremltools.converters.mil.mil.types.symbolic import ( |
21 | 31 | any_symbolic, |
22 | 32 | any_variadic, |
23 | 33 | is_symbolic, |
24 | 34 | ) |
| 35 | +from coremltools.converters.mil.mil.types.type_mapping import types_int64 |
| 36 | +from coremltools.libmilstoragepython import _BlobStorageWriter as BlobWriter |
| 37 | +from coremltools.models.model import _WEIGHTS_FILE_NAME |
25 | 38 | from coremltools.models.neural_network.flexible_shape_utils import ( |
26 | | - NeuralNetworkImageSize, |
27 | | - NeuralNetworkImageSizeRange, |
28 | 39 | add_enumerated_image_sizes, |
29 | 40 | add_multiarray_ndshape_enumeration, |
| 41 | + NeuralNetworkImageSize, |
| 42 | + NeuralNetworkImageSizeRange, |
30 | 43 | set_multiarray_ndshape_range, |
31 | | - update_image_size_range, |
| 44 | + update_image_size_range |
| 45 | +) |
| 46 | +from coremltools.proto import ( |
| 47 | + FeatureTypes_pb2 as ft, |
| 48 | + MIL_pb2 as pm, |
| 49 | + Model_pb2 as ml |
32 | 50 | ) |
33 | 51 |
|
34 | | -from coremltools.libmilstoragepython import _BlobStorageWriter as BlobWriter |
35 | | - |
36 | | -import coremltools.proto.Model_pb2 as ml |
37 | | -import coremltools.proto.FeatureTypes_pb2 as ft |
38 | | -from coremltools.converters.mil.input_types import ImageType, TensorType, EnumeratedShapes, RangeDim |
39 | | -from coremltools.models.model import _WEIGHTS_FILE_NAME |
40 | | -from coremltools.converters.mil.mil import Builder as mb |
41 | | -from coremltools.converters.mil.mil import mil_list |
42 | | -from coremltools import _SPECIFICATION_VERSION_IOS_15 |
43 | 52 |
|
44 | 53 | def should_use_weight_file(val): |
45 | 54 | return ( |
46 | 55 | val is not None |
47 | | - and isinstance(val, (_np.ndarray, _np.generic)) |
| 56 | + and isinstance(val, (np.ndarray, np.generic)) |
48 | 57 | and val.size >= 10 |
49 | 58 | and val.dtype in ['float16', 'float32'] |
50 | 59 | ) |
@@ -97,7 +106,7 @@ def translate_generic_op(op, parameters, blob_writer, literal_params=[]): |
97 | 106 | blocks = None |
98 | 107 | if len(op.blocks) > 0: |
99 | 108 | blocks = [create_block(b, parameters, blob_writer) \ |
100 | | - for b in op.blocks] |
| 109 | + for b in op.blocks] |
101 | 110 |
|
102 | 111 | op_type = op.op_type |
103 | 112 | attr_dict = {} |
@@ -206,6 +215,9 @@ def _add_classify_op(prog, classifier_config): |
206 | 215 |
|
207 | 216 | # add the classify op now |
208 | 217 | with block: |
| 218 | + # cast the int label to np.int64 |
| 219 | + if isinstance(classes[0], int): |
| 220 | + classes = [np.int64(x) for x in classes] |
209 | 221 | classes_var = mb.const(val=mil_list(classes)) |
210 | 222 | out = mb.classify(probabilities=probability_var, classes=classes_var) |
211 | 223 |
|
@@ -344,7 +356,7 @@ def load(prog, weights_dir, resume_on_errors=False, **kwargs): |
344 | 356 | keytype, valtype = var.sym_type.T |
345 | 357 | if types.is_str(keytype): |
346 | 358 | output_feature_type.dictionaryType.stringKeyType.MergeFromString(b"") |
347 | | - elif (keytype == types_int64): |
| 359 | + elif (keytype == types.int64): |
348 | 360 | output_feature_type.dictionaryType.int64KeyType.MergeFromString(b"") |
349 | 361 | else: |
350 | 362 | raise ValueError("Dictionary key type not supported.") |
@@ -445,7 +457,6 @@ def load(prog, weights_dir, resume_on_errors=False, **kwargs): |
445 | 457 | model, input_name, lower_bounds=lb, upper_bounds=ub |
446 | 458 | ) |
447 | 459 |
|
448 | | - |
449 | 460 | # Set optional inputs |
450 | 461 | _set_optional_inputs(model, input_types) |
451 | 462 |
|
|
0 commit comments