8
8
from typer .testing import CliRunner
9
9
10
10
import oumi
11
+ from oumi .cli .alias import AliasType
11
12
from oumi .cli .cli_utils import CONTEXT_ALLOW_EXTRA_ARGS
12
13
from oumi .cli .launch import cancel , down , status , stop , up , which
13
14
from oumi .cli .launch import run as launcher_run
@@ -32,6 +33,12 @@ def mock_fetch():
32
33
yield m_fetch
33
34
34
35
36
+ @pytest .fixture
37
+ def mock_alias ():
38
+ with patch ("oumi.cli.launch.try_get_config_name_for_alias" ) as try_alias :
39
+ yield try_alias
40
+
41
+
35
42
runner = CliRunner ()
36
43
37
44
@@ -202,6 +209,52 @@ def test_launch_up_job(
202
209
assert logger .level == logging .DEBUG
203
210
204
211
212
+ def test_launch_up_job_with_alias (
213
+ app , mock_launcher , mock_pool , mock_version , mock_confirm , mock_fetch , mock_alias
214
+ ):
215
+ with tempfile .TemporaryDirectory () as output_temp_dir :
216
+ train_yaml_path = str (pathlib .Path (output_temp_dir ) / "train.yaml" )
217
+ config : TrainingConfig = _create_training_config ()
218
+ config .to_yaml (train_yaml_path )
219
+ job_yaml_path = str (pathlib .Path (output_temp_dir ) / "job.yaml" )
220
+ mock_alias .return_value = job_yaml_path
221
+ job_config = _create_job_config (train_yaml_path )
222
+ job_config .to_yaml (job_yaml_path )
223
+ mock_launcher .JobConfig = JobConfig
224
+ mock_cluster = Mock ()
225
+ job_status = JobStatus (
226
+ id = "job_id" ,
227
+ cluster = "cluster_id" ,
228
+ name = "job_name" ,
229
+ status = "running" ,
230
+ metadata = "" ,
231
+ done = False ,
232
+ )
233
+ mock_launcher .up .return_value = (mock_cluster , job_status )
234
+ mock_cluster .get_job .return_value = job_status = JobStatus (
235
+ id = "job_id" ,
236
+ cluster = "cluster_id" ,
237
+ name = "job_name" ,
238
+ status = "done" ,
239
+ metadata = "" ,
240
+ done = True ,
241
+ )
242
+ _ = runner .invoke (
243
+ app ,
244
+ [
245
+ "up" ,
246
+ "--config" ,
247
+ "some_alias" ,
248
+ "--log-level" ,
249
+ "DEBUG" ,
250
+ ],
251
+ )
252
+ mock_fetch .assert_called_once_with (job_yaml_path )
253
+ mock_cluster .get_job .assert_has_calls ([call ("job_id" )])
254
+ mock_alias .assert_called_once_with ("some_alias" , AliasType .JOB )
255
+ assert logger .level == logging .DEBUG
256
+
257
+
205
258
def test_launch_up_job_dev_confirm (
206
259
app , mock_launcher , mock_pool , mock_version , mock_confirm , mock_git_root , mock_fetch
207
260
):
@@ -560,6 +613,62 @@ def test_launch_run_job(
560
613
assert logger .level == logging .CRITICAL
561
614
562
615
616
+ def test_launch_run_job_with_alias (
617
+ app , mock_launcher , mock_pool , mock_version , mock_confirm , mock_fetch , mock_alias
618
+ ):
619
+ with tempfile .TemporaryDirectory () as output_temp_dir :
620
+ train_yaml_path = str (pathlib .Path (output_temp_dir ) / "train.yaml" )
621
+ config : TrainingConfig = _create_training_config ()
622
+ config .to_yaml (train_yaml_path )
623
+ job_yaml_path = str (pathlib .Path (output_temp_dir ) / "job.yaml" )
624
+ mock_alias .return_value = job_yaml_path
625
+ job_config = _create_job_config (train_yaml_path )
626
+ job_config .to_yaml (job_yaml_path )
627
+ mock_launcher .JobConfig = JobConfig
628
+ mock_cluster = Mock ()
629
+ job_status = JobStatus (
630
+ id = "job_id" ,
631
+ cluster = "cluster_id" ,
632
+ name = "job_name" ,
633
+ status = "running" ,
634
+ metadata = "" ,
635
+ done = False ,
636
+ )
637
+ mock_cloud = Mock ()
638
+ mock_launcher .run .return_value = job_status
639
+ mock_launcher .get_cloud .side_effect = [mock_cloud , mock_cloud ]
640
+ mock_cloud .get_cluster .side_effect = [mock_cluster , mock_cluster ]
641
+ mock_cluster .get_job .return_value = job_status = JobStatus (
642
+ id = "job_id" ,
643
+ cluster = "cluster_id" ,
644
+ name = "job_name" ,
645
+ status = "done" ,
646
+ metadata = "" ,
647
+ done = True ,
648
+ )
649
+ _ = runner .invoke (
650
+ app ,
651
+ [
652
+ "run" ,
653
+ "--config" ,
654
+ "some_alias" ,
655
+ "--cluster" ,
656
+ "cluster_id" ,
657
+ "-log" ,
658
+ "CRITICAL" ,
659
+ ],
660
+ )
661
+ mock_cluster .get_job .assert_has_calls ([call ("job_id" ), call ("job_id" )])
662
+ mock_launcher .run .assert_called_once_with (job_config , "cluster_id" )
663
+ mock_launcher .get_cloud .assert_has_calls ([call ("aws" ), call ("aws" )])
664
+ mock_cloud .get_cluster .assert_has_calls (
665
+ [call ("cluster_id" ), call ("cluster_id" )]
666
+ )
667
+ mock_fetch .assert_called_once_with (job_yaml_path )
668
+ mock_alias .assert_called_once_with ("some_alias" , AliasType .JOB )
669
+ assert logger .level == logging .CRITICAL
670
+
671
+
563
672
def test_launch_run_job_dev_confirm (
564
673
app , mock_launcher , mock_pool , mock_version , mock_confirm , mock_git_root , mock_fetch
565
674
):
0 commit comments