Skip to content

Commit 7a3710f

Browse files
Merge pull request #790 from Niccolo-Ajroldi/pass_train_state
Pass `train_state `to `update_params`
2 parents 5ce9e5a + 1965241 commit 7a3710f

File tree

34 files changed

+460
-358
lines changed

34 files changed

+460
-358
lines changed

DOCUMENTATION.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ def update_params(
199199
batch: Dict[str, Tensor],
200200
loss_type: LossType,
201201
optimizer_state: OptimizerState,
202+
train_state: Dict[str, Any],
202203
eval_results: List[Tuple[int, float]],
203204
global_step: int,
204205
rng: RandomState
@@ -212,6 +213,7 @@ def update_params(
212213
- The `loss_fn` produces a loss per example and a summed loss (both only for one device), which both can be used.
213214
- Allowed to update state for the optimizer.
214215
- Uses the `model_fn` of the `workload` in order to decouple the loss from the model so that model outputs (forward passes) can be reused (by storing them in the optimizer state).
216+
- The submission can access the elapsed training time and get further information about the evaluation through `train_state`.
215217
- The submission can access the target evaluation metric via the `workload` variable.
216218
- **A call to this function will be considered a step**
217219
- The time between a call to this function and the next call to this function will be considered the per-step time.

algorithmic_efficiency/spec.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,8 @@ def init_optimizer_state(workload: Workload,
403403
OptimizerState,
404404
List[Tuple[int, float]],
405405
int,
406-
RandomState
406+
RandomState,
407+
Optional[Dict[str, Any]]
407408
],
408409
UpdateReturn]
409410

@@ -424,7 +425,8 @@ def update_params(workload: Workload,
424425
optimizer_state: OptimizerState,
425426
eval_results: List[Tuple[int, float]],
426427
global_step: int,
427-
rng: RandomState) -> UpdateReturn:
428+
rng: RandomState,
429+
train_state: Optional[Dict[str, Any]] = None) -> UpdateReturn:
428430
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
429431
pass
430432

prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -252,20 +252,23 @@ def _loss_fn(params):
252252
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm
253253

254254

255-
def update_params(workload: spec.Workload,
256-
current_param_container: spec.ParameterContainer,
257-
current_params_types: spec.ParameterTypeTree,
258-
model_state: spec.ModelAuxiliaryState,
259-
hyperparameters: spec.Hyperparameters,
260-
batch: Dict[str, spec.Tensor],
261-
loss_type: spec.LossType,
262-
optimizer_state: spec.OptimizerState,
263-
eval_results: List[Tuple[int, float]],
264-
global_step: int,
265-
rng: spec.RandomState) -> spec.UpdateReturn:
255+
def update_params(
256+
workload: spec.Workload,
257+
current_param_container: spec.ParameterContainer,
258+
current_params_types: spec.ParameterTypeTree,
259+
model_state: spec.ModelAuxiliaryState,
260+
hyperparameters: spec.Hyperparameters,
261+
batch: Dict[str, spec.Tensor],
262+
loss_type: spec.LossType,
263+
optimizer_state: spec.OptimizerState,
264+
eval_results: List[Tuple[int, float]],
265+
global_step: int,
266+
rng: spec.RandomState,
267+
train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
266268
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
267269
del current_params_types
268270
del loss_type
271+
del train_state
269272
del eval_results
270273

271274
optimizer_state, opt_update_fn = optimizer_state

prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -252,20 +252,23 @@ def _loss_fn(params):
252252
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm
253253

254254

255-
def update_params(workload: spec.Workload,
256-
current_param_container: spec.ParameterContainer,
257-
current_params_types: spec.ParameterTypeTree,
258-
model_state: spec.ModelAuxiliaryState,
259-
hyperparameters: spec.Hyperparameters,
260-
batch: Dict[str, spec.Tensor],
261-
loss_type: spec.LossType,
262-
optimizer_state: spec.OptimizerState,
263-
eval_results: List[Tuple[int, float]],
264-
global_step: int,
265-
rng: spec.RandomState) -> spec.UpdateReturn:
255+
def update_params(
256+
workload: spec.Workload,
257+
current_param_container: spec.ParameterContainer,
258+
current_params_types: spec.ParameterTypeTree,
259+
model_state: spec.ModelAuxiliaryState,
260+
hyperparameters: spec.Hyperparameters,
261+
batch: Dict[str, spec.Tensor],
262+
loss_type: spec.LossType,
263+
optimizer_state: spec.OptimizerState,
264+
eval_results: List[Tuple[int, float]],
265+
global_step: int,
266+
rng: spec.RandomState,
267+
train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
266268
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
267269
del current_params_types
268270
del loss_type
271+
del train_state
269272
del eval_results
270273

271274
optimizer_state, opt_update_fn = optimizer_state

prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch."""
22

33
import math
4-
from typing import Dict, Iterator, List, Tuple
4+
from typing import Any, Dict, Iterator, List, Optional, Tuple
55

66
from absl import logging
77
import torch
@@ -224,20 +224,23 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
224224
return optimizer_state
225225

226226

227-
def update_params(workload: spec.Workload,
228-
current_param_container: spec.ParameterContainer,
229-
current_params_types: spec.ParameterTypeTree,
230-
model_state: spec.ModelAuxiliaryState,
231-
hyperparameters: spec.Hyperparameters,
232-
batch: Dict[str, spec.Tensor],
233-
loss_type: spec.LossType,
234-
optimizer_state: spec.OptimizerState,
235-
eval_results: List[Tuple[int, float]],
236-
global_step: int,
237-
rng: spec.RandomState) -> spec.UpdateReturn:
227+
def update_params(
228+
workload: spec.Workload,
229+
current_param_container: spec.ParameterContainer,
230+
current_params_types: spec.ParameterTypeTree,
231+
model_state: spec.ModelAuxiliaryState,
232+
hyperparameters: spec.Hyperparameters,
233+
batch: Dict[str, spec.Tensor],
234+
loss_type: spec.LossType,
235+
optimizer_state: spec.OptimizerState,
236+
eval_results: List[Tuple[int, float]],
237+
global_step: int,
238+
rng: spec.RandomState,
239+
train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
238240
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
239241
del current_params_types
240242
del loss_type
243+
del train_state
241244
del eval_results
242245

243246
current_model = current_param_container

prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch."""
22

33
import math
4-
from typing import Dict, Iterator, List, Tuple
4+
from typing import Any, Dict, Iterator, List, Optional, Tuple
55

66
from absl import logging
77
import torch
@@ -224,20 +224,23 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
224224
return optimizer_state
225225

226226

227-
def update_params(workload: spec.Workload,
228-
current_param_container: spec.ParameterContainer,
229-
current_params_types: spec.ParameterTypeTree,
230-
model_state: spec.ModelAuxiliaryState,
231-
hyperparameters: spec.Hyperparameters,
232-
batch: Dict[str, spec.Tensor],
233-
loss_type: spec.LossType,
234-
optimizer_state: spec.OptimizerState,
235-
eval_results: List[Tuple[int, float]],
236-
global_step: int,
237-
rng: spec.RandomState) -> spec.UpdateReturn:
227+
def update_params(
228+
workload: spec.Workload,
229+
current_param_container: spec.ParameterContainer,
230+
current_params_types: spec.ParameterTypeTree,
231+
model_state: spec.ModelAuxiliaryState,
232+
hyperparameters: spec.Hyperparameters,
233+
batch: Dict[str, spec.Tensor],
234+
loss_type: spec.LossType,
235+
optimizer_state: spec.OptimizerState,
236+
eval_results: List[Tuple[int, float]],
237+
global_step: int,
238+
rng: spec.RandomState,
239+
train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
238240
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
239241
del current_params_types
240242
del loss_type
243+
del train_state
241244
del eval_results
242245

243246
current_model = current_param_container

prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -264,20 +264,23 @@ def _loss_fn(params):
264264
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm
265265

266266

267-
def update_params(workload: spec.Workload,
268-
current_param_container: spec.ParameterContainer,
269-
current_params_types: spec.ParameterTypeTree,
270-
model_state: spec.ModelAuxiliaryState,
271-
hyperparameters: spec.Hyperparameters,
272-
batch: Dict[str, spec.Tensor],
273-
loss_type: spec.LossType,
274-
optimizer_state: spec.OptimizerState,
275-
eval_results: List[Tuple[int, float]],
276-
global_step: int,
277-
rng: spec.RandomState) -> spec.UpdateReturn:
267+
def update_params(
268+
workload: spec.Workload,
269+
current_param_container: spec.ParameterContainer,
270+
current_params_types: spec.ParameterTypeTree,
271+
model_state: spec.ModelAuxiliaryState,
272+
hyperparameters: spec.Hyperparameters,
273+
batch: Dict[str, spec.Tensor],
274+
loss_type: spec.LossType,
275+
optimizer_state: spec.OptimizerState,
276+
eval_results: List[Tuple[int, float]],
277+
global_step: int,
278+
rng: spec.RandomState,
279+
train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
278280
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
279281
del current_params_types
280282
del loss_type
283+
del train_state
281284
del eval_results
282285
del hyperparameters
283286

prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -264,20 +264,23 @@ def _loss_fn(params):
264264
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm
265265

266266

267-
def update_params(workload: spec.Workload,
268-
current_param_container: spec.ParameterContainer,
269-
current_params_types: spec.ParameterTypeTree,
270-
model_state: spec.ModelAuxiliaryState,
271-
hyperparameters: spec.Hyperparameters,
272-
batch: Dict[str, spec.Tensor],
273-
loss_type: spec.LossType,
274-
optimizer_state: spec.OptimizerState,
275-
eval_results: List[Tuple[int, float]],
276-
global_step: int,
277-
rng: spec.RandomState) -> spec.UpdateReturn:
267+
def update_params(
268+
workload: spec.Workload,
269+
current_param_container: spec.ParameterContainer,
270+
current_params_types: spec.ParameterTypeTree,
271+
model_state: spec.ModelAuxiliaryState,
272+
hyperparameters: spec.Hyperparameters,
273+
batch: Dict[str, spec.Tensor],
274+
loss_type: spec.LossType,
275+
optimizer_state: spec.OptimizerState,
276+
eval_results: List[Tuple[int, float]],
277+
global_step: int,
278+
rng: spec.RandomState,
279+
train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
278280
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
279281
del current_params_types
280282
del loss_type
283+
del train_state
281284
del eval_results
282285
del hyperparameters
283286

prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch."""
22

33
import math
4-
from typing import Dict, Iterator, List, Tuple
4+
from typing import Any, Dict, Iterator, List, Optional, Tuple
55

66
from absl import logging
77
import torch
@@ -236,20 +236,23 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
236236
return optimizer_state
237237

238238

239-
def update_params(workload: spec.Workload,
240-
current_param_container: spec.ParameterContainer,
241-
current_params_types: spec.ParameterTypeTree,
242-
model_state: spec.ModelAuxiliaryState,
243-
hyperparameters: spec.Hyperparameters,
244-
batch: Dict[str, spec.Tensor],
245-
loss_type: spec.LossType,
246-
optimizer_state: spec.OptimizerState,
247-
eval_results: List[Tuple[int, float]],
248-
global_step: int,
249-
rng: spec.RandomState) -> spec.UpdateReturn:
239+
def update_params(
240+
workload: spec.Workload,
241+
current_param_container: spec.ParameterContainer,
242+
current_params_types: spec.ParameterTypeTree,
243+
model_state: spec.ModelAuxiliaryState,
244+
hyperparameters: spec.Hyperparameters,
245+
batch: Dict[str, spec.Tensor],
246+
loss_type: spec.LossType,
247+
optimizer_state: spec.OptimizerState,
248+
eval_results: List[Tuple[int, float]],
249+
global_step: int,
250+
rng: spec.RandomState,
251+
train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
250252
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
251253
del current_params_types
252254
del loss_type
255+
del train_state
253256
del eval_results
254257
del hyperparameters
255258

prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch."""
22

33
import math
4-
from typing import Dict, Iterator, List, Tuple
4+
from typing import Any, Dict, Iterator, List, Optional, Tuple
55

66
from absl import logging
77
import torch
@@ -236,20 +236,23 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
236236
return optimizer_state
237237

238238

239-
def update_params(workload: spec.Workload,
240-
current_param_container: spec.ParameterContainer,
241-
current_params_types: spec.ParameterTypeTree,
242-
model_state: spec.ModelAuxiliaryState,
243-
hyperparameters: spec.Hyperparameters,
244-
batch: Dict[str, spec.Tensor],
245-
loss_type: spec.LossType,
246-
optimizer_state: spec.OptimizerState,
247-
eval_results: List[Tuple[int, float]],
248-
global_step: int,
249-
rng: spec.RandomState) -> spec.UpdateReturn:
239+
def update_params(
240+
workload: spec.Workload,
241+
current_param_container: spec.ParameterContainer,
242+
current_params_types: spec.ParameterTypeTree,
243+
model_state: spec.ModelAuxiliaryState,
244+
hyperparameters: spec.Hyperparameters,
245+
batch: Dict[str, spec.Tensor],
246+
loss_type: spec.LossType,
247+
optimizer_state: spec.OptimizerState,
248+
eval_results: List[Tuple[int, float]],
249+
global_step: int,
250+
rng: spec.RandomState,
251+
train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
250252
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
251253
del current_params_types
252254
del loss_type
255+
del train_state
253256
del eval_results
254257
del hyperparameters
255258

0 commit comments

Comments
 (0)