1717"""
1818
1919import logging
20+ import queue
2021import threading
2122from abc import ABC
2223from datetime import timedelta
23- from typing import TYPE_CHECKING , Dict , List , Optional , Type
24+ from typing import TYPE_CHECKING , Dict , List , Optional , Type , Union
2425
2526import torch
2627import torch .distributed as dist
5354_FUTURE_EXCEPTION = "fut_exception"
5455
5556
56- def _get (queue : mp .Queue , timeout : float ) -> object :
57- v = queue .get (timeout = timeout )
57+ def _get (q : mp .Queue , timeout : Union [float , timedelta ]) -> object :
58+ """
59+ Gets an item from a queue with a timeout. If the timeout is exceeded then
60+ a TimeoutError is raised.
61+
62+ If an exception is returned from the queue then it is raised.
63+
64+ Args:
65+ q: queue to get from
66+ timeout: timeout in seconds
67+ """
68+ if isinstance (timeout , timedelta ):
69+ timeout = timeout .total_seconds ()
70+ try :
71+ v = q .get (timeout = timeout )
72+ except queue .Empty as e :
73+ raise TimeoutError (f"queue.get() timed out after { timeout } seconds" ) from e
5874 if isinstance (v , Exception ):
5975 raise v
6076 return v
@@ -95,6 +111,9 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
95111 Every time this is called it must be provided with a unique prefixed
96112 store address. I.e. localhost:1234/my/prefix/1
97113
114+ This function will block until the underlying ProcessGroup is created.
115+ If an error occurs this will throw.
116+
98117 Args:
99118 store_addr: address of the store to use
100119 rank: rank of this process
@@ -187,7 +206,6 @@ def __repr__(self) -> str:
187206
188207
189208class ProcessGroupWrapper (ProcessGroup ):
190- PG_CLASS : Type [BaseProcessGroup ] # pyre-fixme[13]: never initialized
191209 """
192210 This is a wrapper around any ProcessGroup with a reconfiguration method.
193211 """
@@ -209,9 +227,10 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
209227
210228 store = create_store_client (store_addr )
211229
212- # TODO: set global timeout
213- # pyre-fixme[20]: expects argument options
214- self ._pg = self .PG_CLASS (store , rank , world_size )
230+ self ._pg = self ._create_pg (store , rank , world_size )
231+
232+ def _create_pg (self , store : Store , rank : int , world_size : int ) -> BaseProcessGroup :
233+ raise NotImplementedError ("not implemented" )
215234
216235 def allreduce (self , tensors : List [torch .Tensor ], opts : object ) -> Work :
217236 return self .parent .allreduce (tensors , opts )
@@ -244,9 +263,13 @@ class ProcessGroupGloo(ProcessGroupWrapper):
244263 This is a reconfigurable version of ProcessGroupGloo.
245264 """
246265
247- PG_CLASS : Type [BaseProcessGroup ] = (
248- BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo
249- )
266+ def __init__ (self , timeout : timedelta = timedelta (seconds = 60.0 )) -> None :
267+ super ().__init__ ()
268+ self ._timeout = timeout
269+
270+ def _create_pg (self , store : Store , rank : int , world_size : int ) -> BaseProcessGroup :
271+ # pyre-fixme[16]: no attribute ProcessGroupGloo
272+ return BaseProcessGroupGloo (store , rank , world_size , self ._timeout )
250273
251274 def getBackendName (self ) -> str :
252275 return "torchft-gloo"
@@ -263,9 +286,9 @@ class ProcessGroupNCCL(ProcessGroupWrapper):
263286 abort when reconfiguring, we need to ensure this is safe.
264287 """
265288
266- PG_CLASS : Type [ BaseProcessGroup ] = (
267- BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL
268- )
289+ def _create_pg ( self , store : Store , rank : int , world_size : int ) -> BaseProcessGroup :
290+ # pyre-fixme[16]: no attribute ProcessGroupNCCL
291+ return BaseProcessGroupNCCL ( store , rank , world_size )
269292
270293 def getBackendName (self ) -> str :
271294 return "torchft-nccl"
@@ -546,10 +569,9 @@ class ProcessGroupBaby(ProcessGroup):
546569
547570 """
548571
549- PG_CLASS : Type [BaseProcessGroup ] # pyre-fixme[13]: never initialized
550572 WORK_CLASS : Type [_BabyWork ] = _BabyWork
551573
552- def __init__ (self , timeout : float = 60.0 ) -> None :
574+ def __init__ (self , timeout : Union [ float , timedelta ] = 60.0 ) -> None :
553575 super ().__init__ (0 , 1 )
554576
555577 self ._world_size = - 1
@@ -562,7 +584,10 @@ def __init__(self, timeout: float = 60.0) -> None:
562584 self ._futures : Dict [int , Future [object ]] = {}
563585 self ._futures_lock = threading .Lock ()
564586
565- self ._timeout = timeout
587+ if isinstance (timeout , timedelta ):
588+ timeout = timeout .total_seconds ()
589+
590+ self ._timeout : float = timeout
566591
567592 def configure (self , store_addr : str , rank : int , world_size : int ) -> None :
568593 if self ._p is not None :
@@ -581,7 +606,7 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
581606
582607 ctx = mp .get_context ("spawn" )
583608 self ._tx = ctx .Queue ()
584- self ._rx = ctx .Queue ()
609+ self ._rx = rx = ctx .Queue ()
585610
586611 # futures need thread to fire callbacks
587612 self ._future_queue = ctx .Queue ()
@@ -602,6 +627,17 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
602627 )
603628 self ._p .start ()
604629
630+ # fetch the status of the PG init
631+ # if an exception was returned _get will throw
632+ assert _get (rx , self ._timeout ) is None
633+
634+ @classmethod
635+ def _create_pg (cls , store : Store , rank : int , world_size : int ) -> BaseProcessGroup :
636+ """
637+ This is a class method to avoid pickling the class.
638+ """
639+ raise NotImplementedError ("not implemented" )
640+
605641 @classmethod
606642 def _worker (
607643 cls ,
@@ -615,8 +651,13 @@ def _worker(
615651 try :
616652 store = create_store_client (store_addr )
617653
618- # pyre-fixme[20]: expects argument options
619- pg = cls .PG_CLASS (store , rank , world_size )
654+ try :
655+ pg = cls ._create_pg (store , rank , world_size )
656+ except Exception as e :
657+ logger .exception (f"got exception in worker: { e } " )
658+ tx .put (e )
659+ return
660+ tx .put (None )
620661
621662 work = {}
622663 next_op_id : int = 0
@@ -737,9 +778,10 @@ class ProcessGroupBabyGloo(ProcessGroupBaby):
737778 ProcessGroupBabyNCCL.
738779 """
739780
740- PG_CLASS : Type [BaseProcessGroup ] = (
741- BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo
742- )
781+ @classmethod
782+ def _create_pg (cls , store : Store , rank : int , world_size : int ) -> BaseProcessGroup :
783+ # pyre-fixme[16]: no attribute ProcessGroupGloo
784+ return BaseProcessGroupGloo (store , rank , world_size )
743785
744786 def getBackendName (self ) -> str :
745787 return "torchft-baby-gloo"
@@ -761,11 +803,13 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby):
761803 tensors may leak in the current PyTorch implementation. TODO fix
762804 """
763805
764- PG_CLASS : Type [BaseProcessGroup ] = (
765- BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL
766- )
767806 WORK_CLASS = _BabyWorkNCCL
768807
808+ @classmethod
809+ def _create_pg (cls , store : Store , rank : int , world_size : int ) -> BaseProcessGroup :
810+ # pyre-fixme[16]: no attribute ProcessGroupNCCL
811+ return BaseProcessGroupNCCL (store , rank , world_size )
812+
769813 def getBackendName (self ) -> str :
770814 return "torchft-baby-nccl"
771815
0 commit comments