8
8
import logging
9
9
from typing import Type , List , Optional , Callable , Tuple
10
10
from datetime import timedelta
11
+ import threading
11
12
12
13
from torch .futures import Future
13
14
from torch .distributed import (
26
27
27
28
logger = logging .getLogger (__name__ )
28
29
30
+ # TODO: use non strings which are cheaper
31
+ _QUEUE_CLOSE = "queue_close"
32
+ _FUTURE_RESULT = "fut_result"
33
+ _FUTURE_EXCEPTION = "fut_exception"
34
+
29
35
30
36
def _get (queue : mp .Queue , timeout ) -> object :
31
37
v = queue .get (timeout = timeout )
@@ -208,9 +214,17 @@ def getBackendName(self):
208
214
209
215
210
216
class BabyWork (Work ):
211
- def __init__ (self , tx : mp .Queue , rx : mp .Queue , op_id : int , timeout : float ):
217
+ def __init__ (
218
+ self ,
219
+ pg : "ProcessGroupBaby" ,
220
+ tx : mp .Queue ,
221
+ rx : mp .Queue ,
222
+ op_id : int ,
223
+ timeout : float ,
224
+ ):
212
225
super ().__init__ ()
213
226
227
+ self ._pg = pg
214
228
self ._tx = tx
215
229
self ._rx = rx
216
230
self ._op_id = op_id
@@ -221,6 +235,9 @@ def wait(self) -> bool:
221
235
assert _get (self ._rx , self ._timeout ) == self ._op_id
222
236
return True
223
237
238
+ def get_future (self ) -> Future :
239
+ return self ._pg ._get_future (self ._op_id )
240
+
224
241
225
242
class BabyWorkNCCL (BabyWork ):
226
243
def wait (self ) -> bool :
@@ -255,6 +272,8 @@ def __init__(self, timeout: float = 60.0) -> None:
255
272
self ._p = None
256
273
self ._tx = None
257
274
self ._rx = None
275
+ self ._future_queue = None
276
+ self ._future_thread = None
258
277
259
278
self ._timeout = timeout
260
279
@@ -264,20 +283,46 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
264
283
265
284
self ._world_size = world_size
266
285
286
+ if self ._tx is not None :
287
+ self ._tx .close ()
288
+ if self ._rx is not None :
289
+ self ._rx .close ()
290
+ if self ._future_queue is not None :
291
+ self ._future_queue .put (_QUEUE_CLOSE )
292
+ self ._future_queue .close ()
293
+
267
294
ctx = mp .get_context ("spawn" )
268
295
self ._tx = ctx .Queue ()
269
296
self ._rx = ctx .Queue ()
270
297
298
+ # futures need thread to fire callbacks
299
+ self ._future_queue = ctx .Queue ()
300
+ # this lock needs to be held when manipulating _futures
301
+ self ._futures_lock = threading .Lock ()
302
+ self ._futures = {}
303
+ self ._future_thread = threading .Thread (
304
+ target = self ._future_handler ,
305
+ args = (self ._future_queue ,),
306
+ daemon = True ,
307
+ )
308
+ self ._future_thread .start ()
309
+
271
310
self ._p = ctx .Process (
272
311
target = self ._worker ,
273
- args = (store_addr , rank , world_size , self ._tx , self ._rx ),
312
+ args = (store_addr , rank , world_size , self ._tx , self ._rx , self . _future_queue ),
274
313
daemon = True ,
275
314
)
276
315
self ._p .start ()
277
316
278
317
@classmethod
279
318
def _worker (
280
- cls , store_addr : str , rank : int , world_size : int , rx : mp .Queue , tx : mp .Queue
319
+ cls ,
320
+ store_addr : str ,
321
+ rank : int ,
322
+ world_size : int ,
323
+ rx : mp .Queue ,
324
+ tx : mp .Queue ,
325
+ future_queue : mp .Queue ,
281
326
) -> None :
282
327
try :
283
328
store = create_store (store_addr )
@@ -291,15 +336,28 @@ def _worker(
291
336
op = rx .get ()
292
337
cmd = op [0 ]
293
338
if cmd == "func" :
294
- func , args , kwargs = op [1 :]
295
- work [next_op_id ] = getattr (pg , func )(* args , ** kwargs )
339
+ func_name , args , kwargs = op [1 :]
340
+ fn = getattr (pg , func_name )
341
+ work [next_op_id ] = fn (* args , ** kwargs )
296
342
tx .put (next_op_id )
297
343
next_op_id += 1
298
344
elif cmd == "wait" :
299
345
op_id = op [1 ]
300
346
work [op_id ].wait ()
301
347
del work [op_id ]
302
348
tx .put (op_id )
349
+ elif cmd == "future" :
350
+ op_id = op [1 ]
351
+
352
+ def callback (fut : Future ):
353
+ try :
354
+ fut .wait ()
355
+ future_queue .put ((op_id , _FUTURE_RESULT , None ))
356
+ except Exception as e :
357
+ future_queue .put ((op_id , _FUTURE_EXCEPTION , e ))
358
+
359
+ work [op_id ].get_future ().add_done_callback (callback )
360
+ tx .put (op_id )
303
361
elif cmd == "synchronize" :
304
362
# CUDA only, use events instead of waiting on CPU
305
363
op_id = op [1 ]
@@ -322,12 +380,41 @@ def _worker(
322
380
logger .exception ("worker errored" )
323
381
tx .put (e )
324
382
383
+ def _future_handler (self , future_queue : mp .Queue ) -> None :
384
+ try :
385
+ while True :
386
+ cmd = future_queue .get ()
387
+ if cmd == _QUEUE_CLOSE :
388
+ break
389
+ op_id , mode , data = cmd
390
+ with self ._futures_lock :
391
+ fut = self ._futures [op_id ]
392
+ del self ._futures [op_id ]
393
+ if mode == _FUTURE_RESULT :
394
+ fut .set_result (data )
395
+ elif mode == _FUTURE_EXCEPTION :
396
+ fut .set_exception (data )
397
+ else :
398
+ raise ValueError (f"unknown mode { mode } " )
399
+ except Exception as e :
400
+ logger .exception (f"got unexpected error in future handler: { e } " )
401
+
402
+ def _get_future (self , op_id : int ) -> Future :
403
+ with self ._futures_lock :
404
+ fut = Future ()
405
+ self ._futures [op_id ] = fut
406
+ self ._tx .put (("future" , op_id ), timeout = self ._timeout )
407
+
408
+ assert _get (self ._rx , self ._timeout ) == op_id
409
+ # TODO: return correct tensor instead of None
410
+ return fut
411
+
325
412
def _run_func (self , func : str , * args : object , ** kwargs : object ) -> Work :
326
413
self ._tx .put (("func" , func , args , kwargs ), timeout = self ._timeout )
327
414
op_id = _get (self ._rx , self ._timeout )
328
415
assert isinstance (op_id , int ), f"invalid return { op_id } "
329
416
return self .WORK_CLASS (
330
- tx = self ._tx , rx = self ._rx , op_id = op_id , timeout = self ._timeout
417
+ pg = self , tx = self ._tx , rx = self ._rx , op_id = op_id , timeout = self ._timeout
331
418
)
332
419
333
420
def allreduce (self , tensors : List [torch .Tensor ], opts : object ) -> Work :
@@ -366,7 +453,7 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby):
366
453
tensors may leak in the current PyTorch implementation. TODO fix
367
454
"""
368
455
369
- PG_CLASS = BaseProcessGroupGloo
456
+ PG_CLASS = BaseProcessGroupNCCL
370
457
WORK_CLASS = BabyWorkNCCL
371
458
372
459
def getBackendName (self ):
0 commit comments