Skip to content

Commit 02a0cfc

Browse files
committed
Review comments updated
1 parent bc9ccdb commit 02a0cfc

File tree

5 files changed

+118
-242
lines changed

5 files changed

+118
-242
lines changed

torchft/futures.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import torch
1212
from torch.futures import Future
1313

14+
from torchft.utils import get_stream_context
15+
1416
T = TypeVar("T")
1517

1618

@@ -162,20 +164,13 @@ def register(self, fut: Future[T], timeout: timedelta) -> Future[T]:
162164
)
163165

164166
stream: Optional[torch.Stream] = (
165-
torch.accelerator.current_stream() if torch.accelerator.is_available() else None
167+
torch.accelerator.current_stream()
168+
if torch.accelerator.is_available()
169+
else None
166170
)
167171

168172
def callback(fut: Future[T]) -> None:
169-
if stream is not None:
170-
if torch.cuda.is_available():
171-
context = torch.cuda.stream(stream)
172-
elif torch.xpu.is_available():
173-
context = torch.xpu.stream(stream)
174-
else:
175-
context = nullcontext()
176-
else:
177-
context = nullcontext()
178-
with context:
173+
with get_stream_context(stream):
179174
handle.cancel()
180175
try:
181176
timed_fut.set_result(fut.wait())

torchft/manager.py

Lines changed: 25 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
from torchft.checkpointing import CheckpointTransport, HTTPTransport
5959
from torchft.checkpointing._rwlock import RWLock
6060
from torchft.futures import future_timeout
61+
from torchft.utils import get_stream_context
6162
from torchft.work import _DummyWork
6263

6364
if TYPE_CHECKING:
@@ -276,10 +277,9 @@ def __init__(
276277
self._pg = pg
277278
self._manager: Optional[ManagerServer] = None
278279

279-
if torch.accelerator.is_available():
280-
self._recovery_stream: Optional["torch.Stream"] = torch.Stream()
281-
else:
282-
self._recovery_stream = None
280+
self._recovery_stream: Optional["torch.Stream"] = (
281+
torch.Stream() if torch.accelerator.is_available() else None
282+
)
283283

284284
# Used to synchronize recovery operation
285285
self._recovery_event: Optional[torch.Event] = None
@@ -414,13 +414,9 @@ def allreduce(
414414
# Run the allreduce async and save the work object so we can wait on
415415
# it later.
416416
if should_quantize and IS_TRITON_AVAILABLE:
417-
if torch.accelerator.is_available():
418-
current_stream = torch.accelerator.current_stream()
419-
else:
420-
current_stream = None
421-
417+
# pyre-fixme[6]: Expected `Optional[streams.Stream]` but got `_C.Stream`
422418
work = allreduce_quantized(
423-
[tensor], ReduceOp.SUM, self._pg, current_stream
419+
[tensor], ReduceOp.SUM, self._pg, torch.accelerator.current_stream()
424420
)
425421
else:
426422
work = self._pg.allreduce([tensor], ReduceOp.SUM)
@@ -488,27 +484,19 @@ def wrap_future(
488484

489485
fut = future_timeout(fut, timeout or self._timeout)
490486

491-
if torch.accelerator.is_available():
492-
stream = torch.accelerator.current_stream()
493-
else:
494-
stream = None
487+
stream: Optional[torch.Stream] = (
488+
torch.accelerator.current_stream()
489+
if torch.accelerator.is_available()
490+
else None
491+
)
495492

496493
# schedule error handling as a continuation on the Future
497494
def callback(
498495
fut: torch.futures.Future[T],
499496
) -> T:
500497
nonlocal default, stream
501498

502-
if stream is not None:
503-
if torch.cuda.is_available():
504-
context = torch.cuda.stream(stream)
505-
elif torch.xpu.is_available():
506-
context = torch.xpu.stream(stream)
507-
else:
508-
context = nullcontext()
509-
else:
510-
context = nullcontext()
511-
with context:
499+
with get_stream_context(stream):
512500
try:
513501
return fut.value()
514502
except Exception as e:
@@ -562,7 +550,9 @@ def start_quorum(
562550
shrink_only=shrink_only,
563551
quorum_timeout=timeout or self._quorum_timeout,
564552
curr_device=(
565-
torch.accelerator.current_device_index() if torch.accelerator.is_available() else -1
553+
torch.accelerator.current_device_index()
554+
if torch.accelerator.is_available()
555+
else -1
566556
),
567557
)
568558
if not self._use_async_quorum:
@@ -598,11 +588,8 @@ def _async_quorum(
598588
) -> None:
599589
torch.multiprocessing._set_thread_name("torchft_quorum")
600590

601-
if curr_device >= 0:
602-
if torch.cuda.is_available():
603-
torch.cuda.set_device(curr_device)
604-
elif torch.xpu.is_available():
605-
torch.xpu.set_device(curr_device)
591+
if curr_device >= 0 and torch.accelerator.is_available():
592+
torch.accelerator.set_device_index(curr_device)
606593

607594
quorum = None
608595
with torch.profiler.record_function("torchft::manager::_client::_quorum"):
@@ -668,17 +655,7 @@ def _async_quorum(
668655
if allow_heal:
669656
# run recovery on the recovery stream if available
670657
recovery_stream = self._recovery_stream
671-
if recovery_stream is not None:
672-
if torch.cuda.is_available():
673-
stream_context = torch.cuda.stream(recovery_stream)
674-
elif torch.xpu.is_available():
675-
stream_context = torch.xpu.stream(recovery_stream)
676-
else:
677-
stream_context = nullcontext()
678-
else:
679-
stream_context = nullcontext()
680-
681-
with stream_context:
658+
with get_stream_context(recovery_stream):
682659
try:
683660
if quorum.recover_dst_replica_ranks:
684661
self._logger.info(
@@ -1154,7 +1131,9 @@ def __init__(
11541131
# The stream used to created the `Work` - we ensure all operations in the future
11551132
# callback chain are executed on this stream
11561133
self._stream: Optional[torch.Stream] = (
1157-
torch.accelerator.current_stream() if torch.accelerator.is_available() else None
1134+
torch.accelerator.current_stream()
1135+
if torch.accelerator.is_available()
1136+
else None
11581137
)
11591138

11601139
# To ensure the future callback chain is only created once
@@ -1189,16 +1168,7 @@ def callback(
11891168
nonlocal managed_fut, value
11901169
# change the stream to avoid making the callback stream
11911170
# dependent on process group stream running the allreduce
1192-
if self._stream is not None:
1193-
if torch.cuda.is_available():
1194-
context = torch.cuda.stream(self._stream)
1195-
elif torch.xpu.is_available():
1196-
context = torch.xpu.stream(self._stream)
1197-
else:
1198-
context = nullcontext()
1199-
else:
1200-
context = nullcontext()
1201-
with context:
1171+
with get_stream_context(self._stream):
12021172
# Setup stream dependency
12031173
fut.wait()
12041174
assert managed_fut._callback
@@ -1234,46 +1204,19 @@ def _assert_same_stream(self) -> None:
12341204
def wait(self, timeout: Optional[timedelta] = None) -> bool:
12351205
self._assert_same_stream()
12361206

1237-
if self._stream is not None:
1238-
if torch.cuda.is_available():
1239-
context = torch.cuda.stream(self._stream)
1240-
elif torch.xpu.is_available():
1241-
context = torch.xpu.stream(self._stream)
1242-
else:
1243-
context = nullcontext()
1244-
else:
1245-
context = nullcontext()
1246-
with context:
1207+
with get_stream_context(self._stream):
12471208
self._work.wait()
12481209
self._set_future_callback()
12491210

1250-
if self._stream is not None:
1251-
if torch.cuda.is_available():
1252-
context = torch.cuda.stream(self._stream)
1253-
elif torch.xpu.is_available():
1254-
context = torch.xpu.stream(self._stream)
1255-
else:
1256-
context = nullcontext()
1257-
else:
1258-
context = nullcontext()
1259-
with context:
1211+
with get_stream_context(self._stream):
12601212
self._managed_fut_tail.wait()
12611213

12621214
return True
12631215

12641216
def block_current_stream(self, timeout: Optional[timedelta] = None) -> None:
12651217
self._assert_same_stream()
12661218

1267-
if self._stream is not None:
1268-
if torch.cuda.is_available():
1269-
context = torch.cuda.stream(self._stream)
1270-
elif torch.xpu.is_available():
1271-
context = torch.xpu.stream(self._stream)
1272-
else:
1273-
context = nullcontext()
1274-
else:
1275-
context = nullcontext()
1276-
with context:
1219+
with get_stream_context(self._stream):
12771220
self._work.block_current_stream()
12781221

12791222
self._set_future_callback()

0 commit comments

Comments
 (0)