Skip to content

Commit f0a4061

Browse files
authored
process_group: set timeout for TCPStore client connect (#145)
1 parent 73a6f78 commit f0a4061

File tree

2 files changed

+39
-13
lines changed

2 files changed

+39
-13
lines changed

torchft/process_group.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@
8484
T = TypeVar("T")
8585

8686

87-
def create_store_client(store_addr: str) -> Store:
87+
def create_store_client(store_addr: str, timeout: timedelta) -> Store:
8888
"""
8989
Creates a PrefixStore(TCPStore(...)) client from an address in the format:
9090
@@ -100,6 +100,7 @@ def create_store_client(store_addr: str) -> Store:
100100
port=int(port),
101101
is_master=False,
102102
wait_for_workers=False,
103+
timeout=timeout,
103104
)
104105
store = PrefixStore(prefix, store)
105106
return store
@@ -350,11 +351,20 @@ def __repr__(self) -> str:
350351
class ProcessGroupWrapper(ProcessGroup):
351352
"""
352353
This is a wrapper around any ProcessGroup with a reconfiguration method.
354+
355+
Args:
356+
timeout: timeout for reconfiguration for TCPStore
357+
pg: optional ProcessGroup to use, if None a new one will be created
353358
"""
354359

355-
def __init__(self, pg: Optional[ProcessGroup] = None) -> None:
360+
def __init__(
361+
self,
362+
timeout: timedelta = timedelta(seconds=60),
363+
pg: Optional[ProcessGroup] = None,
364+
) -> None:
356365
super().__init__(0, 1)
357366
self._pg: Optional[BaseProcessGroup] = pg
367+
self._timeout = timeout
358368

359369
def configure(self, store_addr: str, rank: int, world_size: int) -> None:
360370
pg = self._pg
@@ -365,7 +375,7 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
365375
# abort if already initialized
366376
self.abort()
367377

368-
store = create_store_client(store_addr)
378+
store = create_store_client(store_addr, timeout=self._timeout)
369379

370380
self._pg = self._create_pg(store, rank, world_size)
371381

@@ -511,10 +521,6 @@ class ProcessGroupGloo(ProcessGroupWrapper):
511521
This is a reconfigurable version of ProcessGroupGloo.
512522
"""
513523

514-
def __init__(self, timeout: timedelta = timedelta(seconds=60.0)) -> None:
515-
super().__init__()
516-
self._timeout = timeout
517-
518524
def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
519525
pg = BaseProcessGroup(store, rank, world_size)
520526
pg._set_default_backend(ProcessGroup.BackendType.GLOO)
@@ -648,8 +654,7 @@ class ProcessGroupNCCL(ProcessGroupWrapper):
648654
"""
649655

650656
def __init__(self, timeout: timedelta = timedelta(seconds=60.0)) -> None:
651-
super().__init__()
652-
self._timeout = timeout
657+
super().__init__(timeout)
653658
self._use_abort: bool = torch.cuda.nccl.version() >= (2, 25)
654659

655660
def _opts_hook(self, opts: T) -> T:
@@ -877,7 +882,7 @@ class ErrorSwallowingProcessGroupWrapper(ProcessGroupWrapper):
877882
"""
878883

879884
def __init__(self, pg: ProcessGroup) -> None:
880-
super().__init__(pg)
885+
super().__init__(pg=pg)
881886

882887
self._error: Optional[Exception] = None
883888

@@ -958,7 +963,7 @@ class ManagedProcessGroup(ProcessGroupWrapper):
958963
"""
959964

960965
def __init__(self, manager: "Manager") -> None:
961-
super().__init__(manager._pg)
966+
super().__init__(pg=manager._pg)
962967

963968
self._manager = manager
964969

@@ -1195,7 +1200,11 @@ def _worker(
11951200
if curr_device >= 0 and torch.cuda.is_available():
11961201
torch.cuda.set_device(curr_device)
11971202

1198-
store = create_store_client(store_addr)
1203+
store = create_store_client(
1204+
store_addr,
1205+
# default TCPStore timeout is 5 minutes
1206+
timeout=timedelta(minutes=5),
1207+
)
11991208

12001209
try:
12011210
pg = cls._create_pg(store, rank, world_size)

torchft/process_group_test.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,23 @@ def test_nccl_apis(self) -> None:
552552

553553
torch.cuda.synchronize()
554554

555+
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
556+
@skipUnless(
557+
torch.cuda.is_available(),
558+
"needs CUDA",
559+
)
560+
def test_nccl_init_timeout(self) -> None:
561+
store = TCPStore(
562+
host_name="localhost", port=0, is_master=True, wait_for_workers=False
563+
)
564+
store_addr = f"localhost:{store.port}/prefix"
565+
del store
566+
567+
pg = ProcessGroupNCCL(timeout=timedelta(seconds=0.01))
568+
569+
with self.assertRaisesRegex(RuntimeError, "timed out after 10ms"):
570+
pg.configure(store_addr, 0, 2)
571+
555572
def test_baby_gloo_timeout(self) -> None:
556573
store = TCPStore(
557574
host_name="localhost", port=0, is_master=True, wait_for_workers=False
@@ -710,7 +727,7 @@ def test_functional_collectives(self) -> None:
710727

711728
def test_process_group_wrapper(self) -> None:
712729
pg = ProcessGroupDummy(0, 1)
713-
wrapper = ProcessGroupWrapper(pg)
730+
wrapper = ProcessGroupWrapper(pg=pg)
714731
self.assertIs(wrapper.parent, pg)
715732

716733
wrapper.configure("addr", 0, 1)

0 commit comments

Comments
 (0)