1- from concurrent .futures import ThreadPoolExecutor , as_completed
1+ import threading
2+ from concurrent .futures import as_completed , ThreadPoolExecutor
23from contextlib import ExitStack
3- from typing import Dict , Set , Tuple
4+ from typing import Dict , List , Set , Tuple
45from unittest import TestCase
56
67import torch
@@ -32,32 +33,74 @@ class InjectedFailure(Exception):
3233
3334class FailureInjector :
3435 def __init__ (self ) -> None :
35- self ._failures : Set [int ] = set ()
36+ self ._lock = threading .Lock ()
37+ self ._failures : Set [Tuple [int , int ]] = set ()
3638 self .count = 0
3739
38- def fail_at (self , step : int ) -> "FailureInjector" :
39- self ._failures .add (step )
40- return self
40+ def fail_at (self , rank : int , step : int ) -> "FailureInjector" :
41+ with self ._lock :
42+ self ._failures .add ((rank , step ))
43+ return self
4144
42- def check (self , step : int ) -> None :
43- if step in self ._failures :
44- self .count += 1
45- self ._failures .remove (step )
46- print (f"injecting failure { step = } " )
47- raise InjectedFailure (f"injected failure { step = } " )
45+ def check (self , rank : int , step : int ) -> None :
46+ with self ._lock :
47+ key = (rank , step )
48+ if key in self ._failures :
49+ self .count += 1
50+ self ._failures .remove (key )
51+ print (f"injecting failure { rank = } { step = } " )
52+ raise InjectedFailure (f"injected failure { rank = } { step = } " )
53+
54+
55+ def replica_main (
56+ replica_id : int ,
57+ lighthouse_address : str ,
58+ failure_injector : FailureInjector ,
59+ world_size : int ,
60+ ) -> List [Dict [str , Dict [str , object ]]]:
61+ store = dist .TCPStore (
62+ host_name = "localhost" ,
63+ port = 0 ,
64+ is_master = True ,
65+ wait_for_workers = False ,
66+ )
67+
68+ with ThreadPoolExecutor (
69+ max_workers = world_size , thread_name_prefix = f"replica{ replica_id } "
70+ ) as executor :
71+ futures = []
72+ for rank in range (world_size ):
73+ futures .append (
74+ executor .submit (
75+ train_loop ,
76+ replica_id ,
77+ lighthouse_address ,
78+ failure_injector = failure_injector ,
79+ rank = rank ,
80+ world_size = world_size ,
81+ store_port = store .port ,
82+ )
83+ )
84+
85+ return [fut .result () for fut in as_completed (futures )]
4886
4987
5088def worker_manager (
5189 replica_id : int ,
5290 lighthouse_address : str ,
5391 failure_injector : FailureInjector ,
5492 attempts : int = 3 ,
55- ) -> Dict [str , Dict [str , object ]]:
93+ world_size : int = 1 ,
94+ ) -> List [Dict [str , Dict [str , object ]]]:
95+
5696 for i in range (attempts ):
5797 try :
58- print (f"starting worker { replica_id } attempt { i } " )
59- return train_loop (
60- replica_id , lighthouse_address , failure_injector = failure_injector
98+ print (f"starting replica group { replica_id = } { world_size = } attempt { i } " )
99+ return replica_main (
100+ replica_id ,
101+ lighthouse_address ,
102+ failure_injector = failure_injector ,
103+ world_size = world_size ,
61104 )
62105 except InjectedFailure as e :
63106 print ("got injected failure" , i , e )
@@ -69,15 +112,14 @@ def worker_manager(
69112
70113
71114def train_loop (
72- replica_id : int , lighthouse_address : str , failure_injector : FailureInjector
115+ replica_id : int ,
116+ lighthouse_address : str ,
117+ failure_injector : FailureInjector ,
118+ rank : int ,
119+ world_size : int ,
120+ store_port : int ,
73121) -> Dict [str , Dict [str , object ]]:
74122 with ExitStack () as stack :
75- store = dist .TCPStore (
76- host_name = "localhost" ,
77- port = 0 ,
78- is_master = True ,
79- wait_for_workers = False ,
80- )
81123
82124 def load_state_dict (state_dict : Dict [str , Dict [str , object ]]) -> None :
83125 m .load_state_dict (state_dict ["model" ])
@@ -89,6 +131,8 @@ def state_dict() -> Dict[str, Dict[str, object]]:
89131 "optim" : optimizer .state_dict (),
90132 }
91133
134+ print (f"worker { replica_id = } { rank = } { world_size = } starting" )
135+
92136 pg = ProcessGroupGloo ()
93137 manager = Manager (
94138 pg = pg ,
@@ -97,9 +141,9 @@ def state_dict() -> Dict[str, Dict[str, object]]:
97141 state_dict = state_dict ,
98142 replica_id = str (replica_id ),
99143 store_addr = "localhost" ,
100- store_port = store . port ,
101- rank = 0 ,
102- world_size = 1 ,
144+ store_port = store_port ,
145+ rank = rank ,
146+ world_size = world_size ,
103147 lighthouse_addr = lighthouse_address ,
104148 port = 19530 + replica_id ,
105149 )
@@ -112,7 +156,9 @@ def state_dict() -> Dict[str, Dict[str, object]]:
112156 criterion = nn .CrossEntropyLoss ()
113157
114158 while True :
115- print (f"worker { replica_id } starting step { manager .current_step ()} " )
159+ print (
160+ f"worker { replica_id = } { rank = } { world_size = } starting step { manager .current_step ()} "
161+ )
116162 inputs = torch .rand (2 , 3 )
117163 labels = torch .randint (4 , (2 ,))
118164
@@ -126,7 +172,7 @@ def state_dict() -> Dict[str, Dict[str, object]]:
126172 if manager .current_step () >= 5 :
127173 break
128174
129- failure_injector .check (manager .current_step ())
175+ failure_injector .check (rank , manager .current_step ())
130176
131177 # return state_dict so we can check consistency
132178 return state_dict ()
@@ -173,7 +219,7 @@ def test_ddp_recovery(self) -> None:
173219
174220 failure_injectors = [
175221 FailureInjector (),
176- FailureInjector ().fail_at (2 ),
222+ FailureInjector ().fail_at (0 , 2 ),
177223 ]
178224
179225 with ThreadPoolExecutor (max_workers = num_replicas ) as executor :
@@ -200,3 +246,45 @@ def test_ddp_recovery(self) -> None:
200246 torch .testing .assert_close (state_dict , state_dicts [0 ])
201247
202248 self .assertEqual (failure_injectors [1 ].count , 1 )
249+
250+ def test_ddp_recovery_multi_rank (self ) -> None :
251+ lighthouse = Lighthouse (
252+ bind = "[::]:0" ,
253+ min_replicas = 2 ,
254+ )
255+ num_replicas = 2
256+ world_size = 2
257+ futures = []
258+
259+ failure_injectors = [
260+ FailureInjector (),
261+ FailureInjector ().fail_at (0 , 2 ).fail_at (1 , 2 ),
262+ ]
263+
264+ with ThreadPoolExecutor (max_workers = num_replicas ) as executor :
265+ for replica_id , failure_injector in zip (
266+ range (num_replicas ), failure_injectors
267+ ):
268+ futures .append (
269+ executor .submit (
270+ worker_manager ,
271+ replica_id ,
272+ lighthouse .address (),
273+ failure_injector = failure_injector ,
274+ world_size = world_size ,
275+ )
276+ )
277+
278+ state_dicts = []
279+
280+ for fut in as_completed (futures ):
281+ try :
282+ state_dicts .append (fut .result ())
283+ except Exception as e :
284+ print (e )
285+ raise
286+
287+ lighthouse .shutdown ()
288+
289+ for state_dict in state_dicts :
290+ torch .testing .assert_close (state_dict , state_dicts [0 ])
0 commit comments