Skip to content

Commit cd14f9c

Browse files
committed
Fix AdvancedSubtensor static shape with newaxis
1 parent 695574b commit cd14f9c

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

pytensor/tensor/subtensor.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2629,9 +2629,13 @@ def make_node(self, x, *indices):
26292629
advanced_indices = []
26302630
adv_group_axis = None
26312631
last_adv_group_axis = None
2632-
expanded_x_shape = tuple(
2633-
np.insert(np.array(x.type.shape, dtype=object), 1, new_axes)
2634-
)
2632+
if new_axes:
2633+
expanded_x_shape_list = list(x.type.shape)
2634+
for new_axis in new_axes:
2635+
expanded_x_shape_list.insert(new_axis, 1)
2636+
expanded_x_shape = tuple(expanded_x_shape_list)
2637+
else:
2638+
expanded_x_shape = x.type.shape
26352639
for i, (idx, dim_length) in enumerate(
26362640
zip_longest(explicit_indices, expanded_x_shape, fillvalue=NoneSliceConst)
26372641
):

tests/tensor/test_subtensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1856,6 +1856,7 @@ def test_static_shape(self):
18561856

18571857
assert x[idx1].type.shape == (10, None)
18581858
assert x[:, idx1].type.shape == (None, 10)
1859+
assert x[None, :, idx1].type.shape == (1, None, 10)
18591860
assert x[idx2, :5].type.shape == (3, None, None)
18601861
assert specify_shape(x, (None, 7))[idx2, :5].type.shape == (3, None, 5)
18611862
assert specify_shape(x, (None, 3))[idx2, :5].type.shape == (3, None, 3)

0 commit comments

Comments
 (0)