Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 67 additions & 2 deletions torchx/schedulers/slurm_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import shlex
import subprocess
import tempfile
import warnings
from dataclasses import dataclass
from datetime import datetime
from subprocess import CalledProcessError, PIPE
Expand Down Expand Up @@ -72,6 +73,55 @@ def appstate_from_slurm_state(slurm_state: str) -> AppState:
return SLURM_STATES.get(slurm_state, AppState.UNKNOWN)


def version() -> Tuple[int, int]:
"""
Uses ``sinfo --version`` to get the slurm version. If the command fails, it
assumes the version is ``slurm 24.05.8``.

Returns:
-------
Tuple[int, int] slurm version as a tuple of ints (major, minor).
"""

cmd = ["sinfo", "--version"]
try:
out = subprocess.check_output(cmd, stderr=PIPE, encoding="utf-8")
except (CalledProcessError, FileNotFoundError):
out = "slurm 24.05.8"
warnings.warn(
"Error running: `{sinfo_cmd}` to get SLURM version. Are you running outside the "
"cluster's login or head node? This typically happens when running in `--dryrun`"
" mode. Assuming version is `slurm 24.05.8`.",
RuntimeWarning,
stacklevel=2,
)

# sinfo --version returns in the form "slurm 24.1.0"
_, version_literal = out.split(" ", maxsplit=2)
major, minor = [int(v) for v in version_literal.split(".")][:2]

return (major, minor)


def _should_use_gpus_per_node_from_version() -> bool:
"""
Determine whether to use gpus-per-node based on automatically detected slurm version.

Change Reference: https://fburl.com/sqwqzxn6
> select/linear - Reject jobs asking for GRES per job|socket|task or cpus|mem per GRES.

Returns:
``True`` in slurm ``version>=24.11.0``, ``False`` otherwise.
"""

slurm_24_11_0 = (24, 11)
slurm_version = version()

return slurm_version[0] > slurm_24_11_0[0] or ( # Major version is greater
slurm_version[0] == slurm_24_11_0[0] and slurm_version[1] >= slurm_24_11_0[1]
) # Major version is equal and minor version is greater or equal


SBATCH_JOB_OPTIONS = {
"comment",
"mail-user",
Expand All @@ -81,6 +131,7 @@ def appstate_from_slurm_state(slurm_state: str) -> AppState:
"partition",
"time",
"constraint",
"qos",
}

log: logging.Logger = logging.getLogger(__name__)
Expand All @@ -106,6 +157,7 @@ def _apply_app_id_env(s: str) -> str:
"mail-user": Optional[str],
"mail-type": Optional[str],
"job_dir": Optional[str],
"qos": Optional[str],
},
total=False,
)
Expand All @@ -126,7 +178,11 @@ class SlurmReplicaRequest:

@classmethod
def from_role(
cls, name: str, role: Role, cfg: SlurmOpts, nomem: bool
cls,
name: str,
role: Role,
cfg: SlurmOpts,
nomem: bool,
) -> "SlurmReplicaRequest":
"""
``from_role`` creates a SlurmReplicaRequest for the specific role and
Expand All @@ -149,7 +205,11 @@ def from_role(
if not nomem and resource.memMB > 0:
sbatch_opts.setdefault("mem", str(resource.memMB))
if resource.gpu > 0:
sbatch_opts.setdefault("gpus-per-task", str(resource.gpu))
# Use smart GPU allocation based on automatically detected Slurm version
if _should_use_gpus_per_node_from_version():
sbatch_opts.setdefault("gpus-per-node", str(resource.gpu))
else:
sbatch_opts.setdefault("gpus-per-task", str(resource.gpu))

srun_opts = {
"output": f"slurm-{macros.app_id}-{name}.out",
Expand Down Expand Up @@ -378,6 +438,11 @@ def _run_opts(self) -> runopts:
iteration, jobs will be tracked in ``.torchxslurmjobdirs``.
""",
)
opts.add(
"qos",
type_=str,
help="Quality of Service (QoS) to assign to the job.",
)
return opts

def schedule(self, dryrun_info: AppDryRunInfo[SlurmBatchRequest]) -> str:
Expand Down
Loading
Loading