-
Notifications
You must be signed in to change notification settings - Fork 49
ProcessGroupBabyNCCL: support multiple streams and use event on start #91
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,17 +19,18 @@ | |
import logging | ||
import queue | ||
import threading | ||
from contextlib import contextmanager, nullcontext | ||
from dataclasses import dataclass | ||
from datetime import timedelta | ||
from typing import ( | ||
TYPE_CHECKING, | ||
Any, | ||
Callable, | ||
Dict, | ||
Generator, | ||
List, | ||
Optional, | ||
Tuple, | ||
Type, | ||
TypeVar, | ||
Union, | ||
cast, | ||
|
@@ -58,9 +59,9 @@ | |
BroadcastOptions, | ||
ReduceOp, | ||
Work, | ||
_world, | ||
) | ||
from torch.futures import Future | ||
from torch.utils._pytree import tree_any | ||
|
||
if TYPE_CHECKING: | ||
from torchft.manager import Manager | ||
|
@@ -586,29 +587,52 @@ def __init__( | |
self._timeout = timeout | ||
|
||
def wait(self, timeout: Optional[timedelta] = None) -> bool: | ||
self._pg._assert_alive() | ||
|
||
self._tx.put(("wait", self._op_id), timeout=self._timeout) | ||
assert _get(self._rx, self._timeout) == self._op_id | ||
op_id, event = cast( | ||
Tuple[int, Optional[torch.cuda.Event]], | ||
_get(self._rx, timeout or self._timeout), | ||
) | ||
assert op_id == self._op_id | ||
if event is not None: | ||
event.wait() | ||
return True | ||
|
||
def synchronize(self) -> None: | ||
# TODO: No one seems to use this and NCCL wait already only waits the | ||
# stream and is non-blocking on the CPU side so no real need for a | ||
# separate call. | ||
raise NotImplementedError("not implemented") | ||
|
||
def get_future(self) -> Future[object]: | ||
return self._pg._get_future(self._op_id) | ||
|
||
def __del__(self) -> None: | ||
self._tx.put(("del", self._op_id), timeout=self._timeout) | ||
|
||
|
||
class _BabyWorkNCCL(_BabyWork): | ||
def wait(self, timeout: Optional[timedelta] = None) -> bool: | ||
self._tx.put(("synchronize", self._op_id), timeout=self._timeout) | ||
# pyre-fixme[23]: unable to unpack into 2 values | ||
op_id, event = _get(self._rx, self._timeout) | ||
assert op_id == self._op_id | ||
assert isinstance(event, torch.cuda.Event) | ||
def _is_any_cuda(obj: object) -> bool: | ||
""" | ||
Returns true if any of the tensors in the object are CUDA tensors. | ||
|
||
# Wait on Event makes the stream wait but not the CPU thread. | ||
event.wait() | ||
Supports lists, tuples, dicts, and tensors. | ||
""" | ||
return tree_any(lambda obj: isinstance(obj, torch.Tensor) and obj.is_cuda, obj) | ||
|
||
return True | ||
|
||
@dataclass | ||
class _OpMetadata: | ||
work: Work | ||
stream: Optional[torch.cuda.Stream] | ||
|
||
@contextmanager | ||
def set_stream(self) -> Generator[None, None, None]: | ||
if self.stream is not None: | ||
with torch.cuda.stream(self.stream): | ||
yield | ||
else: | ||
yield | ||
|
||
|
||
class ProcessGroupBaby(ProcessGroup): | ||
|
@@ -617,11 +641,8 @@ class ProcessGroupBaby(ProcessGroup): | |
subprocess. Since it's running in a subprocess all tensors need to be in | ||
shared memory or will be moved to shared memory. CUDA tensors are implicitly | ||
share able and don't need any changes. | ||
|
||
""" | ||
|
||
WORK_CLASS: Type[_BabyWork] = _BabyWork | ||
|
||
def __init__(self, timeout: Union[float, timedelta] = 60.0) -> None: | ||
super().__init__(0, 1) | ||
|
||
|
@@ -679,7 +700,14 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None: | |
|
||
self._p = ctx.Process( | ||
target=self._worker, | ||
args=(store_addr, rank, world_size, self._tx, self._rx, self._future_queue), | ||
args=( | ||
store_addr, | ||
rank, | ||
world_size, | ||
self._tx, | ||
self._rx, | ||
self._future_queue, | ||
), | ||
daemon=True, | ||
) | ||
self._p.start() | ||
|
@@ -716,23 +744,70 @@ def _worker( | |
return | ||
tx.put(None) | ||
|
||
work = {} | ||
streams: Dict[str, torch.cuda.Stream] = {} | ||
work: Dict[int, _OpMetadata] = {} | ||
next_op_id: int = 0 | ||
|
||
while True: | ||
op = rx.get() | ||
cmd = op[0] | ||
if cmd == "func": | ||
func_name, args, kwargs = op[1:] | ||
args = _PickleSafeOptions.unsafe_args(args) | ||
fn = getattr(pg, func_name) | ||
work[next_op_id] = fn(*args, **kwargs) | ||
func_name, args, kwargs, stream_device, stream_id, event = op[1:] | ||
|
||
# To avoid potential deadlocks we need to preserve the | ||
# stream/synchronization behavior of the parent process. | ||
# We allocate one Stream per stream_id to make sure that we | ||
# don't accidentally introduce cross stream synchronization | ||
# points. | ||
if stream_id is not None: | ||
stream_key = f"{stream_device}/{stream_id}" | ||
if stream_key not in streams: | ||
streams[stream_key] = torch.cuda.Stream( | ||
device=stream_device | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we going to have zombie There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My understanding is streams are specific to the cuda context/process so this will be cleaned up just fine when it gets killed |
||
stream = streams[stream_key] | ||
else: | ||
stream = None | ||
|
||
with ( | ||
torch.cuda.stream(stream) | ||
if stream is not None | ||
else nullcontext() | ||
): | ||
# Make the stream wait on the cuda event to make sure we | ||
# don't start the operation until the tensor is ready. | ||
if event is not None: | ||
event.wait() | ||
|
||
args = _PickleSafeOptions.unsafe_args(args) | ||
fn = getattr(pg, func_name) | ||
work[next_op_id] = _OpMetadata( | ||
work=fn(*args, **kwargs), | ||
stream=stream, | ||
) | ||
tx.put(next_op_id) | ||
next_op_id += 1 | ||
elif cmd == "wait": | ||
op_id: int = op[1] | ||
work[op_id].wait() | ||
tx.put(op_id) | ||
|
||
metadata = work[op_id] | ||
|
||
with metadata.set_stream(): | ||
# With WorkNCCL this makes the stream wait not the CPU when | ||
# no timeout is passed. | ||
metadata.work.wait() | ||
|
||
# Register event on the stream that we can pass to the main | ||
# process. | ||
event = ( | ||
torch.cuda.current_stream().record_event( | ||
torch.cuda.Event(interprocess=True) | ||
) | ||
if metadata.stream is not None | ||
else None | ||
) | ||
|
||
tx.put((op_id, event)) | ||
elif cmd == "del": | ||
op_id: int = op[1] | ||
del work[op_id] | ||
|
@@ -746,23 +821,8 @@ def callback(fut: Future[object]) -> None: | |
except Exception as e: | ||
future_queue.put((op_id, _FUTURE_EXCEPTION, e)) | ||
|
||
work[op_id].get_future().add_done_callback(callback) | ||
work[op_id].work.get_future().add_done_callback(callback) | ||
tx.put(op_id) | ||
elif cmd == "synchronize": | ||
# CUDA only, use events instead of waiting on CPU | ||
op_id = op[1] | ||
|
||
# With WorkNCCL this makes the stream wait not the CPU when | ||
# no timeout is passed. | ||
work[op_id].wait() | ||
|
||
# Register event on the stream that we can pass to the main | ||
# process. | ||
event = torch.cuda.Event(interprocess=True) | ||
event.record() | ||
|
||
del work[op_id] | ||
tx.put((op_id, event)) | ||
elif cmd == "num_active_work": | ||
tx.put(len(work)) | ||
else: | ||
|
@@ -771,6 +831,7 @@ def callback(fut: Future[object]) -> None: | |
except Exception as e: | ||
logger.exception("worker errored") | ||
tx.put(e) | ||
raise | ||
|
||
def _future_handler(self, future_queue: mp.Queue) -> None: | ||
try: | ||
|
@@ -792,6 +853,8 @@ def _future_handler(self, future_queue: mp.Queue) -> None: | |
logger.exception(f"got unexpected error in future handler: {e}") | ||
|
||
def _get_future(self, op_id: int) -> Future[object]: | ||
self._assert_alive() | ||
|
||
with self._futures_lock: | ||
fut = Future() # pyre-fixme[29]: is not a function | ||
self._futures[op_id] = fut | ||
|
@@ -804,22 +867,52 @@ def _get_future(self, op_id: int) -> Future[object]: | |
return fut | ||
|
||
def _run_func(self, func: str, *args: object, **kwargs: object) -> Work: | ||
self._assert_alive() | ||
|
||
rx = self._rx | ||
tx = self._tx | ||
assert rx is not None | ||
assert tx is not None | ||
|
||
is_cuda = _is_any_cuda(args) | ||
|
||
stream_device = torch.cuda.current_stream().device if is_cuda else None | ||
stream_id = torch.cuda.current_stream().stream_id if is_cuda else None | ||
event = ( | ||
torch.cuda.current_stream().record_event( | ||
torch.cuda.Event(interprocess=True) | ||
) | ||
if is_cuda | ||
else None | ||
) | ||
|
||
tx.put( | ||
("func", func, _PickleSafeOptions.safe_args(args), kwargs), | ||
( | ||
"func", | ||
func, | ||
_PickleSafeOptions.safe_args(args), | ||
kwargs, | ||
stream_device, | ||
stream_id, | ||
event, | ||
), | ||
timeout=self._timeout, | ||
) | ||
|
||
op_id = _get(rx, self._timeout) | ||
assert isinstance(op_id, int), f"invalid return {op_id}" | ||
|
||
return self.WORK_CLASS( | ||
pg=self, tx=tx, rx=rx, op_id=op_id, timeout=self._timeout | ||
) | ||
return _BabyWork(pg=self, tx=tx, rx=rx, op_id=op_id, timeout=self._timeout) | ||
|
||
def _assert_alive(self) -> None: | ||
""" | ||
Assert that the process group is alive. This is used to ensure that | ||
operations are not performed on a dead process group and any errors are surfaced. | ||
""" | ||
p = self._p | ||
assert p is not None | ||
if not p.is_alive(): | ||
raise RuntimeError(f"child process {p.pid=} is dead {p.exitcode=}") | ||
|
||
def allreduce( | ||
self, | ||
|
@@ -952,8 +1045,6 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby): | |
tensors may leak in the current PyTorch implementation. TODO fix | ||
""" | ||
|
||
WORK_CLASS = _BabyWorkNCCL | ||
|
||
@classmethod | ||
def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup: | ||
# pyre-fixme[16]: no attribute ProcessGroupNCCL | ||
|
Uh oh!
There was an error while loading. Please reload this page.