32
32
from torchx .specs .finder import ComponentNotFoundException
33
33
from torchx .test .fixtures import TestWithTmpDir
34
34
from torchx .tracker .api import ENV_TORCHX_JOB_ID , ENV_TORCHX_PARENT_RUN_ID
35
+ from torchx .util .session import get_session_id
35
36
from torchx .util .types import none_throws
36
37
from torchx .workspace import WorkspaceMixin
37
38
@@ -51,7 +52,7 @@ def get_full_path(name: str) -> str:
51
52
return os .path .join (os .path .dirname (__file__ ), "resource" , name )
52
53
53
54
54
- @patch ("torchx.runner.api.log_event " )
55
+ @patch ("torchx.runner.events.record " )
55
56
class RunnerTest (TestWithTmpDir ):
56
57
def setUp (self ) -> None :
57
58
super ().setUp ()
@@ -104,7 +105,38 @@ def test_validate_invalid_replicas(self, _) -> None:
104
105
with self .assertRaises (ValueError ):
105
106
runner .run (app , scheduler = "local_dir" )
106
107
107
- def test_run (self , _ ) -> None :
108
+ @patch ("torchx.util.session.uuid" )
109
+ def test_session_id (self , uuid_mock : MagicMock , record_mock : MagicMock ) -> None :
110
+ uuid_mock .uuid4 .return_value = "test_session_id"
111
+ test_file = self .tmpdir / "test_file"
112
+
113
+ with self .get_runner () as runner :
114
+ self .assertEqual (1 , len (runner .scheduler_backends ()))
115
+ role = Role (
116
+ name = "touch" ,
117
+ image = str (self .tmpdir ),
118
+ resource = resource .SMALL ,
119
+ entrypoint = "touch.sh" ,
120
+ args = [str (test_file )],
121
+ )
122
+ app = AppDef ("name" , roles = [role ])
123
+
124
+ app_handle_1 = runner .run (app , scheduler = "local_dir" , cfg = self .cfg )
125
+ none_throws (runner .wait (app_handle_1 , wait_interval = 0.1 ))
126
+
127
+ app_handle_2 = runner .run (app , scheduler = "local_dir" , cfg = self .cfg )
128
+ none_throws (runner .wait (app_handle_2 , wait_interval = 0.1 ))
129
+
130
+ self .assertEqual (get_session_id (), "test_session_id" )
131
+ uuid_mock .uuid4 .assert_called_once ()
132
+ record_mock .assert_called ()
133
+ for i in range (record_mock .call_count ):
134
+ event = record_mock .call_args_list [i ].args [0 ]
135
+ self .assertEqual (event .session , "test_session_id" )
136
+
137
+ @patch ("torchx.util.session.uuid" )
138
+ def test_run (self , uuid_mock : MagicMock , _ ) -> None :
139
+ uuid_mock .uuid4 .return_value = "test_session_id"
108
140
test_file = self .tmpdir / "test_file"
109
141
110
142
with self .get_runner () as runner :
@@ -121,8 +153,11 @@ def test_run(self, _) -> None:
121
153
app_handle = runner .run (app , scheduler = "local_dir" , cfg = self .cfg )
122
154
app_status = none_throws (runner .wait (app_handle , wait_interval = 0.1 ))
123
155
self .assertEqual (AppState .SUCCEEDED , app_status .state )
156
+ self .assertEqual (get_session_id (), "test_session_id" )
124
157
125
- def test_dryrun (self , _ ) -> None :
158
+ @patch ("torchx.util.session.uuid" )
159
+ def test_dryrun (self , uuid_mock : MagicMock , _ ) -> None :
160
+ uuid_mock .uuid4 .return_value = "test_session_id"
126
161
scheduler_mock = MagicMock ()
127
162
scheduler_mock .run_opts .return_value .resolve .return_value = {
128
163
** self .cfg ,
@@ -145,6 +180,7 @@ def test_dryrun(self, _) -> None:
145
180
app , {** self .cfg , "foo" : "bar" }
146
181
)
147
182
scheduler_mock ._validate .assert_called_once ()
183
+ self .assertEqual (get_session_id (), "test_session_id" )
148
184
149
185
def test_dryrun_env_variables (self , _ ) -> None :
150
186
scheduler_mock = MagicMock ()
0 commit comments