Skip to content

Commit dcca99b

Browse files
yashk2810jax authors
authored andcommitted
Remove path from the serde API as tspec encompasses those things.
PiperOrigin-RevId: 425727733
1 parent 4e47de6 commit dcca99b

File tree

2 files changed

+13
-15
lines changed

2 files changed

+13
-15
lines changed

jax/experimental/gda_serialization/serialization.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,9 @@ def get_tensorstore_spec(ckpt_path: str):
8080
return spec
8181

8282

83-
async def async_serialize(ckpt_path: str, gda: gda.GlobalDeviceArray,
84-
tensorstore_spec):
83+
async def async_serialize(gda_inp: gda.GlobalDeviceArray, tensorstore_spec):
8584
if not tensorstore_spec.get('metadata'):
86-
tensorstore_spec['metadata'] = _get_metadata(gda)
85+
tensorstore_spec['metadata'] = _get_metadata(gda_inp)
8786

8887
t = await ts.open(
8988
ts.Spec(tensorstore_spec),
@@ -97,19 +96,19 @@ async def _write_array(shard):
9796
if shard.replica_id == 0:
9897
await t[shard.index].write(shard.data)
9998

100-
future_write_state = jax.tree_util.tree_map(_write_array, tuple(gda.local_shards))
99+
future_write_state = jax.tree_util.tree_map(_write_array,
100+
tuple(gda_inp.local_shards))
101101
return await asyncio.gather(*future_write_state)
102102

103103

104-
def run_serialization(ckpt_paths, gdas, tensorstore_specs):
104+
def run_serialization(gdas, tensorstore_specs):
105105
async def _run_serializer():
106-
future_writer = jax.tree_map(async_serialize, ckpt_paths, gdas,
107-
tensorstore_specs)
106+
future_writer = jax.tree_map(async_serialize, gdas, tensorstore_specs)
108107
return await asyncio.gather(*future_writer)
109108
asyncio.run(_run_serializer())
110109

111110

112-
async def async_deserialize(ckpt_path, mesh, mesh_axes, tensorstore_spec):
111+
async def async_deserialize(mesh, mesh_axes, tensorstore_spec):
113112
t = ts.open(ts.Spec(tensorstore_spec), open=True).result()
114113

115114
async def cb(index):
@@ -118,9 +117,9 @@ async def cb(index):
118117
return await create_async_gda_from_callback(t.shape, mesh, mesh_axes, cb)
119118

120119

121-
def run_deserialization(ckpt_paths, global_meshes, mesh_axes, tensorstore_specs):
120+
def run_deserialization(global_meshes, mesh_axes, tensorstore_specs):
122121
async def _run_deserializer():
123-
future_gdas = jax.tree_map(async_deserialize, ckpt_paths, global_meshes,
124-
mesh_axes, tensorstore_specs)
122+
future_gdas = jax.tree_map(async_deserialize, global_meshes, mesh_axes,
123+
tensorstore_specs)
125124
return await asyncio.gather(*future_gdas)
126125
return asyncio.run(_run_deserializer())

jax/experimental/gda_serialization/serialization_test.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,10 @@ def cb2(index):
6464
ckpt_paths = [str(ckpt_dir1), str(ckpt_dir2)]
6565
tspecs = jax.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
6666

67-
serialization.run_serialization(ckpt_paths, [gda1, gda2], tspecs)
67+
serialization.run_serialization([gda1, gda2], tspecs)
6868

69-
m1, m2 = serialization.run_deserialization(ckpt_paths,
70-
[global_mesh, global_mesh],
71-
[mesh_axes, ['x']], tspecs)
69+
m1, m2 = serialization.run_deserialization(
70+
[global_mesh, global_mesh], [mesh_axes, ['x']], tspecs)
7271

7372
self.assertArraysEqual(m1.local_shards[0].data.to_py(),
7473
np.array([[0], [2]]))

0 commit comments

Comments
 (0)