diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py index 706c2b51a..abf6b3ad2 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py @@ -223,11 +223,7 @@ def scaled_init(key, shape, dtype=jnp.float_): top_mlp_input = nn.relu(top_mlp_input) if self.use_layer_norm: top_mlp_input = nn.LayerNorm()(top_mlp_input) - if ( - dropout_rate is not None - and dropout_rate > 0.0 - and layer_idx == num_layers_top - 2 - ): + if dropout_rate is not None and layer_idx == num_layers_top - 2: top_mlp_input = Dropout(dropout_rate, deterministic=not train)( top_mlp_input, rate=dropout_rate ) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py index 84ad4fe21..ee1ddf427 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py @@ -91,7 +91,6 @@ def __call__( use_running_average_bn: Optional[bool] = None, ) -> spec.Tensor: conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) - # Preserve default behavior for backwards compatibility if use_running_average_bn is None: use_running_average_bn = not update_batch_norm diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py index 5e30a2e95..f73a1b26e 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -154,9 +154,11 @@ def model_fn( rng: spec.RandomState, update_batch_norm: bool, use_running_average_bn: Optional[bool] = None, + dropout_rate: Optional[float] = None, ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode del rng + del dropout_rate variables = {'params': params, **model_state} if update_batch_norm: logits, new_model_state = self._model.apply( diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py index 9fc2e39ef..6c275232f 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py @@ -428,7 +428,7 @@ def __call__(self, inputs, paddings, train, dropout_rate=DROPOUT_RATE): use_bias=True, broadcast_dropout=False, attention_fn=attention_fn, - dropout_rate=dropout_rate, + dropout_rate=0.0, deterministic=not train, )(inputs_q=inputs, mask=attention_mask) diff --git a/algoperf/workloads/wmt/wmt_jax/models.py b/algoperf/workloads/wmt/wmt_jax/models.py index 81f2ece4c..a16b99b70 100644 --- a/algoperf/workloads/wmt/wmt_jax/models.py +++ b/algoperf/workloads/wmt/wmt_jax/models.py @@ -223,7 +223,7 @@ def __call__(self, inputs, encoder_mask=None, dropout_rate=DROPOUT_RATE): bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, - dropout_rate=dropout_rate, + dropout_rate=0.0, deterministic=cfg.deterministic, )(cfg.attention_temp * x, x, mask=encoder_mask)