diff --git a/mlos_bench/mlos_bench/storage/base_storage.py b/mlos_bench/mlos_bench/storage/base_storage.py index 75a84bf0b2..3965c1404f 100644 --- a/mlos_bench/mlos_bench/storage/base_storage.py +++ b/mlos_bench/mlos_bench/storage/base_storage.py @@ -29,6 +29,7 @@ from collections.abc import Iterator, Mapping from contextlib import AbstractContextManager as ContextManager from datetime import datetime +from subprocess import CalledProcessError from types import TracebackType from typing import Any, Literal @@ -187,15 +188,62 @@ def __init__( # pylint: disable=too-many-arguments tunables: TunableGroups, experiment_id: str, trial_id: int, - root_env_config: str, + root_env_config: str | None, description: str, opt_targets: dict[str, Literal["min", "max"]], + git_repo: str | None = None, + git_commit: str | None = None, + git_rel_root_env_config: str | None = None, ): self._tunables = tunables.copy() self._trial_id = trial_id self._experiment_id = experiment_id - (self._git_repo, self._git_commit, self._root_env_config) = get_git_info( - root_env_config + self._abs_root_env_config: str | None + if root_env_config is not None: + if git_repo or git_commit or git_rel_root_env_config: + # Extra args are only used when restoring an Experiment from the DB. + raise ValueError("Unexpected args: git_repo, git_commit, rel_root_env_config") + try: + ( + self._git_repo, + self._git_commit, + self._git_rel_root_env_config, + self._abs_root_env_config, + ) = get_git_info(root_env_config) + except CalledProcessError as e: + # Note: currently the Experiment schema requires git + # metadata to be set. We *could* set the git metadata to + # dummy values, but for now we just throw an error. + _LOG.warning( + "Failed to get git info for root_env_config %s: %s", + root_env_config, + e, + ) + raise e + else: + # Restoring from DB. + if not (git_repo and git_commit and git_rel_root_env_config): + raise ValueError("Missing args: git_repo, git_commit, rel_root_env_config") + self._git_repo = git_repo + self._git_commit = git_commit + self._git_rel_root_env_config = git_rel_root_env_config + # Note: The absolute path to the root config is not stored in the DB, + # and resolving it is not always possible, so we omit this + # operation by default for now. + # See commit 0cb5948865662776e92ceaca3f0a80a34c6a39ef in + # for prior + # implementation attempts. + self._abs_root_env_config = None + assert isinstance( + self._git_rel_root_env_config, str + ), "Failed to get relative root config path" + _LOG.info( + "Resolved relative root_config %s from %s at commit %s for Experiment %s to %s", + self._git_rel_root_env_config, + self._git_repo, + self._git_commit, + self._experiment_id, + self._abs_root_env_config, ) self._description = description self._opt_targets = opt_targets @@ -205,6 +253,8 @@ def __enter__(self) -> Storage.Experiment: """ Enter the context of the experiment. + Notes + ----- Override the `_setup` method to add custom context initialization. """ _LOG.debug("Starting experiment: %s", self) @@ -222,6 +272,8 @@ def __exit__( """ End the context of the experiment. + Notes + ----- Override the `_teardown` method to add custom context teardown logic. """ is_ok = exc_val is None @@ -247,14 +299,14 @@ def _setup(self) -> None: Create a record of the new experiment or find an existing one in the storage. - This method is called by `Storage.Experiment.__enter__()`. + This method is called by :py:class:`.Storage.Experiment.__enter__()`. """ def _teardown(self, is_ok: bool) -> None: """ Finalize the experiment in the storage. - This method is called by `Storage.Experiment.__exit__()`. + This method is called by :py:class:`.Storage.Experiment.__exit__()`. Parameters ---------- @@ -278,9 +330,35 @@ def description(self) -> str: return self._description @property - def root_env_config(self) -> str: - """Get the Experiment's root Environment config file path.""" - return self._root_env_config + def rel_root_env_config(self) -> str: + """Get the Experiment's root Environment config's relative file path to the + git repo root. + """ + return self._git_rel_root_env_config + + @property + def abs_root_env_config(self) -> str | None: + """ + Get the Experiment's root Environment config absolute file path. + + This attempts to return the current absolute path to the root config + for this process instead of the path relative to the git repo root. + + However, this may not always be possible if the git repo root is not + accessible, which can happen if the Experiment was restored from the + DB, but the process was started from a different working directory, + for instance. + + Notes + ----- + This is mostly useful for other components (e.g., + :py:class:`~mlos_bench.schedulers.base_scheduler.Scheduler`) to use + within the same process, and not across invocations. + """ + # TODO: In the future, we can consider fetching the git_repo to a + # standard working directory for ``mlos_bench`` and then resolving + # the root config path from there based on the relative path. + return self._abs_root_env_config @property def tunables(self) -> TunableGroups: diff --git a/mlos_bench/mlos_bench/storage/sql/experiment.py b/mlos_bench/mlos_bench/storage/sql/experiment.py index 838b62a842..ea8c857495 100644 --- a/mlos_bench/mlos_bench/storage/sql/experiment.py +++ b/mlos_bench/mlos_bench/storage/sql/experiment.py @@ -38,9 +38,12 @@ def __init__( # pylint: disable=too-many-arguments tunables: TunableGroups, experiment_id: str, trial_id: int, - root_env_config: str, + root_env_config: str | None, description: str, opt_targets: dict[str, Literal["min", "max"]], + git_repo: str | None = None, + git_commit: str | None = None, + git_rel_root_env_config: str | None = None, ): super().__init__( tunables=tunables, @@ -49,6 +52,9 @@ def __init__( # pylint: disable=too-many-arguments root_env_config=root_env_config, description=description, opt_targets=opt_targets, + git_repo=git_repo, + git_commit=git_commit, + git_rel_root_env_config=git_rel_root_env_config, ) self._engine = engine self._schema = schema @@ -89,7 +95,7 @@ def _setup(self) -> None: description=self._description, git_repo=self._git_repo, git_commit=self._git_commit, - root_env_config=self._root_env_config, + root_env_config=self._git_rel_root_env_config, ) ) conn.execute( diff --git a/mlos_bench/mlos_bench/storage/sql/storage.py b/mlos_bench/mlos_bench/storage/sql/storage.py index dcad3bdf8a..383f4de146 100644 --- a/mlos_bench/mlos_bench/storage/sql/storage.py +++ b/mlos_bench/mlos_bench/storage/sql/storage.py @@ -156,9 +156,13 @@ def get_experiment_by_id( experiment_id=exp.exp_id, trial_id=-1, # will be loaded upon __enter__ which calls _setup() description=exp.description, - root_env_config=exp.root_env_config, + # Use special logic to load the experiment root config info directly. + root_env_config=None, tunables=tunables, opt_targets=opt_targets, + git_repo=exp.git_repo, + git_commit=exp.git_commit, + git_rel_root_env_config=exp.root_env_config, ) def experiment( # pylint: disable=too-many-arguments diff --git a/mlos_bench/mlos_bench/tests/storage/exp_data_test.py b/mlos_bench/mlos_bench/tests/storage/exp_data_test.py index dc8baf489c..310f22d98d 100644 --- a/mlos_bench/mlos_bench/tests/storage/exp_data_test.py +++ b/mlos_bench/mlos_bench/tests/storage/exp_data_test.py @@ -30,7 +30,7 @@ def test_exp_data_root_env_config( """Tests the root_env_config property of ExperimentData.""" # pylint: disable=protected-access assert exp_data.root_env_config == ( - exp_storage._root_env_config, + exp_storage._git_rel_root_env_config, exp_storage._git_repo, exp_storage._git_commit, ) diff --git a/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py b/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py index 0b25f963a9..6be568eafb 100644 --- a/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py +++ b/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py @@ -241,7 +241,7 @@ def exp_storage( with storage.experiment( experiment_id="Test-001", trial_id=1, - root_env_config="environment.jsonc", + root_env_config="my-environment.jsonc", description="pytest experiment", tunables=tunable_groups, opt_targets={"score": "min"}, @@ -375,7 +375,7 @@ def _dummy_run_exp( trial_runners=trial_runners, optimizer=opt, storage=storage, - root_env_config=exp.root_env_config, + root_env_config=exp.abs_root_env_config or "ERROR-UNKNOWN.jsonc", ) # Add some trial data to that experiment by "running" it. diff --git a/mlos_bench/mlos_bench/tests/storage/sql/test_storage_schemas.py b/mlos_bench/mlos_bench/tests/storage/sql/test_storage_schemas.py index 7036419a48..f1d1b4c353 100644 --- a/mlos_bench/mlos_bench/tests/storage/sql/test_storage_schemas.py +++ b/mlos_bench/mlos_bench/tests/storage/sql/test_storage_schemas.py @@ -17,9 +17,8 @@ # See Also: schema.py for an example of programmatic alembic config access. CURRENT_ALEMBIC_HEAD = "b61aa446e724" -# Try to test multiple DBMS engines. - +# Try to test multiple DBMS engines. @pytest.mark.parametrize( "some_sql_storage_fixture", [ diff --git a/mlos_bench/mlos_bench/tests/storage/test_storage_pickling.py b/mlos_bench/mlos_bench/tests/storage/test_storage_pickling.py index f4cc1793b9..21c204283c 100644 --- a/mlos_bench/mlos_bench/tests/storage/test_storage_pickling.py +++ b/mlos_bench/mlos_bench/tests/storage/test_storage_pickling.py @@ -72,7 +72,7 @@ def test_storage_pickle_restore_experiment_and_trial( assert restored_experiment is not experiment assert restored_experiment.experiment_id == experiment.experiment_id assert restored_experiment.description == experiment.description - assert restored_experiment.root_env_config == experiment.root_env_config + assert restored_experiment.rel_root_env_config == experiment.rel_root_env_config assert restored_experiment.tunables == experiment.tunables assert restored_experiment.opt_targets == experiment.opt_targets with restored_experiment: diff --git a/mlos_bench/mlos_bench/tests/util_git_test.py b/mlos_bench/mlos_bench/tests/util_git_test.py index 77fd2779c7..88fae7c7f4 100644 --- a/mlos_bench/mlos_bench/tests/util_git_test.py +++ b/mlos_bench/mlos_bench/tests/util_git_test.py @@ -3,14 +3,95 @@ # Licensed under the MIT License. # """Unit tests for get_git_info utility function.""" +import os import re +import tempfile +from pathlib import Path +from subprocess import CalledProcessError +from subprocess import check_call as run -from mlos_bench.util import get_git_info +import pytest + +from mlos_bench.util import get_git_info, get_git_root, path_join def test_get_git_info() -> None: - """Check that we can retrieve git info about the current repository correctly.""" - (git_repo, git_commit, rel_path) = get_git_info(__file__) + """Check that we can retrieve git info about the current repository correctly from a + file. + """ + (git_repo, git_commit, rel_path, abs_path) = get_git_info(__file__) assert "mlos" in git_repo.lower() assert re.match(r"[0-9a-f]{40}", git_commit) is not None assert rel_path == "mlos_bench/mlos_bench/tests/util_git_test.py" + assert abs_path == path_join(__file__, abs_path=True) + + +def test_get_git_info_dir() -> None: + """Check that we can retrieve git info about the current repository correctly from a + directory. + """ + dirname = os.path.dirname(__file__) + (git_repo, git_commit, rel_path, abs_path) = get_git_info(dirname) + assert "mlos" in git_repo.lower() + assert re.match(r"[0-9a-f]{40}", git_commit) is not None + assert rel_path == "mlos_bench/mlos_bench/tests" + assert abs_path == path_join(dirname, abs_path=True) + + +def test_non_git_dir() -> None: + """Check that we can handle a non-git directory.""" + with tempfile.TemporaryDirectory() as non_git_dir: + with pytest.raises(CalledProcessError): + # This should raise an error because the directory is not a git repository. + get_git_root(non_git_dir) + + +def test_non_upstream_git() -> None: + """Check that we can handle a git directory without an upstream.""" + with tempfile.TemporaryDirectory() as local_git_dir: + local_git_dir = path_join(local_git_dir, abs_path=True) + # Initialize a new git repository. + run(["git", "init", local_git_dir, "-b", "main"]) + run(["git", "-C", local_git_dir, "config", "--local", "user.email", "pytest@example.com"]) + run(["git", "-C", local_git_dir, "config", "--local", "user.name", "PyTest User"]) + Path(local_git_dir).joinpath("README.md").touch() + run(["git", "-C", local_git_dir, "add", "README.md"]) + run(["git", "-C", local_git_dir, "commit", "-m", "Initial commit"]) + # This should have slightly different behavior when there is no upstream. + (git_repo, _git_commit, rel_path, abs_path) = get_git_info(local_git_dir) + assert git_repo == f"file://{local_git_dir}" + assert abs_path == local_git_dir + assert rel_path == "." + + +@pytest.mark.skipif( + os.environ.get("GITHUB_ACTIONS") != "true", + reason="Not running in GitHub Actions CI.", +) +def test_github_actions_git_info() -> None: + """ + Test that get_git_info matches GitHub Actions environment variables if running in + CI. + + Examples + -------- + Test locally with the following command: + + .. code-block:: shell + + export GITHUB_ACTIONS=true + export GITHUB_SHA=$(git rev-parse HEAD) + # GITHUB_REPOSITORY should be in "owner/repo" format. + # e.g., GITHUB_REPOSITORY="bpkroth/MLOS" or "microsoft/MLOS" + export GITHUB_REPOSITORY=$(git rev-parse --abbrev-ref --symbolic-full-name HEAD@{u} | cut -d/ -f1 | xargs git remote get-url | grep https://github.com | cut -d/ -f4-) + pytest -n0 mlos_bench/mlos_bench/tests/util_git_test.py + """ # pylint: disable=line-too-long # noqa: E501 + repo_env = os.environ.get("GITHUB_REPOSITORY") # "owner/repo" format + sha_env = os.environ.get("GITHUB_SHA") + assert repo_env, "GITHUB_REPOSITORY not set in environment." + assert sha_env, "GITHUB_SHA not set in environment." + git_repo, git_commit, _rel_path, _abs_path = get_git_info(__file__) + assert git_repo.endswith(repo_env), f"git_repo '{git_repo}' does not end with '{repo_env}'" + assert ( + git_commit == sha_env + ), f"git_commit '{git_commit}' does not match GITHUB_SHA '{sha_env}'" diff --git a/mlos_bench/mlos_bench/util.py b/mlos_bench/mlos_bench/util.py index 7cf9bc640c..006e0e0a54 100644 --- a/mlos_bench/mlos_bench/util.py +++ b/mlos_bench/mlos_bench/util.py @@ -153,7 +153,7 @@ def path_join(*args: str, abs_path: bool = False) -> str: """ path = os.path.join(*args) if abs_path: - path = os.path.abspath(path) + path = os.path.realpath(path) return os.path.normpath(path).replace("\\", "/") @@ -274,33 +274,150 @@ def check_required_params(config: Mapping[str, Any], required_params: Iterable[s ) -def get_git_info(path: str = __file__) -> tuple[str, str, str]: +def get_git_root(path: str = __file__) -> str: """ - Get the git repository, commit hash, and local path of the given file. + Get the root dir of the git repository. + + Parameters + ---------- + path : Optional[str] + Path to the file in git repository. + + Raises + ------ + subprocess.CalledProcessError + If the path is not a git repository or the command fails. + + Returns + ------- + str + The absolute path to the root directory of the git repository. + """ + abspath = path_join(path, abs_path=True) + if not os.path.exists(abspath) or not os.path.isdir(abspath): + dirname = os.path.dirname(abspath) + else: + dirname = abspath + git_root = subprocess.check_output( + ["git", "-C", dirname, "rev-parse", "--show-toplevel"], text=True + ).strip() + return path_join(git_root, abs_path=True) + + +def get_git_remote_info(path: str, remote: str) -> str: + """ + Gets the remote URL for the given remote name in the git repository. Parameters ---------- path : str Path to the file in git repository. + remote : str + The name of the remote (e.g., "origin"). + + Raises + ------ + subprocess.CalledProcessError + If the command fails or the remote does not exist. Returns ------- - (git_repo, git_commit, git_path) : tuple[str, str, str] - Git repository URL, last commit hash, and relative file path. + str + The URL of the remote repository. """ - dirname = os.path.dirname(path) - git_repo = subprocess.check_output( - ["git", "-C", dirname, "remote", "get-url", "origin"], text=True + return subprocess.check_output( + ["git", "-C", path, "remote", "get-url", remote], text=True ).strip() + + +def get_git_repo_info(path: str) -> str: + """ + Get the git repository URL for the given git repo. + + Tries to get the upstream branch URL, falling back to the "origin" remote + if the upstream branch is not set or does not exist. If that also fails, + it returns a file URL pointing to the local path. + + Parameters + ---------- + path : str + Path to the git repository. + + Raises + ------ + subprocess.CalledProcessError + If the command fails or the git repository does not exist. + + Returns + ------- + str + The upstream URL of the git repository. + """ + # In case "origin" remote is not set, or this branch has a different + # upstream, we should handle it gracefully. + # (e.g., fallback to the first one we find?) + path = path_join(path, abs_path=True) + cmd = ["git", "-C", path, "rev-parse", "--abbrev-ref", "--symbolic-full-name", "HEAD@{u}"] + try: + git_remote = subprocess.check_output(cmd, text=True).strip() + git_remote = git_remote.split("/", 1)[0] + git_repo = get_git_remote_info(path, git_remote) + except subprocess.CalledProcessError: + git_remote = "origin" + _LOG.warning( + "Failed to get the upstream branch for %s. Falling back to '%s' remote.", + path, + git_remote, + ) + try: + git_repo = get_git_remote_info(path, git_remote) + except subprocess.CalledProcessError: + git_repo = "file://" + path + _LOG.warning( + "Failed to get the upstream branch for %s. Falling back to '%s'.", + path, + git_repo, + ) + return git_repo + + +def get_git_info(path: str = __file__) -> tuple[str, str, str, str]: + """ + Get the git repository, commit hash, and local path of the given file. + + Parameters + ---------- + path : str + Path to the file in git repository. + + Raises + ------ + subprocess.CalledProcessError + If the path is not a git repository or the command fails. + + Returns + ------- + (git_repo, git_commit, rel_path, abs_path) : tuple[str, str, str, str] + Git repository URL, last commit hash, and relative file path and current + absolute path. + """ + abspath = path_join(path, abs_path=True) + if not os.path.exists(abspath) or not os.path.isdir(abspath): + dirname = os.path.dirname(abspath) + else: + dirname = abspath + git_root = get_git_root(path=abspath) + git_repo = get_git_repo_info(git_root) git_commit = subprocess.check_output( ["git", "-C", dirname, "rev-parse", "HEAD"], text=True ).strip() - git_root = subprocess.check_output( - ["git", "-C", dirname, "rev-parse", "--show-toplevel"], text=True - ).strip() - _LOG.debug("Current git branch: %s %s", git_repo, git_commit) - rel_path = os.path.relpath(os.path.abspath(path), os.path.abspath(git_root)) - return (git_repo, git_commit, rel_path.replace("\\", "/")) + _LOG.debug("Current git branch for %s: %s %s", git_root, git_repo, git_commit) + rel_path = os.path.relpath(abspath, os.path.abspath(git_root)) + # TODO: return the branch too? + return (git_repo, git_commit, rel_path.replace("\\", "/"), abspath) + + +# TODO: Add support for checking out the branch locally. # Note: to avoid circular imports, we don't specify TunableValue here.