@@ -85,21 +85,30 @@ def getBackendName(self) -> str:
85
85
raise NotImplementedError ("not implemented" )
86
86
87
87
88
- class ProcessGroupGloo (ProcessGroup ):
88
+ class ProcessGroupWrapper (ProcessGroup ):
89
+ PG_CLASS : Type [BaseProcessGroup ]
89
90
"""
90
- This is a wrapper around ProcessGroupGloo with a reconfiguration argument .
91
+ This is a wrapper around any ProcessGroup with a reconfiguration method .
91
92
"""
92
93
93
- def __init__ (self ) -> None :
94
+ def __init__ (self , timeout : float = 60.0 ) -> None :
95
+ """
96
+ Args:
97
+ timeout: the timeout to use for the process group
98
+ """
94
99
super ().__init__ (0 , 1 )
95
100
self ._pg = None
96
101
97
102
def configure (self , store_addr : str , rank : int , world_size : int ) -> None :
103
+ if self ._pg is not None :
104
+ if hasattr (self ._pg , "abort" ):
105
+ self ._pg .abort ()
106
+ self ._pg = None
107
+
98
108
store = create_store (store_addr )
99
109
100
- # TODO: set lower timeout
101
- # pyre-fixme[16]: no attribute ProcessGroupGloo
102
- self ._pg = BaseProcessGroupGloo (store , rank , world_size )
110
+ # TODO: set global timeout
111
+ self ._pg = self .PG_CLASS (store , rank , world_size )
103
112
104
113
def allreduce (self , tensors : List [torch .Tensor ], opts : object ) -> Work :
105
114
return self ._pg .allreduce (tensors , opts )
@@ -118,10 +127,35 @@ def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work:
118
127
def size (self ) -> int :
119
128
return self ._pg .size ()
120
129
130
+
131
+ class ProcessGroupGloo (ProcessGroupWrapper ):
132
+ """
133
+ This is a reconfigurable version of ProcessGroupGloo.
134
+ """
135
+
136
+ PG_CLASS = BaseProcessGroupGloo
137
+
121
138
def getBackendName (self ) -> str :
122
139
return "torchft-gloo"
123
140
124
141
142
+ class ProcessGroupNCCL (ProcessGroupWrapper ):
143
+ """
144
+ This is a reconfigurable version of ProcessGroupNCCL.
145
+
146
+ WARNING: this may result in deadlocks due to NCCL error handling. This is
147
+ provided for completeness but your mileage may vary.
148
+
149
+ TODO: verify shutdown correctness with latest NCCL. This currently will call
150
+ abort when reconfiguring, we need to ensure this is safe.
151
+ """
152
+
153
+ PG_CLASS = BaseProcessGroupNCCL
154
+
155
+ def getBackendName (self ) -> str :
156
+ return "torchft-nccl"
157
+
158
+
125
159
class DummyWork (dist ._Work ):
126
160
def __init__ (self , result ):
127
161
super ().__init__ ()
0 commit comments