Skip to content

Checkpoints do not preserve AxisType.Explicit meshes #2408

@jkyl

Description

@jkyl

Hello,

It seems when using jax.sharding.AxisType.Explicit axis types in a sharded array's mesh, ocp's save/restore methods do not preserve this metadata. Here's a minimal repro:

import tempfile

import jax
import orbax.checkpoint as ocp
from jax.sharding import PartitionSpec as P


# Spoof 4 CPU devices.
jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_num_cpu_devices", 4)


# 2D mesh over all devices.
explicit_mesh = jax.make_mesh(
    (2, 2),
    ("data", "model"),
    (jax.sharding.AxisType.Explicit, jax.sharding.AxisType.Explicit),
)


@jax.jit
def init() -> jax.Array:
    """Initialize a distributed array."""
    return jax.numpy.arange(4).reshape((2, 2), out_sharding=P("data", "model"))


def main():
    """Test that sharding is preserved across save and restore."""
    with (
        tempfile.TemporaryDirectory(delete=True) as temp_dir,
        jax.set_mesh(explicit_mesh)
    ):
        distributed_array = init()
        checkpointer = ocp.StandardCheckpointer()
        checkpointer.save(f"{temp_dir}/ocp_distributed_test", (distributed_array,))
        checkpointer.wait_until_finished()
        restored_array, = checkpointer.restore(f"{temp_dir}/ocp_distributed_test")
        assert \
            distributed_array.sharding == restored_array.sharding, \
            (distributed_array.sharding, restored_array.sharding)


if __name__ == "__main__":
    main()

This fails with:

AssertionError: (NamedSharding(mesh=Mesh('data': 2, 'model': 2, axis_types=(Explicit, Explicit)), spec=PartitionSpec('data', 'model'), memory_kind=device), NamedSharding(mesh=Mesh('data': 2, 'model': 2, axis_types=(Auto, Auto)), spec=PartitionSpec('data', 'model'), memory_kind=device))

The original and restored arrays differ by the axis_types of their mesh: the restored version is set to Auto even though the original was Explicit.

This is problematic when restoring an explicitly-sharded model from a checkpoint. If the program ordinarily runs under jax.set_mesh with explicit axes, it will fail if the parameters are not explicitly sharded.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions