1818
1919# default Lion parameters
2020HPARAMS = {
21- " dropout_rate" : 0.1 ,
22- " learning_rate" : 2e-4 ,
23- " one_minus_beta1" : 0.05 ,
24- " beta2" : 0.98 ,
25- " weight_decay" : 0.5 ,
26- " warmup_factor" : 0.02
21+ ' dropout_rate' : 0.1 ,
22+ ' learning_rate' : 2e-4 ,
23+ ' one_minus_beta1' : 0.05 ,
24+ ' beta2' : 0.98 ,
25+ ' weight_decay' : 0.5 ,
26+ ' warmup_factor' : 0.02 ,
2727}
2828HPARAMS = collections .namedtuple ('Hyperparameters' , HPARAMS .keys ())(** HPARAMS )
2929
30+
3031# Modified from https://github.com/google/automl/blob/master/lion/lion_pytorch.py.
3132class Lion (Optimizer ):
3233 def __init__ (
@@ -90,11 +91,13 @@ def step(self, closure=None):
9091 return loss
9192
9293
93- def init_optimizer_state (workload : spec .Workload ,
94- model_params : spec .ParameterContainer ,
95- model_state : spec .ModelAuxiliaryState ,
96- hyperparameters : spec .Hyperparameters ,
97- rng : spec .RandomState ) -> spec .OptimizerState :
94+ def init_optimizer_state (
95+ workload : spec .Workload ,
96+ model_params : spec .ParameterContainer ,
97+ model_state : spec .ModelAuxiliaryState ,
98+ hyperparameters : spec .Hyperparameters ,
99+ rng : spec .RandomState ,
100+ ) -> spec .OptimizerState :
98101 """Creates a Lion optimizer and a learning rate schedule."""
99102 del model_state
100103 del rng
@@ -103,44 +106,47 @@ def init_optimizer_state(workload: spec.Workload,
103106 hyperparameters = HPARAMS
104107
105108 optimizer_state = {
106- 'optimizer' :
107- Lion (
108- model_params .parameters (),
109- lr = HPARAMS .learning_rate ,
110- betas = (1.0 - HPARAMS .one_minus_beta1 ,
111- HPARAMS .beta2 ),
112- weight_decay = HPARAMS .weight_decay )
109+ 'optimizer' : Lion (
110+ model_params .parameters (),
111+ lr = HPARAMS .learning_rate ,
112+ betas = (1.0 - HPARAMS .one_minus_beta1 , HPARAMS .beta2 ),
113+ weight_decay = HPARAMS .weight_decay ,
114+ )
113115 }
114116
115117 def pytorch_cosine_warmup (step_hint : int , hyperparameters , optimizer ):
116118 warmup_steps = int (hyperparameters .warmup_factor * step_hint )
117119 warmup = LinearLR (
118- optimizer , start_factor = 1e-10 , end_factor = 1. , total_iters = warmup_steps )
120+ optimizer , start_factor = 1e-10 , end_factor = 1.0 , total_iters = warmup_steps
121+ )
119122 cosine_steps = max (step_hint - warmup_steps , 1 )
120123 cosine_decay = CosineAnnealingLR (optimizer , T_max = cosine_steps )
121124 return SequentialLR (
122- optimizer , schedulers = [warmup , cosine_decay ], milestones = [warmup_steps ])
125+ optimizer , schedulers = [warmup , cosine_decay ], milestones = [warmup_steps ]
126+ )
123127
124128 optimizer_state ['scheduler' ] = pytorch_cosine_warmup (
125- workload .step_hint , HPARAMS , optimizer_state ['optimizer' ])
129+ workload .step_hint , HPARAMS , optimizer_state ['optimizer' ]
130+ )
126131 optimizer_state ['hyperparameters' ] = hyperparameters
127132
128133 return optimizer_state
129134
130135
131136def update_params (
132- workload : spec .Workload ,
133- current_param_container : spec .ParameterContainer ,
134- current_params_types : spec .ParameterTypeTree ,
135- model_state : spec .ModelAuxiliaryState ,
136- hyperparameters : spec .Hyperparameters ,
137- batch : Dict [str , spec .Tensor ],
138- loss_type : spec .LossType ,
139- optimizer_state : spec .OptimizerState ,
140- eval_results : List [Tuple [int , float ]],
141- global_step : int ,
142- rng : spec .RandomState ,
143- train_state : Optional [Dict [str , Any ]] = None ) -> spec .UpdateReturn :
137+ workload : spec .Workload ,
138+ current_param_container : spec .ParameterContainer ,
139+ current_params_types : spec .ParameterTypeTree ,
140+ model_state : spec .ModelAuxiliaryState ,
141+ hyperparameters : spec .Hyperparameters ,
142+ batch : Dict [str , spec .Tensor ],
143+ loss_type : spec .LossType ,
144+ optimizer_state : spec .OptimizerState ,
145+ eval_results : List [Tuple [int , float ]],
146+ global_step : int ,
147+ rng : spec .RandomState ,
148+ train_state : Optional [Dict [str , Any ]] = None ,
149+ ) -> spec .UpdateReturn :
144150 """Return (updated_optimizer_state, updated_params, updated_model_state)."""
145151 del current_params_types
146152 del loss_type
@@ -155,26 +161,30 @@ def update_params(
155161 optimizer_state ['optimizer' ].zero_grad ()
156162
157163 logits_batch , new_model_state = workload .model_fn (
158- params = current_model ,
159- augmented_and_preprocessed_input_batch = batch ,
160- model_state = model_state ,
161- mode = spec .ForwardPassMode .TRAIN ,
162- rng = rng ,
163- update_batch_norm = True )
164+ params = current_model ,
165+ augmented_and_preprocessed_input_batch = batch ,
166+ model_state = model_state ,
167+ mode = spec .ForwardPassMode .TRAIN ,
168+ rng = rng ,
169+ update_batch_norm = True ,
170+ )
164171
165172 label_smoothing = (
166- hyperparameters .label_smoothing if hasattr (HPARAMS ,
167- 'label_smoothing' ) else 0.0 )
173+ hyperparameters .label_smoothing
174+ if hasattr (HPARAMS , 'label_smoothing' )
175+ else 0.0
176+ )
168177 if hasattr (hyperparameters , 'grad_clip' ):
169178 grad_clip = hyperparameters .grad_clip
170179 else :
171180 grad_clip = None
172181
173182 loss_dict = workload .loss_fn (
174- label_batch = batch ['targets' ],
175- logits_batch = logits_batch ,
176- mask_batch = batch .get ('weights' ),
177- label_smoothing = label_smoothing )
183+ label_batch = batch ['targets' ],
184+ logits_batch = logits_batch ,
185+ mask_batch = batch .get ('weights' ),
186+ label_smoothing = label_smoothing ,
187+ )
178188 summed_loss = loss_dict ['summed' ]
179189 n_valid_examples = loss_dict ['n_valid_examples' ]
180190 if USE_PYTORCH_DDP :
@@ -187,7 +197,8 @@ def update_params(
187197
188198 if grad_clip is not None :
189199 torch .nn .utils .clip_grad_norm_ (
190- current_model .parameters (), max_norm = grad_clip )
200+ current_model .parameters (), max_norm = grad_clip
201+ )
191202 optimizer_state ['optimizer' ].step ()
192203 optimizer_state ['scheduler' ].step ()
193204
@@ -196,31 +207,38 @@ def update_params(
196207 with torch .no_grad ():
197208 parameters = [p for p in current_model .parameters () if p .grad is not None ]
198209 grad_norm = torch .norm (
199- torch .stack ([torch .norm (p .grad .detach (), 2 ) for p in parameters ]), 2 )
210+ torch .stack ([torch .norm (p .grad .detach (), 2 ) for p in parameters ]), 2
211+ )
200212 if workload .metrics_logger is not None :
201213 workload .metrics_logger .append_scalar_metrics (
202- {
203- 'loss' : loss .item (),
204- 'grad_norm' : grad_norm .item (),
205- }, global_step )
206- logging .info ('%d) loss = %0.3f, grad_norm = %0.3f' ,
207- global_step ,
208- loss .item (),
209- grad_norm .item ())
214+ {
215+ 'loss' : loss .item (),
216+ 'grad_norm' : grad_norm .item (),
217+ },
218+ global_step ,
219+ )
220+ logging .info (
221+ '%d) loss = %0.3f, grad_norm = %0.3f' ,
222+ global_step ,
223+ loss .item (),
224+ grad_norm .item (),
225+ )
210226
211227 return (optimizer_state , current_param_container , new_model_state )
212228
213229
214- def prepare_for_eval (workload : spec .Workload ,
215- current_param_container : spec .ParameterContainer ,
216- current_params_types : spec .ParameterTypeTree ,
217- model_state : spec .ModelAuxiliaryState ,
218- hyperparameters : spec .Hyperparameters ,
219- loss_type : spec .LossType ,
220- optimizer_state : spec .OptimizerState ,
221- eval_results : List [Tuple [int , float ]],
222- global_step : int ,
223- rng : spec .RandomState ) -> spec .UpdateReturn :
230+ def prepare_for_eval (
231+ workload : spec .Workload ,
232+ current_param_container : spec .ParameterContainer ,
233+ current_params_types : spec .ParameterTypeTree ,
234+ model_state : spec .ModelAuxiliaryState ,
235+ hyperparameters : spec .Hyperparameters ,
236+ loss_type : spec .LossType ,
237+ optimizer_state : spec .OptimizerState ,
238+ eval_results : List [Tuple [int , float ]],
239+ global_step : int ,
240+ rng : spec .RandomState ,
241+ ) -> spec .UpdateReturn :
224242 """Return (updated_optimizer_state, updated_params)."""
225243 del workload
226244 del hyperparameters
@@ -234,8 +252,8 @@ def prepare_for_eval(workload: spec.Workload,
234252
235253def get_batch_size (workload_name ):
236254 # Return the global batch size.
237- if hasattr (HPARAMS , " batch_size" ):
238- return HPARAMS .batch_size
255+ if hasattr (HPARAMS , ' batch_size' ):
256+ return HPARAMS .batch_size
239257 if workload_name == 'criteo1tb' :
240258 return 262_144
241259 elif workload_name == 'fastmri' :
@@ -262,14 +280,16 @@ def get_batch_size(workload_name):
262280 raise ValueError (f'Unsupported workload name: { workload_name } .' )
263281
264282
265- def data_selection (workload : spec .Workload ,
266- input_queue : Iterator [Dict [str , spec .Tensor ]],
267- optimizer_state : spec .OptimizerState ,
268- current_param_container : spec .ParameterContainer ,
269- model_state : spec .ModelAuxiliaryState ,
270- hyperparameters : spec .Hyperparameters ,
271- global_step : int ,
272- rng : spec .RandomState ) -> Dict [str , spec .Tensor ]:
283+ def data_selection (
284+ workload : spec .Workload ,
285+ input_queue : Iterator [Dict [str , spec .Tensor ]],
286+ optimizer_state : spec .OptimizerState ,
287+ current_param_container : spec .ParameterContainer ,
288+ model_state : spec .ModelAuxiliaryState ,
289+ hyperparameters : spec .Hyperparameters ,
290+ global_step : int ,
291+ rng : spec .RandomState ,
292+ ) -> Dict [str , spec .Tensor ]:
273293 """Select data from the infinitely repeating, pre-shuffled input queue.
274294 Each element of the queue is a batch of training examples and labels.
275295 """
0 commit comments