Skip to content

Commit b32d1f9

Browse files
author
Orbax Authors
committed
Separate sub-item temporary path class from default temporary path class.
PiperOrigin-RevId: 715363908
1 parent 2a7e309 commit b32d1f9

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

checkpoint/orbax/checkpoint/_src/handlers/composite_checkpoint_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -637,7 +637,7 @@ def _get_item_temporary_directory(
637637
) -> atomicity_types.TemporaryPath:
638638
temporary_path_class = (
639639
self._temporary_path_class
640-
or atomicity_defaults.get_default_temporary_path_class(directory)
640+
or atomicity_defaults.get_item_default_temporary_path_class(directory)
641641
)
642642
tmp_item_dir = temporary_path_class.from_final(
643643
self._get_item_directory(directory, item_name),

checkpoint/orbax/checkpoint/_src/path/atomicity_defaults.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,20 @@
2626
from orbax.checkpoint._src.path import step as step_lib
2727

2828

29+
def get_item_default_temporary_path_class(
30+
final_path: epath.Path,
31+
) -> Type[atomicity_types.TemporaryPath]:
32+
"""Returns the default temporary path class for a given sub-item path."""
33+
if step_lib.is_gcs_path(final_path):
34+
return atomicity.CommitFileTemporaryPath
35+
else:
36+
return atomicity.AtomicRenameTemporaryPath
37+
38+
2939
def get_default_temporary_path_class(
3040
final_path: epath.Path,
3141
) -> Type[atomicity_types.TemporaryPath]:
42+
"""Returns the default temporary path class for a given checkpoint path."""
3243
if step_lib.is_gcs_path(final_path):
3344
return atomicity.CommitFileTemporaryPath
3445
else:

0 commit comments

Comments
 (0)