Skip to content

Commit 4162c4a

Browse files
committed
option 2 - call work.wait inside wrapped work
1 parent 91207a2 commit 4162c4a

File tree

2 files changed

+59
-46
lines changed

2 files changed

+59
-46
lines changed

torchft/manager.py

Lines changed: 59 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,14 @@
3838
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast
3939

4040
import torch
41+
import torch.distributed as dist
4142
from torch.distributed import ReduceOp, TCPStore
4243
from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp, Work
4344

4445
from torchft._torchft import ManagerClient, ManagerServer
4546
from torchft.checkpointing import CheckpointTransport, HTTPTransport
4647
from torchft.futures import future_timeout
47-
from torchft.work import _DummyWork, _WorkWrapper
48+
from torchft.work import _DummyWork
4849

4950
if TYPE_CHECKING:
5051
from torchft.process_group import ProcessGroup
@@ -382,37 +383,8 @@ def allreduce(self, tensor: torch.Tensor, should_quantize: bool = False) -> Work
382383
)
383384
else:
384385
work = self._pg.allreduce([tensor], ReduceOp.SUM)
385-
work.wait()
386386

387-
fut = work.get_future()
388-
389-
stream: Optional[torch.cuda.Stream] = (
390-
torch.cuda.current_stream() if torch.cuda.is_available() else None
391-
)
392-
393-
# schedule grad normalization as a continuation
394-
# on the Future
395-
@torch.profiler.record_function("torchft::manager::allreduce::callback")
396-
def callback(
397-
fut: torch.futures.Future[List[torch.Tensor]],
398-
) -> torch.Tensor:
399-
nonlocal tensor, stream, num_participants
400-
401-
# change the stream to avoid making the callback stream
402-
# dependent on process group stream running the allreduce
403-
with torch.cuda.stream(stream) if stream is not None else nullcontext():
404-
# Setup stream dependency
405-
fut.wait()
406-
fut.value()
407-
tensor /= num_participants
408-
409-
return tensor
410-
411-
fut = fut.then(callback)
412-
413-
fut = self.wrap_future(fut, tensor)
414-
415-
return _WorkWrapper(work, fut)
387+
return _WorkWrapper(work, self, tensor, num_participants)
416388

417389
except Exception as e:
418390
self._logger.exception(
@@ -932,3 +904,59 @@ def warn(self, msg: str) -> None:
932904

933905
def exception(self, msg: str) -> None:
934906
self._logger.exception(f"{self.prefix()} {msg}")
907+
908+
909+
class _WorkWrapper(dist._Work):
910+
def __init__(
911+
self,
912+
work: dist._Work,
913+
manager: Manager,
914+
tensor: torch.Tensor,
915+
num_participants: int,
916+
) -> None:
917+
super().__init__()
918+
self._manager = manager
919+
self._work = work
920+
self._tensor = tensor
921+
self._num_participants = num_participants
922+
923+
self._fut: torch.futures.Future[torch.Tensor] = self._work.get_future()
924+
self._stream: Optional[torch.cuda.Stream] = (
925+
torch.cuda.current_stream() if torch.cuda.is_available() else None
926+
)
927+
928+
def wait(self, timeout: Optional[timedelta] = None) -> bool:
929+
with (
930+
torch.cuda.stream(self._stream)
931+
if self._stream is not None
932+
else nullcontext()
933+
):
934+
self._work.wait()
935+
936+
# schedule grad normalization as a continuation
937+
# on the Future
938+
@torch.profiler.record_function("torchft::manager::allreduce::callback")
939+
def callback(
940+
fut: torch.futures.Future[List[torch.Tensor]],
941+
) -> torch.Tensor:
942+
# change the stream to avoid making the callback stream
943+
# dependent on process group stream running the allreduce
944+
with (
945+
torch.cuda.stream(self._stream)
946+
if self._stream is not None
947+
else nullcontext()
948+
):
949+
# Setup stream dependency
950+
fut.wait()
951+
self._tensor /= self._num_participants
952+
953+
return self._tensor
954+
955+
self._fut = self._fut.then(callback)
956+
self._fut = self._manager.wrap_future(self._fut, self._tensor)
957+
958+
return True
959+
960+
def get_future(self) -> torch.futures.Future[torch.Tensor]:
961+
self.wait()
962+
return self._fut

torchft/work.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,3 @@ def wait(self, timeout: Optional[timedelta] = None) -> bool:
1919

2020
def get_future(self) -> torch.futures.Future[object]:
2121
return self.future_
22-
23-
24-
class _WorkWrapper(dist._Work):
25-
def __init__(
26-
self, work: dist._Work, fut: torch.futures.Future[torch.Tensor]
27-
) -> None:
28-
super().__init__()
29-
self._work = work
30-
self._fut = fut
31-
32-
def wait(self, timeout: Optional[timedelta] = None) -> bool:
33-
return True
34-
35-
def get_future(self) -> torch.futures.Future[torch.Tensor]:
36-
return self._fut

0 commit comments

Comments
 (0)