@@ -92,6 +92,7 @@ def spmd(
92
92
h : str = "gpu.small" ,
93
93
j : str = "1x1" ,
94
94
env : Optional [Dict [str , str ]] = None ,
95
+ metadata : Optional [Dict [str , str ]] = None ,
95
96
max_retries : int = 0 ,
96
97
mounts : Optional [List [str ]] = None ,
97
98
debug : bool = False ,
@@ -131,6 +132,7 @@ def spmd(
131
132
h: the type of host to run on (e.g. aws_p4d.24xlarge). Must be one of the registered named resources
132
133
j: {nnodes}x{nproc_per_node}. For GPU hosts omitting nproc_per_node will infer it from the GPU count on the host
133
134
env: environment variables to be passed to the run (e.g. ENV1=v1,ENV2=v2,ENV3=v3)
135
+ metadata: metadata to be passed to the scheduler (e.g. KEY1=v1,KEY2=v2,KEY3=v3)
134
136
max_retries: the number of scheduler retries allowed
135
137
mounts: (for docker based runs only) mounts to mount into the worker environment/container
136
138
(ex. type=<bind/volume>,src=/host,dst=/job[,readonly]).
@@ -150,6 +152,7 @@ def spmd(
150
152
h = h ,
151
153
j = str (StructuredJArgument .parse_from (h , j )),
152
154
env = env ,
155
+ metadata = metadata ,
153
156
max_retries = max_retries ,
154
157
mounts = mounts ,
155
158
debug = debug ,
@@ -168,6 +171,7 @@ def ddp(
168
171
memMB : int = 1024 ,
169
172
j : str = "1x2" ,
170
173
env : Optional [Dict [str , str ]] = None ,
174
+ metadata : Optional [Dict [str , str ]] = None ,
171
175
max_retries : int = 0 ,
172
176
rdzv_port : int = 29500 ,
173
177
rdzv_backend : str = "c10d" ,
@@ -201,6 +205,7 @@ def ddp(
201
205
h: a registered named resource (if specified takes precedence over cpu, gpu, memMB)
202
206
j: [{min_nnodes}:]{nnodes}x{nproc_per_node}, for gpu hosts, nproc_per_node must not exceed num gpus
203
207
env: environment varibles to be passed to the run (e.g. ENV1=v1,ENV2=v2,ENV3=v3)
208
+ metadata: metadata to be passed to the scheduler (e.g. KEY1=v1,KEY2=v2,KEY3=v3)
204
209
max_retries: the number of scheduler retries allowed
205
210
rdzv_port: the port on rank0's host to use for hosting the c10d store used for rendezvous.
206
211
Only takes effect when running multi-node. When running single node, this parameter
@@ -237,8 +242,8 @@ def ddp(
237
242
# use $$ in the prefix to escape the '$' literal (rather than a string Template substitution argument)
238
243
rdzv_endpoint = _noquote (f"$${{{ macros .rank0_env } :=localhost}}:{ rdzv_port } " )
239
244
240
- if env is None :
241
- env = {}
245
+ env = env or {}
246
+ metadata = metadata or {}
242
247
243
248
argname = StructuredNameArgument .parse_from (
244
249
name = name ,
@@ -299,6 +304,7 @@ def ddp(
299
304
mounts = specs .parse_mounts (mounts ) if mounts else [],
300
305
)
301
306
],
307
+ metadata = metadata ,
302
308
)
303
309
304
310
0 commit comments