|
1 | | -from abc import ABC, abstractmethod |
| 1 | +import uuid |
| 2 | +from datetime import UTC, datetime |
2 | 3 |
|
| 4 | +from pydantic import BaseModel, Field, field_validator |
| 5 | + |
| 6 | +from alphatrion.artifact.artifact import Artifact |
| 7 | +from alphatrion.metadata.sql_models import COMPLETED_STATUS, ExperimentStatus |
3 | 8 | from alphatrion.runtime.runtime import Runtime |
4 | 9 |
|
5 | 10 |
|
6 | | -class Experiment(ABC): |
7 | | - """Base class for all experiments.""" |
| 11 | +class CheckpointConfig(BaseModel): |
| 12 | + """Configuration for a checkpoint.""" |
| 13 | + |
| 14 | + enabled: bool = Field( |
| 15 | + default=True, |
| 16 | + description="Whether to enable checkpointing. \ |
| 17 | + Default is True. One exception is CraftExperiment, \ |
| 18 | + which doesn't enable checkpoint by default.", |
| 19 | + ) |
| 20 | + save_every_n_seconds: int = Field( |
| 21 | + default=300, |
| 22 | + description="Interval in seconds to save checkpoints. \ |
| 23 | + Default is 300 seconds.", |
| 24 | + ) |
| 25 | + save_every_n_steps: int = Field( |
| 26 | + default=0, |
| 27 | + description="Interval in steps to save checkpoints. \ |
| 28 | + Default is 0 (disabled).", |
| 29 | + ) |
| 30 | + save_best_only: bool = Field( |
| 31 | + default=True, |
| 32 | + description="Once a best result is found, it will be saved. Default is True. \ |
| 33 | + Can be enabled together with save_every_n_steps/save_every_n_seconds.", |
| 34 | + ) |
| 35 | + monitor_metric: str = Field( |
| 36 | + default=None, |
| 37 | + description="The metric to monitor for saving the best checkpoint. \ |
| 38 | + Required if save_best_only is True.", |
| 39 | + ) |
| 40 | + monitor_mode: str = Field( |
| 41 | + default="max", |
| 42 | + description="The mode for monitoring the metric. Can be 'max' or 'min'. \ |
| 43 | + Default is 'max'.", |
| 44 | + ) |
| 45 | + |
| 46 | + @field_validator("monitor_metric") |
| 47 | + def metric_must_be_valid(cls, v, info): |
| 48 | + save_best_only = info.data.get("save_best_only") |
| 49 | + if save_best_only and v is None: |
| 50 | + raise ValueError("metric must be specified when save_best_only=True") |
| 51 | + return v |
| 52 | + |
| 53 | + |
| 54 | +class ExperimentConfig(BaseModel): |
| 55 | + """Configuration for an experiment.""" |
8 | 56 |
|
9 | | - def __init__(self, runtime: Runtime): |
| 57 | + max_duration_seconds: int = Field( |
| 58 | + default=86400, |
| 59 | + description="Maximum duration in seconds for the experiment. \ |
| 60 | + Default is 86400 seconds (1 day).", |
| 61 | + ) |
| 62 | + max_retries: int = Field( |
| 63 | + default=0, |
| 64 | + description="Maximum number of retries for the experiment. \ |
| 65 | + Default is 0 (no retries).", |
| 66 | + ) |
| 67 | + checkpoint: CheckpointConfig = Field( |
| 68 | + default=CheckpointConfig(), |
| 69 | + description="Configuration for checkpointing.", |
| 70 | + ) |
| 71 | + |
| 72 | + |
| 73 | +class Experiment: |
| 74 | + """Base Experiment class.""" |
| 75 | + |
| 76 | + def __init__(self, runtime: Runtime, config: ExperimentConfig | None = None): |
10 | 77 | self._runtime = runtime |
| 78 | + self._artifact = Artifact(runtime) |
| 79 | + self._config = config or ExperimentConfig() |
| 80 | + self._steps = 0 |
| 81 | + self._start_at = None |
| 82 | + self._best_metric_value = None |
| 83 | + |
| 84 | + @classmethod |
| 85 | + def run( |
| 86 | + cls, |
| 87 | + project_id: str, |
| 88 | + name: str | None = None, |
| 89 | + description: str | None = None, |
| 90 | + meta: dict | None = None, |
| 91 | + labels: dict | None = None, |
| 92 | + ): |
| 93 | + runtime = Runtime(project_id=project_id) |
| 94 | + exp = cls(runtime=runtime) |
| 95 | + return RunContext( |
| 96 | + exp, name=name, description=description, meta=meta, labels=labels |
| 97 | + ) |
11 | 98 |
|
12 | | - @abstractmethod |
13 | 99 | def create( |
14 | 100 | self, |
15 | 101 | name: str, |
16 | 102 | description: str | None = None, |
17 | 103 | meta: dict | None = None, |
18 | 104 | labels: dict | None = None, |
19 | 105 | ): |
20 | | - raise NotImplementedError("Subclasses must implement this method.") |
| 106 | + exp_id = self._runtime._metadb.create_exp( |
| 107 | + name=name, |
| 108 | + description=description, |
| 109 | + project_id=self._runtime._project_id, |
| 110 | + meta=meta, |
| 111 | + labels=labels, |
| 112 | + ) |
21 | 113 |
|
22 | | - @abstractmethod |
23 | | - def delete(self, exp_id: int): |
24 | | - raise NotImplementedError("Subclasses must implement this method.") |
| 114 | + return exp_id |
25 | 115 |
|
26 | | - @abstractmethod |
27 | 116 | def get(self, exp_id: int): |
28 | | - raise NotImplementedError("Subclasses must implement this method.") |
| 117 | + return self._runtime._metadb.get_exp(exp_id=exp_id) |
29 | 118 |
|
30 | | - @abstractmethod |
31 | 119 | def list(self, page: int = 0, page_size: int = 10): |
32 | | - raise NotImplementedError("Subclasses must implement this method.") |
| 120 | + return self._runtime._metadb.list_exps( |
| 121 | + project_id=self._runtime._project_id, page=page, page_size=page_size |
| 122 | + ) |
| 123 | + |
| 124 | + # TODO: delete related artifacts too. But for google artifact registry, |
| 125 | + # it seems not supported to delete a tag only. |
| 126 | + # See issue: https://github.com/InftyAI/alphatrion/issues/14 |
| 127 | + def delete(self, exp_id: int): |
| 128 | + self._runtime._metadb.delete_exp(exp_id=exp_id) |
33 | 129 |
|
34 | | - @abstractmethod |
| 130 | + # Please provide all the labels to update, or it will overwrite the existing labels. |
35 | 131 | def update_labels(self, exp_id: int, labels: dict): |
36 | | - raise NotImplementedError("Subclasses must implement this method.") |
| 132 | + self._runtime._metadb.update_exp(exp_id=exp_id, labels=labels) |
37 | 133 |
|
38 | | - @abstractmethod |
39 | | - def start(self, exp_id: int): |
40 | | - raise NotImplementedError("Subclasses must implement this method.") |
| 134 | + def start( |
| 135 | + self, |
| 136 | + name: str | None = None, |
| 137 | + description: str | None = None, |
| 138 | + meta: dict | None = None, |
| 139 | + labels: dict | None = None, |
| 140 | + ) -> int: |
| 141 | + if name is None: |
| 142 | + name = f"{uuid.uuid4()}" |
| 143 | + |
| 144 | + exp_id = self.create( |
| 145 | + name=name, |
| 146 | + description=description, |
| 147 | + meta=meta, |
| 148 | + labels=labels, |
| 149 | + ) |
| 150 | + self._runtime._metadb.update_exp(exp_id=exp_id, status=ExperimentStatus.RUNNING) |
| 151 | + self._start_at = datetime.now(UTC) |
| 152 | + return exp_id |
| 153 | + |
| 154 | + def stop(self, exp_id: int, status: ExperimentStatus = ExperimentStatus.FINISHED): |
| 155 | + exp = self._runtime._metadb.get_exp(exp_id=exp_id) |
| 156 | + if exp is not None and exp.status not in COMPLETED_STATUS: |
| 157 | + duration = (datetime.now() - exp.created_at).total_seconds() |
| 158 | + self._runtime._metadb.update_exp( |
| 159 | + exp_id=exp_id, status=status, duration=duration |
| 160 | + ) |
| 161 | + |
| 162 | + def status(self, exp_id: int) -> ExperimentStatus: |
| 163 | + exp = self._runtime._metadb.get_exp(exp_id=exp_id) |
| 164 | + return exp.status |
| 165 | + |
| 166 | + def reset(self): |
| 167 | + self._steps = 0 |
| 168 | + self._start_at = None |
| 169 | + self._best_metric_value = None |
| 170 | + |
| 171 | + # def save_checkpoint( |
| 172 | + # self, |
| 173 | + # exp_id: int, |
| 174 | + # files: list[str] | None = None, |
| 175 | + # folder: str | None = None, |
| 176 | + # version: str = "latest", |
| 177 | + # ): |
| 178 | + # exp = self._runtime._metadb.get_exp(exp_id=exp_id) |
| 179 | + # self._artifact.push( |
| 180 | + # experiment_name=exp.name, files=files, folder=folder, version=version |
| 181 | + # ) |
| 182 | + |
| 183 | + |
| 184 | +class RunContext: |
| 185 | + """A context manager for running experiments.""" |
| 186 | + |
| 187 | + def __init__( |
| 188 | + self, |
| 189 | + experiment: Experiment, |
| 190 | + name: str | None = None, |
| 191 | + description: str | None = None, |
| 192 | + meta: dict | None = None, |
| 193 | + labels: dict | None = None, |
| 194 | + ): |
| 195 | + self._experiment = experiment |
| 196 | + self._exp_name = name |
| 197 | + self._description = description |
| 198 | + self._meta = meta |
| 199 | + self._labels = labels |
41 | 200 |
|
42 | | - @abstractmethod |
43 | | - def stop(self, exp_id: int, status: str = "finished"): |
44 | | - raise NotImplementedError("Subclasses must implement this method.") |
| 201 | + def __enter__(self): |
| 202 | + self._exp_id = self._experiment.start( |
| 203 | + name=self._exp_name, |
| 204 | + description=self._description, |
| 205 | + meta=self._meta, |
| 206 | + labels=self._labels, |
| 207 | + ) |
| 208 | + return self._experiment |
45 | 209 |
|
46 | | - @abstractmethod |
47 | | - def status(self, exp_id: int) -> str: |
48 | | - raise NotImplementedError("Subclasses must implement this method.") |
| 210 | + def __exit__(self, exc_type, exc_val, exc_tb): |
| 211 | + self._experiment.stop(self._exp_id) |
| 212 | + self._experiment.reset() |
0 commit comments