Skip to content

Commit b952422

Browse files
committed
fix name
1 parent 4f1c43e commit b952422

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

reference_algorithms/qualification_baselines/external_tuning/jax_nadamw_target_setting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,8 +274,8 @@ def update_params(
274274
else:
275275
grad_clip = None
276276

277-
# Create shardings for each argument
278-
replicated = jax_sharding_utils.get_replicated_sharding() # No partitioning
277+
# Create shardings for each argument
278+
replicated = jax_sharding_utils.get_replicate_sharding() # No partitioning
279279
sharded = jax_sharding_utils.get_batch_dim_sharding() # Partition along batch dimension
280280

281281
# Create the sharding rules for each argument

0 commit comments

Comments
 (0)