2020import threading
2121from abc import ABC
2222from datetime import timedelta
23- from typing import Callable , List , Optional , Tuple , Type
23+ from typing import Callable , List , Optional , Tuple , Type , TYPE_CHECKING
2424
2525import torch
2626import torch .distributed as dist
4444
4545from torch .futures import Future
4646
47+ if TYPE_CHECKING :
48+ from torchft .manager import Manager
49+
4750logger = logging .getLogger (__name__ )
4851
4952# TODO: use non strings which are cheaper
@@ -177,18 +180,25 @@ def unregister(self) -> None:
177180 """
178181 dist .destroy_process_group (self )
179182
183+ def __repr__ (self ) -> str :
184+ return f"{ self .__class__ .__name__ } ()"
185+
180186
181187class ProcessGroupWrapper (ProcessGroup ):
182188 PG_CLASS : Type [BaseProcessGroup ]
183189 """
184190 This is a wrapper around any ProcessGroup with a reconfiguration method.
185191 """
186192
187- def __init__ (self ) -> None :
193+ def __init__ (self , pg : Optional [ ProcessGroup ] = None ) -> None :
188194 super ().__init__ (0 , 1 )
189- self ._pg = None
195+ self ._pg = pg
190196
191197 def configure (self , store_addr : str , rank : int , world_size : int ) -> None :
198+ if isinstance (self ._pg , ProcessGroup ):
199+ self ._pg .configure (store_addr , rank , world_size )
200+ return
201+
192202 if self ._pg is not None :
193203 if hasattr (self ._pg , "abort" ):
194204 self ._pg .abort ()
@@ -216,6 +226,12 @@ def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work:
216226 def size (self ) -> int :
217227 return self ._pg .size ()
218228
229+ def parent (self ) -> ProcessGroup :
230+ return self ._pg
231+
232+ def __repr__ (self ) -> str :
233+ return f"{ self .__class__ .__name__ } (pg={ self ._pg } )"
234+
219235
220236class ProcessGroupGloo (ProcessGroupWrapper ):
221237 """
@@ -252,7 +268,7 @@ def __init__(self, result):
252268 self .future_ = torch .futures .Future ()
253269 self .future_ .set_result (result )
254270
255- def wait (self , timeout ):
271+ def wait (self , timeout = None ):
256272 return True
257273
258274 def get_future (self ):
@@ -278,6 +294,10 @@ def __init__(self, rank: int, world: int) -> None:
278294 self .wait_count = 0
279295 self .get_future_count = 0
280296 self ._work = []
297+ self .configure_count = 0
298+
299+ def configure (self , store_addr : str , rank : int , world_size : int ) -> None :
300+ self .configure_count += 1
281301
282302 def broadcast (self , tensor_list , opts ):
283303 res = _DummyWork (tensor_list )
@@ -304,6 +324,102 @@ def getBackendName(self):
304324 return "torchft-dummy"
305325
306326
327+ class _ErrorSwallowingWork (Work ):
328+ def __init__ (
329+ self ,
330+ pg : "ErrorSwallowingProcessGroup" ,
331+ work : Work ,
332+ default_result : object ,
333+ ):
334+ super ().__init__ ()
335+
336+ self ._pg = pg
337+ self ._work = work
338+ self ._default_result = default_result
339+
340+ def wait (self , timeout = None ) -> bool :
341+ try :
342+ self ._work .wait ()
343+ except Exception as e :
344+ self ._pg .report_error (e )
345+
346+ return True
347+
348+ def get_future (self ) -> Future :
349+ fut = self ._work .get_future ()
350+
351+ # schedule error handling as a continuation on the Future
352+ def callback (
353+ fut : torch .futures .Future [List [torch .Tensor ]],
354+ ) -> torch .futures .Future [torch .Tensor ]:
355+ try :
356+ return fut .value ()
357+ except Exception as e :
358+ logger .exception (f"got exception in future -- skipping remaining: { e } " )
359+ self ._pg .report_error (e )
360+ return self ._default_result
361+
362+ fut = fut .then (callback )
363+ return fut
364+
365+
366+ class ErrorSwallowingProcessGroupWrapper (ProcessGroupWrapper ):
367+ """
368+ This is a wrapper around any ProcessGroup that will swallow errors and
369+ return dummy results on error.
370+
371+ This is intended to allow handling errors outside of the training loop to
372+ avoid having to modify modeling code to support error handling.
373+
374+ After an error occurs all future operations will be skipped until the
375+ process group is reconfigured via ``configure``.
376+ """
377+
378+ def __init__ (self , pg : ProcessGroup ) -> None :
379+ super ().__init__ (pg )
380+
381+ self ._error = None
382+
383+ def configure (self , store_addr : str , rank : int , world_size : int ) -> None :
384+ self ._error = None
385+
386+ super ().configure (store_addr , rank , world_size )
387+
388+ def report_error (self , e : Exception ) -> None :
389+ """
390+ Report an error to this process group. This will cause all future
391+ operations to be skipped until the process group is reconfigured via
392+ ``configure``.
393+
394+ Args:
395+ e: exception to report
396+ """
397+ self ._error = e
398+
399+ def error (self ) -> Optional [Exception ]:
400+ """
401+ Returns the error that was reported to this process group.
402+
403+ Returns:
404+ exception that was reported
405+ """
406+ return self ._error
407+
408+ def allreduce (self , tensors : List [torch .Tensor ], opts : object ) -> Work :
409+ if self ._error is not None :
410+ return _DummyWork (tensors )
411+
412+ try :
413+ return _ErrorSwallowingWork (
414+ self ,
415+ super ().allreduce (tensors , opts ),
416+ tensors ,
417+ )
418+ except Exception as e :
419+ self .report_error (e )
420+ return _DummyWork (tensors )
421+
422+
307423class _BabyWork (Work ):
308424 def __init__ (
309425 self ,
0 commit comments