Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions alphatrion/experiment/craft_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def run(

return exp

async def start_trial(
def start_trial(
self,
description: str | None = None,
meta: dict | None = None,
Expand All @@ -50,6 +50,9 @@ async def start_trial(
) -> Trial:
"""
start_trial starts a new trial in this experiment.
You need to call trial.stop() to stop the trial for proper cleanup,
unless it's a timeout trial. Or you can use 'async with exp.run_trial(...)'
as trial, which will automatically stop the trial at the end of the context.

:params description: the description of the trial
:params meta: the metadata of the trial
Expand All @@ -59,6 +62,6 @@ async def start_trial(
"""

trial = Trial(exp_id=self._id, config=config)
await trial._start(description=description, meta=meta, params=params)
trial._start(description=description, meta=meta, params=params)
self.register_trial(id=trial.id, instance=trial)
return trial
12 changes: 9 additions & 3 deletions alphatrion/trial/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ def __init__(self, exp_id: int, config: TrialConfig | None = None):
# like the metric max/min values.
self._construct_meta()

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
self.stop()

def _construct_meta(self):
self._meta = dict()

Expand Down Expand Up @@ -165,7 +171,7 @@ def stopped(self) -> bool:
async def wait_stopped(self):
await self._context.wait_cancelled()

async def _start(
def _start(
self,
description: str | None = None,
meta: dict | None = None,
Expand All @@ -182,7 +188,7 @@ async def _start(
# We don't reset the trial id context var here, because
# each trial runs in its own context.
self._token = current_trial_id.set(self._id)
await self._context.start()
self._context.start()
return self._id

@property
Expand All @@ -205,7 +211,7 @@ def _stop(self):

self._runtime.current_exp.unregister_trial(self._id)

def _get(self):
def _get_obj(self):
return self._runtime._metadb.get_trial(trial_id=self._id)

def increment_step(self) -> int:
Expand Down
2 changes: 1 addition & 1 deletion alphatrion/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self, cancel_func: Callable | None = None, timeout=None):
self._cancel_func = cancel_func
self._timeout = timeout

async def start(self):
def start(self):
# If timeout is None, it means no timeout is set.
# If timeout is negative, it means already timed out.
if self._timeout is not None:
Expand Down
12 changes: 5 additions & 7 deletions tests/integration/test_log_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ async def test_log_artifact():
description="Context manager test",
meta={"key": "value"},
) as exp:
trial = await exp.start_trial(description="First trial")
trial = exp.start_trial(description="First trial")

exp_obj = exp._runtime._metadb.get_exp(exp_id=exp._id)
assert exp_obj is not None
Expand Down Expand Up @@ -66,7 +66,7 @@ async def test_log_params():
alpha.init(project_id="test_project", artifact_insecure=True)

async with alpha.CraftExperiment.run(name="test_experiment") as exp:
trial = await exp.start_trial(description="First trial", params={"param1": 0.1})
trial = exp.start_trial(description="First trial", params={"param1": 0.1})

new_trial = exp._runtime._metadb.get_trial(trial_id=trial.id)
assert new_trial is not None
Expand All @@ -83,9 +83,7 @@ async def test_log_params():

trial.stop()

trial = await exp.start_trial(
description="Second trial", params={"param1": 0.1}
)
trial = exp.start_trial(description="Second trial", params={"param1": 0.1})
assert current_trial_id.get() == trial.id
trial.stop()

Expand All @@ -95,7 +93,7 @@ async def test_log_metrics():
alpha.init(project_id="test_project", artifact_insecure=True)

async with alpha.CraftExperiment.run(name="test_experiment") as exp:
trial = await exp.start_trial(description="First trial", params={"param1": 0.1})
trial = exp.start_trial(description="First trial", params={"param1": 0.1})

new_trial = exp._runtime._metadb.get_trial(trial_id=trial._id)
assert new_trial is not None
Expand Down Expand Up @@ -138,7 +136,7 @@ async def test_log_metrics_with_save_best_only():
with tempfile.TemporaryDirectory() as tmpdir:
os.chdir(tmpdir)

_ = await exp.start_trial(
_ = exp.start_trial(
description="Trial with save_best_only",
config=TrialConfig(
checkpoint=CheckpointConfig(
Expand Down
34 changes: 25 additions & 9 deletions tests/unit/experiment/test_craft_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,33 @@ async def test_craft_experiment():
assert exp1.name == "context_exp"
assert exp1.description == "Context manager test"

trial = await exp.start_trial(description="First trial")
trial1 = trial._get()
assert trial1 is not None
assert trial1.description == "First trial"
trial = exp.start_trial(description="First trial")
trial_obj = trial._get_obj()
assert trial_obj is not None
assert trial_obj.description == "First trial"

trial.stop()

trial2 = trial._get()
trial2 = trial._get_obj()
assert trial2.status == TrialStatus.FINISHED


@pytest.mark.asyncio
async def test_create_experiment_with_trial():
init(project_id="test_project", artifact_insecure=True)

trial_id = None
async with CraftExperiment.run(name="context_exp") as exp:
async with exp.start_trial(description="First trial") as trial:
trial_obj = trial._get_obj()
assert trial_obj is not None
assert trial_obj.description == "First trial"
trial_id = current_trial_id.get()

trial_obj = exp._runtime._metadb.get_trial(trial_id=trial_id)
assert trial_obj.status == TrialStatus.FINISHED


@pytest.mark.asyncio
async def test_craft_experiment_with_context():
init(project_id="test_project", artifact_insecure=True)
Expand All @@ -43,13 +59,13 @@ async def test_craft_experiment_with_context():
description="Context manager test",
meta={"key": "value"},
) as exp:
trial = await exp.start_trial(
trial = exp.start_trial(
description="First trial", config=TrialConfig(max_duration_seconds=2)
)
await trial.wait_stopped()
assert trial.stopped()

trial = trial._get()
trial = trial._get_obj()
assert trial.status == TrialStatus.FINISHED


Expand All @@ -59,7 +75,7 @@ async def test_craft_experiment_with_multi_trials_in_parallel():

async def fake_work(exp: CraftExperiment):
duration = random.randint(1, 5)
trial = await exp.start_trial(
trial = exp.start_trial(
description="First trial", config=TrialConfig(max_duration_seconds=duration)
)
# double check current trial id.
Expand All @@ -70,7 +86,7 @@ async def fake_work(exp: CraftExperiment):
# we don't reset the current trial id.
assert trial.id == current_trial_id.get()

trial = trial._get()
trial = trial._get_obj()
assert trial.status == TrialStatus.FINISHED

async with CraftExperiment.run(
Expand Down
10 changes: 5 additions & 5 deletions tests/unit/utils/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
@pytest.mark.asyncio
async def test_context_no_timeout():
ctx = Context()
await ctx.start()
ctx.start()
assert not ctx.cancelled()
ctx.cancel()
# double cancel should be no-op
Expand All @@ -20,7 +20,7 @@ async def test_context_no_timeout():
@pytest.mark.asyncio
async def test_context_with_timeout():
ctx = Context(timeout=0.1)
await ctx.start()
ctx.start()
assert not ctx.cancelled()
await asyncio.sleep(0.2)
assert ctx.cancelled()
Expand All @@ -30,7 +30,7 @@ async def test_context_with_timeout():
@pytest.mark.asyncio
async def test_context_manual_cancel():
ctx = Context(timeout=10000)
await ctx.start()
ctx.start()
assert not ctx.cancelled()
ctx.cancel()
assert ctx.cancelled()
Expand All @@ -40,7 +40,7 @@ async def test_context_manual_cancel():
@pytest.mark.asyncio
async def test_context_wait_cancelled():
ctx = Context()
await ctx.start()
ctx.start()

async def waiter():
await ctx.wait_cancelled()
Expand All @@ -58,7 +58,7 @@ async def waiter():
@pytest.mark.asyncio
async def test_context_multiple_waiters():
ctx = Context()
await ctx.start()
ctx.start()
results = []

async def waiter(idx):
Expand Down
Loading