Skip to content

Commit 8e98702

Browse files
committed
pin cudnn version
1 parent 3b5a623 commit 8e98702

File tree

2 files changed

+3
-4
lines changed

2 files changed

+3
-4
lines changed

docker/Dockerfile

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,7 @@ RUN if [ "$framework" = "jax" ] ; then \
7070
echo "Installing Jax GPU" \
7171
&& cd /algorithmic-efficiency \
7272
&& pip install -e '.[pytorch_cpu, full]' --extra-index-url https://download.pytorch.org/whl/cpu \
73-
# Todo: remove temporary nightly install
7473
&& pip install -e '.[jax_gpu]'; \
75-
# && pip install -U --pre jax jaxlib "jax-cuda12-plugin[with-cuda]" jax-cuda12-pjrt -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/; \
7674
elif [ "$framework" = "pytorch" ] ; then \
7775
echo "Installing Pytorch GPU" \
7876
&& cd /algorithmic-efficiency \

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,14 @@ jax_core_deps = [
101101
"protobuf==4.25.5",
102102
]
103103
jax_cpu = [
104-
"jax==0.6.1",
104+
"jax==0.7.0",
105105
"algoperf[jax_core_deps]",
106106
]
107107

108108
jax_gpu = [
109-
"jax[cuda12]==0.6.1",
109+
"jax[cuda12]==0.7.0",
110110
"algoperf[jax_core_deps]",
111+
"nvidia-cudnn-cu12==9.10.2.21", # temporary workaround for https://github.com/jax-ml/jax/issues/30663
111112
]
112113

113114
pytorch_cpu = [

0 commit comments

Comments
 (0)