diff --git a/README.md b/README.md index aa67f592..61e37d7d 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ Easy Per Step Fault Tolerance for PyTorch | Documentation | Poster | Design Doc - | + |
@@ -98,7 +98,7 @@ when using synchronous training.
You can start a lighthouse server by running:
```sh
-$ RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 1 --quorum_tick_ms 100 --join_timeout_ms 1000
+$ RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 1 --quorum_tick_ms 100 --join_timeout_ms 10000
```
### Example Training Loop (DDP)
@@ -108,7 +108,7 @@ See [train_ddp.py](./train_ddp.py) for the full example.
Invoke with:
```sh
-$ TORCHFT_MANAGER_PORT=29512 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port 29501 --nnodes 1 --nproc_per_node 1 train.py
+$ TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port 29501 --nnodes 1 --nproc_per_node 1 train.py
```
train.py:
diff --git a/src/lighthouse.rs b/src/lighthouse.rs
index f151fbf9..643aef1b 100644
--- a/src/lighthouse.rs
+++ b/src/lighthouse.rs
@@ -77,7 +77,7 @@ pub struct LighthouseOpt {
#[structopt(
long = "join_timeout_ms",
default_value = "60000",
- help = "How long to wait for new replicas to join before considering a quorum"
+ help = "How long to wait for heartbeating stragglers to join before issuing quorum"
)]
pub join_timeout_ms: u64,
@@ -90,14 +90,14 @@ pub struct LighthouseOpt {
#[structopt(
long = "quorum_tick_ms",
default_value = "100",
- help = "How frequently to check for quorum when waiting for workers."
+ help = "How frequently to check for quorum when waiting for stragglers."
)]
pub quorum_tick_ms: u64,
#[structopt(
long = "heartbeat_timeout_ms",
default_value = "5000",
- help = "how long to wait for a heartbeat before considering a replica dead."
+ help = "How long to wait for a heartbeat before considering a replica dead."
)]
pub heartbeat_timeout_ms: u64,
}
diff --git a/train_ddp.py b/train_ddp.py
index 9ad9cc85..4bcc0297 100644
--- a/train_ddp.py
+++ b/train_ddp.py
@@ -7,6 +7,7 @@
import logging
import os
import sys
+from datetime import timedelta
import torch
import torch.nn.functional as F
@@ -70,7 +71,13 @@ def state_dict():
}
device = "cuda" if torch.cuda.is_available() else "cpu"
- pg = ProcessGroupBabyNCCL() if torch.cuda.is_available() else ProcessGroupGloo()
+ pg = (
+ ProcessGroupBabyNCCL(
+ timeout=timedelta(seconds=5),
+ )
+ if torch.cuda.is_available()
+ else ProcessGroupGloo(timeout=timedelta(seconds=5))
+ )
manager = Manager(
pg=pg,
@@ -78,6 +85,7 @@ def state_dict():
load_state_dict=load_state_dict,
state_dict=state_dict,
replica_id=f"train_ddp_{REPLICA_GROUP_ID}",
+ timeout=timedelta(seconds=10),
)
class Net(nn.Module):