diff --git a/README.md b/README.md index aa67f592..61e37d7d 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ Easy Per Step Fault Tolerance for PyTorch | Documentation | Poster | Design Doc - | + |

PyPI - Version @@ -98,7 +98,7 @@ when using synchronous training. You can start a lighthouse server by running: ```sh -$ RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 1 --quorum_tick_ms 100 --join_timeout_ms 1000 +$ RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 1 --quorum_tick_ms 100 --join_timeout_ms 10000 ``` ### Example Training Loop (DDP) @@ -108,7 +108,7 @@ See [train_ddp.py](./train_ddp.py) for the full example. Invoke with: ```sh -$ TORCHFT_MANAGER_PORT=29512 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port 29501 --nnodes 1 --nproc_per_node 1 train.py +$ TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port 29501 --nnodes 1 --nproc_per_node 1 train.py ``` train.py: diff --git a/src/lighthouse.rs b/src/lighthouse.rs index f151fbf9..643aef1b 100644 --- a/src/lighthouse.rs +++ b/src/lighthouse.rs @@ -77,7 +77,7 @@ pub struct LighthouseOpt { #[structopt( long = "join_timeout_ms", default_value = "60000", - help = "How long to wait for new replicas to join before considering a quorum" + help = "How long to wait for heartbeating stragglers to join before issuing quorum" )] pub join_timeout_ms: u64, @@ -90,14 +90,14 @@ pub struct LighthouseOpt { #[structopt( long = "quorum_tick_ms", default_value = "100", - help = "How frequently to check for quorum when waiting for workers." + help = "How frequently to check for quorum when waiting for stragglers." )] pub quorum_tick_ms: u64, #[structopt( long = "heartbeat_timeout_ms", default_value = "5000", - help = "how long to wait for a heartbeat before considering a replica dead." + help = "How long to wait for a heartbeat before considering a replica dead." )] pub heartbeat_timeout_ms: u64, } diff --git a/train_ddp.py b/train_ddp.py index 9ad9cc85..4bcc0297 100644 --- a/train_ddp.py +++ b/train_ddp.py @@ -7,6 +7,7 @@ import logging import os import sys +from datetime import timedelta import torch import torch.nn.functional as F @@ -70,7 +71,13 @@ def state_dict(): } device = "cuda" if torch.cuda.is_available() else "cpu" - pg = ProcessGroupBabyNCCL() if torch.cuda.is_available() else ProcessGroupGloo() + pg = ( + ProcessGroupBabyNCCL( + timeout=timedelta(seconds=5), + ) + if torch.cuda.is_available() + else ProcessGroupGloo(timeout=timedelta(seconds=5)) + ) manager = Manager( pg=pg, @@ -78,6 +85,7 @@ def state_dict(): load_state_dict=load_state_dict, state_dict=state_dict, replica_id=f"train_ddp_{REPLICA_GROUP_ID}", + timeout=timedelta(seconds=10), ) class Net(nn.Module):