Skip to content

Commit 43ef175

Browse files
committed
manager_integ_tests: added multi rank recovery
1 parent a52d746 commit 43ef175

File tree

2 files changed

+123
-32
lines changed

2 files changed

+123
-32
lines changed

torchft/manager.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from concurrent.futures import ThreadPoolExecutor
3434
from datetime import timedelta
3535
from enum import Enum
36-
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, cast
36+
from typing import Callable, cast, Dict, List, Optional, TYPE_CHECKING, TypeVar
3737

3838
import torch
3939
from torch.distributed import ReduceOp, TCPStore
@@ -374,16 +374,19 @@ def _async_quorum(self) -> None:
374374
self._participating_rank = None
375375

376376
if quorum_id != self._quorum_id:
377-
logger.info(f"{replica_rank=} reconfiguring for quorum_id {quorum_id}")
378377
store_prefixed_addr = f"{store_address}/torchft/{quorum_id}/{self._rank}"
378+
379+
logger.info(
380+
f"{replica_rank=} reconfiguring for {quorum_id=} {store_prefixed_addr=}"
381+
)
379382
# We use the replica rank and world as we want all replicas in the PG.
380383
self._pg.configure(store_prefixed_addr, replica_rank, replica_world_size)
381384
self._quorum_id = quorum_id
382385

383386
# See manager.rs for healing conditions
384387
if heal:
385388
self._healing = True
386-
logger.info(f"{replica_rank}= healing required")
389+
logger.info(f"{replica_rank=} healing required")
387390

388391
logger.info(f"fetching checkpoint server address from {address}")
389392
primary_client = ManagerClient(address, timeout=self._timeout)

torchft/manager_integ_test.py

Lines changed: 117 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from concurrent.futures import ThreadPoolExecutor, as_completed
1+
import threading
2+
from concurrent.futures import as_completed, ThreadPoolExecutor
23
from contextlib import ExitStack
3-
from typing import Dict, Set, Tuple
4+
from typing import Dict, List, Set, Tuple
45
from unittest import TestCase
56

67
import torch
@@ -32,32 +33,74 @@ class InjectedFailure(Exception):
3233

3334
class FailureInjector:
3435
def __init__(self) -> None:
35-
self._failures: Set[int] = set()
36+
self._lock = threading.Lock()
37+
self._failures: Set[Tuple[int, int]] = set()
3638
self.count = 0
3739

38-
def fail_at(self, step: int) -> "FailureInjector":
39-
self._failures.add(step)
40-
return self
40+
def fail_at(self, rank: int, step: int) -> "FailureInjector":
41+
with self._lock:
42+
self._failures.add((rank, step))
43+
return self
4144

42-
def check(self, step: int) -> None:
43-
if step in self._failures:
44-
self.count += 1
45-
self._failures.remove(step)
46-
print(f"injecting failure {step=}")
47-
raise InjectedFailure(f"injected failure {step=}")
45+
def check(self, rank: int, step: int) -> None:
46+
with self._lock:
47+
key = (rank, step)
48+
if key in self._failures:
49+
self.count += 1
50+
self._failures.remove(key)
51+
print(f"injecting failure {rank=} {step=}")
52+
raise InjectedFailure(f"injected failure {rank=} {step=}")
53+
54+
55+
def replica_main(
56+
replica_id: int,
57+
lighthouse_address: str,
58+
failure_injector: FailureInjector,
59+
world_size: int,
60+
) -> List[Dict[str, Dict[str, object]]]:
61+
store = dist.TCPStore(
62+
host_name="localhost",
63+
port=0,
64+
is_master=True,
65+
wait_for_workers=False,
66+
)
67+
68+
with ThreadPoolExecutor(
69+
max_workers=world_size, thread_name_prefix=f"replica{replica_id}"
70+
) as executor:
71+
futures = []
72+
for rank in range(world_size):
73+
futures.append(
74+
executor.submit(
75+
train_loop,
76+
replica_id,
77+
lighthouse_address,
78+
failure_injector=failure_injector,
79+
rank=rank,
80+
world_size=world_size,
81+
store_port=store.port,
82+
)
83+
)
84+
85+
return [fut.result() for fut in as_completed(futures)]
4886

4987

5088
def worker_manager(
5189
replica_id: int,
5290
lighthouse_address: str,
5391
failure_injector: FailureInjector,
5492
attempts: int = 3,
55-
) -> Dict[str, Dict[str, object]]:
93+
world_size: int = 1,
94+
) -> List[Dict[str, Dict[str, object]]]:
95+
5696
for i in range(attempts):
5797
try:
58-
print(f"starting worker {replica_id} attempt {i}")
59-
return train_loop(
60-
replica_id, lighthouse_address, failure_injector=failure_injector
98+
print(f"starting replica group {replica_id=} {world_size=} attempt {i}")
99+
return replica_main(
100+
replica_id,
101+
lighthouse_address,
102+
failure_injector=failure_injector,
103+
world_size=world_size,
61104
)
62105
except InjectedFailure as e:
63106
print("got injected failure", i, e)
@@ -69,15 +112,14 @@ def worker_manager(
69112

70113

71114
def train_loop(
72-
replica_id: int, lighthouse_address: str, failure_injector: FailureInjector
115+
replica_id: int,
116+
lighthouse_address: str,
117+
failure_injector: FailureInjector,
118+
rank: int,
119+
world_size: int,
120+
store_port: int,
73121
) -> Dict[str, Dict[str, object]]:
74122
with ExitStack() as stack:
75-
store = dist.TCPStore(
76-
host_name="localhost",
77-
port=0,
78-
is_master=True,
79-
wait_for_workers=False,
80-
)
81123

82124
def load_state_dict(state_dict: Dict[str, Dict[str, object]]) -> None:
83125
m.load_state_dict(state_dict["model"])
@@ -89,6 +131,8 @@ def state_dict() -> Dict[str, Dict[str, object]]:
89131
"optim": optimizer.state_dict(),
90132
}
91133

134+
print(f"worker {replica_id=} {rank=} {world_size=} starting")
135+
92136
pg = ProcessGroupGloo()
93137
manager = Manager(
94138
pg=pg,
@@ -97,9 +141,9 @@ def state_dict() -> Dict[str, Dict[str, object]]:
97141
state_dict=state_dict,
98142
replica_id=str(replica_id),
99143
store_addr="localhost",
100-
store_port=store.port,
101-
rank=0,
102-
world_size=1,
144+
store_port=store_port,
145+
rank=rank,
146+
world_size=world_size,
103147
lighthouse_addr=lighthouse_address,
104148
port=19530 + replica_id,
105149
)
@@ -112,7 +156,9 @@ def state_dict() -> Dict[str, Dict[str, object]]:
112156
criterion = nn.CrossEntropyLoss()
113157

114158
while True:
115-
print(f"worker {replica_id} starting step {manager.current_step()}")
159+
print(
160+
f"worker {replica_id=} {rank=} {world_size=} starting step {manager.current_step()}"
161+
)
116162
inputs = torch.rand(2, 3)
117163
labels = torch.randint(4, (2,))
118164

@@ -126,7 +172,7 @@ def state_dict() -> Dict[str, Dict[str, object]]:
126172
if manager.current_step() >= 5:
127173
break
128174

129-
failure_injector.check(manager.current_step())
175+
failure_injector.check(rank, manager.current_step())
130176

131177
# return state_dict so we can check consistency
132178
return state_dict()
@@ -173,7 +219,7 @@ def test_ddp_recovery(self) -> None:
173219

174220
failure_injectors = [
175221
FailureInjector(),
176-
FailureInjector().fail_at(2),
222+
FailureInjector().fail_at(0, 2),
177223
]
178224

179225
with ThreadPoolExecutor(max_workers=num_replicas) as executor:
@@ -200,3 +246,45 @@ def test_ddp_recovery(self) -> None:
200246
torch.testing.assert_close(state_dict, state_dicts[0])
201247

202248
self.assertEqual(failure_injectors[1].count, 1)
249+
250+
def test_ddp_recovery_multi_rank(self) -> None:
251+
lighthouse = Lighthouse(
252+
bind="[::]:0",
253+
min_replicas=2,
254+
)
255+
num_replicas = 2
256+
world_size = 2
257+
futures = []
258+
259+
failure_injectors = [
260+
FailureInjector(),
261+
FailureInjector().fail_at(0, 2).fail_at(1, 2),
262+
]
263+
264+
with ThreadPoolExecutor(max_workers=num_replicas) as executor:
265+
for replica_id, failure_injector in zip(
266+
range(num_replicas), failure_injectors
267+
):
268+
futures.append(
269+
executor.submit(
270+
worker_manager,
271+
replica_id,
272+
lighthouse.address(),
273+
failure_injector=failure_injector,
274+
world_size=world_size,
275+
)
276+
)
277+
278+
state_dicts = []
279+
280+
for fut in as_completed(futures):
281+
try:
282+
state_dicts.append(fut.result())
283+
except Exception as e:
284+
print(e)
285+
raise
286+
287+
lighthouse.shutdown()
288+
289+
for state_dict in state_dicts:
290+
torch.testing.assert_close(state_dict, state_dicts[0])

0 commit comments

Comments
 (0)