diff --git a/pyproject.toml b/pyproject.toml index cc404f4b5..17dd9c2cf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,15 +106,15 @@ jax_core_deps = [ "protobuf==4.25.5", ] jax_cpu = [ - "jax==0.4.28", - "jaxlib==0.4.28", + "jax==0.4.30", + "jaxlib==0.4.30", "algoperf[jax_core_deps]", ] jax_gpu = [ - "jax==0.4.28", - "jaxlib==0.4.28", - "jax-cuda12-plugin[with_cuda]==0.4.28", - "jax-cuda12-pjrt==0.4.28", + "jax==0.4.30", + "jaxlib==0.4.30", + "jax-cuda12-plugin[with_cuda]==0.4.30", + "jax-cuda12-pjrt==0.4.30", "algoperf[jax_core_deps]", ] pytorch_cpu = ["torch==2.5.1", "torchvision==0.20.1"]