@@ -85,21 +85,26 @@ def getBackendName(self) -> str:
85
85
raise NotImplementedError ("not implemented" )
86
86
87
87
88
- class ProcessGroupGloo (ProcessGroup ):
88
+ class ProcessGroupWrapper (ProcessGroup ):
89
+ PG_CLASS : Type [BaseProcessGroup ]
89
90
"""
90
- This is a wrapper around ProcessGroupGloo with a reconfiguration argument .
91
+ This is a wrapper around any ProcessGroup with a reconfiguration method .
91
92
"""
92
93
93
94
def __init__ (self ) -> None :
94
95
super ().__init__ (0 , 1 )
95
96
self ._pg = None
96
97
97
98
def configure (self , store_addr : str , rank : int , world_size : int ) -> None :
99
+ if self ._pg is not None :
100
+ if hasattr (self ._pg , "abort" ):
101
+ self ._pg .abort ()
102
+ self ._pg = None
103
+
98
104
store = create_store (store_addr )
99
105
100
- # TODO: set lower timeout
101
- # pyre-fixme[16]: no attribute ProcessGroupGloo
102
- self ._pg = BaseProcessGroupGloo (store , rank , world_size )
106
+ # TODO: set global timeout
107
+ self ._pg = self .PG_CLASS (store , rank , world_size )
103
108
104
109
def allreduce (self , tensors : List [torch .Tensor ], opts : object ) -> Work :
105
110
return self ._pg .allreduce (tensors , opts )
@@ -118,10 +123,35 @@ def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work:
118
123
def size (self ) -> int :
119
124
return self ._pg .size ()
120
125
126
+
127
+ class ProcessGroupGloo (ProcessGroupWrapper ):
128
+ """
129
+ This is a reconfigurable version of ProcessGroupGloo.
130
+ """
131
+
132
+ PG_CLASS = BaseProcessGroupGloo
133
+
121
134
def getBackendName (self ) -> str :
122
135
return "torchft-gloo"
123
136
124
137
138
+ class ProcessGroupNCCL (ProcessGroupWrapper ):
139
+ """
140
+ This is a reconfigurable version of ProcessGroupNCCL.
141
+
142
+ WARNING: this may result in deadlocks due to NCCL error handling. This is
143
+ provided for completeness but your mileage may vary.
144
+
145
+ TODO: verify shutdown correctness with latest NCCL. This currently will call
146
+ abort when reconfiguring, we need to ensure this is safe.
147
+ """
148
+
149
+ PG_CLASS = BaseProcessGroupNCCL
150
+
151
+ def getBackendName (self ) -> str :
152
+ return "torchft-nccl"
153
+
154
+
125
155
class DummyWork (dist ._Work ):
126
156
def __init__ (self , result ):
127
157
super ().__init__ ()
0 commit comments