Skip to content

Commit 1429367

Browse files
committed
feat: add metadata parameters to dist/spmd components (#1037)
1 parent 1e3df20 commit 1429367

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

torchx/components/dist.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def spmd(
9292
h: str = "gpu.small",
9393
j: str = "1x1",
9494
env: Optional[Dict[str, str]] = None,
95+
metadata: Optional[Dict[str, str]] = None,
9596
max_retries: int = 0,
9697
mounts: Optional[List[str]] = None,
9798
debug: bool = False,
@@ -131,6 +132,7 @@ def spmd(
131132
h: the type of host to run on (e.g. aws_p4d.24xlarge). Must be one of the registered named resources
132133
j: {nnodes}x{nproc_per_node}. For GPU hosts omitting nproc_per_node will infer it from the GPU count on the host
133134
env: environment variables to be passed to the run (e.g. ENV1=v1,ENV2=v2,ENV3=v3)
135+
metadata: metadata to be passed to the scheduler (e.g. KEY1=v1,KEY2=v2,KEY3=v3)
134136
max_retries: the number of scheduler retries allowed
135137
mounts: (for docker based runs only) mounts to mount into the worker environment/container
136138
(ex. type=<bind/volume>,src=/host,dst=/job[,readonly]).
@@ -150,6 +152,7 @@ def spmd(
150152
h=h,
151153
j=str(StructuredJArgument.parse_from(h, j)),
152154
env=env,
155+
metadata=metadata,
153156
max_retries=max_retries,
154157
mounts=mounts,
155158
debug=debug,
@@ -168,6 +171,7 @@ def ddp(
168171
memMB: int = 1024,
169172
j: str = "1x2",
170173
env: Optional[Dict[str, str]] = None,
174+
metadata: Optional[Dict[str, str]] = None,
171175
max_retries: int = 0,
172176
rdzv_port: int = 29500,
173177
rdzv_backend: str = "c10d",
@@ -201,6 +205,7 @@ def ddp(
201205
h: a registered named resource (if specified takes precedence over cpu, gpu, memMB)
202206
j: [{min_nnodes}:]{nnodes}x{nproc_per_node}, for gpu hosts, nproc_per_node must not exceed num gpus
203207
env: environment varibles to be passed to the run (e.g. ENV1=v1,ENV2=v2,ENV3=v3)
208+
metadata: metadata to be passed to the scheduler (e.g. KEY1=v1,KEY2=v2,KEY3=v3)
204209
max_retries: the number of scheduler retries allowed
205210
rdzv_port: the port on rank0's host to use for hosting the c10d store used for rendezvous.
206211
Only takes effect when running multi-node. When running single node, this parameter
@@ -237,8 +242,8 @@ def ddp(
237242
# use $$ in the prefix to escape the '$' literal (rather than a string Template substitution argument)
238243
rdzv_endpoint = _noquote(f"$${{{macros.rank0_env}:=localhost}}:{rdzv_port}")
239244

240-
if env is None:
241-
env = {}
245+
env = env or {}
246+
metadata = metadata or {}
242247

243248
argname = StructuredNameArgument.parse_from(
244249
name=name,
@@ -299,6 +304,7 @@ def ddp(
299304
mounts=specs.parse_mounts(mounts) if mounts else [],
300305
)
301306
],
307+
metadata=metadata,
302308
)
303309

304310

torchx/components/test/dist_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,13 @@ def test_ddp_debug(self) -> None:
4040
for k, v in _TORCH_DEBUG_FLAGS.items():
4141
self.assertEqual(env[k], v)
4242

43+
def test_ddp_metadata(self) -> None:
44+
metadata = {"key": "value"}
45+
app = ddp(script="foo.py", metadata=metadata)
46+
for k, v in metadata.items():
47+
self.assertEqual(app.metadata[k], v)
48+
self.assertEqual(len(metadata), len(app.metadata))
49+
4350
def test_ddp_rdzv_backend_static(self) -> None:
4451
rdzv_conf = "join_timeout=600,close_timeout=600,timeout=600"
4552
app = ddp(script="foo.py", rdzv_backend="static", rdzv_conf=rdzv_conf)
@@ -55,6 +62,13 @@ def test_validate_spmd(self) -> None:
5562

5663
self.validate(dist, "ddp")
5764

65+
def test_spmd_metadata(self) -> None:
66+
metadata = {"key": "value"}
67+
app = spmd(script="foo.py", metadata=metadata)
68+
for k, v in metadata.items():
69+
self.assertEqual(app.metadata[k], v)
70+
self.assertEqual(len(metadata), len(app.metadata))
71+
5872
def test_spmd_call_by_module_or_script_no_name(self) -> None:
5973
appdef = spmd(script="foo/bar.py")
6074
self.assertEqual("bar", appdef.name)

0 commit comments

Comments
 (0)