@@ -80,10 +80,9 @@ def get_tensorstore_spec(ckpt_path: str):
8080 return spec
8181
8282
83- async def async_serialize (ckpt_path : str , gda : gda .GlobalDeviceArray ,
84- tensorstore_spec ):
83+ async def async_serialize (gda_inp : gda .GlobalDeviceArray , tensorstore_spec ):
8584 if not tensorstore_spec .get ('metadata' ):
86- tensorstore_spec ['metadata' ] = _get_metadata (gda )
85+ tensorstore_spec ['metadata' ] = _get_metadata (gda_inp )
8786
8887 t = await ts .open (
8988 ts .Spec (tensorstore_spec ),
@@ -97,19 +96,19 @@ async def _write_array(shard):
9796 if shard .replica_id == 0 :
9897 await t [shard .index ].write (shard .data )
9998
100- future_write_state = jax .tree_util .tree_map (_write_array , tuple (gda .local_shards ))
99+ future_write_state = jax .tree_util .tree_map (_write_array ,
100+ tuple (gda_inp .local_shards ))
101101 return await asyncio .gather (* future_write_state )
102102
103103
104- def run_serialization (ckpt_paths , gdas , tensorstore_specs ):
104+ def run_serialization (gdas , tensorstore_specs ):
105105 async def _run_serializer ():
106- future_writer = jax .tree_map (async_serialize , ckpt_paths , gdas ,
107- tensorstore_specs )
106+ future_writer = jax .tree_map (async_serialize , gdas , tensorstore_specs )
108107 return await asyncio .gather (* future_writer )
109108 asyncio .run (_run_serializer ())
110109
111110
112- async def async_deserialize (ckpt_path , mesh , mesh_axes , tensorstore_spec ):
111+ async def async_deserialize (mesh , mesh_axes , tensorstore_spec ):
113112 t = ts .open (ts .Spec (tensorstore_spec ), open = True ).result ()
114113
115114 async def cb (index ):
@@ -118,9 +117,9 @@ async def cb(index):
118117 return await create_async_gda_from_callback (t .shape , mesh , mesh_axes , cb )
119118
120119
121- def run_deserialization (ckpt_paths , global_meshes , mesh_axes , tensorstore_specs ):
120+ def run_deserialization (global_meshes , mesh_axes , tensorstore_specs ):
122121 async def _run_deserializer ():
123- future_gdas = jax .tree_map (async_deserialize , ckpt_paths , global_meshes ,
124- mesh_axes , tensorstore_specs )
122+ future_gdas = jax .tree_map (async_deserialize , global_meshes , mesh_axes ,
123+ tensorstore_specs )
125124 return await asyncio .gather (* future_gdas )
126125 return asyncio .run (_run_deserializer ())
0 commit comments