Skip to content

Commit 647979b

Browse files
committed
use global rank for flight recorder
1 parent 8579601 commit 647979b

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

torchft/process_group.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -738,7 +738,6 @@ def __init__(self, timeout: timedelta = timedelta(seconds=60.0)) -> None:
738738
self._use_abort: bool = torch.cuda.nccl.version() >= (2, 25)
739739

740740
self._errored: Optional[Exception] = None
741-
self._rank: int = 0
742741

743742
NONBLOCKING_TIMEOUT_ENV = "TORCH_NCCL_NONBLOCKING_TIMEOUT"
744743
if NONBLOCKING_TIMEOUT_ENV not in os.environ:
@@ -788,7 +787,6 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro
788787
from torch.distributed import ProcessGroupNCCL as BaseProcessGroupNCCL
789788

790789
self._errored = None
791-
self._rank = rank
792790

793791
# pyre-fixme[16]: no attribute ProcessGroupNCCL
794792
opts = BaseProcessGroupNCCL.Options()
@@ -811,8 +809,8 @@ def abort(self) -> None:
811809
# We need to set the error before aborting to ensure that errored()
812810
# returns the error correctly when NCCL abort fires and unblocks the
813811
# stream.
814-
if os.environ.get("TORCHFT_TRIGGER_FR_ON_ABORT", "true") == "true":
815-
trigger_nccl_fr_trace_through_pipe(self._rank)
812+
if os.environ.get("TORCHFT_TRIGGER_FR_ON_ABORT", "false") == "true":
813+
trigger_nccl_fr_trace_through_pipe(dist.get_rank())
816814
self._errored = RuntimeError("aborted")
817815

818816
super().abort()

0 commit comments

Comments
 (0)