Skip to content

Commit a4bb971

Browse files
committed
local_sgd: initial version of fault tolerant LocalSGD
1 parent a484e4f commit a4bb971

File tree

5 files changed

+283
-5
lines changed

5 files changed

+283
-5
lines changed

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ the entire training job.
1717
manager
1818
optim
1919
ddp
20+
local_sgd
2021
data
2122
checkpointing
2223
parameter_server

docs/source/local_sgd.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
.. automodule:: torchft.local_sgd
2+
:members:
3+
:undoc-members:
4+
:show-inheritance:

torchft/local_sgd.py

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
LocalSGD
9+
=========
10+
11+
This module implements a fault tolerant version of LocalSGD and related methods.
12+
"""
13+
14+
from typing import Any, Dict, Mapping, Optional
15+
16+
import torch
17+
from torch import nn
18+
19+
from torchft.manager import Manager
20+
21+
22+
class LocalSGD(nn.Module):
23+
"""
24+
LocalSGD is a model wrapper similar to DistributedDataParallel that
25+
implements the algorithm described in https://arxiv.org/pdf/1805.09767
26+
27+
This will synchronize the model parameters periodically in a fault tolerant
28+
way using a torchft Manager.
29+
30+
This expects you to call step() on every step of the training loop after
31+
the optimizer step. This will then call the allreduce on the gradients
32+
every sync_every steps.
33+
34+
To implement safe and fault tolerant, this requires a backup copy of the
35+
weights. By default these are stored in CPU memory. If any error occurs
36+
during the LocalSGD step, the step will be discarded and the model
37+
parameters will reset back to the last time LocalSGD synchronized.
38+
39+
The backup weights could be eliminated by relaxing the guarantee of exactly
40+
`sync_every` steps but that would diverge from the LocalSGD algorithm.
41+
DiLoCo also needs this backup copy to compute the delta.
42+
43+
TODO: add DiLoCo support
44+
45+
The torchft quorum is computed at the beginning of ``sync_every`` steps. If
46+
any error occurs, or a worker fails between syncs, ``sync_every`` steps will be
47+
discarded and a new quorum will be computed on the next step.
48+
49+
TODO: add a way via Manager to detect workers heartbeats failing early
50+
51+
If running in async mode, on a joining worker the first ``sync_every`` steps
52+
will discarded as the model will be recovering during that period. When
53+
using sync mode, the checkpoint will be restored prior to the first step.
54+
"""
55+
56+
def __init__(
57+
self,
58+
manager: Manager,
59+
model: nn.Module,
60+
sync_every: int,
61+
backup_device: Optional[torch.device] = None,
62+
) -> None:
63+
"""
64+
Args:
65+
manager: The manager to use.
66+
model: The model to wrap.
67+
sync_every: How often to sync the model weights.
68+
backup_device: The device to store the backup of the model parameters on. (default cpu)
69+
"""
70+
super().__init__()
71+
72+
self._manager = manager
73+
self._model = model
74+
self._local_step = 0
75+
self._started_step = False
76+
self._sync_every = sync_every
77+
assert sync_every >= 1, "sync_every must be greater than or equal to 1"
78+
79+
device = backup_device or torch.device("cpu")
80+
81+
self._backup_parameters: Dict[str, torch.Tensor] = {}
82+
83+
for name, p in self._model.named_parameters():
84+
t = torch.empty(*tuple(p.shape), dtype=p.dtype, device=device)
85+
if t.device == torch.device("cpu"):
86+
t = t.pin_memory()
87+
self._backup_parameters[name] = t
88+
89+
# Need to copy the parameters to the host to be safe if we are on the first step.
90+
self._save_parameters()
91+
92+
def _save_parameters(self) -> None:
93+
# TODO: consider running copy on a separate stream
94+
for name, p in self._model.named_parameters():
95+
self._backup_parameters[name].copy_(p.data, non_blocking=True)
96+
97+
def _restore_parameters(self) -> None:
98+
# TODO: consider running copy on a separate stream
99+
for name, p in self._model.named_parameters():
100+
p.data.copy_(self._backup_parameters[name], non_blocking=True)
101+
102+
# pyre-fixme[14]: support state_dict args
103+
def state_dict(self) -> Dict[str, object]:
104+
"""
105+
state_dict returns the state_dict from the last time LocalSGD
106+
synchronized and not the current weights.
107+
"""
108+
state_dict = self._model.state_dict()
109+
for name, p in self._backup_parameters.items():
110+
assert name in state_dict
111+
state_dict[name] = p
112+
return state_dict
113+
114+
def load_state_dict(
115+
self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False
116+
) -> None:
117+
"""
118+
Loads the state dict to the model and the backup parameters.
119+
120+
This must be called while the model weights aren't being modified to
121+
avoid corrupting the backup weights.
122+
"""
123+
self._model.load_state_dict(state_dict, strict=strict, assign=assign)
124+
self._save_parameters()
125+
126+
def forward(self, *args: object, **kwargs: object) -> object:
127+
"""
128+
Run the model parameters.
129+
130+
This should be called before the optimizer step.
131+
132+
This will start the quorum and save the parameters if this is the first step.
133+
"""
134+
if self._local_step == 0:
135+
self._manager.start_quorum()
136+
137+
self._started_step = True
138+
139+
return self._model.forward(*args, **kwargs)
140+
141+
def step(self) -> None:
142+
"""
143+
This should be called after the optimizer step.
144+
145+
This will call the allreduce on the model weights every sync_every steps.
146+
If any errors occur it will restore to the weights from the previous sync.
147+
148+
``forward`` must be called before this function.
149+
"""
150+
assert self._started_step, "forward must be called before step"
151+
self._started_step = False
152+
153+
self._local_step += 1
154+
155+
if self._local_step >= self._sync_every:
156+
self._local_step = 0
157+
self._average()
158+
159+
if self._manager.should_commit():
160+
# save the parameters so we can restore from them later if necessary.
161+
self._save_parameters()
162+
else:
163+
# commit failed, restore from the backup parameters
164+
self._restore_parameters()
165+
166+
def _average(self) -> None:
167+
# TODO: do we need to broadcast buffers like DDP does?
168+
169+
works = []
170+
171+
for p in self._model.parameters():
172+
# TODO: bucketize parameters
173+
works.append(self._manager.allreduce_grad(p))
174+
175+
for work in works:
176+
work.wait()

torchft/local_sgd_test.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Dict
8+
from unittest import TestCase
9+
from unittest.mock import create_autospec
10+
11+
import torch
12+
from torch import nn, optim
13+
14+
from torchft.local_sgd import LocalSGD
15+
from torchft.manager import Manager
16+
17+
18+
class SimpleModel(nn.Module):
19+
def __init__(self) -> None:
20+
super().__init__()
21+
22+
self.model = nn.Sequential(
23+
nn.Linear(3, 4),
24+
nn.ReLU(),
25+
nn.Linear(4, 5),
26+
nn.Sigmoid(),
27+
)
28+
29+
def forward(self, x: torch.Tensor) -> torch.Tensor:
30+
return self.model(x)
31+
32+
33+
def _params_dict(m: torch.nn.Module) -> Dict[str, torch.Tensor]:
34+
return {name: p.data for name, p in m.named_parameters()}
35+
36+
37+
def _copy_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
38+
return {name: value.clone().detach() for name, value in state_dict.items()}
39+
40+
41+
class LocalSGDTest(TestCase):
42+
def test_local_sgd_healthy(self) -> None:
43+
base_m = SimpleModel()
44+
optimizer = optim.SGD(base_m.parameters())
45+
manager = create_autospec(Manager)
46+
47+
m = LocalSGD(manager, base_m, sync_every=2)
48+
49+
torch.testing.assert_close(m._backup_parameters, _params_dict(base_m))
50+
51+
inp = torch.rand(2, 3)
52+
53+
loss = m(inp).mean()
54+
loss.backward()
55+
optimizer.step()
56+
57+
m.step()
58+
self.assertEqual(m._local_step, 1)
59+
self.assertEqual(manager.start_quorum.call_count, 1)
60+
61+
loss = m(inp).mean()
62+
loss.backward()
63+
optimizer.step()
64+
65+
manager.should_commit.return_value = True
66+
m.step()
67+
self.assertEqual(m._local_step, 0)
68+
69+
torch.testing.assert_close(m._backup_parameters, _params_dict(base_m))
70+
self.assertEqual(manager.should_commit.call_count, 1)
71+
self.assertEqual(manager.allreduce_grad.call_count, 4)
72+
73+
def test_local_sgd_recovery(self) -> None:
74+
base_m = SimpleModel()
75+
optimizer = optim.SGD(base_m.parameters())
76+
manager = create_autospec(Manager)
77+
78+
m = LocalSGD(manager, base_m, sync_every=2)
79+
80+
torch.testing.assert_close(m._backup_parameters, _params_dict(base_m))
81+
og_state_dict = _copy_state_dict(base_m.state_dict())
82+
83+
inp = torch.rand(2, 3)
84+
85+
loss = m(inp).mean()
86+
loss.backward()
87+
optimizer.step()
88+
89+
m.step()
90+
self.assertEqual(m._local_step, 1)
91+
92+
state_dict = m.state_dict()
93+
torch.testing.assert_close(state_dict, m._backup_parameters)
94+
torch.testing.assert_close(state_dict, og_state_dict)
95+
96+
m.load_state_dict(state_dict)
97+
torch.testing.assert_close(_params_dict(base_m), state_dict)
98+
torch.testing.assert_close(m._backup_parameters, _params_dict(base_m))

torchft/manager.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -324,12 +324,11 @@ def start_quorum(self, allow_heal: bool = True) -> None:
324324
It's best practice to call this before the forwards pass of each step for
325325
performance as computing quorum may take some time.
326326
327-
If allow_heal is set, the manager will attempt to heal either
328-
synchronously before returning or asynchronously prior to any network
329-
calls.
330-
331327
Args:
332-
allow_heal: whether to allow healing at the beginning of the step
328+
allow_heal: (experimental) whether to allow healing at the beginning of the step
329+
If allow_heal is set, the manager will attempt to heal either
330+
synchronously before returning or asynchronously prior to any network
331+
calls. All replicas must pass the same value to allow_heal.
333332
"""
334333

335334
# wait for previous quorum to complete

0 commit comments

Comments
 (0)