Skip to content

Commit 0d9e270

Browse files
authored
Remove ExperimentConfig (#77)
Signed-off-by: kerthcet <[email protected]>
1 parent f0dcd0e commit 0d9e270

File tree

4 files changed

+7
-70
lines changed

4 files changed

+7
-70
lines changed

alphatrion/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from alphatrion.experiment.craft_exp import CraftExperiment, ExperimentConfig
1+
from alphatrion.experiment.craft_exp import CraftExperiment
22
from alphatrion.log.log import log_artifact, log_metrics, log_params
33
from alphatrion.runtime.runtime import init
44
from alphatrion.tracing.tracing import task, workflow
@@ -10,7 +10,6 @@
1010
"log_params",
1111
"log_metrics",
1212
"CraftExperiment",
13-
"ExperimentConfig",
1413
"Trial",
1514
"TrialConfig",
1615
"CheckpointConfig",

alphatrion/experiment/base.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,10 @@
22
from abc import ABC, abstractmethod
33
from dataclasses import dataclass
44

5-
from pydantic import BaseModel, Field
6-
75
from alphatrion.runtime.runtime import global_runtime
86
from alphatrion.trial import trial
97

108

11-
class ExperimentConfig(BaseModel):
12-
"""
13-
Configuration for Experiment.
14-
"""
15-
16-
max_runtime_seconds: int = Field(
17-
default=-1,
18-
description="Maximum runtime seconds for the experiment. \
19-
It will overwrite the trial timeout if both are set. \
20-
Default is -1 (no limit).",
21-
)
22-
23-
249
@dataclass
2510
class Experiment(ABC):
2611
"""
@@ -29,8 +14,7 @@ class Experiment(ABC):
2914

3015
__slots__ = ("_runtime", "_id", "_trials")
3116

32-
def __init__(self, config: ExperimentConfig | None = None):
33-
self._config = config or ExperimentConfig()
17+
def __init__(self):
3418
self._runtime = global_runtime()
3519
# All trials in this experiment, key is trial_id, value is Trial instance.
3620
self._trials = dict()

alphatrion/experiment/craft_exp.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from alphatrion.experiment.base import Experiment, ExperimentConfig
1+
from alphatrion.experiment.base import Experiment
22
from alphatrion.trial.trial import Trial, TrialConfig
33

44

@@ -10,23 +10,22 @@ class CraftExperiment(Experiment):
1010
Opposite to other experiment classes, you need to call all these methods yourself.
1111
"""
1212

13-
def __init__(self, config: ExperimentConfig | None = None):
14-
super().__init__(config=config)
13+
def __init__(self):
14+
super().__init__()
1515

1616
@classmethod
1717
def setup(
1818
cls,
1919
name: str,
2020
description: str | None = None,
2121
meta: dict | None = None,
22-
config: ExperimentConfig | None = None,
2322
) -> "CraftExperiment":
2423
"""
2524
Setup the experiment. If the name already exists in the same project,
2625
it will refer to the existing experiment instead of creating a new one.
2726
"""
2827

29-
exp = CraftExperiment(config=config)
28+
exp = CraftExperiment()
3029
exp_obj = exp._get_by_name(name=name, project_id=exp._runtime._project_id)
3130

3231
# If experiment with the same name exists in the project, use it.
@@ -63,15 +62,6 @@ def start_trial(
6362
:return: the Trial instance
6463
"""
6564

66-
config = config or TrialConfig()
67-
68-
if (
69-
self._config is not None
70-
and self._config.max_runtime_seconds > 0
71-
and config.max_runtime_seconds < 0
72-
):
73-
config.max_runtime_seconds = self._config.max_runtime_seconds
74-
7565
trial = Trial(exp_id=self._id, config=config)
7666
trial._start(name=name, description=description, meta=meta, params=params)
7767
self.register_trial(id=trial.id, instance=trial)

tests/unit/experiment/test_craft_exp.py

Lines changed: 1 addition & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import pytest
77

8-
from alphatrion.experiment.craft_exp import CraftExperiment, ExperimentConfig
8+
from alphatrion.experiment.craft_exp import CraftExperiment
99
from alphatrion.metadata.sql_models import TrialStatus
1010
from alphatrion.runtime.runtime import global_runtime, init
1111
from alphatrion.trial.trial import Trial, TrialConfig, current_trial_id
@@ -194,39 +194,3 @@ async def fake_work():
194194
fake_work(),
195195
)
196196
print("All trials finished.")
197-
198-
199-
@pytest.mark.asyncio
200-
async def test_craft_experiment_with_timeout():
201-
init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True)
202-
203-
exp = CraftExperiment.setup(
204-
name="timeout_exp",
205-
config=ExperimentConfig(max_runtime_seconds=3),
206-
)
207-
208-
async with exp.start_trial(name="first-trial") as trial:
209-
await trial.wait()
210-
211-
trial_obj = trial._get_obj()
212-
assert trial_obj.status == TrialStatus.COMPLETED
213-
214-
215-
@pytest.mark.asyncio
216-
async def test_craft_experiment_with_timeout_overwrite():
217-
init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True)
218-
219-
exp = CraftExperiment.setup(
220-
name="timeout_exp",
221-
config=ExperimentConfig(max_runtime_seconds=3),
222-
)
223-
224-
start_time = datetime.now()
225-
async with exp.start_trial(
226-
name="first-trial", config=TrialConfig(max_runtime_seconds=1)
227-
) as trial:
228-
await trial.wait()
229-
assert datetime.now() - start_time < timedelta(seconds=3)
230-
231-
trial_obj = trial._get_obj()
232-
assert trial_obj.status == TrialStatus.COMPLETED

0 commit comments

Comments
 (0)