Skip to content

Commit 32131d0

Browse files
author
jax authors
committed
Merge pull request #22897 from jakevdp:bool-indexing
PiperOrigin-RevId: 660444193
2 parents 6fc57c0 + b45f0fe commit 32131d0

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

jax/_src/numpy/lax_numpy.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8340,7 +8340,10 @@ def _expand_bool_indices(idx, shape):
83408340
i_shape = _shape(i)
83418341
start = len(out) + ellipsis_offset - newaxis_offset
83428342
expected_shape = shape[start: start + _ndim(i)]
8343-
if i_shape != expected_shape:
8343+
if len(i_shape) != len(expected_shape):
8344+
raise IndexError(f"too many boolean indices at index {dim_number}: got mask of shape "
8345+
f"{i_shape}, but only {len(expected_shape)} dimensions remain.")
8346+
if not all(s1 in (0, s2) for s1, s2 in zip(i_shape, expected_shape)):
83448347
raise IndexError("boolean index did not match shape of indexed array in index "
83458348
f"{dim_number}: got {i_shape}, expected {expected_shape}")
83468349
out.extend(np.where(i))

tests/lax_numpy_indexing_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,6 +1030,23 @@ def testNontrivialBooleanIndexing(self):
10301030
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
10311031
self._CompileAndCheck(jnp_fun, args_maker)
10321032

1033+
@parameterized.parameters(
1034+
[(3,), (0,)],
1035+
[(3, 4), (0,)],
1036+
[(3, 4), (0, 4)],
1037+
[(3, 4), (3, 0)],
1038+
[(3, 4, 5), (3, 0)],
1039+
)
1040+
def testEmptyBooleanIndexing(self, x_shape, m_shape):
1041+
# Regression test for https://github.com/google/jax/issues/22886
1042+
rng = jtu.rand_default(self.rng())
1043+
args_maker = lambda: [rng(x_shape, np.int32), np.empty(m_shape, dtype=bool)]
1044+
1045+
np_fun = lambda x, m: np.asarray(x)[np.asarray(m)]
1046+
jnp_fun = lambda x, m: jnp.asarray(x)[jnp.asarray(m)]
1047+
1048+
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
1049+
10331050
@jtu.sample_product(
10341051
shape=[(2, 3, 4, 5)],
10351052
idx=[

0 commit comments

Comments
 (0)