Skip to content

Commit f0f7faf

Browse files
committed
make checkpointing thread safe and deterministic
Summary: - the regression tests fail (on future changes) because it expects no recovery to happen, or it happens at the first step - because we validate the parameters at each step, if recovery happens non deterministically, we can't really validate the parameters - to fix this, copy the state dict before transferring it - the checkpointing also wasn't thread safe for http transport so use lock the model in the pre step hook and when we want to transfer the checkpoint
1 parent b746582 commit f0f7faf

File tree

4 files changed

+68
-6
lines changed

4 files changed

+68
-6
lines changed

torchft/checkpointing/http_transport.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import Generator, List, Optional, TypeVar, cast
1717

1818
import torch
19+
from torch.distributed.tensor import DTensor, distribute_tensor
1920
from torch.utils._pytree import TreeSpec, tree_flatten, tree_unflatten
2021

2122
from torchft.checkpointing._rwlock import RWLock
@@ -266,6 +267,15 @@ def recv_checkpoint(
266267
return tree_unflatten(values, spec)
267268

268269

270+
def _clone_cpu_tensor(tensor: torch.Tensor) -> torch.Tensor:
271+
if isinstance(tensor, DTensor):
272+
return distribute_tensor(
273+
tensor.to_local().clone(), tensor.device_mesh, tensor.placements
274+
)
275+
else:
276+
return tensor.clone()
277+
278+
269279
def _to_cpu(values: List[T], pin_memory: bool) -> List[T]:
270280
out = []
271281
for v in values:
@@ -278,7 +288,7 @@ def _to_cpu(values: List[T], pin_memory: bool) -> List[T]:
278288
else:
279289
out.append(v.cpu())
280290
else:
281-
out.append(v)
291+
out.append(_clone_cpu_tensor(v))
282292
else:
283293
out.append(v)
284294
return out

torchft/local_sgd.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ def __init__(
8686
self._hooks: List[RemovableHandle] = []
8787

8888
def __enter__(self) -> "LocalSGD":
89+
self._hooks.append(
90+
self._local_optimizer.register_step_pre_hook(self._step_pre_hook)
91+
)
8992
# Add optimizer hook which increments the local step counter and syncs if necessary
9093
self._hooks.append(
9194
self._local_optimizer.register_step_post_hook(self._step_post_hook)
@@ -106,12 +109,20 @@ def __exit__(
106109

107110
return False # Propagate exceptions
108111

112+
def _step_pre_hook(
113+
self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any]
114+
) -> None:
115+
# The checkpoint may transfer model parameters, so we need to make access to it thread safe
116+
self._manager.disallow_state_dict_read()
117+
109118
def _step_post_hook(
110119
self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any]
111120
) -> None:
112121
"""
113122
This hook is registered on the optimizer and is called after the optimizer step.
114123
"""
124+
self._manager.allow_state_dict_read()
125+
115126
self._local_step += 1
116127
if self._local_step >= self._sync_every:
117128
self.sync()
@@ -676,12 +687,21 @@ def _restore_parameters(self) -> None:
676687
fragment.restore_parameters()
677688

678689
def __enter__(self) -> "DiLoCo":
690+
self._hooks.append(
691+
self._local_optimizer.register_step_pre_hook(self._step_pre_hook)
692+
)
679693
# Add optimizer hook which increments the local step counter and syncs if necessary
680694
self._hooks.append(
681695
self._local_optimizer.register_step_post_hook(self._step_post_hook)
682696
)
683697
return self
684698

699+
def _step_pre_hook(
700+
self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any]
701+
) -> None:
702+
# The checkpoint may transfer model parameters, so we need to make access to it thread safe
703+
self._manager.disallow_state_dict_read()
704+
685705
def __exit__(
686706
self,
687707
exc_type: Optional[Type[BaseException]],
@@ -716,6 +736,8 @@ def _step_post_hook(
716736
"""
717737
This hook is registered on the optimizer and is called after the optimizer step.
718738
"""
739+
self._manager.allow_state_dict_read()
740+
719741
# We need to make sure all nodes send the same fragments in order.
720742
# This is to avoid deadlocking e.g.
721743
#

torchft/local_sgd_integ_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
)
3737

3838
logger: logging.Logger = logging.getLogger(__name__)
39+
logging.basicConfig(level=logging.INFO)
3940

4041

4142
def local_sgd_train_loop(
@@ -143,6 +144,7 @@ def assert_equal_global_state(
143144
rep1[step]["user"][f"StreamingDiLoCoFragment_{i}"],
144145
rep0[step]["user"][f"StreamingDiLoCoFragment_{i}"],
145146
check_device=False,
147+
msg=f"{step=} {i=}",
146148
)
147149
# Check all outer optimizers
148150
for i in range(
@@ -574,3 +576,9 @@ def test_streaming_diloco_commit_failure(
574576
self.assertEqual(
575577
event_injector.count[EventInjectorEvent.AllreduceFailure], 1
576578
)
579+
580+
581+
if __name__ == "__main__":
582+
import unittest
583+
584+
unittest.main()

torchft/manager.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343

4444
from torchft._torchft import ManagerClient, ManagerServer
4545
from torchft.checkpointing import CheckpointTransport, HTTPTransport
46+
from torchft.checkpointing._rwlock import RWLock
4647
from torchft.futures import future_timeout
4748
from torchft.work import _DummyWork, _WorkWrapper
4849

@@ -204,6 +205,9 @@ def __init__(
204205
self._load_state_dict_fns: Dict[str, Callable[[object], None]] = {}
205206
self._user_state_dicts: Dict[str, Callable[[], object]] = {}
206207

208+
# Protects state dict
209+
self._state_dict_lock = RWLock(timeout=timeout.total_seconds())
210+
207211
if load_state_dict and state_dict:
208212
self.register_state_dict_fn("default", load_state_dict, state_dict)
209213

@@ -312,6 +316,21 @@ def __init__(
312316
# first step is 1
313317
self._participating_replica_rank: Optional[int] = None
314318
self._participating_replica_world_size: int = 0
319+
self._is_state_dict_read_allowed = True
320+
321+
def allow_state_dict_read(self) -> None:
322+
if self._is_state_dict_read_allowed:
323+
return
324+
325+
self._is_state_dict_read_allowed = True
326+
self._state_dict_lock.w_release()
327+
328+
def disallow_state_dict_read(self) -> None:
329+
if not self._is_state_dict_read_allowed:
330+
return
331+
332+
self._is_state_dict_read_allowed = False
333+
self._state_dict_lock.w_acquire()
315334

316335
def register_state_dict_fn(
317336
self,
@@ -819,11 +838,14 @@ def load_state_dict(self, state_dict: Dict[str, int]) -> None:
819838
self._batches_committed = state_dict["batches_committed"]
820839

821840
def _manager_state_dict(self) -> Dict[str, object]:
822-
assert len(self._user_state_dicts) > 0, "user state_dict is not initialized."
823-
return {
824-
"user": {key: value() for key, value in self._user_state_dicts.items()},
825-
"torchft": self.state_dict(),
826-
}
841+
with self._state_dict_lock.r_lock():
842+
assert (
843+
len(self._user_state_dicts) > 0
844+
), "user state_dict is not initialized."
845+
return {
846+
"user": {key: value() for key, value in self._user_state_dicts.items()},
847+
"torchft": self.state_dict(),
848+
}
827849

828850
def state_dict(self) -> Dict[str, int]:
829851
"""

0 commit comments

Comments
 (0)