JAX is changing the default jax.pmap implementation #32412
danielsuo
announced in
Announcements
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
tl;dr
As of JAX 0.8.0, the default implementation of
jax.pmapwill be based onjax.jitandjax.shard_map. The new implementation is not a perfect replacement for the original. We've published documentation for this change to help users who run into trouble.This change makes
jax.pmapintegrate well with JAX shardings and simplifies the implementation.Help! Fix me now!
IMPORTANT: This option is not a permanent fix. Until January 15, 2026, it will be possible to temporarily use the old version of
jax.pmapby doing one of the following:Setting the shell environment variable
JAX_PMAP_SHMAP_MERGEto something false-like (e.g., 0);Setting the boolean flag
--jax_pmap_shmap_mergeto something false-like if your code parses flags withabsl-py.Using this statement in your main file or anywhere before you call
jax.pmap:NOTE: Please file a bug and tag @danielsuo with a reproducer so we can resolve it as quickly as possible under the new
jax.pmap.How do I know I'm broken? What are some examples and fixes?
Please see the documentation for this change here. We include a number of typical issues that can come up and how to resolve them.
Beta Was this translation helpful? Give feedback.
All reactions