Skip to content

Commit 04611dc

Browse files
authored
Support with function to start a new trial (#36)
* switch async to sync for start_trial Signed-off-by: kerthcet <[email protected]> * Support "with exp.start_trial() Signed-off-by: kerthcet <[email protected]> --------- Signed-off-by: kerthcet <[email protected]>
1 parent e54d749 commit 04611dc

File tree

6 files changed

+50
-27
lines changed

6 files changed

+50
-27
lines changed

alphatrion/experiment/craft_exp.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def run(
4141

4242
return exp
4343

44-
async def start_trial(
44+
def start_trial(
4545
self,
4646
description: str | None = None,
4747
meta: dict | None = None,
@@ -50,6 +50,9 @@ async def start_trial(
5050
) -> Trial:
5151
"""
5252
start_trial starts a new trial in this experiment.
53+
You need to call trial.stop() to stop the trial for proper cleanup,
54+
unless it's a timeout trial. Or you can use 'async with exp.run_trial(...)'
55+
as trial, which will automatically stop the trial at the end of the context.
5356
5457
:params description: the description of the trial
5558
:params meta: the metadata of the trial
@@ -59,6 +62,6 @@ async def start_trial(
5962
"""
6063

6164
trial = Trial(exp_id=self._id, config=config)
62-
await trial._start(description=description, meta=meta, params=params)
65+
trial._start(description=description, meta=meta, params=params)
6366
self.register_trial(id=trial.id, instance=trial)
6467
return trial

alphatrion/trial/trial.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,12 @@ def __init__(self, exp_id: int, config: TrialConfig | None = None):
100100
# like the metric max/min values.
101101
self._construct_meta()
102102

103+
async def __aenter__(self):
104+
return self
105+
106+
async def __aexit__(self, exc_type, exc_val, exc_tb):
107+
self.stop()
108+
103109
def _construct_meta(self):
104110
self._meta = dict()
105111

@@ -165,7 +171,7 @@ def stopped(self) -> bool:
165171
async def wait_stopped(self):
166172
await self._context.wait_cancelled()
167173

168-
async def _start(
174+
def _start(
169175
self,
170176
description: str | None = None,
171177
meta: dict | None = None,
@@ -182,7 +188,7 @@ async def _start(
182188
# We don't reset the trial id context var here, because
183189
# each trial runs in its own context.
184190
self._token = current_trial_id.set(self._id)
185-
await self._context.start()
191+
self._context.start()
186192
return self._id
187193

188194
@property
@@ -205,7 +211,7 @@ def _stop(self):
205211

206212
self._runtime.current_exp.unregister_trial(self._id)
207213

208-
def _get(self):
214+
def _get_obj(self):
209215
return self._runtime._metadb.get_trial(trial_id=self._id)
210216

211217
def increment_step(self) -> int:

alphatrion/utils/context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def __init__(self, cancel_func: Callable | None = None, timeout=None):
1313
self._cancel_func = cancel_func
1414
self._timeout = timeout
1515

16-
async def start(self):
16+
def start(self):
1717
# If timeout is None, it means no timeout is set.
1818
# If timeout is negative, it means already timed out.
1919
if self._timeout is not None:

tests/integration/test_log_functions.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ async def test_log_artifact():
1818
description="Context manager test",
1919
meta={"key": "value"},
2020
) as exp:
21-
trial = await exp.start_trial(description="First trial")
21+
trial = exp.start_trial(description="First trial")
2222

2323
exp_obj = exp._runtime._metadb.get_exp(exp_id=exp._id)
2424
assert exp_obj is not None
@@ -66,7 +66,7 @@ async def test_log_params():
6666
alpha.init(project_id="test_project", artifact_insecure=True)
6767

6868
async with alpha.CraftExperiment.run(name="test_experiment") as exp:
69-
trial = await exp.start_trial(description="First trial", params={"param1": 0.1})
69+
trial = exp.start_trial(description="First trial", params={"param1": 0.1})
7070

7171
new_trial = exp._runtime._metadb.get_trial(trial_id=trial.id)
7272
assert new_trial is not None
@@ -83,9 +83,7 @@ async def test_log_params():
8383

8484
trial.stop()
8585

86-
trial = await exp.start_trial(
87-
description="Second trial", params={"param1": 0.1}
88-
)
86+
trial = exp.start_trial(description="Second trial", params={"param1": 0.1})
8987
assert current_trial_id.get() == trial.id
9088
trial.stop()
9189

@@ -95,7 +93,7 @@ async def test_log_metrics():
9593
alpha.init(project_id="test_project", artifact_insecure=True)
9694

9795
async with alpha.CraftExperiment.run(name="test_experiment") as exp:
98-
trial = await exp.start_trial(description="First trial", params={"param1": 0.1})
96+
trial = exp.start_trial(description="First trial", params={"param1": 0.1})
9997

10098
new_trial = exp._runtime._metadb.get_trial(trial_id=trial._id)
10199
assert new_trial is not None
@@ -138,7 +136,7 @@ async def test_log_metrics_with_save_best_only():
138136
with tempfile.TemporaryDirectory() as tmpdir:
139137
os.chdir(tmpdir)
140138

141-
_ = await exp.start_trial(
139+
_ = exp.start_trial(
142140
description="Trial with save_best_only",
143141
config=TrialConfig(
144142
checkpoint=CheckpointConfig(

tests/unit/experiment/test_craft_exp.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,33 @@ async def test_craft_experiment():
2323
assert exp1.name == "context_exp"
2424
assert exp1.description == "Context manager test"
2525

26-
trial = await exp.start_trial(description="First trial")
27-
trial1 = trial._get()
28-
assert trial1 is not None
29-
assert trial1.description == "First trial"
26+
trial = exp.start_trial(description="First trial")
27+
trial_obj = trial._get_obj()
28+
assert trial_obj is not None
29+
assert trial_obj.description == "First trial"
3030

3131
trial.stop()
3232

33-
trial2 = trial._get()
33+
trial2 = trial._get_obj()
3434
assert trial2.status == TrialStatus.FINISHED
3535

3636

37+
@pytest.mark.asyncio
38+
async def test_create_experiment_with_trial():
39+
init(project_id="test_project", artifact_insecure=True)
40+
41+
trial_id = None
42+
async with CraftExperiment.run(name="context_exp") as exp:
43+
async with exp.start_trial(description="First trial") as trial:
44+
trial_obj = trial._get_obj()
45+
assert trial_obj is not None
46+
assert trial_obj.description == "First trial"
47+
trial_id = current_trial_id.get()
48+
49+
trial_obj = exp._runtime._metadb.get_trial(trial_id=trial_id)
50+
assert trial_obj.status == TrialStatus.FINISHED
51+
52+
3753
@pytest.mark.asyncio
3854
async def test_craft_experiment_with_context():
3955
init(project_id="test_project", artifact_insecure=True)
@@ -43,13 +59,13 @@ async def test_craft_experiment_with_context():
4359
description="Context manager test",
4460
meta={"key": "value"},
4561
) as exp:
46-
trial = await exp.start_trial(
62+
trial = exp.start_trial(
4763
description="First trial", config=TrialConfig(max_duration_seconds=2)
4864
)
4965
await trial.wait_stopped()
5066
assert trial.stopped()
5167

52-
trial = trial._get()
68+
trial = trial._get_obj()
5369
assert trial.status == TrialStatus.FINISHED
5470

5571

@@ -59,7 +75,7 @@ async def test_craft_experiment_with_multi_trials_in_parallel():
5975

6076
async def fake_work(exp: CraftExperiment):
6177
duration = random.randint(1, 5)
62-
trial = await exp.start_trial(
78+
trial = exp.start_trial(
6379
description="First trial", config=TrialConfig(max_duration_seconds=duration)
6480
)
6581
# double check current trial id.
@@ -70,7 +86,7 @@ async def fake_work(exp: CraftExperiment):
7086
# we don't reset the current trial id.
7187
assert trial.id == current_trial_id.get()
7288

73-
trial = trial._get()
89+
trial = trial._get_obj()
7490
assert trial.status == TrialStatus.FINISHED
7591

7692
async with CraftExperiment.run(

tests/unit/utils/test_context.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
@pytest.mark.asyncio
99
async def test_context_no_timeout():
1010
ctx = Context()
11-
await ctx.start()
11+
ctx.start()
1212
assert not ctx.cancelled()
1313
ctx.cancel()
1414
# double cancel should be no-op
@@ -20,7 +20,7 @@ async def test_context_no_timeout():
2020
@pytest.mark.asyncio
2121
async def test_context_with_timeout():
2222
ctx = Context(timeout=0.1)
23-
await ctx.start()
23+
ctx.start()
2424
assert not ctx.cancelled()
2525
await asyncio.sleep(0.2)
2626
assert ctx.cancelled()
@@ -30,7 +30,7 @@ async def test_context_with_timeout():
3030
@pytest.mark.asyncio
3131
async def test_context_manual_cancel():
3232
ctx = Context(timeout=10000)
33-
await ctx.start()
33+
ctx.start()
3434
assert not ctx.cancelled()
3535
ctx.cancel()
3636
assert ctx.cancelled()
@@ -40,7 +40,7 @@ async def test_context_manual_cancel():
4040
@pytest.mark.asyncio
4141
async def test_context_wait_cancelled():
4242
ctx = Context()
43-
await ctx.start()
43+
ctx.start()
4444

4545
async def waiter():
4646
await ctx.wait_cancelled()
@@ -58,7 +58,7 @@ async def waiter():
5858
@pytest.mark.asyncio
5959
async def test_context_multiple_waiters():
6060
ctx = Context()
61-
await ctx.start()
61+
ctx.start()
6262
results = []
6363

6464
async def waiter(idx):

0 commit comments

Comments
 (0)