Skip to content

Commit f04ac56

Browse files
committed
add failure injector for monarch script
1 parent af5a4ff commit f04ac56

File tree

5 files changed

+216
-49
lines changed

5 files changed

+216
-49
lines changed

examples/monarch/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ You can also override the resource configuration manually:
3535
- TrainingActor: Individual trainer processes
3636
- ReplicaActor: Manages groups of trainers
3737
- OrchestrationManager: Top-level orchestration and failure recovery
38+
- FailureController: Optional, periodically injects random failures into trainer processes
3839

3940
##### FAILURE RECOVERY
4041
- Automatic retry with configurable delays (PER_ATTEMPT_DELAY)

examples/monarch/__init__.py

Whitespace-only changes.

examples/monarch/train_distributed.py

Lines changed: 91 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -24,54 +24,61 @@
2424
from torchtitan.config import ConfigManager, JobConfig
2525
from torchtitan.tools.logging import init_logger, logger
2626
from torchtitan.train import Trainer
27+
from utils.failure import Failure, FailureActor, FailureController
2728

2829

2930
# ==== Allocation boilerplate - much of this will be upstreamed into Monarch ====
3031
class MonarchSlurm:
3132
# Cluster Configuration - update these values for your specific cluster
32-
machine: str = "aws_g5.12xlarge"
33-
machine_memory: int = 186777
33+
machine: str = "gpu.xlarge"
34+
machine_memory: int = 2062607
3435
job_name_prefix: str = "monarch-torchft"
3536

36-
job_handles: Dict[str, str] = {}
37+
def __init__(self):
38+
self.job_handles: Dict[str, str] = {}
39+
atexit.register(self.kill_jobs)
3740

38-
@classmethod
39-
def get_config(cls, mesh_name: str, nodes_per_mesh: int) -> Config:
41+
def get_config(self, mesh_name: str, nodes_per_mesh: int) -> Config:
4042
mesh = [f"{mesh_name}:{nodes_per_mesh}:{MonarchSlurm.machine}"]
41-
appdef = hyperactor.host_mesh(meshes=mesh)
43+
# to enable relative import of utils on actors
44+
current_dir = os.path.dirname(os.path.abspath(__file__))
45+
env = {"PYTHONPATH": current_dir}
46+
47+
appdef = hyperactor.host_mesh(meshes=mesh, env=env)
4248

4349
for role in appdef.roles:
4450
role.resource.memMB = MonarchSlurm.machine_memory
4551

4652
return Config(scheduler="slurm", appdef=appdef)
4753

48-
@classmethod
49-
async def get_or_create_job(cls, mesh_name: str, nodes_per_mesh: int = 1) -> None:
50-
config = cls.get_config(mesh_name, nodes_per_mesh)
54+
async def get_or_create_job(self, mesh_name: str, nodes_per_mesh: int = 1) -> None:
55+
config = self.get_config(mesh_name, nodes_per_mesh)
5156
job_name = f"{MonarchSlurm.job_name_prefix}-{mesh_name}"
5257
server_spec = await commands.get_or_create(job_name, config, force_restart=True)
53-
cls.job_handles[mesh_name] = server_spec.name
58+
self.job_handles[mesh_name] = server_spec.name
5459

55-
@classmethod
56-
def kill_jobs(cls):
57-
for mesh_name, job_handle in cls.job_handles.items():
58-
try:
59-
logger.info(f"Destroying job for mesh {mesh_name}")
60-
commands.kill(f"slurm:///{job_handle}")
61-
except Exception as e:
62-
logger.warning(f"Failed to destroy job for {mesh_name}: {e}")
60+
def kill_jobs(self):
61+
for mesh_name in self.job_handles.keys():
62+
self.kill_job(mesh_name)
63+
64+
def kill_job(self, mesh_name: str):
65+
try:
66+
job_handle = self.job_handles[mesh_name]
67+
logger.info(f"Destroying job for mesh {mesh_name}")
68+
commands.kill(f"slurm:///{job_handle}")
69+
except Exception as e:
70+
logger.warning(f"Failed to destroy job for {mesh_name}: {e}")
6371

64-
@classmethod
6572
def proc_mesh(
66-
cls,
73+
self,
6774
mesh_name: str,
6875
num_hosts: int = 1,
6976
num_gpus: int = 8,
7077
) -> ProcMesh:
7178
allocator = RemoteAllocator(
7279
world_id=MonarchSlurm.job_name_prefix,
7380
initializer=TorchXRemoteAllocInitializer(
74-
f"slurm:///{cls.job_handles[mesh_name]}"
81+
f"slurm:///{self.job_handles[mesh_name]}"
7582
),
7683
)
7784
alloc = allocator.allocate(
@@ -140,6 +147,7 @@ class JobSpec:
140147
replica_count: int
141148
hosts_per_replica: int
142149
gpus_per_node: int
150+
with_failures: bool
143151
lighthouse_address: str = ""
144152

145153

@@ -154,16 +162,15 @@ class Replica:
154162
# This does not currently benefit from being an actor, but will once
155163
# Monarch supervision APIs are fleshed out.
156164
class ReplicaActor(Actor):
157-
def __init__(
158-
self,
159-
spec: JobSpec,
160-
replica_id: int,
161-
) -> None:
165+
def __init__(self, spec: JobSpec, replica_id: int, scheduler: MonarchSlurm) -> None:
162166
self.spec = deepcopy(spec)
163167
self.replica_id = replica_id
164168

165169
self.uid = f"[replica_{replica_id}]"
166170
self.spec.job_config.fault_tolerance.replica_id = self.replica_id
171+
self.scheduler = scheduler
172+
173+
self.failure_actors: FailureActor | None = None
167174

168175
@endpoint
169176
async def start_replica(self) -> None:
@@ -172,7 +179,7 @@ async def start_replica(self) -> None:
172179

173180
trainers_proc_mesh: ProcMesh | None = None
174181
try:
175-
trainers_proc_mesh = MonarchSlurm.proc_mesh(
182+
trainers_proc_mesh = self.scheduler.proc_mesh(
176183
f"replica_{self.replica_id}",
177184
self.spec.hosts_per_replica,
178185
self.spec.gpus_per_node,
@@ -189,6 +196,10 @@ async def start_replica(self) -> None:
189196
self.replica_id,
190197
)
191198

199+
self.failure_actors = trainers_proc_mesh.spawn(
200+
"failure_actors", FailureActor
201+
)
202+
192203
logger.info(f"{self.uid} Starting trainers")
193204
await training_actors.start_training.call(self.spec.lighthouse_address)
194205
await trainers_proc_mesh.stop()
@@ -197,13 +208,29 @@ async def start_replica(self) -> None:
197208
await trainers_proc_mesh.stop()
198209
raise e
199210

211+
@endpoint
212+
async def inject_failure(self, failure_type: Failure):
213+
if self.failure_actors:
214+
try:
215+
logger.info(
216+
f"{self.uid} Injecting failure ({failure_type}) into random trainer"
217+
)
218+
219+
await self.failure_actors.fail.choose(failure_type)
220+
except Exception as e:
221+
error_msg = f"{self.uid} Injected failure: {e}"
222+
logger.error(error_msg)
223+
else:
224+
error_msg = f"{self.uid} No failure actors available"
225+
logger.error(error_msg)
226+
200227

201228
# delay before re-creating proc mesh on existing job. change as needed.
202-
PROC_ATTEMPT_DELAY = 10
229+
PROC_ATTEMPT_DELAY = 0
203230
# proc attempts before getting a new scheduler allocation. change as needed.
204-
PROC_ATTEMPTS = 2
231+
PROC_ATTEMPTS = 3
205232
# attempts before failing training on replica. change as needed.
206-
MAX_ATTEMPT = PROC_ATTEMPTS * 2
233+
MAX_ATTEMPT = PROC_ATTEMPTS * 3
207234

208235

209236
class OrchestrationManager:
@@ -213,26 +240,37 @@ def __init__(self, spec: JobSpec) -> None:
213240
self.lighthouse_actor: LighthouseActor | None = None
214241
self.lighthouse_mesh: ProcMesh | None = None
215242

243+
self.scheduler = MonarchSlurm()
244+
216245
async def start_training(self) -> None:
217246
logger.info(
218247
f"[Controller] Creating training system with {self.spec.replica_count} replicas"
219248
)
220249

221250
for replica_id in range(self.spec.replica_count):
222-
await MonarchSlurm.get_or_create_job(
251+
await self.scheduler.get_or_create_job(
223252
f"replica_{replica_id}", self.spec.hosts_per_replica
224253
)
225254

226255
mesh_futures = {}
227256
for i in range(self.spec.replica_count):
228257
mesh_futures[i] = asyncio.create_task(self._run_replica(i, 0))
229258

259+
failure_future = None
260+
if self.spec.with_failures:
261+
failure_future = asyncio.create_task(
262+
FailureController.execute_failures(self.replicas, self.scheduler)
263+
)
264+
230265
await asyncio.gather(*mesh_futures.values(), return_exceptions=True)
231266

267+
if failure_future:
268+
failure_future.cancel()
269+
232270
async def start_lighthouse(self) -> None:
233271
if self.spec.remote_lighthouse:
234-
await MonarchSlurm.get_or_create_job("lighthouse")
235-
self.lighthouse_mesh = MonarchSlurm.proc_mesh("lighthouse", num_gpus=1)
272+
await self.scheduler.get_or_create_job("lighthouse")
273+
self.lighthouse_mesh = self.scheduler.proc_mesh("lighthouse", num_gpus=1)
236274
else:
237275
self.lighthouse_mesh = this_host().spawn_procs({"gpus": 1})
238276

@@ -274,7 +312,8 @@ async def _spin_up_replica(self, replica_id: int, attempt_number: int = 0) -> No
274312
logger.info(
275313
f"[Controller] Replica {replica_id} has failed {attempt_number} times. Getting new allocation."
276314
)
277-
await MonarchSlurm.get_or_create_job(
315+
self.scheduler.kill_job(f"replica_{replica_id}")
316+
await self.scheduler.get_or_create_job(
278317
f"replica_{replica_id}", self.spec.hosts_per_replica
279318
)
280319
delay = 0 if not attempt_number else PROC_ATTEMPT_DELAY
@@ -287,10 +326,7 @@ async def _spin_up_replica(self, replica_id: int, attempt_number: int = 0) -> No
287326
await replica_proc_mesh.logging_option(aggregate_window_sec=None)
288327

289328
replica_actor = replica_proc_mesh.spawn(
290-
"replica_actor",
291-
ReplicaActor,
292-
self.spec,
293-
replica_id,
329+
"replica_actor", ReplicaActor, self.spec, replica_id, self.scheduler
294330
)
295331

296332
replica = Replica(replica_id, replica_proc_mesh, replica_actor, attempt_number)
@@ -339,20 +375,25 @@ def parse_args() -> argparse.Namespace:
339375
parser.add_argument(
340376
"--model-config",
341377
type=str,
342-
default=os.path.join(script_dir, "debug_model.toml"),
343-
help=f"Path to model configuration file (default: {os.path.join(script_dir, 'debug_model.toml')})",
378+
default="debug_model.toml",
379+
help=f"Relative path to model configuration file (default: {os.path.join(script_dir, 'debug_model.toml')})",
344380
)
345381
parser.add_argument(
346382
"--dataset-path",
347383
type=str,
348-
default=os.path.join(script_dir, "c4_test"),
349-
help=f"Path to training dataset (default: {os.path.join(script_dir, 'c4_test')})",
384+
default="c4_test",
385+
help=f"Relative path to training dataset (default: {os.path.join(script_dir, 'c4_test')})",
350386
)
351387
parser.add_argument(
352388
"--tokenizer-path",
353389
type=str,
354-
default=os.path.join(script_dir, "tokenizer"),
355-
help=f"Path to tokenizer (default: {os.path.join(script_dir, 'tokenizer')})",
390+
default="debug_tokenizer",
391+
help=f"Relative path to tokenizer (default: {os.path.join(script_dir, 'debug_tokenizer')})",
392+
)
393+
parser.add_argument(
394+
"--with-failures",
395+
action="store_true",
396+
help="Enable the failure injector utility (default: False)",
356397
)
357398

358399
return parser.parse_args()
@@ -362,13 +403,14 @@ def make_job_spec(args: argparse.Namespace) -> JobSpec:
362403
data_parallel_shard_degree = args.gpu_per_node * args.host_per_replica
363404

364405
output_path = "./outputs"
365-
training_dataset = "c4_test"
406+
training_dataset = args.dataset_path.split("/")[-1]
366407

408+
script_dir = os.path.dirname(os.path.abspath(__file__))
367409
default_args = [
368410
"--job.config_file",
369-
args.model_config,
411+
os.path.join(script_dir, args.model_config),
370412
"--model.tokenizer_path",
371-
args.tokenizer_path,
413+
os.path.join(script_dir, args.tokenizer_path),
372414
"--comm.trace_buf_size",
373415
"0",
374416
"--metrics.log_freq",
@@ -387,7 +429,7 @@ def make_job_spec(args: argparse.Namespace) -> JobSpec:
387429
"--training.dataset",
388430
training_dataset,
389431
"--training.dataset_path",
390-
args.dataset_path,
432+
os.path.join(script_dir, args.dataset_path),
391433
"--job.dump_folder",
392434
output_path,
393435
"--metrics.enable_tensorboard",
@@ -402,6 +444,7 @@ def make_job_spec(args: argparse.Namespace) -> JobSpec:
402444
replica_count=args.replica_count,
403445
hosts_per_replica=args.host_per_replica,
404446
gpus_per_node=args.gpu_per_node,
447+
with_failures=args.with_failures,
405448
)
406449

407450

@@ -414,7 +457,6 @@ async def main() -> None:
414457
args = parse_args()
415458
job_spec = make_job_spec(args)
416459

417-
atexit.register(MonarchSlurm.kill_jobs)
418460
orchestrator = OrchestrationManager(job_spec)
419461
try:
420462
await orchestrator.start_lighthouse()

examples/monarch/utils/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)