We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 4f1c43e commit b952422Copy full SHA for b952422
reference_algorithms/qualification_baselines/external_tuning/jax_nadamw_target_setting.py
@@ -274,8 +274,8 @@ def update_params(
274
else:
275
grad_clip = None
276
277
- # Create shardings for each argument
278
- replicated = jax_sharding_utils.get_replicated_sharding() # No partitioning
+ # Create shardings for each argument
+ replicated = jax_sharding_utils.get_replicate_sharding() # No partitioning
279
sharded = jax_sharding_utils.get_batch_dim_sharding() # Partition along batch dimension
280
281
# Create the sharding rules for each argument
0 commit comments