Skip to content

Commit 41829de

Browse files
committed
Lint fixes
1 parent ff78244 commit 41829de

File tree

5 files changed

+21
-12
lines changed

5 files changed

+21
-12
lines changed

torchft/checkpointing/transport_test.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,7 @@ def run_multi_recovery_test(
6060
metadata: str = ""
6161

6262
dist.init_process_group(
63-
backend="gloo",
64-
rank=0,
65-
world_size=1,
66-
store=dist.HashStore(), # pyre-fixme[6]: Expected `Optional[Store]` but got `HashStore`
63+
backend="gloo", rank=0, world_size=1, store=dist.HashStore()
6764
)
6865

6966
def run(rank: int) -> CheckpointTransport[dict[str, object]]:

torchft/manager.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
from torchft.checkpointing import CheckpointTransport, HTTPTransport
5959
from torchft.checkpointing._rwlock import RWLock
6060
from torchft.futures import future_timeout
61-
from torchft.utils import get_stream_context
61+
from torchft.utils import current_stream, get_stream_context
6262
from torchft.work import _DummyWork
6363

6464
if TYPE_CHECKING:
@@ -790,7 +790,9 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
790790
"torchft::manager::should_commit::current_stream::synchronize"
791791
):
792792
if torch.accelerator.is_available():
793-
torch.accelerator.current_stream().synchronize()
793+
794+
# pyre-fixme[16]: no attribute synchronize
795+
current_stream().synchronize()
794796

795797
if err := self._pg.errored():
796798
self.report_error(err)

torchft/process_group.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969
from torchft.device_mesh import * # noqa: F401
7070
from torchft.futures import context_timeout, stream_timeout
7171
from torchft.multiprocessing import _MonitoredPipe
72-
from torchft.utils import get_stream_context, record_event
72+
from torchft.utils import current_stream, get_stream_context, record_event
7373
from torchft.work import _DummyWork
7474

7575
if TYPE_CHECKING:
@@ -793,7 +793,8 @@ def abort(self) -> None:
793793

794794
def errored(self) -> Optional[Exception]:
795795
# force a synchronization to ensure all work is complete
796-
torch.accelerator.current_stream().synchronize()
796+
# pyre-fixme[16]: no attribute synchronize
797+
current_stream().synchronize()
797798

798799
return self._errored
799800

@@ -877,6 +878,7 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro
877878

878879
self._errored = None
879880

881+
# pyre-fixme[16]: no attribute ProcessGroupXCCL
880882
opts = BaseProcessGroupXCCL.Options()
881883
# opts.config.blocking = False
882884

torchft/process_group_test.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,7 @@
5858
def dummy_init_pg() -> None:
5959
if not dist.is_initialized():
6060
dist.init_process_group(
61-
backend="gloo",
62-
rank=0,
63-
world_size=1,
64-
store=dist.HashStore(), # pyre-fixme[6]: Expected `Optional[Store]` but got `HashStore`
61+
backend="gloo", rank=0, world_size=1, store=dist.HashStore()
6562
)
6663

6764

torchft/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,14 @@ def record_event() -> None:
5454
torch.xpu.current_stream().record_event(torch.xpu.Event())
5555
else:
5656
torch.cuda.current_stream().record_event(torch.cuda.Event(interprocess=True))
57+
58+
59+
def current_stream() -> None:
60+
"""
61+
This function provides a unified way to get current stream across different
62+
accelerator types (CUDA, XPU).
63+
"""
64+
if torch.xpu.is_available():
65+
torch.xpu.current_stream()
66+
else:
67+
torch.cuda.current_stream()

0 commit comments

Comments
 (0)