Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,8 @@ def _async_quorum(
quorum_timeout: timedelta,
curr_device: int,
) -> None:
torch.multiprocessing._set_thread_name("torchft_quorum")

if curr_device >= 0 and torch.cuda.is_available():
torch.cuda.set_device(curr_device)
quorum = self._client._quorum(
Expand Down Expand Up @@ -605,6 +607,9 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool:

self._pending_work = []

if err := self._pg.errored():
self.report_error(err)

# apply state_dict if healing
if self._healing:
self._apply_pending_state_dict()
Expand Down
35 changes: 35 additions & 0 deletions torchft/manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def _create_manager(
timeout: timedelta = timedelta(seconds=10),
) -> Manager:
pg = create_autospec(ProcessGroup)
pg.errored.return_value = None

self.store = TCPStore(
host_name="localhost", port=0, is_master=True, wait_for_workers=False
)
Expand Down Expand Up @@ -408,6 +410,39 @@ def test_allreduce_error(self, client_mock: MagicMock) -> None:
manager.allreduce(torch.tensor([1.0])).wait()
self.assertTrue(manager.should_commit())

@patch("torchft.manager.ManagerClient", autospec=True)
def test_pg_errored(self, client_mock: MagicMock) -> None:
manager = self._create_manager()
client_mock().should_commit = mock_should_commit

quorum = QuorumResult()
quorum.quorum_id = 123
quorum.replica_rank = 1
quorum.replica_world_size = 2
quorum.recover_src_manager_address = "manager address"
quorum.store_address = f"localhost:{self.store.port}"
quorum.max_step = 1
quorum.max_rank = 1
quorum.max_world_size = 2
quorum.heal = False

client_mock()._quorum.return_value = quorum

self.assertEqual(manager._quorum_id, -1)
self.assertEqual(manager.current_step(), 0)

manager.start_quorum()

injected_failure = RuntimeError("injected failure")

# pyre-ignore[16]: _pg is mocked
manager._pg.errored.return_value = injected_failure

self.assertFalse(manager.should_commit())
self.assertEqual(manager._errored, injected_failure)
# pyre-ignore[16]: _pg is mocked
self.assertEqual(manager._pg.errored.call_count, 1)

@patch("torchft.manager.ManagerClient", autospec=True)
def test_quorum_fixed_world_size(self, client_mock: MagicMock) -> None:
# test active and spares
Expand Down
22 changes: 22 additions & 0 deletions torchft/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,12 @@ def shutdown(self) -> None:
"""
pass

def errored(self) -> Optional[Exception]:
"""
Whether an async error occured that requires reconfiguration.
"""
return None

def __repr__(self) -> str:
return f"{self.__class__.__name__}()"

Expand Down Expand Up @@ -657,6 +663,8 @@ def __init__(self, timeout: timedelta = timedelta(seconds=60.0)) -> None:
super().__init__(timeout)
self._use_abort: bool = torch.cuda.nccl.version() >= (2, 25)

self._errored: Optional[Exception] = None

def _opts_hook(self, opts: T) -> T:
if not self._use_abort:
return opts
Expand All @@ -679,6 +687,8 @@ def _wrap_work(self, work: Work, opts: object) -> Work:
return _WorkCUDATimeout(self, work, timeout)

def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
self._errored = None

pg = BaseProcessGroup(store, rank, world_size)
pg._set_default_backend(ProcessGroup.BackendType.NCCL)
# pyre-fixme[16]: no attribute ProcessGroupNCCL
Expand All @@ -689,6 +699,18 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro
)
return pg

def abort(self) -> None:
super().abort()

self._errored = RuntimeError("aborted")

def errored(self) -> Optional[Exception]:
pg = self._pg
if pg is not None:
pg._wait_for_pending_works()

return self._errored

def getBackendName(self) -> str:
return "torchft-nccl"

Expand Down
6 changes: 6 additions & 0 deletions torchft/process_group_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,6 +921,8 @@ def _run_with_resiliency(self, collective: str, device: str = "cpu") -> None:
def worker(pg: ProcessGroup, rank: int, dev: str) -> str:
if dev == "cuda":
torch.cuda.set_device(rank)
# Use a separate stream to avoid deadlocks between threads.
torch.cuda.set_stream(torch.cuda.Stream())

fault_rank = self.WORLD_SIZE - 1
test = _COLLECTIVE_TO_FUNC[collective]
Expand Down Expand Up @@ -952,6 +954,10 @@ def worker(pg: ProcessGroup, rank: int, dev: str) -> str:
test(pg, rank, t1.clone())
raise RuntimeError("no error")

if err := pg.errored():
with self.assertRaisesRegex(RuntimeError, "aborted"):
raise err

return f"Rank{rank} final success."

# run in parallel
Expand Down