Skip to content

Commit f185372

Browse files
committed
Fix all lint and format issues
1 parent f06d2b1 commit f185372

File tree

181 files changed

+15217
-12669
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

181 files changed

+15217
-12669
lines changed

algoperf/checkpoint_utils.py

Lines changed: 111 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -7,37 +7,41 @@
77
import os
88
from typing import Sequence, Tuple
99

10+
import jax
11+
import numpy as np
12+
import torch
1013
from absl import logging
1114
from flax import jax_utils
1215
from flax.training import checkpoints as flax_checkpoints
1316
from flax.training.checkpoints import latest_checkpoint
14-
import jax
15-
import numpy as np
1617
from tensorflow.io import gfile # pytype: disable=import-error
17-
import torch
1818

1919
from algoperf import spec
2020
from algoperf.pytorch_utils import pytorch_setup
2121

2222
_, _, DEVICE, _ = pytorch_setup()
23-
CheckpointReturn = Tuple[spec.OptimizerState,
24-
spec.ParameterContainer,
25-
spec.ModelAuxiliaryState,
26-
dict,
27-
list,
28-
int,
29-
int]
30-
31-
32-
def maybe_restore_checkpoint(framework: str,
33-
optimizer_state: spec.OptimizerState,
34-
model_params: spec.ParameterContainer,
35-
model_state: spec.ModelAuxiliaryState,
36-
train_state: dict,
37-
eval_results: list,
38-
global_step: int,
39-
preemption_count: int,
40-
checkpoint_dir: str) -> CheckpointReturn:
23+
CheckpointReturn = Tuple[
24+
spec.OptimizerState,
25+
spec.ParameterContainer,
26+
spec.ModelAuxiliaryState,
27+
dict,
28+
list,
29+
int,
30+
int,
31+
]
32+
33+
34+
def maybe_restore_checkpoint(
35+
framework: str,
36+
optimizer_state: spec.OptimizerState,
37+
model_params: spec.ParameterContainer,
38+
model_state: spec.ModelAuxiliaryState,
39+
train_state: dict,
40+
eval_results: list,
41+
global_step: int,
42+
preemption_count: int,
43+
checkpoint_dir: str,
44+
) -> CheckpointReturn:
4145
"""Optionally restores from a checkpoint.
4246
4347
The checkpoint logic is as follows: if there is a checkpoint in
@@ -69,20 +73,22 @@ def maybe_restore_checkpoint(framework: str,
6973
uninitialized_global_step = -1
7074
uninitialized_preemption_count = -1
7175
checkpoint_state = {
72-
'model_params': model_params,
73-
'optimizer_state': opt_state,
74-
'model_state': model_state,
75-
'train_state': train_state,
76-
'eval_results': None,
77-
'global_step': uninitialized_global_step,
78-
'preemption_count': uninitialized_preemption_count,
76+
'model_params': model_params,
77+
'optimizer_state': opt_state,
78+
'model_state': model_state,
79+
'train_state': train_state,
80+
'eval_results': None,
81+
'global_step': uninitialized_global_step,
82+
'preemption_count': uninitialized_preemption_count,
7983
}
8084

8185
if framework == 'jax':
8286
latest_ckpt = flax_checkpoints.restore_checkpoint(
83-
checkpoint_dir, target=checkpoint_state)
84-
save_path = os.path.join(checkpoint_dir,
85-
'checkpoint_' + str(latest_ckpt['global_step']))
87+
checkpoint_dir, target=checkpoint_state
88+
)
89+
save_path = os.path.join(
90+
checkpoint_dir, 'checkpoint_' + str(latest_ckpt['global_step'])
91+
)
8692
else:
8793
latest_ckpt = checkpoint_state
8894
save_path = latest_checkpoint(checkpoint_dir)
@@ -94,55 +100,64 @@ def maybe_restore_checkpoint(framework: str,
94100
found_checkpoint = latest_ckpt['global_step'] != uninitialized_global_step
95101

96102
if not found_checkpoint:
97-
return (optimizer_state,
98-
model_params,
99-
model_state,
100-
train_state,
101-
eval_results,
102-
global_step,
103-
preemption_count)
103+
return (
104+
optimizer_state,
105+
model_params,
106+
model_state,
107+
train_state,
108+
eval_results,
109+
global_step,
110+
preemption_count,
111+
)
104112

105113
# If there's the latest checkpoint in the checkpoint_dir, restore from that.
106114
if framework == 'jax':
107115
checkpoint_state = replicate_checkpoint(
108-
latest_ckpt,
109-
pytree_keys=[
110-
'optimizer_state',
111-
'model_params',
112-
'model_state',
113-
])
114-
checkpoint_state['optimizer_state'] = (checkpoint_state['optimizer_state'],
115-
opt_update_fn)
116+
latest_ckpt,
117+
pytree_keys=[
118+
'optimizer_state',
119+
'model_params',
120+
'model_state',
121+
],
122+
)
123+
checkpoint_state['optimizer_state'] = (
124+
checkpoint_state['optimizer_state'],
125+
opt_update_fn,
126+
)
116127
checkpoint_state['eval_results'] = [
117-
(value, key) for key, value in latest_ckpt['eval_results'].items()
128+
(value, key) for key, value in latest_ckpt['eval_results'].items()
118129
]
119130

120131
else:
121132
checkpoint_state = latest_ckpt
122133
if isinstance(
123-
model_params,
124-
(torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)):
134+
model_params,
135+
(torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel),
136+
):
125137
model_params = model_params.module
126138
model_params.load_state_dict(checkpoint_state['model_params'])
127139
checkpoint_state['model_params'] = model_params
128140
for key in optimizer_state.keys():
129141
optimizer_state[key].load_state_dict(
130-
checkpoint_state['optimizer_state'][key])
142+
checkpoint_state['optimizer_state'][key]
143+
)
131144
checkpoint_state['optimizer_state'][key] = optimizer_state[key]
132145

133146
logging.info(f'Loaded checkpoint from {save_path}.')
134-
return (checkpoint_state['optimizer_state'],
135-
checkpoint_state['model_params'],
136-
checkpoint_state['model_state'],
137-
checkpoint_state['train_state'],
138-
list(checkpoint_state['eval_results']),
139-
checkpoint_state['global_step'],
140-
checkpoint_state['preemption_count'] + 1)
141-
142-
143-
def replicate_checkpoint(latest: dict,
144-
pytree_keys: Sequence[str],
145-
replicate: bool = True) -> dict:
147+
return (
148+
checkpoint_state['optimizer_state'],
149+
checkpoint_state['model_params'],
150+
checkpoint_state['model_state'],
151+
checkpoint_state['train_state'],
152+
list(checkpoint_state['eval_results']),
153+
checkpoint_state['global_step'],
154+
checkpoint_state['preemption_count'] + 1,
155+
)
156+
157+
158+
def replicate_checkpoint(
159+
latest: dict, pytree_keys: Sequence[str], replicate: bool = True
160+
) -> dict:
146161
"""Restores from the provided checkpoint.
147162
148163
Args:
@@ -163,16 +178,18 @@ def replicate_checkpoint(latest: dict,
163178
return pytree
164179

165180

166-
def save_checkpoint(framework: str,
167-
optimizer_state: spec.OptimizerState,
168-
model_params: spec.ParameterContainer,
169-
model_state: spec.ModelAuxiliaryState,
170-
train_state: dict,
171-
eval_results: list,
172-
global_step: int,
173-
preemption_count: int,
174-
checkpoint_dir: str,
175-
save_intermediate_checkpoints: bool) -> None:
181+
def save_checkpoint(
182+
framework: str,
183+
optimizer_state: spec.OptimizerState,
184+
model_params: spec.ParameterContainer,
185+
model_state: spec.ModelAuxiliaryState,
186+
train_state: dict,
187+
eval_results: list,
188+
global_step: int,
189+
preemption_count: int,
190+
checkpoint_dir: str,
191+
save_intermediate_checkpoints: bool,
192+
) -> None:
176193
"""Save the checkpoint in `checkpoint_dir`.
177194
178195
Args:
@@ -199,8 +216,9 @@ def save_checkpoint(framework: str,
199216
model_state = jax.device_get(jax_utils.unreplicate(model_state))
200217
else:
201218
if isinstance(
202-
model_params,
203-
(torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)):
219+
model_params,
220+
(torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel),
221+
):
204222
model_params = model_params.module
205223
model_params = model_params.state_dict()
206224
optimizer_state_dict = {}
@@ -209,33 +227,36 @@ def save_checkpoint(framework: str,
209227
optimizer_state_dict[key] = optimizer_state[key].state_dict()
210228
else:
211229
logging.warning(
212-
f'The optimizer state for key {key} is not saved, because '
213-
f'{type(optimizer_state[key])} has not implemented a state_dict() '
214-
'method.')
230+
f'The optimizer state for key {key} is not saved, because '
231+
f'{type(optimizer_state[key])} has not implemented a state_dict() '
232+
'method.'
233+
)
215234
opt_state = optimizer_state_dict
216235

217236
checkpoint_state = {
218-
'model_params': model_params,
219-
'optimizer_state': opt_state,
220-
'model_state': model_state,
221-
'train_state': train_state,
222-
'eval_results': tuple(eval_results),
223-
'global_step': global_step,
224-
'preemption_count': preemption_count,
237+
'model_params': model_params,
238+
'optimizer_state': opt_state,
239+
'model_state': model_state,
240+
'train_state': train_state,
241+
'eval_results': tuple(eval_results),
242+
'global_step': global_step,
243+
'preemption_count': preemption_count,
225244
}
226245

227246
save_path = os.path.join(checkpoint_dir, f'checkpoint_{global_step}')
228247
if framework == 'jax':
229248
flax_checkpoints.save_checkpoint(
230-
checkpoint_dir,
231-
target=checkpoint_state,
232-
step=global_step,
233-
overwrite=True,
234-
keep=np.Inf if save_intermediate_checkpoints else 1)
249+
checkpoint_dir,
250+
target=checkpoint_state,
251+
step=global_step,
252+
overwrite=True,
253+
keep=np.Inf if save_intermediate_checkpoints else 1,
254+
)
235255
else:
236256
if not save_intermediate_checkpoints:
237257
checkpoint_files = gfile.glob(
238-
os.path.join(checkpoint_dir, 'checkpoint_*'))
258+
os.path.join(checkpoint_dir, 'checkpoint_*')
259+
)
239260
for path in checkpoint_files:
240261
logging.info('Removing checkpoint at %s', path)
241262
gfile.rmtree(path)

0 commit comments

Comments
 (0)