58
58
from torchft .checkpointing import CheckpointTransport , HTTPTransport
59
59
from torchft .checkpointing ._rwlock import RWLock
60
60
from torchft .futures import future_timeout
61
+ from torchft .utils import get_stream_context , synchronize
61
62
from torchft .work import _DummyWork
62
63
63
64
if TYPE_CHECKING :
@@ -276,12 +277,12 @@ def __init__(
276
277
self ._pg = pg
277
278
self ._manager : Optional [ManagerServer ] = None
278
279
279
- self ._recovery_stream : Optional ["torch.cuda. Stream" ] = (
280
- torch .cuda . Stream () if torch .cuda .is_available () else None
280
+ self ._recovery_stream : Optional ["torch.Stream" ] = (
281
+ torch .Stream () if torch .accelerator .is_available () else None
281
282
)
282
283
283
284
# Used to synchronize recovery operation
284
- self ._recovery_event : Optional [torch .cuda . Event ] = None
285
+ self ._recovery_event : Optional [torch .Event ] = None
285
286
286
287
if self ._group_rank == 0 :
287
288
if port is None :
@@ -414,7 +415,11 @@ def allreduce(
414
415
# it later.
415
416
if should_quantize and IS_TRITON_AVAILABLE :
416
417
work = allreduce_quantized (
417
- [tensor ], ReduceOp .SUM , self ._pg , torch .cuda .current_stream ()
418
+ [tensor ],
419
+ ReduceOp .SUM ,
420
+ self ._pg ,
421
+ # pyre-fixme[6]: Expected `Optional[streams.Stream]` but got `_C.Stream`
422
+ torch .accelerator .current_stream (),
418
423
)
419
424
else :
420
425
work = self ._pg .allreduce ([tensor ], ReduceOp .SUM )
@@ -482,8 +487,10 @@ def wrap_future(
482
487
483
488
fut = future_timeout (fut , timeout or self ._timeout )
484
489
485
- stream : Optional [torch .cuda .Stream ] = (
486
- torch .cuda .current_stream () if torch .cuda .is_available () else None
490
+ stream : Optional [torch .Stream ] = (
491
+ torch .accelerator .current_stream ()
492
+ if torch .accelerator .is_available ()
493
+ else None
487
494
)
488
495
489
496
# schedule error handling as a continuation on the Future
@@ -492,7 +499,7 @@ def callback(
492
499
) -> T :
493
500
nonlocal default , stream
494
501
495
- with torch . cuda . stream (stream ) if stream is not None else nullcontext ( ):
502
+ with get_stream_context (stream ):
496
503
try :
497
504
return fut .value ()
498
505
except Exception as e :
@@ -546,7 +553,9 @@ def start_quorum(
546
553
shrink_only = shrink_only ,
547
554
quorum_timeout = timeout or self ._quorum_timeout ,
548
555
curr_device = (
549
- torch .cuda .current_device () if torch .cuda .is_available () else - 1
556
+ torch .accelerator .current_device_index ()
557
+ if torch .accelerator .is_available ()
558
+ else - 1
550
559
),
551
560
)
552
561
if not self ._use_async_quorum :
@@ -582,8 +591,8 @@ def _async_quorum(
582
591
) -> None :
583
592
torch .multiprocessing ._set_thread_name ("torchft_quorum" )
584
593
585
- if curr_device >= 0 and torch .cuda .is_available ():
586
- torch .cuda . set_device (curr_device )
594
+ if curr_device >= 0 and torch .accelerator .is_available ():
595
+ torch .accelerator . set_device_index (curr_device )
587
596
588
597
quorum = None
589
598
with torch .profiler .record_function ("torchft::manager::_client::_quorum" ):
@@ -649,11 +658,7 @@ def _async_quorum(
649
658
if allow_heal :
650
659
# run recovery on the recovery stream if available
651
660
recovery_stream = self ._recovery_stream
652
- with (
653
- torch .cuda .stream (recovery_stream )
654
- if recovery_stream is not None
655
- else nullcontext ()
656
- ):
661
+ with get_stream_context (recovery_stream ):
657
662
try :
658
663
if quorum .recover_dst_replica_ranks :
659
664
self ._logger .info (
@@ -714,7 +719,7 @@ def _async_quorum(
714
719
self .report_error (e )
715
720
716
721
self ._recovery_event = (
717
- torch .cuda .current_stream ().record_event ()
722
+ torch .accelerator .current_stream ().record_event ()
718
723
if recovery_stream is not None
719
724
else None
720
725
)
@@ -784,8 +789,8 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
784
789
with torch .profiler .record_function (
785
790
"torchft::manager::should_commit::current_stream::synchronize"
786
791
):
787
- if torch .cuda .is_available ():
788
- torch . cuda . current_stream (). synchronize ()
792
+ if torch .accelerator .is_available ():
793
+ synchronize ()
789
794
790
795
if err := self ._pg .errored ():
791
796
self .report_error (err )
@@ -1128,8 +1133,10 @@ def __init__(
1128
1133
1129
1134
# The stream used to created the `Work` - we ensure all operations in the future
1130
1135
# callback chain are executed on this stream
1131
- self ._stream : Optional [torch .cuda .Stream ] = (
1132
- torch .cuda .current_stream () if torch .cuda .is_available () else None
1136
+ self ._stream : Optional [torch .Stream ] = (
1137
+ torch .accelerator .current_stream ()
1138
+ if torch .accelerator .is_available ()
1139
+ else None
1133
1140
)
1134
1141
1135
1142
# To ensure the future callback chain is only created once
@@ -1164,11 +1171,7 @@ def callback(
1164
1171
nonlocal managed_fut , value
1165
1172
# change the stream to avoid making the callback stream
1166
1173
# dependent on process group stream running the allreduce
1167
- with (
1168
- torch .cuda .stream (self ._stream )
1169
- if self ._stream is not None
1170
- else nullcontext ()
1171
- ):
1174
+ with get_stream_context (self ._stream ):
1172
1175
# Setup stream dependency
1173
1176
fut .wait ()
1174
1177
assert managed_fut ._callback
@@ -1199,36 +1202,24 @@ def _assert_same_stream(self) -> None:
1199
1202
This makes sure users of the API are aware about stream dependencies.
1200
1203
"""
1201
1204
if self ._stream is not None :
1202
- assert self ._stream == torch .cuda .current_stream ()
1205
+ assert self ._stream == torch .accelerator .current_stream ()
1203
1206
1204
1207
def wait (self , timeout : Optional [timedelta ] = None ) -> bool :
1205
1208
self ._assert_same_stream ()
1206
1209
1207
- with (
1208
- torch .cuda .stream (self ._stream )
1209
- if self ._stream is not None
1210
- else nullcontext ()
1211
- ):
1210
+ with get_stream_context (self ._stream ):
1212
1211
self ._work .wait ()
1213
1212
self ._set_future_callback ()
1214
1213
1215
- with (
1216
- torch .cuda .stream (self ._stream )
1217
- if self ._stream is not None
1218
- else nullcontext ()
1219
- ):
1214
+ with get_stream_context (self ._stream ):
1220
1215
self ._managed_fut_tail .wait ()
1221
1216
1222
1217
return True
1223
1218
1224
1219
def block_current_stream (self , timeout : Optional [timedelta ] = None ) -> None :
1225
1220
self ._assert_same_stream ()
1226
1221
1227
- with (
1228
- torch .cuda .stream (self ._stream )
1229
- if self ._stream is not None
1230
- else nullcontext ()
1231
- ):
1222
+ with get_stream_context (self ._stream ):
1232
1223
self ._work .block_current_stream ()
1233
1224
1234
1225
self ._set_future_callback ()
@@ -1238,6 +1229,8 @@ def synchronize(self) -> None:
1238
1229
1239
1230
if torch .cuda .is_available ():
1240
1231
self .block_current_stream ()
1232
+ elif torch .xpu .is_available ():
1233
+ self ._set_future_callback ()
1241
1234
else :
1242
1235
# No stream dependencies need to be set
1243
1236
self ._set_future_callback ()
0 commit comments