Skip to content

Commit d596ec7

Browse files
amirafzalifacebook-github-bot
authored andcommitted
add failure injector for monarch training script (#270)
Summary: This introduces a failure injector with 5 failure modes: - SEGFAULT: Triggers a SIGSEGV on the process - DEADLOCK = Deadlocks the GIL, resulting in ProcessGroupNCCL timeout and terminal failure - KILL_PROC: Immediately kills the process with non-zero exit code - COMMS = Forcefully aborts the ProcessGroup and NCCL communicator - KILL_SLURM = Kills a random replica SLURM job It can be enabled with the flag `--with--failure`, and it runs async every 120 seconds. Pull Request resolved: #270 Reviewed By: tushar00jain Differential Revision: D83601242 Pulled By: amirafzali fbshipit-source-id: b26a7b6349a9d46c2a4331a70b8b65cdcc600a35
1 parent 302fd39 commit d596ec7

File tree

5 files changed

+237
-58
lines changed

5 files changed

+237
-58
lines changed

examples/monarch/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ You can also override the resource configuration manually:
3535
- TrainingActor: Individual trainer processes
3636
- ReplicaActor: Manages groups of trainers
3737
- OrchestrationManager: Top-level orchestration and failure recovery
38+
- FailureController: Optional, periodically injects random failures into trainer processes
3839

3940
##### FAILURE RECOVERY
4041
- Automatic retry with configurable delays (PER_ATTEMPT_DELAY)

examples/monarch/__init__.py

Whitespace-only changes.

examples/monarch/train_distributed.py

Lines changed: 100 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -24,54 +24,61 @@
2424
from torchtitan.config import ConfigManager, JobConfig
2525
from torchtitan.tools.logging import init_logger, logger
2626
from torchtitan.train import Trainer
27+
from utils.failure import Failure, FailureActor, FailureController
2728

2829

2930
# ==== Allocation boilerplate - much of this will be upstreamed into Monarch ====
3031
class MonarchSlurm:
3132
# Cluster Configuration - update these values for your specific cluster
32-
machine: str = "aws_g5.12xlarge"
33-
machine_memory: int = 186777
33+
machine: str = "gpu.xlarge"
34+
machine_memory: int = 2062607
3435
job_name_prefix: str = "monarch-torchft"
3536

36-
job_handles: Dict[str, str] = {}
37+
def __init__(self):
38+
self.job_handles: Dict[str, str] = {}
39+
atexit.register(self.kill_jobs)
3740

38-
@classmethod
39-
def get_config(cls, mesh_name: str, nodes_per_mesh: int) -> Config:
41+
def get_config(self, mesh_name: str, nodes_per_mesh: int) -> Config:
4042
mesh = [f"{mesh_name}:{nodes_per_mesh}:{MonarchSlurm.machine}"]
41-
appdef = hyperactor.host_mesh(meshes=mesh)
43+
# to enable relative import of utils on actors
44+
current_dir = os.path.dirname(os.path.abspath(__file__))
45+
env = {"PYTHONPATH": current_dir}
46+
47+
appdef = hyperactor.host_mesh(meshes=mesh, env=env)
4248

4349
for role in appdef.roles:
4450
role.resource.memMB = MonarchSlurm.machine_memory
4551

4652
return Config(scheduler="slurm", appdef=appdef)
4753

48-
@classmethod
49-
async def get_or_create_job(cls, mesh_name: str, nodes_per_mesh: int = 1) -> None:
50-
config = cls.get_config(mesh_name, nodes_per_mesh)
54+
async def get_or_create_job(self, mesh_name: str, nodes_per_mesh: int = 1) -> None:
55+
config = self.get_config(mesh_name, nodes_per_mesh)
5156
job_name = f"{MonarchSlurm.job_name_prefix}-{mesh_name}"
5257
server_spec = await commands.get_or_create(job_name, config, force_restart=True)
53-
cls.job_handles[mesh_name] = server_spec.name
58+
self.job_handles[mesh_name] = server_spec.name
5459

55-
@classmethod
56-
def kill_jobs(cls):
57-
for mesh_name, job_handle in cls.job_handles.items():
58-
try:
59-
logger.info(f"Destroying job for mesh {mesh_name}")
60-
commands.kill(f"slurm:///{job_handle}")
61-
except Exception as e:
62-
logger.warning(f"Failed to destroy job for {mesh_name}: {e}")
60+
def kill_jobs(self):
61+
for mesh_name in self.job_handles.keys():
62+
self.kill_job(mesh_name)
63+
64+
def kill_job(self, mesh_name: str):
65+
try:
66+
job_handle = self.job_handles[mesh_name]
67+
logger.info(f"Destroying job for mesh {mesh_name}")
68+
commands.kill(f"slurm:///{job_handle}")
69+
except Exception as e:
70+
logger.warning(f"Failed to destroy job for {mesh_name}: {e}")
6371

64-
@classmethod
6572
def proc_mesh(
66-
cls,
73+
self,
6774
mesh_name: str,
6875
num_hosts: int = 1,
6976
num_gpus: int = 8,
7077
) -> ProcMesh:
7178
allocator = RemoteAllocator(
7279
world_id=MonarchSlurm.job_name_prefix,
7380
initializer=TorchXRemoteAllocInitializer(
74-
f"slurm:///{cls.job_handles[mesh_name]}"
81+
f"slurm:///{self.job_handles[mesh_name]}"
7582
),
7683
)
7784
alloc = allocator.allocate(
@@ -94,7 +101,7 @@ def start_lighthouse(self) -> str:
94101
from torchft.coordination import LighthouseServer
95102

96103
self.lighthouse = LighthouseServer(
97-
bind="[::]:0", min_replicas=1, join_timeout_ms=10000
104+
bind="[::]:0", min_replicas=1, join_timeout_ms=60000
98105
)
99106
return self.lighthouse.address()
100107

@@ -140,6 +147,7 @@ class JobSpec:
140147
replica_count: int
141148
hosts_per_replica: int
142149
gpus_per_node: int
150+
with_failures: bool
143151
lighthouse_address: str = ""
144152

145153

@@ -154,16 +162,15 @@ class Replica:
154162
# This does not currently benefit from being an actor, but will once
155163
# Monarch supervision APIs are fleshed out.
156164
class ReplicaActor(Actor):
157-
def __init__(
158-
self,
159-
spec: JobSpec,
160-
replica_id: int,
161-
) -> None:
165+
def __init__(self, spec: JobSpec, replica_id: int, scheduler: MonarchSlurm) -> None:
162166
self.spec = deepcopy(spec)
163167
self.replica_id = replica_id
164168

165169
self.uid = f"[replica_{replica_id}]"
166170
self.spec.job_config.fault_tolerance.replica_id = self.replica_id
171+
self.scheduler = scheduler
172+
173+
self.failure_actors: FailureActor | None = None
167174

168175
@endpoint
169176
async def start_replica(self) -> None:
@@ -172,14 +179,12 @@ async def start_replica(self) -> None:
172179

173180
trainers_proc_mesh: ProcMesh | None = None
174181
try:
175-
trainers_proc_mesh = MonarchSlurm.proc_mesh(
182+
trainers_proc_mesh = self.scheduler.proc_mesh(
176183
f"replica_{self.replica_id}",
177184
self.spec.hosts_per_replica,
178185
self.spec.gpus_per_node,
179186
)
180-
await trainers_proc_mesh.logging_option(
181-
stream_to_client=True, aggregate_window_sec=None
182-
)
187+
await trainers_proc_mesh.logging_option(stream_to_client=True)
183188
await setup_env_for_distributed(trainers_proc_mesh)
184189

185190
training_actors = trainers_proc_mesh.spawn(
@@ -189,6 +194,10 @@ async def start_replica(self) -> None:
189194
self.replica_id,
190195
)
191196

197+
self.failure_actors = trainers_proc_mesh.spawn(
198+
"failure_actors", FailureActor
199+
)
200+
192201
logger.info(f"{self.uid} Starting trainers")
193202
await training_actors.start_training.call(self.spec.lighthouse_address)
194203
await trainers_proc_mesh.stop()
@@ -197,13 +206,29 @@ async def start_replica(self) -> None:
197206
await trainers_proc_mesh.stop()
198207
raise e
199208

209+
@endpoint
210+
async def inject_failure(self, failure_type: Failure):
211+
if self.failure_actors:
212+
try:
213+
logger.info(
214+
f"{self.uid} Injecting failure ({failure_type}) into random trainer"
215+
)
216+
217+
await self.failure_actors.fail.choose(failure_type)
218+
except Exception as e:
219+
error_msg = f"{self.uid} Injected failure: {e}"
220+
logger.error(error_msg)
221+
else:
222+
error_msg = f"{self.uid} No failure actors available"
223+
logger.error(error_msg)
224+
200225

201226
# delay before re-creating proc mesh on existing job. change as needed.
202-
PROC_ATTEMPT_DELAY = 10
227+
PROC_ATTEMPT_DELAY = 0
203228
# proc attempts before getting a new scheduler allocation. change as needed.
204-
PROC_ATTEMPTS = 2
229+
PROC_ATTEMPTS = 4
205230
# attempts before failing training on replica. change as needed.
206-
MAX_ATTEMPT = PROC_ATTEMPTS * 2
231+
MAX_ATTEMPT = PROC_ATTEMPTS * 4
207232

208233

209234
class OrchestrationManager:
@@ -213,32 +238,41 @@ def __init__(self, spec: JobSpec) -> None:
213238
self.lighthouse_actor: LighthouseActor | None = None
214239
self.lighthouse_mesh: ProcMesh | None = None
215240

241+
self.scheduler = MonarchSlurm()
242+
216243
async def start_training(self) -> None:
217244
logger.info(
218245
f"[Controller] Creating training system with {self.spec.replica_count} replicas"
219246
)
220247

221248
for replica_id in range(self.spec.replica_count):
222-
await MonarchSlurm.get_or_create_job(
249+
await self.scheduler.get_or_create_job(
223250
f"replica_{replica_id}", self.spec.hosts_per_replica
224251
)
225252

226253
mesh_futures = {}
227254
for i in range(self.spec.replica_count):
228255
mesh_futures[i] = asyncio.create_task(self._run_replica(i, 0))
229256

257+
failure_future = None
258+
if self.spec.with_failures:
259+
failure_future = asyncio.create_task(
260+
FailureController.execute_failures(self.replicas, self.scheduler)
261+
)
262+
230263
await asyncio.gather(*mesh_futures.values(), return_exceptions=True)
231264

265+
if failure_future:
266+
failure_future.cancel()
267+
232268
async def start_lighthouse(self) -> None:
233269
if self.spec.remote_lighthouse:
234-
await MonarchSlurm.get_or_create_job("lighthouse")
235-
self.lighthouse_mesh = MonarchSlurm.proc_mesh("lighthouse", num_gpus=1)
270+
await self.scheduler.get_or_create_job("lighthouse")
271+
self.lighthouse_mesh = self.scheduler.proc_mesh("lighthouse", num_gpus=1)
236272
else:
237273
self.lighthouse_mesh = this_host().spawn_procs({"gpus": 1})
238274

239-
await self.lighthouse_mesh.logging_option(
240-
stream_to_client=True, aggregate_window_sec=None
241-
)
275+
await self.lighthouse_mesh.logging_option(stream_to_client=True)
242276
self.lighthouse_actor = self.lighthouse_mesh.spawn(
243277
"lighthouse_actor", LighthouseActor
244278
)
@@ -274,7 +308,8 @@ async def _spin_up_replica(self, replica_id: int, attempt_number: int = 0) -> No
274308
logger.info(
275309
f"[Controller] Replica {replica_id} has failed {attempt_number} times. Getting new allocation."
276310
)
277-
await MonarchSlurm.get_or_create_job(
311+
self.scheduler.kill_job(f"replica_{replica_id}")
312+
await self.scheduler.get_or_create_job(
278313
f"replica_{replica_id}", self.spec.hosts_per_replica
279314
)
280315
delay = 0 if not attempt_number else PROC_ATTEMPT_DELAY
@@ -287,10 +322,7 @@ async def _spin_up_replica(self, replica_id: int, attempt_number: int = 0) -> No
287322
await replica_proc_mesh.logging_option(aggregate_window_sec=None)
288323

289324
replica_actor = replica_proc_mesh.spawn(
290-
"replica_actor",
291-
ReplicaActor,
292-
self.spec,
293-
replica_id,
325+
"replica_actor", ReplicaActor, self.spec, replica_id, self.scheduler
294326
)
295327

296328
replica = Replica(replica_id, replica_proc_mesh, replica_actor, attempt_number)
@@ -301,8 +333,8 @@ async def _teardown(self, replica_id: int) -> None:
301333
try:
302334
replica = self.replicas[replica_id]
303335
await replica.proc_mesh.stop()
304-
del replica.proc_mesh
305336
del self.replicas[replica_id]
337+
del replica.proc_mesh
306338
except Exception as e:
307339
logger.error(f"[Controller] Failed to _teardown replica {replica_id}: {e}")
308340

@@ -339,20 +371,25 @@ def parse_args() -> argparse.Namespace:
339371
parser.add_argument(
340372
"--model-config",
341373
type=str,
342-
default=os.path.join(script_dir, "debug_model.toml"),
343-
help=f"Path to model configuration file (default: {os.path.join(script_dir, 'debug_model.toml')})",
374+
default="debug_model.toml",
375+
help=f"Relative path to model configuration file (default: {os.path.join(script_dir, 'debug_model.toml')})",
344376
)
345377
parser.add_argument(
346378
"--dataset-path",
347379
type=str,
348-
default=os.path.join(script_dir, "c4_test"),
349-
help=f"Path to training dataset (default: {os.path.join(script_dir, 'c4_test')})",
380+
default="c4_test",
381+
help=f"Relative path to training dataset (default: {os.path.join(script_dir, 'c4_test')})",
350382
)
351383
parser.add_argument(
352384
"--tokenizer-path",
353385
type=str,
354-
default=os.path.join(script_dir, "tokenizer"),
355-
help=f"Path to tokenizer (default: {os.path.join(script_dir, 'tokenizer')})",
386+
default="debug_tokenizer",
387+
help=f"Relative path to tokenizer (default: {os.path.join(script_dir, 'debug_tokenizer')})",
388+
)
389+
parser.add_argument(
390+
"--with-failures",
391+
action="store_true",
392+
help="Enable the failure injector utility (default: False)",
356393
)
357394

358395
return parser.parse_args()
@@ -362,32 +399,37 @@ def make_job_spec(args: argparse.Namespace) -> JobSpec:
362399
data_parallel_shard_degree = args.gpu_per_node * args.host_per_replica
363400

364401
output_path = "./outputs"
365-
training_dataset = "c4_test"
402+
training_dataset = args.dataset_path.split("/")[-1]
366403

404+
script_dir = os.path.dirname(os.path.abspath(__file__))
367405
default_args = [
368406
"--job.config_file",
369-
args.model_config,
407+
os.path.join(script_dir, args.model_config),
370408
"--model.tokenizer_path",
371-
args.tokenizer_path,
409+
os.path.join(script_dir, args.tokenizer_path),
372410
"--comm.trace_buf_size",
373411
"0",
374412
"--metrics.log_freq",
375413
"1",
376414
"--fault_tolerance.enable",
377415
"--fault_tolerance.group_size",
378416
str(args.replica_count),
417+
"--fault_tolerance.process_group",
418+
"nccl",
419+
"--fault_tolerance.process_group_timeout_ms",
420+
"60000",
379421
"--parallelism.data_parallel_shard_degree",
380422
str(data_parallel_shard_degree),
381423
"--activation_checkpoint.mode",
382424
"full",
383425
"--comm.train_timeout_seconds",
384-
"60",
426+
"300",
385427
"--training.steps",
386428
str(args.training_steps),
387429
"--training.dataset",
388430
training_dataset,
389431
"--training.dataset_path",
390-
args.dataset_path,
432+
os.path.join(script_dir, args.dataset_path),
391433
"--job.dump_folder",
392434
output_path,
393435
"--metrics.enable_tensorboard",
@@ -402,6 +444,7 @@ def make_job_spec(args: argparse.Namespace) -> JobSpec:
402444
replica_count=args.replica_count,
403445
hosts_per_replica=args.host_per_replica,
404446
gpus_per_node=args.gpu_per_node,
447+
with_failures=args.with_failures,
405448
)
406449

407450

@@ -414,7 +457,6 @@ async def main() -> None:
414457
args = parse_args()
415458
job_spec = make_job_spec(args)
416459

417-
atexit.register(MonarchSlurm.kill_jobs)
418460
orchestrator = OrchestrationManager(job_spec)
419461
try:
420462
await orchestrator.start_lighthouse()

examples/monarch/utils/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)