@@ -85,21 +85,26 @@ def getBackendName(self) -> str:
8585 raise NotImplementedError ("not implemented" )
8686
8787
88- class ProcessGroupGloo (ProcessGroup ):
88+ class ProcessGroupWrapper (ProcessGroup ):
89+ PG_CLASS : Type [BaseProcessGroup ]
8990 """
90- This is a wrapper around ProcessGroupGloo with a reconfiguration argument .
91+ This is a wrapper around any ProcessGroup with a reconfiguration method .
9192 """
9293
9394 def __init__ (self ) -> None :
9495 super ().__init__ (0 , 1 )
9596 self ._pg = None
9697
9798 def configure (self , store_addr : str , rank : int , world_size : int ) -> None :
99+ if self ._pg is not None :
100+ if hasattr (self ._pg , "abort" ):
101+ self ._pg .abort ()
102+ self ._pg = None
103+
98104 store = create_store (store_addr )
99105
100- # TODO: set lower timeout
101- # pyre-fixme[16]: no attribute ProcessGroupGloo
102- self ._pg = BaseProcessGroupGloo (store , rank , world_size )
106+ # TODO: set global timeout
107+ self ._pg = self .PG_CLASS (store , rank , world_size )
103108
104109 def allreduce (self , tensors : List [torch .Tensor ], opts : object ) -> Work :
105110 return self ._pg .allreduce (tensors , opts )
@@ -118,10 +123,35 @@ def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work:
118123 def size (self ) -> int :
119124 return self ._pg .size ()
120125
126+
127+ class ProcessGroupGloo (ProcessGroupWrapper ):
128+ """
129+ This is a reconfigurable version of ProcessGroupGloo.
130+ """
131+
132+ PG_CLASS = BaseProcessGroupGloo
133+
121134 def getBackendName (self ) -> str :
122135 return "torchft-gloo"
123136
124137
138+ class ProcessGroupNCCL (ProcessGroupWrapper ):
139+ """
140+ This is a reconfigurable version of ProcessGroupNCCL.
141+
142+ WARNING: this may result in deadlocks due to NCCL error handling. This is
143+ provided for completeness but your mileage may vary.
144+
145+ TODO: verify shutdown correctness with latest NCCL. This currently will call
146+ abort when reconfiguring, we need to ensure this is safe.
147+ """
148+
149+ PG_CLASS = BaseProcessGroupNCCL
150+
151+ def getBackendName (self ) -> str :
152+ return "torchft-nccl"
153+
154+
125155class DummyWork (dist ._Work ):
126156 def __init__ (self , result ):
127157 super ().__init__ ()
0 commit comments