|
38 | 38 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast
|
39 | 39 |
|
40 | 40 | import torch
|
| 41 | +import torch.distributed as dist |
41 | 42 | from torch.distributed import ReduceOp, TCPStore
|
42 | 43 | from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp, Work
|
43 | 44 |
|
44 | 45 | from torchft._torchft import ManagerClient, ManagerServer
|
45 | 46 | from torchft.checkpointing import CheckpointTransport, HTTPTransport
|
46 | 47 | from torchft.futures import future_timeout
|
47 |
| -from torchft.work import _DummyWork, _WorkWrapper |
| 48 | +from torchft.work import _DummyWork |
48 | 49 |
|
49 | 50 | if TYPE_CHECKING:
|
50 | 51 | from torchft.process_group import ProcessGroup
|
|
74 | 75 | QUORUM_RETRIES_ENV: str = "TORCHFT_QUORUM_RETRIES"
|
75 | 76 |
|
76 | 77 | T = TypeVar("T")
|
| 78 | +type AllReduceCallback = Callable[[torch.futures.Future[torch.Tensor]], None] |
77 | 79 |
|
78 | 80 |
|
79 | 81 | def get_timeout(
|
@@ -350,7 +352,12 @@ def shutdown(self, wait: bool = True) -> None:
|
350 | 352 | self._executor.shutdown(wait=wait)
|
351 | 353 |
|
352 | 354 | @torch.profiler.record_function("torchft::manager::allreduce")
|
353 |
| - def allreduce(self, tensor: torch.Tensor, should_quantize: bool = False) -> Work: |
| 355 | + def allreduce( |
| 356 | + self, |
| 357 | + tensor: torch.Tensor, |
| 358 | + should_quantize: bool = False, |
| 359 | + callback: Optional[AllReduceCallback] = None, |
| 360 | + ) -> Work: |
354 | 361 | """
|
355 | 362 | Fault tolerant allreduce the tensor and return a Future that will be completed when
|
356 | 363 | the tensor is ready.
|
@@ -388,37 +395,8 @@ def allreduce(self, tensor: torch.Tensor, should_quantize: bool = False) -> Work
|
388 | 395 | )
|
389 | 396 | else:
|
390 | 397 | work = self._pg.allreduce([tensor], ReduceOp.SUM)
|
391 |
| - work.block_current_stream() |
392 |
| - |
393 |
| - fut = work.get_future() |
394 |
| - |
395 |
| - stream: Optional[torch.cuda.Stream] = ( |
396 |
| - torch.cuda.current_stream() if torch.cuda.is_available() else None |
397 |
| - ) |
398 |
| - |
399 |
| - # schedule grad normalization as a continuation |
400 |
| - # on the Future |
401 |
| - @torch.profiler.record_function("torchft::manager::allreduce::callback") |
402 |
| - def callback( |
403 |
| - fut: torch.futures.Future[List[torch.Tensor]], |
404 |
| - ) -> torch.Tensor: |
405 |
| - nonlocal tensor, stream, num_participants |
406 |
| - |
407 |
| - # change the stream to avoid making the callback stream |
408 |
| - # dependent on process group stream running the allreduce |
409 |
| - with torch.cuda.stream(stream) if stream is not None else nullcontext(): |
410 |
| - # Setup stream dependency |
411 |
| - fut.wait() |
412 |
| - fut.value() |
413 |
| - tensor /= num_participants |
414 |
| - |
415 |
| - return tensor |
416 |
| - |
417 |
| - fut = fut.then(callback) |
418 |
| - |
419 |
| - fut = self.wrap_future(fut, tensor) |
420 | 398 |
|
421 |
| - return _WorkWrapper(work, fut) |
| 399 | + return _WorkWrapper(work, self, tensor, num_participants, callback) |
422 | 400 |
|
423 | 401 | except Exception as e:
|
424 | 402 | self._logger.exception(
|
@@ -938,3 +916,77 @@ def warn(self, msg: str) -> None:
|
938 | 916 |
|
939 | 917 | def exception(self, msg: str) -> None:
|
940 | 918 | self._logger.exception(f"{self.prefix()} {msg}")
|
| 919 | + |
| 920 | + |
| 921 | +class _WorkWrapper(dist._Work): |
| 922 | + def __init__( |
| 923 | + self, |
| 924 | + work: dist._Work, |
| 925 | + manager: Manager, |
| 926 | + tensor: torch.Tensor, |
| 927 | + num_participants: int, |
| 928 | + callback: Optional[AllReduceCallback], |
| 929 | + ) -> None: |
| 930 | + super().__init__() |
| 931 | + self._manager = manager |
| 932 | + self._work = work |
| 933 | + self._tensor = tensor |
| 934 | + self._num_participants = num_participants |
| 935 | + self._callback = callback |
| 936 | + |
| 937 | + self._stream: Optional[torch.cuda.Stream] = ( |
| 938 | + torch.cuda.current_stream() if torch.cuda.is_available() else None |
| 939 | + ) |
| 940 | + |
| 941 | + def _set_future_callback( |
| 942 | + self, |
| 943 | + ) -> None: |
| 944 | + # schedule grad normalization as a continuation |
| 945 | + # on the Future |
| 946 | + @torch.profiler.record_function("torchft::manager::allreduce::callback") |
| 947 | + def callback( |
| 948 | + fut: torch.futures.Future[List[torch.Tensor]], |
| 949 | + ) -> torch.Tensor: |
| 950 | + # change the stream to avoid making the callback stream |
| 951 | + # dependent on process group stream running the allreduce |
| 952 | + with ( |
| 953 | + torch.cuda.stream(self._stream) |
| 954 | + if self._stream is not None |
| 955 | + else nullcontext() |
| 956 | + ): |
| 957 | + # Setup stream dependency |
| 958 | + fut.wait() |
| 959 | + self._tensor /= self._num_participants |
| 960 | + |
| 961 | + return self._tensor |
| 962 | + |
| 963 | + fut = self._work.get_future() |
| 964 | + fut = fut.then(callback) |
| 965 | + fut = self._manager.wrap_future(fut, self._tensor) |
| 966 | + fut = fut.then(self._callback) if self._callback else fut |
| 967 | + |
| 968 | + def wait(self, timeout: Optional[timedelta] = None) -> bool: |
| 969 | + with ( |
| 970 | + torch.cuda.stream(self._stream) |
| 971 | + if self._stream is not None |
| 972 | + else nullcontext() |
| 973 | + ): |
| 974 | + self._work.wait() |
| 975 | + |
| 976 | + self._set_future_callback() |
| 977 | + |
| 978 | + return True |
| 979 | + |
| 980 | + def block_current_stream(self, timeout: Optional[timedelta] = None) -> None: |
| 981 | + with ( |
| 982 | + torch.cuda.stream(self._stream) |
| 983 | + if self._stream is not None |
| 984 | + else nullcontext() |
| 985 | + ): |
| 986 | + self._work.block_current_stream() |
| 987 | + |
| 988 | + self._set_future_callback() |
| 989 | + |
| 990 | + def get_future(self) -> torch.futures.Future[torch.Tensor]: |
| 991 | + self.block_current_stream() |
| 992 | + return self._work.get_future() |
0 commit comments