Skip to content

Commit 565b9bb

Browse files
authored
Add support for log_params (#25)
* Defiine the public APIs Signed-off-by: kerthcet <[email protected]> * Move labels to meta Signed-off-by: kerthcet <[email protected]> * Add support to log params Signed-off-by: kerthcet <[email protected]> --------- Signed-off-by: kerthcet <[email protected]>
1 parent 9130e4f commit 565b9bb

File tree

13 files changed

+60
-69
lines changed

13 files changed

+60
-69
lines changed

alphatrion/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1-
from alphatrion.experiment.craft_exp import CraftExperiment as CraftExperiment
2-
from alphatrion.observe.observe import log_artifact as log_artifact
3-
from alphatrion.runtime.runtime import init as init
1+
from alphatrion.experiment.craft_exp import CraftExperiment
2+
from alphatrion.observe.observe import log_artifact, log_params
3+
from alphatrion.runtime.runtime import init
4+
5+
__all__ = ["log_artifact", "log_params", "CraftExperiment", "init"]

alphatrion/experiment/base.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -97,15 +97,13 @@ def run(
9797
name: str | None = None,
9898
description: str | None = None,
9999
meta: dict | None = None,
100-
labels: dict | None = None,
101100
):
102101
"""
103102
:param project_id: the project ID to run the experiment under
104103
:param name: the name of the experiment. If not provided,
105104
a UUID will be generated.
106105
:param description: the description of the experiment
107106
:param meta: the metadata of the experiment
108-
:param labels: the labels of the experiment
109107
:param artifact_insecure: whether to use insecure connection to the
110108
artifact registry. Default is False.
111109
@@ -114,15 +112,17 @@ def run(
114112

115113
exp = Experiment(config=config)
116114
return RunContext(
117-
exp, name=name, description=description, meta=meta, labels=labels
115+
exp,
116+
name=name,
117+
description=description,
118+
meta=meta,
118119
)
119120

120121
def create(
121122
self,
122123
name: str,
123124
description: str | None = None,
124125
meta: dict | None = None,
125-
labels: dict | None = None,
126126
status: ExperimentStatus = ExperimentStatus.PENDING,
127127
) -> int:
128128
"""
@@ -135,7 +135,6 @@ def create(
135135
description=description,
136136
project_id=self._runtime._project_id,
137137
meta=meta,
138-
labels=labels,
139138
status=status,
140139
)
141140

@@ -162,24 +161,29 @@ def delete(self, exp_id: int):
162161
self._runtime._metadb.delete_exp(exp_id=exp_id)
163162
self._artifact.delete(experiment_name=exp.name, versions=tags)
164163

165-
# Please provide all the labels to update, or it will overwrite the existing labels.
166-
def update_labels(self, exp_id: int, labels: dict):
167-
self._runtime._metadb.update_exp(exp_id=exp_id, labels=labels)
164+
# Please provide all the tags to update, or it will overwrite the existing tags.
165+
def update_tags(self, exp_id: int, tags: dict):
166+
exp = self.get(exp_id)
167+
if exp is None:
168+
return
169+
170+
if exp.meta is None:
171+
exp.meta = {}
172+
173+
exp.meta["tags"] = tags
174+
self._runtime._metadb.update_exp(exp_id=exp_id, meta=exp.meta)
168175

169-
# start with save the
170176
def _start(
171177
self,
172178
name: str | None = None,
173179
description: str | None = None,
174180
meta: dict | None = None,
175-
labels: dict | None = None,
176181
) -> int:
177182
"""
178183
:param name: the name of the experiment. If not provided,
179184
a UUID will be generated.
180185
:param description: the description of the experiment
181186
:param meta: the metadata of the experiment
182-
:param labels: the labels of the experiment
183187
184188
:return: the experiment ID
185189
"""
@@ -191,8 +195,7 @@ def _start(
191195
name=name,
192196
description=description,
193197
meta=meta,
194-
labels=labels,
195-
status=ExperimentStatus.RUNNING
198+
status=ExperimentStatus.RUNNING,
196199
)
197200

198201
return exp_id
@@ -226,20 +229,17 @@ def __init__(
226229
name: str | None = None,
227230
description: str | None = None,
228231
meta: dict | None = None,
229-
labels: dict | None = None,
230232
):
231233
self._experiment = experiment
232234
self._exp_name = name
233235
self._description = description
234236
self._meta = meta
235-
self._labels = labels
236237

237238
def __enter__(self):
238239
exp_id = self._experiment._start(
239240
name=self._exp_name,
240241
description=self._description,
241242
meta=self._meta,
242-
labels=self._labels,
243243
)
244244

245245
# Set the current experiment ID in the runtime

alphatrion/metadata/sql.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ def create_exp(
2121
project_id: str,
2222
description: str | None,
2323
meta: dict | None,
24-
labels: dict | None = None,
2524
status: ExperimentStatus = ExperimentStatus.PENDING,
2625
) -> int:
2726
session = self._session()
@@ -30,7 +29,6 @@ def create_exp(
3029
description=description,
3130
project_id=project_id,
3231
meta=meta,
33-
labels=labels,
3432
status=status,
3533
)
3634
session.add(new_exp)
@@ -98,15 +96,13 @@ def create_model(
9896
version: str = "latest",
9997
description: str | None = None,
10098
meta: dict | None = None,
101-
labels: dict | None = None,
10299
):
103100
session = self._session()
104101
new_model = Model(
105102
name=name,
106103
version=version,
107104
description=description,
108105
meta=meta,
109-
labels=labels,
110106
)
111107
session.add(new_model)
112108
session.commit()

alphatrion/metadata/sql_models.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ class Experiment(Base):
2929
Enum(ExperimentStatus), nullable=False, default=ExperimentStatus.PENDING
3030
)
3131
meta = Column(JSON, nullable=True, comment="Additional metadata for the experiment")
32-
labels = Column(JSON, nullable=True, comment="Labels for the experiment")
32+
# Let's start with simple approach here, it the params are too large,
33+
# we can move them to a separate table.
34+
params = Column(JSON, nullable=True, comment="Parameters for the experiment")
3335
duration = Column(Integer, default=0, comment="Duration in seconds")
3436

3537
created_at = Column(DateTime(timezone=True), default=datetime.now(UTC))
@@ -47,10 +49,18 @@ class Model(Base):
4749
version = Column(String, nullable=False)
4850
description = Column(String, nullable=True)
4951
meta = Column(JSON, nullable=True, comment="Additional metadata for the model")
50-
labels = Column(JSON, nullable=True, comment="Labels for the model")
5152

5253
created_at = Column(DateTime(timezone=True), default=datetime.now(UTC))
5354
updated_at = Column(
5455
DateTime(timezone=True), default=datetime.now(UTC), onupdate=datetime.now(UTC)
5556
)
5657
is_del = Column(Integer, default=0, comment="0 for not deleted, 1 for deleted")
58+
59+
60+
class Params(Base):
61+
__tablename__ = "params"
62+
63+
id = Column(Integer, primary_key=True)
64+
experiment_id = Column(Integer, nullable=False)
65+
params = Column(JSON, nullable=False, comment="Parameters for the experiment")
66+
is_del = Column(Integer, default=0, comment="0 for not deleted, 1 for deleted")

alphatrion/model/model.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,11 @@ def create(
1010
name: str,
1111
description: str | None = None,
1212
meta: dict | None = None,
13-
labels: dict | None = None,
1413
):
1514
self._runtime._metadb.create_model(
1615
name=name,
1716
description=description,
1817
meta=meta,
19-
labels=labels,
2018
)
2119

2220
def update(self, model_id: int, **kwargs):

alphatrion/observe/observe.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ def log_artifact(
3535
runtime._artifact.push(experiment_name=exp.name, paths=paths, version=version)
3636

3737

38-
# def log_params(exp_id: int, params: dict):
39-
# runtime = global_runtime()
40-
# if runtime is None:
41-
# raise RuntimeError("Runtime is not initialized. Please call init() first.")
38+
def log_params(params: dict):
39+
runtime = global_runtime()
40+
if runtime is None:
41+
raise RuntimeError("Runtime is not initialized. Please call init() first.")
4242

43-
# runtime._metadb.log_params(exp_id=exp_id, params=params)
43+
runtime._metadb.update_exp(exp_id=runtime._current_exp_id, params=params)

tests/integration/artifact/test_artifact.py renamed to tests/integration/test_artifact.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,13 @@ def test_push_with_folder(artifact):
6666
assert "v1" not in tags
6767

6868

69-
def test_save_checkpoint():
69+
def test_log_artifact():
7070
init(project_id="test_project", artifact_insecure=True)
7171

7272
with Experiment.run(
7373
name="context_exp",
7474
description="Context manager test",
7575
meta={"key": "value"},
76-
labels={"type": "unit"},
7776
) as exp:
7877
with tempfile.TemporaryDirectory() as tmpdir:
7978
os.chdir(tmpdir)

tests/integration/test_params.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
2+
import alphatrion as at
3+
4+
5+
def test_log_params():
6+
at.init(project_id="test_project", artifact_insecure=True)
7+
8+
with at.CraftExperiment.run(name="test_experiment") as exp:
9+
params = {"param1": 0.1, "param2": "value2", "param3": 3}
10+
at.log_params(params=params)
11+
12+
new_exp = exp._runtime._metadb.get_exp(exp_id=exp._runtime._current_exp_id)
13+
assert new_exp is not None
14+
assert new_exp.params == params

tests/integration/test_sdk.py

Lines changed: 0 additions & 26 deletions
This file was deleted.

tests/unit/experiment/test_base_exp.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,18 @@ def exp():
1313

1414

1515
def test_experiment_crud(exp):
16-
id = exp.create("test_exp", "A test experiment", {"foo": "bar"}, {"env": "test"})
16+
id = exp.create("test_exp", "A test experiment")
1717
exp1 = exp.get(id)
1818
assert exp1 is not None
1919
assert exp1.name == "test_exp"
2020
assert exp1.description == "A test experiment"
21-
assert exp1.meta == {"foo": "bar"}
2221
assert exp1.status == ExperimentStatus.PENDING
2322
assert exp1.duration == 0
2423
assert len(exp.list_paginated()) == 1
2524

26-
exp.update_labels(id, {"env": "prod"})
25+
exp.update_tags(id, {"env": "prod"})
2726
exp1 = exp.get(id)
28-
assert exp1.labels == {"env": "prod"}
27+
assert exp1.meta["tags"] == {"env": "prod"}
2928

3029

3130
def test_experiment_start(exp):

0 commit comments

Comments
 (0)