diff --git a/torchft/manager.py b/torchft/manager.py index 85af6235..0697bd4d 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -445,6 +445,8 @@ def _async_quorum( quorum_timeout: timedelta, curr_device: int, ) -> None: + torch.multiprocessing._set_thread_name("torchft_quorum") + if curr_device >= 0 and torch.cuda.is_available(): torch.cuda.set_device(curr_device) quorum = self._client._quorum( @@ -605,6 +607,9 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool: self._pending_work = [] + if err := self._pg.errored(): + self.report_error(err) + # apply state_dict if healing if self._healing: self._apply_pending_state_dict() diff --git a/torchft/manager_test.py b/torchft/manager_test.py index 05793e1b..fb134967 100644 --- a/torchft/manager_test.py +++ b/torchft/manager_test.py @@ -42,6 +42,8 @@ def _create_manager( timeout: timedelta = timedelta(seconds=10), ) -> Manager: pg = create_autospec(ProcessGroup) + pg.errored.return_value = None + self.store = TCPStore( host_name="localhost", port=0, is_master=True, wait_for_workers=False ) @@ -408,6 +410,39 @@ def test_allreduce_error(self, client_mock: MagicMock) -> None: manager.allreduce(torch.tensor([1.0])).wait() self.assertTrue(manager.should_commit()) + @patch("torchft.manager.ManagerClient", autospec=True) + def test_pg_errored(self, client_mock: MagicMock) -> None: + manager = self._create_manager() + client_mock().should_commit = mock_should_commit + + quorum = QuorumResult() + quorum.quorum_id = 123 + quorum.replica_rank = 1 + quorum.replica_world_size = 2 + quorum.recover_src_manager_address = "manager address" + quorum.store_address = f"localhost:{self.store.port}" + quorum.max_step = 1 + quorum.max_rank = 1 + quorum.max_world_size = 2 + quorum.heal = False + + client_mock()._quorum.return_value = quorum + + self.assertEqual(manager._quorum_id, -1) + self.assertEqual(manager.current_step(), 0) + + manager.start_quorum() + + injected_failure = RuntimeError("injected failure") + + # pyre-ignore[16]: _pg is mocked + manager._pg.errored.return_value = injected_failure + + self.assertFalse(manager.should_commit()) + self.assertEqual(manager._errored, injected_failure) + # pyre-ignore[16]: _pg is mocked + self.assertEqual(manager._pg.errored.call_count, 1) + @patch("torchft.manager.ManagerClient", autospec=True) def test_quorum_fixed_world_size(self, client_mock: MagicMock) -> None: # test active and spares diff --git a/torchft/process_group.py b/torchft/process_group.py index 2b13593e..5ae5ce85 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -344,6 +344,12 @@ def shutdown(self) -> None: """ pass + def errored(self) -> Optional[Exception]: + """ + Whether an async error occured that requires reconfiguration. + """ + return None + def __repr__(self) -> str: return f"{self.__class__.__name__}()" @@ -657,6 +663,8 @@ def __init__(self, timeout: timedelta = timedelta(seconds=60.0)) -> None: super().__init__(timeout) self._use_abort: bool = torch.cuda.nccl.version() >= (2, 25) + self._errored: Optional[Exception] = None + def _opts_hook(self, opts: T) -> T: if not self._use_abort: return opts @@ -679,6 +687,8 @@ def _wrap_work(self, work: Work, opts: object) -> Work: return _WorkCUDATimeout(self, work, timeout) def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGroup: + self._errored = None + pg = BaseProcessGroup(store, rank, world_size) pg._set_default_backend(ProcessGroup.BackendType.NCCL) # pyre-fixme[16]: no attribute ProcessGroupNCCL @@ -689,6 +699,18 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro ) return pg + def abort(self) -> None: + super().abort() + + self._errored = RuntimeError("aborted") + + def errored(self) -> Optional[Exception]: + pg = self._pg + if pg is not None: + pg._wait_for_pending_works() + + return self._errored + def getBackendName(self) -> str: return "torchft-nccl" diff --git a/torchft/process_group_test.py b/torchft/process_group_test.py index 5a072fe5..59b05fb8 100644 --- a/torchft/process_group_test.py +++ b/torchft/process_group_test.py @@ -921,6 +921,8 @@ def _run_with_resiliency(self, collective: str, device: str = "cpu") -> None: def worker(pg: ProcessGroup, rank: int, dev: str) -> str: if dev == "cuda": torch.cuda.set_device(rank) + # Use a separate stream to avoid deadlocks between threads. + torch.cuda.set_stream(torch.cuda.Stream()) fault_rank = self.WORLD_SIZE - 1 test = _COLLECTIVE_TO_FUNC[collective] @@ -952,6 +954,10 @@ def worker(pg: ProcessGroup, rank: int, dev: str) -> str: test(pg, rank, t1.clone()) raise RuntimeError("no error") + if err := pg.errored(): + with self.assertRaisesRegex(RuntimeError, "aborted"): + raise err + return f"Rank{rank} final success." # run in parallel