Skip to content

Commit c270159

Browse files
authored
Merge pull request #466 from DawerG/dev/gdawer/bug_fixes_flexible_shapes
Integrate range based flexible shapes + bug fixes in Type Inference + SSA converter + Shaper
2 parents f341d6a + 2b7f447 commit c270159

File tree

7 files changed

+240
-116
lines changed

7 files changed

+240
-116
lines changed

coremltools/converters/nnssa/coreml/shapes.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ def _slice_static(layer_spec, input_shapes):
3434
begin = 0 if params.beginMasks[idx] else begin_indices[idx]
3535
end = dim if params.endMasks[idx] else end_indices[idx]
3636
output_shape[idx] = (end - begin) // params.strides[idx]
37+
if (end - begin) % params.strides[idx] != 0:
38+
output_shape[idx] += 1
3739
return [output_shape]
3840

3941

@@ -45,6 +47,8 @@ def _slice_dynamic(layer_spec, input_shapes):
4547

4648

4749
def _squeeze(layer_spec, input_shapes):
50+
if layer_spec.squeeze.squeezeAll:
51+
return [[1]]
4852
axes = list(layer_spec.squeeze.axes)
4953
input_shape = input_shapes[0]
5054
rank = len(input_shape)
@@ -56,7 +60,7 @@ def _squeeze(layer_spec, input_shapes):
5660
for dim in range(rank):
5761
if dim not in axes:
5862
output_shape.append(input_shape[dim])
59-
elif input_shape[dim] != 1:
63+
elif input_shape[dim] > 0 and input_shape[dim] != 1:
6064
raise ValueError(
6165
'[Shaper] Cannot squeeze on index %d of shape %s' % (dim, str(input_shape)))
6266
return [output_shape] if output_shape else [[1]]
@@ -319,6 +323,9 @@ def _reduce_general(params, input_shapes):
319323
return [output_shape] if output_shape else [[1]]
320324

321325

326+
def _reduce_logsumexp(layer_spec, input_shapes):
327+
return _reduce_general(layer_spec.reduceLogSumExp, input_shapes)
328+
322329
def _reduce_prod(layer_spec, input_shapes):
323330
return _reduce_general(layer_spec.reduceProd, input_shapes)
324331

@@ -355,7 +362,7 @@ def _argmax(layer_spec, input_shapes):
355362

356363

357364
def _argmin(layer_spec, input_shapes):
358-
params = layer_spec.argMax
365+
params = layer_spec.argMin
359366
axis = params.axis
360367
keepdims = not params.removeDim
361368

@@ -427,7 +434,7 @@ def _reorganize_data(layer_spec, input_shapes):
427434
elif 'DepthToSpace' in layer_spec.name or 'BatchToSpaceND' in layer_spec.name:
428435
output_shape[2] *= block_size
429436
output_shape[3] *= block_size
430-
output_shape[1] = input_shape[2] // (block_size * block_size)
437+
output_shape[1] = input_shape[1] // (block_size * block_size)
431438
return [output_shape]
432439

433440

@@ -490,6 +497,7 @@ def _reorganize_data(layer_spec, input_shapes):
490497
'reduce': _reduce,
491498
'argMax': _argmax,
492499
'argMin': _argmin,
500+
'reduceLogSumExp': _reduce_logsumexp,
493501
'reduceProd': _reduce_prod,
494502
'reduceMean': _reduce_mean,
495503
'reduceSum': _reduce_sum,

0 commit comments

Comments
 (0)