@@ -23,17 +23,33 @@ async def test_craft_experiment():
2323 assert exp1 .name == "context_exp"
2424 assert exp1 .description == "Context manager test"
2525
26- trial = await exp .start_trial (description = "First trial" )
27- trial1 = trial ._get ()
28- assert trial1 is not None
29- assert trial1 .description == "First trial"
26+ trial = exp .start_trial (description = "First trial" )
27+ trial_obj = trial ._get_obj ()
28+ assert trial_obj is not None
29+ assert trial_obj .description == "First trial"
3030
3131 trial .stop ()
3232
33- trial2 = trial ._get ()
33+ trial2 = trial ._get_obj ()
3434 assert trial2 .status == TrialStatus .FINISHED
3535
3636
37+ @pytest .mark .asyncio
38+ async def test_create_experiment_with_trial ():
39+ init (project_id = "test_project" , artifact_insecure = True )
40+
41+ trial_id = None
42+ async with CraftExperiment .run (name = "context_exp" ) as exp :
43+ async with exp .start_trial (description = "First trial" ) as trial :
44+ trial_obj = trial ._get_obj ()
45+ assert trial_obj is not None
46+ assert trial_obj .description == "First trial"
47+ trial_id = current_trial_id .get ()
48+
49+ trial_obj = exp ._runtime ._metadb .get_trial (trial_id = trial_id )
50+ assert trial_obj .status == TrialStatus .FINISHED
51+
52+
3753@pytest .mark .asyncio
3854async def test_craft_experiment_with_context ():
3955 init (project_id = "test_project" , artifact_insecure = True )
@@ -43,13 +59,13 @@ async def test_craft_experiment_with_context():
4359 description = "Context manager test" ,
4460 meta = {"key" : "value" },
4561 ) as exp :
46- trial = await exp .start_trial (
62+ trial = exp .start_trial (
4763 description = "First trial" , config = TrialConfig (max_duration_seconds = 2 )
4864 )
4965 await trial .wait_stopped ()
5066 assert trial .stopped ()
5167
52- trial = trial ._get ()
68+ trial = trial ._get_obj ()
5369 assert trial .status == TrialStatus .FINISHED
5470
5571
@@ -59,7 +75,7 @@ async def test_craft_experiment_with_multi_trials_in_parallel():
5975
6076 async def fake_work (exp : CraftExperiment ):
6177 duration = random .randint (1 , 5 )
62- trial = await exp .start_trial (
78+ trial = exp .start_trial (
6379 description = "First trial" , config = TrialConfig (max_duration_seconds = duration )
6480 )
6581 # double check current trial id.
@@ -70,7 +86,7 @@ async def fake_work(exp: CraftExperiment):
7086 # we don't reset the current trial id.
7187 assert trial .id == current_trial_id .get ()
7288
73- trial = trial ._get ()
89+ trial = trial ._get_obj ()
7490 assert trial .status == TrialStatus .FINISHED
7591
7692 async with CraftExperiment .run (
0 commit comments