Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,5 @@ test-integration: lint
until docker exec postgres pg_isready -U at_user; do sleep 1; done; \
$(POETRY) run pytest tests/integration; \
'
.PHONY: test-all
test-all: test test-integration
3 changes: 3 additions & 0 deletions alphatrion/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from alphatrion.experiment.craft_exp import CraftExperiment as CraftExperiment
from alphatrion.observe.observe import log_artifact as log_artifact
from alphatrion.runtime.runtime import init as init
12 changes: 6 additions & 6 deletions alphatrion/artifact/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
import oras.client

from alphatrion import consts
from alphatrion.runtime.runtime import Runtime

SUCCESS_CODE = 201


class Artifact:
def __init__(self, runtime: Runtime, insecure: bool = False):
self._runtime = runtime
def __init__(self, project_id: str, insecure: bool = False):
self._project_id = project_id
self._url = os.environ.get(consts.ARTIFACT_REGISTRY_URL)
self._url = self._url.replace("https://", "").replace("http://", "")
self._client = oras.client.OrasClient(
Expand Down Expand Up @@ -52,16 +51,17 @@ def push(
raise ValueError("No files to push.")

url = self._url if self._url.endswith("/") else f"{self._url}/"
target = f"{url}{self._runtime._project_id}/{experiment_name}:{version}"
target = f"{url}{self._project_id}/{experiment_name}:{version}"

try:
self._client.push(target, files=files_to_push)
except Exception as e:
raise RuntimeError("Failed to push artifacts") from e

# TODO: should we store it in the metadb instead?
def list_versions(self, experiment_name: str) -> list[str]:
url = self._url if self._url.endswith("/") else f"{self._url}/"
target = f"{url}{self._runtime._project_id}/{experiment_name}"
target = f"{url}{self._project_id}/{experiment_name}"
try:
tags = self._client.get_tags(target)
return tags
Expand All @@ -70,7 +70,7 @@ def list_versions(self, experiment_name: str) -> list[str]:

def delete(self, experiment_name: str, versions: str | list[str]):
url = self._url if self._url.endswith("/") else f"{self._url}/"
target = f"{url}{self._runtime._project_id}/{experiment_name}"
target = f"{url}{self._project_id}/{experiment_name}"

try:
self._client.delete_tags(target, tags=versions)
Expand Down
43 changes: 3 additions & 40 deletions alphatrion/experiment/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@

from pydantic import BaseModel, Field, field_validator

from alphatrion.artifact.artifact import Artifact
from alphatrion.metadata.sql_models import COMPLETED_STATUS, ExperimentStatus
from alphatrion.runtime.runtime import Runtime
from alphatrion.runtime.runtime import global_runtime


class CheckpointConfig(BaseModel):
Expand Down Expand Up @@ -75,9 +74,7 @@ class Experiment:

def __init__(
self,
runtime: Runtime,
config: ExperimentConfig | None = None,
artifact_insecure: bool = False,
):
"""
:param runtime: the Runtime instance
Expand All @@ -87,9 +84,8 @@ def __init__(
artifact registry. Default is False.
"""

self._runtime = runtime
self._artifact = Artifact(runtime, insecure=artifact_insecure)
self._config = config or ExperimentConfig()
self._runtime = global_runtime()

self._steps = 0
self._best_metric_value = None
Expand All @@ -100,13 +96,11 @@ def __init__(
@classmethod
def run(
cls,
project_id: str,
config: ExperimentConfig | None = None,
name: str | None = None,
description: str | None = None,
meta: dict | None = None,
labels: dict | None = None,
artifact_insecure: bool = False,
):
"""
:param project_id: the project ID to run the experiment under
Expand All @@ -121,12 +115,7 @@ def run(
:return: a context manager that yields an Experiment instance
"""

runtime = Runtime(project_id=project_id)
exp = Experiment(
runtime=runtime,
config=config,
artifact_insecure=artifact_insecure,
)
exp = Experiment(config=config)
return RunContext(
exp, name=name, description=description, meta=meta, labels=labels
)
Expand Down Expand Up @@ -234,32 +223,6 @@ def running_time(self) -> int:
return 0
return int((datetime.now(UTC) - self._start_at).total_seconds())

def log_artifact(
self,
exp_id: int,
paths: str | list[str],
version: str = "latest",
):
"""
Log artifacts (files) to the artifact registry.
:param exp_id: the experiment ID
:param paths: list of file paths to log.
Support one or multiple files or a folder.
If a folder is provided, all files in the folder will be logged.
Don't support nested folders currently.
Only files in the first level of the folder will be logged.
:param version: the version (tag) to log the files under
"""

if not paths:
raise ValueError("no files specified to log")

exp = self._runtime._metadb.get_exp(exp_id=exp_id)
if exp is None:
raise ValueError(f"Experiment with id {exp_id} does not exist.")

self._artifact.push(experiment_name=exp.name, paths=paths, version=version)


class RunContext:
"""A context manager for running experiments."""
Expand Down
9 changes: 4 additions & 5 deletions alphatrion/experiment/craft_exp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from alphatrion.experiment.base import Experiment
from alphatrion.runtime.runtime import Runtime
from alphatrion.experiment.base import Experiment, ExperimentConfig


class CraftExperiment(Experiment):
Expand All @@ -10,7 +9,7 @@ class CraftExperiment(Experiment):
Opposite to other experiment classes, you need to call all these methods yourself.
"""

def __init__(self, runtime: Runtime):
super().__init__(runtime)
# Disable checkpointing by default for CraftExperiment
def __init__(self, config: ExperimentConfig | None = None):
super().__init__(config=config)
# Disable auto-checkpointing by default for CraftExperiment
self._config.checkpoint.enabled = False
Empty file added alphatrion/observe/__init__.py
Empty file.
40 changes: 40 additions & 0 deletions alphatrion/observe/observe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from alphatrion.runtime.runtime import global_runtime


def log_artifact(
exp_id: int,
paths: str | list[str],
version: str = "latest",
):
"""
Log artifacts (files) to the artifact registry.

:param exp_id: the experiment ID
:param paths: list of file paths to log.
Support one or multiple files or a folder.
If a folder is provided, all files in the folder will be logged.
Don't support nested folders currently.
Only files in the first level of the folder will be logged.
:param version: the version (tag) to log the files
"""

if not paths:
raise ValueError("no files specified to log")

runtime = global_runtime()
if runtime is None:
raise RuntimeError("Runtime is not initialized. Please call init() first.")

exp = runtime._metadb.get_exp(exp_id=exp_id)
if exp is None:
raise ValueError(f"Experiment with id {exp_id} does not exist.")

runtime._artifact.push(experiment_name=exp.name, paths=paths, version=version)


# def log_params(exp_id: int, params: dict):
# runtime = global_runtime()
# if runtime is None:
# raise RuntimeError("Runtime is not initialized. Please call init() first.")

# runtime._metadb.log_params(exp_id=exp_id, params=params)
26 changes: 25 additions & 1 deletion alphatrion/runtime/runtime.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,34 @@
# ruff: noqa: PLW0603
import os

from alphatrion import consts
from alphatrion.artifact.artifact import Artifact
from alphatrion.metadata.sql import SQLStore

__RUNTIME__ = None


def init(project_id: str, artifact_insecure: bool = False):
"""
Initialize the AlphaTrion runtime environment.

:param project_id: the project ID to initialize the environment for
:param artifact_insecure: whether to use insecure connection to the
artifact registry
"""
global __RUNTIME__
__RUNTIME__ = Runtime(project_id=project_id, artifact_insecure=artifact_insecure)


def global_runtime():
return __RUNTIME__


# Runtime contains all kinds of clients, e.g., metadb client, artifact client, etc.
class Runtime:
def __init__(self, project_id: str):
def __init__(self, project_id: str, artifact_insecure: bool = False):
self._project_id = project_id
self._metadb = SQLStore(os.getenv(consts.METADATA_DB_URL), init_tables=True)
self._artifact = Artifact(
project_id=self._project_id, insecure=artifact_insecure
)
36 changes: 20 additions & 16 deletions tests/integration/artifact/test_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,22 @@

import pytest

from alphatrion.artifact.artifact import Artifact
from alphatrion.experiment.base import Experiment
from alphatrion.runtime.runtime import Runtime
from alphatrion.observe.observe import log_artifact
from alphatrion.runtime.runtime import global_runtime, init


@pytest.fixture
def artifact():
# We use a local registry for testing, it doesn't mean
# it will always successfully with cloud registries.
# We may need e2e tests for that.
runtime = Runtime(project_id="test_project")
artifact = Artifact(runtime=runtime, insecure=True)
init(project_id="test_project", artifact_insecure=True)
artifact = global_runtime()._artifact

yield artifact


def test_push_with_files(artifact):
# Create a temporary directory with some files
init(project_id="test_project", artifact_insecure=True)

with tempfile.TemporaryDirectory() as tmpdir:
os.chdir(tmpdir)

Expand All @@ -45,6 +44,8 @@ def test_push_with_files(artifact):


def test_push_with_folder(artifact):
init(project_id="test_project", artifact_insecure=True)

with tempfile.TemporaryDirectory() as tmpdir:
os.chdir(tmpdir)

Expand All @@ -66,13 +67,13 @@ def test_push_with_folder(artifact):


def test_save_checkpoint():
init(project_id="test_project", artifact_insecure=True)

with Experiment.run(
project_id="test_project",
name="context_exp",
description="Context manager test",
meta={"key": "value"},
labels={"type": "unit"},
artifact_insecure=True,
) as exp:
with tempfile.TemporaryDirectory() as tmpdir:
os.chdir(tmpdir)
Expand All @@ -81,18 +82,21 @@ def test_save_checkpoint():
with open(file1, "w") as f:
f.write("This is file1.")

exp.log_artifact(1, paths="file1.txt", version="v1")
versions = exp._artifact.list_versions("context_exp")
log_artifact(1, paths="file1.txt", version="v1")
versions = exp._runtime._artifact.list_versions("context_exp")
assert "v1" in versions

with open("file1.txt", "w") as f:
f.write("This is modified file1.")

# push folder instead
exp.log_artifact(1, paths=["file1.txt"], version="v2")
versions = exp._artifact.list_versions("context_exp")
log_artifact(1, paths=["file1.txt"], version="v2")
versions = exp._runtime._artifact.list_versions("context_exp")
assert "v2" in versions

exp._artifact.delete(experiment_name="context_exp", versions=["v1", "v2"])
versions = exp._artifact.list_versions("context_exp")
exp._runtime._artifact.delete(
experiment_name="context_exp",
versions=["v1", "v2"],
)
versions = exp._runtime._artifact.list_versions("context_exp")
assert len(versions) == 0
26 changes: 26 additions & 0 deletions tests/integration/test_sdk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os
import tempfile

import alphatrion as at


def test_sdk():
at.init(project_id="test_project", artifact_insecure=True)

with at.CraftExperiment.run(
name="craft_exp",
description="test description",
meta={"key": "value"},
labels={"type": "unit"},
) as exp:
with tempfile.TemporaryDirectory() as tmpdir:
os.chdir(tmpdir)

file = "file.txt"
with open(file, "w") as f:
f.write("Hello, AlphaTrion!")

at.log_artifact(2, paths=file, version="v1")

versions = exp._runtime._artifact.list_versions("craft_exp")
assert "v1" in versions
10 changes: 3 additions & 7 deletions tests/unit/artifact/test_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,13 @@

import pytest

from alphatrion.artifact.artifact import Artifact
from alphatrion.runtime.runtime import Runtime
from alphatrion.runtime.runtime import global_runtime, init


@pytest.fixture
def artifact():
# We use a local registry for testing, it doesn't mean
# it will always successfully with cloud registries.
# We may need e2e tests for that.
runtime = Runtime(project_id="test_project")
artifact = Artifact(runtime=runtime, insecure=True)
init(project_id="test_project", artifact_insecure=True)
artifact = global_runtime()._artifact
yield artifact


Expand Down
6 changes: 3 additions & 3 deletions tests/unit/experiment/test_base_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

from alphatrion.experiment.craft_exp import Experiment
from alphatrion.metadata.sql_models import ExperimentStatus
from alphatrion.runtime.runtime import Runtime
from alphatrion.runtime.runtime import init


@pytest.fixture
def exp():
runtime = Runtime(project_id="test_project")
exp = Experiment(runtime=runtime)
init(project_id="test_project", artifact_insecure=True)
exp = Experiment()
yield exp


Expand Down
Loading
Loading