File tree Expand file tree Collapse file tree 2 files changed +3
-4
lines changed Expand file tree Collapse file tree 2 files changed +3
-4
lines changed Original file line number Diff line number Diff line change @@ -70,9 +70,7 @@ RUN if [ "$framework" = "jax" ] ; then \
7070        echo "Installing Jax GPU"  \
7171        && cd /algorithmic-efficiency \
7272        && pip install -e '.[pytorch_cpu, full]'  --extra-index-url https://download.pytorch.org/whl/cpu \
73-         #  Todo: remove temporary nightly install
7473        && pip install -e '.[jax_gpu]' ; \
75-         #  &&  pip install -U --pre jax jaxlib "jax-cuda12-plugin[with-cuda]" jax-cuda12-pjrt -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/; \
7674    elif [ "$framework"  = "pytorch"  ] ; then \
7775        echo "Installing Pytorch GPU"  \
7876        && cd /algorithmic-efficiency \
Original file line number Diff line number Diff line change @@ -101,13 +101,14 @@ jax_core_deps = [
101101  " protobuf==4.25.5" 
102102]
103103jax_cpu  = [
104-   " jax==0.6.1 " 
104+   " jax==0.7.0 " 
105105  " algoperf[jax_core_deps]" 
106106]
107107
108108jax_gpu  = [
109-   " jax[cuda12]==0.6.1 " 
109+   " jax[cuda12]==0.7.0 " 
110110  " algoperf[jax_core_deps]" 
111+   " nvidia-cudnn-cu12==9.10.2.21" #  temporary workaround for https://github.com/jax-ml/jax/issues/30663
111112]
112113
113114pytorch_cpu  = [
 
 
   
 
     
   
   
          
    
    
     
    
      
     
     
    You can’t perform that action at this time.
  
 
    
  
    
      
        
     
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments