Skip to content

Commit 9c059ce

Browse files
authored
add logging (#265)
1 parent f8035dd commit 9c059ce

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

torchft/futures.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import os
23
import queue
34
import sys
45
import threading
@@ -15,6 +16,8 @@
1516

1617
T = TypeVar("T")
1718

19+
WATCHDOG_TIMEOUT_SEC = "TORCHFT_WATCHDOG_TIMEOUT_SEC"
20+
1821

1922
class _TimerHandle:
2023
def __init__(self) -> None:
@@ -61,7 +64,9 @@ def __init__(self) -> None:
6164

6265
# Give this much time the the `_event_loop_thread` to confirm that
6366
# it is not stuck
64-
self._watchdog_interval = timedelta(seconds=30)
67+
self._watchdog_interval = timedelta(
68+
seconds=int(os.environ.get(WATCHDOG_TIMEOUT_SEC, "30"))
69+
)
6570

6671
# This queue is used to delete events on the main thread as cudaEventDestroy
6772
# can block if the CUDA queue is full.

torchft/local_sgd.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -762,6 +762,7 @@ def _step_post_hook(
762762
# can be overrepresented.
763763
self._manager.start_quorum()
764764
fragment = self._current_fragment()
765+
logger.info(f"Preparing fragment={fragment} step={self._local_step}")
765766
self._fragments[fragment].prepare_sync()
766767

767768
if self._local_step < self._sync_every:
@@ -770,6 +771,9 @@ def _step_post_hook(
770771
if self._local_step == self._sync_every:
771772
# Time to sync a fragment
772773
fragment = self._current_fragment()
774+
logger.info(
775+
f"Syncing fragment={fragment} step={self._local_step} manager_step={self._manager.current_step()}"
776+
)
773777
self._fragments[fragment].perform_sync()
774778

775779
# If the allreduce truly failed, we'll keep retrying this fragment.

0 commit comments

Comments
 (0)