From 06cce6e972b9bc4404375c44e4c42a00d3761dc2 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 14 Apr 2025 23:17:52 +0000 Subject: [PATCH 1/2] upgrade to 0.4.30 --- .../librispeech_jax/workload.py | 6 ++++-- pyproject.toml | 12 ++++++------ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index d3b616f43..2d88a18d0 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -51,8 +51,10 @@ def init_model_fn( params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_utils.replicate(model_state) - params = jax_utils.replicate(params) + model_state = jax_utils.replicate(model_state, jax.devices()) + params = jax_utils.replicate(params, jax.devices()) + params_shapes = jax.tree.map(lambda x: jnp.shape(x), model_state) + jax.debug.print("params shapes {params_shapes}", params_shapes=params_shapes) return params, model_state def model_fn( diff --git a/pyproject.toml b/pyproject.toml index cc404f4b5..17dd9c2cf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,15 +106,15 @@ jax_core_deps = [ "protobuf==4.25.5", ] jax_cpu = [ - "jax==0.4.28", - "jaxlib==0.4.28", + "jax==0.4.30", + "jaxlib==0.4.30", "algoperf[jax_core_deps]", ] jax_gpu = [ - "jax==0.4.28", - "jaxlib==0.4.28", - "jax-cuda12-plugin[with_cuda]==0.4.28", - "jax-cuda12-pjrt==0.4.28", + "jax==0.4.30", + "jaxlib==0.4.30", + "jax-cuda12-plugin[with_cuda]==0.4.30", + "jax-cuda12-pjrt==0.4.30", "algoperf[jax_core_deps]", ] pytorch_cpu = ["torch==2.5.1", "torchvision==0.20.1"] From 5c52bd7e167f589d91b381dd28597090987ae1de Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 15 Apr 2025 00:07:24 +0000 Subject: [PATCH 2/2] remove debugging statements --- .../librispeech_deepspeech/librispeech_jax/workload.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 2d88a18d0..d3b616f43 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -51,10 +51,8 @@ def init_model_fn( params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_utils.replicate(model_state, jax.devices()) - params = jax_utils.replicate(params, jax.devices()) - params_shapes = jax.tree.map(lambda x: jnp.shape(x), model_state) - jax.debug.print("params shapes {params_shapes}", params_shapes=params_shapes) + model_state = jax_utils.replicate(model_state) + params = jax_utils.replicate(params) return params, model_state def model_fn(