Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 62 additions & 12 deletions checkpoint/orbax/checkpoint/_src/multihost/multislice.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ def in_replica(
)


@functools.partial(jax.jit, static_argnums=(0, 1, 2))
def fake_zero_data(sharding, shape, dtype=jnp.float32) -> jax.Array:
x = jnp.zeros(shape, dtype=dtype)
@functools.partial(jax.jit, static_argnums=0)
def fake_zero_data(sharding, x):
x = jnp.zeros_like(x)
return jax.lax.with_sharding_constraint(x, sharding)


Expand Down Expand Up @@ -334,10 +334,61 @@ def broadcast_one_replica_to_all(
- pytree with broadcasted data
- number of broadcasts performed.
"""
num_replicas = global_mesh.devices.shape[replica_axis_index]
replica_axis_name = global_mesh.axis_names[replica_axis_index]

if memory_limit_bytes is None:
memory_limit_bytes = get_available_memory(in_tree, memory_scaling_factor)
logging.info('Using available memory of %d bytes.', memory_limit_bytes)

# Set replica_axis to be 0, regardless of its actual value.
def globalize_single_replica_arrays(inp):
sharding = inp.sharding
if not isinstance(sharding, jax.sharding.NamedSharding):
raise ValueError(
'Must provide input arrays with NamedSharding. '
f'Got {type(sharding)} instead.'
)
if not is_source:
inp = fake_zero_data(sharding, inp)
inp = jnp.expand_dims(inp, axis=0)

num_slices = slice_count()
if num_slices != num_replicas:
logging.info(
'num_slices: %d != num_replicas: %d', num_slices, num_replicas
)
in_spec = jax.sharding.PartitionSpec(
'replication',
*sharding.spec,
)
global_shape = (num_slices,) + inp.shape[1:]
slice_global_mesh = jax.sharding.Mesh(
global_mesh.devices.reshape((
num_slices,
*(
-1 if i == replica_axis_index else n
for i, n in enumerate(global_mesh.devices.shape)
),
)),
('replication', *global_mesh.axis_names),
)
global_sharding = jax.sharding.NamedSharding(slice_global_mesh, in_spec)
else:
assert replica_axis_name not in sharding.spec, (
f'Replica axis name {replica_axis_name} already exists in'
f' sharding.spec {sharding.spec}'
)
in_spec = jax.sharding.PartitionSpec(
replica_axis_name,
*sharding.spec,
)
global_shape = (num_replicas,) + inp.shape[1:]
global_sharding = jax.sharding.NamedSharding(global_mesh, in_spec)
return jax.make_array_from_single_device_arrays(
global_shape, global_sharding, [s.data for s in inp.addressable_shards]
)

tree_len = len(in_tree)
start = 0
out_tree = []
Expand Down Expand Up @@ -365,20 +416,19 @@ def broadcast_one_replica_to_all(
end += 1
subtree = tuple(subtree)
num_broadcasts += 1
globalized_sharded_subtree = jax.tree.map(
functools.partial(
_globalize_single_replica_arrays,
global_mesh=global_mesh,
replica_axis_index=replica_axis_index,
is_source=is_source,
out_sharding = jax.tree.map(
lambda x: jax.sharding.NamedSharding(
global_mesh, jax.sharding.PartitionSpec(*x.sharding.spec)
),
subtree,
)
in_tree_sharded = jax.tree.map(globalize_single_replica_arrays, subtree)
# Delete immediately to conserve memory.
jax.tree.map(lambda x: x.delete(), subtree)
out_subtree = _merge_globalized_replicas(
globalized_sharded_subtree, global_mesh
)
out_subtree = jax.jit(
lambda tree: jax.tree.map(functools.partial(jnp.sum, axis=0), tree),
out_shardings=out_sharding,
)(in_tree_sharded)
out_tree.extend(out_subtree)
jax.block_until_ready(out_subtree)
start = end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@ suite_name: "Multislice Broadcast Benchmark"

mesh_config:
mesh_axes: ["data", "stage", "fsdp", "fsdp_transpose", "sequence", "tensor", "expert", "autoregressive"]
# ICI: Within a slice.
ici_parallelism: {"fsdp": 16, "data": 16}
# DCN: Across slices.
dcn_parallelism: {"data": 2}
ici_parallelism: {"fsdp": 32, "tensor": 1, "data": 2}
dcn_parallelism: {"data": 2} # num_slices on the axis at replica_axis_index
allow_split_physical_axes: true

checkpoint_config:
spec:
Expand Down
Loading