diff --git a/torchft/process_group.py b/torchft/process_group.py index 4be8f51..4e10fba 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -738,7 +738,6 @@ def __init__(self, timeout: timedelta = timedelta(seconds=60.0)) -> None: self._use_abort: bool = torch.cuda.nccl.version() >= (2, 25) self._errored: Optional[Exception] = None - self._rank: int = 0 NONBLOCKING_TIMEOUT_ENV = "TORCH_NCCL_NONBLOCKING_TIMEOUT" if NONBLOCKING_TIMEOUT_ENV not in os.environ: @@ -788,7 +787,6 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro from torch.distributed import ProcessGroupNCCL as BaseProcessGroupNCCL self._errored = None - self._rank = rank # pyre-fixme[16]: no attribute ProcessGroupNCCL opts = BaseProcessGroupNCCL.Options() @@ -811,8 +809,8 @@ def abort(self) -> None: # We need to set the error before aborting to ensure that errored() # returns the error correctly when NCCL abort fires and unblocks the # stream. - if os.environ.get("TORCHFT_TRIGGER_FR_ON_ABORT", "true") == "true": - trigger_nccl_fr_trace_through_pipe(self._rank) + if os.environ.get("TORCHFT_TRIGGER_FR_ON_ABORT", "false") == "true": + trigger_nccl_fr_trace_through_pipe(dist.get_rank()) self._errored = RuntimeError("aborted") super().abort()