11from concurrent .futures import ThreadPoolExecutor , as_completed
2+ from contextlib import ExitStack
3+ from typing import Set , Tuple
24from unittest import TestCase
35
46import torch
@@ -24,63 +26,108 @@ def forward(self, x):
2426 return self .model (x )
2527
2628
27- def train_loop (replica_id : int , lighthouse_address : str ) -> None :
28- store = dist .TCPStore (
29- host_name = "localhost" ,
30- port = 0 ,
31- is_master = True ,
32- wait_for_workers = False ,
33- )
34-
35- def load_state_dict (state_dict ):
36- m .load_state_dict (state_dict ["model" ])
37- optimizer .load_state_dict (state_dict ["optim" ])
38-
39- def state_dict ():
40- return {
41- "model" : m .state_dict (),
42- "optim" : optimizer .state_dict (),
43- }
44-
45- pg = ProcessGroupGloo ()
46- manager = Manager (
47- pg = pg ,
48- min_replica_size = 2 ,
49- load_state_dict = load_state_dict ,
50- state_dict = state_dict ,
51- replica_id = str (replica_id ),
52- store_addr = "localhost" ,
53- store_port = store .port ,
54- rank = 0 ,
55- world_size = 1 ,
56- lighthouse_addr = lighthouse_address ,
57- port = 19530 + replica_id ,
58- )
59- m = DistributedDataParallel (manager , MyModel ())
60- optimizer = OptimizerWrapper (manager , optim .Adam (m .parameters ()))
61- criterion = nn .CrossEntropyLoss ()
62-
63- while True :
64- inputs = torch .rand (2 , 3 )
65- labels = torch .randint (4 , (2 ,))
66-
67- optimizer .zero_grad ()
68- out = m (inputs )
69- loss = criterion (out , labels )
70-
71- loss .backward ()
72- optimizer .step ()
73-
74- # TODO: assert weights are equal across replicas
75-
76- if manager .current_step () >= 5 :
77- break
78-
79- manager .shutdown ()
29+ class InjectedFailure (Exception ):
30+ pass
31+
32+
33+ class FailureInjector :
34+ def __init__ (self ) -> None :
35+ self ._failures : Set [int ] = set ()
36+ self .count = 0
37+
38+ def fail_at (self , step : int ) -> "FailureInjector" :
39+ self ._failures .add (step )
40+ return self
41+
42+ def check (self , step : int ) -> None :
43+ if step in self ._failures :
44+ self .count += 1
45+ self ._failures .remove (step )
46+ print (f"injecting failure { step = } " )
47+ raise InjectedFailure (f"injected failure { step = } " )
48+
49+
50+ def worker_manager (
51+ replica_id : int ,
52+ lighthouse_address : str ,
53+ failure_injector : FailureInjector ,
54+ attempts : int = 3 ,
55+ ) -> None :
56+ for i in range (attempts ):
57+ try :
58+ print (f"starting worker { replica_id } attempt { i } " )
59+ return train_loop (
60+ replica_id , lighthouse_address , failure_injector = failure_injector
61+ )
62+ except InjectedFailure as e :
63+ print ("got injected failure" , i , e )
64+ if i == attempts - 1 :
65+ raise
66+ continue
67+
68+
69+ def train_loop (
70+ replica_id : int , lighthouse_address : str , failure_injector : FailureInjector
71+ ) -> None :
72+ with ExitStack () as stack :
73+ store = dist .TCPStore (
74+ host_name = "localhost" ,
75+ port = 0 ,
76+ is_master = True ,
77+ wait_for_workers = False ,
78+ )
79+
80+ def load_state_dict (state_dict ):
81+ m .load_state_dict (state_dict ["model" ])
82+ optimizer .load_state_dict (state_dict ["optim" ])
83+
84+ def state_dict ():
85+ return {
86+ "model" : m .state_dict (),
87+ "optim" : optimizer .state_dict (),
88+ }
89+
90+ pg = ProcessGroupGloo ()
91+ manager = Manager (
92+ pg = pg ,
93+ min_replica_size = 2 ,
94+ load_state_dict = load_state_dict ,
95+ state_dict = state_dict ,
96+ replica_id = str (replica_id ),
97+ store_addr = "localhost" ,
98+ store_port = store .port ,
99+ rank = 0 ,
100+ world_size = 1 ,
101+ lighthouse_addr = lighthouse_address ,
102+ port = 19530 + replica_id ,
103+ )
104+ stack .callback (manager .shutdown )
105+
106+ m = DistributedDataParallel (manager , MyModel ())
107+ optimizer = OptimizerWrapper (manager , optim .Adam (m .parameters ()))
108+ criterion = nn .CrossEntropyLoss ()
109+
110+ while True :
111+ print (f"worker { replica_id } starting step { manager .current_step ()} " )
112+ inputs = torch .rand (2 , 3 )
113+ labels = torch .randint (4 , (2 ,))
114+
115+ optimizer .zero_grad ()
116+ out = m (inputs )
117+ loss = criterion (out , labels )
118+
119+ loss .backward ()
120+ optimizer .step ()
121+
122+ if manager .current_step () >= 5 :
123+ # return state_dict so we can check consistency
124+ return state_dict ()
125+
126+ failure_injector .check (manager .current_step ())
80127
81128
82129class ManagerIntegTest (TestCase ):
83- def test_ddp (self ):
130+ def test_ddp_healthy (self ):
84131 lighthouse = Lighthouse (
85132 bind = "[::]:0" ,
86133 min_replicas = 2 ,
@@ -90,11 +137,60 @@ def test_ddp(self):
90137
91138 with ThreadPoolExecutor (max_workers = num_replicas ) as executor :
92139 for replica_id in range (num_replicas ):
140+ failure_injector = FailureInjector ()
141+ futures .append (
142+ executor .submit (
143+ worker_manager ,
144+ replica_id ,
145+ lighthouse .address (),
146+ failure_injector = failure_injector ,
147+ )
148+ )
149+
150+ state_dicts = []
151+
152+ for fut in as_completed (futures ):
153+ state_dicts .append (fut .result ())
154+
155+ lighthouse .shutdown ()
156+
157+ for state_dict in state_dicts :
158+ torch .testing .assert_close (state_dict , state_dicts [0 ])
159+
160+ def test_ddp_recovery (self ):
161+ lighthouse = Lighthouse (
162+ bind = "[::]:0" ,
163+ min_replicas = 2 ,
164+ )
165+ num_replicas = 2
166+ futures = []
167+
168+ failure_injectors = [
169+ FailureInjector (),
170+ FailureInjector ().fail_at (2 ),
171+ ]
172+
173+ with ThreadPoolExecutor (max_workers = num_replicas ) as executor :
174+ for replica_id , failure_injector in zip (
175+ range (num_replicas ), failure_injectors
176+ ):
93177 futures .append (
94- executor .submit (train_loop , replica_id , lighthouse .address ())
178+ executor .submit (
179+ worker_manager ,
180+ replica_id ,
181+ lighthouse .address (),
182+ failure_injector = failure_injector ,
183+ )
95184 )
96185
186+ state_dicts = []
187+
97188 for fut in as_completed (futures ):
98- fut .result ()
189+ state_dicts . append ( fut .result () )
99190
100191 lighthouse .shutdown ()
192+
193+ for state_dict in state_dicts :
194+ torch .testing .assert_close (state_dict , state_dicts [0 ])
195+
196+ self .assertEqual (failure_injectors [1 ].count , 1 )
0 commit comments