6
6
7
7
from abc import ABC
8
8
import logging
9
- from typing import Type , List , Optional
9
+ from typing import Type , List , Optional , Callable , Tuple
10
10
from datetime import timedelta
11
11
12
12
from torch .futures import Future
@@ -192,22 +192,30 @@ def wait(self) -> bool:
192
192
return True
193
193
194
194
195
+ class BabyWorkNCCL (BabyWork ):
196
+ def wait (self ) -> bool :
197
+ self ._tx .put (("synchronize" , self ._op_id ), timeout = self ._timeout )
198
+ op_id , event = _get (self ._rx , self ._timeout )
199
+ assert op_id == self ._op_id
200
+ assert isinstance (event , torch .cuda .Event )
201
+
202
+ # Wait on Event makes the stream wait but not the CPU thread.
203
+ event .wait ()
204
+
205
+ return True
206
+
207
+
195
208
class ProcessGroupBaby (ProcessGroup ):
196
209
"""
197
210
This is a process group that runs the underlying process group in a
198
211
subprocess. Since it's running in a subprocess all tensors need to be in
199
212
shared memory or will be moved to shared memory. CUDA tensors are implicitly
200
213
share able and don't need any changes.
201
214
202
- If the child process is killed while an operation is running CUDA tensors
203
- may leak in the current implementation.
204
-
205
- For the NCCL backend, extra memory will be used by the subprocesses CUDA
206
- context compared to running NCCL in the main process. This is typically
207
- around ~1GB.
208
215
"""
209
216
210
217
PG_CLASS : Type [BaseProcessGroup ]
218
+ WORK_CLASS : Type [BabyWork ] = BabyWork
211
219
212
220
def __init__ (self , timeout : float = 60.0 ) -> None :
213
221
super ().__init__ (0 , 1 )
@@ -220,6 +228,23 @@ def __init__(self, timeout: float = 60.0) -> None:
220
228
221
229
self ._timeout = timeout
222
230
231
+ def configure (self , store_addr : str , rank : int , world_size : int ) -> None :
232
+ if self ._p is not None :
233
+ self ._p .kill ()
234
+
235
+ self ._world_size = world_size
236
+
237
+ ctx = mp .get_context ("spawn" )
238
+ self ._tx = ctx .Queue ()
239
+ self ._rx = ctx .Queue ()
240
+
241
+ self ._p = ctx .Process (
242
+ target = self ._worker ,
243
+ args = (store_addr , rank , world_size , self ._tx , self ._rx ),
244
+ daemon = True ,
245
+ )
246
+ self ._p .start ()
247
+
223
248
@classmethod
224
249
def _worker (
225
250
cls , store_addr : str , rank : int , world_size : int , rx : mp .Queue , tx : mp .Queue
@@ -235,37 +260,45 @@ def _worker(
235
260
while True :
236
261
op = rx .get ()
237
262
cmd = op [0 ]
238
- if cmd == "allreduce" :
239
- work [next_op_id ] = pg .allreduce (op [1 ], op [2 ])
263
+ if cmd == "func" :
264
+ func , args , kwargs = op [1 :]
265
+ work [next_op_id ] = getattr (pg , func )(* args , ** kwargs )
240
266
tx .put (next_op_id )
241
267
next_op_id += 1
242
268
elif cmd == "wait" :
243
269
op_id = op [1 ]
244
270
work [op_id ].wait ()
245
271
del work [op_id ]
246
272
tx .put (op_id )
273
+ elif cmd == "synchronize" :
274
+ # CUDA only, use events instead of waiting on CPU
275
+ op_id = op [1 ]
276
+
277
+ # With WorkNCCL this makes the stream wait not the CPU when
278
+ # no timeout is passed.
279
+ work [op_id ].wait ()
280
+
281
+ # Register event on the stream that we can pass to the main
282
+ # process.
283
+ event = torch .cuda .Event (interprocess = True )
284
+ event .record ()
285
+
286
+ del work [op_id ]
287
+ tx .put ((op_id , event ))
247
288
else :
248
289
raise ValueError (f"unknown cmd: { cmd } " )
290
+
249
291
except Exception as e :
250
292
logger .exception ("worker errored" )
251
293
tx .put (e )
252
294
253
- def configure (self , store_addr : str , rank : int , world_size : int ) -> None :
254
- if self ._p is not None :
255
- self ._p .kill ()
256
-
257
- self ._world_size = world_size
258
-
259
- ctx = mp .get_context ("spawn" )
260
- self ._tx = ctx .Queue ()
261
- self ._rx = ctx .Queue ()
262
-
263
- self ._p = ctx .Process (
264
- target = self ._worker ,
265
- args = (store_addr , rank , world_size , self ._tx , self ._rx ),
266
- daemon = True ,
295
+ def _run_func (self , func : str , * args : object , ** kwargs : object ) -> Work :
296
+ self ._tx .put (("func" , func , args , kwargs ), timeout = self ._timeout )
297
+ op_id = _get (self ._rx , self ._timeout )
298
+ assert isinstance (op_id , int ), f"invalid return { op_id } "
299
+ return self .WORK_CLASS (
300
+ tx = self ._tx , rx = self ._rx , op_id = op_id , timeout = self ._timeout
267
301
)
268
- self ._p .start ()
269
302
270
303
def allreduce (self , tensors : List [torch .Tensor ], opts : object ) -> Work :
271
304
assert isinstance (tensors , list ), "input must be list"
@@ -274,10 +307,7 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
274
307
if not tensor .is_shared ():
275
308
tensor .share_memory_ ()
276
309
277
- self ._tx .put (("allreduce" , tensors , opts ), timeout = self ._timeout )
278
- op_id = _get (self ._rx , self ._timeout )
279
- assert isinstance (op_id , int ), f"invalid return { op_id } "
280
- return BabyWork (tx = self ._tx , rx = self ._rx , op_id = op_id , timeout = self ._timeout )
310
+ return self ._run_func ("allreduce" , tensors , opts )
281
311
282
312
def size (self ) -> int :
283
313
return self ._world_size
@@ -291,7 +321,23 @@ def getBackendName(self):
291
321
292
322
293
323
class ProcessGroupBabyNCCL (ProcessGroupBaby ):
324
+ """
325
+ This is a ProcessGroup that runs NCCL in a subprocess.
326
+
327
+ For the NCCL backend, extra memory will be used by the subprocesses CUDA
328
+ context compared to running NCCL in the main process. This is typically
329
+ around ~1GB.
330
+
331
+ The returned Work objects only synchronize on the cuda stream and not on the
332
+ CPU side. This works by passing CUDA Events between the processes. To do a
333
+ CPU synchronize, call torch.cuda.synchronize() after wait().
334
+
335
+ WARNING: If the child process is killed while an operation is running, CUDA
336
+ tensors may leak in the current PyTorch implementation. TODO fix
337
+ """
338
+
294
339
PG_CLASS = BaseProcessGroupGloo
340
+ WORK_CLASS = BabyWorkNCCL
295
341
296
342
def getBackendName (self ):
297
343
return "torchft-baby-nccl"
0 commit comments