Skip to content

Commit 5a06a0d

Browse files
adding train_state to all submissions
1 parent ce8eb18 commit 5a06a0d

File tree

30 files changed

+107
-77
lines changed

30 files changed

+107
-77
lines changed

prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,10 +260,11 @@ def update_params(workload: spec.Workload,
260260
batch: Dict[str, spec.Tensor],
261261
loss_type: spec.LossType,
262262
optimizer_state: spec.OptimizerState,
263-
train_state: Dict[str, Any],
264263
eval_results: List[Tuple[int, float]],
265264
global_step: int,
266-
rng: spec.RandomState) -> spec.UpdateReturn:
265+
rng: spec.RandomState,
266+
train_state: Optional[Dict[str, Any]] = None
267+
) -> spec.UpdateReturn:
267268
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
268269
del current_params_types
269270
del loss_type

prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,10 +260,11 @@ def update_params(workload: spec.Workload,
260260
batch: Dict[str, spec.Tensor],
261261
loss_type: spec.LossType,
262262
optimizer_state: spec.OptimizerState,
263-
train_state: Dict[str, Any],
264263
eval_results: List[Tuple[int, float]],
265264
global_step: int,
266-
rng: spec.RandomState) -> spec.UpdateReturn:
265+
rng: spec.RandomState,
266+
train_state: Optional[Dict[str, Any]] = None
267+
) -> spec.UpdateReturn:
267268
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
268269
del current_params_types
269270
del loss_type

prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch."""
22

33
import math
4-
from typing import Any, Dict, Iterator, List, Tuple
4+
from typing import Any, Dict, Iterator, List, Optional, Tuple
55

66
from absl import logging
77
import torch
@@ -232,10 +232,11 @@ def update_params(workload: spec.Workload,
232232
batch: Dict[str, spec.Tensor],
233233
loss_type: spec.LossType,
234234
optimizer_state: spec.OptimizerState,
235-
train_state: Dict[str, Any],
236235
eval_results: List[Tuple[int, float]],
237236
global_step: int,
238-
rng: spec.RandomState) -> spec.UpdateReturn:
237+
rng: spec.RandomState,
238+
train_state: Optional[Dict[str, Any]] = None
239+
) -> spec.UpdateReturn:
239240
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
240241
del current_params_types
241242
del loss_type

prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch."""
22

33
import math
4-
from typing import Any, Dict, Iterator, List, Tuple
4+
from typing import Any, Dict, Iterator, List, Optional, Tuple
55

66
from absl import logging
77
import torch
@@ -232,10 +232,11 @@ def update_params(workload: spec.Workload,
232232
batch: Dict[str, spec.Tensor],
233233
loss_type: spec.LossType,
234234
optimizer_state: spec.OptimizerState,
235-
train_state: Dict[str, Any],
236235
eval_results: List[Tuple[int, float]],
237236
global_step: int,
238-
rng: spec.RandomState) -> spec.UpdateReturn:
237+
rng: spec.RandomState,
238+
train_state: Optional[Dict[str, Any]] = None
239+
) -> spec.UpdateReturn:
239240
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
240241
del current_params_types
241242
del loss_type

prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,10 +272,11 @@ def update_params(workload: spec.Workload,
272272
batch: Dict[str, spec.Tensor],
273273
loss_type: spec.LossType,
274274
optimizer_state: spec.OptimizerState,
275-
train_state: Dict[str, Any],
276275
eval_results: List[Tuple[int, float]],
277276
global_step: int,
278-
rng: spec.RandomState) -> spec.UpdateReturn:
277+
rng: spec.RandomState,
278+
train_state: Optional[Dict[str, Any]] = None
279+
) -> spec.UpdateReturn:
279280
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
280281
del current_params_types
281282
del loss_type

prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,10 +272,11 @@ def update_params(workload: spec.Workload,
272272
batch: Dict[str, spec.Tensor],
273273
loss_type: spec.LossType,
274274
optimizer_state: spec.OptimizerState,
275-
train_state: Dict[str, Any],
276275
eval_results: List[Tuple[int, float]],
277276
global_step: int,
278-
rng: spec.RandomState) -> spec.UpdateReturn:
277+
rng: spec.RandomState,
278+
train_state: Optional[Dict[str, Any]] = None
279+
) -> spec.UpdateReturn:
279280
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
280281
del current_params_types
281282
del loss_type

prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch."""
22

33
import math
4-
from typing import Any, Dict, Iterator, List, Tuple
4+
from typing import Any, Dict, Iterator, List, Optional, Tuple
55

66
from absl import logging
77
import torch
@@ -244,10 +244,11 @@ def update_params(workload: spec.Workload,
244244
batch: Dict[str, spec.Tensor],
245245
loss_type: spec.LossType,
246246
optimizer_state: spec.OptimizerState,
247-
train_state: Dict[str, Any],
248247
eval_results: List[Tuple[int, float]],
249248
global_step: int,
250-
rng: spec.RandomState) -> spec.UpdateReturn:
249+
rng: spec.RandomState,
250+
train_state: Optional[Dict[str, Any]] = None
251+
) -> spec.UpdateReturn:
251252
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
252253
del current_params_types
253254
del loss_type

prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch."""
22

33
import math
4-
from typing import Any, Dict, Iterator, List, Tuple
4+
from typing import Any, Dict, Iterator, List, Optional, Tuple
55

66
from absl import logging
77
import torch
@@ -244,10 +244,11 @@ def update_params(workload: spec.Workload,
244244
batch: Dict[str, spec.Tensor],
245245
loss_type: spec.LossType,
246246
optimizer_state: spec.OptimizerState,
247-
train_state: Dict[str, Any],
248247
eval_results: List[Tuple[int, float]],
249248
global_step: int,
250-
rng: spec.RandomState) -> spec.UpdateReturn:
249+
rng: spec.RandomState,
250+
train_state: Optional[Dict[str, Any]] = None
251+
) -> spec.UpdateReturn:
251252
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
252253
del current_params_types
253254
del loss_type

reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Training algorithm track submission functions for CIFAR10."""
22

33
import functools
4-
from typing import Any, Dict, Iterator, List, Tuple
4+
from typing import Any, Dict, Iterator, List, Optional, Tuple
55

66
from flax import jax_utils
77
import jax
@@ -118,10 +118,11 @@ def update_params(workload: spec.Workload,
118118
batch: Dict[str, spec.Tensor],
119119
loss_type: spec.LossType,
120120
optimizer_state: spec.OptimizerState,
121-
train_state: Dict[str, Any],
122121
eval_results: List[Tuple[int, float]],
123122
global_step: int,
124-
rng: spec.RandomState) -> spec.UpdateReturn:
123+
rng: spec.RandomState,
124+
train_state: Optional[Dict[str, Any]] = None
125+
) -> spec.UpdateReturn:
125126
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
126127
del current_params_types
127128
del loss_type

reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Training algorithm track submission functions for CIFAR10."""
22

3-
from typing import Any, Dict, Iterator, List, Tuple
3+
from typing import Any, Dict, Iterator, List, Optional, Tuple
44

55
import torch
66
from torch.optim.lr_scheduler import CosineAnnealingLR
@@ -61,10 +61,11 @@ def update_params(workload: spec.Workload,
6161
batch: Dict[str, spec.Tensor],
6262
loss_type: spec.LossType,
6363
optimizer_state: spec.OptimizerState,
64-
train_state: Dict[str, Any],
6564
eval_results: List[Tuple[int, float]],
6665
global_step: int,
67-
rng: spec.RandomState) -> spec.UpdateReturn:
66+
rng: spec.RandomState,
67+
train_state: Optional[Dict[str, Any]] = None
68+
) -> spec.UpdateReturn:
6869
"""Return (updated_optimizer_state, updated_params)."""
6970
del current_params_types
7071
del hyperparameters

0 commit comments

Comments
 (0)