Skip to content

Commit 824a9a3

Browse files
committed
ProcessGroupBabyNCCL: use CUDA events to avoid blocking the CPU thread
1 parent 219677e commit 824a9a3

File tree

2 files changed

+103
-28
lines changed

2 files changed

+103
-28
lines changed

torchft/process_group.py

Lines changed: 74 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from abc import ABC
88
import logging
9-
from typing import Type, List, Optional
9+
from typing import Type, List, Optional, Callable, Tuple
1010
from datetime import timedelta
1111

1212
from torch.futures import Future
@@ -192,22 +192,30 @@ def wait(self) -> bool:
192192
return True
193193

194194

195+
class BabyWorkNCCL(BabyWork):
196+
def wait(self) -> bool:
197+
self._tx.put(("synchronize", self._op_id), timeout=self._timeout)
198+
op_id, event = _get(self._rx, self._timeout)
199+
assert op_id == self._op_id
200+
assert isinstance(event, torch.cuda.Event)
201+
202+
# Wait on Event makes the stream wait but not the CPU thread.
203+
event.wait()
204+
205+
return True
206+
207+
195208
class ProcessGroupBaby(ProcessGroup):
196209
"""
197210
This is a process group that runs the underlying process group in a
198211
subprocess. Since it's running in a subprocess all tensors need to be in
199212
shared memory or will be moved to shared memory. CUDA tensors are implicitly
200213
share able and don't need any changes.
201214
202-
If the child process is killed while an operation is running CUDA tensors
203-
may leak in the current implementation.
204-
205-
For the NCCL backend, extra memory will be used by the subprocesses CUDA
206-
context compared to running NCCL in the main process. This is typically
207-
around ~1GB.
208215
"""
209216

210217
PG_CLASS: Type[BaseProcessGroup]
218+
WORK_CLASS: Type[BabyWork] = BabyWork
211219

212220
def __init__(self, timeout: float = 60.0) -> None:
213221
super().__init__(0, 1)
@@ -220,6 +228,23 @@ def __init__(self, timeout: float = 60.0) -> None:
220228

221229
self._timeout = timeout
222230

231+
def configure(self, store_addr: str, rank: int, world_size: int) -> None:
232+
if self._p is not None:
233+
self._p.kill()
234+
235+
self._world_size = world_size
236+
237+
ctx = mp.get_context("spawn")
238+
self._tx = ctx.Queue()
239+
self._rx = ctx.Queue()
240+
241+
self._p = ctx.Process(
242+
target=self._worker,
243+
args=(store_addr, rank, world_size, self._tx, self._rx),
244+
daemon=True,
245+
)
246+
self._p.start()
247+
223248
@classmethod
224249
def _worker(
225250
cls, store_addr: str, rank: int, world_size: int, rx: mp.Queue, tx: mp.Queue
@@ -235,37 +260,45 @@ def _worker(
235260
while True:
236261
op = rx.get()
237262
cmd = op[0]
238-
if cmd == "allreduce":
239-
work[next_op_id] = pg.allreduce(op[1], op[2])
263+
if cmd == "func":
264+
func, args, kwargs = op[1:]
265+
work[next_op_id] = getattr(pg, func)(*args, **kwargs)
240266
tx.put(next_op_id)
241267
next_op_id += 1
242268
elif cmd == "wait":
243269
op_id = op[1]
244270
work[op_id].wait()
245271
del work[op_id]
246272
tx.put(op_id)
273+
elif cmd == "synchronize":
274+
# CUDA only, use events instead of waiting on CPU
275+
op_id = op[1]
276+
277+
# With WorkNCCL this makes the stream wait not the CPU when
278+
# no timeout is passed.
279+
work[op_id].wait()
280+
281+
# Register event on the stream that we can pass to the main
282+
# process.
283+
event = torch.cuda.Event(interprocess=True)
284+
event.record()
285+
286+
del work[op_id]
287+
tx.put((op_id, event))
247288
else:
248289
raise ValueError(f"unknown cmd: {cmd}")
290+
249291
except Exception as e:
250292
logger.exception("worker errored")
251293
tx.put(e)
252294

253-
def configure(self, store_addr: str, rank: int, world_size: int) -> None:
254-
if self._p is not None:
255-
self._p.kill()
256-
257-
self._world_size = world_size
258-
259-
ctx = mp.get_context("spawn")
260-
self._tx = ctx.Queue()
261-
self._rx = ctx.Queue()
262-
263-
self._p = ctx.Process(
264-
target=self._worker,
265-
args=(store_addr, rank, world_size, self._tx, self._rx),
266-
daemon=True,
295+
def _run_func(self, func: str, *args: object, **kwargs: object) -> Work:
296+
self._tx.put(("func", func, args, kwargs), timeout=self._timeout)
297+
op_id = _get(self._rx, self._timeout)
298+
assert isinstance(op_id, int), f"invalid return {op_id}"
299+
return self.WORK_CLASS(
300+
tx=self._tx, rx=self._rx, op_id=op_id, timeout=self._timeout
267301
)
268-
self._p.start()
269302

270303
def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
271304
assert isinstance(tensors, list), "input must be list"
@@ -274,10 +307,7 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
274307
if not tensor.is_shared():
275308
tensor.share_memory_()
276309

277-
self._tx.put(("allreduce", tensors, opts), timeout=self._timeout)
278-
op_id = _get(self._rx, self._timeout)
279-
assert isinstance(op_id, int), f"invalid return {op_id}"
280-
return BabyWork(tx=self._tx, rx=self._rx, op_id=op_id, timeout=self._timeout)
310+
return self._run_func("allreduce", tensors, opts)
281311

282312
def size(self) -> int:
283313
return self._world_size
@@ -291,7 +321,23 @@ def getBackendName(self):
291321

292322

293323
class ProcessGroupBabyNCCL(ProcessGroupBaby):
324+
"""
325+
This is a ProcessGroup that runs NCCL in a subprocess.
326+
327+
For the NCCL backend, extra memory will be used by the subprocesses CUDA
328+
context compared to running NCCL in the main process. This is typically
329+
around ~1GB.
330+
331+
The returned Work objects only synchronize on the cuda stream and not on the
332+
CPU side. This works by passing CUDA Events between the processes. To do a
333+
CPU synchronize, call torch.cuda.synchronize() after wait().
334+
335+
WARNING: If the child process is killed while an operation is running, CUDA
336+
tensors may leak in the current PyTorch implementation. TODO fix
337+
"""
338+
294339
PG_CLASS = BaseProcessGroupGloo
340+
WORK_CLASS = BabyWorkNCCL
295341

296342
def getBackendName(self):
297343
return "torchft-baby-nccl"

torchft/process_group_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from torchft.process_group import (
1515
ProcessGroupBabyGloo,
16+
ProcessGroupBabyNCCL,
1617
ProcessGroupGloo,
1718
ProcessGroupDummy,
1819
ProcessGroup,
@@ -71,3 +72,31 @@ def test_dummy(self) -> None:
7172
m = nn.Linear(3, 4)
7273
m = torch.nn.parallel.DistributedDataParallel(m, process_group=pg)
7374
m(torch.rand(2, 3))
75+
76+
def test_baby_nccl(self) -> None:
77+
store = TCPStore(
78+
host_name="localhost", port=0, is_master=True, wait_for_workers=False
79+
)
80+
81+
store_addr = f"localhost:{store.port}/prefix"
82+
83+
device = "cuda"
84+
85+
a = ProcessGroupBabyNCCL()
86+
b = ProcessGroupBabyNCCL()
87+
88+
a.configure(store_addr, 0, 2)
89+
b.configure(store_addr, 1, 2)
90+
91+
self.assertEqual(a.size(), 2)
92+
93+
at = torch.tensor([1], device=device)
94+
bt = torch.tensor([2], device=device)
95+
96+
a_work = a.allreduce([at], ReduceOp.SUM)
97+
b_work = b.allreduce([bt], ReduceOp.SUM)
98+
99+
a_work.wait()
100+
b_work.wait()
101+
102+
torch.testing.assert_close(at, bt)

0 commit comments

Comments
 (0)