@@ -4216,22 +4216,26 @@ def jnp_fun(a, c):
42164216 else :
42174217 self ._CompileAndCheck (jnp_fun , args_maker )
42184218
4219- @parameterized .parameters (
4220- (0 , (2 , 1 , 3 )),
4221- (5 , (2 , 1 , 3 )),
4222- (0 , ()),
4223- (np .array ([0 , 1 , 2 ]), (2 , 2 )),
4224- (np .array ([[[0 , 1 ], [2 , 3 ]]]), (2 , 2 )))
4225- def testUnravelIndex (self , flat_index , shape ):
4226- args_maker = lambda : (flat_index , shape )
4227- np_fun = jtu .with_jax_dtype_defaults (np .unravel_index , use_defaults = not hasattr (flat_index , 'dtype' ))
4228- self ._CheckAgainstNumpy (np_fun , jnp .unravel_index , args_maker )
4229- self ._CompileAndCheck (jnp .unravel_index , args_maker )
4230-
4231- def testUnravelIndexOOB (self ):
4232- self .assertEqual (jnp .unravel_index (2 , (2 ,)), (1 ,))
4233- self .assertEqual (jnp .unravel_index (- 2 , (2 , 1 , 3 ,)), (1 , 0 , 1 ))
4234- self .assertEqual (jnp .unravel_index (- 3 , (2 ,)), (0 ,))
4219+ @parameterized .named_parameters (jtu .cases_from_list (
4220+ {"testcase_name" : "_shape={}_idx={}" .format (shape ,
4221+ jtu .format_shape_dtype_string (idx_shape , dtype )),
4222+ "shape" : shape , "idx_shape" : idx_shape , "dtype" : dtype }
4223+ for shape in nonempty_nonscalar_array_shapes
4224+ for dtype in int_dtypes
4225+ for idx_shape in all_shapes ))
4226+ def testUnravelIndex (self , shape , idx_shape , dtype ):
4227+ size = prod (shape )
4228+ rng = jtu .rand_int (self .rng (), low = - ((2 * size ) // 3 ), high = (2 * size ) // 3 )
4229+
4230+ def np_fun (index , shape ):
4231+ # Adjust out-of-bounds behavior to match jax's documented behavior.
4232+ index = np .clip (index , - size , size - 1 )
4233+ index = np .where (index < 0 , index + size , index )
4234+ return np .unravel_index (index , shape )
4235+ jnp_fun = jnp .unravel_index
4236+ args_maker = lambda : [rng (idx_shape , dtype ), shape ]
4237+ self ._CheckAgainstNumpy (np_fun , jnp_fun , args_maker )
4238+ self ._CompileAndCheck (jnp_fun , args_maker )
42354239
42364240 def testAstype (self ):
42374241 rng = self .rng ()
0 commit comments