File tree Expand file tree Collapse file tree 2 files changed +6
-4
lines changed
workloads/imagenet_vit/imagenet_jax Expand file tree Collapse file tree 2 files changed +6
-4
lines changed Original file line number Diff line number Diff line change 77import torch
88from torch import Tensor
99import torch .distributed as dist
10- import torch . nn as nn
10+ from torch import nn
1111import torch .nn .functional as F
1212
1313from algoperf import spec
@@ -100,8 +100,8 @@ def __init__(self):
100100 super ().__init__ ()
101101 self ._supports_custom_dropout = True
102102
103- def forward (self , input : Tensor , p : float ) -> Tensor :
104- return F .dropout2d (input , p , training = self .training )
103+ def forward (self , x : Tensor , p : float ) -> Tensor :
104+ return F .dropout2d (x , p , training = self .training )
105105
106106
107107class SequentialWithDropout (nn .Sequential ):
Original file line number Diff line number Diff line change 11"""ImageNet workload implemented in Jax."""
22
3- from typing import Dict , Tuple
3+ from typing import Dict , Optional , Tuple
44
55from flax import jax_utils
66from flax import linen as nn
@@ -54,10 +54,12 @@ def model_fn(
5454 mode : spec .ForwardPassMode ,
5555 rng : spec .RandomState ,
5656 update_batch_norm : bool ,
57+ use_running_average_bn : Optional [bool ] = None ,
5758 dropout_rate : float = models .DROPOUT_RATE
5859 ) -> Tuple [spec .Tensor , spec .ModelAuxiliaryState ]:
5960 del model_state
6061 del update_batch_norm
62+ del use_running_average_bn
6163 train = mode == spec .ForwardPassMode .TRAIN
6264 logits = self ._model .apply ({'params' : params },
6365 augmented_and_preprocessed_input_batch ['inputs' ],
You can’t perform that action at this time.
0 commit comments