24
24
from torchtitan .config import ConfigManager , JobConfig
25
25
from torchtitan .tools .logging import init_logger , logger
26
26
from torchtitan .train import Trainer
27
+ from utils .failure import Failure , FailureActor , FailureController
27
28
28
29
29
30
# ==== Allocation boilerplate - much of this will be upstreamed into Monarch ====
30
31
class MonarchSlurm :
31
32
# Cluster Configuration - update these values for your specific cluster
32
- machine : str = "aws_g5.12xlarge "
33
- machine_memory : int = 186777
33
+ machine : str = "gpu.xlarge "
34
+ machine_memory : int = 2062607
34
35
job_name_prefix : str = "monarch-torchft"
35
36
36
- job_handles : Dict [str , str ] = {}
37
+ def __init__ (self ):
38
+ self .job_handles : Dict [str , str ] = {}
39
+ atexit .register (self .kill_jobs )
37
40
38
- @classmethod
39
- def get_config (cls , mesh_name : str , nodes_per_mesh : int ) -> Config :
41
+ def get_config (self , mesh_name : str , nodes_per_mesh : int ) -> Config :
40
42
mesh = [f"{ mesh_name } :{ nodes_per_mesh } :{ MonarchSlurm .machine } " ]
41
- appdef = hyperactor .host_mesh (meshes = mesh )
43
+ # to enable relative import of utils on actors
44
+ current_dir = os .path .dirname (os .path .abspath (__file__ ))
45
+ env = {"PYTHONPATH" : current_dir }
46
+
47
+ appdef = hyperactor .host_mesh (meshes = mesh , env = env )
42
48
43
49
for role in appdef .roles :
44
50
role .resource .memMB = MonarchSlurm .machine_memory
45
51
46
52
return Config (scheduler = "slurm" , appdef = appdef )
47
53
48
- @classmethod
49
- async def get_or_create_job (cls , mesh_name : str , nodes_per_mesh : int = 1 ) -> None :
50
- config = cls .get_config (mesh_name , nodes_per_mesh )
54
+ async def get_or_create_job (self , mesh_name : str , nodes_per_mesh : int = 1 ) -> None :
55
+ config = self .get_config (mesh_name , nodes_per_mesh )
51
56
job_name = f"{ MonarchSlurm .job_name_prefix } -{ mesh_name } "
52
57
server_spec = await commands .get_or_create (job_name , config , force_restart = True )
53
- cls .job_handles [mesh_name ] = server_spec .name
58
+ self .job_handles [mesh_name ] = server_spec .name
54
59
55
- @classmethod
56
- def kill_jobs (cls ):
57
- for mesh_name , job_handle in cls .job_handles .items ():
58
- try :
59
- logger .info (f"Destroying job for mesh { mesh_name } " )
60
- commands .kill (f"slurm:///{ job_handle } " )
61
- except Exception as e :
62
- logger .warning (f"Failed to destroy job for { mesh_name } : { e } " )
60
+ def kill_jobs (self ):
61
+ for mesh_name in self .job_handles .keys ():
62
+ self .kill_job (mesh_name )
63
+
64
+ def kill_job (self , mesh_name : str ):
65
+ try :
66
+ job_handle = self .job_handles [mesh_name ]
67
+ logger .info (f"Destroying job for mesh { mesh_name } " )
68
+ commands .kill (f"slurm:///{ job_handle } " )
69
+ except Exception as e :
70
+ logger .warning (f"Failed to destroy job for { mesh_name } : { e } " )
63
71
64
- @classmethod
65
72
def proc_mesh (
66
- cls ,
73
+ self ,
67
74
mesh_name : str ,
68
75
num_hosts : int = 1 ,
69
76
num_gpus : int = 8 ,
70
77
) -> ProcMesh :
71
78
allocator = RemoteAllocator (
72
79
world_id = MonarchSlurm .job_name_prefix ,
73
80
initializer = TorchXRemoteAllocInitializer (
74
- f"slurm:///{ cls .job_handles [mesh_name ]} "
81
+ f"slurm:///{ self .job_handles [mesh_name ]} "
75
82
),
76
83
)
77
84
alloc = allocator .allocate (
@@ -94,7 +101,7 @@ def start_lighthouse(self) -> str:
94
101
from torchft .coordination import LighthouseServer
95
102
96
103
self .lighthouse = LighthouseServer (
97
- bind = "[::]:0" , min_replicas = 1 , join_timeout_ms = 10000
104
+ bind = "[::]:0" , min_replicas = 1 , join_timeout_ms = 60000
98
105
)
99
106
return self .lighthouse .address ()
100
107
@@ -140,6 +147,7 @@ class JobSpec:
140
147
replica_count : int
141
148
hosts_per_replica : int
142
149
gpus_per_node : int
150
+ with_failures : bool
143
151
lighthouse_address : str = ""
144
152
145
153
@@ -154,16 +162,15 @@ class Replica:
154
162
# This does not currently benefit from being an actor, but will once
155
163
# Monarch supervision APIs are fleshed out.
156
164
class ReplicaActor (Actor ):
157
- def __init__ (
158
- self ,
159
- spec : JobSpec ,
160
- replica_id : int ,
161
- ) -> None :
165
+ def __init__ (self , spec : JobSpec , replica_id : int , scheduler : MonarchSlurm ) -> None :
162
166
self .spec = deepcopy (spec )
163
167
self .replica_id = replica_id
164
168
165
169
self .uid = f"[replica_{ replica_id } ]"
166
170
self .spec .job_config .fault_tolerance .replica_id = self .replica_id
171
+ self .scheduler = scheduler
172
+
173
+ self .failure_actors : FailureActor | None = None
167
174
168
175
@endpoint
169
176
async def start_replica (self ) -> None :
@@ -172,14 +179,12 @@ async def start_replica(self) -> None:
172
179
173
180
trainers_proc_mesh : ProcMesh | None = None
174
181
try :
175
- trainers_proc_mesh = MonarchSlurm .proc_mesh (
182
+ trainers_proc_mesh = self . scheduler .proc_mesh (
176
183
f"replica_{ self .replica_id } " ,
177
184
self .spec .hosts_per_replica ,
178
185
self .spec .gpus_per_node ,
179
186
)
180
- await trainers_proc_mesh .logging_option (
181
- stream_to_client = True , aggregate_window_sec = None
182
- )
187
+ await trainers_proc_mesh .logging_option (stream_to_client = True )
183
188
await setup_env_for_distributed (trainers_proc_mesh )
184
189
185
190
training_actors = trainers_proc_mesh .spawn (
@@ -189,6 +194,10 @@ async def start_replica(self) -> None:
189
194
self .replica_id ,
190
195
)
191
196
197
+ self .failure_actors = trainers_proc_mesh .spawn (
198
+ "failure_actors" , FailureActor
199
+ )
200
+
192
201
logger .info (f"{ self .uid } Starting trainers" )
193
202
await training_actors .start_training .call (self .spec .lighthouse_address )
194
203
await trainers_proc_mesh .stop ()
@@ -197,13 +206,29 @@ async def start_replica(self) -> None:
197
206
await trainers_proc_mesh .stop ()
198
207
raise e
199
208
209
+ @endpoint
210
+ async def inject_failure (self , failure_type : Failure ):
211
+ if self .failure_actors :
212
+ try :
213
+ logger .info (
214
+ f"{ self .uid } Injecting failure ({ failure_type } ) into random trainer"
215
+ )
216
+
217
+ await self .failure_actors .fail .choose (failure_type )
218
+ except Exception as e :
219
+ error_msg = f"{ self .uid } Injected failure: { e } "
220
+ logger .error (error_msg )
221
+ else :
222
+ error_msg = f"{ self .uid } No failure actors available"
223
+ logger .error (error_msg )
224
+
200
225
201
226
# delay before re-creating proc mesh on existing job. change as needed.
202
- PROC_ATTEMPT_DELAY = 10
227
+ PROC_ATTEMPT_DELAY = 0
203
228
# proc attempts before getting a new scheduler allocation. change as needed.
204
- PROC_ATTEMPTS = 2
229
+ PROC_ATTEMPTS = 4
205
230
# attempts before failing training on replica. change as needed.
206
- MAX_ATTEMPT = PROC_ATTEMPTS * 2
231
+ MAX_ATTEMPT = PROC_ATTEMPTS * 4
207
232
208
233
209
234
class OrchestrationManager :
@@ -213,32 +238,41 @@ def __init__(self, spec: JobSpec) -> None:
213
238
self .lighthouse_actor : LighthouseActor | None = None
214
239
self .lighthouse_mesh : ProcMesh | None = None
215
240
241
+ self .scheduler = MonarchSlurm ()
242
+
216
243
async def start_training (self ) -> None :
217
244
logger .info (
218
245
f"[Controller] Creating training system with { self .spec .replica_count } replicas"
219
246
)
220
247
221
248
for replica_id in range (self .spec .replica_count ):
222
- await MonarchSlurm .get_or_create_job (
249
+ await self . scheduler .get_or_create_job (
223
250
f"replica_{ replica_id } " , self .spec .hosts_per_replica
224
251
)
225
252
226
253
mesh_futures = {}
227
254
for i in range (self .spec .replica_count ):
228
255
mesh_futures [i ] = asyncio .create_task (self ._run_replica (i , 0 ))
229
256
257
+ failure_future = None
258
+ if self .spec .with_failures :
259
+ failure_future = asyncio .create_task (
260
+ FailureController .execute_failures (self .replicas , self .scheduler )
261
+ )
262
+
230
263
await asyncio .gather (* mesh_futures .values (), return_exceptions = True )
231
264
265
+ if failure_future :
266
+ failure_future .cancel ()
267
+
232
268
async def start_lighthouse (self ) -> None :
233
269
if self .spec .remote_lighthouse :
234
- await MonarchSlurm .get_or_create_job ("lighthouse" )
235
- self .lighthouse_mesh = MonarchSlurm .proc_mesh ("lighthouse" , num_gpus = 1 )
270
+ await self . scheduler .get_or_create_job ("lighthouse" )
271
+ self .lighthouse_mesh = self . scheduler .proc_mesh ("lighthouse" , num_gpus = 1 )
236
272
else :
237
273
self .lighthouse_mesh = this_host ().spawn_procs ({"gpus" : 1 })
238
274
239
- await self .lighthouse_mesh .logging_option (
240
- stream_to_client = True , aggregate_window_sec = None
241
- )
275
+ await self .lighthouse_mesh .logging_option (stream_to_client = True )
242
276
self .lighthouse_actor = self .lighthouse_mesh .spawn (
243
277
"lighthouse_actor" , LighthouseActor
244
278
)
@@ -274,7 +308,8 @@ async def _spin_up_replica(self, replica_id: int, attempt_number: int = 0) -> No
274
308
logger .info (
275
309
f"[Controller] Replica { replica_id } has failed { attempt_number } times. Getting new allocation."
276
310
)
277
- await MonarchSlurm .get_or_create_job (
311
+ self .scheduler .kill_job (f"replica_{ replica_id } " )
312
+ await self .scheduler .get_or_create_job (
278
313
f"replica_{ replica_id } " , self .spec .hosts_per_replica
279
314
)
280
315
delay = 0 if not attempt_number else PROC_ATTEMPT_DELAY
@@ -287,10 +322,7 @@ async def _spin_up_replica(self, replica_id: int, attempt_number: int = 0) -> No
287
322
await replica_proc_mesh .logging_option (aggregate_window_sec = None )
288
323
289
324
replica_actor = replica_proc_mesh .spawn (
290
- "replica_actor" ,
291
- ReplicaActor ,
292
- self .spec ,
293
- replica_id ,
325
+ "replica_actor" , ReplicaActor , self .spec , replica_id , self .scheduler
294
326
)
295
327
296
328
replica = Replica (replica_id , replica_proc_mesh , replica_actor , attempt_number )
@@ -301,8 +333,8 @@ async def _teardown(self, replica_id: int) -> None:
301
333
try :
302
334
replica = self .replicas [replica_id ]
303
335
await replica .proc_mesh .stop ()
304
- del replica .proc_mesh
305
336
del self .replicas [replica_id ]
337
+ del replica .proc_mesh
306
338
except Exception as e :
307
339
logger .error (f"[Controller] Failed to _teardown replica { replica_id } : { e } " )
308
340
@@ -339,20 +371,25 @@ def parse_args() -> argparse.Namespace:
339
371
parser .add_argument (
340
372
"--model-config" ,
341
373
type = str ,
342
- default = os . path . join ( script_dir , "debug_model.toml" ) ,
343
- help = f"Path to model configuration file (default: { os .path .join (script_dir , 'debug_model.toml' )} )" ,
374
+ default = "debug_model.toml" ,
375
+ help = f"Relative path to model configuration file (default: { os .path .join (script_dir , 'debug_model.toml' )} )" ,
344
376
)
345
377
parser .add_argument (
346
378
"--dataset-path" ,
347
379
type = str ,
348
- default = os . path . join ( script_dir , "c4_test" ) ,
349
- help = f"Path to training dataset (default: { os .path .join (script_dir , 'c4_test' )} )" ,
380
+ default = "c4_test" ,
381
+ help = f"Relative path to training dataset (default: { os .path .join (script_dir , 'c4_test' )} )" ,
350
382
)
351
383
parser .add_argument (
352
384
"--tokenizer-path" ,
353
385
type = str ,
354
- default = os .path .join (script_dir , "tokenizer" ),
355
- help = f"Path to tokenizer (default: { os .path .join (script_dir , 'tokenizer' )} )" ,
386
+ default = "debug_tokenizer" ,
387
+ help = f"Relative path to tokenizer (default: { os .path .join (script_dir , 'debug_tokenizer' )} )" ,
388
+ )
389
+ parser .add_argument (
390
+ "--with-failures" ,
391
+ action = "store_true" ,
392
+ help = "Enable the failure injector utility (default: False)" ,
356
393
)
357
394
358
395
return parser .parse_args ()
@@ -362,32 +399,37 @@ def make_job_spec(args: argparse.Namespace) -> JobSpec:
362
399
data_parallel_shard_degree = args .gpu_per_node * args .host_per_replica
363
400
364
401
output_path = "./outputs"
365
- training_dataset = "c4_test"
402
+ training_dataset = args . dataset_path . split ( "/" )[ - 1 ]
366
403
404
+ script_dir = os .path .dirname (os .path .abspath (__file__ ))
367
405
default_args = [
368
406
"--job.config_file" ,
369
- args .model_config ,
407
+ os . path . join ( script_dir , args .model_config ) ,
370
408
"--model.tokenizer_path" ,
371
- args .tokenizer_path ,
409
+ os . path . join ( script_dir , args .tokenizer_path ) ,
372
410
"--comm.trace_buf_size" ,
373
411
"0" ,
374
412
"--metrics.log_freq" ,
375
413
"1" ,
376
414
"--fault_tolerance.enable" ,
377
415
"--fault_tolerance.group_size" ,
378
416
str (args .replica_count ),
417
+ "--fault_tolerance.process_group" ,
418
+ "nccl" ,
419
+ "--fault_tolerance.process_group_timeout_ms" ,
420
+ "60000" ,
379
421
"--parallelism.data_parallel_shard_degree" ,
380
422
str (data_parallel_shard_degree ),
381
423
"--activation_checkpoint.mode" ,
382
424
"full" ,
383
425
"--comm.train_timeout_seconds" ,
384
- "60 " ,
426
+ "300 " ,
385
427
"--training.steps" ,
386
428
str (args .training_steps ),
387
429
"--training.dataset" ,
388
430
training_dataset ,
389
431
"--training.dataset_path" ,
390
- args .dataset_path ,
432
+ os . path . join ( script_dir , args .dataset_path ) ,
391
433
"--job.dump_folder" ,
392
434
output_path ,
393
435
"--metrics.enable_tensorboard" ,
@@ -402,6 +444,7 @@ def make_job_spec(args: argparse.Namespace) -> JobSpec:
402
444
replica_count = args .replica_count ,
403
445
hosts_per_replica = args .host_per_replica ,
404
446
gpus_per_node = args .gpu_per_node ,
447
+ with_failures = args .with_failures ,
405
448
)
406
449
407
450
@@ -414,7 +457,6 @@ async def main() -> None:
414
457
args = parse_args ()
415
458
job_spec = make_job_spec (args )
416
459
417
- atexit .register (MonarchSlurm .kill_jobs )
418
460
orchestrator = OrchestrationManager (job_spec )
419
461
try :
420
462
await orchestrator .start_lighthouse ()
0 commit comments