Skip to content

Commit 0a340a2

Browse files
committed
small fix
1 parent a3a9b9f commit 0a340a2

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

algoperf/workloads/cifar/cifar_jax/workload.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def _compute_metrics(self,
183183
jax_sharding_utils.get_replicated_sharding(), # params
184184
jax_sharding_utils.get_batch_sharding(), # batch
185185
jax_sharding_utils.get_replicated_sharding(), # model_state
186-
jax_sharding _utils.get_batch_sharding(), # rng
186+
jax_sharding_utils.get_batch_sharding(), # rng
187187
),
188188
)
189189
def _eval_model(

0 commit comments

Comments
 (0)