Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions proto/torchft.proto
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ message QuorumMember {
string address = 2;
string store_address = 3;
int64 step = 4;
uint64 world_size = 5;
}

message Quorum {
Expand Down
7 changes: 7 additions & 0 deletions src/lighthouse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ mod tests {
address: "".to_string(),
store_address: "".to_string(),
step: 1,
world_size: 1,
},
},
);
Expand Down Expand Up @@ -495,6 +496,7 @@ mod tests {
address: "".to_string(),
store_address: "".to_string(),
step: 1,
world_size: 1,
},
},
);
Expand All @@ -511,6 +513,7 @@ mod tests {
address: "".to_string(),
store_address: "".to_string(),
step: 1,
world_size: 1,
}],
created: Some(SystemTime::now().into()),
});
Expand Down Expand Up @@ -550,6 +553,7 @@ mod tests {
address: "".to_string(),
store_address: "".to_string(),
step: 10,
world_size: 1,
}),
});

Expand All @@ -568,12 +572,14 @@ mod tests {
address: "".to_string(),
store_address: "".to_string(),
step: 1,
world_size: 1,
}];
let b = vec![QuorumMember {
replica_id: "1".to_string(),
address: "changed".to_string(),
store_address: "changed".to_string(),
step: 1000,
world_size: 1,
}];

// replica_id is the same
Expand All @@ -584,6 +590,7 @@ mod tests {
address: "".to_string(),
store_address: "".to_string(),
step: 1,
world_size: 1,
}];
// replica_id changed
assert!(quorum_changed(&a, &c));
Expand Down
1 change: 1 addition & 0 deletions src/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ impl ManagerService for Arc<Manager> {
address: self.address.clone(),
store_address: self.store_address.clone(),
step: req.step,
world_size: self.world_size,
}),
});

Expand Down
3 changes: 2 additions & 1 deletion templates/status.html
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ <h3>Previous Quorum</h3>
<b>{{ member.replica_id }}</b> <br/>
Step: {{ member.step }} <br/>
Manager: {{ member.address }} <br/>
TCPStore: {{ member.store_address }}
TCPStore: {{ member.store_address }} <br/>
World size: {{ member.world_size }} <br/>

<button hx-post="/replica/{{member.replica_id}}/kill"
hx-trigger="click">
Expand Down
7 changes: 5 additions & 2 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
store_addr: Optional[str] = None,
store_port: Optional[int] = None,
lighthouse_addr: Optional[str] = None,
replica_id: Optional[str] = None,
) -> None:
"""
Args:
Expand All @@ -62,7 +63,8 @@ def __init__(
world_size: the replica group local world size
store_addr: TCPStore address for this replica group
store_port: TCPStore port for this replica group
ligthouse_addr: if rank==0, the address of the lighthouse server
lighthouse_addr: if rank==0, the address of the lighthouse server
replica_id: if rank==0, the replica_id for this group
"""
self._load_state_dict = load_state_dict
self._state_dict = state_dict
Expand Down Expand Up @@ -99,7 +101,8 @@ def __init__(
bind = f"[::]:{port}"
lighthouse_addr = lighthouse_addr or os.environ["TORCHFT_LIGHTHOUSE"]

replica_id = str(uuid.uuid4())
if replica_id is None:
replica_id = str(uuid.uuid4())
# pyre-fixme[16]: can't find rust module
self._manager = _Manager(
replica_id=replica_id,
Expand Down
12 changes: 8 additions & 4 deletions train_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@


def main() -> None:
REPLICA_GROUP_ID = int(os.environ.get("REPLICA_GROUP_ID", 0))
NUM_REPLICA_GROUPS = int(os.environ.get("NUM_REPLICA_GROUPS", 2))

transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
Expand All @@ -40,8 +43,8 @@ def main() -> None:
# majority of groups will be available so few batches will be dropped.
sampler = DistributedSampler(
trainset,
replica_group=int(os.environ.get("REPLICA_GROUP_ID", 0)),
num_replica_groups=int(os.environ.get("NUM_REPLICA_GROUPS", 2)),
replica_group=REPLICA_GROUP_ID,
num_replica_groups=NUM_REPLICA_GROUPS,
rank=0,
# for DDP we can use replica groups of size 1, FSDP/PP/CP would need more.
num_replicas=1,
Expand All @@ -50,7 +53,7 @@ def main() -> None:
# This uses the torchdata StatefulDataLoader to be able to checkpoint and
# restore the per worker dataloader position.
trainloader = StatefulDataLoader(
trainset, batch_size=2, shuffle=True, num_workers=2
trainset, batch_size=64, shuffle=True, num_workers=2
)

def load_state_dict(state_dict):
Expand All @@ -68,9 +71,10 @@ def state_dict():

manager = Manager(
pg=pg,
min_replica_size=2,
min_replica_size=1,
load_state_dict=load_state_dict,
state_dict=state_dict,
replica_id=f"train_ddp_{REPLICA_GROUP_ID}",
)

class Net(nn.Module):
Expand Down
Loading