Skip to content

Commit ef18cc0

Browse files
authored
not create new symbol with immutable shape convolution. (#1867)
1 parent 7c71c53 commit ef18cc0

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

coremltools/converters/mil/mil/ops/defs/_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,10 @@ def spatial_dimensions_out_shape(
252252
# * `effective_ks` (effective kernel size, determined from kernel size + dilations) cannot be symbolic
253253
# * strides cannot be symbolic
254254
if is_symbolic(input_shape[r]):
255-
out_shape.append(get_new_symbol())
255+
if not is_symbolic(pad[r]) and pad[r] - effective_ks[r] == -1 and strides[r] == 1:
256+
out_shape.append(input_shape[r])
257+
else:
258+
out_shape.append(get_new_symbol())
256259
else:
257260
out_dim = 0
258261
if not ceil_mode:

coremltools/converters/mil/mil/ops/tests/test_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
55

66
import numpy as np
7+
from coremltools.converters.mil import get_new_symbol
78

89
from coremltools.converters.mil.mil.ops.defs._utils import (
910
aggregated_pad, effective_kernel, spatial_dimensions_out_shape)
@@ -260,3 +261,15 @@ def test_same_padding_shape_dilation_2(self):
260261

261262
expected = [5, 5]
262263
np.testing.assert_equal(actual, expected)
264+
265+
def test_symbolic_custom_pad(self):
266+
input_shape = (get_new_symbol(), get_new_symbol())
267+
actual = spatial_dimensions_out_shape(
268+
pad_type="custom",
269+
input_shape=input_shape,
270+
kernel_shape=(1, 1),
271+
strides=(1, 1),
272+
dilations=(1, 1),
273+
custom_pad=(0, 0, 0, 0),
274+
)
275+
np.testing.assert_equal(actual, input_shape)

0 commit comments

Comments
 (0)