Skip to content

Commit f5888e9

Browse files
authored
[Intel GPU] Extending TorchFT to Support Intel GPU with XCCL Backend (#260)
* [RFC] [Intel GPU] Extending TorchFT to Support Intel GPU with XCCL Backend * Review comments updated * Fix UT failures due to record_event * UT/Lint fix * Add options in BaseProcessGroupXCCL * Lint fixes * UT failures fix
1 parent 2ef3b3a commit f5888e9

File tree

10 files changed

+720
-114
lines changed

10 files changed

+720
-114
lines changed

media/overview.mmd.svg

Lines changed: 359 additions & 1 deletion
Loading

torchft/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
from torchft.optim import OptimizerWrapper as Optimizer
1111
from torchft.process_group import (
1212
ProcessGroupBabyNCCL,
13+
ProcessGroupBabyXCCL,
1314
ProcessGroupGloo,
1415
ProcessGroupNCCL,
16+
ProcessGroupXCCL,
1517
)
1618

1719
__all__ = (
@@ -20,6 +22,8 @@
2022
"Manager",
2123
"Optimizer",
2224
"ProcessGroupNCCL",
25+
"ProcessGroupXCCL",
2326
"ProcessGroupBabyNCCL",
27+
"ProcessGroupBabyXCCL",
2428
"ProcessGroupGloo",
2529
)

torchft/diloco_regression_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
import logging
55
import os
6+
import sys
67
import threading
78
from concurrent.futures import ThreadPoolExecutor, as_completed
89
from contextlib import ExitStack
@@ -307,6 +308,8 @@ def test_diloco_mocked_updates(
307308
# Skip the test if use_cuda is True and there are not enough GPUs
308309
if use_cuda and torch.cuda.device_count() < 2:
309310
self.skipTest("Not enough GPUs for CUDA test")
311+
if sys.platform == "darwin":
312+
self.skipTest("not reliable on mac")
310313

311314
lighthouse = LighthouseServer(bind="[::]:0", min_replicas=2)
312315
sync_every = 6
@@ -386,6 +389,8 @@ def test_diloco_mocked_failure_recovery(
386389
# Skip the test if use_cuda is True and there are not enough GPUs
387390
if use_cuda and torch.cuda.device_count() < 2:
388391
self.skipTest("Not enough GPUs for CUDA test")
392+
if sys.platform == "darwin":
393+
self.skipTest("not reliable on mac")
389394

390395
lighthouse = LighthouseServer(bind="[::]:0", min_replicas=2)
391396
sync_every = 6

torchft/futures.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import torch
1212
from torch.futures import Future
1313

14+
from torchft.utils import get_stream_context
15+
1416
T = TypeVar("T")
1517

1618

@@ -161,12 +163,14 @@ def register(self, fut: Future[T], timeout: timedelta) -> Future[T]:
161163
handle,
162164
)
163165

164-
stream: Optional[torch.cuda.Stream] = (
165-
torch.cuda.current_stream() if torch.cuda.is_available() else None
166+
stream: Optional[torch.Stream] = (
167+
torch.accelerator.current_stream()
168+
if torch.accelerator.is_available()
169+
else None
166170
)
167171

168172
def callback(fut: Future[T]) -> None:
169-
with torch.cuda.stream(stream) if stream is not None else nullcontext():
173+
with get_stream_context(stream):
170174
handle.cancel()
171175
try:
172176
timed_fut.set_result(fut.wait())
@@ -186,7 +190,7 @@ def stream_timeout(self, callback: Callable[[], None], timeout: timedelta) -> No
186190

187191
loop = self._maybe_start_event_loop()
188192

189-
event: torch.cuda.Event = torch.cuda.Event()
193+
event: torch.Event = torch.Event()
190194
event.record()
191195

192196
def handler() -> None:

torchft/local_sgd_integ_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,8 @@ def test_local_sgd_recovery(self, use_cuda: bool) -> None:
181181
# Skip the test if use_cuda is True and there are not enough GPUs
182182
if use_cuda and torch.cuda.device_count() < 2:
183183
self.skipTest("Not enough GPUs for CUDA test")
184+
if sys.platform == "darwin":
185+
self.skipTest("not reliable on mac")
184186

185187
lighthouse = LighthouseServer(
186188
bind="[::]:0",
@@ -239,6 +241,8 @@ def test_diloco_healthy(self, use_cuda: bool) -> None:
239241
# Skip the test if use_cuda is True and there are not enough GPUs
240242
if use_cuda and torch.cuda.device_count() < 2:
241243
self.skipTest("Not enough GPUs for CUDA test")
244+
if sys.platform == "darwin":
245+
self.skipTest("not reliable on mac")
242246

243247
lighthouse = LighthouseServer(bind="[::]:0", min_replicas=2)
244248
num_replicas = 2
@@ -290,6 +294,8 @@ def test_diloco_recovery(self, use_cuda: bool) -> None:
290294
# Skip the test if use_cuda is True and there are not enough GPUs
291295
if use_cuda and torch.cuda.device_count() < 2:
292296
self.skipTest("Not enough GPUs for CUDA test")
297+
if sys.platform == "darwin":
298+
self.skipTest("not reliable on mac")
293299

294300
lighthouse = LighthouseServer(
295301
bind="[::]:0",
@@ -368,6 +374,8 @@ def test_streaming_diloco_recovery(self, use_cuda: bool) -> None:
368374
# Skip the test if use_cuda is True and there are not enough GPUs
369375
if use_cuda and torch.cuda.device_count() < 2:
370376
self.skipTest("Not enough GPUs for CUDA test")
377+
if sys.platform == "darwin":
378+
self.skipTest("not reliable on mac")
371379

372380
lighthouse = LighthouseServer(
373381
bind="[::]:0",
@@ -441,6 +449,8 @@ def test_streaming_diloco_upscale(
441449
# Skip the test if use_cuda is True and there are not enough GPUs
442450
if use_cuda and torch.cuda.device_count() < 2:
443451
self.skipTest("Not enough GPUs for CUDA test")
452+
if sys.platform == "darwin":
453+
self.skipTest("not reliable on mac")
444454

445455
lighthouse = LighthouseServer(
446456
bind="[::]:0",
@@ -515,6 +525,8 @@ def test_streaming_diloco_commit_failure(
515525
# Skip the test if use_cuda is True and there are not enough GPUs
516526
if use_cuda and torch.cuda.device_count() < 2:
517527
self.skipTest("Not enough GPUs for CUDA test")
528+
if sys.platform == "darwin":
529+
self.skipTest("not reliable on mac")
518530

519531
lighthouse = LighthouseServer(
520532
bind="[::]:0",

torchft/manager.py

Lines changed: 34 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
from torchft.checkpointing import CheckpointTransport, HTTPTransport
5959
from torchft.checkpointing._rwlock import RWLock
6060
from torchft.futures import future_timeout
61+
from torchft.utils import get_stream_context, synchronize
6162
from torchft.work import _DummyWork
6263

6364
if TYPE_CHECKING:
@@ -276,12 +277,12 @@ def __init__(
276277
self._pg = pg
277278
self._manager: Optional[ManagerServer] = None
278279

279-
self._recovery_stream: Optional["torch.cuda.Stream"] = (
280-
torch.cuda.Stream() if torch.cuda.is_available() else None
280+
self._recovery_stream: Optional["torch.Stream"] = (
281+
torch.Stream() if torch.accelerator.is_available() else None
281282
)
282283

283284
# Used to synchronize recovery operation
284-
self._recovery_event: Optional[torch.cuda.Event] = None
285+
self._recovery_event: Optional[torch.Event] = None
285286

286287
if self._group_rank == 0:
287288
if port is None:
@@ -414,7 +415,11 @@ def allreduce(
414415
# it later.
415416
if should_quantize and IS_TRITON_AVAILABLE:
416417
work = allreduce_quantized(
417-
[tensor], ReduceOp.SUM, self._pg, torch.cuda.current_stream()
418+
[tensor],
419+
ReduceOp.SUM,
420+
self._pg,
421+
# pyre-fixme[6]: Expected `Optional[streams.Stream]` but got `_C.Stream`
422+
torch.accelerator.current_stream(),
418423
)
419424
else:
420425
work = self._pg.allreduce([tensor], ReduceOp.SUM)
@@ -482,8 +487,10 @@ def wrap_future(
482487

483488
fut = future_timeout(fut, timeout or self._timeout)
484489

485-
stream: Optional[torch.cuda.Stream] = (
486-
torch.cuda.current_stream() if torch.cuda.is_available() else None
490+
stream: Optional[torch.Stream] = (
491+
torch.accelerator.current_stream()
492+
if torch.accelerator.is_available()
493+
else None
487494
)
488495

489496
# schedule error handling as a continuation on the Future
@@ -492,7 +499,7 @@ def callback(
492499
) -> T:
493500
nonlocal default, stream
494501

495-
with torch.cuda.stream(stream) if stream is not None else nullcontext():
502+
with get_stream_context(stream):
496503
try:
497504
return fut.value()
498505
except Exception as e:
@@ -546,7 +553,9 @@ def start_quorum(
546553
shrink_only=shrink_only,
547554
quorum_timeout=timeout or self._quorum_timeout,
548555
curr_device=(
549-
torch.cuda.current_device() if torch.cuda.is_available() else -1
556+
torch.accelerator.current_device_index()
557+
if torch.accelerator.is_available()
558+
else -1
550559
),
551560
)
552561
if not self._use_async_quorum:
@@ -582,8 +591,8 @@ def _async_quorum(
582591
) -> None:
583592
torch.multiprocessing._set_thread_name("torchft_quorum")
584593

585-
if curr_device >= 0 and torch.cuda.is_available():
586-
torch.cuda.set_device(curr_device)
594+
if curr_device >= 0 and torch.accelerator.is_available():
595+
torch.accelerator.set_device_index(curr_device)
587596

588597
quorum = None
589598
with torch.profiler.record_function("torchft::manager::_client::_quorum"):
@@ -649,11 +658,7 @@ def _async_quorum(
649658
if allow_heal:
650659
# run recovery on the recovery stream if available
651660
recovery_stream = self._recovery_stream
652-
with (
653-
torch.cuda.stream(recovery_stream)
654-
if recovery_stream is not None
655-
else nullcontext()
656-
):
661+
with get_stream_context(recovery_stream):
657662
try:
658663
if quorum.recover_dst_replica_ranks:
659664
self._logger.info(
@@ -714,7 +719,7 @@ def _async_quorum(
714719
self.report_error(e)
715720

716721
self._recovery_event = (
717-
torch.cuda.current_stream().record_event()
722+
torch.accelerator.current_stream().record_event()
718723
if recovery_stream is not None
719724
else None
720725
)
@@ -784,8 +789,8 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
784789
with torch.profiler.record_function(
785790
"torchft::manager::should_commit::current_stream::synchronize"
786791
):
787-
if torch.cuda.is_available():
788-
torch.cuda.current_stream().synchronize()
792+
if torch.accelerator.is_available():
793+
synchronize()
789794

790795
if err := self._pg.errored():
791796
self.report_error(err)
@@ -1128,8 +1133,10 @@ def __init__(
11281133

11291134
# The stream used to created the `Work` - we ensure all operations in the future
11301135
# callback chain are executed on this stream
1131-
self._stream: Optional[torch.cuda.Stream] = (
1132-
torch.cuda.current_stream() if torch.cuda.is_available() else None
1136+
self._stream: Optional[torch.Stream] = (
1137+
torch.accelerator.current_stream()
1138+
if torch.accelerator.is_available()
1139+
else None
11331140
)
11341141

11351142
# To ensure the future callback chain is only created once
@@ -1164,11 +1171,7 @@ def callback(
11641171
nonlocal managed_fut, value
11651172
# change the stream to avoid making the callback stream
11661173
# dependent on process group stream running the allreduce
1167-
with (
1168-
torch.cuda.stream(self._stream)
1169-
if self._stream is not None
1170-
else nullcontext()
1171-
):
1174+
with get_stream_context(self._stream):
11721175
# Setup stream dependency
11731176
fut.wait()
11741177
assert managed_fut._callback
@@ -1199,36 +1202,24 @@ def _assert_same_stream(self) -> None:
11991202
This makes sure users of the API are aware about stream dependencies.
12001203
"""
12011204
if self._stream is not None:
1202-
assert self._stream == torch.cuda.current_stream()
1205+
assert self._stream == torch.accelerator.current_stream()
12031206

12041207
def wait(self, timeout: Optional[timedelta] = None) -> bool:
12051208
self._assert_same_stream()
12061209

1207-
with (
1208-
torch.cuda.stream(self._stream)
1209-
if self._stream is not None
1210-
else nullcontext()
1211-
):
1210+
with get_stream_context(self._stream):
12121211
self._work.wait()
12131212
self._set_future_callback()
12141213

1215-
with (
1216-
torch.cuda.stream(self._stream)
1217-
if self._stream is not None
1218-
else nullcontext()
1219-
):
1214+
with get_stream_context(self._stream):
12201215
self._managed_fut_tail.wait()
12211216

12221217
return True
12231218

12241219
def block_current_stream(self, timeout: Optional[timedelta] = None) -> None:
12251220
self._assert_same_stream()
12261221

1227-
with (
1228-
torch.cuda.stream(self._stream)
1229-
if self._stream is not None
1230-
else nullcontext()
1231-
):
1222+
with get_stream_context(self._stream):
12321223
self._work.block_current_stream()
12331224

12341225
self._set_future_callback()
@@ -1238,6 +1229,8 @@ def synchronize(self) -> None:
12381229

12391230
if torch.cuda.is_available():
12401231
self.block_current_stream()
1232+
elif torch.xpu.is_available():
1233+
self._set_future_callback()
12411234
else:
12421235
# No stream dependencies need to be set
12431236
self._set_future_callback()

0 commit comments

Comments
 (0)