Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions alphatrion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from alphatrion.experiment.craft_exp import CraftExperiment as CraftExperiment
from alphatrion.observe.observe import log_artifact as log_artifact
from alphatrion.runtime.runtime import init as init
from alphatrion.experiment.craft_exp import CraftExperiment
from alphatrion.observe.observe import log_artifact, log_params
from alphatrion.runtime.runtime import init

__all__ = ["log_artifact", "log_params", "CraftExperiment", "init"]
32 changes: 16 additions & 16 deletions alphatrion/experiment/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,13 @@ def run(
name: str | None = None,
description: str | None = None,
meta: dict | None = None,
labels: dict | None = None,
):
"""
:param project_id: the project ID to run the experiment under
:param name: the name of the experiment. If not provided,
a UUID will be generated.
:param description: the description of the experiment
:param meta: the metadata of the experiment
:param labels: the labels of the experiment
:param artifact_insecure: whether to use insecure connection to the
artifact registry. Default is False.

Expand All @@ -114,15 +112,17 @@ def run(

exp = Experiment(config=config)
return RunContext(
exp, name=name, description=description, meta=meta, labels=labels
exp,
name=name,
description=description,
meta=meta,
)

def create(
self,
name: str,
description: str | None = None,
meta: dict | None = None,
labels: dict | None = None,
status: ExperimentStatus = ExperimentStatus.PENDING,
) -> int:
"""
Expand All @@ -135,7 +135,6 @@ def create(
description=description,
project_id=self._runtime._project_id,
meta=meta,
labels=labels,
status=status,
)

Expand All @@ -162,24 +161,29 @@ def delete(self, exp_id: int):
self._runtime._metadb.delete_exp(exp_id=exp_id)
self._artifact.delete(experiment_name=exp.name, versions=tags)

# Please provide all the labels to update, or it will overwrite the existing labels.
def update_labels(self, exp_id: int, labels: dict):
self._runtime._metadb.update_exp(exp_id=exp_id, labels=labels)
# Please provide all the tags to update, or it will overwrite the existing tags.
def update_tags(self, exp_id: int, tags: dict):
exp = self.get(exp_id)
if exp is None:
return

if exp.meta is None:
exp.meta = {}

exp.meta["tags"] = tags
self._runtime._metadb.update_exp(exp_id=exp_id, meta=exp.meta)

# start with save the
def _start(
self,
name: str | None = None,
description: str | None = None,
meta: dict | None = None,
labels: dict | None = None,
) -> int:
"""
:param name: the name of the experiment. If not provided,
a UUID will be generated.
:param description: the description of the experiment
:param meta: the metadata of the experiment
:param labels: the labels of the experiment

:return: the experiment ID
"""
Expand All @@ -191,8 +195,7 @@ def _start(
name=name,
description=description,
meta=meta,
labels=labels,
status=ExperimentStatus.RUNNING
status=ExperimentStatus.RUNNING,
)

return exp_id
Expand Down Expand Up @@ -226,20 +229,17 @@ def __init__(
name: str | None = None,
description: str | None = None,
meta: dict | None = None,
labels: dict | None = None,
):
self._experiment = experiment
self._exp_name = name
self._description = description
self._meta = meta
self._labels = labels

def __enter__(self):
exp_id = self._experiment._start(
name=self._exp_name,
description=self._description,
meta=self._meta,
labels=self._labels,
)

# Set the current experiment ID in the runtime
Expand Down
4 changes: 0 additions & 4 deletions alphatrion/metadata/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def create_exp(
project_id: str,
description: str | None,
meta: dict | None,
labels: dict | None = None,
status: ExperimentStatus = ExperimentStatus.PENDING,
) -> int:
session = self._session()
Expand All @@ -30,7 +29,6 @@ def create_exp(
description=description,
project_id=project_id,
meta=meta,
labels=labels,
status=status,
)
session.add(new_exp)
Expand Down Expand Up @@ -98,15 +96,13 @@ def create_model(
version: str = "latest",
description: str | None = None,
meta: dict | None = None,
labels: dict | None = None,
):
session = self._session()
new_model = Model(
name=name,
version=version,
description=description,
meta=meta,
labels=labels,
)
session.add(new_model)
session.commit()
Expand Down
14 changes: 12 additions & 2 deletions alphatrion/metadata/sql_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ class Experiment(Base):
Enum(ExperimentStatus), nullable=False, default=ExperimentStatus.PENDING
)
meta = Column(JSON, nullable=True, comment="Additional metadata for the experiment")
labels = Column(JSON, nullable=True, comment="Labels for the experiment")
# Let's start with simple approach here, it the params are too large,
# we can move them to a separate table.
params = Column(JSON, nullable=True, comment="Parameters for the experiment")
duration = Column(Integer, default=0, comment="Duration in seconds")

created_at = Column(DateTime(timezone=True), default=datetime.now(UTC))
Expand All @@ -47,10 +49,18 @@ class Model(Base):
version = Column(String, nullable=False)
description = Column(String, nullable=True)
meta = Column(JSON, nullable=True, comment="Additional metadata for the model")
labels = Column(JSON, nullable=True, comment="Labels for the model")

created_at = Column(DateTime(timezone=True), default=datetime.now(UTC))
updated_at = Column(
DateTime(timezone=True), default=datetime.now(UTC), onupdate=datetime.now(UTC)
)
is_del = Column(Integer, default=0, comment="0 for not deleted, 1 for deleted")


class Params(Base):
__tablename__ = "params"

id = Column(Integer, primary_key=True)
experiment_id = Column(Integer, nullable=False)
params = Column(JSON, nullable=False, comment="Parameters for the experiment")
is_del = Column(Integer, default=0, comment="0 for not deleted, 1 for deleted")
2 changes: 0 additions & 2 deletions alphatrion/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,11 @@ def create(
name: str,
description: str | None = None,
meta: dict | None = None,
labels: dict | None = None,
):
self._runtime._metadb.create_model(
name=name,
description=description,
meta=meta,
labels=labels,
)

def update(self, model_id: int, **kwargs):
Expand Down
10 changes: 5 additions & 5 deletions alphatrion/observe/observe.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ def log_artifact(
runtime._artifact.push(experiment_name=exp.name, paths=paths, version=version)


# def log_params(exp_id: int, params: dict):
# runtime = global_runtime()
# if runtime is None:
# raise RuntimeError("Runtime is not initialized. Please call init() first.")
def log_params(params: dict):
runtime = global_runtime()
if runtime is None:
raise RuntimeError("Runtime is not initialized. Please call init() first.")

# runtime._metadb.log_params(exp_id=exp_id, params=params)
runtime._metadb.update_exp(exp_id=runtime._current_exp_id, params=params)
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,13 @@ def test_push_with_folder(artifact):
assert "v1" not in tags


def test_save_checkpoint():
def test_log_artifact():
init(project_id="test_project", artifact_insecure=True)

with Experiment.run(
name="context_exp",
description="Context manager test",
meta={"key": "value"},
labels={"type": "unit"},
) as exp:
with tempfile.TemporaryDirectory() as tmpdir:
os.chdir(tmpdir)
Expand Down
14 changes: 14 additions & 0 deletions tests/integration/test_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@

import alphatrion as at


def test_log_params():
at.init(project_id="test_project", artifact_insecure=True)

with at.CraftExperiment.run(name="test_experiment") as exp:
params = {"param1": 0.1, "param2": "value2", "param3": 3}
at.log_params(params=params)

new_exp = exp._runtime._metadb.get_exp(exp_id=exp._runtime._current_exp_id)
assert new_exp is not None
assert new_exp.params == params
26 changes: 0 additions & 26 deletions tests/integration/test_sdk.py

This file was deleted.

7 changes: 3 additions & 4 deletions tests/unit/experiment/test_base_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,18 @@ def exp():


def test_experiment_crud(exp):
id = exp.create("test_exp", "A test experiment", {"foo": "bar"}, {"env": "test"})
id = exp.create("test_exp", "A test experiment")
exp1 = exp.get(id)
assert exp1 is not None
assert exp1.name == "test_exp"
assert exp1.description == "A test experiment"
assert exp1.meta == {"foo": "bar"}
assert exp1.status == ExperimentStatus.PENDING
assert exp1.duration == 0
assert len(exp.list_paginated()) == 1

exp.update_labels(id, {"env": "prod"})
exp.update_tags(id, {"env": "prod"})
exp1 = exp.get(id)
assert exp1.labels == {"env": "prod"}
assert exp1.meta["tags"] == {"env": "prod"}


def test_experiment_start(exp):
Expand Down
1 change: 0 additions & 1 deletion tests/unit/experiment/test_craft_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ def test_craft_experiment():
name="context_exp",
description="Context manager test",
meta={"key": "value"},
labels={"type": "unit"},
) as exp:
id = exp._runtime._current_exp_id
exp1 = exp.get(id)
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@ def model():


def test_model(model):
model.create("test_model", "A test model", {"foo": "bar"}, {"env": "test"})
model.create("test_model", "A test model", {"tags": {"foo": "bar"}})
model1 = model.get(1)
assert model1 is not None
assert model1.name == "test_model"
assert model1.description == "A test model"
assert model1.meta == {"foo": "bar"}
assert model1.meta == {"tags": {"foo": "bar"}}

model.update(1, labels={"env": "prod"})
model.update(1, meta={"tags": {"foo": "fuz"}})
model1 = model.get(1)
assert model1.labels == {"env": "prod"}
assert model1.meta == {"tags": {"foo": "fuz"}}

models = model.list()
assert len(models) == 1
Expand Down
Empty file.
Loading