Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions algorithmic_efficiency/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,13 @@ def update_batch_norm_fn(module: spec.ParameterContainer,
)
if isinstance(module, bn_layers):
if not update_batch_norm:
module.eval()
module.momentum_backup = module.momentum
if not hasattr(module, 'momentum_backup'):
module.momentum_backup = module.momentum

# module.momentum can be float or torch.Tensor.
module.momentum = 0. * module.momentum_backup
if torch.is_tensor(module.momentum_backup):
module.momentum = torch.zeros_like(module.momentum_backup)
else:
module.momentum = 0.0
elif hasattr(module, 'momentum_backup'):
module.momentum = module.momentum_backup
module.track_running_stats = update_batch_norm
9 changes: 7 additions & 2 deletions algorithmic_efficiency/workloads/cifar/cifar_jax/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,16 @@ class ResNet(nn.Module):
@nn.compact
def __call__(self,
x: spec.Tensor,
update_batch_norm: bool = True) -> spec.Tensor:
update_batch_norm: bool = True,
use_running_average_bn: bool = None) -> spec.Tensor:
conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)

# Preserve default behavior for backwards compatibility
if use_running_average_bn is None:
use_running_average_bn = not update_batch_norm
norm = functools.partial(
nn.BatchNorm,
use_running_average=not update_batch_norm,
use_running_average=use_running_average_bn,
momentum=0.9,
epsilon=1e-5,
dtype=self.dtype)
Expand Down
10 changes: 7 additions & 3 deletions algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@ def model_fn(
model_state: spec.ModelAuxiliaryState,
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
update_batch_norm: bool,
use_running_average_bn: Optional[bool] = None
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del mode
del rng
variables = {'params': params, **model_state}
Expand All @@ -119,14 +121,16 @@ def model_fn(
variables,
augmented_and_preprocessed_input_batch['inputs'],
update_batch_norm=update_batch_norm,
mutable=['batch_stats'])
mutable=['batch_stats'],
use_running_average_bn=use_running_average_bn)
return logits, new_model_state
else:
logits = self._model.apply(
variables,
augmented_and_preprocessed_input_batch['inputs'],
update_batch_norm=update_batch_norm,
mutable=False)
mutable=False,
use_running_average_bn=use_running_average_bn)
return logits, model_state

# Does NOT apply regularization, which is left to the submitter to do in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,16 @@ class ResNet(nn.Module):
@nn.compact
def __call__(self,
x: spec.Tensor,
update_batch_norm: bool = True) -> spec.Tensor:
update_batch_norm: bool = True,
use_running_average_bn: Optional[bool] = None) -> spec.Tensor:
conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)

# Preserve default behavior for backwards compatibility
if use_running_average_bn is None:
use_running_average_bn = not update_batch_norm
norm = functools.partial(
nn.BatchNorm,
use_running_average=not update_batch_norm,
use_running_average=use_running_average_bn,
momentum=0.9,
epsilon=1e-5,
dtype=self.dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,9 @@ def model_fn(
model_state: spec.ModelAuxiliaryState,
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
update_batch_norm: bool,
use_running_average_bn: Optional[bool] = None
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del mode
del rng
variables = {'params': params, **model_state}
Expand All @@ -157,14 +159,16 @@ def model_fn(
variables,
augmented_and_preprocessed_input_batch['inputs'],
update_batch_norm=update_batch_norm,
mutable=['batch_stats'])
mutable=['batch_stats'],
use_running_average_bn=use_running_average_bn)
return logits, new_model_state
else:
logits = self._model.apply(
variables,
augmented_and_preprocessed_input_batch['inputs'],
update_batch_norm=update_batch_norm,
mutable=False)
mutable=False,
use_running_average_bn=use_running_average_bn)
return logits, model_state

# Does NOT apply regularization, which is left to the submitter to do in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -454,15 +454,24 @@ def setup(self):
self.beta = self.param('bias', nn.initializers.zeros, dim, dtype)

@nn.compact
def __call__(self, inputs, input_paddings, train):
def __call__(self,
inputs,
input_paddings,
update_batch_norm,
use_running_average_bn):
rank = inputs.ndim
reduce_over_dims = list(range(0, rank - 1))

padding = jnp.expand_dims(input_paddings, -1)
momentum = self.config.batch_norm_momentum
epsilon = self.config.batch_norm_epsilon

if train:
if use_running_average_bn:
mean = self.ra_mean.value
var = self.ra_var.value

else:
# compute batch statistics
mask = 1.0 - padding
sum_v = jnp.sum(inputs * mask, axis=reduce_over_dims, keepdims=True)
count_v = jnp.sum(
Expand All @@ -478,16 +487,13 @@ def __call__(self, inputs, input_paddings, train):

var = sum_vv / count_v

self.ra_mean.value = momentum * \
self.ra_mean.value + (1 - momentum) * mean
self.ra_var.value = momentum * \
self.ra_var.value + (1 - momentum) * var
else:
mean = self.ra_mean.value
var = self.ra_var.value
if update_batch_norm:
self.ra_mean.value = momentum * \
self.ra_mean.value + (1 - momentum) * mean
self.ra_var.value = momentum * \
self.ra_var.value + (1 - momentum) * var

inv = (1 + self.gamma) / jnp.sqrt(var + epsilon)

bn_output = (inputs - mean) * inv + self.beta
bn_output *= 1.0 - padding

Expand Down Expand Up @@ -517,7 +523,12 @@ class ConvolutionBlock(nn.Module):
config: ConformerConfig

@nn.compact
def __call__(self, inputs, input_paddings, train):
def __call__(self,
inputs,
input_paddings,
train,
update_batch_norm,
use_running_average_bn):
config = self.config
inputs = LayerNorm(dim=config.encoder_dim)(inputs)

Expand Down Expand Up @@ -546,7 +557,10 @@ def __call__(self, inputs, input_paddings, train):
kernel_init=nn.initializers.xavier_uniform())(
inputs)

inputs = BatchNorm(config)(inputs, input_paddings, train)
inputs = BatchNorm(config)(inputs,
input_paddings,
update_batch_norm,
use_running_average_bn)
if config.activation_function_name == 'swish':
activation_fn = nn.swish
elif config.activation_function_name == 'gelu':
Expand Down Expand Up @@ -586,7 +600,12 @@ class ConformerBlock(nn.Module):
config: ConformerConfig

@nn.compact
def __call__(self, inputs, input_paddings, train):
def __call__(self,
inputs,
input_paddings,
train,
update_batch_norm,
use_running_average):
config = self.config
padding_mask = jnp.expand_dims(1 - input_paddings, -1)

Expand All @@ -597,7 +616,11 @@ def __call__(self, inputs, input_paddings, train):
inputs, input_paddings, train)

inputs = inputs + \
ConvolutionBlock(config)(inputs, input_paddings, train)
ConvolutionBlock(config)(inputs,
input_paddings,
train,
update_batch_norm,
use_running_average)

inputs = inputs + 0.5 * FeedForwardModule(config=self.config)(
inputs, padding_mask, train)
Expand Down Expand Up @@ -629,12 +652,23 @@ def setup(self):
.use_dynamic_time_mask_max_frames)

@nn.compact
def __call__(self, inputs, input_paddings, train):
def __call__(self,
inputs,
input_paddings,
train,
update_batch_norm: Optional[bool] = None,
use_running_average_bn: Optional[bool] = None):
config = self.config

outputs = inputs
output_paddings = input_paddings

# Set BN args if not supplied for backwards compatibility
if update_batch_norm is None:
update_batch_norm = train
if use_running_average_bn is None:
use_running_average_bn = not train

# Compute normalized log mel spectrograms from input audio signal.
preprocessing_config = preprocessor.LibrispeechPreprocessingConfig()
outputs, output_paddings = preprocessor.MelFilterbankFrontend(
Expand All @@ -660,7 +694,11 @@ def __call__(self, inputs, input_paddings, train):

# Run the conformer encoder layers.
for _ in range(config.num_encoder_layers):
outputs = ConformerBlock(config)(outputs, output_paddings, train)
outputs = ConformerBlock(config)(outputs,
output_paddings,
train,
update_batch_norm,
use_running_average_bn)

outputs = LayerNorm(config.encoder_dim)(outputs)
# Run the decoder which in this case is a trivial projection layer.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def _get_mask(self,
jnp.expand_dims(jnp.arange(multiplicity, dtype=jnp.int32), 0),
[batch_size, 1])
multiplicity_tensor = masks_per_frame * choose_range
multiplicity_weights = (multiplicity_weights <
multiplicity_tensor).astype(jnp.int32)
multiplicity_weights = (multiplicity_weights
< multiplicity_tensor).astype(jnp.int32)
pre_mask = jnp.einsum('bmt,bm->bt', pre_mask, multiplicity_weights)
else:
pre_mask = jnp.einsum('bmt->bt', pre_mask)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ def model_fn(
model_state: spec.ModelAuxiliaryState,
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
update_batch_norm: bool,
use_running_average_bn: Optional[bool] = None
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
variables = {'params': params, **model_state}
inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs']
is_train_mode = mode == spec.ForwardPassMode.TRAIN
Expand All @@ -118,15 +120,17 @@ def model_fn(
input_paddings,
train=True,
rngs={'dropout' : rng},
mutable=['batch_stats'])
mutable=['batch_stats'],
use_running_average_bn=use_running_average_bn)
return (logits, logit_paddings), new_model_state
else:
logits, logit_paddings = self._model.apply(
variables,
inputs,
input_paddings,
train=False,
mutable=False)
mutable=False,
use_running_average_bn=use_running_average_bn)
return (logits, logit_paddings), model_state

def _build_input_queue(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class ConformerConfig:
time_masks_per_frame: float = 0.0
use_dynamic_time_mask_max_frames: bool = True
input_dropout_rate: float = 0.1
batch_norm_momentum: float = 0.999
batch_norm_momentum: float = 1 - 0.999
batch_norm_epsilon: float = 0.001
use_specaug: bool = True
attention_temperature: float = 1.0
Expand Down Expand Up @@ -369,10 +369,11 @@ def forward(self, inputs, input_paddings):
mean = (masked_inp).sum(dim=(0, 1)) / count
var = (torch.square(masked_inp - mean) * mask).sum(dim=(0, 1)) / count

self.running_mean = self.momentum * self.running_mean + (
1 - self.momentum) * mean.detach()
self.running_var = self.momentum * self.running_var + (
1 - self.momentum) * var.detach()
self.running_mean = (1 - self.momentum) * self.running_mean + (
self.momentum) * mean.detach()
self.running_var = (1 - self.momentum) * self.running_var + (
self.momentum) * var.detach()

else:
mean = self.running_mean
var = self.running_var
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,9 @@ def greedy_decode(
idxs = torch.arange(
fin_result.numel(), device=result.device).view(*fin_result.shape)
mask = torch.arange(
fin_result.shape[1], device=result.device).view(
1, -1) < result.count_nonzero(dim=1).view(-1, 1)
fin_result.shape[1],
device=result.device).view(1, -1) < result.count_nonzero(dim=1).view(
-1, 1)
fin_result.view(-1)[idxs[mask != 0]] = result[result != blank_id]
padding = fin_result == 0
return fin_result, padding
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class DeepspeechConfig:
time_mask_max_ratio: float = 0.05
time_masks_per_frame: float = 0.0
use_dynamic_time_mask_max_frames: bool = True
batch_norm_momentum: float = 0.999
batch_norm_momentum: float = 1 - 0.999
batch_norm_epsilon: float = 0.001
# If None, defaults to 0.1.
input_dropout_rate: Optional[float] = 0.1
Expand Down Expand Up @@ -264,10 +264,10 @@ def forward(self, inputs, input_paddings):
sum_ = dist_nn.all_reduce(sum_)
var = sum_ / count

self.running_mean = self.momentum * self.running_mean + (
1 - self.momentum) * mean.detach()
self.running_var = self.momentum * self.running_var + (
1 - self.momentum) * var.detach()
self.running_mean = (1 - self.momentum) * self.running_mean + (
self.momentum) * mean.detach()
self.running_var = (1 - self.momentum) * self.running_var + (
self.momentum) * var.detach()
else:
mean = self.running_mean
var = self.running_var
Expand Down
4 changes: 2 additions & 2 deletions algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,8 +942,8 @@ def forward(self,
# not the remaining zero elements.
if attn_mask is not None:
raise ValueError('Attention mask has to be None for decode == True.')
attn_mask = (torch.arange(max_len, device=k.device) >=
cache_index).reshape(1, max_len)
attn_mask = (torch.arange(max_len, device=k.device)
>= cache_index).reshape(1, max_len)

# Update sequence length to account for complete sequence.
seq_len = k.size(1)
Expand Down
Loading