diff --git a/alphatrion/log/log.py b/alphatrion/log/log.py index 2acbf8d..f8c5030 100644 --- a/alphatrion/log/log.py +++ b/alphatrion/log/log.py @@ -49,7 +49,7 @@ async def log_params(params: dict): # log_metrics is used to log a set of metrics at once, # metric key must be string, value must be float. -# If save_best_only is enabled in the trial config, and the metric is the best metric +# If save_on_best is enabled in the trial config, and the metric is the best metric # so far, the trial will checkpoint the current data. async def log_metrics(metrics: dict[str, float]): runtime = global_runtime() diff --git a/alphatrion/trial/trial.py b/alphatrion/trial/trial.py index 4a165a4..37ae632 100644 --- a/alphatrion/trial/trial.py +++ b/alphatrion/trial/trial.py @@ -3,7 +3,7 @@ import uuid from datetime import UTC, datetime -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field, model_validator from alphatrion.metadata.sql_models import COMPLETED_STATUS, TrialStatus from alphatrion.runtime.runtime import global_runtime @@ -21,16 +21,16 @@ class CheckpointConfig(BaseModel): description="Whether to enable checkpointing. \ Default is False.", ) - save_every_n_seconds: int | None = Field( - default=None, - description="Interval in seconds to save checkpoints. \ - Default is None.", - ) - save_every_n_steps: int | None = Field( - default=None, - description="Interval in steps to save checkpoints. \ - Default is None.", - ) + # save_every_n_seconds: int | None = Field( + # default=None, + # description="Interval in seconds to save checkpoints. \ + # Default is None.", + # ) + # save_every_n_steps: int | None = Field( + # default=None, + # description="Interval in steps to save checkpoints. \ + # Default is None.", + # ) save_on_best: bool = Field( default=False, description="Once a best result is found, it will be saved. \ @@ -52,12 +52,11 @@ class CheckpointConfig(BaseModel): description="The path to save checkpoints. Default is 'checkpoints'.", ) - @field_validator("monitor_metric") - def metric_must_be_valid(cls, v, info): - save_best_only = info.data.get("save_best_only") - if save_best_only and v is None: - raise ValueError("metric must be specified when save_best_only=True") - return v + @model_validator(mode="after") + def metric_must_be_valid(self): + if self.save_on_best and not self.monitor_metric: + raise ValueError("monitor_metric must be specified when save_on_best=True") + return self class TrialConfig(BaseModel): diff --git a/tests/integration/test_log_functions.py b/tests/integration/test_log.py similarity index 75% rename from tests/integration/test_log_functions.py rename to tests/integration/test_log.py index e22551e..0155ba8 100644 --- a/tests/integration/test_log_functions.py +++ b/tests/integration/test_log.py @@ -125,7 +125,7 @@ async def test_log_metrics(): @pytest.mark.asyncio -async def test_log_metrics_with_save_best_only(): +async def test_log_metrics_with_save_on_max(): alpha.init(project_id="test_project", artifact_insecure=True) async with alpha.CraftExperiment.run( @@ -137,7 +137,7 @@ async def test_log_metrics_with_save_best_only(): os.chdir(tmpdir) _ = exp.start_trial( - description="Trial with save_best_only", + description="Trial with save_on_best", config=TrialConfig( checkpoint=CheckpointConfig( enabled=True, @@ -176,3 +176,57 @@ async def test_log_metrics_with_save_best_only(): await alpha.log_metrics({"accuracy2": 0.98}) versions = exp._runtime._artifact.list_versions(exp.id) assert len(versions) == 2 + + +@pytest.mark.asyncio +async def test_log_metrics_with_save_on_min(): + alpha.init(project_id="test_project", artifact_insecure=True) + + async with alpha.CraftExperiment.run( + name="context_exp", + description="Context manager test", + meta={"key": "value"}, + ) as exp: + with tempfile.TemporaryDirectory() as tmpdir: + os.chdir(tmpdir) + + _ = exp.start_trial( + description="Trial with save_on_best", + config=TrialConfig( + checkpoint=CheckpointConfig( + enabled=True, + path=tmpdir, + save_on_best=True, + monitor_metric="accuracy", + monitor_mode="min", + ) + ), + ) + + file1 = "file1.txt" + with open(file1, "w") as f: + f.write("This is file1.") + + await alpha.log_metrics({"accuracy": 0.30}) + + versions = exp._runtime._artifact.list_versions(exp.id) + assert len(versions) == 1 + + # To avoid the same timestamp hash, we wait for 1 second + time.sleep(1) + + await alpha.log_metrics({"accuracy": 0.58}) + versions = exp._runtime._artifact.list_versions(exp.id) + assert len(versions) == 1 + + time.sleep(1) + + await alpha.log_metrics({"accuracy": 0.21}) + versions = exp._runtime._artifact.list_versions(exp.id) + assert len(versions) == 2 + + time.sleep(1) + + await alpha.log_metrics({"accuracy2": 0.18}) + versions = exp._runtime._artifact.list_versions(exp.id) + assert len(versions) == 2 diff --git a/tests/unit/trial/test_trial.py b/tests/unit/trial/test_trial.py index 729f55a..37e77c5 100644 --- a/tests/unit/trial/test_trial.py +++ b/tests/unit/trial/test_trial.py @@ -2,7 +2,50 @@ import unittest from datetime import UTC, datetime, timedelta -from alphatrion.trial.trial import Trial, TrialConfig +from alphatrion.trial.trial import CheckpointConfig, Trial, TrialConfig + + +class TestCheckpointConfig(unittest.TestCase): + def test_invalid_monitor_metric(self): + test_cases = [ + { + "name": "Valid metric with save_on_best True", + "config": { + "enabled": True, + "save_on_best": True, + "monitor_mode": "max", + "monitor_metric": "accuracy", + }, + "error": False, + }, + { + "name": "Invalid metric with save_on_best True", + "config": { + "enabled": True, + "save_on_best": True, + "monitor_mode": "max", + }, + "error": True, + }, + { + "name": "Valid metric with save_on_best False", + "config": { + "enabled": True, + "save_on_best": False, + "monitor_mode": "max", + "monitor_metric": "accuracy", + }, + "error": False, + }, + ] + + for case in test_cases: + with self.subTest(name=case["name"]): + if case["error"]: + with self.assertRaises(ValueError): + CheckpointConfig(**case["config"]) + else: + _ = CheckpointConfig(**case["config"]) class TestTrial(unittest.IsolatedAsyncioTestCase):