Skip to content
2 changes: 1 addition & 1 deletion mlos_bench/mlos_bench/storage/base_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def __init__( # pylint: disable=too-many-arguments
self._tunables = tunables.copy()
self._trial_id = trial_id
self._experiment_id = experiment_id
(self._git_repo, self._git_commit, self._root_env_config) = get_git_info(
(self._git_repo, self._git_commit, self._root_env_config, _future_pr) = get_git_info(
root_env_config
)
self._description = description
Expand Down
87 changes: 84 additions & 3 deletions mlos_bench/mlos_bench/tests/util_git_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,95 @@
# Licensed under the MIT License.
#
"""Unit tests for get_git_info utility function."""
import os
import re
import tempfile
from pathlib import Path
from subprocess import CalledProcessError

from mlos_bench.util import get_git_info
import pytest

from mlos_bench.util import get_git_info, get_git_root, path_join


def test_get_git_info() -> None:
"""Check that we can retrieve git info about the current repository correctly."""
(git_repo, git_commit, rel_path) = get_git_info(__file__)
"""Check that we can retrieve git info about the current repository correctly from a
file.
"""
(git_repo, git_commit, rel_path, abs_path) = get_git_info(__file__)
assert "mlos" in git_repo.lower()
assert re.match(r"[0-9a-f]{40}", git_commit) is not None
assert rel_path == "mlos_bench/mlos_bench/tests/util_git_test.py"
assert abs_path == path_join(__file__, abs_path=True)


def test_get_git_info_dir() -> None:
"""Check that we can retrieve git info about the current repository correctly from a
directory.
"""
dirname = os.path.dirname(__file__)
(git_repo, git_commit, rel_path, abs_path) = get_git_info(dirname)
assert "mlos" in git_repo.lower()
assert re.match(r"[0-9a-f]{40}", git_commit) is not None
assert rel_path == "mlos_bench/mlos_bench/tests"
assert abs_path == path_join(dirname, abs_path=True)


def test_non_git_dir() -> None:
"""Check that we can handle a non-git directory."""
with tempfile.TemporaryDirectory() as non_git_dir:
with pytest.raises(CalledProcessError):
# This should raise an error because the directory is not a git repository.
get_git_root(non_git_dir)


def test_non_upstream_git() -> None:
"""Check that we can handle a git directory without an upstream."""
with tempfile.TemporaryDirectory() as non_upstream_git_dir:
non_upstream_git_dir = path_join(non_upstream_git_dir, abs_path=True)
# Initialize a new git repository.
os.system(f"git init {non_upstream_git_dir} -b main")
os.system(f"git -C {non_upstream_git_dir} config --local user.email '[email protected]'")
os.system(f"git -C {non_upstream_git_dir} config --local user.name 'PyTest User'")
Path(non_upstream_git_dir).joinpath("README.md").touch()
os.system(f"git -C {non_upstream_git_dir} add README.md")
os.system(f"git -C {non_upstream_git_dir} commit -m 'Initial commit'")
# This should raise an error because the repository has no upstream.
(git_repo, _git_commit, rel_path, abs_path) = get_git_info(non_upstream_git_dir)
assert git_repo == f"file://{non_upstream_git_dir}"
assert abs_path == non_upstream_git_dir
assert rel_path == "."


@pytest.mark.skipif(
os.environ.get("GITHUB_ACTIONS") != "true",
reason="Not running in GitHub Actions CI.",
)
def test_github_actions_git_info() -> None:
"""
Test that get_git_info matches GitHub Actions environment variables if
running in CI.

Examples
--------
Test locally with the following command:

.. code-block:: shell

export GITHUB_ACTIONS=true
GITHUB_SHA=$(git rev-parse HEAD)
# GITHUB_REPOSITORY should be in "owner/repo" format.
# e.g., GITHUB_REPOSITORY="bpkroth/MLOS" or "microsoft/MLOS"
GITHUB_REPOSITORY=$(git rev-parse --abbrev-ref --symbolic-full-name HEAD@{u} | cut -d/ -f1 | xargs git remote get-url | grep https://github.com | cut -d/ -f4-)
pytest -n0 mlos_bench/mlos_bench/tests/util_git_test.py
```
""" # pylint: disable=line-too-long # noqa: E501
repo_env = os.environ.get("GITHUB_REPOSITORY") # "owner/repo" format
sha_env = os.environ.get("GITHUB_SHA")
assert repo_env, "GITHUB_REPOSITORY not set in environment."
assert sha_env, "GITHUB_SHA not set in environment."
git_repo, git_commit, _rel_path, _abs_path = get_git_info(__file__)
assert git_repo.endswith(repo_env), f"git_repo '{git_repo}' does not end with '{repo_env}'"
assert (
git_commit == sha_env
), f"git_commit '{git_commit}' does not match GITHUB_SHA '{sha_env}'"
142 changes: 128 additions & 14 deletions mlos_bench/mlos_bench/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def path_join(*args: str, abs_path: bool = False) -> str:
"""
path = os.path.join(*args)
if abs_path:
path = os.path.abspath(path)
path = os.path.realpath(path)
return os.path.normpath(path).replace("\\", "/")


Expand Down Expand Up @@ -274,33 +274,147 @@ def check_required_params(config: Mapping[str, Any], required_params: Iterable[s
)


def get_git_info(path: str = __file__) -> tuple[str, str, str]:
def get_git_root(path: str = __file__) -> str:
"""
Get the git repository, commit hash, and local path of the given file.
Get the root dir of the git repository.

Parameters
----------
path : str, optional
Path to the file in git repository.

Raises
------
subprocess.CalledProcessError
If the path is not a git repository or the command fails.

Returns
-------
str
The absolute path to the root directory of the git repository.
"""
abspath = path_join(path, abs_path=True)
if not os.path.exists(abspath) or not os.path.isdir(abspath):
dirname = os.path.dirname(abspath)
else:
dirname = abspath
git_root = subprocess.check_output(
["git", "-C", dirname, "rev-parse", "--show-toplevel"],
text=True,
).strip()
return path_join(git_root, abs_path=True)


def get_git_remote_info(path: str, remote: str) -> str:
"""
Gets the remote URL for the given remote name in the git repository.

Parameters
----------
path : str
Path to the file in git repository.
remote : str
The name of the remote (e.g., "origin").

Raises
------
subprocess.CalledProcessError
If the command fails or the remote does not exist.

Returns
-------
(git_repo, git_commit, git_path) : tuple[str, str, str]
Git repository URL, last commit hash, and relative file path.
str
The URL of the remote repository.
"""
dirname = os.path.dirname(path)
git_repo = subprocess.check_output(
["git", "-C", dirname, "remote", "get-url", "origin"], text=True
return subprocess.check_output(
["git", "-C", path, "remote", "get-url", remote], text=True
).strip()


def get_git_repo_info(path: str) -> str:
"""
Get the git repository URL for the given git repo.

Tries to get the upstream branch URL, falling back to the "origin" remote
if the upstream branch is not set or does not exist. If that also fails,
it returns a file URL pointing to the local path.

Parameters
----------
path : str
Path to the git repository.

Raises
------
subprocess.CalledProcessError
If the command fails or the git repository does not exist.

Returns
-------
str
The upstream URL of the git repository.
"""
# In case "origin" remote is not set, or this branch has a different
# upstream, we should handle it gracefully.
# (e.g., fallback to the first one we find?)
path = path_join(path, abs_path=True)
cmd = ["git", "-C", path, "rev-parse", "--abbrev-ref", "--symbolic-full-name", "HEAD@{u}"]
try:
git_remote = subprocess.check_output(cmd, text=True).strip()
git_remote = git_remote.split("/", 1)[0]
git_repo = get_git_remote_info(path, git_remote)
except subprocess.CalledProcessError:
git_remote = "origin"
_LOG.warning(
"Failed to get the upstream branch for %s. Falling back to '%s' remote.",
path,
git_remote,
)
try:
git_repo = get_git_remote_info(path, git_remote)
except subprocess.CalledProcessError:
git_repo = "file://" + path
_LOG.warning(
"Failed to get the upstream branch for %s. Falling back to '%s'.",
path,
git_repo,
)
return git_repo


def get_git_info(path: str = __file__) -> tuple[str, str, str, str]:
"""
Get the git repository, commit hash, and local path of the given file.

Parameters
----------
path : str
Path to the file in git repository.

Raises
------
subprocess.CalledProcessError
If the path is not a git repository or the command fails.

Returns
-------
(git_repo, git_commit, rel_path, abs_path) : tuple[str, str, str, str]
Git repository URL, last commit hash, and relative file path and current
absolute path.
"""
abspath = path_join(path, abs_path=True)
if not os.path.exists(abspath) or not os.path.isdir(abspath):
dirname = os.path.dirname(abspath)
else:
dirname = abspath
git_root = get_git_root(path=abspath)
git_repo = get_git_repo_info(git_root)
git_commit = subprocess.check_output(
["git", "-C", dirname, "rev-parse", "HEAD"], text=True
).strip()
git_root = subprocess.check_output(
["git", "-C", dirname, "rev-parse", "--show-toplevel"], text=True
).strip()
_LOG.debug("Current git branch: %s %s", git_repo, git_commit)
rel_path = os.path.relpath(os.path.abspath(path), os.path.abspath(git_root))
return (git_repo, git_commit, rel_path.replace("\\", "/"))
_LOG.debug("Current git branch for %s: %s %s", git_root, git_repo, git_commit)
rel_path = os.path.relpath(abspath, os.path.abspath(git_root))
return (git_repo, git_commit, rel_path.replace("\\", "/"), abspath)


# Note: to avoid circular imports, we don't specify TunableValue here.
Expand Down
Loading