Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion alphatrion/log/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ async def log_params(params: dict):

# log_metrics is used to log a set of metrics at once,
# metric key must be string, value must be float.
# If save_best_only is enabled in the trial config, and the metric is the best metric
# If save_on_best is enabled in the trial config, and the metric is the best metric
# so far, the trial will checkpoint the current data.
async def log_metrics(metrics: dict[str, float]):
runtime = global_runtime()
Expand Down
33 changes: 16 additions & 17 deletions alphatrion/trial/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import uuid
from datetime import UTC, datetime

from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field, model_validator

from alphatrion.metadata.sql_models import COMPLETED_STATUS, TrialStatus
from alphatrion.runtime.runtime import global_runtime
Expand All @@ -21,16 +21,16 @@ class CheckpointConfig(BaseModel):
description="Whether to enable checkpointing. \
Default is False.",
)
save_every_n_seconds: int | None = Field(
default=None,
description="Interval in seconds to save checkpoints. \
Default is None.",
)
save_every_n_steps: int | None = Field(
default=None,
description="Interval in steps to save checkpoints. \
Default is None.",
)
# save_every_n_seconds: int | None = Field(
# default=None,
# description="Interval in seconds to save checkpoints. \
# Default is None.",
# )
# save_every_n_steps: int | None = Field(
# default=None,
# description="Interval in steps to save checkpoints. \
# Default is None.",
# )
save_on_best: bool = Field(
default=False,
description="Once a best result is found, it will be saved. \
Expand All @@ -52,12 +52,11 @@ class CheckpointConfig(BaseModel):
description="The path to save checkpoints. Default is 'checkpoints'.",
)

@field_validator("monitor_metric")
def metric_must_be_valid(cls, v, info):
save_best_only = info.data.get("save_best_only")
if save_best_only and v is None:
raise ValueError("metric must be specified when save_best_only=True")
return v
@model_validator(mode="after")
def metric_must_be_valid(self):
if self.save_on_best and not self.monitor_metric:
raise ValueError("monitor_metric must be specified when save_on_best=True")
return self


class TrialConfig(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ async def test_log_metrics():


@pytest.mark.asyncio
async def test_log_metrics_with_save_best_only():
async def test_log_metrics_with_save_on_max():
alpha.init(project_id="test_project", artifact_insecure=True)

async with alpha.CraftExperiment.run(
Expand All @@ -137,7 +137,7 @@ async def test_log_metrics_with_save_best_only():
os.chdir(tmpdir)

_ = exp.start_trial(
description="Trial with save_best_only",
description="Trial with save_on_best",
config=TrialConfig(
checkpoint=CheckpointConfig(
enabled=True,
Expand Down Expand Up @@ -176,3 +176,57 @@ async def test_log_metrics_with_save_best_only():
await alpha.log_metrics({"accuracy2": 0.98})
versions = exp._runtime._artifact.list_versions(exp.id)
assert len(versions) == 2


@pytest.mark.asyncio
async def test_log_metrics_with_save_on_min():
alpha.init(project_id="test_project", artifact_insecure=True)

async with alpha.CraftExperiment.run(
name="context_exp",
description="Context manager test",
meta={"key": "value"},
) as exp:
with tempfile.TemporaryDirectory() as tmpdir:
os.chdir(tmpdir)

_ = exp.start_trial(
description="Trial with save_on_best",
config=TrialConfig(
checkpoint=CheckpointConfig(
enabled=True,
path=tmpdir,
save_on_best=True,
monitor_metric="accuracy",
monitor_mode="min",
)
),
)

file1 = "file1.txt"
with open(file1, "w") as f:
f.write("This is file1.")

await alpha.log_metrics({"accuracy": 0.30})

versions = exp._runtime._artifact.list_versions(exp.id)
assert len(versions) == 1

# To avoid the same timestamp hash, we wait for 1 second
time.sleep(1)

await alpha.log_metrics({"accuracy": 0.58})
versions = exp._runtime._artifact.list_versions(exp.id)
assert len(versions) == 1

time.sleep(1)

await alpha.log_metrics({"accuracy": 0.21})
versions = exp._runtime._artifact.list_versions(exp.id)
assert len(versions) == 2

time.sleep(1)

await alpha.log_metrics({"accuracy2": 0.18})
versions = exp._runtime._artifact.list_versions(exp.id)
assert len(versions) == 2
45 changes: 44 additions & 1 deletion tests/unit/trial/test_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,50 @@
import unittest
from datetime import UTC, datetime, timedelta

from alphatrion.trial.trial import Trial, TrialConfig
from alphatrion.trial.trial import CheckpointConfig, Trial, TrialConfig


class TestCheckpointConfig(unittest.TestCase):
def test_invalid_monitor_metric(self):
test_cases = [
{
"name": "Valid metric with save_on_best True",
"config": {
"enabled": True,
"save_on_best": True,
"monitor_mode": "max",
"monitor_metric": "accuracy",
},
"error": False,
},
{
"name": "Invalid metric with save_on_best True",
"config": {
"enabled": True,
"save_on_best": True,
"monitor_mode": "max",
},
"error": True,
},
{
"name": "Valid metric with save_on_best False",
"config": {
"enabled": True,
"save_on_best": False,
"monitor_mode": "max",
"monitor_metric": "accuracy",
},
"error": False,
},
]

for case in test_cases:
with self.subTest(name=case["name"]):
if case["error"]:
with self.assertRaises(ValueError):
CheckpointConfig(**case["config"])
else:
_ = CheckpointConfig(**case["config"])


class TestTrial(unittest.IsolatedAsyncioTestCase):
Expand Down
Loading