4949
5050SLURM_JOB_DIRS = ".torchxslurmjobdirs"
5151
52+ DEFAULT_SLURM_VERSION : str = "1.0"
53+
5254SLURM_STATES : Mapping [str , AppState ] = {
5355 "BOOT_FAIL" : AppState .FAILED ,
5456 "CANCELLED" : AppState .CANCELLED ,
@@ -72,6 +74,45 @@ def appstate_from_slurm_state(slurm_state: str) -> AppState:
7274 return SLURM_STATES .get (slurm_state , AppState .UNKNOWN )
7375
7476
77+ def _parse_slurm_version (version_str : str ) -> Tuple [int , int ]:
78+ """
79+ Parse Slurm version string (e.g., '24.05', '25.11.2') into (major, minor) tuple.
80+ Raises ValueError if parsing fails.
81+ """
82+ parts = version_str .split ("." )
83+ if len (parts ) < 2 :
84+ raise ValueError (
85+ f"Invalid Slurm version string: { version_str } . Expected format: '24.05' or '25.11.2'"
86+ )
87+
88+ try :
89+ major = int (parts [0 ])
90+ minor = int (parts [1 ])
91+ except (ValueError , IndexError ) as err :
92+ raise ValueError (
93+ f"Invalid Slurm version string: { version_str } . Expected format: '24.05' or '25.11.2'"
94+ ) from err
95+
96+ return (major , minor )
97+
98+
99+ def _should_use_gpus_per_node_from_version (version_str : Optional [str ]) -> bool :
100+ """
101+ Determine whether to use gpus-per-node based on version string.
102+ Returns True if version > 24.11, False otherwise or if version cannot be parsed.
103+ """
104+ if not version_str :
105+ return False
106+
107+ try :
108+ major , minor = _parse_slurm_version (version_str )
109+ except ValueError :
110+ return False
111+
112+ # Use gpus-per-node if version > 24.11
113+ return major > 24 or (major == 24 and minor > 11 )
114+
115+
75116SBATCH_JOB_OPTIONS = {
76117 "comment" ,
77118 "mail-user" ,
@@ -81,6 +122,7 @@ def appstate_from_slurm_state(slurm_state: str) -> AppState:
81122 "partition" ,
82123 "time" ,
83124 "constraint" ,
125+ "qos" ,
84126}
85127
86128log : logging .Logger = logging .getLogger (__name__ )
@@ -106,6 +148,8 @@ def _apply_app_id_env(s: str) -> str:
106148 "mail-user" : Optional [str ],
107149 "mail-type" : Optional [str ],
108150 "job_dir" : Optional [str ],
151+ "qos" : Optional [str ],
152+ "slurm_version" : Optional [str ],
109153 },
110154 total = False ,
111155)
@@ -126,7 +170,11 @@ class SlurmReplicaRequest:
126170
127171 @classmethod
128172 def from_role (
129- cls , name : str , role : Role , cfg : SlurmOpts , nomem : bool
173+ cls ,
174+ name : str ,
175+ role : Role ,
176+ cfg : SlurmOpts ,
177+ nomem : bool ,
130178 ) -> "SlurmReplicaRequest" :
131179 """
132180 ``from_role`` creates a SlurmReplicaRequest for the specific role and
@@ -149,7 +197,12 @@ def from_role(
149197 if not nomem and resource .memMB > 0 :
150198 sbatch_opts .setdefault ("mem" , str (resource .memMB ))
151199 if resource .gpu > 0 :
152- sbatch_opts .setdefault ("gpus-per-task" , str (resource .gpu ))
200+ # Use smart GPU allocation based on Slurm version from config
201+ slurm_version = cfg .get ("slurm_version" )
202+ if _should_use_gpus_per_node_from_version (slurm_version ):
203+ sbatch_opts .setdefault ("gpus-per-node" , str (resource .gpu ))
204+ else :
205+ sbatch_opts .setdefault ("gpus-per-task" , str (resource .gpu ))
153206
154207 srun_opts = {
155208 "output" : f"slurm-{ macros .app_id } -{ name } .out" ,
@@ -378,6 +431,18 @@ def _run_opts(self) -> runopts:
378431 iteration, jobs will be tracked in ``.torchxslurmjobdirs``.
379432 """ ,
380433 )
434+ opts .add (
435+ "qos" ,
436+ type_ = str ,
437+ help = "Quality of Service (QoS) to assign to the job." ,
438+ )
439+ opts .add (
440+ "slurm_version" ,
441+ type_ = str ,
442+ help = """Slurm version (e.g., '24.05', '25.11'). If version > 24.11,
443+ uses gpus-per-node instead of gpus-per-task for GPU allocation.
444+ """ ,
445+ )
381446 return opts
382447
383448 def schedule (self , dryrun_info : AppDryRunInfo [SlurmBatchRequest ]) -> str :
@@ -401,6 +466,55 @@ def schedule(self, dryrun_info: AppDryRunInfo[SlurmBatchRequest]) -> str:
401466
402467 return job_id
403468
469+ def _get_slurm_version (self ) -> str :
470+ """
471+ _get_slurm_version returns the Slurm version string (e.g., "24.05").
472+ Raises ValueError if version cannot be determined.
473+ """
474+ try :
475+ p = subprocess .run (
476+ ["sinfo" , "--version" ],
477+ stdout = subprocess .PIPE ,
478+ stderr = subprocess .PIPE ,
479+ )
480+ except FileNotFoundError :
481+ log .error (
482+ "Slurm is not available (sinfo command not found). "
483+ "Returning default 1.0 instead."
484+ )
485+
486+ return DEFAULT_SLURM_VERSION
487+
488+ if p .returncode != 0 :
489+ log .error (
490+ f"Failed to get Slurm version: { p .stderr .decode ('utf-8' ).strip ()} . "
491+ "Returning default 1.0 instead."
492+ )
493+
494+ return DEFAULT_SLURM_VERSION
495+
496+ output = p .stdout .decode ("utf-8" ).strip ().lower ()
497+ if not output .startswith ("slurm " ):
498+ log .error (
499+ f"Unexpected sinfo --version output format: { output } . "
500+ "Returning default 1.0 instead."
501+ )
502+
503+ return DEFAULT_SLURM_VERSION
504+
505+ # Remove "slurm " prefix and extract version (e.g., "24.05.4" -> "24.05")
506+ version_full = output .replace ("slurm" , "" ).strip ()
507+ version_parts = version_full .split ("." )
508+ if len (version_parts ) < 2 :
509+ log .error (
510+ f"Invalid Slurm version format: `{ version_full } `; "
511+ "Returning default 1.0 instead."
512+ )
513+
514+ return DEFAULT_SLURM_VERSION
515+
516+ return f"{ version_parts [0 ]} .{ version_parts [1 ]} "
517+
404518 def _partition_memmb (self , partition : Optional [str ]) -> Optional [int ]:
405519 """
406520 _partition_memmb returns the memory allocation for the given partition
@@ -441,6 +555,12 @@ def _submit_dryrun(
441555 partition = cfg .get ("partition" )
442556 assert partition is None or isinstance (partition , str ), "partition must be str"
443557
558+ # Create a new config with the resolved slurm version
559+ resolved_cfg = cfg .copy ()
560+ resolved_cfg ["slurm_version" ] = cfg .get (
561+ "slurm_version" , self ._get_slurm_version ()
562+ )
563+
444564 # check if the partition has at least 1GB memory, if we're not sure,
445565 # default to using memory allocations
446566 memmb = self ._partition_memmb (partition )
@@ -460,7 +580,7 @@ def _submit_dryrun(
460580 replicas [name ] = SlurmReplicaRequest .from_role (
461581 name ,
462582 replica_role ,
463- cfg ,
583+ resolved_cfg ,
464584 nomem = nomem ,
465585 )
466586 cmd = ["sbatch" , "--parsable" ]
0 commit comments