|
58 | 58 | from torchft.checkpointing import CheckpointTransport, HTTPTransport
|
59 | 59 | from torchft.checkpointing._rwlock import RWLock
|
60 | 60 | from torchft.futures import future_timeout
|
| 61 | +from torchft.utils import get_stream_context |
61 | 62 | from torchft.work import _DummyWork
|
62 | 63 |
|
63 | 64 | if TYPE_CHECKING:
|
@@ -276,10 +277,9 @@ def __init__(
|
276 | 277 | self._pg = pg
|
277 | 278 | self._manager: Optional[ManagerServer] = None
|
278 | 279 |
|
279 |
| - if torch.accelerator.is_available(): |
280 |
| - self._recovery_stream: Optional["torch.Stream"] = torch.Stream() |
281 |
| - else: |
282 |
| - self._recovery_stream = None |
| 280 | + self._recovery_stream: Optional["torch.Stream"] = ( |
| 281 | + torch.Stream() if torch.accelerator.is_available() else None |
| 282 | + ) |
283 | 283 |
|
284 | 284 | # Used to synchronize recovery operation
|
285 | 285 | self._recovery_event: Optional[torch.Event] = None
|
@@ -414,13 +414,9 @@ def allreduce(
|
414 | 414 | # Run the allreduce async and save the work object so we can wait on
|
415 | 415 | # it later.
|
416 | 416 | if should_quantize and IS_TRITON_AVAILABLE:
|
417 |
| - if torch.accelerator.is_available(): |
418 |
| - current_stream = torch.accelerator.current_stream() |
419 |
| - else: |
420 |
| - current_stream = None |
421 |
| - |
| 417 | + # pyre-fixme[6]: Expected `Optional[streams.Stream]` but got `_C.Stream` |
422 | 418 | work = allreduce_quantized(
|
423 |
| - [tensor], ReduceOp.SUM, self._pg, current_stream |
| 419 | + [tensor], ReduceOp.SUM, self._pg, torch.accelerator.current_stream() |
424 | 420 | )
|
425 | 421 | else:
|
426 | 422 | work = self._pg.allreduce([tensor], ReduceOp.SUM)
|
@@ -488,27 +484,19 @@ def wrap_future(
|
488 | 484 |
|
489 | 485 | fut = future_timeout(fut, timeout or self._timeout)
|
490 | 486 |
|
491 |
| - if torch.accelerator.is_available(): |
492 |
| - stream = torch.accelerator.current_stream() |
493 |
| - else: |
494 |
| - stream = None |
| 487 | + stream: Optional[torch.Stream] = ( |
| 488 | + torch.accelerator.current_stream() |
| 489 | + if torch.accelerator.is_available() |
| 490 | + else None |
| 491 | + ) |
495 | 492 |
|
496 | 493 | # schedule error handling as a continuation on the Future
|
497 | 494 | def callback(
|
498 | 495 | fut: torch.futures.Future[T],
|
499 | 496 | ) -> T:
|
500 | 497 | nonlocal default, stream
|
501 | 498 |
|
502 |
| - if stream is not None: |
503 |
| - if torch.cuda.is_available(): |
504 |
| - context = torch.cuda.stream(stream) |
505 |
| - elif torch.xpu.is_available(): |
506 |
| - context = torch.xpu.stream(stream) |
507 |
| - else: |
508 |
| - context = nullcontext() |
509 |
| - else: |
510 |
| - context = nullcontext() |
511 |
| - with context: |
| 499 | + with get_stream_context(stream): |
512 | 500 | try:
|
513 | 501 | return fut.value()
|
514 | 502 | except Exception as e:
|
@@ -562,7 +550,9 @@ def start_quorum(
|
562 | 550 | shrink_only=shrink_only,
|
563 | 551 | quorum_timeout=timeout or self._quorum_timeout,
|
564 | 552 | curr_device=(
|
565 |
| - torch.accelerator.current_device_index() if torch.accelerator.is_available() else -1 |
| 553 | + torch.accelerator.current_device_index() |
| 554 | + if torch.accelerator.is_available() |
| 555 | + else -1 |
566 | 556 | ),
|
567 | 557 | )
|
568 | 558 | if not self._use_async_quorum:
|
@@ -598,11 +588,8 @@ def _async_quorum(
|
598 | 588 | ) -> None:
|
599 | 589 | torch.multiprocessing._set_thread_name("torchft_quorum")
|
600 | 590 |
|
601 |
| - if curr_device >= 0: |
602 |
| - if torch.cuda.is_available(): |
603 |
| - torch.cuda.set_device(curr_device) |
604 |
| - elif torch.xpu.is_available(): |
605 |
| - torch.xpu.set_device(curr_device) |
| 591 | + if curr_device >= 0 and torch.accelerator.is_available(): |
| 592 | + torch.accelerator.set_device_index(curr_device) |
606 | 593 |
|
607 | 594 | quorum = None
|
608 | 595 | with torch.profiler.record_function("torchft::manager::_client::_quorum"):
|
@@ -668,17 +655,7 @@ def _async_quorum(
|
668 | 655 | if allow_heal:
|
669 | 656 | # run recovery on the recovery stream if available
|
670 | 657 | recovery_stream = self._recovery_stream
|
671 |
| - if recovery_stream is not None: |
672 |
| - if torch.cuda.is_available(): |
673 |
| - stream_context = torch.cuda.stream(recovery_stream) |
674 |
| - elif torch.xpu.is_available(): |
675 |
| - stream_context = torch.xpu.stream(recovery_stream) |
676 |
| - else: |
677 |
| - stream_context = nullcontext() |
678 |
| - else: |
679 |
| - stream_context = nullcontext() |
680 |
| - |
681 |
| - with stream_context: |
| 658 | + with get_stream_context(recovery_stream): |
682 | 659 | try:
|
683 | 660 | if quorum.recover_dst_replica_ranks:
|
684 | 661 | self._logger.info(
|
@@ -1154,7 +1131,9 @@ def __init__(
|
1154 | 1131 | # The stream used to created the `Work` - we ensure all operations in the future
|
1155 | 1132 | # callback chain are executed on this stream
|
1156 | 1133 | self._stream: Optional[torch.Stream] = (
|
1157 |
| - torch.accelerator.current_stream() if torch.accelerator.is_available() else None |
| 1134 | + torch.accelerator.current_stream() |
| 1135 | + if torch.accelerator.is_available() |
| 1136 | + else None |
1158 | 1137 | )
|
1159 | 1138 |
|
1160 | 1139 | # To ensure the future callback chain is only created once
|
@@ -1189,16 +1168,7 @@ def callback(
|
1189 | 1168 | nonlocal managed_fut, value
|
1190 | 1169 | # change the stream to avoid making the callback stream
|
1191 | 1170 | # dependent on process group stream running the allreduce
|
1192 |
| - if self._stream is not None: |
1193 |
| - if torch.cuda.is_available(): |
1194 |
| - context = torch.cuda.stream(self._stream) |
1195 |
| - elif torch.xpu.is_available(): |
1196 |
| - context = torch.xpu.stream(self._stream) |
1197 |
| - else: |
1198 |
| - context = nullcontext() |
1199 |
| - else: |
1200 |
| - context = nullcontext() |
1201 |
| - with context: |
| 1171 | + with get_stream_context(self._stream): |
1202 | 1172 | # Setup stream dependency
|
1203 | 1173 | fut.wait()
|
1204 | 1174 | assert managed_fut._callback
|
@@ -1234,46 +1204,19 @@ def _assert_same_stream(self) -> None:
|
1234 | 1204 | def wait(self, timeout: Optional[timedelta] = None) -> bool:
|
1235 | 1205 | self._assert_same_stream()
|
1236 | 1206 |
|
1237 |
| - if self._stream is not None: |
1238 |
| - if torch.cuda.is_available(): |
1239 |
| - context = torch.cuda.stream(self._stream) |
1240 |
| - elif torch.xpu.is_available(): |
1241 |
| - context = torch.xpu.stream(self._stream) |
1242 |
| - else: |
1243 |
| - context = nullcontext() |
1244 |
| - else: |
1245 |
| - context = nullcontext() |
1246 |
| - with context: |
| 1207 | + with get_stream_context(self._stream): |
1247 | 1208 | self._work.wait()
|
1248 | 1209 | self._set_future_callback()
|
1249 | 1210 |
|
1250 |
| - if self._stream is not None: |
1251 |
| - if torch.cuda.is_available(): |
1252 |
| - context = torch.cuda.stream(self._stream) |
1253 |
| - elif torch.xpu.is_available(): |
1254 |
| - context = torch.xpu.stream(self._stream) |
1255 |
| - else: |
1256 |
| - context = nullcontext() |
1257 |
| - else: |
1258 |
| - context = nullcontext() |
1259 |
| - with context: |
| 1211 | + with get_stream_context(self._stream): |
1260 | 1212 | self._managed_fut_tail.wait()
|
1261 | 1213 |
|
1262 | 1214 | return True
|
1263 | 1215 |
|
1264 | 1216 | def block_current_stream(self, timeout: Optional[timedelta] = None) -> None:
|
1265 | 1217 | self._assert_same_stream()
|
1266 | 1218 |
|
1267 |
| - if self._stream is not None: |
1268 |
| - if torch.cuda.is_available(): |
1269 |
| - context = torch.cuda.stream(self._stream) |
1270 |
| - elif torch.xpu.is_available(): |
1271 |
| - context = torch.xpu.stream(self._stream) |
1272 |
| - else: |
1273 |
| - context = nullcontext() |
1274 |
| - else: |
1275 |
| - context = nullcontext() |
1276 |
| - with context: |
| 1219 | + with get_stream_context(self._stream): |
1277 | 1220 | self._work.block_current_stream()
|
1278 | 1221 |
|
1279 | 1222 | self._set_future_callback()
|
|
0 commit comments