55# LICENSE file in the root directory of this source tree.
66
77import os
8+ import shutil
89import typing
910from importlib import resources
10- from typing import Any , Dict , final , List
11+ from typing import Any , Dict , final , List , Optional
1112
1213import torch
1314from executorch .backends .aoti .aoti_backend import AotiBackend
@@ -36,6 +37,57 @@ class CudaBackend(AotiBackend, BackendDetails):
3637 def get_device_name (cls ) -> str :
3738 return "cuda"
3839
40+ @staticmethod
41+ def _find_ptxas_for_version (cuda_version : str ) -> Optional [str ]: # noqa: C901
42+ """
43+ Find ptxas binary that matches the expected CUDA version.
44+ Returns the path to ptxas if found and version matches, None otherwise.
45+ """
46+ expected_version_marker = f"/cuda-{ cuda_version } /"
47+
48+ def _validate_ptxas_version (path : str ) -> bool :
49+ """Check if ptxas at given path matches expected CUDA version."""
50+ if not os .path .exists (path ):
51+ return False
52+ resolved = os .path .realpath (path )
53+ return expected_version_marker in resolved
54+
55+ # 1. Try PyTorch's CUDA_HOME
56+ try :
57+ from torch .utils .cpp_extension import CUDA_HOME
58+
59+ if CUDA_HOME :
60+ ptxas_path = os .path .join (CUDA_HOME , "bin" , "ptxas" )
61+ if _validate_ptxas_version (ptxas_path ):
62+ return ptxas_path
63+ except ImportError :
64+ pass
65+
66+ # 2. Try CUDA_HOME / CUDA_PATH environment variables
67+ for env_var in ("CUDA_HOME" , "CUDA_PATH" , "CUDA_ROOT" ):
68+ cuda_home = os .environ .get (env_var )
69+ if cuda_home :
70+ ptxas_path = os .path .join (cuda_home , "bin" , "ptxas" )
71+ if _validate_ptxas_version (ptxas_path ):
72+ return ptxas_path
73+
74+ # 3. Try versioned path directly
75+ versioned_path = f"/usr/local/cuda-{ cuda_version } /bin/ptxas"
76+ if os .path .exists (versioned_path ):
77+ return versioned_path
78+
79+ # 4. Try system PATH via shutil.which
80+ ptxas_in_path = shutil .which ("ptxas" )
81+ if ptxas_in_path and _validate_ptxas_version (ptxas_in_path ):
82+ return ptxas_in_path
83+
84+ # 5. Try default symlink path as last resort
85+ default_path = "/usr/local/cuda/bin/ptxas"
86+ if _validate_ptxas_version (default_path ):
87+ return default_path
88+
89+ return None
90+
3991 @staticmethod
4092 def _setup_cuda_environment_for_fatbin () -> bool :
4193 """
@@ -57,12 +109,9 @@ def _setup_cuda_environment_for_fatbin() -> bool:
57109
58110 # Set TRITON_PTXAS_PATH for CUDA 12.6+
59111 if major == 12 and minor >= 6 :
60- # Try versioned path first, fallback to symlinked path
61- ptxas_path = f"/usr/local/cuda-{ cuda_version } /bin/ptxas"
62- if not os .path .exists (ptxas_path ):
63- ptxas_path = "/usr/local/cuda/bin/ptxas"
64- if not os .path .exists (ptxas_path ):
65- return False
112+ ptxas_path = CudaBackend ._find_ptxas_for_version (cuda_version )
113+ if ptxas_path is None :
114+ return False
66115 os .environ ["TRITON_PTXAS_PATH" ] = ptxas_path
67116
68117 # Get compute capability of current CUDA device
0 commit comments