Skip to content

Commit 32e0e33

Browse files
authored
Support with context manager (#18)
* Add ExperimentConfig Signed-off-by: kerthcet <[email protected]> * Add context manager for quick start Signed-off-by: kerthcet <[email protected]> --------- Signed-off-by: kerthcet <[email protected]>
1 parent 6544cc8 commit 32e0e33

File tree

10 files changed

+410
-105
lines changed

10 files changed

+410
-105
lines changed

alphatrion/experiment/base.py

Lines changed: 188 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,212 @@
1-
from abc import ABC, abstractmethod
1+
import uuid
2+
from datetime import UTC, datetime
23

4+
from pydantic import BaseModel, Field, field_validator
5+
6+
from alphatrion.artifact.artifact import Artifact
7+
from alphatrion.metadata.sql_models import COMPLETED_STATUS, ExperimentStatus
38
from alphatrion.runtime.runtime import Runtime
49

510

6-
class Experiment(ABC):
7-
"""Base class for all experiments."""
11+
class CheckpointConfig(BaseModel):
12+
"""Configuration for a checkpoint."""
13+
14+
enabled: bool = Field(
15+
default=True,
16+
description="Whether to enable checkpointing. \
17+
Default is True. One exception is CraftExperiment, \
18+
which doesn't enable checkpoint by default.",
19+
)
20+
save_every_n_seconds: int = Field(
21+
default=300,
22+
description="Interval in seconds to save checkpoints. \
23+
Default is 300 seconds.",
24+
)
25+
save_every_n_steps: int = Field(
26+
default=0,
27+
description="Interval in steps to save checkpoints. \
28+
Default is 0 (disabled).",
29+
)
30+
save_best_only: bool = Field(
31+
default=True,
32+
description="Once a best result is found, it will be saved. Default is True. \
33+
Can be enabled together with save_every_n_steps/save_every_n_seconds.",
34+
)
35+
monitor_metric: str = Field(
36+
default=None,
37+
description="The metric to monitor for saving the best checkpoint. \
38+
Required if save_best_only is True.",
39+
)
40+
monitor_mode: str = Field(
41+
default="max",
42+
description="The mode for monitoring the metric. Can be 'max' or 'min'. \
43+
Default is 'max'.",
44+
)
45+
46+
@field_validator("monitor_metric")
47+
def metric_must_be_valid(cls, v, info):
48+
save_best_only = info.data.get("save_best_only")
49+
if save_best_only and v is None:
50+
raise ValueError("metric must be specified when save_best_only=True")
51+
return v
52+
53+
54+
class ExperimentConfig(BaseModel):
55+
"""Configuration for an experiment."""
856

9-
def __init__(self, runtime: Runtime):
57+
max_duration_seconds: int = Field(
58+
default=86400,
59+
description="Maximum duration in seconds for the experiment. \
60+
Default is 86400 seconds (1 day).",
61+
)
62+
max_retries: int = Field(
63+
default=0,
64+
description="Maximum number of retries for the experiment. \
65+
Default is 0 (no retries).",
66+
)
67+
checkpoint: CheckpointConfig = Field(
68+
default=CheckpointConfig(),
69+
description="Configuration for checkpointing.",
70+
)
71+
72+
73+
class Experiment:
74+
"""Base Experiment class."""
75+
76+
def __init__(self, runtime: Runtime, config: ExperimentConfig | None = None):
1077
self._runtime = runtime
78+
self._artifact = Artifact(runtime)
79+
self._config = config or ExperimentConfig()
80+
self._steps = 0
81+
self._start_at = None
82+
self._best_metric_value = None
83+
84+
@classmethod
85+
def run(
86+
cls,
87+
project_id: str,
88+
name: str | None = None,
89+
description: str | None = None,
90+
meta: dict | None = None,
91+
labels: dict | None = None,
92+
):
93+
runtime = Runtime(project_id=project_id)
94+
exp = cls(runtime=runtime)
95+
return RunContext(
96+
exp, name=name, description=description, meta=meta, labels=labels
97+
)
1198

12-
@abstractmethod
1399
def create(
14100
self,
15101
name: str,
16102
description: str | None = None,
17103
meta: dict | None = None,
18104
labels: dict | None = None,
19105
):
20-
raise NotImplementedError("Subclasses must implement this method.")
106+
exp_id = self._runtime._metadb.create_exp(
107+
name=name,
108+
description=description,
109+
project_id=self._runtime._project_id,
110+
meta=meta,
111+
labels=labels,
112+
)
21113

22-
@abstractmethod
23-
def delete(self, exp_id: int):
24-
raise NotImplementedError("Subclasses must implement this method.")
114+
return exp_id
25115

26-
@abstractmethod
27116
def get(self, exp_id: int):
28-
raise NotImplementedError("Subclasses must implement this method.")
117+
return self._runtime._metadb.get_exp(exp_id=exp_id)
29118

30-
@abstractmethod
31119
def list(self, page: int = 0, page_size: int = 10):
32-
raise NotImplementedError("Subclasses must implement this method.")
120+
return self._runtime._metadb.list_exps(
121+
project_id=self._runtime._project_id, page=page, page_size=page_size
122+
)
123+
124+
# TODO: delete related artifacts too. But for google artifact registry,
125+
# it seems not supported to delete a tag only.
126+
# See issue: https://github.com/InftyAI/alphatrion/issues/14
127+
def delete(self, exp_id: int):
128+
self._runtime._metadb.delete_exp(exp_id=exp_id)
33129

34-
@abstractmethod
130+
# Please provide all the labels to update, or it will overwrite the existing labels.
35131
def update_labels(self, exp_id: int, labels: dict):
36-
raise NotImplementedError("Subclasses must implement this method.")
132+
self._runtime._metadb.update_exp(exp_id=exp_id, labels=labels)
37133

38-
@abstractmethod
39-
def start(self, exp_id: int):
40-
raise NotImplementedError("Subclasses must implement this method.")
134+
def start(
135+
self,
136+
name: str | None = None,
137+
description: str | None = None,
138+
meta: dict | None = None,
139+
labels: dict | None = None,
140+
) -> int:
141+
if name is None:
142+
name = f"{uuid.uuid4()}"
143+
144+
exp_id = self.create(
145+
name=name,
146+
description=description,
147+
meta=meta,
148+
labels=labels,
149+
)
150+
self._runtime._metadb.update_exp(exp_id=exp_id, status=ExperimentStatus.RUNNING)
151+
self._start_at = datetime.now(UTC)
152+
return exp_id
153+
154+
def stop(self, exp_id: int, status: ExperimentStatus = ExperimentStatus.FINISHED):
155+
exp = self._runtime._metadb.get_exp(exp_id=exp_id)
156+
if exp is not None and exp.status not in COMPLETED_STATUS:
157+
duration = (datetime.now() - exp.created_at).total_seconds()
158+
self._runtime._metadb.update_exp(
159+
exp_id=exp_id, status=status, duration=duration
160+
)
161+
162+
def status(self, exp_id: int) -> ExperimentStatus:
163+
exp = self._runtime._metadb.get_exp(exp_id=exp_id)
164+
return exp.status
165+
166+
def reset(self):
167+
self._steps = 0
168+
self._start_at = None
169+
self._best_metric_value = None
170+
171+
# def save_checkpoint(
172+
# self,
173+
# exp_id: int,
174+
# files: list[str] | None = None,
175+
# folder: str | None = None,
176+
# version: str = "latest",
177+
# ):
178+
# exp = self._runtime._metadb.get_exp(exp_id=exp_id)
179+
# self._artifact.push(
180+
# experiment_name=exp.name, files=files, folder=folder, version=version
181+
# )
182+
183+
184+
class RunContext:
185+
"""A context manager for running experiments."""
186+
187+
def __init__(
188+
self,
189+
experiment: Experiment,
190+
name: str | None = None,
191+
description: str | None = None,
192+
meta: dict | None = None,
193+
labels: dict | None = None,
194+
):
195+
self._experiment = experiment
196+
self._exp_name = name
197+
self._description = description
198+
self._meta = meta
199+
self._labels = labels
41200

42-
@abstractmethod
43-
def stop(self, exp_id: int, status: str = "finished"):
44-
raise NotImplementedError("Subclasses must implement this method.")
201+
def __enter__(self):
202+
self._exp_id = self._experiment.start(
203+
name=self._exp_name,
204+
description=self._description,
205+
meta=self._meta,
206+
labels=self._labels,
207+
)
208+
return self._experiment
45209

46-
@abstractmethod
47-
def status(self, exp_id: int) -> str:
48-
raise NotImplementedError("Subclasses must implement this method.")
210+
def __exit__(self, exc_type, exc_val, exc_tb):
211+
self._experiment.stop(self._exp_id)
212+
self._experiment.reset()

alphatrion/experiment/craft_exp.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from alphatrion.experiment.base import Experiment
2+
from alphatrion.runtime.runtime import Runtime
3+
4+
5+
class CraftExperiment(Experiment):
6+
"""
7+
Craft experiment implementation.
8+
9+
This experiment class offers methods to manage the experiment lifecycle flexibly.
10+
Opposite to other experiment classes, you need to call all these methods yourself.
11+
"""
12+
13+
def __init__(self, runtime: Runtime):
14+
super().__init__(runtime)
15+
# Disable checkpointing by default for CraftExperiment
16+
self._config.checkpoint.enabled = False

alphatrion/experiment/custom_exp.py

Lines changed: 0 additions & 68 deletions
This file was deleted.

alphatrion/metadata/sql.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def create_exp(
2222
description: str | None,
2323
meta: dict | None,
2424
labels: dict | None = None,
25-
):
25+
) -> int:
2626
session = self._session()
2727
new_exp = Experiment(
2828
name=name,
@@ -33,8 +33,12 @@ def create_exp(
3333
)
3434
session.add(new_exp)
3535
session.commit()
36+
37+
exp_id = new_exp.id
3638
session.close()
3739

40+
return exp_id
41+
3842
# Soft delete the experiment now. In the future, we may implement hard delete.
3943
def delete_exp(self, exp_id: int):
4044
session = self._session()

alphatrion/runtime/runtime.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,4 @@
77
class Runtime:
88
def __init__(self, project_id: str):
99
self._project_id = project_id
10-
# TODO: initialize the metadata database based on the URL.
1110
self._metadb = SQLStore(os.getenv(consts.METADATA_DB_URL), init_tables=True)

0 commit comments

Comments
 (0)