Skip to content

Commit 1965241

Browse files
fix yapf
1 parent 86114ef commit 1965241

File tree

30 files changed

+390
-390
lines changed

30 files changed

+390
-390
lines changed

prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -252,19 +252,19 @@ def _loss_fn(params):
252252
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm
253253

254254

255-
def update_params(workload: spec.Workload,
256-
current_param_container: spec.ParameterContainer,
257-
current_params_types: spec.ParameterTypeTree,
258-
model_state: spec.ModelAuxiliaryState,
259-
hyperparameters: spec.Hyperparameters,
260-
batch: Dict[str, spec.Tensor],
261-
loss_type: spec.LossType,
262-
optimizer_state: spec.OptimizerState,
263-
eval_results: List[Tuple[int, float]],
264-
global_step: int,
265-
rng: spec.RandomState,
266-
train_state: Optional[Dict[str, Any]] = None
267-
) -> spec.UpdateReturn:
255+
def update_params(
256+
workload: spec.Workload,
257+
current_param_container: spec.ParameterContainer,
258+
current_params_types: spec.ParameterTypeTree,
259+
model_state: spec.ModelAuxiliaryState,
260+
hyperparameters: spec.Hyperparameters,
261+
batch: Dict[str, spec.Tensor],
262+
loss_type: spec.LossType,
263+
optimizer_state: spec.OptimizerState,
264+
eval_results: List[Tuple[int, float]],
265+
global_step: int,
266+
rng: spec.RandomState,
267+
train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
268268
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
269269
del current_params_types
270270
del loss_type

prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -252,19 +252,19 @@ def _loss_fn(params):
252252
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm
253253

254254

255-
def update_params(workload: spec.Workload,
256-
current_param_container: spec.ParameterContainer,
257-
current_params_types: spec.ParameterTypeTree,
258-
model_state: spec.ModelAuxiliaryState,
259-
hyperparameters: spec.Hyperparameters,
260-
batch: Dict[str, spec.Tensor],
261-
loss_type: spec.LossType,
262-
optimizer_state: spec.OptimizerState,
263-
eval_results: List[Tuple[int, float]],
264-
global_step: int,
265-
rng: spec.RandomState,
266-
train_state: Optional[Dict[str, Any]] = None
267-
) -> spec.UpdateReturn:
255+
def update_params(
256+
workload: spec.Workload,
257+
current_param_container: spec.ParameterContainer,
258+
current_params_types: spec.ParameterTypeTree,
259+
model_state: spec.ModelAuxiliaryState,
260+
hyperparameters: spec.Hyperparameters,
261+
batch: Dict[str, spec.Tensor],
262+
loss_type: spec.LossType,
263+
optimizer_state: spec.OptimizerState,
264+
eval_results: List[Tuple[int, float]],
265+
global_step: int,
266+
rng: spec.RandomState,
267+
train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
268268
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
269269
del current_params_types
270270
del loss_type

prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -224,19 +224,19 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
224224
return optimizer_state
225225

226226

227-
def update_params(workload: spec.Workload,
228-
current_param_container: spec.ParameterContainer,
229-
current_params_types: spec.ParameterTypeTree,
230-
model_state: spec.ModelAuxiliaryState,
231-
hyperparameters: spec.Hyperparameters,
232-
batch: Dict[str, spec.Tensor],
233-
loss_type: spec.LossType,
234-
optimizer_state: spec.OptimizerState,
235-
eval_results: List[Tuple[int, float]],
236-
global_step: int,
237-
rng: spec.RandomState,
238-
train_state: Optional[Dict[str, Any]] = None
239-
) -> spec.UpdateReturn:
227+
def update_params(
228+
workload: spec.Workload,
229+
current_param_container: spec.ParameterContainer,
230+
current_params_types: spec.ParameterTypeTree,
231+
model_state: spec.ModelAuxiliaryState,
232+
hyperparameters: spec.Hyperparameters,
233+
batch: Dict[str, spec.Tensor],
234+
loss_type: spec.LossType,
235+
optimizer_state: spec.OptimizerState,
236+
eval_results: List[Tuple[int, float]],
237+
global_step: int,
238+
rng: spec.RandomState,
239+
train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
240240
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
241241
del current_params_types
242242
del loss_type

prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -224,19 +224,19 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
224224
return optimizer_state
225225

226226

227-
def update_params(workload: spec.Workload,
228-
current_param_container: spec.ParameterContainer,
229-
current_params_types: spec.ParameterTypeTree,
230-
model_state: spec.ModelAuxiliaryState,
231-
hyperparameters: spec.Hyperparameters,
232-
batch: Dict[str, spec.Tensor],
233-
loss_type: spec.LossType,
234-
optimizer_state: spec.OptimizerState,
235-
eval_results: List[Tuple[int, float]],
236-
global_step: int,
237-
rng: spec.RandomState,
238-
train_state: Optional[Dict[str, Any]] = None
239-
) -> spec.UpdateReturn:
227+
def update_params(
228+
workload: spec.Workload,
229+
current_param_container: spec.ParameterContainer,
230+
current_params_types: spec.ParameterTypeTree,
231+
model_state: spec.ModelAuxiliaryState,
232+
hyperparameters: spec.Hyperparameters,
233+
batch: Dict[str, spec.Tensor],
234+
loss_type: spec.LossType,
235+
optimizer_state: spec.OptimizerState,
236+
eval_results: List[Tuple[int, float]],
237+
global_step: int,
238+
rng: spec.RandomState,
239+
train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
240240
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
241241
del current_params_types
242242
del loss_type

prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -264,19 +264,19 @@ def _loss_fn(params):
264264
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm
265265

266266

267-
def update_params(workload: spec.Workload,
268-
current_param_container: spec.ParameterContainer,
269-
current_params_types: spec.ParameterTypeTree,
270-
model_state: spec.ModelAuxiliaryState,
271-
hyperparameters: spec.Hyperparameters,
272-
batch: Dict[str, spec.Tensor],
273-
loss_type: spec.LossType,
274-
optimizer_state: spec.OptimizerState,
275-
eval_results: List[Tuple[int, float]],
276-
global_step: int,
277-
rng: spec.RandomState,
278-
train_state: Optional[Dict[str, Any]] = None
279-
) -> spec.UpdateReturn:
267+
def update_params(
268+
workload: spec.Workload,
269+
current_param_container: spec.ParameterContainer,
270+
current_params_types: spec.ParameterTypeTree,
271+
model_state: spec.ModelAuxiliaryState,
272+
hyperparameters: spec.Hyperparameters,
273+
batch: Dict[str, spec.Tensor],
274+
loss_type: spec.LossType,
275+
optimizer_state: spec.OptimizerState,
276+
eval_results: List[Tuple[int, float]],
277+
global_step: int,
278+
rng: spec.RandomState,
279+
train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
280280
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
281281
del current_params_types
282282
del loss_type

prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -264,19 +264,19 @@ def _loss_fn(params):
264264
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm
265265

266266

267-
def update_params(workload: spec.Workload,
268-
current_param_container: spec.ParameterContainer,
269-
current_params_types: spec.ParameterTypeTree,
270-
model_state: spec.ModelAuxiliaryState,
271-
hyperparameters: spec.Hyperparameters,
272-
batch: Dict[str, spec.Tensor],
273-
loss_type: spec.LossType,
274-
optimizer_state: spec.OptimizerState,
275-
eval_results: List[Tuple[int, float]],
276-
global_step: int,
277-
rng: spec.RandomState,
278-
train_state: Optional[Dict[str, Any]] = None
279-
) -> spec.UpdateReturn:
267+
def update_params(
268+
workload: spec.Workload,
269+
current_param_container: spec.ParameterContainer,
270+
current_params_types: spec.ParameterTypeTree,
271+
model_state: spec.ModelAuxiliaryState,
272+
hyperparameters: spec.Hyperparameters,
273+
batch: Dict[str, spec.Tensor],
274+
loss_type: spec.LossType,
275+
optimizer_state: spec.OptimizerState,
276+
eval_results: List[Tuple[int, float]],
277+
global_step: int,
278+
rng: spec.RandomState,
279+
train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
280280
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
281281
del current_params_types
282282
del loss_type

prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -236,19 +236,19 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
236236
return optimizer_state
237237

238238

239-
def update_params(workload: spec.Workload,
240-
current_param_container: spec.ParameterContainer,
241-
current_params_types: spec.ParameterTypeTree,
242-
model_state: spec.ModelAuxiliaryState,
243-
hyperparameters: spec.Hyperparameters,
244-
batch: Dict[str, spec.Tensor],
245-
loss_type: spec.LossType,
246-
optimizer_state: spec.OptimizerState,
247-
eval_results: List[Tuple[int, float]],
248-
global_step: int,
249-
rng: spec.RandomState,
250-
train_state: Optional[Dict[str, Any]] = None
251-
) -> spec.UpdateReturn:
239+
def update_params(
240+
workload: spec.Workload,
241+
current_param_container: spec.ParameterContainer,
242+
current_params_types: spec.ParameterTypeTree,
243+
model_state: spec.ModelAuxiliaryState,
244+
hyperparameters: spec.Hyperparameters,
245+
batch: Dict[str, spec.Tensor],
246+
loss_type: spec.LossType,
247+
optimizer_state: spec.OptimizerState,
248+
eval_results: List[Tuple[int, float]],
249+
global_step: int,
250+
rng: spec.RandomState,
251+
train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
252252
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
253253
del current_params_types
254254
del loss_type

prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -236,19 +236,19 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
236236
return optimizer_state
237237

238238

239-
def update_params(workload: spec.Workload,
240-
current_param_container: spec.ParameterContainer,
241-
current_params_types: spec.ParameterTypeTree,
242-
model_state: spec.ModelAuxiliaryState,
243-
hyperparameters: spec.Hyperparameters,
244-
batch: Dict[str, spec.Tensor],
245-
loss_type: spec.LossType,
246-
optimizer_state: spec.OptimizerState,
247-
eval_results: List[Tuple[int, float]],
248-
global_step: int,
249-
rng: spec.RandomState,
250-
train_state: Optional[Dict[str, Any]] = None
251-
) -> spec.UpdateReturn:
239+
def update_params(
240+
workload: spec.Workload,
241+
current_param_container: spec.ParameterContainer,
242+
current_params_types: spec.ParameterTypeTree,
243+
model_state: spec.ModelAuxiliaryState,
244+
hyperparameters: spec.Hyperparameters,
245+
batch: Dict[str, spec.Tensor],
246+
loss_type: spec.LossType,
247+
optimizer_state: spec.OptimizerState,
248+
eval_results: List[Tuple[int, float]],
249+
global_step: int,
250+
rng: spec.RandomState,
251+
train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
252252
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
253253
del current_params_types
254254
del loss_type

reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -110,19 +110,19 @@ def _loss_fn(params):
110110

111111
# Not allowed to update the model parameters, hyperparameters, global step, or
112112
# optimzier state.
113-
def update_params(workload: spec.Workload,
114-
current_param_container: spec.ParameterContainer,
115-
current_params_types: spec.ParameterTypeTree,
116-
model_state: spec.ModelAuxiliaryState,
117-
hyperparameters: spec.Hyperparameters,
118-
batch: Dict[str, spec.Tensor],
119-
loss_type: spec.LossType,
120-
optimizer_state: spec.OptimizerState,
121-
eval_results: List[Tuple[int, float]],
122-
global_step: int,
123-
rng: spec.RandomState,
124-
train_state: Optional[Dict[str, Any]] = None
125-
) -> spec.UpdateReturn:
113+
def update_params(
114+
workload: spec.Workload,
115+
current_param_container: spec.ParameterContainer,
116+
current_params_types: spec.ParameterTypeTree,
117+
model_state: spec.ModelAuxiliaryState,
118+
hyperparameters: spec.Hyperparameters,
119+
batch: Dict[str, spec.Tensor],
120+
loss_type: spec.LossType,
121+
optimizer_state: spec.OptimizerState,
122+
eval_results: List[Tuple[int, float]],
123+
global_step: int,
124+
rng: spec.RandomState,
125+
train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
126126
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
127127
del current_params_types
128128
del loss_type

reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -53,19 +53,19 @@ def init_optimizer_state(workload: spec.Workload,
5353
return optimizer_state
5454

5555

56-
def update_params(workload: spec.Workload,
57-
current_param_container: spec.ParameterContainer,
58-
current_params_types: spec.ParameterTypeTree,
59-
model_state: spec.ModelAuxiliaryState,
60-
hyperparameters: spec.Hyperparameters,
61-
batch: Dict[str, spec.Tensor],
62-
loss_type: spec.LossType,
63-
optimizer_state: spec.OptimizerState,
64-
eval_results: List[Tuple[int, float]],
65-
global_step: int,
66-
rng: spec.RandomState,
67-
train_state: Optional[Dict[str, Any]] = None
68-
) -> spec.UpdateReturn:
56+
def update_params(
57+
workload: spec.Workload,
58+
current_param_container: spec.ParameterContainer,
59+
current_params_types: spec.ParameterTypeTree,
60+
model_state: spec.ModelAuxiliaryState,
61+
hyperparameters: spec.Hyperparameters,
62+
batch: Dict[str, spec.Tensor],
63+
loss_type: spec.LossType,
64+
optimizer_state: spec.OptimizerState,
65+
eval_results: List[Tuple[int, float]],
66+
global_step: int,
67+
rng: spec.RandomState,
68+
train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
6969
"""Return (updated_optimizer_state, updated_params)."""
7070
del current_params_types
7171
del hyperparameters

0 commit comments

Comments
 (0)