Skip to content

Commit 54781ca

Browse files
authored
Add alias functionality for train, launch, evaluate, and infer in the oumi CLI (#1618)
1 parent bd4892f commit 54781ca

File tree

10 files changed

+289
-5
lines changed

10 files changed

+289
-5
lines changed

src/oumi/cli/alias.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright 2025 - Oumi
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from enum import Enum
16+
17+
from oumi.utils.logging import logger
18+
19+
20+
class AliasType(str, Enum):
21+
"""The type of configs we support with aliases."""
22+
23+
TRAIN = "train"
24+
EVAL = "eval"
25+
INFER = "infer"
26+
JOB = "job"
27+
28+
29+
_ALIASES: dict[str, dict[AliasType, str]] = {
30+
"llama4-scout": {
31+
AliasType.TRAIN: "oumi://configs/recipes/llama4/sft/scout_base_full/train.yaml",
32+
AliasType.JOB: "oumi://configs/recipes/llama4/sft/scout_base_full/train.yaml",
33+
},
34+
"llama4-scout-instruct-lora": {
35+
AliasType.TRAIN: "oumi://configs/recipes/llama4/sft/scout_instruct_lora/train.yaml",
36+
},
37+
"llama4-scout-instruct-qlora": {
38+
AliasType.TRAIN: "oumi://configs/recipes/llama4/sft/scout_instruct_qlora/train.yaml",
39+
},
40+
"llama4-scout-instruct": {
41+
AliasType.TRAIN: "oumi://configs/recipes/llama4/sft/scout_instruct_full/train.yaml",
42+
AliasType.INFER: "oumi://configs/recipes/llama4/inference/scout_instruct_infer.yaml",
43+
AliasType.JOB: "oumi://configs/recipes/llama4/sft/scout_instruct_full/gcp_job.yaml",
44+
AliasType.EVAL: "oumi://configs/recipes/llama4/evaluation/scout_instruct_eval.yaml",
45+
},
46+
"llama4-maverick": {
47+
AliasType.INFER: "oumi://configs/recipes/llama4/inference/maverick_instruct_together_infer.yaml",
48+
},
49+
}
50+
51+
52+
def try_get_config_name_for_alias(
53+
alias: str,
54+
alias_type: AliasType,
55+
) -> str:
56+
"""Gets the config path for a given alias.
57+
58+
This function resolves the config path for a given alias and alias type.
59+
If the alias is not found, the original alias is returned.
60+
61+
Args:
62+
alias (str): The alias to resolve.
63+
alias_type (AliasType): The type of config to resolve.
64+
65+
Returns:
66+
str: The resolved config path (or the original alias if not found).
67+
"""
68+
if alias in _ALIASES and alias_type in _ALIASES[alias]:
69+
config_path = _ALIASES[alias][alias_type]
70+
logger.info(f"Resolved alias '{alias}' to '{config_path}'")
71+
return config_path
72+
return alias

src/oumi/cli/evaluate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from rich.table import Table
1919

2020
import oumi.cli.cli_utils as cli_utils
21+
from oumi.cli.alias import AliasType, try_get_config_name_for_alias
2122
from oumi.utils.logging import logger
2223

2324

@@ -42,7 +43,7 @@ def evaluate(
4243

4344
config = str(
4445
cli_utils.resolve_and_fetch_config(
45-
config,
46+
try_get_config_name_for_alias(config, AliasType.EVAL),
4647
)
4748
)
4849
with cli_utils.CONSOLE.status(

src/oumi/cli/infer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from rich.table import Table
2020

2121
import oumi.cli.cli_utils as cli_utils
22+
from oumi.cli.alias import AliasType, try_get_config_name_for_alias
2223
from oumi.utils.logging import logger
2324

2425
_DEFAULT_CLI_PDF_DPI: Final[int] = 200
@@ -79,7 +80,7 @@ def infer(
7980

8081
config = str(
8182
cli_utils.resolve_and_fetch_config(
82-
config,
83+
try_get_config_name_for_alias(config, AliasType.INFER),
8384
)
8485
)
8586

src/oumi/cli/launch.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from rich.text import Text
2626

2727
import oumi.cli.cli_utils as cli_utils
28+
from oumi.cli.alias import AliasType, try_get_config_name_for_alias
2829
from oumi.utils.git_utils import get_git_root_dir
2930
from oumi.utils.logging import logger
3031
from oumi.utils.version_utils import is_dev_build
@@ -324,7 +325,7 @@ def run(
324325

325326
config = str(
326327
cli_utils.resolve_and_fetch_config(
327-
config,
328+
try_get_config_name_for_alias(config, AliasType.JOB),
328329
)
329330
)
330331

@@ -490,7 +491,7 @@ def up(
490491

491492
config = str(
492493
cli_utils.resolve_and_fetch_config(
493-
config,
494+
try_get_config_name_for_alias(config, AliasType.JOB),
494495
)
495496
)
496497

src/oumi/cli/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import typer
1818

1919
import oumi.cli.cli_utils as cli_utils
20+
from oumi.cli.alias import AliasType, try_get_config_name_for_alias
2021
from oumi.utils.logging import logger
2122

2223

@@ -41,7 +42,7 @@ def train(
4142

4243
config = str(
4344
cli_utils.resolve_and_fetch_config(
44-
config,
45+
try_get_config_name_for_alias(config, AliasType.TRAIN),
4546
)
4647
)
4748
with cli_utils.CONSOLE.status(

tests/unit/cli/test_cli_alias.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from oumi.cli.alias import _ALIASES, AliasType, try_get_config_name_for_alias
2+
3+
4+
def test_alias_all_entries():
5+
for alias in _ALIASES:
6+
for alias_type in _ALIASES[alias]:
7+
config_path = try_get_config_name_for_alias(alias, alias_type)
8+
assert config_path == _ALIASES[alias][alias_type], (
9+
f"Alias '{alias}' with type '{alias_type}' did not resolve correctly."
10+
f" Expected: {config_path}, Actual: {_ALIASES[alias][alias_type]}"
11+
)
12+
13+
14+
def test_alias_not_found():
15+
alias = "non_existent_alias"
16+
alias_type = AliasType.TRAIN
17+
config_path = try_get_config_name_for_alias(alias, alias_type)
18+
assert (
19+
config_path == alias
20+
), f"Expected the original alias '{alias}' to be returned."
21+
22+
23+
def test_alias_type_not_found():
24+
alias = "llama4-scout"
25+
config_path = try_get_config_name_for_alias(alias, AliasType.EVAL)
26+
assert (
27+
config_path == alias
28+
), f"Expected the original alias '{alias}' to be returned."

tests/unit/cli/test_cli_evaluate.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typer.testing import CliRunner
99

1010
import oumi
11+
import oumi.cli.alias
1112
from oumi.cli.cli_utils import CONTEXT_ALLOW_EXTRA_ARGS
1213
from oumi.cli.evaluate import evaluate
1314
from oumi.core.configs import (
@@ -26,6 +27,12 @@ def mock_fetch():
2627
yield m_fetch
2728

2829

30+
@pytest.fixture
31+
def mock_alias():
32+
with patch("oumi.cli.evaluate.try_get_config_name_for_alias") as try_alias:
33+
yield try_alias
34+
35+
2936
def _create_eval_config() -> EvaluationConfig:
3037
return EvaluationConfig(
3138
output_dir="output/dir",
@@ -69,6 +76,21 @@ def test_evaluate_runs(app, mock_evaluate):
6976
mock_evaluate.assert_has_calls([call(config)])
7077

7178

79+
def test_evaluate_calls_alias(app, mock_evaluate, mock_alias):
80+
with tempfile.TemporaryDirectory() as output_temp_dir:
81+
yaml_path = str(Path(output_temp_dir) / "eval.yaml")
82+
mock_alias.return_value = yaml_path
83+
config: EvaluationConfig = _create_eval_config()
84+
config.to_yaml(yaml_path)
85+
_ = runner.invoke(app, ["--config", "an_alias"])
86+
mock_alias.assert_has_calls(
87+
[
88+
call("an_alias", oumi.cli.alias.AliasType.EVAL),
89+
]
90+
)
91+
mock_evaluate.assert_has_calls([call(config)])
92+
93+
7294
def test_evaluate_unparsable_metrics(app, mock_evaluate):
7395
with tempfile.TemporaryDirectory() as output_temp_dir:
7496
mock_evaluate.return_value = [

tests/unit/cli/test_cli_infer.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from typer.testing import CliRunner
1111

1212
import oumi
13+
from oumi.cli.alias import AliasType
1314
from oumi.cli.cli_utils import CONTEXT_ALLOW_EXTRA_ARGS
1415
from oumi.cli.infer import infer
1516
from oumi.core.configs import (
@@ -29,6 +30,12 @@ def mock_fetch():
2930
yield m_fetch
3031

3132

33+
@pytest.fixture
34+
def mock_alias():
35+
with patch("oumi.cli.infer.try_get_config_name_for_alias") as try_alias:
36+
yield try_alias
37+
38+
3239
def _create_inference_config() -> InferenceConfig:
3340
return InferenceConfig(
3441
model=ModelParams(
@@ -75,6 +82,19 @@ def test_infer_runs(app, mock_infer, mock_infer_interactive):
7582
)
7683

7784

85+
def test_infer_with_alias_runs(app, mock_infer, mock_infer_interactive, mock_alias):
86+
with tempfile.TemporaryDirectory() as output_temp_dir:
87+
yaml_path = str(Path(output_temp_dir) / "infer.yaml")
88+
mock_alias.return_value = yaml_path
89+
config: InferenceConfig = _create_inference_config()
90+
config.to_yaml(yaml_path)
91+
_ = runner.invoke(app, ["-i", "--config", "random_alias"])
92+
mock_alias.assert_called_once_with("random_alias", AliasType.INFER)
93+
mock_infer_interactive.assert_has_calls(
94+
[call(config, input_image_bytes=None, system_prompt=None)]
95+
)
96+
97+
7898
def test_infer_runs_interactive_by_default(app, mock_infer, mock_infer_interactive):
7999
with tempfile.TemporaryDirectory() as output_temp_dir:
80100
yaml_path = str(Path(output_temp_dir) / "infer.yaml")

tests/unit/cli/test_cli_launch.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typer.testing import CliRunner
99

1010
import oumi
11+
from oumi.cli.alias import AliasType
1112
from oumi.cli.cli_utils import CONTEXT_ALLOW_EXTRA_ARGS
1213
from oumi.cli.launch import cancel, down, status, stop, up, which
1314
from oumi.cli.launch import run as launcher_run
@@ -32,6 +33,12 @@ def mock_fetch():
3233
yield m_fetch
3334

3435

36+
@pytest.fixture
37+
def mock_alias():
38+
with patch("oumi.cli.launch.try_get_config_name_for_alias") as try_alias:
39+
yield try_alias
40+
41+
3542
runner = CliRunner()
3643

3744

@@ -202,6 +209,52 @@ def test_launch_up_job(
202209
assert logger.level == logging.DEBUG
203210

204211

212+
def test_launch_up_job_with_alias(
213+
app, mock_launcher, mock_pool, mock_version, mock_confirm, mock_fetch, mock_alias
214+
):
215+
with tempfile.TemporaryDirectory() as output_temp_dir:
216+
train_yaml_path = str(pathlib.Path(output_temp_dir) / "train.yaml")
217+
config: TrainingConfig = _create_training_config()
218+
config.to_yaml(train_yaml_path)
219+
job_yaml_path = str(pathlib.Path(output_temp_dir) / "job.yaml")
220+
mock_alias.return_value = job_yaml_path
221+
job_config = _create_job_config(train_yaml_path)
222+
job_config.to_yaml(job_yaml_path)
223+
mock_launcher.JobConfig = JobConfig
224+
mock_cluster = Mock()
225+
job_status = JobStatus(
226+
id="job_id",
227+
cluster="cluster_id",
228+
name="job_name",
229+
status="running",
230+
metadata="",
231+
done=False,
232+
)
233+
mock_launcher.up.return_value = (mock_cluster, job_status)
234+
mock_cluster.get_job.return_value = job_status = JobStatus(
235+
id="job_id",
236+
cluster="cluster_id",
237+
name="job_name",
238+
status="done",
239+
metadata="",
240+
done=True,
241+
)
242+
_ = runner.invoke(
243+
app,
244+
[
245+
"up",
246+
"--config",
247+
"some_alias",
248+
"--log-level",
249+
"DEBUG",
250+
],
251+
)
252+
mock_fetch.assert_called_once_with(job_yaml_path)
253+
mock_cluster.get_job.assert_has_calls([call("job_id")])
254+
mock_alias.assert_called_once_with("some_alias", AliasType.JOB)
255+
assert logger.level == logging.DEBUG
256+
257+
205258
def test_launch_up_job_dev_confirm(
206259
app, mock_launcher, mock_pool, mock_version, mock_confirm, mock_git_root, mock_fetch
207260
):
@@ -560,6 +613,62 @@ def test_launch_run_job(
560613
assert logger.level == logging.CRITICAL
561614

562615

616+
def test_launch_run_job_with_alias(
617+
app, mock_launcher, mock_pool, mock_version, mock_confirm, mock_fetch, mock_alias
618+
):
619+
with tempfile.TemporaryDirectory() as output_temp_dir:
620+
train_yaml_path = str(pathlib.Path(output_temp_dir) / "train.yaml")
621+
config: TrainingConfig = _create_training_config()
622+
config.to_yaml(train_yaml_path)
623+
job_yaml_path = str(pathlib.Path(output_temp_dir) / "job.yaml")
624+
mock_alias.return_value = job_yaml_path
625+
job_config = _create_job_config(train_yaml_path)
626+
job_config.to_yaml(job_yaml_path)
627+
mock_launcher.JobConfig = JobConfig
628+
mock_cluster = Mock()
629+
job_status = JobStatus(
630+
id="job_id",
631+
cluster="cluster_id",
632+
name="job_name",
633+
status="running",
634+
metadata="",
635+
done=False,
636+
)
637+
mock_cloud = Mock()
638+
mock_launcher.run.return_value = job_status
639+
mock_launcher.get_cloud.side_effect = [mock_cloud, mock_cloud]
640+
mock_cloud.get_cluster.side_effect = [mock_cluster, mock_cluster]
641+
mock_cluster.get_job.return_value = job_status = JobStatus(
642+
id="job_id",
643+
cluster="cluster_id",
644+
name="job_name",
645+
status="done",
646+
metadata="",
647+
done=True,
648+
)
649+
_ = runner.invoke(
650+
app,
651+
[
652+
"run",
653+
"--config",
654+
"some_alias",
655+
"--cluster",
656+
"cluster_id",
657+
"-log",
658+
"CRITICAL",
659+
],
660+
)
661+
mock_cluster.get_job.assert_has_calls([call("job_id"), call("job_id")])
662+
mock_launcher.run.assert_called_once_with(job_config, "cluster_id")
663+
mock_launcher.get_cloud.assert_has_calls([call("aws"), call("aws")])
664+
mock_cloud.get_cluster.assert_has_calls(
665+
[call("cluster_id"), call("cluster_id")]
666+
)
667+
mock_fetch.assert_called_once_with(job_yaml_path)
668+
mock_alias.assert_called_once_with("some_alias", AliasType.JOB)
669+
assert logger.level == logging.CRITICAL
670+
671+
563672
def test_launch_run_job_dev_confirm(
564673
app, mock_launcher, mock_pool, mock_version, mock_confirm, mock_git_root, mock_fetch
565674
):

0 commit comments

Comments
 (0)