Skip to content

Commit 2edf48a

Browse files
authored
Ensure shapes passed to mb.fill are int (#1850)
Ensure shapes passed to `mb.fill` are `int`, with test cases.
1 parent 973b361 commit 2edf48a

File tree

2 files changed

+25
-0
lines changed

2 files changed

+25
-0
lines changed

coremltools/converters/mil/frontend/torch/ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3594,6 +3594,8 @@ def _make_fill_op(size, val, name):
35943594
assert val is not None
35953595
if isinstance(size, list):
35963596
size = mb.concat(values=size, axis=0)
3597+
if types.is_float(size.dtype):
3598+
size = mb.cast(x=size, dtype="int32")
35973599
fill = mb.fill(shape=size, value=val, name=name)
35983600
return fill
35993601

coremltools/converters/mil/frontend/torch/test/test_torch_ops.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3531,6 +3531,29 @@ def forward(self, x):
35313531
shape, FullStaticModel().eval(), backend=backend, compute_unit=compute_unit
35323532
)
35333533

3534+
@pytest.mark.parametrize(
3535+
"compute_unit, backend, shape_val",
3536+
itertools.product(
3537+
compute_units,
3538+
backends,
3539+
[
3540+
[(1,), 0.0],
3541+
[(2, 3), 3.1415],
3542+
[(1, 1, 2, 5, 1), -2.0],
3543+
],
3544+
),
3545+
)
3546+
def test_full_scalar(self, compute_unit, backend, shape_val):
3547+
shape, val = shape_val
3548+
3549+
class FullScalarModel(nn.Module):
3550+
def forward(self, x):
3551+
return x / torch.full([], fill_value=val)
3552+
3553+
self.run_compare_torch(
3554+
shape, FullScalarModel().eval(), backend=backend, compute_unit=compute_unit
3555+
)
3556+
35343557
@pytest.mark.parametrize(
35353558
"compute_unit, backend, shape_val",
35363559
itertools.product(

0 commit comments

Comments
 (0)