We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent a3a9b9f commit 0a340a2Copy full SHA for 0a340a2
algoperf/workloads/cifar/cifar_jax/workload.py
@@ -183,7 +183,7 @@ def _compute_metrics(self,
183
jax_sharding_utils.get_replicated_sharding(), # params
184
jax_sharding_utils.get_batch_sharding(), # batch
185
jax_sharding_utils.get_replicated_sharding(), # model_state
186
- jax_sharding _utils.get_batch_sharding(), # rng
+ jax_sharding_utils.get_batch_sharding(), # rng
187
),
188
)
189
def _eval_model(
0 commit comments