Skip to content

Commit 594f285

Browse files
committed
pylint fixes
1 parent 0ea37ee commit 594f285

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

algoperf/pytorch_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88
from torch import Tensor
99
import torch.distributed as dist
10-
import torch.nn as nn
10+
from torch import nn
1111
import torch.nn.functional as F
1212

1313
from algoperf import spec
@@ -100,8 +100,8 @@ def __init__(self):
100100
super().__init__()
101101
self._supports_custom_dropout = True
102102

103-
def forward(self, input: Tensor, p: float) -> Tensor:
104-
return F.dropout2d(input, p, training=self.training)
103+
def forward(self, x: Tensor, p: float) -> Tensor:
104+
return F.dropout2d(x, p, training=self.training)
105105

106106

107107
class SequentialWithDropout(nn.Sequential):

algoperf/workloads/imagenet_vit/imagenet_jax/workload.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""ImageNet workload implemented in Jax."""
22

3-
from typing import Dict, Tuple
3+
from typing import Dict, Optional, Tuple
44

55
from flax import jax_utils
66
from flax import linen as nn
@@ -54,10 +54,12 @@ def model_fn(
5454
mode: spec.ForwardPassMode,
5555
rng: spec.RandomState,
5656
update_batch_norm: bool,
57+
use_running_average_bn: Optional[bool] = None,
5758
dropout_rate: float = models.DROPOUT_RATE
5859
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
5960
del model_state
6061
del update_batch_norm
62+
del use_running_average_bn
6163
train = mode == spec.ForwardPassMode.TRAIN
6264
logits = self._model.apply({'params': params},
6365
augmented_and_preprocessed_input_batch['inputs'],

0 commit comments

Comments
 (0)