Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions xtuner/v1/utils/env_check.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Any, Callable, List


def check_torch_accelerator_available():
"""Check if PyTorch is installed and the torch accelerator is available.

Expand All @@ -14,21 +13,24 @@ def check_torch_accelerator_available():
except Exception:
return False


def check_triton_available():
"""Check if Triton is installed.

Returns:
bool: True if Triton is installed, False otherwise.
"""
import os

if os.environ.get("XTUNER_USE_TRITON", "1") == "0":
return False

try:
import triton # noqa: F401

return True
except ImportError:
return False


def get_env_not_available_func(env_name_list: List[str]) -> Callable:
"""Get a function that raises an error indicating the environment is not
available.
Expand All @@ -42,7 +44,6 @@ def env_not_available_func(*args: Any, **kwargs: Any) -> Any:

return env_not_available_func


def get_rollout_engine_version() -> dict:
import os

Expand Down
Loading