Skip to content

Commit 3acbd44

Browse files
yashk2810jax authors
authored andcommitted
Remove isinstance checks
PiperOrigin-RevId: 425745786
1 parent dcca99b commit 3acbd44

File tree

1 file changed

+2
-6
lines changed

1 file changed

+2
-6
lines changed

jax/experimental/global_device_array.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def _canonicalize_mesh_axes(mesh_axes):
5454
return pspec
5555

5656
def _get_indices(global_shape: Shape, global_mesh: pxla.Mesh,
57-
mesh_axes: MeshAxes) -> Tuple[pxla.Index, ...]:
57+
mesh_axes: MeshAxes) -> Tuple[Index, ...]:
5858
# Import here to avoid cyclic import error when importing gda in pjit.py.
5959
from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources
6060

@@ -66,11 +66,7 @@ def _get_indices(global_shape: Shape, global_mesh: pxla.Mesh,
6666
sharding_spec = pxla.mesh_sharding_specs(
6767
global_mesh.shape, global_mesh.axis_names)(aval, array_mapping)
6868
indices = pxla.spec_to_indices(global_shape, sharding_spec)
69-
for index in indices:
70-
assert isinstance(index, tuple)
71-
for idx in index:
72-
assert isinstance(idx, slice)
73-
return indices
69+
return indices # type: ignore
7470

7571

7672
@_convert_list_args_to_tuple

0 commit comments

Comments
 (0)