Skip to content

Commit e8722e6

Browse files
authored
add more tests (#37)
Signed-off-by: kerthcet <[email protected]>
1 parent 04611dc commit e8722e6

File tree

4 files changed

+117
-21
lines changed

4 files changed

+117
-21
lines changed

alphatrion/log/log.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ async def log_params(params: dict):
4949

5050
# log_metrics is used to log a set of metrics at once,
5151
# metric key must be string, value must be float.
52-
# If save_best_only is enabled in the trial config, and the metric is the best metric
52+
# If save_on_best is enabled in the trial config, and the metric is the best metric
5353
# so far, the trial will checkpoint the current data.
5454
async def log_metrics(metrics: dict[str, float]):
5555
runtime = global_runtime()

alphatrion/trial/trial.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import uuid
44
from datetime import UTC, datetime
55

6-
from pydantic import BaseModel, Field, field_validator
6+
from pydantic import BaseModel, Field, model_validator
77

88
from alphatrion.metadata.sql_models import COMPLETED_STATUS, TrialStatus
99
from alphatrion.runtime.runtime import global_runtime
@@ -21,16 +21,16 @@ class CheckpointConfig(BaseModel):
2121
description="Whether to enable checkpointing. \
2222
Default is False.",
2323
)
24-
save_every_n_seconds: int | None = Field(
25-
default=None,
26-
description="Interval in seconds to save checkpoints. \
27-
Default is None.",
28-
)
29-
save_every_n_steps: int | None = Field(
30-
default=None,
31-
description="Interval in steps to save checkpoints. \
32-
Default is None.",
33-
)
24+
# save_every_n_seconds: int | None = Field(
25+
# default=None,
26+
# description="Interval in seconds to save checkpoints. \
27+
# Default is None.",
28+
# )
29+
# save_every_n_steps: int | None = Field(
30+
# default=None,
31+
# description="Interval in steps to save checkpoints. \
32+
# Default is None.",
33+
# )
3434
save_on_best: bool = Field(
3535
default=False,
3636
description="Once a best result is found, it will be saved. \
@@ -52,12 +52,11 @@ class CheckpointConfig(BaseModel):
5252
description="The path to save checkpoints. Default is 'checkpoints'.",
5353
)
5454

55-
@field_validator("monitor_metric")
56-
def metric_must_be_valid(cls, v, info):
57-
save_best_only = info.data.get("save_best_only")
58-
if save_best_only and v is None:
59-
raise ValueError("metric must be specified when save_best_only=True")
60-
return v
55+
@model_validator(mode="after")
56+
def metric_must_be_valid(self):
57+
if self.save_on_best and not self.monitor_metric:
58+
raise ValueError("monitor_metric must be specified when save_on_best=True")
59+
return self
6160

6261

6362
class TrialConfig(BaseModel):

tests/integration/test_log_functions.py renamed to tests/integration/test_log.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ async def test_log_metrics():
125125

126126

127127
@pytest.mark.asyncio
128-
async def test_log_metrics_with_save_best_only():
128+
async def test_log_metrics_with_save_on_max():
129129
alpha.init(project_id="test_project", artifact_insecure=True)
130130

131131
async with alpha.CraftExperiment.run(
@@ -137,7 +137,7 @@ async def test_log_metrics_with_save_best_only():
137137
os.chdir(tmpdir)
138138

139139
_ = exp.start_trial(
140-
description="Trial with save_best_only",
140+
description="Trial with save_on_best",
141141
config=TrialConfig(
142142
checkpoint=CheckpointConfig(
143143
enabled=True,
@@ -176,3 +176,57 @@ async def test_log_metrics_with_save_best_only():
176176
await alpha.log_metrics({"accuracy2": 0.98})
177177
versions = exp._runtime._artifact.list_versions(exp.id)
178178
assert len(versions) == 2
179+
180+
181+
@pytest.mark.asyncio
182+
async def test_log_metrics_with_save_on_min():
183+
alpha.init(project_id="test_project", artifact_insecure=True)
184+
185+
async with alpha.CraftExperiment.run(
186+
name="context_exp",
187+
description="Context manager test",
188+
meta={"key": "value"},
189+
) as exp:
190+
with tempfile.TemporaryDirectory() as tmpdir:
191+
os.chdir(tmpdir)
192+
193+
_ = exp.start_trial(
194+
description="Trial with save_on_best",
195+
config=TrialConfig(
196+
checkpoint=CheckpointConfig(
197+
enabled=True,
198+
path=tmpdir,
199+
save_on_best=True,
200+
monitor_metric="accuracy",
201+
monitor_mode="min",
202+
)
203+
),
204+
)
205+
206+
file1 = "file1.txt"
207+
with open(file1, "w") as f:
208+
f.write("This is file1.")
209+
210+
await alpha.log_metrics({"accuracy": 0.30})
211+
212+
versions = exp._runtime._artifact.list_versions(exp.id)
213+
assert len(versions) == 1
214+
215+
# To avoid the same timestamp hash, we wait for 1 second
216+
time.sleep(1)
217+
218+
await alpha.log_metrics({"accuracy": 0.58})
219+
versions = exp._runtime._artifact.list_versions(exp.id)
220+
assert len(versions) == 1
221+
222+
time.sleep(1)
223+
224+
await alpha.log_metrics({"accuracy": 0.21})
225+
versions = exp._runtime._artifact.list_versions(exp.id)
226+
assert len(versions) == 2
227+
228+
time.sleep(1)
229+
230+
await alpha.log_metrics({"accuracy2": 0.18})
231+
versions = exp._runtime._artifact.list_versions(exp.id)
232+
assert len(versions) == 2

tests/unit/trial/test_trial.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,50 @@
22
import unittest
33
from datetime import UTC, datetime, timedelta
44

5-
from alphatrion.trial.trial import Trial, TrialConfig
5+
from alphatrion.trial.trial import CheckpointConfig, Trial, TrialConfig
6+
7+
8+
class TestCheckpointConfig(unittest.TestCase):
9+
def test_invalid_monitor_metric(self):
10+
test_cases = [
11+
{
12+
"name": "Valid metric with save_on_best True",
13+
"config": {
14+
"enabled": True,
15+
"save_on_best": True,
16+
"monitor_mode": "max",
17+
"monitor_metric": "accuracy",
18+
},
19+
"error": False,
20+
},
21+
{
22+
"name": "Invalid metric with save_on_best True",
23+
"config": {
24+
"enabled": True,
25+
"save_on_best": True,
26+
"monitor_mode": "max",
27+
},
28+
"error": True,
29+
},
30+
{
31+
"name": "Valid metric with save_on_best False",
32+
"config": {
33+
"enabled": True,
34+
"save_on_best": False,
35+
"monitor_mode": "max",
36+
"monitor_metric": "accuracy",
37+
},
38+
"error": False,
39+
},
40+
]
41+
42+
for case in test_cases:
43+
with self.subTest(name=case["name"]):
44+
if case["error"]:
45+
with self.assertRaises(ValueError):
46+
CheckpointConfig(**case["config"])
47+
else:
48+
_ = CheckpointConfig(**case["config"])
649

750

851
class TestTrial(unittest.IsolatedAsyncioTestCase):

0 commit comments

Comments
 (0)