Skip to content

Commit 01ef2f6

Browse files
committed
feat: add ulimits support to aws_batch (#1126)
1 parent b72ba03 commit 01ef2f6

File tree

2 files changed

+92
-0
lines changed

2 files changed

+92
-0
lines changed

torchx/schedulers/aws_batch_scheduler.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,46 @@
9999
TAG_TORCHX_USER = "torchx.pytorch.org/user"
100100

101101

102+
def parse_ulimits(ulimits_str: str) -> List[Dict[str, Any]]:
103+
"""
104+
Parse ulimit string in format: name=nofile,softLimit=65536,hardLimit=65536
105+
Multiple ulimits separated by semicolons.
106+
"""
107+
if not ulimits_str:
108+
return []
109+
110+
ulimits = []
111+
for ulimit_str in ulimits_str.split(";"):
112+
if not ulimit_str.strip():
113+
continue
114+
115+
ulimit = {}
116+
for opt in ulimit_str.split(","):
117+
key, _, val = opt.partition("=")
118+
key = key.strip()
119+
val = val.strip()
120+
121+
if key == "name":
122+
ulimit["name"] = val
123+
elif key == "softLimit":
124+
ulimit["softLimit"] = int(val) if val != "-1" else -1
125+
elif key == "hardLimit":
126+
ulimit["hardLimit"] = int(val) if val != "-1" else -1
127+
else:
128+
raise ValueError(f"Unknown ulimit option: {key}")
129+
130+
if "name" not in ulimit:
131+
raise ValueError("ulimit must specify 'name'")
132+
if "softLimit" not in ulimit:
133+
raise ValueError("ulimit must specify 'softLimit'")
134+
if "hardLimit" not in ulimit:
135+
raise ValueError("ulimit must specify 'hardLimit'")
136+
137+
ulimits.append(ulimit)
138+
139+
return ulimits
140+
141+
102142
if TYPE_CHECKING:
103143
from docker import DockerClient
104144

@@ -177,6 +217,7 @@ def _role_to_node_properties(
177217
privileged: bool = False,
178218
job_role_arn: Optional[str] = None,
179219
execution_role_arn: Optional[str] = None,
220+
ulimits: Optional[List[Dict[str, Any]]] = None,
180221
) -> Dict[str, object]:
181222
role.mounts += get_device_mounts(role.resource.devices)
182223

@@ -239,6 +280,7 @@ def _role_to_node_properties(
239280
"environment": [{"name": k, "value": v} for k, v in role.env.items()],
240281
"privileged": privileged,
241282
"resourceRequirements": resource_requirements_from_resource(role.resource),
283+
**({"ulimits": ulimits} if ulimits else {}),
242284
"linuxParameters": {
243285
# To support PyTorch dataloaders we need to set /dev/shm to larger
244286
# than the 64M default.
@@ -361,6 +403,7 @@ class AWSBatchOpts(TypedDict, total=False):
361403
priority: int
362404
job_role_arn: Optional[str]
363405
execution_role_arn: Optional[str]
406+
ulimits: Optional[str]
364407

365408

366409
class AWSBatchScheduler(
@@ -514,6 +557,7 @@ def _submit_dryrun(self, app: AppDef, cfg: AWSBatchOpts) -> AppDryRunInfo[BatchJ
514557
privileged=cfg["privileged"],
515558
job_role_arn=cfg.get("job_role_arn"),
516559
execution_role_arn=cfg.get("execution_role_arn"),
560+
ulimits=parse_ulimits(cfg.get("ulimits") or ""),
517561
)
518562
)
519563
node_idx += role.num_replicas
@@ -599,6 +643,11 @@ def _run_opts(self) -> runopts:
599643
type_=str,
600644
help="The Amazon Resource Name (ARN) of the IAM role that the ECS agent can assume for AWS permissions.",
601645
)
646+
opts.add(
647+
"ulimits",
648+
type_=str,
649+
help="Ulimit settings in format: name=nofile,softLimit=65536,hardLimit=65536 (multiple separated by semicolons)",
650+
)
602651
return opts
603652

604653
def _get_job_id(self, app_id: str) -> Optional[str]:

torchx/schedulers/test/aws_batch_scheduler_test.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
AWSBatchScheduler,
2424
create_scheduler,
2525
ENV_TORCHX_ROLE_NAME,
26+
parse_ulimits,
2627
resource_from_resource_requirements,
2728
resource_requirements_from_resource,
2829
to_millis_since_epoch,
@@ -396,6 +397,48 @@ def test_resource_devices(self) -> None:
396397
],
397398
)
398399

400+
def test_role_to_node_properties_ulimits(self) -> None:
401+
role = specs.Role(
402+
name="test",
403+
image="test:latest",
404+
entrypoint="test",
405+
args=["test"],
406+
resource=specs.Resource(cpu=1, memMB=1000, gpu=0),
407+
)
408+
ulimits = [
409+
{"name": "nofile", "softLimit": 65536, "hardLimit": 65536},
410+
{"name": "memlock", "softLimit": -1, "hardLimit": -1},
411+
]
412+
props = _role_to_node_properties(role, 0, ulimits=ulimits)
413+
self.assertEqual(
414+
props["container"]["ulimits"],
415+
ulimits,
416+
)
417+
418+
def test_parse_ulimits(self) -> None:
419+
# Test single ulimit
420+
result = parse_ulimits("name=nofile,softLimit=65536,hardLimit=65536")
421+
expected = [{"name": "nofile", "softLimit": 65536, "hardLimit": 65536}]
422+
self.assertEqual(result, expected)
423+
424+
# Test multiple ulimits
425+
result = parse_ulimits(
426+
"name=nofile,softLimit=65536,hardLimit=65536;name=memlock,softLimit=-1,hardLimit=-1"
427+
)
428+
expected = [
429+
{"name": "nofile", "softLimit": 65536, "hardLimit": 65536},
430+
{"name": "memlock", "softLimit": -1, "hardLimit": -1},
431+
]
432+
self.assertEqual(result, expected)
433+
434+
# Test empty string
435+
result = parse_ulimits("")
436+
self.assertEqual(result, [])
437+
438+
# Test invalid format
439+
with self.assertRaises(ValueError):
440+
parse_ulimits("invalid")
441+
399442
def _mock_scheduler_running_job(self) -> AWSBatchScheduler:
400443
scheduler = AWSBatchScheduler(
401444
"test",

0 commit comments

Comments
 (0)