Skip to content

Commit 83957be

Browse files
committed
option 2 - call work.wait inside wrapped work
1 parent a77b84b commit 83957be

File tree

3 files changed

+89
-54
lines changed

3 files changed

+89
-54
lines changed

torchft/local_sgd.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -521,10 +521,6 @@ def _bucketize_and_allreduce(
521521
pack_offset += numel
522522
flat_index += 1
523523

524-
work = self._manager.allreduce(
525-
flat_buffer, should_quantize=self.should_quantize
526-
)
527-
528524
def callback(fut: torch.futures.Future[torch.Tensor]) -> None:
529525
with torch.cuda.stream(self._stream) if self._stream else nullcontext():
530526
nonlocal bucket_tensors, flat_buffer
@@ -535,8 +531,11 @@ def callback(fut: torch.futures.Future[torch.Tensor]) -> None:
535531
flat_buffer[pack_offset : pack_offset + numel].view_as(t)
536532
)
537533

538-
fut = work.get_future()
539-
fut = fut.then(callback)
534+
work = self._manager.allreduce(
535+
flat_buffer,
536+
should_quantize=self.should_quantize,
537+
callback=callback,
538+
)
540539

541540
self._allreduce_work.append(work)
542541

torchft/manager.py

Lines changed: 84 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,14 @@
3838
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast
3939

4040
import torch
41+
import torch.distributed as dist
4142
from torch.distributed import ReduceOp, TCPStore
4243
from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp, Work
4344

4445
from torchft._torchft import ManagerClient, ManagerServer
4546
from torchft.checkpointing import CheckpointTransport, HTTPTransport
4647
from torchft.futures import future_timeout
47-
from torchft.work import _DummyWork, _WorkWrapper
48+
from torchft.work import _DummyWork
4849

4950
if TYPE_CHECKING:
5051
from torchft.process_group import ProcessGroup
@@ -74,6 +75,7 @@
7475
QUORUM_RETRIES_ENV: str = "TORCHFT_QUORUM_RETRIES"
7576

7677
T = TypeVar("T")
78+
type AllReduceCallback = Callable[[torch.futures.Future[torch.Tensor]], None]
7779

7880

7981
def get_timeout(
@@ -350,7 +352,12 @@ def shutdown(self, wait: bool = True) -> None:
350352
self._executor.shutdown(wait=wait)
351353

352354
@torch.profiler.record_function("torchft::manager::allreduce")
353-
def allreduce(self, tensor: torch.Tensor, should_quantize: bool = False) -> Work:
355+
def allreduce(
356+
self,
357+
tensor: torch.Tensor,
358+
should_quantize: bool = False,
359+
callback: Optional[AllReduceCallback] = None,
360+
) -> Work:
354361
"""
355362
Fault tolerant allreduce the tensor and return a Future that will be completed when
356363
the tensor is ready.
@@ -388,37 +395,8 @@ def allreduce(self, tensor: torch.Tensor, should_quantize: bool = False) -> Work
388395
)
389396
else:
390397
work = self._pg.allreduce([tensor], ReduceOp.SUM)
391-
work.block_current_stream()
392-
393-
fut = work.get_future()
394-
395-
stream: Optional[torch.cuda.Stream] = (
396-
torch.cuda.current_stream() if torch.cuda.is_available() else None
397-
)
398-
399-
# schedule grad normalization as a continuation
400-
# on the Future
401-
@torch.profiler.record_function("torchft::manager::allreduce::callback")
402-
def callback(
403-
fut: torch.futures.Future[List[torch.Tensor]],
404-
) -> torch.Tensor:
405-
nonlocal tensor, stream, num_participants
406-
407-
# change the stream to avoid making the callback stream
408-
# dependent on process group stream running the allreduce
409-
with torch.cuda.stream(stream) if stream is not None else nullcontext():
410-
# Setup stream dependency
411-
fut.wait()
412-
fut.value()
413-
tensor /= num_participants
414-
415-
return tensor
416-
417-
fut = fut.then(callback)
418-
419-
fut = self.wrap_future(fut, tensor)
420398

421-
return _WorkWrapper(work, fut)
399+
return _WorkWrapper(work, self, tensor, num_participants, callback)
422400

423401
except Exception as e:
424402
self._logger.exception(
@@ -938,3 +916,77 @@ def warn(self, msg: str) -> None:
938916

939917
def exception(self, msg: str) -> None:
940918
self._logger.exception(f"{self.prefix()} {msg}")
919+
920+
921+
class _WorkWrapper(dist._Work):
922+
def __init__(
923+
self,
924+
work: dist._Work,
925+
manager: Manager,
926+
tensor: torch.Tensor,
927+
num_participants: int,
928+
callback: Optional[AllReduceCallback],
929+
) -> None:
930+
super().__init__()
931+
self._manager = manager
932+
self._work = work
933+
self._tensor = tensor
934+
self._num_participants = num_participants
935+
self._callback = callback
936+
937+
self._stream: Optional[torch.cuda.Stream] = (
938+
torch.cuda.current_stream() if torch.cuda.is_available() else None
939+
)
940+
941+
def _set_future_callback(
942+
self,
943+
) -> None:
944+
# schedule grad normalization as a continuation
945+
# on the Future
946+
@torch.profiler.record_function("torchft::manager::allreduce::callback")
947+
def callback(
948+
fut: torch.futures.Future[List[torch.Tensor]],
949+
) -> torch.Tensor:
950+
# change the stream to avoid making the callback stream
951+
# dependent on process group stream running the allreduce
952+
with (
953+
torch.cuda.stream(self._stream)
954+
if self._stream is not None
955+
else nullcontext()
956+
):
957+
# Setup stream dependency
958+
fut.wait()
959+
self._tensor /= self._num_participants
960+
961+
return self._tensor
962+
963+
fut = self._work.get_future()
964+
fut = fut.then(callback)
965+
fut = self._manager.wrap_future(fut, self._tensor)
966+
fut = fut.then(self._callback) if self._callback else fut
967+
968+
def wait(self, timeout: Optional[timedelta] = None) -> bool:
969+
with (
970+
torch.cuda.stream(self._stream)
971+
if self._stream is not None
972+
else nullcontext()
973+
):
974+
self._work.wait()
975+
976+
self._set_future_callback()
977+
978+
return True
979+
980+
def block_current_stream(self, timeout: Optional[timedelta] = None) -> None:
981+
with (
982+
torch.cuda.stream(self._stream)
983+
if self._stream is not None
984+
else nullcontext()
985+
):
986+
self._work.block_current_stream()
987+
988+
self._set_future_callback()
989+
990+
def get_future(self) -> torch.futures.Future[torch.Tensor]:
991+
self.block_current_stream()
992+
return self._work.get_future()

torchft/work.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,3 @@ def wait(self, timeout: Optional[timedelta] = None) -> bool:
1818

1919
def get_future(self) -> torch.futures.Future[object]:
2020
return self.future_
21-
22-
23-
class _WorkWrapper(dist._Work):
24-
def __init__(
25-
self, work: dist._Work, fut: torch.futures.Future[torch.Tensor]
26-
) -> None:
27-
super().__init__()
28-
self._work = work
29-
self._fut = fut
30-
31-
def wait(self, timeout: Optional[timedelta] = None) -> bool:
32-
self._fut.wait()
33-
return True
34-
35-
def get_future(self) -> torch.futures.Future[torch.Tensor]:
36-
return self._fut

0 commit comments

Comments
 (0)