From bbaf95e07d31c42027a1ff97de0c0d8af9f5d6d1 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Wed, 27 Nov 2024 18:47:22 -0800 Subject: [PATCH 1/2] manager: added participant information --- torchft/manager.py | 39 ++++++++++++++++++++++++++++++++------- torchft/manager_test.py | 10 +++++++--- 2 files changed, 39 insertions(+), 10 deletions(-) diff --git a/torchft/manager.py b/torchft/manager.py index ae2f15ae..a48a5ef3 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -32,7 +32,7 @@ import uuid from concurrent.futures import ThreadPoolExecutor from datetime import timedelta -from typing import Dict, List, Optional +from typing import Callable, Dict, List, Optional, TYPE_CHECKING import torch from torch.distributed import PrefixStore, ReduceOp, TCPStore, Work @@ -42,6 +42,9 @@ # pyre-fixme[21]: can't find rust module from torchft.torchft import Manager as _Manager, ManagerClient +if TYPE_CHECKING: + from torchft.process_group import ProcessGroup + logger: logging.Logger = logging.getLogger(__name__) MANAGER_ADDR_KEY: str = "manager_addr" @@ -58,9 +61,9 @@ class Manager: def __init__( self, - pg, - load_state_dict, - state_dict, + pg: "ProcessGroup", + load_state_dict: Callable[[object], None], + state_dict: Callable[[], object], min_replica_size: int, port: int = MANAGER_DEFAULT_PORT, use_async_quorum: bool = True, @@ -182,8 +185,7 @@ def allreduce_grad(self, grad: torch.Tensor) -> torch.futures.Future[torch.Tenso self._quorum_future.result() - if self._healing: - assert self._use_async_quorum + if not self.is_participating(): grad.zero_() # TODO: increase timeout when waiting when healing @@ -209,7 +211,7 @@ def callback( self._errored = True return grad - grad /= self._participating_replicas + grad /= self.num_participants() return grad @@ -411,3 +413,26 @@ def batches_committed(self) -> int: the total number of batches committed """ return self._batches_committed + + def num_participants(self) -> int: + """ + Get the number of participants in the current quorum. + + This is the number of replicas participating in the current step. + + Returns: + the number of participants in the current quorum + """ + return self._participating_replicas + + def is_participating(self) -> bool: + """ + Get whether this replica is participating in the current quorum. + + Returns: + whether this replica is participating in the current quorum + """ + if self._healing: + assert self._use_async_quorum + return False + return True diff --git a/torchft/manager_test.py b/torchft/manager_test.py index 3d3e4a0b..53d882d2 100644 --- a/torchft/manager_test.py +++ b/torchft/manager_test.py @@ -5,15 +5,15 @@ # LICENSE file in the root directory of this source tree. from unittest import TestCase -from unittest.mock import patch, create_autospec, MagicMock +from unittest.mock import create_autospec, MagicMock, patch import torch from torch.distributed import TCPStore - -from torchft.torchft import ManagerClient from torchft.manager import Manager, MANAGER_ADDR_KEY from torchft.process_group import ProcessGroup +from torchft.torchft import ManagerClient + class TestManager(TestCase): def _create_manager( @@ -129,6 +129,8 @@ def test_quorum_heal_sync(self, client_mock) -> None: manager.step() manager.allreduce_grad(torch.tensor([1.0])).wait() self.assertFalse(manager._healing) + self.assertTrue(manager.is_participating()) + self.assertEqual(manager.num_participants(), 2) self.assertTrue(manager.should_commit()) self.assertEqual(manager._quorum_id, 123) @@ -164,6 +166,8 @@ def test_quorum_heal_async_not_enough_participants(self, client_mock) -> None: manager.step() manager._quorum_future.result() self.assertTrue(manager._healing) + self.assertFalse(manager.is_participating()) + self.assertEqual(manager.num_participants(), 1) grad = torch.tensor([1.0]) manager.allreduce_grad(grad).wait() From 9c13c5f1c6902e4e3b75c6825b48f0bf107a8295 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Wed, 27 Nov 2024 19:11:16 -0800 Subject: [PATCH 2/2] manager: error reporting APIs and numerics test --- torchft/manager.py | 69 +++++++++++++++++++++++++++++++++-------- torchft/manager_test.py | 39 ++++++++++++++++++++++- 2 files changed, 94 insertions(+), 14 deletions(-) diff --git a/torchft/manager.py b/torchft/manager.py index a48a5ef3..ec3d3d04 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -178,7 +178,7 @@ def allreduce_grad(self, grad: torch.Tensor) -> torch.futures.Future[torch.Tenso Returns: a Future that will be completed with the allreduced gradient """ - if self._errored: + if self.errored(): fut = torch.futures.Future() fut.set_result(grad) return fut @@ -195,38 +195,81 @@ def allreduce_grad(self, grad: torch.Tensor) -> torch.futures.Future[torch.Tenso work = self._pg.allreduce([grad], ReduceOp.SUM) fut = work.get_future() - # schedule error handling and grad normalization as a continuation + # schedule grad normalization as a continuation # on the Future def callback( fut: torch.futures.Future[List[torch.Tensor]], ) -> torch.futures.Future[torch.Tensor]: nonlocal grad - try: - val = fut.value() - except Exception: - logger.exception( - "got exception in all reduce future -- skipping remaining" - ) - self._errored = True - return grad + fut.value() grad /= self.num_participants() return grad fut = fut.then(callback) - self._pending_work.append(fut) + fut = self.wrap_future(fut, grad) return fut except Exception as e: - logger.exception("got exception in all reduce -- skipping remaining") - self._errored = True + logger.exception(f"got exception in all reduce -- skipping remaining: {e}") + self.report_error() fut = torch.futures.Future() fut.set_result(grad) return fut + def report_error(self) -> None: + """ + Report an error to the manager. + + This will cause the manager to skip the current step and will be + reconfigured on the next step. + + This should be called when an error occurs that leads to a corrupted + gradient that needs to be discarded. + """ + self._errored = True + + def errored(self) -> bool: + """ + Get whether an error has occurred. + + Returns: + whether an error has occurred + """ + return self._errored + + def wrap_future(self, fut: torch.futures.Future[object], default: object) -> None: + """ + Wrap a Future and swallow any errors that occur and report them to the manager. + + If an error occurs, the Future will be completed with the default value. + + Args: + fut: the Future to wrap + default: the default value to complete the Future with if an error occurs + """ + + # schedule error handling and grad normalization as a continuation + # on the Future + def callback( + fut: torch.futures.Future[List[torch.Tensor]], + ) -> torch.futures.Future[torch.Tensor]: + nonlocal default + + try: + return fut.value() + except Exception as e: + logger.exception(f"got exception in future -- skipping remaining: {e}") + self.report_error() + return default + + fut = fut.then(callback) + self._pending_work.append(fut) + return fut + def step(self) -> None: """ .. note:: diff --git a/torchft/manager_test.py b/torchft/manager_test.py index 53d882d2..6b350a4f 100644 --- a/torchft/manager_test.py +++ b/torchft/manager_test.py @@ -10,7 +10,7 @@ import torch from torch.distributed import TCPStore from torchft.manager import Manager, MANAGER_ADDR_KEY -from torchft.process_group import ProcessGroup +from torchft.process_group import _DummyWork, ProcessGroup from torchft.torchft import ManagerClient @@ -311,3 +311,40 @@ def test_allreduce_error(self, client_mock) -> None: manager.step() manager.allreduce_grad(torch.tensor([1.0])).wait() self.assertTrue(manager.should_commit()) + + @patch("torchft.manager.ManagerClient", autospec=True) + def test_manager_report_error(self, client_mock) -> None: + manager = self._create_manager() + + self.assertFalse(manager.errored()) + manager.report_error() + self.assertTrue(manager.errored()) + + @patch("torchft.manager.ManagerClient", autospec=True) + def test_manager_wrap_future(self, client_mock) -> None: + manager = self._create_manager() + + self.assertFalse(manager.errored()) + + fut = torch.futures.Future() + wrapped_fut = manager.wrap_future(fut, 2) + + fut.set_exception(RuntimeError("injected failure")) + + self.assertEqual(wrapped_fut.value(), 2) + self.assertTrue(manager.errored()) + self.assertEqual(manager._pending_work, [wrapped_fut]) + + @patch("torchft.manager.ManagerClient", autospec=True) + def test_manager_numerics(self, client_mock) -> None: + manager = self._create_manager() + + manager._quorum_future = MagicMock() + manager._participating_replicas = 5 + self.assertEqual(manager.num_participants(), 5) + manager._pg.allreduce.return_value = _DummyWork(None) + + fut = torch.futures.Future() + fut = manager.allreduce_grad(torch.tensor([1.0])) + result = fut.value() + torch.testing.assert_close(result, torch.tensor([1.0 / 5]))