Skip to content

Commit 582951c

Browse files
committed
remove unused baselines that rely on pmap
1 parent 8f6648c commit 582951c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+55
-5970
lines changed

prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py

Lines changed: 52 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -192,24 +192,18 @@ def jax_cosine_warmup(step_hint: int, hyperparameters):
192192
workload.param_shapes)
193193
optimizer_state = opt_init_fn(params_zeros_like)
194194

195-
return jax_utils.replicate(optimizer_state), opt_update_fn
196-
197-
198-
@functools.partial(
199-
jax.pmap,
200-
axis_name='batch',
201-
in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
202-
static_broadcasted_argnums=(0, 1),
203-
donate_argnums=(2, 3, 4))
204-
def pmapped_train_step(workload,
205-
opt_update_fn,
206-
model_state,
207-
optimizer_state,
208-
current_param_container,
209-
batch,
210-
rng,
211-
grad_clip,
212-
label_smoothing):
195+
return optimizer_state, opt_update_fn
196+
197+
198+
def train_step(workload,
199+
opt_update_fn,
200+
model_state,
201+
optimizer_state,
202+
current_param_container,
203+
batch,
204+
rng,
205+
grad_clip,
206+
label_smoothing):
213207

214208
def _loss_fn(params):
215209
"""Loss function used for training."""
@@ -232,9 +226,7 @@ def _loss_fn(params):
232226
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
233227
(summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn(
234228
current_param_container)
235-
# Get correct global mean loss and grad.
236-
(summed_loss, n_valid_examples, grad) = lax.psum(
237-
(summed_loss, n_valid_examples, grad), axis_name='batch')
229+
# Compute mean loss and grad
238230
loss = summed_loss / n_valid_examples
239231
grad = jax.tree.map(lambda x: x / n_valid_examples, grad)
240232

@@ -272,7 +264,6 @@ def update_params(
272264
del eval_results
273265

274266
optimizer_state, opt_update_fn = optimizer_state
275-
per_device_rngs = jax.random.split(rng, jax.local_device_count())
276267
if hasattr(hyperparameters, 'label_smoothing'):
277268
label_smoothing = hyperparameters.label_smoothing
278269
else:
@@ -281,13 +272,48 @@ def update_params(
281272
grad_clip = hyperparameters.grad_clip
282273
else:
283274
grad_clip = None
284-
outputs = pmapped_train_step(workload,
275+
276+
# Get mesh
277+
mesh = jax_sharding_utils.get_mesh()
278+
# Create shardings for each argument
279+
replicated = jax_sharding_utils.get_replicated_sharding(mesh) # No partitioning
280+
sharded = jax_sharding_utils.get_batch_sharding(
281+
mesh) # Partition along batch dimension
282+
283+
# Create the sharding rules for each argument
284+
arg_shardings = (
285+
# workload is static
286+
# opt_update_fn is static
287+
replicated, # model_state
288+
replicated, # optimizer_state
289+
replicated, # current_param_container
290+
sharded, # batch
291+
replicated, # rng
292+
replicated, # grad_clip
293+
replicated # label_smoothing
294+
)
295+
out_shardings = (
296+
replicated, # new_optimizer_state
297+
replicated, # updated_params
298+
replicated, # new_model_state
299+
replicated, # loss
300+
replicated # grad_norm
301+
)
302+
# Jit with shardings
303+
jitted_train_step = jax.jit(
304+
train_step,
305+
static_argnums=(0, 1),
306+
donate_argnums=(2, 3, 4),
307+
in_shardings=arg_shardings,
308+
out_shardings=out_shardings)
309+
310+
new_optimizer_state, new_params, new_model_state, loss, grad_norm = jitted_train_step(workload,
285311
opt_update_fn,
286312
model_state,
287313
optimizer_state,
288314
current_param_container,
289315
batch,
290-
per_device_rngs,
316+
rng,
291317
grad_clip,
292318
label_smoothing)
293319
new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs
@@ -296,8 +322,8 @@ def update_params(
296322
if global_step % 100 == 0 and workload.metrics_logger is not None:
297323
workload.metrics_logger.append_scalar_metrics(
298324
{
299-
'loss': loss[0],
300-
'grad_norm': grad_norm[0],
325+
'loss': loss.item(),
326+
'grad_norm': grad_norm.item()
301327
}, global_step)
302328
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
303329

prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters):
192192
workload.param_shapes)
193193
optimizer_state = opt_init_fn(params_zeros_like)
194194

195-
return jax_utils.replicate(optimizer_state), opt_update_fn
195+
return optimizer_state, opt_update_fn
196196

197197

198198
@functools.partial(

reference_algorithms/development_algorithms/__init__.py

Whitespace-only changes.

reference_algorithms/development_algorithms/cifar/__init__.py

Whitespace-only changes.

reference_algorithms/development_algorithms/cifar/cifar_jax/__init__.py

Whitespace-only changes.

reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py

Lines changed: 0 additions & 180 deletions
This file was deleted.

reference_algorithms/development_algorithms/cifar/cifar_pytorch/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)