-
Notifications
You must be signed in to change notification settings - Fork 68
Open
Description
Hello,
It seems when using jax.sharding.AxisType.Explicit
axis types in a sharded array's mesh, ocp's save/restore methods do not preserve this metadata. Here's a minimal repro:
import tempfile
import jax
import orbax.checkpoint as ocp
from jax.sharding import PartitionSpec as P
# Spoof 4 CPU devices.
jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_num_cpu_devices", 4)
# 2D mesh over all devices.
explicit_mesh = jax.make_mesh(
(2, 2),
("data", "model"),
(jax.sharding.AxisType.Explicit, jax.sharding.AxisType.Explicit),
)
@jax.jit
def init() -> jax.Array:
"""Initialize a distributed array."""
return jax.numpy.arange(4).reshape((2, 2), out_sharding=P("data", "model"))
def main():
"""Test that sharding is preserved across save and restore."""
with (
tempfile.TemporaryDirectory(delete=True) as temp_dir,
jax.set_mesh(explicit_mesh)
):
distributed_array = init()
checkpointer = ocp.StandardCheckpointer()
checkpointer.save(f"{temp_dir}/ocp_distributed_test", (distributed_array,))
checkpointer.wait_until_finished()
restored_array, = checkpointer.restore(f"{temp_dir}/ocp_distributed_test")
assert \
distributed_array.sharding == restored_array.sharding, \
(distributed_array.sharding, restored_array.sharding)
if __name__ == "__main__":
main()
This fails with:
AssertionError: (NamedSharding(mesh=Mesh('data': 2, 'model': 2, axis_types=(Explicit, Explicit)), spec=PartitionSpec('data', 'model'), memory_kind=device), NamedSharding(mesh=Mesh('data': 2, 'model': 2, axis_types=(Auto, Auto)), spec=PartitionSpec('data', 'model'), memory_kind=device))
The original and restored arrays differ by the axis_types
of their mesh: the restored version is set to Auto
even though the original was Explicit
.
This is problematic when restoring an explicitly-sharded model from a checkpoint. If the program ordinarily runs under jax.set_mesh
with explicit axes, it will fail if the parameters are not explicitly sharded.
Metadata
Metadata
Assignees
Labels
No labels