Skip to content

Commit 4bf29d5

Browse files
committed
process_group: add inprocess ProcessGroupNCCL
1 parent f07c80c commit 4bf29d5

File tree

2 files changed

+72
-6
lines changed

2 files changed

+72
-6
lines changed

torchft/process_group.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -85,21 +85,30 @@ def getBackendName(self) -> str:
8585
raise NotImplementedError("not implemented")
8686

8787

88-
class ProcessGroupGloo(ProcessGroup):
88+
class ProcessGroupWrapper(ProcessGroup):
89+
PG_CLASS: Type[BaseProcessGroup]
8990
"""
90-
This is a wrapper around ProcessGroupGloo with a reconfiguration argument.
91+
This is a wrapper around any ProcessGroup with a reconfiguration method.
9192
"""
9293

93-
def __init__(self) -> None:
94+
def __init__(self, timeout: float = 60.0) -> None:
95+
"""
96+
Args:
97+
timeout: the timeout to use for the process group
98+
"""
9499
super().__init__(0, 1)
95100
self._pg = None
96101

97102
def configure(self, store_addr: str, rank: int, world_size: int) -> None:
103+
if self._pg is not None:
104+
if hasattr(self._pg, "abort"):
105+
self._pg.abort()
106+
self._pg = None
107+
98108
store = create_store(store_addr)
99109

100-
# TODO: set lower timeout
101-
# pyre-fixme[16]: no attribute ProcessGroupGloo
102-
self._pg = BaseProcessGroupGloo(store, rank, world_size)
110+
# TODO: set global timeout
111+
self._pg = self.PG_CLASS(store, rank, world_size)
103112

104113
def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
105114
return self._pg.allreduce(tensors, opts)
@@ -118,10 +127,35 @@ def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work:
118127
def size(self) -> int:
119128
return self._pg.size()
120129

130+
131+
class ProcessGroupGloo(ProcessGroupWrapper):
132+
"""
133+
This is a reconfigurable version of ProcessGroupGloo.
134+
"""
135+
136+
PG_CLASS = BaseProcessGroupGloo
137+
121138
def getBackendName(self) -> str:
122139
return "torchft-gloo"
123140

124141

142+
class ProcessGroupNCCL(ProcessGroupWrapper):
143+
"""
144+
This is a reconfigurable version of ProcessGroupNCCL.
145+
146+
WARNING: this may result in deadlocks due to NCCL error handling. This is
147+
provided for completeness but your mileage may vary.
148+
149+
TODO: verify shutdown correctness with latest NCCL. This currently will call
150+
abort when reconfiguring, we need to ensure this is safe.
151+
"""
152+
153+
PG_CLASS = BaseProcessGroupNCCL
154+
155+
def getBackendName(self) -> str:
156+
return "torchft-nccl"
157+
158+
125159
class DummyWork(dist._Work):
126160
def __init__(self, result):
127161
super().__init__()

torchft/process_group_test.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
ProcessGroupBabyGloo,
1616
ProcessGroupBabyNCCL,
1717
ProcessGroupGloo,
18+
ProcessGroupNCCL,
1819
ProcessGroupDummy,
1920
ProcessGroup,
2021
)
@@ -41,6 +42,37 @@ def test_gloo(self) -> None:
4142
m = torch.nn.parallel.DistributedDataParallel(m, process_group=pg)
4243
m(torch.rand(2, 3))
4344

45+
@skipUnless(torch.cuda.is_available(), "needs CUDA")
46+
def test_nccl(self) -> None:
47+
store = TCPStore(
48+
host_name="localhost", port=0, is_master=True, wait_for_workers=False
49+
)
50+
device = "cuda"
51+
52+
store_addr = f"localhost:{store.port}/prefix"
53+
pg = ProcessGroupNCCL()
54+
pg.configure(store_addr, 0, 1)
55+
56+
self.assertEqual(pg.size(), 1)
57+
58+
at = torch.tensor([2], device=device)
59+
a_work = pg.allreduce([at], ReduceOp.SUM)
60+
a_work.wait()
61+
62+
m = nn.Linear(3, 4).to(device)
63+
m = torch.nn.parallel.DistributedDataParallel(m, process_group=pg)
64+
m(torch.rand(2, 3, device=device))
65+
66+
# reconfigure
67+
store_addr = f"localhost:{store.port}/prefix2"
68+
pg.configure(store_addr, 0, 1)
69+
70+
at = torch.tensor([2], device=device)
71+
a_work = pg.allreduce([at], ReduceOp.SUM)
72+
a_work.wait()
73+
74+
torch.cuda.synchronize()
75+
4476
def test_baby_gloo(self) -> None:
4577
store = TCPStore(
4678
host_name="localhost", port=0, is_master=True, wait_for_workers=False

0 commit comments

Comments
 (0)