69
69
from torchft .device_mesh import * # noqa: F401
70
70
from torchft .futures import context_timeout , stream_timeout
71
71
from torchft .multiprocessing import _MonitoredPipe
72
- from torchft .utils import get_stream_context
72
+ from torchft .utils import get_stream_context , record_event
73
73
from torchft .work import _DummyWork
74
74
75
75
if TYPE_CHECKING :
@@ -793,7 +793,7 @@ def abort(self) -> None:
793
793
794
794
def errored (self ) -> Optional [Exception ]:
795
795
# force a synchronization to ensure all work is complete
796
- torch .cuda .current_stream ().synchronize ()
796
+ torch .accelerator .current_stream ().synchronize ()
797
797
798
798
return self ._errored
799
799
@@ -1539,13 +1539,7 @@ def _worker(
1539
1539
1540
1540
# Register event on the stream that we can pass to the main
1541
1541
# process.
1542
- event = (
1543
- torch .accelerator .current_stream ().record_event (
1544
- torch .Event (interprocess = True )
1545
- )
1546
- if metadata .stream is not None
1547
- else None
1548
- )
1542
+ event = record_event () if metadata .stream is not None else None
1549
1543
1550
1544
req_pipe .send ((op_id , event ))
1551
1545
elif cmd == "del" :
@@ -1562,9 +1556,7 @@ def callback(fut: Future[object], metadata: _OpMetadata) -> None:
1562
1556
with metadata .set_stream ():
1563
1557
fut .wait ()
1564
1558
event = (
1565
- torch .accelerator .current_stream ().record_event (
1566
- torch .Event (interprocess = True )
1567
- )
1559
+ record_event ()
1568
1560
if metadata .stream is not None
1569
1561
else None
1570
1562
)
@@ -1660,13 +1652,7 @@ def _run_func(self, func: str, *args: object, **kwargs: object) -> Work:
1660
1652
stream_id = (
1661
1653
torch .accelerator .current_stream ().stream_id if is_accelerator else None
1662
1654
)
1663
- event = (
1664
- torch .accelerator .current_stream ().record_event (
1665
- torch .Event (interprocess = True )
1666
- )
1667
- if is_accelerator
1668
- else None
1669
- )
1655
+ event = record_event () if is_accelerator else None
1670
1656
1671
1657
op_id = self ._next_op_id
1672
1658
self ._next_op_id += 1
0 commit comments