@@ -519,6 +519,7 @@ def _device_put_impl(
519519 if aval is None :
520520 try :
521521 aval = core .abstractify (x )
522+ aval = update_dp_aval (aval , device )
522523 except TypeError as err :
523524 raise TypeError (
524525 f"Argument '{ x } ' of type { type (x )} is not a valid JAX type" ) from err
@@ -557,11 +558,11 @@ def _batched_device_put_impl(
557558 devices : Sequence [Device | Sharding | Format | None ],
558559 srcs : Sequence [Device | Sharding | Format | None ],
559560 copy_semantics : Sequence [ArrayCopySemantics ],
560- x_avals : Sequence [core .ShapedArray | None ]):
561+ dst_avals : Sequence [core .ShapedArray | None ]):
561562 ys = []
562563 dsa_indices , dsa_xs , dsa_shardings , dsa_copy_semantics = [], [], [], []
563564 for i , (x , device , src , cp , aval ) in enumerate (
564- zip (xs , devices , srcs , copy_semantics , x_avals )):
565+ zip (xs , devices , srcs , copy_semantics , dst_avals )):
565566 y = _device_put_impl (x , device = device , src = src , copy = cp , aval = aval )
566567 if isinstance (y , _DeferredShardArg ):
567568 dsa_indices .append (i )
@@ -589,24 +590,29 @@ def batched_device_put_impl(
589590 copy_semantics : Sequence [ArrayCopySemantics ]):
590591 return _batched_device_put_impl (
591592 * xs , devices = devices , srcs = srcs , copy_semantics = copy_semantics ,
592- x_avals = [None ] * len (devices ))
593+ dst_avals = [None ] * len (devices ))
593594
594595
595596device_put_p = core .Primitive ('device_put' )
596597device_put_p .multiple_results = True
597598device_put_p .def_impl (batched_device_put_impl )
598599
599600
601+ def update_dp_aval (aval , d ):
602+ if not isinstance (aval , core .ShapedArray ):
603+ return aval
604+ if isinstance (d , Sharding ):
605+ aval = (aval .update (sharding = aval .sharding .update (mesh = d .mesh .abstract_mesh ))
606+ if isinstance (d , NamedSharding ) else aval .update (sharding = None ))
607+ if d .memory_kind is not None :
608+ aval = aval .update (memory_space = core .mem_kind_to_space (d .memory_kind ))
609+ return aval
610+ elif isinstance (d , core .MemorySpace ):
611+ return aval .update (memory_space = d )
612+ return aval
613+
600614def _device_put_abstract_eval (* xs , devices , srcs , copy_semantics ):
601- out = []
602- for x , d in zip (xs , devices ):
603- if isinstance (d , Sharding ) and d .memory_kind is not None :
604- out .append (x .update (memory_space = core .mem_kind_to_space (d .memory_kind )))
605- elif isinstance (d , core .MemorySpace ):
606- out .append (x .update (memory_space = d ))
607- else :
608- out .append (x )
609- return out
615+ return [update_dp_aval (x , d ) for x , d in zip (xs , devices )]
610616device_put_p .def_abstract_eval (_device_put_abstract_eval )
611617
612618def _device_put_transpose (cts , * _ , devices , srcs , copy_semantics ):
0 commit comments