18
18
import torch
19
19
import torch .distributed as dist
20
20
from torch import nn , optim
21
+ from torch .distributed .distributed_c10d import Work
21
22
from torch .distributed .tensor import DTensor
22
23
from torch .nn .parameter import Parameter
23
24
from torch .optim .optimizer import Optimizer
@@ -200,7 +201,7 @@ def __init__(
200
201
self ._outer_optimizer = outer_optimizer
201
202
202
203
# Stores pending all reduce
203
- self ._allreduce_futures : list [torch . futures . Future [ torch . Tensor ] ] = []
204
+ self ._allreduce_work : list [Work ] = []
204
205
self ._stream : Optional [torch .cuda .Stream ] = (
205
206
torch .cuda .Stream () if torch .cuda .is_available () else None
206
207
)
@@ -368,15 +369,15 @@ def wait(self) -> None:
368
369
"""
369
370
Waits for the previously scheduled allreduce to finish
370
371
"""
371
- if len (self ._allreduce_futures ) == 0 :
372
+ if len (self ._allreduce_work ) == 0 :
372
373
return
373
374
374
375
if self ._stream is not None :
375
376
assert self ._stop_event is not None
376
377
self ._stop_event .synchronize ()
377
378
self ._stop_event = None
378
379
379
- self ._allreduce_futures = []
380
+ self ._allreduce_work = []
380
381
381
382
@torch .profiler .record_function ("torchft::local_sgd::prepare_sync" )
382
383
def prepare_sync (self ) -> None :
@@ -386,7 +387,7 @@ def prepare_sync(self) -> None:
386
387
"""
387
388
self ._save_grads ()
388
389
389
- assert len (self ._allreduce_futures ) == 0
390
+ assert len (self ._allreduce_work ) == 0
390
391
391
392
# Make sure tensors are available to `_stream`
392
393
if self ._stream is not None :
@@ -399,7 +400,7 @@ def prepare_sync(self) -> None:
399
400
):
400
401
self ._average_grads ()
401
402
402
- for work in self ._allreduce_futures :
403
+ for work in self ._allreduce_work :
403
404
work .wait ()
404
405
405
406
if self ._stream is not None :
@@ -413,7 +414,7 @@ def perform_sync(self) -> bool:
413
414
steps using the outer optimizer.
414
415
"""
415
416
# Waiting for an allreduce before it has been sent is currently not supported.
416
- assert len (self ._allreduce_futures ) > 0
417
+ assert len (self ._allreduce_work ) > 0
417
418
418
419
self .wait ()
419
420
@@ -467,7 +468,8 @@ def _allreduce_per_param(self) -> None:
467
468
work = self ._manager .allreduce (
468
469
self ._grads [name ], should_quantize = self .should_quantize
469
470
)
470
- self ._allreduce_futures .append (work )
471
+
472
+ self ._allreduce_work .append (work )
471
473
472
474
def _bucketize_and_allreduce (
473
475
self ,
@@ -522,8 +524,10 @@ def callback(fut: torch.futures.Future[torch.Tensor]) -> None:
522
524
flat_buffer [pack_offset : pack_offset + numel ].view_as (t )
523
525
)
524
526
525
- work = work .then (callback )
526
- self ._allreduce_futures .append (work )
527
+ fut = work .get_future ()
528
+ fut = fut .then (callback )
529
+
530
+ self ._allreduce_work .append (work )
527
531
528
532
offset += chunk_size
529
533
0 commit comments