Skip to content

Commit 8c565aa

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Error in jax.make_mesh if we detect multi-slice topology
PiperOrigin-RevId: 829634731
1 parent ab11905 commit 8c565aa

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

jax/_src/sharding_impls.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1191,6 +1191,11 @@ def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str],
11911191
mesh_devices = mesh_utils.create_device_mesh(
11921192
new_axis_shapes, devices,
11931193
allow_split_physical_axes=allow_split_physical_axes)
1194+
if (hasattr(mesh_devices.flat[0], 'slice_index') and
1195+
len({d.slice_index for d in mesh_devices.flat}) > 1):
1196+
raise ValueError(
1197+
'`jax.make_mesh` does not support multi-slice topologies. Please use'
1198+
' jax.experimental.mesh_utils.create_hybrid_device_mesh')
11941199
if axis_types is None:
11951200
if deprecations.is_accelerated('jax-make-mesh-default-explicit'):
11961201
axis_types = (mesh_lib.AxisType.Explicit,) * len(mesh_devices.shape)

0 commit comments

Comments
 (0)