Skip to content

Commit 217f0d0

Browse files
committed
setup stream dependencies inside work wrapper
Summary: - extend the work wrapper object to also do the division post allreduce - add api to block_current_stream on work wrapper so it can be used for HSDP
1 parent d650c7a commit 217f0d0

File tree

5 files changed

+124
-49
lines changed

5 files changed

+124
-49
lines changed

torchft/ddp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def _comm_hook(
6969
state: "Manager", bucket: dist.GradBucket
7070
) -> torch.futures.Future[torch.Tensor]:
7171
work = state.allreduce(bucket.buffer())
72+
work.synchronize()
7273
return work.get_future()
7374

7475

torchft/local_sgd.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,7 @@ def callback(fut: torch.futures.Future[torch.Tensor]) -> None:
535535
flat_buffer[pack_offset : pack_offset + numel].view_as(t)
536536
)
537537

538+
work.synchronize()
538539
fut = work.get_future()
539540
fut.add_done_callback(callback)
540541

torchft/manager.py

Lines changed: 120 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,29 @@
3535
from contextlib import nullcontext
3636
from datetime import timedelta
3737
from enum import Enum
38-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast
38+
from typing import (
39+
TYPE_CHECKING,
40+
Any,
41+
Callable,
42+
Dict,
43+
List,
44+
Optional,
45+
TypeAlias,
46+
TypeVar,
47+
Union,
48+
cast,
49+
)
3950

4051
import torch
52+
import torch.distributed as dist
4153
from torch.distributed import ReduceOp, TCPStore
4254
from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp, Work
4355

4456
from torchft._torchft import ManagerClient, ManagerServer
4557
from torchft.checkpointing import CheckpointTransport, HTTPTransport
4658
from torchft.checkpointing._rwlock import RWLock
4759
from torchft.futures import future_timeout
48-
from torchft.work import _DummyWork, _WorkWrapper
60+
from torchft.work import _DummyWork
4961

5062
if TYPE_CHECKING:
5163
from torchft.process_group import ProcessGroup
@@ -371,7 +383,11 @@ def shutdown(self, wait: bool = True) -> None:
371383
self._executor.shutdown(wait=wait)
372384

373385
@torch.profiler.record_function("torchft::manager::allreduce")
374-
def allreduce(self, tensor: torch.Tensor, should_quantize: bool = False) -> Work:
386+
def allreduce(
387+
self,
388+
tensor: torch.Tensor,
389+
should_quantize: bool = False,
390+
) -> Work:
375391
"""
376392
Fault tolerant allreduce the tensor and return a Future that will be completed when
377393
the tensor is ready.
@@ -409,37 +425,8 @@ def allreduce(self, tensor: torch.Tensor, should_quantize: bool = False) -> Work
409425
)
410426
else:
411427
work = self._pg.allreduce([tensor], ReduceOp.SUM)
412-
work.wait()
413-
414-
fut = work.get_future()
415-
416-
stream: Optional[torch.cuda.Stream] = (
417-
torch.cuda.current_stream() if torch.cuda.is_available() else None
418-
)
419-
420-
# schedule grad normalization as a continuation
421-
# on the Future
422-
@torch.profiler.record_function("torchft::manager::allreduce::callback")
423-
def callback(
424-
fut: torch.futures.Future[List[torch.Tensor]],
425-
) -> torch.Tensor:
426-
nonlocal tensor, stream, num_participants
427-
428-
# change the stream to avoid making the callback stream
429-
# dependent on process group stream running the allreduce
430-
with torch.cuda.stream(stream) if stream is not None else nullcontext():
431-
# Setup stream dependency
432-
fut.wait()
433-
fut.value()
434-
tensor /= num_participants
435428

436-
return tensor
437-
438-
fut = fut.then(callback)
439-
440-
fut = self.wrap_future(fut, tensor)
441-
442-
return _WorkWrapper(work, fut)
429+
return _ManagedWork(work, self, tensor, num_participants)
443430

444431
except Exception as e:
445432
self._logger.exception(
@@ -962,3 +949,103 @@ def warn(self, msg: str) -> None:
962949

963950
def exception(self, msg: str) -> None:
964951
self._logger.exception(f"{self.prefix()} {msg}")
952+
953+
954+
class _ManagedWork(dist._Work):
955+
def __init__(
956+
self,
957+
work: dist._Work,
958+
manager: Manager,
959+
tensor: torch.Tensor,
960+
num_participants: int,
961+
) -> None:
962+
super().__init__()
963+
self._manager = manager
964+
self._work = work
965+
self._tensor = tensor
966+
self._num_participants = num_participants
967+
self._fut: Union[
968+
torch.futures.Future[torch.Tensor], torch.futures.Future[None]
969+
] = work.get_future()
970+
971+
self._stream: Optional[torch.cuda.Stream] = (
972+
torch.cuda.current_stream() if torch.cuda.is_available() else None
973+
)
974+
975+
self._is_set_future_callback_called = False
976+
977+
def _set_future_callback(
978+
self,
979+
) -> None:
980+
if self._is_set_future_callback_called:
981+
return
982+
983+
# schedule grad normalization as a continuation
984+
# on the Future
985+
@torch.profiler.record_function("torchft::manager::allreduce::callback")
986+
def callback(
987+
fut: torch.futures.Future[List[torch.Tensor]],
988+
) -> torch.Tensor:
989+
# change the stream to avoid making the callback stream
990+
# dependent on process group stream running the allreduce
991+
with (
992+
torch.cuda.stream(self._stream)
993+
if self._stream is not None
994+
else nullcontext()
995+
):
996+
# Setup stream dependency
997+
fut.wait()
998+
self._tensor /= self._num_participants
999+
1000+
return self._tensor
1001+
1002+
fut = self._fut
1003+
fut = fut.then(callback)
1004+
fut = self._manager.wrap_future(fut, self._tensor)
1005+
self._fut = fut
1006+
1007+
self._is_set_future_callback_called = True
1008+
1009+
def wait(self, timeout: Optional[timedelta] = None) -> bool:
1010+
with (
1011+
torch.cuda.stream(self._stream)
1012+
if self._stream is not None
1013+
else nullcontext()
1014+
):
1015+
self._work.wait()
1016+
1017+
self._set_future_callback()
1018+
1019+
with (
1020+
torch.cuda.stream(self._stream)
1021+
if self._stream is not None
1022+
else nullcontext()
1023+
):
1024+
self._fut.wait()
1025+
1026+
return True
1027+
1028+
def block_current_stream(self, timeout: Optional[timedelta] = None) -> None:
1029+
with (
1030+
torch.cuda.stream(self._stream)
1031+
if self._stream is not None
1032+
else nullcontext()
1033+
):
1034+
self._work.block_current_stream()
1035+
1036+
self._set_future_callback()
1037+
1038+
def synchronize(self) -> None:
1039+
if torch.cuda.is_available():
1040+
self.block_current_stream()
1041+
else:
1042+
# No stream dependencies need to be set
1043+
self._set_future_callback()
1044+
1045+
def get_future(
1046+
self,
1047+
) -> Union[torch.futures.Future[torch.Tensor], torch.futures.Future[None]]:
1048+
assert (
1049+
self._is_set_future_callback_called
1050+
), "getting the future without calling synchronize() is unsafe"
1051+
return self._fut

torchft/manager_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,7 @@ def test_manager_numerics(self, client_mock: MagicMock) -> None:
588588

589589
self.assertTrue(manager.is_participating())
590590
work = manager.allreduce(torch.tensor([1.0]))
591+
work.synchronize()
591592
fut = work.get_future()
592593
result = fut.value()
593594
torch.testing.assert_close(result, torch.tensor([1.0 / 5]))
@@ -596,6 +597,7 @@ def test_manager_numerics(self, client_mock: MagicMock) -> None:
596597
manager._healing = True
597598
self.assertFalse(manager.is_participating())
598599
work = manager.allreduce(torch.tensor([1.0]))
600+
work.synchronize()
599601
fut = work.get_future()
600602
result = fut.value()
601603
torch.testing.assert_close(result, torch.tensor([0.0]))

torchft/work.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,3 @@ def wait(self, timeout: Optional[timedelta] = None) -> bool:
1818

1919
def get_future(self) -> torch.futures.Future[object]:
2020
return self.future_
21-
22-
23-
class _WorkWrapper(dist._Work):
24-
def __init__(
25-
self, work: dist._Work, fut: torch.futures.Future[torch.Tensor]
26-
) -> None:
27-
super().__init__()
28-
self._work = work
29-
self._fut = fut
30-
31-
def wait(self, timeout: Optional[timedelta] = None) -> bool:
32-
self._fut.wait()
33-
return True
34-
35-
def get_future(self) -> torch.futures.Future[torch.Tensor]:
36-
return self._fut

0 commit comments

Comments
 (0)