Use of jax.device_put vs jax.lax.with_sharding_contraints
#31105
Replies: 2 comments 3 replies
-
|
with_sharding_constraint can be used in eager mode too but it's just an identity jit i.e. Under a jit, with_sharding_constraint is a strict constraint which compiler has to respect.
Usually the recommendation is to use device_put outside |
Beta Was this translation helpful? Give feedback.
-
|
Thanks for the reply, but some things are still not clear to me. When using My understanding is that |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
As the question says, it is unclear what the different uses of these two are based on the documentation. First some context. The question stems from @mjo22's effort to rewrite the auto-parallelisation tutorial for equinox and my own attempt at implementing the code in the tutorial, which lead to me noticing a duplicate use of equinox's
filter_shardfunction.Essentially, the issue is that with 0.7.0, Jax recommends using
device_putto shard arrays in eager mode andlax.with_sharding_constraintsbeing used inside jitted functions.equinox.filter_sharduses the latter and while testing I found that this works outside jitted functions just fine. The first question is, is there any downside to using alaxfunction such as this one in eager mode? My understanding is that there isn't in this case and that the only difference between the two is thatdevice_putforces a particular sharding layout on the arrays which the compiler must stick to whilewith_sharding_constraintsonly provides a suggestion which the compiler only needs to consider but is otherwise free to optimize the sharding layout as wanted.If this is true, then how could we achieve the same result in eager mode, that is, before passing the arrays to jitted functions? Is the only way to use
in_shardingwithLayout.AUTO?This is important as it will influence how the equinox filter functions are re-written, especially
eqx.filter_jit.Beta Was this translation helpful? Give feedback.
All reactions