Skip to content

Commit 6093c04

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Fix device_put's abstract_eval to return the correct sharding on the type when the sharding passed to it changes. Fixes #31793
PiperOrigin-RevId: 807429077
1 parent a29751e commit 6093c04

File tree

3 files changed

+33
-15
lines changed

3 files changed

+33
-15
lines changed

jax/_src/api.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2719,13 +2719,16 @@ def device_put(
27192719
assert not m and not d
27202720
copy_semantics.append(dispatch.ArrayCopySemantics.ALWAYS_COPY)
27212721

2722-
x_avals = tuple(shaped_abstractify(i) for i in x_flat)
2723-
for aval, d in zip(x_avals, device_flat):
2722+
dst_avals = []
2723+
for xf, d in zip(x_flat, device_flat):
2724+
aval = shaped_abstractify(xf)
2725+
aval = dispatch.update_dp_aval(aval, d)
2726+
dst_avals.append(aval)
27242727
_check_sharding(aval, d)
27252728
if core.trace_state_clean():
27262729
out_flat = dispatch._batched_device_put_impl(
27272730
*x_flat, devices=device_flat, srcs=src_flat, # type: ignore
2728-
copy_semantics=copy_semantics, x_avals=x_avals)
2731+
copy_semantics=copy_semantics, dst_avals=dst_avals)
27292732
else:
27302733
out_flat = dispatch.device_put_p.bind(
27312734
*x_flat, devices=tuple(device_flat), srcs=tuple(src_flat),

jax/_src/dispatch.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,7 @@ def _device_put_impl(
519519
if aval is None:
520520
try:
521521
aval = core.abstractify(x)
522+
aval = update_dp_aval(aval, device)
522523
except TypeError as err:
523524
raise TypeError(
524525
f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err
@@ -557,11 +558,11 @@ def _batched_device_put_impl(
557558
devices: Sequence[Device | Sharding | Format | None],
558559
srcs: Sequence[Device | Sharding | Format | None],
559560
copy_semantics: Sequence[ArrayCopySemantics],
560-
x_avals: Sequence[core.ShapedArray | None]):
561+
dst_avals: Sequence[core.ShapedArray | None]):
561562
ys = []
562563
dsa_indices, dsa_xs, dsa_shardings, dsa_copy_semantics = [], [], [], []
563564
for i, (x, device, src, cp, aval) in enumerate(
564-
zip(xs, devices, srcs, copy_semantics, x_avals)):
565+
zip(xs, devices, srcs, copy_semantics, dst_avals)):
565566
y = _device_put_impl(x, device=device, src=src, copy=cp, aval=aval)
566567
if isinstance(y, _DeferredShardArg):
567568
dsa_indices.append(i)
@@ -589,24 +590,29 @@ def batched_device_put_impl(
589590
copy_semantics: Sequence[ArrayCopySemantics]):
590591
return _batched_device_put_impl(
591592
*xs, devices=devices, srcs=srcs, copy_semantics=copy_semantics,
592-
x_avals=[None] * len(devices))
593+
dst_avals=[None] * len(devices))
593594

594595

595596
device_put_p = core.Primitive('device_put')
596597
device_put_p.multiple_results = True
597598
device_put_p.def_impl(batched_device_put_impl)
598599

599600

601+
def update_dp_aval(aval, d):
602+
if not isinstance(aval, core.ShapedArray):
603+
return aval
604+
if isinstance(d, Sharding):
605+
aval = (aval.update(sharding=aval.sharding.update(mesh=d.mesh.abstract_mesh))
606+
if isinstance(d, NamedSharding) else aval.update(sharding=None))
607+
if d.memory_kind is not None:
608+
aval = aval.update(memory_space=core.mem_kind_to_space(d.memory_kind))
609+
return aval
610+
elif isinstance(d, core.MemorySpace):
611+
return aval.update(memory_space=d)
612+
return aval
613+
600614
def _device_put_abstract_eval(*xs, devices, srcs, copy_semantics):
601-
out = []
602-
for x, d in zip(xs, devices):
603-
if isinstance(d, Sharding) and d.memory_kind is not None:
604-
out.append(x.update(memory_space=core.mem_kind_to_space(d.memory_kind)))
605-
elif isinstance(d, core.MemorySpace):
606-
out.append(x.update(memory_space=d))
607-
else:
608-
out.append(x)
609-
return out
615+
return [update_dp_aval(x, d) for x, d in zip(xs, devices)]
610616
device_put_p.def_abstract_eval(_device_put_abstract_eval)
611617

612618
def _device_put_transpose(cts, *_, devices, srcs, copy_semantics):

tests/pjit_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8894,6 +8894,15 @@ def model(kernel, inputs):
88948894
self.assertEqual(out1.sharding, kernel.sharding)
88958895
self.assertEqual(out2.sharding, inputs.sharding)
88968896

8897+
@jtu.with_explicit_mesh((2,), 'x')
8898+
def test_device_put_typeof(self, mesh):
8899+
array = jnp.zeros(8)
8900+
self.assertEqual(jax.typeof(array).sharding,
8901+
NamedSharding(mesh.abstract_mesh, P(None)))
8902+
8903+
array = jax.device_put(array, SingleDeviceSharding(jax.devices()[0]))
8904+
self.assertTrue(jax.typeof(array).sharding.mesh.empty)
8905+
88978906

88988907
@jtu.pytest_mark_if_available('multiaccelerator')
88998908
class PJitErrorTest(jtu.JaxTestCase):

0 commit comments

Comments
 (0)