Skip to content

Commit 2cfa2a9

Browse files
committed
set jax to 0.5.1
1 parent d3a06fc commit 2cfa2a9

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,11 @@ jax_core_deps = [
106106
"protobuf==4.25.5",
107107
]
108108
jax_cpu = [
109-
"jax",
109+
"jax==0.5.2",
110110
"algoperf[jax_core_deps]",
111111
]
112112
jax_gpu = [
113-
"jax[cuda12]",
113+
"jax[cuda12]==0.5.2",
114114
"algoperf[jax_core_deps]",
115115
]
116116
pytorch_cpu = ["torch==2.5.1", "torchvision==0.20.1"]

0 commit comments

Comments
 (0)