Skip to content

Commit 7297115

Browse files
author
jax authors
committed
Merge pull request #10546 from jakevdp:unravel-indices
PiperOrigin-RevId: 446553390
2 parents a8c6742 + 3c2d2b2 commit 7297115

File tree

2 files changed

+31
-27
lines changed

2 files changed

+31
-27
lines changed

jax/_src/numpy/lax_numpy.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -808,22 +808,22 @@ def ravel_multi_index(multi_index, dims, mode='raise', order='C'):
808808

809809
_UNRAVEL_INDEX_DOC = """\
810810
Unlike numpy's implementation of unravel_index, negative indices are accepted
811-
and out-of-bounds indices are clipped.
811+
and out-of-bounds indices are clipped into the valid range.
812812
"""
813813

814814
@_wraps(np.unravel_index, lax_description=_UNRAVEL_INDEX_DOC)
815815
def unravel_index(indices, shape):
816816
_check_arraylike("unravel_index", indices)
817-
sizes = append(array(shape), 1)
818-
cumulative_sizes = cumprod(sizes[::-1])[::-1]
819-
total_size = cumulative_sizes[0]
820-
# Clip so raveling and unraveling an oob index will not change the behavior
821-
clipped_indices = clip(indices, -total_size, total_size - 1)
822-
# Add enough trailing dims to avoid conflict with clipped_indices
823-
cumulative_sizes = expand_dims(cumulative_sizes, range(1, 1 + _ndim(indices)))
824-
clipped_indices = expand_dims(clipped_indices, axis=0)
825-
idx = clipped_indices % cumulative_sizes[:-1] // cumulative_sizes[1:]
826-
return tuple(idx)
817+
shape = atleast_1d(shape)
818+
if shape.ndim != 1:
819+
raise ValueError("unravel_index: shape should be a scalar or 1D sequence.")
820+
out_indices = [None] * len(shape)
821+
for i, s in reversed(list(enumerate(shape))):
822+
indices, out_indices[i] = divmod(indices, s)
823+
oob_pos = indices > 0
824+
oob_neg = indices < -1
825+
return tuple(where(oob_pos, s - 1, where(oob_neg, 0, i))
826+
for s, i in zip(shape, out_indices))
827827

828828
@_wraps(np.resize)
829829
@partial(jit, static_argnames=('new_shape',))

tests/lax_numpy_test.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4216,22 +4216,26 @@ def jnp_fun(a, c):
42164216
else:
42174217
self._CompileAndCheck(jnp_fun, args_maker)
42184218

4219-
@parameterized.parameters(
4220-
(0, (2, 1, 3)),
4221-
(5, (2, 1, 3)),
4222-
(0, ()),
4223-
(np.array([0, 1, 2]), (2, 2)),
4224-
(np.array([[[0, 1], [2, 3]]]), (2, 2)))
4225-
def testUnravelIndex(self, flat_index, shape):
4226-
args_maker = lambda: (flat_index, shape)
4227-
np_fun = jtu.with_jax_dtype_defaults(np.unravel_index, use_defaults=not hasattr(flat_index, 'dtype'))
4228-
self._CheckAgainstNumpy(np_fun, jnp.unravel_index, args_maker)
4229-
self._CompileAndCheck(jnp.unravel_index, args_maker)
4230-
4231-
def testUnravelIndexOOB(self):
4232-
self.assertEqual(jnp.unravel_index(2, (2,)), (1,))
4233-
self.assertEqual(jnp.unravel_index(-2, (2, 1, 3,)), (1, 0, 1))
4234-
self.assertEqual(jnp.unravel_index(-3, (2,)), (0,))
4219+
@parameterized.named_parameters(jtu.cases_from_list(
4220+
{"testcase_name": "_shape={}_idx={}".format(shape,
4221+
jtu.format_shape_dtype_string(idx_shape, dtype)),
4222+
"shape": shape, "idx_shape": idx_shape, "dtype": dtype}
4223+
for shape in nonempty_nonscalar_array_shapes
4224+
for dtype in int_dtypes
4225+
for idx_shape in all_shapes))
4226+
def testUnravelIndex(self, shape, idx_shape, dtype):
4227+
size = prod(shape)
4228+
rng = jtu.rand_int(self.rng(), low=-((2 * size) // 3), high=(2 * size) // 3)
4229+
4230+
def np_fun(index, shape):
4231+
# Adjust out-of-bounds behavior to match jax's documented behavior.
4232+
index = np.clip(index, -size, size - 1)
4233+
index = np.where(index < 0, index + size, index)
4234+
return np.unravel_index(index, shape)
4235+
jnp_fun = jnp.unravel_index
4236+
args_maker = lambda: [rng(idx_shape, dtype), shape]
4237+
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
4238+
self._CompileAndCheck(jnp_fun, args_maker)
42354239

42364240
def testAstype(self):
42374241
rng = self.rng()

0 commit comments

Comments
 (0)