Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions algoperf/workloads/criteo1tb/criteo1tb_jax/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,11 +223,7 @@ def scaled_init(key, shape, dtype=jnp.float_):
top_mlp_input = nn.relu(top_mlp_input)
if self.use_layer_norm:
top_mlp_input = nn.LayerNorm()(top_mlp_input)
if (
dropout_rate is not None
and dropout_rate > 0.0
and layer_idx == num_layers_top - 2
):
if dropout_rate is not None and layer_idx == num_layers_top - 2:
top_mlp_input = Dropout(dropout_rate, deterministic=not train)(
top_mlp_input, rate=dropout_rate
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def __call__(
use_running_average_bn: Optional[bool] = None,
) -> spec.Tensor:
conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)

# Preserve default behavior for backwards compatibility
if use_running_average_bn is None:
use_running_average_bn = not update_batch_norm
Expand Down
2 changes: 2 additions & 0 deletions algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,11 @@ def model_fn(
rng: spec.RandomState,
update_batch_norm: bool,
use_running_average_bn: Optional[bool] = None,
dropout_rate: Optional[float] = None,
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del mode
del rng
del dropout_rate
variables = {'params': params, **model_state}
if update_batch_norm:
logits, new_model_state = self._model.apply(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def __call__(self, inputs, paddings, train, dropout_rate=DROPOUT_RATE):
use_bias=True,
broadcast_dropout=False,
attention_fn=attention_fn,
dropout_rate=dropout_rate,
dropout_rate=0.0,
deterministic=not train,
)(inputs_q=inputs, mask=attention_mask)

Expand Down
2 changes: 1 addition & 1 deletion algoperf/workloads/wmt/wmt_jax/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def __call__(self, inputs, encoder_mask=None, dropout_rate=DROPOUT_RATE):
bias_init=cfg.bias_init,
use_bias=False,
broadcast_dropout=False,
dropout_rate=dropout_rate,
dropout_rate=0.0,
deterministic=cfg.deterministic,
)(cfg.attention_temp * x, x, mask=encoder_mask)

Expand Down