Skip to content

Commit 8e4bbfe

Browse files
committed
lint and feedback
1 parent 20660dd commit 8e4bbfe

File tree

2 files changed

+24
-17
lines changed

2 files changed

+24
-17
lines changed

examples/monarch/README.md

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
1-
Monarch-TorchFT-TorchTitan Distributed Training Orchestrator
1+
### Monarch-TorchFT-TorchTitan Distributed Training Orchestrator
22

3+
#### Overview
34
This script orchestrates fault-tolerant distributed training using TorchTitan and TorchMonarch
45
frameworks. It manages multiple training replicas across SLURM-scheduled compute nodes
56
with automatic failure recovery and TorchFT lighthouse coordination.
67

7-
PREREQUISITES:
8+
##### PREREQUISITES
89
- Access to a SLURM cluster with GPU nodes
9-
- Environment with nightly TorchFT, TorchTitan, and Monarch libraries installed.
1010
- TorchTitan training configuration file in script directory (debug_model.toml)
1111
- A training dataset (c4_test) and tokenizer in script directory
1212

13-
CONFIGURATION:
13+
##### CONFIGURATION
1414
Before running, update the cluster-specific constants:
1515
- MACHINE: TorchX named resource for your cluster (currently: "gpu.xlarge")
1616
- MACHINE_MEMORY: Memory per machine in MB (currently: 2062607)
1717
You can also override the resource configuration manually:
1818
- https://docs.pytorch.org/torchx/main/specs.html#resource
1919

20-
USAGE:
20+
##### USAGE
2121
python train_distributed.py --help
2222

2323
Basic usage with 2 replicas, each with 1 node and 8 GPUs:
@@ -30,21 +30,21 @@ USAGE:
3030
With remote TorchFT lighthouse:
3131
python train_distributed.py --remote-lighthouse
3232

33-
KEY COMPONENTS:
33+
##### KEY COMPONENTS
3434
- LighthouseActor: Coordination server for fault tolerance
3535
- TrainingActor: Individual trainer processes
3636
- ReplicaActor: Manages groups of trainers
3737
- OrchestrationManager: Top-level orchestration and failure recovery
3838

39-
FAILURE RECOVERY:
40-
- Automatic replica retry with configurable delays (PER_ATTEMPT_DELAY)
39+
##### FAILURE RECOVERY
40+
- Automatic retry with configurable delays (PER_ATTEMPT_DELAY)
4141
- New allocations after repeated failures (PROC_ATTEMPTS)
4242
- Maximum attempts per replica (MAX_ATTEMPT)
4343

44-
OUTPUT:
44+
##### OUTPUT
4545
- Training outputs saved to ./outputs directory
4646
- Logs streamed from all distributed processes
4747
- TensorBoard metrics enabled by default
4848

49-
CLEANUP:
49+
##### CLEANUP
5050
All SLURM jobs are automatically terminated at script completion.

examples/monarch/train_distributed.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,20 @@
77

88
import argparse
99
import asyncio
10+
import atexit
1011
import os
11-
1212
from copy import deepcopy
1313
from dataclasses import dataclass
1414
from typing import Dict
15-
import atexit
1615

1716
import torch
18-
1917
from monarch._rust_bindings.monarch_hyperactor.alloc import AllocConstraints, AllocSpec
2018
from monarch._src.actor.allocator import RemoteAllocator, TorchXRemoteAllocInitializer
21-
from monarch.actor import Actor, current_rank, endpoint, ProcMesh, this_host
19+
from monarch.actor import Actor, ProcMesh, current_rank, endpoint, this_host
2220
from monarch.tools import commands
2321
from monarch.tools.components import hyperactor
2422
from monarch.tools.config import Config
2523
from monarch.utils import setup_env_for_distributed
26-
2724
from torchtitan.config import ConfigManager, JobConfig
2825
from torchtitan.tools.logging import init_logger, logger
2926
from torchtitan.train import Trainer
@@ -73,7 +70,9 @@ def proc_mesh(
7370
) -> ProcMesh:
7471
allocator = RemoteAllocator(
7572
world_id=MonarchSlurm.job_name_prefix,
76-
initializer=TorchXRemoteAllocInitializer(f"slurm:///{cls.job_handles[mesh_name]}"),
73+
initializer=TorchXRemoteAllocInitializer(
74+
f"slurm:///{cls.job_handles[mesh_name]}"
75+
),
7776
)
7877
alloc = allocator.allocate(
7978
AllocSpec(AllocConstraints(), hosts=num_hosts, gpus=num_gpus)
@@ -84,13 +83,16 @@ def proc_mesh(
8483

8584
# ==== allocation boilerplate ====
8685

86+
8787
class LighthouseActor(Actor):
8888
def __init__(self) -> None:
8989
self.lighthouse = None
9090

9191
@endpoint
9292
def start_lighthouse(self) -> str:
93+
# inline import because of https://github.com/meta-pytorch/monarch/issues/804
9394
from torchft.coordination import LighthouseServer
95+
9496
self.lighthouse = LighthouseServer(
9597
bind="[::]:0", min_replicas=1, join_timeout_ms=10000
9698
)
@@ -217,7 +219,9 @@ async def start_training(self) -> None:
217219
)
218220

219221
for replica_id in range(self.spec.replica_count):
220-
await MonarchSlurm.get_or_create_job(f"replica_{replica_id}", self.spec.hosts_per_replica)
222+
await MonarchSlurm.get_or_create_job(
223+
f"replica_{replica_id}", self.spec.hosts_per_replica
224+
)
221225

222226
mesh_futures = {}
223227
for i in range(self.spec.replica_count):
@@ -305,6 +309,7 @@ async def _teardown(self, replica_id: int) -> None:
305309

306310
# === CLI / CONFIG === #
307311

312+
308313
def parse_args() -> argparse.Namespace:
309314
parser = argparse.ArgumentParser(
310315
description="Monarch-TorchFT Distributed Training Example"
@@ -398,6 +403,8 @@ def make_job_spec(args: argparse.Namespace) -> JobSpec:
398403
hosts_per_replica=args.host_per_replica,
399404
gpus_per_node=args.gpu_per_node,
400405
)
406+
407+
401408
# === CLI / CONFIG === #
402409

403410

0 commit comments

Comments
 (0)