Skip to content

Commit f557301

Browse files
committed
[reformat] ruff reformat lion submission
1 parent 22b07c8 commit f557301

File tree

1 file changed

+95
-75
lines changed

1 file changed

+95
-75
lines changed

reference_algorithms/paper_baselines/lion/pytorch/submission.py

Lines changed: 95 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,16 @@
1818

1919
# default Lion parameters
2020
HPARAMS = {
21-
"dropout_rate": 0.1,
22-
"learning_rate": 2e-4,
23-
"one_minus_beta1": 0.05,
24-
"beta2": 0.98,
25-
"weight_decay": 0.5,
26-
"warmup_factor": 0.02
21+
'dropout_rate': 0.1,
22+
'learning_rate': 2e-4,
23+
'one_minus_beta1': 0.05,
24+
'beta2': 0.98,
25+
'weight_decay': 0.5,
26+
'warmup_factor': 0.02,
2727
}
2828
HPARAMS = collections.namedtuple('Hyperparameters', HPARAMS.keys())(**HPARAMS)
2929

30+
3031
# Modified from https://github.com/google/automl/blob/master/lion/lion_pytorch.py.
3132
class Lion(Optimizer):
3233
def __init__(
@@ -90,11 +91,13 @@ def step(self, closure=None):
9091
return loss
9192

9293

93-
def init_optimizer_state(workload: spec.Workload,
94-
model_params: spec.ParameterContainer,
95-
model_state: spec.ModelAuxiliaryState,
96-
hyperparameters: spec.Hyperparameters,
97-
rng: spec.RandomState) -> spec.OptimizerState:
94+
def init_optimizer_state(
95+
workload: spec.Workload,
96+
model_params: spec.ParameterContainer,
97+
model_state: spec.ModelAuxiliaryState,
98+
hyperparameters: spec.Hyperparameters,
99+
rng: spec.RandomState,
100+
) -> spec.OptimizerState:
98101
"""Creates a Lion optimizer and a learning rate schedule."""
99102
del model_state
100103
del rng
@@ -103,44 +106,47 @@ def init_optimizer_state(workload: spec.Workload,
103106
hyperparameters = HPARAMS
104107

105108
optimizer_state = {
106-
'optimizer':
107-
Lion(
108-
model_params.parameters(),
109-
lr=HPARAMS.learning_rate,
110-
betas=(1.0 - HPARAMS.one_minus_beta1,
111-
HPARAMS.beta2),
112-
weight_decay=HPARAMS.weight_decay)
109+
'optimizer': Lion(
110+
model_params.parameters(),
111+
lr=HPARAMS.learning_rate,
112+
betas=(1.0 - HPARAMS.one_minus_beta1, HPARAMS.beta2),
113+
weight_decay=HPARAMS.weight_decay,
114+
)
113115
}
114116

115117
def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
116118
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
117119
warmup = LinearLR(
118-
optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps)
120+
optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps
121+
)
119122
cosine_steps = max(step_hint - warmup_steps, 1)
120123
cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps)
121124
return SequentialLR(
122-
optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps])
125+
optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]
126+
)
123127

124128
optimizer_state['scheduler'] = pytorch_cosine_warmup(
125-
workload.step_hint, HPARAMS, optimizer_state['optimizer'])
129+
workload.step_hint, HPARAMS, optimizer_state['optimizer']
130+
)
126131
optimizer_state['hyperparameters'] = hyperparameters
127132

128133
return optimizer_state
129134

130135

131136
def update_params(
132-
workload: spec.Workload,
133-
current_param_container: spec.ParameterContainer,
134-
current_params_types: spec.ParameterTypeTree,
135-
model_state: spec.ModelAuxiliaryState,
136-
hyperparameters: spec.Hyperparameters,
137-
batch: Dict[str, spec.Tensor],
138-
loss_type: spec.LossType,
139-
optimizer_state: spec.OptimizerState,
140-
eval_results: List[Tuple[int, float]],
141-
global_step: int,
142-
rng: spec.RandomState,
143-
train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
137+
workload: spec.Workload,
138+
current_param_container: spec.ParameterContainer,
139+
current_params_types: spec.ParameterTypeTree,
140+
model_state: spec.ModelAuxiliaryState,
141+
hyperparameters: spec.Hyperparameters,
142+
batch: Dict[str, spec.Tensor],
143+
loss_type: spec.LossType,
144+
optimizer_state: spec.OptimizerState,
145+
eval_results: List[Tuple[int, float]],
146+
global_step: int,
147+
rng: spec.RandomState,
148+
train_state: Optional[Dict[str, Any]] = None,
149+
) -> spec.UpdateReturn:
144150
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
145151
del current_params_types
146152
del loss_type
@@ -155,26 +161,30 @@ def update_params(
155161
optimizer_state['optimizer'].zero_grad()
156162

157163
logits_batch, new_model_state = workload.model_fn(
158-
params=current_model,
159-
augmented_and_preprocessed_input_batch=batch,
160-
model_state=model_state,
161-
mode=spec.ForwardPassMode.TRAIN,
162-
rng=rng,
163-
update_batch_norm=True)
164+
params=current_model,
165+
augmented_and_preprocessed_input_batch=batch,
166+
model_state=model_state,
167+
mode=spec.ForwardPassMode.TRAIN,
168+
rng=rng,
169+
update_batch_norm=True,
170+
)
164171

165172
label_smoothing = (
166-
hyperparameters.label_smoothing if hasattr(HPARAMS,
167-
'label_smoothing') else 0.0)
173+
hyperparameters.label_smoothing
174+
if hasattr(HPARAMS, 'label_smoothing')
175+
else 0.0
176+
)
168177
if hasattr(hyperparameters, 'grad_clip'):
169178
grad_clip = hyperparameters.grad_clip
170179
else:
171180
grad_clip = None
172181

173182
loss_dict = workload.loss_fn(
174-
label_batch=batch['targets'],
175-
logits_batch=logits_batch,
176-
mask_batch=batch.get('weights'),
177-
label_smoothing=label_smoothing)
183+
label_batch=batch['targets'],
184+
logits_batch=logits_batch,
185+
mask_batch=batch.get('weights'),
186+
label_smoothing=label_smoothing,
187+
)
178188
summed_loss = loss_dict['summed']
179189
n_valid_examples = loss_dict['n_valid_examples']
180190
if USE_PYTORCH_DDP:
@@ -187,7 +197,8 @@ def update_params(
187197

188198
if grad_clip is not None:
189199
torch.nn.utils.clip_grad_norm_(
190-
current_model.parameters(), max_norm=grad_clip)
200+
current_model.parameters(), max_norm=grad_clip
201+
)
191202
optimizer_state['optimizer'].step()
192203
optimizer_state['scheduler'].step()
193204

@@ -196,31 +207,38 @@ def update_params(
196207
with torch.no_grad():
197208
parameters = [p for p in current_model.parameters() if p.grad is not None]
198209
grad_norm = torch.norm(
199-
torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2)
210+
torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2
211+
)
200212
if workload.metrics_logger is not None:
201213
workload.metrics_logger.append_scalar_metrics(
202-
{
203-
'loss': loss.item(),
204-
'grad_norm': grad_norm.item(),
205-
}, global_step)
206-
logging.info('%d) loss = %0.3f, grad_norm = %0.3f',
207-
global_step,
208-
loss.item(),
209-
grad_norm.item())
214+
{
215+
'loss': loss.item(),
216+
'grad_norm': grad_norm.item(),
217+
},
218+
global_step,
219+
)
220+
logging.info(
221+
'%d) loss = %0.3f, grad_norm = %0.3f',
222+
global_step,
223+
loss.item(),
224+
grad_norm.item(),
225+
)
210226

211227
return (optimizer_state, current_param_container, new_model_state)
212228

213229

214-
def prepare_for_eval(workload: spec.Workload,
215-
current_param_container: spec.ParameterContainer,
216-
current_params_types: spec.ParameterTypeTree,
217-
model_state: spec.ModelAuxiliaryState,
218-
hyperparameters: spec.Hyperparameters,
219-
loss_type: spec.LossType,
220-
optimizer_state: spec.OptimizerState,
221-
eval_results: List[Tuple[int, float]],
222-
global_step: int,
223-
rng: spec.RandomState) -> spec.UpdateReturn:
230+
def prepare_for_eval(
231+
workload: spec.Workload,
232+
current_param_container: spec.ParameterContainer,
233+
current_params_types: spec.ParameterTypeTree,
234+
model_state: spec.ModelAuxiliaryState,
235+
hyperparameters: spec.Hyperparameters,
236+
loss_type: spec.LossType,
237+
optimizer_state: spec.OptimizerState,
238+
eval_results: List[Tuple[int, float]],
239+
global_step: int,
240+
rng: spec.RandomState,
241+
) -> spec.UpdateReturn:
224242
"""Return (updated_optimizer_state, updated_params)."""
225243
del workload
226244
del hyperparameters
@@ -234,8 +252,8 @@ def prepare_for_eval(workload: spec.Workload,
234252

235253
def get_batch_size(workload_name):
236254
# Return the global batch size.
237-
if hasattr(HPARAMS, "batch_size"):
238-
return HPARAMS.batch_size
255+
if hasattr(HPARAMS, 'batch_size'):
256+
return HPARAMS.batch_size
239257
if workload_name == 'criteo1tb':
240258
return 262_144
241259
elif workload_name == 'fastmri':
@@ -262,14 +280,16 @@ def get_batch_size(workload_name):
262280
raise ValueError(f'Unsupported workload name: {workload_name}.')
263281

264282

265-
def data_selection(workload: spec.Workload,
266-
input_queue: Iterator[Dict[str, spec.Tensor]],
267-
optimizer_state: spec.OptimizerState,
268-
current_param_container: spec.ParameterContainer,
269-
model_state: spec.ModelAuxiliaryState,
270-
hyperparameters: spec.Hyperparameters,
271-
global_step: int,
272-
rng: spec.RandomState) -> Dict[str, spec.Tensor]:
283+
def data_selection(
284+
workload: spec.Workload,
285+
input_queue: Iterator[Dict[str, spec.Tensor]],
286+
optimizer_state: spec.OptimizerState,
287+
current_param_container: spec.ParameterContainer,
288+
model_state: spec.ModelAuxiliaryState,
289+
hyperparameters: spec.Hyperparameters,
290+
global_step: int,
291+
rng: spec.RandomState,
292+
) -> Dict[str, spec.Tensor]:
273293
"""Select data from the infinitely repeating, pre-shuffled input queue.
274294
Each element of the queue is a batch of training examples and labels.
275295
"""

0 commit comments

Comments
 (0)