Skip to content

Commit c74a1b0

Browse files
tushar00jaind4l3k
andauthored
fix managed pg allreduce (#249)
Summary: managed pg allreduce should just call manager's allreduce Co-authored-by: Tristan Rice <[email protected]>
1 parent fc15d58 commit c74a1b0

File tree

3 files changed

+11
-47
lines changed

3 files changed

+11
-47
lines changed

torchft/manager.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,7 @@ def allreduce(
381381
self,
382382
tensor: torch.Tensor,
383383
should_quantize: bool = False,
384+
reduce_op: ReduceOp = ReduceOp.SUM,
384385
) -> Work:
385386
"""
386387
Fault tolerant allreduce the tensor and return a Future that will be completed when
@@ -416,13 +417,13 @@ def allreduce(
416417
if should_quantize and IS_TRITON_AVAILABLE:
417418
work = allreduce_quantized(
418419
[tensor],
419-
ReduceOp.SUM,
420+
reduce_op,
420421
self._pg,
421422
# pyre-fixme[6]: Expected `Optional[streams.Stream]` but got `_C.Stream`
422423
torch.accelerator.current_stream(),
423424
)
424425
else:
425-
work = self._pg.allreduce([tensor], ReduceOp.SUM)
426+
work = self._pg.allreduce([tensor], reduce_op)
426427

427428
# schedule grad normalization as a continuation
428429
# on the Future
@@ -431,7 +432,8 @@ def callback(
431432
fut: torch.futures.Future[list[torch.Tensor]],
432433
) -> torch.Tensor:
433434
nonlocal tensor
434-
tensor /= num_participants
435+
if reduce_op == ReduceOp.SUM:
436+
tensor /= num_participants
435437
return tensor
436438

437439
managed_work = _ManagedWork(self, work, tensor)

torchft/process_group.py

Lines changed: 5 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1204,30 +1204,6 @@ def callback(
12041204
return work
12051205

12061206

1207-
class _ManagedWork(Work):
1208-
def __init__(self, manager: "Manager", work: Work, default_result: object) -> None:
1209-
super().__init__()
1210-
1211-
self._manager = manager
1212-
self._work = work
1213-
self._default_result = default_result
1214-
1215-
def wait(self, timeout: Optional[timedelta] = None) -> bool:
1216-
try:
1217-
if self._work is not None:
1218-
if timeout is not None:
1219-
self._work.wait(timeout)
1220-
else:
1221-
self._work.wait()
1222-
except Exception as e:
1223-
self._manager.report_error(e)
1224-
1225-
return True
1226-
1227-
def get_future(self) -> Future[object]:
1228-
return self._manager.wrap_future(self._work.get_future(), self._default_result)
1229-
1230-
12311207
class ManagedProcessGroup(ProcessGroupWrapper):
12321208
"""
12331209
This is a wrapper around any ProcessGroup that is managed by a torchft
@@ -1247,23 +1223,13 @@ def __init__(self, manager: "Manager") -> None:
12471223
self._manager = manager
12481224

12491225
def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
1250-
# Ensure we have a valid quorum and are configured before trying to do
1251-
# any work.
1252-
self._manager.wait_quorum()
1226+
if isinstance(opts, ReduceOp):
1227+
return self._manager.allreduce(tensors, reduce_op=opts)
12531228

1254-
if self._manager.errored() is not None:
1255-
return _DummyWork(tensors)
1256-
try:
1257-
work = super().allreduce(tensors, opts)
1258-
except Exception as e:
1259-
self._manager.report_error(e)
1260-
return _DummyWork(tensors)
1229+
if isinstance(opts, AllreduceOptions):
1230+
return self._manager.allreduce(tensors, reduce_op=opts.reduceOp)
12611231

1262-
return _ManagedWork(
1263-
self._manager,
1264-
work,
1265-
tensors,
1266-
)
1232+
assert False, "unreachable"
12671233

12681234
def size(self) -> int:
12691235
return self._manager.num_participants()

torchft/process_group_test.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
ProcessGroupNCCL,
4949
ProcessGroupWrapper,
5050
_ErrorSwallowingWork,
51-
_ManagedWork,
5251
extend_device_mesh,
5352
ft_init_device_mesh,
5453
)
@@ -810,11 +809,8 @@ def test_managed_process_group(self) -> None:
810809
self.assertEqual(pg.size(), 123)
811810

812811
works = _test_pg(pg)
813-
self.assertIsInstance(list(works.values())[0], _ManagedWork)
814812

815-
self.assertEqual(manager.report_error.call_count, 0)
816-
self.assertEqual(manager.wrap_future.call_count, 2)
817-
self.assertEqual(manager.wait_quorum.call_count, 2)
813+
self.assertEqual(manager.allreduce.call_count, 2)
818814

819815

820816
class DeviceMeshTest(TestCase):

0 commit comments

Comments
 (0)