diff --git a/mlos_bench/mlos_bench/storage/base_storage.py b/mlos_bench/mlos_bench/storage/base_storage.py index 75a84bf0b2..de81bb94d7 100644 --- a/mlos_bench/mlos_bench/storage/base_storage.py +++ b/mlos_bench/mlos_bench/storage/base_storage.py @@ -194,7 +194,7 @@ def __init__( # pylint: disable=too-many-arguments self._tunables = tunables.copy() self._trial_id = trial_id self._experiment_id = experiment_id - (self._git_repo, self._git_commit, self._root_env_config) = get_git_info( + (self._git_repo, self._git_commit, self._root_env_config, _future_pr) = get_git_info( root_env_config ) self._description = description diff --git a/mlos_bench/mlos_bench/tests/util_git_test.py b/mlos_bench/mlos_bench/tests/util_git_test.py index 77fd2779c7..88fae7c7f4 100644 --- a/mlos_bench/mlos_bench/tests/util_git_test.py +++ b/mlos_bench/mlos_bench/tests/util_git_test.py @@ -3,14 +3,95 @@ # Licensed under the MIT License. # """Unit tests for get_git_info utility function.""" +import os import re +import tempfile +from pathlib import Path +from subprocess import CalledProcessError +from subprocess import check_call as run -from mlos_bench.util import get_git_info +import pytest + +from mlos_bench.util import get_git_info, get_git_root, path_join def test_get_git_info() -> None: - """Check that we can retrieve git info about the current repository correctly.""" - (git_repo, git_commit, rel_path) = get_git_info(__file__) + """Check that we can retrieve git info about the current repository correctly from a + file. + """ + (git_repo, git_commit, rel_path, abs_path) = get_git_info(__file__) assert "mlos" in git_repo.lower() assert re.match(r"[0-9a-f]{40}", git_commit) is not None assert rel_path == "mlos_bench/mlos_bench/tests/util_git_test.py" + assert abs_path == path_join(__file__, abs_path=True) + + +def test_get_git_info_dir() -> None: + """Check that we can retrieve git info about the current repository correctly from a + directory. + """ + dirname = os.path.dirname(__file__) + (git_repo, git_commit, rel_path, abs_path) = get_git_info(dirname) + assert "mlos" in git_repo.lower() + assert re.match(r"[0-9a-f]{40}", git_commit) is not None + assert rel_path == "mlos_bench/mlos_bench/tests" + assert abs_path == path_join(dirname, abs_path=True) + + +def test_non_git_dir() -> None: + """Check that we can handle a non-git directory.""" + with tempfile.TemporaryDirectory() as non_git_dir: + with pytest.raises(CalledProcessError): + # This should raise an error because the directory is not a git repository. + get_git_root(non_git_dir) + + +def test_non_upstream_git() -> None: + """Check that we can handle a git directory without an upstream.""" + with tempfile.TemporaryDirectory() as local_git_dir: + local_git_dir = path_join(local_git_dir, abs_path=True) + # Initialize a new git repository. + run(["git", "init", local_git_dir, "-b", "main"]) + run(["git", "-C", local_git_dir, "config", "--local", "user.email", "pytest@example.com"]) + run(["git", "-C", local_git_dir, "config", "--local", "user.name", "PyTest User"]) + Path(local_git_dir).joinpath("README.md").touch() + run(["git", "-C", local_git_dir, "add", "README.md"]) + run(["git", "-C", local_git_dir, "commit", "-m", "Initial commit"]) + # This should have slightly different behavior when there is no upstream. + (git_repo, _git_commit, rel_path, abs_path) = get_git_info(local_git_dir) + assert git_repo == f"file://{local_git_dir}" + assert abs_path == local_git_dir + assert rel_path == "." + + +@pytest.mark.skipif( + os.environ.get("GITHUB_ACTIONS") != "true", + reason="Not running in GitHub Actions CI.", +) +def test_github_actions_git_info() -> None: + """ + Test that get_git_info matches GitHub Actions environment variables if running in + CI. + + Examples + -------- + Test locally with the following command: + + .. code-block:: shell + + export GITHUB_ACTIONS=true + export GITHUB_SHA=$(git rev-parse HEAD) + # GITHUB_REPOSITORY should be in "owner/repo" format. + # e.g., GITHUB_REPOSITORY="bpkroth/MLOS" or "microsoft/MLOS" + export GITHUB_REPOSITORY=$(git rev-parse --abbrev-ref --symbolic-full-name HEAD@{u} | cut -d/ -f1 | xargs git remote get-url | grep https://github.com | cut -d/ -f4-) + pytest -n0 mlos_bench/mlos_bench/tests/util_git_test.py + """ # pylint: disable=line-too-long # noqa: E501 + repo_env = os.environ.get("GITHUB_REPOSITORY") # "owner/repo" format + sha_env = os.environ.get("GITHUB_SHA") + assert repo_env, "GITHUB_REPOSITORY not set in environment." + assert sha_env, "GITHUB_SHA not set in environment." + git_repo, git_commit, _rel_path, _abs_path = get_git_info(__file__) + assert git_repo.endswith(repo_env), f"git_repo '{git_repo}' does not end with '{repo_env}'" + assert ( + git_commit == sha_env + ), f"git_commit '{git_commit}' does not match GITHUB_SHA '{sha_env}'" diff --git a/mlos_bench/mlos_bench/util.py b/mlos_bench/mlos_bench/util.py index 7cf9bc640c..fe09a013fd 100644 --- a/mlos_bench/mlos_bench/util.py +++ b/mlos_bench/mlos_bench/util.py @@ -153,7 +153,7 @@ def path_join(*args: str, abs_path: bool = False) -> str: """ path = os.path.join(*args) if abs_path: - path = os.path.abspath(path) + path = os.path.realpath(path) return os.path.normpath(path).replace("\\", "/") @@ -274,33 +274,150 @@ def check_required_params(config: Mapping[str, Any], required_params: Iterable[s ) -def get_git_info(path: str = __file__) -> tuple[str, str, str]: +def get_git_root(path: str = __file__) -> str: """ - Get the git repository, commit hash, and local path of the given file. + Get the root dir of the git repository. + + Parameters + ---------- + path : Optional[str] + Path to the file in git repository. + + Raises + ------ + subprocess.CalledProcessError + If the path is not a git repository or the command fails. + + Returns + ------- + str + The absolute path to the root directory of the git repository. + """ + abspath = path_join(path, abs_path=True) + if not os.path.exists(abspath) or not os.path.isdir(abspath): + dirname = os.path.dirname(abspath) + else: + dirname = abspath + git_root = subprocess.check_output( + ["git", "-C", dirname, "rev-parse", "--show-toplevel"], text=True + ).strip() + return path_join(git_root, abs_path=True) + + +def get_git_remote_info(path: str, remote: str) -> str: + """ + Gets the remote URL for the given remote name in the git repository. Parameters ---------- path : str Path to the file in git repository. + remote : str + The name of the remote (e.g., "origin"). + + Raises + ------ + subprocess.CalledProcessError + If the command fails or the remote does not exist. Returns ------- - (git_repo, git_commit, git_path) : tuple[str, str, str] - Git repository URL, last commit hash, and relative file path. + str + The URL of the remote repository. """ - dirname = os.path.dirname(path) - git_repo = subprocess.check_output( - ["git", "-C", dirname, "remote", "get-url", "origin"], text=True + return subprocess.check_output( + ["git", "-C", path, "remote", "get-url", remote], text=True ).strip() + + +def get_git_repo_info(path: str) -> str: + """ + Get the git repository URL for the given git repo. + + Tries to get the upstream branch URL, falling back to the "origin" remote + if the upstream branch is not set or does not exist. If that also fails, + it returns a file URL pointing to the local path. + + Parameters + ---------- + path : str + Path to the git repository. + + Raises + ------ + subprocess.CalledProcessError + If the command fails or the git repository does not exist. + + Returns + ------- + str + The upstream URL of the git repository. + """ + # In case "origin" remote is not set, or this branch has a different + # upstream, we should handle it gracefully. + # (e.g., fallback to the first one we find?) + path = path_join(path, abs_path=True) + cmd = ["git", "-C", path, "rev-parse", "--abbrev-ref", "--symbolic-full-name", "HEAD@{u}"] + try: + git_remote = subprocess.check_output(cmd, text=True).strip() + git_remote = git_remote.split("/", 1)[0] + git_repo = get_git_remote_info(path, git_remote) + except subprocess.CalledProcessError: + git_remote = "origin" + _LOG.warning( + "Failed to get the upstream branch for %s. Falling back to '%s' remote.", + path, + git_remote, + ) + try: + git_repo = get_git_remote_info(path, git_remote) + except subprocess.CalledProcessError: + git_repo = "file://" + path + _LOG.warning( + "Failed to get the upstream branch for %s. Falling back to '%s'.", + path, + git_repo, + ) + return git_repo + + +def get_git_info(path: str = __file__) -> tuple[str, str, str, str]: + """ + Get the git repository, commit hash, and local path of the given file. + + Parameters + ---------- + path : str + Path to the file in git repository. + + Raises + ------ + subprocess.CalledProcessError + If the path is not a git repository or the command fails. + + Returns + ------- + (git_repo, git_commit, rel_path, abs_path) : tuple[str, str, str, str] + Git repository URL, last commit hash, and relative file path and current + absolute path. + """ + abspath = path_join(path, abs_path=True) + if os.path.exists(abspath) and os.path.isdir(abspath): + dirname = abspath + else: + dirname = os.path.dirname(abspath) + git_root = get_git_root(path=abspath) + git_repo = get_git_repo_info(git_root) git_commit = subprocess.check_output( ["git", "-C", dirname, "rev-parse", "HEAD"], text=True ).strip() - git_root = subprocess.check_output( - ["git", "-C", dirname, "rev-parse", "--show-toplevel"], text=True - ).strip() - _LOG.debug("Current git branch: %s %s", git_repo, git_commit) - rel_path = os.path.relpath(os.path.abspath(path), os.path.abspath(git_root)) - return (git_repo, git_commit, rel_path.replace("\\", "/")) + _LOG.debug("Current git branch for %s: %s %s", git_root, git_repo, git_commit) + rel_path = os.path.relpath(abspath, os.path.abspath(git_root)) + # TODO: return the branch too? + return (git_repo, git_commit, rel_path.replace("\\", "/"), abspath) + + +# TODO: Add support for checking out the branch locally. # Note: to avoid circular imports, we don't specify TunableValue here.