Skip to content

Commit a608410

Browse files
committed
setup stream dependencies inside work wrapper
Summary: - extend the work wrapper object to also do the division post allreduce - add api to block_current_stream on work wrapper so it can be used for HSDP
1 parent b746582 commit a608410

File tree

5 files changed

+124
-49
lines changed

5 files changed

+124
-49
lines changed

torchft/ddp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def _comm_hook(
6969
state: "Manager", bucket: dist.GradBucket
7070
) -> torch.futures.Future[torch.Tensor]:
7171
work = state.allreduce(bucket.buffer())
72+
work.synchronize()
7273
return work.get_future()
7374

7475

torchft/local_sgd.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,7 @@ def callback(fut: torch.futures.Future[torch.Tensor]) -> None:
524524
flat_buffer[pack_offset : pack_offset + numel].view_as(t)
525525
)
526526

527+
work.synchronize()
527528
fut = work.get_future()
528529
fut.add_done_callback(callback)
529530

torchft/manager.py

Lines changed: 120 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,28 @@
3535
from contextlib import nullcontext
3636
from datetime import timedelta
3737
from enum import Enum
38-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast
38+
from typing import (
39+
TYPE_CHECKING,
40+
Any,
41+
Callable,
42+
Dict,
43+
List,
44+
Optional,
45+
TypeAlias,
46+
TypeVar,
47+
Union,
48+
cast,
49+
)
3950

4051
import torch
52+
import torch.distributed as dist
4153
from torch.distributed import ReduceOp, TCPStore
4254
from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp, Work
4355

4456
from torchft._torchft import ManagerClient, ManagerServer
4557
from torchft.checkpointing import CheckpointTransport, HTTPTransport
4658
from torchft.futures import future_timeout
47-
from torchft.work import _DummyWork, _WorkWrapper
59+
from torchft.work import _DummyWork
4860

4961
if TYPE_CHECKING:
5062
from torchft.process_group import ProcessGroup
@@ -344,7 +356,11 @@ def shutdown(self, wait: bool = True) -> None:
344356
self._executor.shutdown(wait=wait)
345357

346358
@torch.profiler.record_function("torchft::manager::allreduce")
347-
def allreduce(self, tensor: torch.Tensor, should_quantize: bool = False) -> Work:
359+
def allreduce(
360+
self,
361+
tensor: torch.Tensor,
362+
should_quantize: bool = False,
363+
) -> Work:
348364
"""
349365
Fault tolerant allreduce the tensor and return a Future that will be completed when
350366
the tensor is ready.
@@ -382,37 +398,8 @@ def allreduce(self, tensor: torch.Tensor, should_quantize: bool = False) -> Work
382398
)
383399
else:
384400
work = self._pg.allreduce([tensor], ReduceOp.SUM)
385-
work.wait()
386-
387-
fut = work.get_future()
388-
389-
stream: Optional[torch.cuda.Stream] = (
390-
torch.cuda.current_stream() if torch.cuda.is_available() else None
391-
)
392-
393-
# schedule grad normalization as a continuation
394-
# on the Future
395-
@torch.profiler.record_function("torchft::manager::allreduce::callback")
396-
def callback(
397-
fut: torch.futures.Future[List[torch.Tensor]],
398-
) -> torch.Tensor:
399-
nonlocal tensor, stream, num_participants
400-
401-
# change the stream to avoid making the callback stream
402-
# dependent on process group stream running the allreduce
403-
with torch.cuda.stream(stream) if stream is not None else nullcontext():
404-
# Setup stream dependency
405-
fut.wait()
406-
fut.value()
407-
tensor /= num_participants
408401

409-
return tensor
410-
411-
fut = fut.then(callback)
412-
413-
fut = self.wrap_future(fut, tensor)
414-
415-
return _WorkWrapper(work, fut)
402+
return _ManagedWork(work, self, tensor, num_participants)
416403

417404
except Exception as e:
418405
self._logger.exception(
@@ -932,3 +919,103 @@ def warn(self, msg: str) -> None:
932919

933920
def exception(self, msg: str) -> None:
934921
self._logger.exception(f"{self.prefix()} {msg}")
922+
923+
924+
class _ManagedWork(dist._Work):
925+
def __init__(
926+
self,
927+
work: dist._Work,
928+
manager: Manager,
929+
tensor: torch.Tensor,
930+
num_participants: int,
931+
) -> None:
932+
super().__init__()
933+
self._manager = manager
934+
self._work = work
935+
self._tensor = tensor
936+
self._num_participants = num_participants
937+
self._fut: Union[
938+
torch.futures.Future[torch.Tensor], torch.futures.Future[None]
939+
] = work.get_future()
940+
941+
self._stream: Optional[torch.cuda.Stream] = (
942+
torch.cuda.current_stream() if torch.cuda.is_available() else None
943+
)
944+
945+
self._is_set_future_callback_called = False
946+
947+
def _set_future_callback(
948+
self,
949+
) -> None:
950+
if self._is_set_future_callback_called:
951+
return
952+
953+
# schedule grad normalization as a continuation
954+
# on the Future
955+
@torch.profiler.record_function("torchft::manager::allreduce::callback")
956+
def callback(
957+
fut: torch.futures.Future[List[torch.Tensor]],
958+
) -> torch.Tensor:
959+
# change the stream to avoid making the callback stream
960+
# dependent on process group stream running the allreduce
961+
with (
962+
torch.cuda.stream(self._stream)
963+
if self._stream is not None
964+
else nullcontext()
965+
):
966+
# Setup stream dependency
967+
fut.wait()
968+
self._tensor /= self._num_participants
969+
970+
return self._tensor
971+
972+
fut = self._fut
973+
fut = fut.then(callback)
974+
fut = self._manager.wrap_future(fut, self._tensor)
975+
self._fut = fut
976+
977+
self._is_set_future_callback_called = True
978+
979+
def wait(self, timeout: Optional[timedelta] = None) -> bool:
980+
with (
981+
torch.cuda.stream(self._stream)
982+
if self._stream is not None
983+
else nullcontext()
984+
):
985+
self._work.wait()
986+
987+
self._set_future_callback()
988+
989+
with (
990+
torch.cuda.stream(self._stream)
991+
if self._stream is not None
992+
else nullcontext()
993+
):
994+
self._fut.wait()
995+
996+
return True
997+
998+
def block_current_stream(self, timeout: Optional[timedelta] = None) -> None:
999+
with (
1000+
torch.cuda.stream(self._stream)
1001+
if self._stream is not None
1002+
else nullcontext()
1003+
):
1004+
self._work.block_current_stream()
1005+
1006+
self._set_future_callback()
1007+
1008+
def synchronize(self) -> None:
1009+
if torch.cuda.is_available():
1010+
self.block_current_stream()
1011+
else:
1012+
# No stream dependencies need to be set
1013+
self._set_future_callback()
1014+
1015+
def get_future(
1016+
self,
1017+
) -> Union[torch.futures.Future[torch.Tensor], torch.futures.Future[None]]:
1018+
assert (
1019+
self._is_set_future_callback_called
1020+
), "getting the future without calling synchronize() is unsafe"
1021+
return self._fut

torchft/manager_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,7 @@ def test_manager_numerics(self, client_mock: MagicMock) -> None:
588588

589589
self.assertTrue(manager.is_participating())
590590
work = manager.allreduce(torch.tensor([1.0]))
591+
work.synchronize()
591592
fut = work.get_future()
592593
result = fut.value()
593594
torch.testing.assert_close(result, torch.tensor([1.0 / 5]))
@@ -596,6 +597,7 @@ def test_manager_numerics(self, client_mock: MagicMock) -> None:
596597
manager._healing = True
597598
self.assertFalse(manager.is_participating())
598599
work = manager.allreduce(torch.tensor([1.0]))
600+
work.synchronize()
599601
fut = work.get_future()
600602
result = fut.value()
601603
torch.testing.assert_close(result, torch.tensor([0.0]))

torchft/work.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,3 @@ def wait(self, timeout: Optional[timedelta] = None) -> bool:
1818

1919
def get_future(self) -> torch.futures.Future[object]:
2020
return self.future_
21-
22-
23-
class _WorkWrapper(dist._Work):
24-
def __init__(
25-
self, work: dist._Work, fut: torch.futures.Future[torch.Tensor]
26-
) -> None:
27-
super().__init__()
28-
self._work = work
29-
self._fut = fut
30-
31-
def wait(self, timeout: Optional[timedelta] = None) -> bool:
32-
self._fut.wait()
33-
return True
34-
35-
def get_future(self) -> torch.futures.Future[torch.Tensor]:
36-
return self._fut

0 commit comments

Comments
 (0)