Skip to content

Commit da67b52

Browse files
committed
removing unnecessary code
1 parent 8bbe5e2 commit da67b52

File tree

2 files changed

+0
-14
lines changed

2 files changed

+0
-14
lines changed

reference_algorithms/schedule_free/jax/submission.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,6 @@ def init_optimizer_state(workload: spec.Workload,
3636
model_params
3737
del model_state
3838
del rng
39-
lr=HPARAMS['learning_rate']
40-
betas=(1.0 - HPARAMS['one_minus_beta1'], HPARAMS['beta2'])
41-
warmup_steps=int(HPARAMS['warmup_factor'] * workload.step_hint * 0.75)
42-
weight_decay=HPARAMS['weight_decay']
43-
weight_lr_power=HPARAMS['weight_lr_power']
44-
r=HPARAMS['r']
4539

4640
opt_init_fn, opt_update_fn = schedule_free_adamw(
4741
learning_rate=HPARAMS['learning_rate'],

reference_algorithms/schedule_free/pytorch/submission.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -266,14 +266,6 @@ def closure():
266266

267267
loss = optimizer_state['optimizer'].step(closure)
268268

269-
# # current_state = current_model.module.state_dict() is a workaround for DDP
270-
# if global_step==1:
271-
# torch.save(current_model.module.state_dict(), "/results/pytorch_base_model_criteo1tb_11may_global_step2.pth")
272-
# import torch.distributed as dist
273-
# import sys
274-
# dist.destroy_process_group()
275-
# sys.exit(0)
276-
277269
return (optimizer_state, current_param_container, new_model_state)
278270

279271

0 commit comments

Comments
 (0)