Skip to content

Commit 82e37df

Browse files
More robust PTAX finding (#16229)
Earlier attempt failed some internal tests on machines with weird paths. Trying if this is more robust.
1 parent 7b651e1 commit 82e37df

File tree

1 file changed

+56
-7
lines changed

1 file changed

+56
-7
lines changed

backends/cuda/cuda_backend.py

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import os
8+
import shutil
89
import typing
910
from importlib import resources
10-
from typing import Any, Dict, final, List
11+
from typing import Any, Dict, final, List, Optional
1112

1213
import torch
1314
from executorch.backends.aoti.aoti_backend import AotiBackend
@@ -36,6 +37,57 @@ class CudaBackend(AotiBackend, BackendDetails):
3637
def get_device_name(cls) -> str:
3738
return "cuda"
3839

40+
@staticmethod
41+
def _find_ptxas_for_version(cuda_version: str) -> Optional[str]: # noqa: C901
42+
"""
43+
Find ptxas binary that matches the expected CUDA version.
44+
Returns the path to ptxas if found and version matches, None otherwise.
45+
"""
46+
expected_version_marker = f"/cuda-{cuda_version}/"
47+
48+
def _validate_ptxas_version(path: str) -> bool:
49+
"""Check if ptxas at given path matches expected CUDA version."""
50+
if not os.path.exists(path):
51+
return False
52+
resolved = os.path.realpath(path)
53+
return expected_version_marker in resolved
54+
55+
# 1. Try PyTorch's CUDA_HOME
56+
try:
57+
from torch.utils.cpp_extension import CUDA_HOME
58+
59+
if CUDA_HOME:
60+
ptxas_path = os.path.join(CUDA_HOME, "bin", "ptxas")
61+
if _validate_ptxas_version(ptxas_path):
62+
return ptxas_path
63+
except ImportError:
64+
pass
65+
66+
# 2. Try CUDA_HOME / CUDA_PATH environment variables
67+
for env_var in ("CUDA_HOME", "CUDA_PATH", "CUDA_ROOT"):
68+
cuda_home = os.environ.get(env_var)
69+
if cuda_home:
70+
ptxas_path = os.path.join(cuda_home, "bin", "ptxas")
71+
if _validate_ptxas_version(ptxas_path):
72+
return ptxas_path
73+
74+
# 3. Try versioned path directly
75+
versioned_path = f"/usr/local/cuda-{cuda_version}/bin/ptxas"
76+
if os.path.exists(versioned_path):
77+
return versioned_path
78+
79+
# 4. Try system PATH via shutil.which
80+
ptxas_in_path = shutil.which("ptxas")
81+
if ptxas_in_path and _validate_ptxas_version(ptxas_in_path):
82+
return ptxas_in_path
83+
84+
# 5. Try default symlink path as last resort
85+
default_path = "/usr/local/cuda/bin/ptxas"
86+
if _validate_ptxas_version(default_path):
87+
return default_path
88+
89+
return None
90+
3991
@staticmethod
4092
def _setup_cuda_environment_for_fatbin() -> bool:
4193
"""
@@ -57,12 +109,9 @@ def _setup_cuda_environment_for_fatbin() -> bool:
57109

58110
# Set TRITON_PTXAS_PATH for CUDA 12.6+
59111
if major == 12 and minor >= 6:
60-
# Try versioned path first, fallback to symlinked path
61-
ptxas_path = f"/usr/local/cuda-{cuda_version}/bin/ptxas"
62-
if not os.path.exists(ptxas_path):
63-
ptxas_path = "/usr/local/cuda/bin/ptxas"
64-
if not os.path.exists(ptxas_path):
65-
return False
112+
ptxas_path = CudaBackend._find_ptxas_for_version(cuda_version)
113+
if ptxas_path is None:
114+
return False
66115
os.environ["TRITON_PTXAS_PATH"] = ptxas_path
67116

68117
# Get compute capability of current CUDA device

0 commit comments

Comments
 (0)