File tree Expand file tree Collapse file tree 3 files changed +42
-9
lines changed Expand file tree Collapse file tree 3 files changed +42
-9
lines changed Original file line number Diff line number Diff line change 7171      "container-options" : " --device=/dev/kfd --device=/dev/dri" 
7272      "pytorch-version" : " pytorch-nightly" 
7373      "alias" : " mi325x" 
74+     },
75+     {
76+       "runner" : " linux.g5.4xlarge.nvidia.gpu" 
77+       "python-version" : " 3.12" 
78+       "ref-eager" : false ,
79+       "image" : " nvidia/cuda:12.8.1-devel-ubuntu24.04" 
80+       "runtime-version" : " cpu" 
81+       "container-options" : " --gpus all" 
82+       "pytorch-version" : " pytorch-nightly" 
83+       "alias" : " cpu" 
7484    }
7585  ]
7686}
Original file line number Diff line number Diff line change 9797          fi 
9898
9999name : Install Triton 
100-         if : steps.cache.outputs.cache-hit != 'true' && matrix.pytorch-version != 'pytorch-2.9' 
100+         if : steps.cache.outputs.cache-hit != 'true' && ( matrix.pytorch-version != 'pytorch-2.9' || contains(matrix.alias, 'cpu'))  
101101        run : | 
102102          set -x 
103103          source .venv/bin/activate 
@@ -110,7 +110,15 @@ jobs:
110110          cd /tmp/$USER 
111111          uv pip uninstall triton pytorch-triton || true  
112112          rm -rf triton/ || true  
113-           git clone https://github.com/triton-lang/triton.git 
113+           if [[ "${{ matrix.alias }}" == *cpu* ]]; then 
114+             sudo apt install libs sleef-dev 
115+             REPO_URL="https://github.com/triton-lang/triton-cpu.git" 
116+             REPO_BRANCH="main-merged" 
117+           else 
118+             REPO_URL="https://github.com/triton-lang/triton.git" 
119+             REPO_BRANCH="main" 
120+           fi 
121+           git clone -b "$REPO_BRANCH" "$REPO_URL" triton 
114122          cd triton/ 
115123          uv pip install -r python/requirements.txt 
116124          MAX_JOBS=$(nproc) TRITON_PARALLEL_LINK_JOBS=2 uv pip install . 
Original file line number Diff line number Diff line change 3434    from  .runtime .kernel  import  Kernel 
3535
3636
37- DEVICE  =  torch .device ("xpu" ) if  torch .xpu .is_available () else  torch .device ("cuda" )
38- PROJECT_ROOT : Path  =  Path (__file__ ).parent .parent 
39- EXAMPLES_DIR : Path  =  PROJECT_ROOT  /  "examples" 
37+ def  _get_triton_backend () ->  str  |  None :
38+     try :
39+         return  triton .runtime .driver .active .get_current_target ().backend   # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess] 
40+     except  Exception :
41+         return  None 
42+ 
43+ 
44+ def  is_cpu () ->  bool :
45+     """Return True if running on Triton CPU backend.""" 
46+     return  _get_triton_backend () ==  "cpu" 
4047
4148
4249def  is_cuda () ->  bool :
4350    """Return True if running on CUDA (NVIDIA GPU).""" 
44-     return  (
45-         triton .runtime .driver .active .get_current_target ().backend  ==  "cuda"   # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess] 
46-         and  DEVICE .type  ==  "cuda" 
47-     )
51+     return  _get_triton_backend () ==  "cuda"  and  torch .cuda .is_available ()
52+ 
53+ 
54+ PROJECT_ROOT : Path  =  Path (__file__ ).parent .parent 
55+ EXAMPLES_DIR : Path  =  PROJECT_ROOT  /  "examples" 
56+ 
57+ if  is_cpu ():
58+     DEVICE  =  torch .device ("cpu" )
59+ elif  torch .xpu .is_available ():
60+     DEVICE  =  torch .device ("xpu" )
61+ else :
62+     DEVICE  =  torch .device ("cuda" )
4863
4964
5065def  get_nvidia_gpu_model () ->  str :
 
 
   
 
     
   
   
          
    
    
     
    
      
     
     
    You can’t perform that action at this time.
  
 
    
  
    
      
        
     
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments