77import os
88from typing import Sequence , Tuple
99
10+ import jax
11+ import numpy as np
12+ import torch
1013from absl import logging
1114from flax import jax_utils
1215from flax .training import checkpoints as flax_checkpoints
1316from flax .training .checkpoints import latest_checkpoint
14- import jax
15- import numpy as np
1617from tensorflow .io import gfile # pytype: disable=import-error
17- import torch
1818
1919from algoperf import spec
2020from algoperf .pytorch_utils import pytorch_setup
2121
2222_ , _ , DEVICE , _ = pytorch_setup ()
23- CheckpointReturn = Tuple [spec .OptimizerState ,
24- spec .ParameterContainer ,
25- spec .ModelAuxiliaryState ,
26- dict ,
27- list ,
28- int ,
29- int ]
30-
31-
32- def maybe_restore_checkpoint (framework : str ,
33- optimizer_state : spec .OptimizerState ,
34- model_params : spec .ParameterContainer ,
35- model_state : spec .ModelAuxiliaryState ,
36- train_state : dict ,
37- eval_results : list ,
38- global_step : int ,
39- preemption_count : int ,
40- checkpoint_dir : str ) -> CheckpointReturn :
23+ CheckpointReturn = Tuple [
24+ spec .OptimizerState ,
25+ spec .ParameterContainer ,
26+ spec .ModelAuxiliaryState ,
27+ dict ,
28+ list ,
29+ int ,
30+ int ,
31+ ]
32+
33+
34+ def maybe_restore_checkpoint (
35+ framework : str ,
36+ optimizer_state : spec .OptimizerState ,
37+ model_params : spec .ParameterContainer ,
38+ model_state : spec .ModelAuxiliaryState ,
39+ train_state : dict ,
40+ eval_results : list ,
41+ global_step : int ,
42+ preemption_count : int ,
43+ checkpoint_dir : str ,
44+ ) -> CheckpointReturn :
4145 """Optionally restores from a checkpoint.
4246
4347 The checkpoint logic is as follows: if there is a checkpoint in
@@ -69,20 +73,22 @@ def maybe_restore_checkpoint(framework: str,
6973 uninitialized_global_step = - 1
7074 uninitialized_preemption_count = - 1
7175 checkpoint_state = {
72- 'model_params' : model_params ,
73- 'optimizer_state' : opt_state ,
74- 'model_state' : model_state ,
75- 'train_state' : train_state ,
76- 'eval_results' : None ,
77- 'global_step' : uninitialized_global_step ,
78- 'preemption_count' : uninitialized_preemption_count ,
76+ 'model_params' : model_params ,
77+ 'optimizer_state' : opt_state ,
78+ 'model_state' : model_state ,
79+ 'train_state' : train_state ,
80+ 'eval_results' : None ,
81+ 'global_step' : uninitialized_global_step ,
82+ 'preemption_count' : uninitialized_preemption_count ,
7983 }
8084
8185 if framework == 'jax' :
8286 latest_ckpt = flax_checkpoints .restore_checkpoint (
83- checkpoint_dir , target = checkpoint_state )
84- save_path = os .path .join (checkpoint_dir ,
85- 'checkpoint_' + str (latest_ckpt ['global_step' ]))
87+ checkpoint_dir , target = checkpoint_state
88+ )
89+ save_path = os .path .join (
90+ checkpoint_dir , 'checkpoint_' + str (latest_ckpt ['global_step' ])
91+ )
8692 else :
8793 latest_ckpt = checkpoint_state
8894 save_path = latest_checkpoint (checkpoint_dir )
@@ -94,55 +100,64 @@ def maybe_restore_checkpoint(framework: str,
94100 found_checkpoint = latest_ckpt ['global_step' ] != uninitialized_global_step
95101
96102 if not found_checkpoint :
97- return (optimizer_state ,
98- model_params ,
99- model_state ,
100- train_state ,
101- eval_results ,
102- global_step ,
103- preemption_count )
103+ return (
104+ optimizer_state ,
105+ model_params ,
106+ model_state ,
107+ train_state ,
108+ eval_results ,
109+ global_step ,
110+ preemption_count ,
111+ )
104112
105113 # If there's the latest checkpoint in the checkpoint_dir, restore from that.
106114 if framework == 'jax' :
107115 checkpoint_state = replicate_checkpoint (
108- latest_ckpt ,
109- pytree_keys = [
110- 'optimizer_state' ,
111- 'model_params' ,
112- 'model_state' ,
113- ])
114- checkpoint_state ['optimizer_state' ] = (checkpoint_state ['optimizer_state' ],
115- opt_update_fn )
116+ latest_ckpt ,
117+ pytree_keys = [
118+ 'optimizer_state' ,
119+ 'model_params' ,
120+ 'model_state' ,
121+ ],
122+ )
123+ checkpoint_state ['optimizer_state' ] = (
124+ checkpoint_state ['optimizer_state' ],
125+ opt_update_fn ,
126+ )
116127 checkpoint_state ['eval_results' ] = [
117- (value , key ) for key , value in latest_ckpt ['eval_results' ].items ()
128+ (value , key ) for key , value in latest_ckpt ['eval_results' ].items ()
118129 ]
119130
120131 else :
121132 checkpoint_state = latest_ckpt
122133 if isinstance (
123- model_params ,
124- (torch .nn .DataParallel , torch .nn .parallel .DistributedDataParallel )):
134+ model_params ,
135+ (torch .nn .DataParallel , torch .nn .parallel .DistributedDataParallel ),
136+ ):
125137 model_params = model_params .module
126138 model_params .load_state_dict (checkpoint_state ['model_params' ])
127139 checkpoint_state ['model_params' ] = model_params
128140 for key in optimizer_state .keys ():
129141 optimizer_state [key ].load_state_dict (
130- checkpoint_state ['optimizer_state' ][key ])
142+ checkpoint_state ['optimizer_state' ][key ]
143+ )
131144 checkpoint_state ['optimizer_state' ][key ] = optimizer_state [key ]
132145
133146 logging .info (f'Loaded checkpoint from { save_path } .' )
134- return (checkpoint_state ['optimizer_state' ],
135- checkpoint_state ['model_params' ],
136- checkpoint_state ['model_state' ],
137- checkpoint_state ['train_state' ],
138- list (checkpoint_state ['eval_results' ]),
139- checkpoint_state ['global_step' ],
140- checkpoint_state ['preemption_count' ] + 1 )
141-
142-
143- def replicate_checkpoint (latest : dict ,
144- pytree_keys : Sequence [str ],
145- replicate : bool = True ) -> dict :
147+ return (
148+ checkpoint_state ['optimizer_state' ],
149+ checkpoint_state ['model_params' ],
150+ checkpoint_state ['model_state' ],
151+ checkpoint_state ['train_state' ],
152+ list (checkpoint_state ['eval_results' ]),
153+ checkpoint_state ['global_step' ],
154+ checkpoint_state ['preemption_count' ] + 1 ,
155+ )
156+
157+
158+ def replicate_checkpoint (
159+ latest : dict , pytree_keys : Sequence [str ], replicate : bool = True
160+ ) -> dict :
146161 """Restores from the provided checkpoint.
147162
148163 Args:
@@ -163,16 +178,18 @@ def replicate_checkpoint(latest: dict,
163178 return pytree
164179
165180
166- def save_checkpoint (framework : str ,
167- optimizer_state : spec .OptimizerState ,
168- model_params : spec .ParameterContainer ,
169- model_state : spec .ModelAuxiliaryState ,
170- train_state : dict ,
171- eval_results : list ,
172- global_step : int ,
173- preemption_count : int ,
174- checkpoint_dir : str ,
175- save_intermediate_checkpoints : bool ) -> None :
181+ def save_checkpoint (
182+ framework : str ,
183+ optimizer_state : spec .OptimizerState ,
184+ model_params : spec .ParameterContainer ,
185+ model_state : spec .ModelAuxiliaryState ,
186+ train_state : dict ,
187+ eval_results : list ,
188+ global_step : int ,
189+ preemption_count : int ,
190+ checkpoint_dir : str ,
191+ save_intermediate_checkpoints : bool ,
192+ ) -> None :
176193 """Save the checkpoint in `checkpoint_dir`.
177194
178195 Args:
@@ -199,8 +216,9 @@ def save_checkpoint(framework: str,
199216 model_state = jax .device_get (jax_utils .unreplicate (model_state ))
200217 else :
201218 if isinstance (
202- model_params ,
203- (torch .nn .DataParallel , torch .nn .parallel .DistributedDataParallel )):
219+ model_params ,
220+ (torch .nn .DataParallel , torch .nn .parallel .DistributedDataParallel ),
221+ ):
204222 model_params = model_params .module
205223 model_params = model_params .state_dict ()
206224 optimizer_state_dict = {}
@@ -209,33 +227,36 @@ def save_checkpoint(framework: str,
209227 optimizer_state_dict [key ] = optimizer_state [key ].state_dict ()
210228 else :
211229 logging .warning (
212- f'The optimizer state for key { key } is not saved, because '
213- f'{ type (optimizer_state [key ])} has not implemented a state_dict() '
214- 'method.' )
230+ f'The optimizer state for key { key } is not saved, because '
231+ f'{ type (optimizer_state [key ])} has not implemented a state_dict() '
232+ 'method.'
233+ )
215234 opt_state = optimizer_state_dict
216235
217236 checkpoint_state = {
218- 'model_params' : model_params ,
219- 'optimizer_state' : opt_state ,
220- 'model_state' : model_state ,
221- 'train_state' : train_state ,
222- 'eval_results' : tuple (eval_results ),
223- 'global_step' : global_step ,
224- 'preemption_count' : preemption_count ,
237+ 'model_params' : model_params ,
238+ 'optimizer_state' : opt_state ,
239+ 'model_state' : model_state ,
240+ 'train_state' : train_state ,
241+ 'eval_results' : tuple (eval_results ),
242+ 'global_step' : global_step ,
243+ 'preemption_count' : preemption_count ,
225244 }
226245
227246 save_path = os .path .join (checkpoint_dir , f'checkpoint_{ global_step } ' )
228247 if framework == 'jax' :
229248 flax_checkpoints .save_checkpoint (
230- checkpoint_dir ,
231- target = checkpoint_state ,
232- step = global_step ,
233- overwrite = True ,
234- keep = np .Inf if save_intermediate_checkpoints else 1 )
249+ checkpoint_dir ,
250+ target = checkpoint_state ,
251+ step = global_step ,
252+ overwrite = True ,
253+ keep = np .Inf if save_intermediate_checkpoints else 1 ,
254+ )
235255 else :
236256 if not save_intermediate_checkpoints :
237257 checkpoint_files = gfile .glob (
238- os .path .join (checkpoint_dir , 'checkpoint_*' ))
258+ os .path .join (checkpoint_dir , 'checkpoint_*' )
259+ )
239260 for path in checkpoint_files :
240261 logging .info ('Removing checkpoint at %s' , path )
241262 gfile .rmtree (path )
0 commit comments