Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
360 changes: 359 additions & 1 deletion media/overview.mmd.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 4 additions & 0 deletions torchft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
from torchft.optim import OptimizerWrapper as Optimizer
from torchft.process_group import (
ProcessGroupBabyNCCL,
ProcessGroupBabyXCCL,
ProcessGroupGloo,
ProcessGroupNCCL,
ProcessGroupXCCL,
)

__all__ = (
Expand All @@ -20,6 +22,8 @@
"Manager",
"Optimizer",
"ProcessGroupNCCL",
"ProcessGroupXCCL",
"ProcessGroupBabyNCCL",
"ProcessGroupBabyXCCL",
"ProcessGroupGloo",
)
5 changes: 5 additions & 0 deletions torchft/diloco_regression_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import logging
import os
import sys
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from contextlib import ExitStack
Expand Down Expand Up @@ -307,6 +308,8 @@ def test_diloco_mocked_updates(
# Skip the test if use_cuda is True and there are not enough GPUs
if use_cuda and torch.cuda.device_count() < 2:
self.skipTest("Not enough GPUs for CUDA test")
if sys.platform == "darwin":
self.skipTest("not reliable on mac")

lighthouse = LighthouseServer(bind="[::]:0", min_replicas=2)
sync_every = 6
Expand Down Expand Up @@ -386,6 +389,8 @@ def test_diloco_mocked_failure_recovery(
# Skip the test if use_cuda is True and there are not enough GPUs
if use_cuda and torch.cuda.device_count() < 2:
self.skipTest("Not enough GPUs for CUDA test")
if sys.platform == "darwin":
self.skipTest("not reliable on mac")

lighthouse = LighthouseServer(bind="[::]:0", min_replicas=2)
sync_every = 6
Expand Down
12 changes: 8 additions & 4 deletions torchft/futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import torch
from torch.futures import Future

from torchft.utils import get_stream_context

T = TypeVar("T")


Expand Down Expand Up @@ -161,12 +163,14 @@ def register(self, fut: Future[T], timeout: timedelta) -> Future[T]:
handle,
)

stream: Optional[torch.cuda.Stream] = (
torch.cuda.current_stream() if torch.cuda.is_available() else None
stream: Optional[torch.Stream] = (
torch.accelerator.current_stream()
if torch.accelerator.is_available()
else None
)

def callback(fut: Future[T]) -> None:
with torch.cuda.stream(stream) if stream is not None else nullcontext():
with get_stream_context(stream):
handle.cancel()
try:
timed_fut.set_result(fut.wait())
Expand All @@ -186,7 +190,7 @@ def stream_timeout(self, callback: Callable[[], None], timeout: timedelta) -> No

loop = self._maybe_start_event_loop()

event: torch.cuda.Event = torch.cuda.Event()
event: torch.Event = torch.Event()
event.record()

def handler() -> None:
Expand Down
12 changes: 12 additions & 0 deletions torchft/local_sgd_integ_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ def test_local_sgd_recovery(self, use_cuda: bool) -> None:
# Skip the test if use_cuda is True and there are not enough GPUs
if use_cuda and torch.cuda.device_count() < 2:
self.skipTest("Not enough GPUs for CUDA test")
if sys.platform == "darwin":
self.skipTest("not reliable on mac")

lighthouse = LighthouseServer(
bind="[::]:0",
Expand Down Expand Up @@ -239,6 +241,8 @@ def test_diloco_healthy(self, use_cuda: bool) -> None:
# Skip the test if use_cuda is True and there are not enough GPUs
if use_cuda and torch.cuda.device_count() < 2:
self.skipTest("Not enough GPUs for CUDA test")
if sys.platform == "darwin":
self.skipTest("not reliable on mac")

lighthouse = LighthouseServer(bind="[::]:0", min_replicas=2)
num_replicas = 2
Expand Down Expand Up @@ -290,6 +294,8 @@ def test_diloco_recovery(self, use_cuda: bool) -> None:
# Skip the test if use_cuda is True and there are not enough GPUs
if use_cuda and torch.cuda.device_count() < 2:
self.skipTest("Not enough GPUs for CUDA test")
if sys.platform == "darwin":
self.skipTest("not reliable on mac")

lighthouse = LighthouseServer(
bind="[::]:0",
Expand Down Expand Up @@ -368,6 +374,8 @@ def test_streaming_diloco_recovery(self, use_cuda: bool) -> None:
# Skip the test if use_cuda is True and there are not enough GPUs
if use_cuda and torch.cuda.device_count() < 2:
self.skipTest("Not enough GPUs for CUDA test")
if sys.platform == "darwin":
self.skipTest("not reliable on mac")

lighthouse = LighthouseServer(
bind="[::]:0",
Expand Down Expand Up @@ -441,6 +449,8 @@ def test_streaming_diloco_upscale(
# Skip the test if use_cuda is True and there are not enough GPUs
if use_cuda and torch.cuda.device_count() < 2:
self.skipTest("Not enough GPUs for CUDA test")
if sys.platform == "darwin":
self.skipTest("not reliable on mac")

lighthouse = LighthouseServer(
bind="[::]:0",
Expand Down Expand Up @@ -515,6 +525,8 @@ def test_streaming_diloco_commit_failure(
# Skip the test if use_cuda is True and there are not enough GPUs
if use_cuda and torch.cuda.device_count() < 2:
self.skipTest("Not enough GPUs for CUDA test")
if sys.platform == "darwin":
self.skipTest("not reliable on mac")

lighthouse = LighthouseServer(
bind="[::]:0",
Expand Down
75 changes: 34 additions & 41 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from torchft.checkpointing import CheckpointTransport, HTTPTransport
from torchft.checkpointing._rwlock import RWLock
from torchft.futures import future_timeout
from torchft.utils import get_stream_context, synchronize
from torchft.work import _DummyWork

if TYPE_CHECKING:
Expand Down Expand Up @@ -276,12 +277,12 @@ def __init__(
self._pg = pg
self._manager: Optional[ManagerServer] = None

self._recovery_stream: Optional["torch.cuda.Stream"] = (
torch.cuda.Stream() if torch.cuda.is_available() else None
self._recovery_stream: Optional["torch.Stream"] = (
torch.Stream() if torch.accelerator.is_available() else None
)

# Used to synchronize recovery operation
self._recovery_event: Optional[torch.cuda.Event] = None
self._recovery_event: Optional[torch.Event] = None

if self._group_rank == 0:
if port is None:
Expand Down Expand Up @@ -414,7 +415,11 @@ def allreduce(
# it later.
if should_quantize and IS_TRITON_AVAILABLE:
work = allreduce_quantized(
[tensor], ReduceOp.SUM, self._pg, torch.cuda.current_stream()
[tensor],
ReduceOp.SUM,
self._pg,
# pyre-fixme[6]: Expected `Optional[streams.Stream]` but got `_C.Stream`
torch.accelerator.current_stream(),
)
else:
work = self._pg.allreduce([tensor], ReduceOp.SUM)
Expand Down Expand Up @@ -482,8 +487,10 @@ def wrap_future(

fut = future_timeout(fut, timeout or self._timeout)

stream: Optional[torch.cuda.Stream] = (
torch.cuda.current_stream() if torch.cuda.is_available() else None
stream: Optional[torch.Stream] = (
torch.accelerator.current_stream()
if torch.accelerator.is_available()
else None
)

# schedule error handling as a continuation on the Future
Expand All @@ -492,7 +499,7 @@ def callback(
) -> T:
nonlocal default, stream

with torch.cuda.stream(stream) if stream is not None else nullcontext():
with get_stream_context(stream):
try:
return fut.value()
except Exception as e:
Expand Down Expand Up @@ -546,7 +553,9 @@ def start_quorum(
shrink_only=shrink_only,
quorum_timeout=timeout or self._quorum_timeout,
curr_device=(
torch.cuda.current_device() if torch.cuda.is_available() else -1
torch.accelerator.current_device_index()
if torch.accelerator.is_available()
else -1
),
)
if not self._use_async_quorum:
Expand Down Expand Up @@ -582,8 +591,8 @@ def _async_quorum(
) -> None:
torch.multiprocessing._set_thread_name("torchft_quorum")

if curr_device >= 0 and torch.cuda.is_available():
torch.cuda.set_device(curr_device)
if curr_device >= 0 and torch.accelerator.is_available():
torch.accelerator.set_device_index(curr_device)

quorum = None
with torch.profiler.record_function("torchft::manager::_client::_quorum"):
Expand Down Expand Up @@ -649,11 +658,7 @@ def _async_quorum(
if allow_heal:
# run recovery on the recovery stream if available
recovery_stream = self._recovery_stream
with (
torch.cuda.stream(recovery_stream)
if recovery_stream is not None
else nullcontext()
):
with get_stream_context(recovery_stream):
try:
if quorum.recover_dst_replica_ranks:
self._logger.info(
Expand Down Expand Up @@ -714,7 +719,7 @@ def _async_quorum(
self.report_error(e)

self._recovery_event = (
torch.cuda.current_stream().record_event()
torch.accelerator.current_stream().record_event()
if recovery_stream is not None
else None
)
Expand Down Expand Up @@ -784,8 +789,8 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
with torch.profiler.record_function(
"torchft::manager::should_commit::current_stream::synchronize"
):
if torch.cuda.is_available():
torch.cuda.current_stream().synchronize()
if torch.accelerator.is_available():
synchronize()

if err := self._pg.errored():
self.report_error(err)
Expand Down Expand Up @@ -1128,8 +1133,10 @@ def __init__(

# The stream used to created the `Work` - we ensure all operations in the future
# callback chain are executed on this stream
self._stream: Optional[torch.cuda.Stream] = (
torch.cuda.current_stream() if torch.cuda.is_available() else None
self._stream: Optional[torch.Stream] = (
torch.accelerator.current_stream()
if torch.accelerator.is_available()
else None
)

# To ensure the future callback chain is only created once
Expand Down Expand Up @@ -1164,11 +1171,7 @@ def callback(
nonlocal managed_fut, value
# change the stream to avoid making the callback stream
# dependent on process group stream running the allreduce
with (
torch.cuda.stream(self._stream)
if self._stream is not None
else nullcontext()
):
with get_stream_context(self._stream):
# Setup stream dependency
fut.wait()
assert managed_fut._callback
Expand Down Expand Up @@ -1199,36 +1202,24 @@ def _assert_same_stream(self) -> None:
This makes sure users of the API are aware about stream dependencies.
"""
if self._stream is not None:
assert self._stream == torch.cuda.current_stream()
assert self._stream == torch.accelerator.current_stream()

def wait(self, timeout: Optional[timedelta] = None) -> bool:
self._assert_same_stream()

with (
torch.cuda.stream(self._stream)
if self._stream is not None
else nullcontext()
):
with get_stream_context(self._stream):
self._work.wait()
self._set_future_callback()

with (
torch.cuda.stream(self._stream)
if self._stream is not None
else nullcontext()
):
with get_stream_context(self._stream):
self._managed_fut_tail.wait()

return True

def block_current_stream(self, timeout: Optional[timedelta] = None) -> None:
self._assert_same_stream()

with (
torch.cuda.stream(self._stream)
if self._stream is not None
else nullcontext()
):
with get_stream_context(self._stream):
self._work.block_current_stream()

self._set_future_callback()
Expand All @@ -1238,6 +1229,8 @@ def synchronize(self) -> None:

if torch.cuda.is_available():
self.block_current_stream()
elif torch.xpu.is_available():
self._set_future_callback()
else:
# No stream dependencies need to be set
self._set_future_callback()
Expand Down
Loading