|
38 | 38 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast
|
39 | 39 |
|
40 | 40 | import torch
|
| 41 | +import torch.distributed as dist |
41 | 42 | from torch.distributed import ReduceOp, TCPStore
|
42 | 43 | from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp, Work
|
43 | 44 |
|
44 | 45 | from torchft._torchft import ManagerClient, ManagerServer
|
45 | 46 | from torchft.checkpointing import CheckpointTransport, HTTPTransport
|
46 | 47 | from torchft.futures import future_timeout
|
47 |
| -from torchft.work import _DummyWork, _WorkWrapper |
| 48 | +from torchft.work import _DummyWork |
48 | 49 |
|
49 | 50 | if TYPE_CHECKING:
|
50 | 51 | from torchft.process_group import ProcessGroup
|
@@ -382,37 +383,8 @@ def allreduce(self, tensor: torch.Tensor, should_quantize: bool = False) -> Work
|
382 | 383 | )
|
383 | 384 | else:
|
384 | 385 | work = self._pg.allreduce([tensor], ReduceOp.SUM)
|
385 |
| - work.wait() |
386 | 386 |
|
387 |
| - fut = work.get_future() |
388 |
| - |
389 |
| - stream: Optional[torch.cuda.Stream] = ( |
390 |
| - torch.cuda.current_stream() if torch.cuda.is_available() else None |
391 |
| - ) |
392 |
| - |
393 |
| - # schedule grad normalization as a continuation |
394 |
| - # on the Future |
395 |
| - @torch.profiler.record_function("torchft::manager::allreduce::callback") |
396 |
| - def callback( |
397 |
| - fut: torch.futures.Future[List[torch.Tensor]], |
398 |
| - ) -> torch.Tensor: |
399 |
| - nonlocal tensor, stream, num_participants |
400 |
| - |
401 |
| - # change the stream to avoid making the callback stream |
402 |
| - # dependent on process group stream running the allreduce |
403 |
| - with torch.cuda.stream(stream) if stream is not None else nullcontext(): |
404 |
| - # Setup stream dependency |
405 |
| - fut.wait() |
406 |
| - fut.value() |
407 |
| - tensor /= num_participants |
408 |
| - |
409 |
| - return tensor |
410 |
| - |
411 |
| - fut = fut.then(callback) |
412 |
| - |
413 |
| - fut = self.wrap_future(fut, tensor) |
414 |
| - |
415 |
| - return _WorkWrapper(work, fut) |
| 387 | + return _WorkWrapper(work, self, tensor, num_participants) |
416 | 388 |
|
417 | 389 | except Exception as e:
|
418 | 390 | self._logger.exception(
|
@@ -932,3 +904,59 @@ def warn(self, msg: str) -> None:
|
932 | 904 |
|
933 | 905 | def exception(self, msg: str) -> None:
|
934 | 906 | self._logger.exception(f"{self.prefix()} {msg}")
|
| 907 | + |
| 908 | + |
| 909 | +class _WorkWrapper(dist._Work): |
| 910 | + def __init__( |
| 911 | + self, |
| 912 | + work: dist._Work, |
| 913 | + manager: Manager, |
| 914 | + tensor: torch.Tensor, |
| 915 | + num_participants: int, |
| 916 | + ) -> None: |
| 917 | + super().__init__() |
| 918 | + self._manager = manager |
| 919 | + self._work = work |
| 920 | + self._tensor = tensor |
| 921 | + self._num_participants = num_participants |
| 922 | + |
| 923 | + self._fut: torch.futures.Future[torch.Tensor] = self._work.get_future() |
| 924 | + self._stream: Optional[torch.cuda.Stream] = ( |
| 925 | + torch.cuda.current_stream() if torch.cuda.is_available() else None |
| 926 | + ) |
| 927 | + |
| 928 | + def wait(self, timeout: Optional[timedelta] = None) -> bool: |
| 929 | + with ( |
| 930 | + torch.cuda.stream(self._stream) |
| 931 | + if self._stream is not None |
| 932 | + else nullcontext() |
| 933 | + ): |
| 934 | + self._work.wait() |
| 935 | + |
| 936 | + # schedule grad normalization as a continuation |
| 937 | + # on the Future |
| 938 | + @torch.profiler.record_function("torchft::manager::allreduce::callback") |
| 939 | + def callback( |
| 940 | + fut: torch.futures.Future[List[torch.Tensor]], |
| 941 | + ) -> torch.Tensor: |
| 942 | + # change the stream to avoid making the callback stream |
| 943 | + # dependent on process group stream running the allreduce |
| 944 | + with ( |
| 945 | + torch.cuda.stream(self._stream) |
| 946 | + if self._stream is not None |
| 947 | + else nullcontext() |
| 948 | + ): |
| 949 | + # Setup stream dependency |
| 950 | + fut.wait() |
| 951 | + self._tensor /= self._num_participants |
| 952 | + |
| 953 | + return self._tensor |
| 954 | + |
| 955 | + self._fut = self._fut.then(callback) |
| 956 | + self._fut = self._manager.wrap_future(self._fut, self._tensor) |
| 957 | + |
| 958 | + return True |
| 959 | + |
| 960 | + def get_future(self) -> torch.futures.Future[torch.Tensor]: |
| 961 | + self.wait() |
| 962 | + return self._fut |
0 commit comments