Skip to content

Commit 52bb875

Browse files
committed
set jax to 0.5.1
1 parent 93ff958 commit 52bb875

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,11 @@ jax_core_deps = [
106106
"protobuf==4.25.5",
107107
]
108108
jax_cpu = [
109-
"jax",
109+
"jax==0.5.2",
110110
"algoperf[jax_core_deps]",
111111
]
112112
jax_gpu = [
113-
"jax[cuda12]",
113+
"jax[cuda12]==0.5.2",
114114
"algoperf[jax_core_deps]",
115115
]
116116
pytorch_cpu = ["torch==2.5.1", "torchvision==0.20.1"]

submission_runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings.
6565
# disable only for deepspeech if it works fine for other workloads
6666
os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false'
67+
os.environ['JAX_TRACEBACK_FILTERING'] = "off"
6768

6869
# TODO(znado): make a nicer registry of workloads that lookup in.
6970
BASE_WORKLOADS_DIR = workloads.BASE_WORKLOADS_DIR

0 commit comments

Comments
 (0)