Skip to content

Commit 24d9815

Browse files
Merge pull request #886 from mlcommons/dropout_fixes
Dropout fixes
2 parents 484e66c + 01cfbb0 commit 24d9815

File tree

5 files changed

+5
-8
lines changed

5 files changed

+5
-8
lines changed

algoperf/workloads/criteo1tb/criteo1tb_jax/models.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -223,11 +223,7 @@ def scaled_init(key, shape, dtype=jnp.float_):
223223
top_mlp_input = nn.relu(top_mlp_input)
224224
if self.use_layer_norm:
225225
top_mlp_input = nn.LayerNorm()(top_mlp_input)
226-
if (
227-
dropout_rate is not None
228-
and dropout_rate > 0.0
229-
and layer_idx == num_layers_top - 2
230-
):
226+
if dropout_rate is not None and layer_idx == num_layers_top - 2:
231227
top_mlp_input = Dropout(dropout_rate, deterministic=not train)(
232228
top_mlp_input, rate=dropout_rate
233229
)

algoperf/workloads/imagenet_resnet/imagenet_jax/models.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ def __call__(
9191
use_running_average_bn: Optional[bool] = None,
9292
) -> spec.Tensor:
9393
conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)
94-
9594
# Preserve default behavior for backwards compatibility
9695
if use_running_average_bn is None:
9796
use_running_average_bn = not update_batch_norm

algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,9 +154,11 @@ def model_fn(
154154
rng: spec.RandomState,
155155
update_batch_norm: bool,
156156
use_running_average_bn: Optional[bool] = None,
157+
dropout_rate: Optional[float] = None,
157158
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
158159
del mode
159160
del rng
161+
del dropout_rate
160162
variables = {'params': params, **model_state}
161163
if update_batch_norm:
162164
logits, new_model_state = self._model.apply(

algoperf/workloads/librispeech_conformer/librispeech_jax/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ def __call__(self, inputs, paddings, train, dropout_rate=DROPOUT_RATE):
428428
use_bias=True,
429429
broadcast_dropout=False,
430430
attention_fn=attention_fn,
431-
dropout_rate=dropout_rate,
431+
dropout_rate=0.0,
432432
deterministic=not train,
433433
)(inputs_q=inputs, mask=attention_mask)
434434

algoperf/workloads/wmt/wmt_jax/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def __call__(self, inputs, encoder_mask=None, dropout_rate=DROPOUT_RATE):
223223
bias_init=cfg.bias_init,
224224
use_bias=False,
225225
broadcast_dropout=False,
226-
dropout_rate=dropout_rate,
226+
dropout_rate=0.0,
227227
deterministic=cfg.deterministic,
228228
)(cfg.attention_temp * x, x, mask=encoder_mask)
229229

0 commit comments

Comments
 (0)