Skip to content

Commit b7c2d9a

Browse files
committed
triton-cpu
stack-info: PR: #1037, branch: oulgen/stack/163
1 parent d5418aa commit b7c2d9a

File tree

3 files changed

+42
-9
lines changed

3 files changed

+42
-9
lines changed

.github/matrix.json

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,16 @@
7171
"container-options": "--device=/dev/kfd --device=/dev/dri",
7272
"pytorch-version": "pytorch-nightly",
7373
"alias": "mi325x"
74+
},
75+
{
76+
"runner": "linux.g5.4xlarge.nvidia.gpu",
77+
"python-version": "3.12",
78+
"ref-eager": false,
79+
"image": "nvidia/cuda:12.8.1-devel-ubuntu24.04",
80+
"runtime-version": "cpu",
81+
"container-options": "--gpus all",
82+
"pytorch-version": "pytorch-nightly",
83+
"alias": "cpu"
7484
}
7585
]
7686
}

.github/workflows/test.yml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ jobs:
9797
fi
9898
9999
- name: Install Triton
100-
if: steps.cache.outputs.cache-hit != 'true' && matrix.pytorch-version != 'pytorch-2.9'
100+
if: steps.cache.outputs.cache-hit != 'true' && (matrix.pytorch-version != 'pytorch-2.9' || contains(matrix.alias, 'cpu'))
101101
run: |
102102
set -x
103103
source .venv/bin/activate
@@ -110,7 +110,15 @@ jobs:
110110
cd /tmp/$USER
111111
uv pip uninstall triton pytorch-triton || true
112112
rm -rf triton/ || true
113-
git clone https://github.com/triton-lang/triton.git
113+
if [[ "${{ matrix.alias }}" == *cpu* ]]; then
114+
sudo apt install libs sleef-dev
115+
REPO_URL="https://github.com/triton-lang/triton-cpu.git"
116+
REPO_BRANCH="main-merged"
117+
else
118+
REPO_URL="https://github.com/triton-lang/triton.git"
119+
REPO_BRANCH="main"
120+
fi
121+
git clone -b "$REPO_BRANCH" "$REPO_URL" triton
114122
cd triton/
115123
uv pip install -r python/requirements.txt
116124
MAX_JOBS=$(nproc) TRITON_PARALLEL_LINK_JOBS=2 uv pip install .

helion/_testing.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,32 @@
3434
from .runtime.kernel import Kernel
3535

3636

37-
DEVICE = torch.device("xpu") if torch.xpu.is_available() else torch.device("cuda")
38-
PROJECT_ROOT: Path = Path(__file__).parent.parent
39-
EXAMPLES_DIR: Path = PROJECT_ROOT / "examples"
37+
def _get_triton_backend() -> str | None:
38+
try:
39+
return triton.runtime.driver.active.get_current_target().backend # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess]
40+
except Exception:
41+
return None
42+
43+
44+
def is_cpu() -> bool:
45+
"""Return True if running on Triton CPU backend."""
46+
return _get_triton_backend() == "cpu"
4047

4148

4249
def is_cuda() -> bool:
4350
"""Return True if running on CUDA (NVIDIA GPU)."""
44-
return (
45-
triton.runtime.driver.active.get_current_target().backend == "cuda" # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess]
46-
and DEVICE.type == "cuda"
47-
)
51+
return _get_triton_backend() == "cuda" and torch.cuda.is_available()
52+
53+
54+
PROJECT_ROOT: Path = Path(__file__).parent.parent
55+
EXAMPLES_DIR: Path = PROJECT_ROOT / "examples"
56+
57+
if is_cpu():
58+
DEVICE = torch.device("cpu")
59+
elif torch.xpu.is_available():
60+
DEVICE = torch.device("xpu")
61+
else:
62+
DEVICE = torch.device("cuda")
4863

4964

5065
def get_nvidia_gpu_model() -> str:

0 commit comments

Comments
 (0)