24
24
from torchtitan .config import ConfigManager , JobConfig
25
25
from torchtitan .tools .logging import init_logger , logger
26
26
from torchtitan .train import Trainer
27
+ from utils .failure import Failure , FailureActor , FailureController
27
28
28
29
29
30
# ==== Allocation boilerplate - much of this will be upstreamed into Monarch ====
30
31
class MonarchSlurm :
31
32
# Cluster Configuration - update these values for your specific cluster
32
- machine : str = "aws_g5.12xlarge "
33
- machine_memory : int = 186777
33
+ machine : str = "gpu.xlarge "
34
+ machine_memory : int = 2062607
34
35
job_name_prefix : str = "monarch-torchft"
35
36
36
- job_handles : Dict [str , str ] = {}
37
+ def __init__ (self ):
38
+ self .job_handles : Dict [str , str ] = {}
39
+ atexit .register (self .kill_jobs )
37
40
38
- @classmethod
39
- def get_config (cls , mesh_name : str , nodes_per_mesh : int ) -> Config :
41
+ def get_config (self , mesh_name : str , nodes_per_mesh : int ) -> Config :
40
42
mesh = [f"{ mesh_name } :{ nodes_per_mesh } :{ MonarchSlurm .machine } " ]
41
- appdef = hyperactor .host_mesh (meshes = mesh )
43
+ # to enable relative import of utils on actors
44
+ current_dir = os .path .dirname (os .path .abspath (__file__ ))
45
+ env = {"PYTHONPATH" : current_dir }
46
+
47
+ appdef = hyperactor .host_mesh (meshes = mesh , env = env )
42
48
43
49
for role in appdef .roles :
44
50
role .resource .memMB = MonarchSlurm .machine_memory
45
51
46
52
return Config (scheduler = "slurm" , appdef = appdef )
47
53
48
- @classmethod
49
- async def get_or_create_job (cls , mesh_name : str , nodes_per_mesh : int = 1 ) -> None :
50
- config = cls .get_config (mesh_name , nodes_per_mesh )
54
+ async def get_or_create_job (self , mesh_name : str , nodes_per_mesh : int = 1 ) -> None :
55
+ config = self .get_config (mesh_name , nodes_per_mesh )
51
56
job_name = f"{ MonarchSlurm .job_name_prefix } -{ mesh_name } "
52
57
server_spec = await commands .get_or_create (job_name , config , force_restart = True )
53
- cls .job_handles [mesh_name ] = server_spec .name
58
+ self .job_handles [mesh_name ] = server_spec .name
54
59
55
- @classmethod
56
- def kill_jobs (cls ):
57
- for mesh_name , job_handle in cls .job_handles .items ():
58
- try :
59
- logger .info (f"Destroying job for mesh { mesh_name } " )
60
- commands .kill (f"slurm:///{ job_handle } " )
61
- except Exception as e :
62
- logger .warning (f"Failed to destroy job for { mesh_name } : { e } " )
60
+ def kill_jobs (self ):
61
+ for mesh_name in self .job_handles .keys ():
62
+ self .kill_job (mesh_name )
63
+
64
+ def kill_job (self , mesh_name : str ):
65
+ try :
66
+ job_handle = self .job_handles [mesh_name ]
67
+ logger .info (f"Destroying job for mesh { mesh_name } " )
68
+ commands .kill (f"slurm:///{ job_handle } " )
69
+ except Exception as e :
70
+ logger .warning (f"Failed to destroy job for { mesh_name } : { e } " )
63
71
64
- @classmethod
65
72
def proc_mesh (
66
- cls ,
73
+ self ,
67
74
mesh_name : str ,
68
75
num_hosts : int = 1 ,
69
76
num_gpus : int = 8 ,
70
77
) -> ProcMesh :
71
78
allocator = RemoteAllocator (
72
79
world_id = MonarchSlurm .job_name_prefix ,
73
80
initializer = TorchXRemoteAllocInitializer (
74
- f"slurm:///{ cls .job_handles [mesh_name ]} "
81
+ f"slurm:///{ self .job_handles [mesh_name ]} "
75
82
),
76
83
)
77
84
alloc = allocator .allocate (
@@ -140,6 +147,7 @@ class JobSpec:
140
147
replica_count : int
141
148
hosts_per_replica : int
142
149
gpus_per_node : int
150
+ with_failures : bool
143
151
lighthouse_address : str = ""
144
152
145
153
@@ -154,16 +162,15 @@ class Replica:
154
162
# This does not currently benefit from being an actor, but will once
155
163
# Monarch supervision APIs are fleshed out.
156
164
class ReplicaActor (Actor ):
157
- def __init__ (
158
- self ,
159
- spec : JobSpec ,
160
- replica_id : int ,
161
- ) -> None :
165
+ def __init__ (self , spec : JobSpec , replica_id : int , scheduler : MonarchSlurm ) -> None :
162
166
self .spec = deepcopy (spec )
163
167
self .replica_id = replica_id
164
168
165
169
self .uid = f"[replica_{ replica_id } ]"
166
170
self .spec .job_config .fault_tolerance .replica_id = self .replica_id
171
+ self .scheduler = scheduler
172
+
173
+ self .failure_actors : FailureActor | None = None
167
174
168
175
@endpoint
169
176
async def start_replica (self ) -> None :
@@ -172,7 +179,7 @@ async def start_replica(self) -> None:
172
179
173
180
trainers_proc_mesh : ProcMesh | None = None
174
181
try :
175
- trainers_proc_mesh = MonarchSlurm .proc_mesh (
182
+ trainers_proc_mesh = self . scheduler .proc_mesh (
176
183
f"replica_{ self .replica_id } " ,
177
184
self .spec .hosts_per_replica ,
178
185
self .spec .gpus_per_node ,
@@ -189,6 +196,10 @@ async def start_replica(self) -> None:
189
196
self .replica_id ,
190
197
)
191
198
199
+ self .failure_actors = trainers_proc_mesh .spawn (
200
+ "failure_actors" , FailureActor
201
+ )
202
+
192
203
logger .info (f"{ self .uid } Starting trainers" )
193
204
await training_actors .start_training .call (self .spec .lighthouse_address )
194
205
await trainers_proc_mesh .stop ()
@@ -197,13 +208,29 @@ async def start_replica(self) -> None:
197
208
await trainers_proc_mesh .stop ()
198
209
raise e
199
210
211
+ @endpoint
212
+ async def inject_failure (self , failure_type : Failure ):
213
+ if self .failure_actors :
214
+ try :
215
+ logger .info (
216
+ f"{ self .uid } Injecting failure ({ failure_type } ) into random trainer"
217
+ )
218
+
219
+ await self .failure_actors .fail .choose (failure_type )
220
+ except Exception as e :
221
+ error_msg = f"{ self .uid } Injected failure: { e } "
222
+ logger .error (error_msg )
223
+ else :
224
+ error_msg = f"{ self .uid } No failure actors available"
225
+ logger .error (error_msg )
226
+
200
227
201
228
# delay before re-creating proc mesh on existing job. change as needed.
202
- PROC_ATTEMPT_DELAY = 10
229
+ PROC_ATTEMPT_DELAY = 0
203
230
# proc attempts before getting a new scheduler allocation. change as needed.
204
- PROC_ATTEMPTS = 2
231
+ PROC_ATTEMPTS = 3
205
232
# attempts before failing training on replica. change as needed.
206
- MAX_ATTEMPT = PROC_ATTEMPTS * 2
233
+ MAX_ATTEMPT = PROC_ATTEMPTS * 3
207
234
208
235
209
236
class OrchestrationManager :
@@ -213,26 +240,37 @@ def __init__(self, spec: JobSpec) -> None:
213
240
self .lighthouse_actor : LighthouseActor | None = None
214
241
self .lighthouse_mesh : ProcMesh | None = None
215
242
243
+ self .scheduler = MonarchSlurm ()
244
+
216
245
async def start_training (self ) -> None :
217
246
logger .info (
218
247
f"[Controller] Creating training system with { self .spec .replica_count } replicas"
219
248
)
220
249
221
250
for replica_id in range (self .spec .replica_count ):
222
- await MonarchSlurm .get_or_create_job (
251
+ await self . scheduler .get_or_create_job (
223
252
f"replica_{ replica_id } " , self .spec .hosts_per_replica
224
253
)
225
254
226
255
mesh_futures = {}
227
256
for i in range (self .spec .replica_count ):
228
257
mesh_futures [i ] = asyncio .create_task (self ._run_replica (i , 0 ))
229
258
259
+ failure_future = None
260
+ if self .spec .with_failures :
261
+ failure_future = asyncio .create_task (
262
+ FailureController .execute_failures (self .replicas , self .scheduler )
263
+ )
264
+
230
265
await asyncio .gather (* mesh_futures .values (), return_exceptions = True )
231
266
267
+ if failure_future :
268
+ failure_future .cancel ()
269
+
232
270
async def start_lighthouse (self ) -> None :
233
271
if self .spec .remote_lighthouse :
234
- await MonarchSlurm .get_or_create_job ("lighthouse" )
235
- self .lighthouse_mesh = MonarchSlurm .proc_mesh ("lighthouse" , num_gpus = 1 )
272
+ await self . scheduler .get_or_create_job ("lighthouse" )
273
+ self .lighthouse_mesh = self . scheduler .proc_mesh ("lighthouse" , num_gpus = 1 )
236
274
else :
237
275
self .lighthouse_mesh = this_host ().spawn_procs ({"gpus" : 1 })
238
276
@@ -274,7 +312,8 @@ async def _spin_up_replica(self, replica_id: int, attempt_number: int = 0) -> No
274
312
logger .info (
275
313
f"[Controller] Replica { replica_id } has failed { attempt_number } times. Getting new allocation."
276
314
)
277
- await MonarchSlurm .get_or_create_job (
315
+ self .scheduler .kill_job (f"replica_{ replica_id } " )
316
+ await self .scheduler .get_or_create_job (
278
317
f"replica_{ replica_id } " , self .spec .hosts_per_replica
279
318
)
280
319
delay = 0 if not attempt_number else PROC_ATTEMPT_DELAY
@@ -287,10 +326,7 @@ async def _spin_up_replica(self, replica_id: int, attempt_number: int = 0) -> No
287
326
await replica_proc_mesh .logging_option (aggregate_window_sec = None )
288
327
289
328
replica_actor = replica_proc_mesh .spawn (
290
- "replica_actor" ,
291
- ReplicaActor ,
292
- self .spec ,
293
- replica_id ,
329
+ "replica_actor" , ReplicaActor , self .spec , replica_id , self .scheduler
294
330
)
295
331
296
332
replica = Replica (replica_id , replica_proc_mesh , replica_actor , attempt_number )
@@ -339,20 +375,25 @@ def parse_args() -> argparse.Namespace:
339
375
parser .add_argument (
340
376
"--model-config" ,
341
377
type = str ,
342
- default = os . path . join ( script_dir , "debug_model.toml" ) ,
343
- help = f"Path to model configuration file (default: { os .path .join (script_dir , 'debug_model.toml' )} )" ,
378
+ default = "debug_model.toml" ,
379
+ help = f"Relative path to model configuration file (default: { os .path .join (script_dir , 'debug_model.toml' )} )" ,
344
380
)
345
381
parser .add_argument (
346
382
"--dataset-path" ,
347
383
type = str ,
348
- default = os . path . join ( script_dir , "c4_test" ) ,
349
- help = f"Path to training dataset (default: { os .path .join (script_dir , 'c4_test' )} )" ,
384
+ default = "c4_test" ,
385
+ help = f"Relative path to training dataset (default: { os .path .join (script_dir , 'c4_test' )} )" ,
350
386
)
351
387
parser .add_argument (
352
388
"--tokenizer-path" ,
353
389
type = str ,
354
- default = os .path .join (script_dir , "tokenizer" ),
355
- help = f"Path to tokenizer (default: { os .path .join (script_dir , 'tokenizer' )} )" ,
390
+ default = "debug_tokenizer" ,
391
+ help = f"Relative path to tokenizer (default: { os .path .join (script_dir , 'debug_tokenizer' )} )" ,
392
+ )
393
+ parser .add_argument (
394
+ "--with-failures" ,
395
+ action = "store_true" ,
396
+ help = "Enable the failure injector utility (default: False)" ,
356
397
)
357
398
358
399
return parser .parse_args ()
@@ -362,13 +403,14 @@ def make_job_spec(args: argparse.Namespace) -> JobSpec:
362
403
data_parallel_shard_degree = args .gpu_per_node * args .host_per_replica
363
404
364
405
output_path = "./outputs"
365
- training_dataset = "c4_test"
406
+ training_dataset = args . dataset_path . split ( "/" )[ - 1 ]
366
407
408
+ script_dir = os .path .dirname (os .path .abspath (__file__ ))
367
409
default_args = [
368
410
"--job.config_file" ,
369
- args .model_config ,
411
+ os . path . join ( script_dir , args .model_config ) ,
370
412
"--model.tokenizer_path" ,
371
- args .tokenizer_path ,
413
+ os . path . join ( script_dir , args .tokenizer_path ) ,
372
414
"--comm.trace_buf_size" ,
373
415
"0" ,
374
416
"--metrics.log_freq" ,
@@ -387,7 +429,7 @@ def make_job_spec(args: argparse.Namespace) -> JobSpec:
387
429
"--training.dataset" ,
388
430
training_dataset ,
389
431
"--training.dataset_path" ,
390
- args .dataset_path ,
432
+ os . path . join ( script_dir , args .dataset_path ) ,
391
433
"--job.dump_folder" ,
392
434
output_path ,
393
435
"--metrics.enable_tensorboard" ,
@@ -402,6 +444,7 @@ def make_job_spec(args: argparse.Namespace) -> JobSpec:
402
444
replica_count = args .replica_count ,
403
445
hosts_per_replica = args .host_per_replica ,
404
446
gpus_per_node = args .gpu_per_node ,
447
+ with_failures = args .with_failures ,
405
448
)
406
449
407
450
@@ -414,7 +457,6 @@ async def main() -> None:
414
457
args = parse_args ()
415
458
job_spec = make_job_spec (args )
416
459
417
- atexit .register (MonarchSlurm .kill_jobs )
418
460
orchestrator = OrchestrationManager (job_spec )
419
461
try :
420
462
await orchestrator .start_lighthouse ()
0 commit comments