diff --git a/torchft/_test/managed_work_test.py b/torchft/_test/managed_work_test.py new file mode 100644 index 00000000..30002dd2 --- /dev/null +++ b/torchft/_test/managed_work_test.py @@ -0,0 +1,215 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import types +import unittest +from datetime import timedelta +from typing import Callable, List, Optional + +import parameterized +import torch +from torch.distributed.distributed_c10d import Work +from torch.futures import Future + +from torchft.manager import Manager, _ManagedWork + + +class SimpleWork(Work): + """A simple implementation of torch.distributed.Work for testing.""" + + def __init__(self, tensors: List[torch.Tensor]) -> None: + super().__init__() + self._tensors = tensors + self._future: Future[List[torch.Tensor]] = torch.futures.Future() + self._is_completed: bool = False + + def wait(self, timeout: Optional[timedelta] = None) -> bool: + self._is_completed = True + self._future.set_result(self._tensors) + return True + + def get_future(self) -> Future[List[torch.Tensor]]: + return self._future + + +class TestManagedWork(unittest.TestCase): + @parameterized.parameterized.expand( + [ + ("cpu", torch.device("cpu")), + ("cuda", torch.device("cuda:0")), + ] + ) + def test_callbacks_execute_after_wait( + self, name: str, device: torch.device + ) -> None: + """Test that callbacks are only executed after wait() is called.""" + # Skip if CUDA is requested but not available + if device.type == "cuda" and not torch.cuda.is_available(): + self.skipTest("CUDA not available") + + # Create a tensor to work with + tensor = torch.ones(1, dtype=torch.float32, device=device) + + # Create a simple work object + work = SimpleWork([tensor]) + + # Create a minimal manager object with just the wrap_future method + manager = Manager.__new__(Manager) # Create instance without calling __init__ + # We're using types.MethodType to attach a method to the manager instance + # This is just for testing purposes + manager.wrap_future = types.MethodType( # type: ignore + lambda self, fut, default, timeout=None: fut, manager + ) + + # Create the managed work + managed_work = _ManagedWork(work, manager, [tensor]) + + # Track callback execution + callback_executed: bool = False + + def callback(fut: Future[List[torch.Tensor]]) -> List[torch.Tensor]: + nonlocal callback_executed + callback_executed = True + # Multiply tensor by 2 to verify the callback ran + fut.value()[0].mul_(2) + return fut.value() + + # Add the callback + managed_work.add_callback(callback) + + # Verify callback hasn't executed yet + self.assertFalse(callback_executed) + self.assertEqual(tensor.item(), 1.0) + + # Call wait() which should trigger the callback + managed_work.wait() + + # Verify callback has executed + self.assertTrue(callback_executed) + self.assertEqual(tensor.item(), 2.0) + + @parameterized.parameterized.expand( + [ + ("cpu", torch.device("cpu")), + ("cuda", torch.device("cuda:0")), + ] + ) + def test_multiple_callbacks_execute_in_order( + self, name: str, device: torch.device + ) -> None: + """Test that multiple callbacks are executed in the order they were added.""" + # Skip if CUDA is requested but not available + if device.type == "cuda" and not torch.cuda.is_available(): + self.skipTest("CUDA not available") + + # Create a tensor to work with + tensor = torch.ones(1, dtype=torch.float32, device=device) + + # Create a simple work object + work = SimpleWork([tensor]) + + # Create a minimal manager object with just the wrap_future method + manager = Manager.__new__(Manager) # Create instance without calling __init__ + manager.wrap_future = types.MethodType( # type: ignore + lambda self, fut, default, timeout=None: fut, manager + ) + + # Create the managed work + managed_work = _ManagedWork(work, manager, [tensor]) + + # Track execution order + execution_order: List[int] = [] + + def callback1(fut: Future[List[torch.Tensor]]) -> List[torch.Tensor]: + execution_order.append(1) + fut.value()[0].add_(1) + return fut.value() + + def callback2(fut: Future[List[torch.Tensor]]) -> List[torch.Tensor]: + execution_order.append(2) + fut.value()[0].add_(2) + return fut.value() + + def callback3(fut: Future[List[torch.Tensor]]) -> List[torch.Tensor]: + execution_order.append(3) + fut.value()[0].add_(3) + return fut.value() + + # Add callbacks + managed_work.add_callback(callback1) + managed_work.add_callback(callback2) + managed_work.add_callback(callback3) + + # Verify no callbacks have executed yet + self.assertEqual(len(execution_order), 0) + self.assertEqual(tensor.item(), 1.0) + + # Call wait() which should trigger the callbacks + managed_work.wait() + + # Verify callbacks executed in order + self.assertEqual(execution_order, [1, 2, 3]) + + # Each callback adds to the tensor, so final value should be 1 + 1 + 2 + 3 = 7 + self.assertEqual(tensor.item(), 7.0) + + @parameterized.parameterized.expand( + [ + ("cpu", torch.device("cpu")), + ("cuda", torch.device("cuda:0")), + ] + ) + def test_future_then_api(self, name: str, device: torch.device) -> None: + """Test that the future's then API works correctly with ManagedWork.""" + # Skip if CUDA is requested but not available + if device.type == "cuda" and not torch.cuda.is_available(): + self.skipTest("CUDA not available") + + # Create a tensor to work with + tensor = torch.ones(1, dtype=torch.float32, device=device) + + # Create a simple work object + work = SimpleWork([tensor]) + + # Create a minimal manager object with just the wrap_future method + manager = Manager.__new__(Manager) # Create instance without calling __init__ + manager.wrap_future = types.MethodType( # type: ignore + lambda self, fut, default, timeout=None: fut, manager + ) + + # Create the managed work + managed_work = _ManagedWork(work, manager, [tensor]) + + # Get the future + future = managed_work.get_future() + + # Track callback execution + callback_executed: bool = False + + def callback(fut: Future[List[torch.Tensor]]) -> List[torch.Tensor]: + nonlocal callback_executed + callback_executed = True + # Multiply tensor by 3 to verify the callback ran + fut.value()[0].mul_(3) + return fut.value() + + # Use the then API + future.then(callback) + + # Verify callback hasn't executed yet + self.assertFalse(callback_executed) + self.assertEqual(tensor.item(), 1.0) + + # Call wait() which should trigger the callback + future.wait() + + # Verify callback has executed + self.assertTrue(callback_executed) + self.assertEqual(tensor.item(), 3.0) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchft/ddp.py b/torchft/ddp.py index 494a9b13..2ad7260b 100644 --- a/torchft/ddp.py +++ b/torchft/ddp.py @@ -69,8 +69,22 @@ def _comm_hook( state: "Manager", bucket: dist.GradBucket ) -> torch.futures.Future[torch.Tensor]: work = state.allreduce(bucket.buffer()) - work.synchronize() - return work.get_future() + + result_fut: torch.futures.Future[torch.Tensor] = torch.futures.Future() + + fut = work.get_future() + + def callback( + tensors: torch.futures.Future[list[torch.Tensor]], + ) -> list[torch.Tensor]: + nonlocal result_fut + result_fut.set_result(tensors.value()[0]) + return [] + + fut = fut.then(callback) + + work.wait() + return result_fut class PureDistributedDataParallel(nn.Module): diff --git a/torchft/local_sgd.py b/torchft/local_sgd.py index e92d4bd7..5129e670 100644 --- a/torchft/local_sgd.py +++ b/torchft/local_sgd.py @@ -530,7 +530,9 @@ def _bucketize_and_allreduce( flat_buffer, should_quantize=self.should_quantize ) - def callback(fut: torch.futures.Future[torch.Tensor]) -> None: + def callback( + fut: torch.futures.Future[list[torch.Tensor]], + ) -> list[torch.Tensor]: with torch.cuda.stream(self._stream) if self._stream else nullcontext(): nonlocal bucket_tensors, flat_buffer # Setup stream dependency @@ -540,9 +542,10 @@ def callback(fut: torch.futures.Future[torch.Tensor]) -> None: flat_buffer[pack_offset : pack_offset + numel].view_as(t) ) - work.synchronize() + return [] + fut = work.get_future() - fut.add_done_callback(callback) + fut = fut.then(callback) self._allreduce_work.append(work) diff --git a/torchft/manager.py b/torchft/manager.py index ad9a0566..250b67bc 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -418,7 +418,19 @@ def allreduce( else: work = self._pg.allreduce([tensor], ReduceOp.SUM) - return _ManagedWork(work, self, tensor, num_participants) + # schedule grad normalization as a continuation + # on the Future + @torch.profiler.record_function("torchft::manager::allreduce::callback") + def callback( + tensors: torch.futures.Future[list[torch.Tensor]], + ) -> list[torch.Tensor]: + nonlocal num_participants, tensor + tensor /= num_participants + return [tensor] + + managed_work = _ManagedWork(work, self, [tensor]) + managed_work.add_callback(callback) + return managed_work except Exception as e: self._logger.exception( @@ -943,22 +955,56 @@ def exception(self, msg: str) -> None: self._logger.exception(f"{self.prefix()} {msg}") +class _SimpleFuture(torch.futures.Future[list[torch.Tensor]]): + def __init__(self, tensors: list[torch.Tensor]) -> None: + super().__init__() + self._tensors = tensors + + def value(self) -> list[torch.Tensor]: + return self._tensors + + +class _ManagedFuture(torch.futures.Future[list[torch.Tensor]]): + def __init__(self, work: "_ManagedWork") -> None: + super().__init__() + self._work = work + + def then( + self, + callback: Callable[[torch.futures.Future[list[torch.Tensor]]], torch.futures.S], + ) -> torch.futures.Future[torch.futures.S]: + self._work.add_callback( + cast( + Callable[ + [torch.futures.Future[list[torch.Tensor]]], list[torch.Tensor] + ], + callback, + ) + ) + return cast(torch.futures.Future[torch.futures.S], self) + + def wait(self) -> List[torch.Tensor]: + self._work.wait() + return self._work._tensors + + def value(self) -> list[torch.Tensor]: + self._work.wait() + return self._work._tensors + + class _ManagedWork(dist._Work): def __init__( self, work: dist._Work, manager: Manager, - tensor: torch.Tensor, - num_participants: int, + tensors: list[torch.Tensor], ) -> None: super().__init__() self._manager = manager self._work = work - self._tensor = tensor - self._num_participants = num_participants - self._fut: Union[ - torch.futures.Future[torch.Tensor], torch.futures.Future[None] - ] = work.get_future() + self._tensors = tensors + self._fut: torch.futures.Future[list[torch.Tensor]] = work.get_future() + self._managed_fut = _ManagedFuture(self) self._stream: Optional[torch.cuda.Stream] = ( torch.cuda.current_stream() if torch.cuda.is_available() else None @@ -966,39 +1012,63 @@ def __init__( self._is_set_future_callback_called = False + self._callbacks: list[ + Callable[[torch.futures.Future[list[torch.Tensor]]], list[torch.Tensor]] + ] = [] + + def add_callback( + self, + callback: Callable[ + [torch.futures.Future[list[torch.Tensor]]], list[torch.Tensor] + ], + ) -> None: + self._callbacks.append(callback) + def _set_future_callback( self, ) -> None: if self._is_set_future_callback_called: return - # schedule grad normalization as a continuation - # on the Future - @torch.profiler.record_function("torchft::manager::allreduce::callback") - def callback( - fut: torch.futures.Future[List[torch.Tensor]], - ) -> torch.Tensor: - # change the stream to avoid making the callback stream - # dependent on process group stream running the allreduce - with ( - torch.cuda.stream(self._stream) - if self._stream is not None - else nullcontext() - ): - # Setup stream dependency - fut.wait() - self._tensor /= self._num_participants + while self._callbacks: + user_callback: Callable[ + [torch.futures.Future[list[torch.Tensor]]], list[torch.Tensor] + ] = self._callbacks.pop(0) + + def callback( + fut: torch.futures.Future[list[torch.Tensor]], + ) -> list[torch.Tensor]: + nonlocal user_callback + # change the stream to avoid making the callback stream + # dependent on process group stream running the allreduce + with ( + torch.cuda.stream(self._stream) + if self._stream is not None + else nullcontext() + ): + # Setup stream dependency + fut.wait() + self._tensors = user_callback( + cast( + torch.futures.Future[list[torch.Tensor]], + _SimpleFuture(self._tensors), + ) + ) + return self._tensors - return self._tensor + self._fut = self._fut.then(callback) - fut = self._fut - fut = fut.then(callback) - fut = self._manager.wrap_future(fut, self._tensor) - self._fut = fut + self._fut = self._manager.wrap_future(self._fut, self._tensors) self._is_set_future_callback_called = True + def _assert_same_stream(self) -> None: + if self._stream is not None: + assert self._stream == torch.cuda.current_stream() + def wait(self, timeout: Optional[timedelta] = None) -> bool: + self._assert_same_stream() + with ( torch.cuda.stream(self._stream) if self._stream is not None @@ -1018,6 +1088,8 @@ def wait(self, timeout: Optional[timedelta] = None) -> bool: return True def block_current_stream(self, timeout: Optional[timedelta] = None) -> None: + self._assert_same_stream() + with ( torch.cuda.stream(self._stream) if self._stream is not None @@ -1028,6 +1100,8 @@ def block_current_stream(self, timeout: Optional[timedelta] = None) -> None: self._set_future_callback() def synchronize(self) -> None: + self._assert_same_stream() + if torch.cuda.is_available(): self.block_current_stream() else: @@ -1036,8 +1110,5 @@ def synchronize(self) -> None: def get_future( self, - ) -> Union[torch.futures.Future[torch.Tensor], torch.futures.Future[None]]: - assert ( - self._is_set_future_callback_called - ), "getting the future without calling synchronize() is unsafe" - return self._fut + ) -> torch.futures.Future[list[torch.Tensor]]: + return self._managed_fut diff --git a/torchft/manager_integ_test.py b/torchft/manager_integ_test.py index ed2d11e8..e75d5dde 100644 --- a/torchft/manager_integ_test.py +++ b/torchft/manager_integ_test.py @@ -40,6 +40,7 @@ ProcessGroupGloo, ) +logging.basicConfig(level=logging.INFO) logger: logging.Logger = logging.getLogger(__name__) INIT_LOCK: threading.Lock = threading.Lock() @@ -638,3 +639,9 @@ def all_reduce_callback( work.wait() return t1 return None + + +if __name__ == "__main__": + import unittest + + unittest.main() diff --git a/torchft/manager_test.py b/torchft/manager_test.py index 6960abce..c81453e0 100644 --- a/torchft/manager_test.py +++ b/torchft/manager_test.py @@ -591,19 +591,17 @@ def test_manager_numerics(self, client_mock: MagicMock) -> None: self.assertTrue(manager.is_participating()) work = manager.allreduce(torch.tensor([1.0])) - work.synchronize() fut = work.get_future() result = fut.value() - torch.testing.assert_close(result, torch.tensor([1.0 / 5])) + torch.testing.assert_close(result[0], torch.tensor([1.0 / 5])) # check healing numerics manager._healing = True self.assertFalse(manager.is_participating()) work = manager.allreduce(torch.tensor([1.0])) - work.synchronize() fut = work.get_future() result = fut.value() - torch.testing.assert_close(result, torch.tensor([0.0])) + torch.testing.assert_close(result[0], torch.tensor([0.0])) @patch("torchft.manager.ManagerClient", autospec=True) def test_quorum_happy_timeouts(self, client_mock: MagicMock) -> None: