|
35 | 35 | from contextlib import nullcontext
|
36 | 36 | from datetime import timedelta
|
37 | 37 | from enum import Enum
|
38 |
| -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast |
| 38 | +from typing import ( |
| 39 | + TYPE_CHECKING, |
| 40 | + Any, |
| 41 | + Callable, |
| 42 | + Dict, |
| 43 | + List, |
| 44 | + Optional, |
| 45 | + TypeAlias, |
| 46 | + TypeVar, |
| 47 | + Union, |
| 48 | + cast, |
| 49 | +) |
39 | 50 |
|
40 | 51 | import torch
|
| 52 | +import torch.distributed as dist |
41 | 53 | from torch.distributed import ReduceOp, TCPStore
|
42 | 54 | from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp, Work
|
43 | 55 |
|
44 | 56 | from torchft._torchft import ManagerClient, ManagerServer
|
45 | 57 | from torchft.checkpointing import CheckpointTransport, HTTPTransport
|
46 | 58 | from torchft.checkpointing._rwlock import RWLock
|
47 | 59 | from torchft.futures import future_timeout
|
48 |
| -from torchft.work import _DummyWork, _WorkWrapper |
| 60 | +from torchft.work import _DummyWork |
49 | 61 |
|
50 | 62 | if TYPE_CHECKING:
|
51 | 63 | from torchft.process_group import ProcessGroup
|
@@ -371,7 +383,11 @@ def shutdown(self, wait: bool = True) -> None:
|
371 | 383 | self._executor.shutdown(wait=wait)
|
372 | 384 |
|
373 | 385 | @torch.profiler.record_function("torchft::manager::allreduce")
|
374 |
| - def allreduce(self, tensor: torch.Tensor, should_quantize: bool = False) -> Work: |
| 386 | + def allreduce( |
| 387 | + self, |
| 388 | + tensor: torch.Tensor, |
| 389 | + should_quantize: bool = False, |
| 390 | + ) -> Work: |
375 | 391 | """
|
376 | 392 | Fault tolerant allreduce the tensor and return a Future that will be completed when
|
377 | 393 | the tensor is ready.
|
@@ -409,37 +425,8 @@ def allreduce(self, tensor: torch.Tensor, should_quantize: bool = False) -> Work
|
409 | 425 | )
|
410 | 426 | else:
|
411 | 427 | work = self._pg.allreduce([tensor], ReduceOp.SUM)
|
412 |
| - work.wait() |
413 |
| - |
414 |
| - fut = work.get_future() |
415 |
| - |
416 |
| - stream: Optional[torch.cuda.Stream] = ( |
417 |
| - torch.cuda.current_stream() if torch.cuda.is_available() else None |
418 |
| - ) |
419 |
| - |
420 |
| - # schedule grad normalization as a continuation |
421 |
| - # on the Future |
422 |
| - @torch.profiler.record_function("torchft::manager::allreduce::callback") |
423 |
| - def callback( |
424 |
| - fut: torch.futures.Future[List[torch.Tensor]], |
425 |
| - ) -> torch.Tensor: |
426 |
| - nonlocal tensor, stream, num_participants |
427 |
| - |
428 |
| - # change the stream to avoid making the callback stream |
429 |
| - # dependent on process group stream running the allreduce |
430 |
| - with torch.cuda.stream(stream) if stream is not None else nullcontext(): |
431 |
| - # Setup stream dependency |
432 |
| - fut.wait() |
433 |
| - fut.value() |
434 |
| - tensor /= num_participants |
435 | 428 |
|
436 |
| - return tensor |
437 |
| - |
438 |
| - fut = fut.then(callback) |
439 |
| - |
440 |
| - fut = self.wrap_future(fut, tensor) |
441 |
| - |
442 |
| - return _WorkWrapper(work, fut) |
| 429 | + return _ManagedWork(work, self, tensor, num_participants) |
443 | 430 |
|
444 | 431 | except Exception as e:
|
445 | 432 | self._logger.exception(
|
@@ -962,3 +949,103 @@ def warn(self, msg: str) -> None:
|
962 | 949 |
|
963 | 950 | def exception(self, msg: str) -> None:
|
964 | 951 | self._logger.exception(f"{self.prefix()} {msg}")
|
| 952 | + |
| 953 | + |
| 954 | +class _ManagedWork(dist._Work): |
| 955 | + def __init__( |
| 956 | + self, |
| 957 | + work: dist._Work, |
| 958 | + manager: Manager, |
| 959 | + tensor: torch.Tensor, |
| 960 | + num_participants: int, |
| 961 | + ) -> None: |
| 962 | + super().__init__() |
| 963 | + self._manager = manager |
| 964 | + self._work = work |
| 965 | + self._tensor = tensor |
| 966 | + self._num_participants = num_participants |
| 967 | + self._fut: Union[ |
| 968 | + torch.futures.Future[torch.Tensor], torch.futures.Future[None] |
| 969 | + ] = work.get_future() |
| 970 | + |
| 971 | + self._stream: Optional[torch.cuda.Stream] = ( |
| 972 | + torch.cuda.current_stream() if torch.cuda.is_available() else None |
| 973 | + ) |
| 974 | + |
| 975 | + self._is_set_future_callback_called = False |
| 976 | + |
| 977 | + def _set_future_callback( |
| 978 | + self, |
| 979 | + ) -> None: |
| 980 | + if self._is_set_future_callback_called: |
| 981 | + return |
| 982 | + |
| 983 | + # schedule grad normalization as a continuation |
| 984 | + # on the Future |
| 985 | + @torch.profiler.record_function("torchft::manager::allreduce::callback") |
| 986 | + def callback( |
| 987 | + fut: torch.futures.Future[List[torch.Tensor]], |
| 988 | + ) -> torch.Tensor: |
| 989 | + # change the stream to avoid making the callback stream |
| 990 | + # dependent on process group stream running the allreduce |
| 991 | + with ( |
| 992 | + torch.cuda.stream(self._stream) |
| 993 | + if self._stream is not None |
| 994 | + else nullcontext() |
| 995 | + ): |
| 996 | + # Setup stream dependency |
| 997 | + fut.wait() |
| 998 | + self._tensor /= self._num_participants |
| 999 | + |
| 1000 | + return self._tensor |
| 1001 | + |
| 1002 | + fut = self._fut |
| 1003 | + fut = fut.then(callback) |
| 1004 | + fut = self._manager.wrap_future(fut, self._tensor) |
| 1005 | + self._fut = fut |
| 1006 | + |
| 1007 | + self._is_set_future_callback_called = True |
| 1008 | + |
| 1009 | + def wait(self, timeout: Optional[timedelta] = None) -> bool: |
| 1010 | + with ( |
| 1011 | + torch.cuda.stream(self._stream) |
| 1012 | + if self._stream is not None |
| 1013 | + else nullcontext() |
| 1014 | + ): |
| 1015 | + self._work.wait() |
| 1016 | + |
| 1017 | + self._set_future_callback() |
| 1018 | + |
| 1019 | + with ( |
| 1020 | + torch.cuda.stream(self._stream) |
| 1021 | + if self._stream is not None |
| 1022 | + else nullcontext() |
| 1023 | + ): |
| 1024 | + self._fut.wait() |
| 1025 | + |
| 1026 | + return True |
| 1027 | + |
| 1028 | + def block_current_stream(self, timeout: Optional[timedelta] = None) -> None: |
| 1029 | + with ( |
| 1030 | + torch.cuda.stream(self._stream) |
| 1031 | + if self._stream is not None |
| 1032 | + else nullcontext() |
| 1033 | + ): |
| 1034 | + self._work.block_current_stream() |
| 1035 | + |
| 1036 | + self._set_future_callback() |
| 1037 | + |
| 1038 | + def synchronize(self) -> None: |
| 1039 | + if torch.cuda.is_available(): |
| 1040 | + self.block_current_stream() |
| 1041 | + else: |
| 1042 | + # No stream dependencies need to be set |
| 1043 | + self._set_future_callback() |
| 1044 | + |
| 1045 | + def get_future( |
| 1046 | + self, |
| 1047 | + ) -> Union[torch.futures.Future[torch.Tensor], torch.futures.Future[None]]: |
| 1048 | + assert ( |
| 1049 | + self._is_set_future_callback_called |
| 1050 | + ), "getting the future without calling synchronize() is unsafe" |
| 1051 | + return self._fut |
0 commit comments