7
7
8
8
import argparse
9
9
import asyncio
10
+ import atexit
10
11
import os
11
-
12
12
from copy import deepcopy
13
13
from dataclasses import dataclass
14
14
from typing import Dict
15
- import atexit
16
15
17
16
import torch
18
-
19
17
from monarch ._rust_bindings .monarch_hyperactor .alloc import AllocConstraints , AllocSpec
20
18
from monarch ._src .actor .allocator import RemoteAllocator , TorchXRemoteAllocInitializer
21
- from monarch .actor import Actor , current_rank , endpoint , ProcMesh , this_host
19
+ from monarch .actor import Actor , ProcMesh , current_rank , endpoint , this_host
22
20
from monarch .tools import commands
23
21
from monarch .tools .components import hyperactor
24
22
from monarch .tools .config import Config
25
23
from monarch .utils import setup_env_for_distributed
26
-
27
24
from torchtitan .config import ConfigManager , JobConfig
28
25
from torchtitan .tools .logging import init_logger , logger
29
26
from torchtitan .train import Trainer
@@ -73,7 +70,9 @@ def proc_mesh(
73
70
) -> ProcMesh :
74
71
allocator = RemoteAllocator (
75
72
world_id = MonarchSlurm .job_name_prefix ,
76
- initializer = TorchXRemoteAllocInitializer (f"slurm:///{ cls .job_handles [mesh_name ]} " ),
73
+ initializer = TorchXRemoteAllocInitializer (
74
+ f"slurm:///{ cls .job_handles [mesh_name ]} "
75
+ ),
77
76
)
78
77
alloc = allocator .allocate (
79
78
AllocSpec (AllocConstraints (), hosts = num_hosts , gpus = num_gpus )
@@ -84,13 +83,16 @@ def proc_mesh(
84
83
85
84
# ==== allocation boilerplate ====
86
85
86
+
87
87
class LighthouseActor (Actor ):
88
88
def __init__ (self ) -> None :
89
89
self .lighthouse = None
90
90
91
91
@endpoint
92
92
def start_lighthouse (self ) -> str :
93
+ # inline import because of https://github.com/meta-pytorch/monarch/issues/804
93
94
from torchft .coordination import LighthouseServer
95
+
94
96
self .lighthouse = LighthouseServer (
95
97
bind = "[::]:0" , min_replicas = 1 , join_timeout_ms = 10000
96
98
)
@@ -217,7 +219,9 @@ async def start_training(self) -> None:
217
219
)
218
220
219
221
for replica_id in range (self .spec .replica_count ):
220
- await MonarchSlurm .get_or_create_job (f"replica_{ replica_id } " , self .spec .hosts_per_replica )
222
+ await MonarchSlurm .get_or_create_job (
223
+ f"replica_{ replica_id } " , self .spec .hosts_per_replica
224
+ )
221
225
222
226
mesh_futures = {}
223
227
for i in range (self .spec .replica_count ):
@@ -305,6 +309,7 @@ async def _teardown(self, replica_id: int) -> None:
305
309
306
310
# === CLI / CONFIG === #
307
311
312
+
308
313
def parse_args () -> argparse .Namespace :
309
314
parser = argparse .ArgumentParser (
310
315
description = "Monarch-TorchFT Distributed Training Example"
@@ -398,6 +403,8 @@ def make_job_spec(args: argparse.Namespace) -> JobSpec:
398
403
hosts_per_replica = args .host_per_replica ,
399
404
gpus_per_node = args .gpu_per_node ,
400
405
)
406
+
407
+
401
408
# === CLI / CONFIG === #
402
409
403
410
0 commit comments