19
19
TorchxEvent ,
20
20
)
21
21
22
+ SESSION_ID = "123"
23
+
22
24
23
25
class TorchxEventLibTest (unittest .TestCase ):
24
26
def assert_event (
@@ -44,14 +46,14 @@ def test_get_or_create_logger(self, logging_handler_mock: MagicMock) -> None:
44
46
def test_event_created (self ) -> None :
45
47
test_metadata = {"test_key" : "test_value" }
46
48
event = TorchxEvent (
47
- session = "test_session" ,
49
+ session = SESSION_ID ,
48
50
scheduler = "test_scheduler" ,
49
51
api = "test_api" ,
50
52
app_image = "test_app_image" ,
51
53
app_metadata = test_metadata ,
52
54
workspace = "test_workspace" ,
53
55
)
54
- self .assertEqual ("test_session" , event .session )
56
+ self .assertEqual (SESSION_ID , event .session )
55
57
self .assertEqual ("test_scheduler" , event .scheduler )
56
58
self .assertEqual ("test_api" , event .api )
57
59
self .assertEqual ("test_app_image" , event .app_image )
@@ -76,6 +78,7 @@ def test_event_deser(self) -> None:
76
78
77
79
78
80
@patch ("torchx.runner.events.record" )
81
+ @patch ("torchx.runner.events.get_session_id_or_create_new" )
79
82
class LogEventTest (unittest .TestCase ):
80
83
def assert_torchx_event (self , expected : TorchxEvent , actual : TorchxEvent ) -> None :
81
84
self .assertEqual (expected .session , actual .session )
@@ -86,7 +89,10 @@ def assert_torchx_event(self, expected: TorchxEvent, actual: TorchxEvent) -> Non
86
89
self .assertEqual (expected .workspace , actual .workspace )
87
90
self .assertEqual (expected .app_metadata , actual .app_metadata )
88
91
89
- def test_create_context (self , _ ) -> None :
92
+ def test_create_context (
93
+ self , get_session_id_or_create_new_mock : MagicMock , record_mock : MagicMock
94
+ ) -> None :
95
+ get_session_id_or_create_new_mock .return_value = SESSION_ID
90
96
test_dict = {"test_key" : "test_value" }
91
97
cfg = json .dumps (test_dict )
92
98
context = log_event (
@@ -99,7 +105,7 @@ def test_create_context(self, _) -> None:
99
105
workspace = "test_workspace" ,
100
106
)
101
107
expected_torchx_event = TorchxEvent (
102
- "test_app_id" ,
108
+ SESSION_ID ,
103
109
"local" ,
104
110
"test_call" ,
105
111
"test_app_id" ,
@@ -111,7 +117,10 @@ def test_create_context(self, _) -> None:
111
117
112
118
self .assert_torchx_event (expected_torchx_event , context ._torchx_event )
113
119
114
- def test_record_event (self , record_mock : MagicMock ) -> None :
120
+ def test_record_event (
121
+ self , get_session_id_or_create_new_mock : MagicMock , record_mock : MagicMock
122
+ ) -> None :
123
+ get_session_id_or_create_new_mock .return_value = SESSION_ID
115
124
test_dict = {"test_key" : "test_value" }
116
125
cfg = json .dumps (test_dict )
117
126
with log_event (
@@ -126,7 +135,7 @@ def test_record_event(self, record_mock: MagicMock) -> None:
126
135
pass
127
136
128
137
expected_torchx_event = TorchxEvent (
129
- "test_app_id" ,
138
+ SESSION_ID ,
130
139
"local" ,
131
140
"test_call" ,
132
141
"test_app_id" ,
@@ -139,7 +148,9 @@ def test_record_event(self, record_mock: MagicMock) -> None:
139
148
)
140
149
self .assert_torchx_event (expected_torchx_event , ctx ._torchx_event )
141
150
142
- def test_record_event_with_exception (self , record_mock : MagicMock ) -> None :
151
+ def test_record_event_with_exception (
152
+ self , get_session_id_or_create_new_mock : MagicMock , record_mock : MagicMock
153
+ ) -> None :
143
154
cfg = json .dumps ({"test_key" : "test_value" })
144
155
with self .assertRaises (RuntimeError ):
145
156
with log_event ("test_call" , "local" , "test_app_id" , cfg ) as ctx :
0 commit comments