Skip to content

Commit a712b82

Browse files
committed
Fix UT failures due to record_event
1 parent f21ea28 commit a712b82

File tree

2 files changed

+18
-19
lines changed

2 files changed

+18
-19
lines changed

torchft/process_group.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969
from torchft.device_mesh import * # noqa: F401
7070
from torchft.futures import context_timeout, stream_timeout
7171
from torchft.multiprocessing import _MonitoredPipe
72-
from torchft.utils import get_stream_context
72+
from torchft.utils import get_stream_context, record_event
7373
from torchft.work import _DummyWork
7474

7575
if TYPE_CHECKING:
@@ -793,7 +793,7 @@ def abort(self) -> None:
793793

794794
def errored(self) -> Optional[Exception]:
795795
# force a synchronization to ensure all work is complete
796-
torch.cuda.current_stream().synchronize()
796+
torch.accelerator.current_stream().synchronize()
797797

798798
return self._errored
799799

@@ -1539,13 +1539,7 @@ def _worker(
15391539

15401540
# Register event on the stream that we can pass to the main
15411541
# process.
1542-
event = (
1543-
torch.accelerator.current_stream().record_event(
1544-
torch.Event(interprocess=True)
1545-
)
1546-
if metadata.stream is not None
1547-
else None
1548-
)
1542+
event = record_event() if metadata.stream is not None else None
15491543

15501544
req_pipe.send((op_id, event))
15511545
elif cmd == "del":
@@ -1562,9 +1556,7 @@ def callback(fut: Future[object], metadata: _OpMetadata) -> None:
15621556
with metadata.set_stream():
15631557
fut.wait()
15641558
event = (
1565-
torch.accelerator.current_stream().record_event(
1566-
torch.Event(interprocess=True)
1567-
)
1559+
record_event()
15681560
if metadata.stream is not None
15691561
else None
15701562
)
@@ -1660,13 +1652,7 @@ def _run_func(self, func: str, *args: object, **kwargs: object) -> Work:
16601652
stream_id = (
16611653
torch.accelerator.current_stream().stream_id if is_accelerator else None
16621654
)
1663-
event = (
1664-
torch.accelerator.current_stream().record_event(
1665-
torch.Event(interprocess=True)
1666-
)
1667-
if is_accelerator
1668-
else None
1669-
)
1655+
event = record_event() if is_accelerator else None
16701656

16711657
op_id = self._next_op_id
16721658
self._next_op_id += 1

torchft/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,16 @@ def get_stream_context(
4141
return nullcontext()
4242
else:
4343
return nullcontext()
44+
45+
46+
def record_event() -> None:
47+
"""
48+
Record an event in the current stream.
49+
50+
This function provides a unified way to record events across different
51+
accelerator types (CUDA, XPU).
52+
"""
53+
if torch.xpu.is_available():
54+
torch.xpu.current_stream().record_event(torch.xpu.Event(interprocess=True))
55+
else:
56+
torch.cuda.current_stream().record_event(torch.cuda.Event(interprocess=True))

0 commit comments

Comments
 (0)