Skip to content

Commit c7e231e

Browse files
authored
train_ddp, process_group: fixes so CUDA works e2e (#5)
1 parent 5d2e55f commit c7e231e

File tree

4 files changed

+250
-131
lines changed

4 files changed

+250
-131
lines changed

torchft/manager.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,9 @@ def step(self) -> None:
197197
if not self._use_async_quorum:
198198
self._quorum_future.result()
199199

200+
# eagerly apply pending state_dict so we can run the forwards pass
201+
self._apply_pending_state_dict()
202+
200203
# we are forcing healing at the beginning so we're in a good state
201204
# and don't need to zero_grad
202205
self._healing = False
@@ -236,14 +239,27 @@ def _async_quorum(self) -> None:
236239
primary_client = ManagerClient(address, timeout=self._timeout)
237240
checkpoint_server_address = primary_client.checkpoint_address(self._rank)
238241

239-
state_dict = CheckpointServer.load_from_address(checkpoint_server_address)
240-
self._load_state_dict(state_dict["user"])
241-
self.load_state_dict(state_dict["torchft"])
242+
self._state_dict = CheckpointServer.load_from_address(
243+
checkpoint_server_address
244+
)
245+
self.load_state_dict(self._state_dict["torchft"])
246+
# we apply the user state dict only when safe from the main thread
242247

243248
# This isn't strictly needed as loading the state_dict above should
244249
# restore the correct step but it makes writing tests simpler.
245250
self._step = max_step
246251

252+
def _apply_pending_state_dict(self) -> None:
253+
assert self._healing, "must be in healing state"
254+
255+
# synchronize on future
256+
self._quorum_future.result()
257+
258+
assert self._state_dict is not None, "checkpoint was not staged"
259+
260+
self._load_state_dict(self._state_dict["user"])
261+
self._state_dict = None
262+
247263
def should_commit(self) -> bool:
248264
for work in self._pending_work:
249265
# check at the beginning of since .wait() may trigger errors
@@ -256,6 +272,10 @@ def should_commit(self) -> bool:
256272

257273
self._pending_work = []
258274

275+
# apply state_dict if healing
276+
if self._healing:
277+
self._apply_pending_state_dict()
278+
259279
enough_replicas = self._participating_replicas >= self._min_replica_size
260280
local_should_commit = enough_replicas and not self._errored
261281
should_commit = self._client.should_commit(

torchft/process_group.py

Lines changed: 94 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import logging
99
from typing import Type, List, Optional, Callable, Tuple
1010
from datetime import timedelta
11+
import threading
1112

1213
from torch.futures import Future
1314
from torch.distributed import (
@@ -26,6 +27,11 @@
2627

2728
logger = logging.getLogger(__name__)
2829

30+
# TODO: use non strings which are cheaper
31+
_QUEUE_CLOSE = "queue_close"
32+
_FUTURE_RESULT = "fut_result"
33+
_FUTURE_EXCEPTION = "fut_exception"
34+
2935

3036
def _get(queue: mp.Queue, timeout) -> object:
3137
v = queue.get(timeout=timeout)
@@ -208,9 +214,17 @@ def getBackendName(self):
208214

209215

210216
class BabyWork(Work):
211-
def __init__(self, tx: mp.Queue, rx: mp.Queue, op_id: int, timeout: float):
217+
def __init__(
218+
self,
219+
pg: "ProcessGroupBaby",
220+
tx: mp.Queue,
221+
rx: mp.Queue,
222+
op_id: int,
223+
timeout: float,
224+
):
212225
super().__init__()
213226

227+
self._pg = pg
214228
self._tx = tx
215229
self._rx = rx
216230
self._op_id = op_id
@@ -221,6 +235,9 @@ def wait(self) -> bool:
221235
assert _get(self._rx, self._timeout) == self._op_id
222236
return True
223237

238+
def get_future(self) -> Future:
239+
return self._pg._get_future(self._op_id)
240+
224241

225242
class BabyWorkNCCL(BabyWork):
226243
def wait(self) -> bool:
@@ -255,6 +272,8 @@ def __init__(self, timeout: float = 60.0) -> None:
255272
self._p = None
256273
self._tx = None
257274
self._rx = None
275+
self._future_queue = None
276+
self._future_thread = None
258277

259278
self._timeout = timeout
260279

@@ -264,20 +283,46 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
264283

265284
self._world_size = world_size
266285

286+
if self._tx is not None:
287+
self._tx.close()
288+
if self._rx is not None:
289+
self._rx.close()
290+
if self._future_queue is not None:
291+
self._future_queue.put(_QUEUE_CLOSE)
292+
self._future_queue.close()
293+
267294
ctx = mp.get_context("spawn")
268295
self._tx = ctx.Queue()
269296
self._rx = ctx.Queue()
270297

298+
# futures need thread to fire callbacks
299+
self._future_queue = ctx.Queue()
300+
# this lock needs to be held when manipulating _futures
301+
self._futures_lock = threading.Lock()
302+
self._futures = {}
303+
self._future_thread = threading.Thread(
304+
target=self._future_handler,
305+
args=(self._future_queue,),
306+
daemon=True,
307+
)
308+
self._future_thread.start()
309+
271310
self._p = ctx.Process(
272311
target=self._worker,
273-
args=(store_addr, rank, world_size, self._tx, self._rx),
312+
args=(store_addr, rank, world_size, self._tx, self._rx, self._future_queue),
274313
daemon=True,
275314
)
276315
self._p.start()
277316

278317
@classmethod
279318
def _worker(
280-
cls, store_addr: str, rank: int, world_size: int, rx: mp.Queue, tx: mp.Queue
319+
cls,
320+
store_addr: str,
321+
rank: int,
322+
world_size: int,
323+
rx: mp.Queue,
324+
tx: mp.Queue,
325+
future_queue: mp.Queue,
281326
) -> None:
282327
try:
283328
store = create_store(store_addr)
@@ -291,15 +336,28 @@ def _worker(
291336
op = rx.get()
292337
cmd = op[0]
293338
if cmd == "func":
294-
func, args, kwargs = op[1:]
295-
work[next_op_id] = getattr(pg, func)(*args, **kwargs)
339+
func_name, args, kwargs = op[1:]
340+
fn = getattr(pg, func_name)
341+
work[next_op_id] = fn(*args, **kwargs)
296342
tx.put(next_op_id)
297343
next_op_id += 1
298344
elif cmd == "wait":
299345
op_id = op[1]
300346
work[op_id].wait()
301347
del work[op_id]
302348
tx.put(op_id)
349+
elif cmd == "future":
350+
op_id = op[1]
351+
352+
def callback(fut: Future):
353+
try:
354+
fut.wait()
355+
future_queue.put((op_id, _FUTURE_RESULT, None))
356+
except Exception as e:
357+
future_queue.put((op_id, _FUTURE_EXCEPTION, e))
358+
359+
work[op_id].get_future().add_done_callback(callback)
360+
tx.put(op_id)
303361
elif cmd == "synchronize":
304362
# CUDA only, use events instead of waiting on CPU
305363
op_id = op[1]
@@ -322,12 +380,41 @@ def _worker(
322380
logger.exception("worker errored")
323381
tx.put(e)
324382

383+
def _future_handler(self, future_queue: mp.Queue) -> None:
384+
try:
385+
while True:
386+
cmd = future_queue.get()
387+
if cmd == _QUEUE_CLOSE:
388+
break
389+
op_id, mode, data = cmd
390+
with self._futures_lock:
391+
fut = self._futures[op_id]
392+
del self._futures[op_id]
393+
if mode == _FUTURE_RESULT:
394+
fut.set_result(data)
395+
elif mode == _FUTURE_EXCEPTION:
396+
fut.set_exception(data)
397+
else:
398+
raise ValueError(f"unknown mode {mode}")
399+
except Exception as e:
400+
logger.exception(f"got unexpected error in future handler: {e}")
401+
402+
def _get_future(self, op_id: int) -> Future:
403+
with self._futures_lock:
404+
fut = Future()
405+
self._futures[op_id] = fut
406+
self._tx.put(("future", op_id), timeout=self._timeout)
407+
408+
assert _get(self._rx, self._timeout) == op_id
409+
# TODO: return correct tensor instead of None
410+
return fut
411+
325412
def _run_func(self, func: str, *args: object, **kwargs: object) -> Work:
326413
self._tx.put(("func", func, args, kwargs), timeout=self._timeout)
327414
op_id = _get(self._rx, self._timeout)
328415
assert isinstance(op_id, int), f"invalid return {op_id}"
329416
return self.WORK_CLASS(
330-
tx=self._tx, rx=self._rx, op_id=op_id, timeout=self._timeout
417+
pg=self, tx=self._tx, rx=self._rx, op_id=op_id, timeout=self._timeout
331418
)
332419

333420
def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
@@ -366,7 +453,7 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby):
366453
tensors may leak in the current PyTorch implementation. TODO fix
367454
"""
368455

369-
PG_CLASS = BaseProcessGroupGloo
456+
PG_CLASS = BaseProcessGroupNCCL
370457
WORK_CLASS = BabyWorkNCCL
371458

372459
def getBackendName(self):

torchft/process_group_test.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from unittest import TestCase, skipUnless
8+
from concurrent.futures import ThreadPoolExecutor
89

910
import torch
1011
from torch.distributed import TCPStore, ReduceOp
@@ -37,6 +38,7 @@ def test_gloo(self) -> None:
3738

3839
a_work = pg.allreduce([at], ReduceOp.SUM)
3940
a_work.wait()
41+
a_work.get_future().wait()
4042

4143
m = nn.Linear(3, 4)
4244
m = torch.nn.parallel.DistributedDataParallel(m, process_group=pg)
@@ -58,6 +60,7 @@ def test_nccl(self) -> None:
5860
at = torch.tensor([2], device=device)
5961
a_work = pg.allreduce([at], ReduceOp.SUM)
6062
a_work.wait()
63+
a_work.get_future().wait()
6164

6265
m = nn.Linear(3, 4).to(device)
6366
m = torch.nn.parallel.DistributedDataParallel(m, process_group=pg)
@@ -95,7 +98,9 @@ def test_baby_gloo(self) -> None:
9598
b_work = b.allreduce([bt], ReduceOp.SUM)
9699

97100
a_work.wait()
98-
b_work.wait()
101+
fut = b_work.get_future()
102+
103+
fut.wait()
99104

100105
torch.testing.assert_close(at, bt)
101106

@@ -113,23 +118,25 @@ def test_baby_nccl(self) -> None:
113118

114119
store_addr = f"localhost:{store.port}/prefix"
115120

116-
device = "cuda"
121+
def run(rank: int) -> None:
122+
a = ProcessGroupBabyNCCL()
123+
a.configure(store_addr, rank, 2)
117124

118-
a = ProcessGroupBabyNCCL()
119-
b = ProcessGroupBabyNCCL()
125+
self.assertEqual(a.size(), 2)
120126

121-
a.configure(store_addr, 0, 2)
122-
b.configure(store_addr, 1, 2)
127+
at = torch.tensor([rank + 1], device=f"cuda:{rank}")
123128

124-
self.assertEqual(a.size(), 2)
129+
a_work = a.allreduce([at], ReduceOp.SUM)
130+
return at, a_work
125131

126-
at = torch.tensor([1], device=device)
127-
bt = torch.tensor([2], device=device)
132+
with ThreadPoolExecutor(max_workers=2) as executor:
133+
a_fut = executor.submit(run, 0)
134+
b_fut = executor.submit(run, 1)
128135

129-
a_work = a.allreduce([at], ReduceOp.SUM)
130-
b_work = b.allreduce([bt], ReduceOp.SUM)
136+
at, a_work = a_fut.result()
137+
bt, b_work = b_fut.result()
131138

132139
a_work.wait()
133-
b_work.wait()
140+
b_work.get_future().wait()
134141

135-
torch.testing.assert_close(at, bt)
142+
torch.testing.assert_close(at.cpu(), bt.cpu())

0 commit comments

Comments
 (0)