8484T = TypeVar ("T" )
8585
8686
87- def create_store_client (store_addr : str ) -> Store :
87+ def create_store_client (store_addr : str , timeout : timedelta ) -> Store :
8888 """
8989 Creates a PrefixStore(TCPStore(...)) client from an address in the format:
9090
@@ -100,6 +100,7 @@ def create_store_client(store_addr: str) -> Store:
100100 port = int (port ),
101101 is_master = False ,
102102 wait_for_workers = False ,
103+ timeout = timeout ,
103104 )
104105 store = PrefixStore (prefix , store )
105106 return store
@@ -350,11 +351,20 @@ def __repr__(self) -> str:
350351class ProcessGroupWrapper (ProcessGroup ):
351352 """
352353 This is a wrapper around any ProcessGroup with a reconfiguration method.
354+
355+ Args:
356+ timeout: timeout for reconfiguration for TCPStore
357+ pg: optional ProcessGroup to use, if None a new one will be created
353358 """
354359
355- def __init__ (self , pg : Optional [ProcessGroup ] = None ) -> None :
360+ def __init__ (
361+ self ,
362+ timeout : timedelta = timedelta (seconds = 60 ),
363+ pg : Optional [ProcessGroup ] = None ,
364+ ) -> None :
356365 super ().__init__ (0 , 1 )
357366 self ._pg : Optional [BaseProcessGroup ] = pg
367+ self ._timeout = timeout
358368
359369 def configure (self , store_addr : str , rank : int , world_size : int ) -> None :
360370 pg = self ._pg
@@ -365,7 +375,7 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
365375 # abort if already initialized
366376 self .abort ()
367377
368- store = create_store_client (store_addr )
378+ store = create_store_client (store_addr , timeout = self . _timeout )
369379
370380 self ._pg = self ._create_pg (store , rank , world_size )
371381
@@ -511,10 +521,6 @@ class ProcessGroupGloo(ProcessGroupWrapper):
511521 This is a reconfigurable version of ProcessGroupGloo.
512522 """
513523
514- def __init__ (self , timeout : timedelta = timedelta (seconds = 60.0 )) -> None :
515- super ().__init__ ()
516- self ._timeout = timeout
517-
518524 def _create_pg (self , store : Store , rank : int , world_size : int ) -> BaseProcessGroup :
519525 pg = BaseProcessGroup (store , rank , world_size )
520526 pg ._set_default_backend (ProcessGroup .BackendType .GLOO )
@@ -648,8 +654,7 @@ class ProcessGroupNCCL(ProcessGroupWrapper):
648654 """
649655
650656 def __init__ (self , timeout : timedelta = timedelta (seconds = 60.0 )) -> None :
651- super ().__init__ ()
652- self ._timeout = timeout
657+ super ().__init__ (timeout )
653658 self ._use_abort : bool = torch .cuda .nccl .version () >= (2 , 25 )
654659
655660 def _opts_hook (self , opts : T ) -> T :
@@ -877,7 +882,7 @@ class ErrorSwallowingProcessGroupWrapper(ProcessGroupWrapper):
877882 """
878883
879884 def __init__ (self , pg : ProcessGroup ) -> None :
880- super ().__init__ (pg )
885+ super ().__init__ (pg = pg )
881886
882887 self ._error : Optional [Exception ] = None
883888
@@ -958,7 +963,7 @@ class ManagedProcessGroup(ProcessGroupWrapper):
958963 """
959964
960965 def __init__ (self , manager : "Manager" ) -> None :
961- super ().__init__ (manager ._pg )
966+ super ().__init__ (pg = manager ._pg )
962967
963968 self ._manager = manager
964969
@@ -1195,7 +1200,11 @@ def _worker(
11951200 if curr_device >= 0 and torch .cuda .is_available ():
11961201 torch .cuda .set_device (curr_device )
11971202
1198- store = create_store_client (store_addr )
1203+ store = create_store_client (
1204+ store_addr ,
1205+ # default TCPStore timeout is 5 minutes
1206+ timeout = timedelta (minutes = 5 ),
1207+ )
11991208
12001209 try :
12011210 pg = cls ._create_pg (store , rank , world_size )
0 commit comments