Skip to content

Commit 15d9243

Browse files
committed
Lint algoperf/
1 parent 1318c91 commit 15d9243

File tree

61 files changed

+132
-183
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+132
-183
lines changed

algoperf/checkpoint_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77
import os
88
from typing import Sequence, Tuple
99

10+
import jax
11+
import numpy as np
12+
import torch
1013
from absl import logging
1114
from flax import jax_utils
1215
from flax.training import checkpoints as flax_checkpoints
1316
from flax.training.checkpoints import latest_checkpoint
14-
import jax
15-
import numpy as np
1617
from tensorflow.io import gfile # pytype: disable=import-error
17-
import torch
1818

1919
from algoperf import spec
2020
from algoperf.pytorch_utils import pytorch_setup

algoperf/data_utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77
import torch
88
import torch.distributed as dist
99
import torch.nn.functional as F
10-
from torch.utils.data import DataLoader
11-
from torch.utils.data import DistributedSampler
12-
from torch.utils.data import Sampler
10+
from torch.utils.data import DataLoader, DistributedSampler, Sampler
1311

1412
from algoperf import spec
1513

@@ -266,7 +264,7 @@ def prefetched_loader(self) -> Iterable[Tuple[spec.Tensor, spec.Tensor]]:
266264
next_targets = next_targets.to(self.device, non_blocking=True)
267265

268266
if not first:
269-
yield inputs, targets
267+
yield inputs, targets # noqa: F821
270268
else:
271269
first = False
272270

algoperf/logger_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@
1111
import sys
1212
from typing import Any, Dict, Optional
1313

14-
from absl import flags
15-
from clu import metric_writers
1614
import GPUtil
1715
import pandas as pd
1816
import psutil
1917
import torch.distributed as dist
18+
from absl import flags
19+
from clu import metric_writers
2020

2121
from algoperf import spec
2222
from algoperf.pytorch_utils import pytorch_setup
@@ -198,7 +198,7 @@ def _get_system_hardware_info() -> Dict:
198198
try:
199199
system_hardware_info['cpu_model_name'] = _get_cpu_model_name()
200200
system_hardware_info['cpu_count'] = psutil.cpu_count()
201-
except: # pylint: disable=bare-except
201+
except: # noqa: E722
202202
logging.info('Unable to record cpu information. Continuing without it.')
203203

204204
gpus = GPUtil.getGPUs()
@@ -207,7 +207,7 @@ def _get_system_hardware_info() -> Dict:
207207
system_hardware_info['gpu_model_name'] = gpus[0].name
208208
system_hardware_info['gpu_count'] = len(gpus)
209209
system_hardware_info['gpu_driver'] = gpus[0].driver
210-
except: # pylint: disable=bare-except
210+
except: # noqa: E722
211211
logging.info('Unable to record gpu information. Continuing without it.')
212212

213213
return system_hardware_info
@@ -232,7 +232,7 @@ def _get_system_software_info() -> Dict:
232232
system_software_info['git_commit_hash'] = _get_git_commit_hash()
233233
# Note: do not store git repo url as it may be sensitive or contain a
234234
# secret.
235-
except: # pylint: disable=bare-except
235+
except: # noqa: E722
236236
logging.info('Unable to record git information. Continuing without it.')
237237

238238
return system_software_info
@@ -279,7 +279,7 @@ def _get_workload_properties(workload: spec.Workload) -> Dict:
279279
for key in keys:
280280
try:
281281
attr = getattr(workload, key)
282-
except: # pylint: disable=bare-except
282+
except: # noqa: E722
283283
logging.info(
284284
f'Unable to record workload.{key} information. Continuing without it.'
285285
)

algoperf/profiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
https://github.com/Lightning-AI/lightning/tree/master/src/pytorch_lightning/profilers.
55
"""
66

7-
from collections import defaultdict
8-
from contextlib import contextmanager
97
import os
108
import time
9+
from collections import defaultdict
10+
from contextlib import contextmanager
1111
from typing import Dict, Generator, List, Optional, Tuple
1212

1313
import numpy as np

algoperf/pytorch_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import os
22
from typing import Tuple
33

4-
from absl import logging
54
import jax
65
import tensorflow as tf
76
import torch
87
import torch.distributed as dist
8+
from absl import logging
99

1010
from algoperf import spec
1111
from algoperf.profiler import Profiler

algoperf/random_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22

33
from typing import Any, List, Union
44

5-
from absl import flags
6-
from absl import logging
75
import numpy as np
6+
from absl import flags, logging
87

98
try:
109
import jax.random as jax_rng

algoperf/spec.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
import functools
66
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
77

8-
from absl import logging
98
import jax
10-
from torch import nn
119
import torch.nn.functional as F
10+
from absl import logging
11+
from torch import nn
1212

1313

1414
class LossType(enum.Enum):

algoperf/workloads/cifar/cifar_jax/input_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
import functools
99
from typing import Dict, Iterator, Tuple
1010

11-
from flax import jax_utils
1211
import jax
1312
import tensorflow as tf
1413
import tensorflow_datasets as tfds
14+
from flax import jax_utils
1515

1616
from algoperf import spec
1717
from algoperf.data_utils import shard_and_maybe_pad_np

algoperf/workloads/cifar/cifar_jax/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
import functools
88
from typing import Any, Callable, Tuple
99

10-
from flax import linen as nn
1110
import jax.numpy as jnp
11+
from flax import linen as nn
1212

1313
from algoperf import spec
1414
from algoperf.workloads.imagenet_resnet.imagenet_jax.models import ResNetBlock

algoperf/workloads/cifar/cifar_jax/workload.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,16 @@
33
import functools
44
from typing import Any, Dict, Iterator, Optional, Tuple
55

6-
from flax import jax_utils
7-
from flax import linen as nn
8-
from flax.core import pop
96
import jax
10-
from jax import lax
117
import jax.numpy as jnp
128
import optax
139
import tensorflow_datasets as tfds
10+
from flax import jax_utils
11+
from flax import linen as nn
12+
from flax.core import pop
13+
from jax import lax
1414

15-
from algoperf import param_utils
16-
from algoperf import spec
15+
from algoperf import param_utils, spec
1716
from algoperf.workloads.cifar.cifar_jax import models
1817
from algoperf.workloads.cifar.cifar_jax.input_pipeline import create_input_iter
1918
from algoperf.workloads.cifar.workload import BaseCifarWorkload

0 commit comments

Comments
 (0)