From 023e94dbd6b1b78aecd7799461369a8d9a5972bf Mon Sep 17 00:00:00 2001 From: aydentang Date: Tue, 15 Mar 2022 20:14:44 +0800 Subject: [PATCH 01/19] wav2vec2 training --- wenet/bin/train.py | 17 +- wenet/wav2vec/grad_multiply.py | 18 + wenet/wav2vec/gumbel_vector_quantizer.py | 202 ++++++ wenet/wav2vec/w2v_loss.py | 92 +++ wenet/wav2vec/wav2vec2_encoder.py | 821 +++++++++++++++++++++++ wenet/wav2vec/wav2vec2_model.py | 765 +++++++++++++++++++++ 6 files changed, 1913 insertions(+), 2 deletions(-) create mode 100644 wenet/wav2vec/grad_multiply.py create mode 100644 wenet/wav2vec/gumbel_vector_quantizer.py create mode 100644 wenet/wav2vec/w2v_loss.py create mode 100644 wenet/wav2vec/wav2vec2_encoder.py create mode 100644 wenet/wav2vec/wav2vec2_model.py diff --git a/wenet/bin/train.py b/wenet/bin/train.py index 764b43eb1..d6c5aa7c6 100644 --- a/wenet/bin/train.py +++ b/wenet/bin/train.py @@ -28,6 +28,7 @@ from wenet.dataset.dataset import Dataset from wenet.transformer.asr_model import init_asr_model +from wenet.wav2vec.wav2vec2_model import init_wav2vec2_model from wenet.utils.checkpoint import (load_checkpoint, save_checkpoint, load_trained_modules) from wenet.utils.executor import Executor @@ -179,6 +180,13 @@ def main(): input_dim = configs['dataset_conf']['mfcc_conf']['num_mel_bins'] vocab_size = len(symbol_table) + pretrain=configs.get('pretrain',False) + if 'wav2vec_conf' in configs: + wav2vec_conf=configs['wav2vec_conf'] + wav2vec_conf['pretrain']=pretrain + else: + wav2vec_conf=None + # Save configs to model_dir/train.yaml for inference and export configs['input_dim'] = input_dim configs['output_dim'] = vocab_size @@ -191,7 +199,10 @@ def main(): fout.write(data) # Init asr model from configs - model = init_asr_model(configs) + if wav2vec_conf: + model=init_wav2vec2_model(configs) + else: + model = init_asr_model(configs) print(model) num_params = sum(p.numel() for p in model.parameters()) print('the number of model params: {}'.format(num_params)) @@ -199,7 +210,7 @@ def main(): # !!!IMPORTANT!!! # Try to export the model by script, if fails, we should refine # the code to satisfy the script export requirements - if args.rank == 0: + if args.rank == 0 and not pretrain and wav2vec_conf is None: script_model = torch.jit.script(model) script_model.save(os.path.join(args.model_dir, 'init.zip')) executor = Executor() @@ -218,6 +229,8 @@ def main(): num_epochs = configs.get('max_epoch', 100) model_dir = args.model_dir writer = None + if pretrain: + model.set_num_updates(step) if args.rank == 0: os.makedirs(model_dir, exist_ok=True) exp_id = os.path.basename(model_dir) diff --git a/wenet/wav2vec/grad_multiply.py b/wenet/wav2vec/grad_multiply.py new file mode 100644 index 000000000..08d15f55d --- /dev/null +++ b/wenet/wav2vec/grad_multiply.py @@ -0,0 +1,18 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + + +class GradMultiply(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scale): + ctx.scale = scale + res = x.new(x) + return res + + @staticmethod + def backward(ctx, grad): + return grad * ctx.scale, None diff --git a/wenet/wav2vec/gumbel_vector_quantizer.py b/wenet/wav2vec/gumbel_vector_quantizer.py new file mode 100644 index 000000000..711343888 --- /dev/null +++ b/wenet/wav2vec/gumbel_vector_quantizer.py @@ -0,0 +1,202 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class GumbelVectorQuantizer(nn.Module): + def __init__( + self, + dim, + num_vars, + temp, + groups, + combine_groups, + vq_dim, + time_first, + activation=nn.GELU(), + weight_proj_depth=1, + weight_proj_factor=1, + ): + """Vector quantization using gumbel softmax + + Args: + dim: input dimension (channels) + num_vars: number of quantized vectors per group + temp: temperature for training. this should be a tuple of 3 elements: (start, stop, decay factor) + groups: number of groups for vector quantization + combine_groups: whether to use the vectors for all groups + vq_dim: dimensionality of the resulting quantized vector + time_first: if true, expect input in BxTxC format, otherwise in BxCxT + activation: what activation to use (should be a module). this is only used if weight_proj_depth is > 1 + weight_proj_depth: number of layers (with activation in between) to project input before computing logits + weight_proj_factor: this is used only if weight_proj_depth is > 1. scales the inner dimensionality of + projections by this factor + """ + super().__init__() + + self.groups = groups + self.combine_groups = combine_groups + self.input_dim = dim + self.num_vars = num_vars + self.time_first = time_first + + assert ( + vq_dim % groups == 0 + ), f"dim {vq_dim} must be divisible by groups {groups} for concatenation" + + var_dim = vq_dim // groups + num_groups = groups if not combine_groups else 1 + + self.vars = nn.Parameter(torch.FloatTensor(1, num_groups * num_vars, var_dim)) + nn.init.uniform_(self.vars) + + if weight_proj_depth > 1: + + def block(input_dim, output_dim): + return nn.Sequential(nn.Linear(input_dim, output_dim), activation) + + inner_dim = self.input_dim * weight_proj_factor + self.weight_proj = nn.Sequential( + *[ + block(self.input_dim if i == 0 else inner_dim, inner_dim) + for i in range(weight_proj_depth - 1) + ], + nn.Linear(inner_dim, groups * num_vars), + ) + else: + self.weight_proj = nn.Linear(self.input_dim, groups * num_vars) + nn.init.normal_(self.weight_proj.weight, mean=0, std=1) + nn.init.zeros_(self.weight_proj.bias) + + if isinstance(temp, str): + import ast + temp = ast.literal_eval(temp) + assert len(temp) == 3, f"{temp}, {len(temp)}" + + self.max_temp, self.min_temp, self.temp_decay = temp + self.curr_temp = self.max_temp + self.codebook_indices = None + + def set_num_updates(self, num_updates): + self.curr_temp = max( + self.max_temp * self.temp_decay ** num_updates, self.min_temp + ) + + def get_codebook_indices(self): + if self.codebook_indices is None: + from itertools import product + + p = [range(self.num_vars)] * self.groups + inds = list(product(*p)) + self.codebook_indices = torch.tensor( + inds, dtype=torch.long, device=self.vars.device + ).flatten() + + if not self.combine_groups: + self.codebook_indices = self.codebook_indices.view( + self.num_vars ** self.groups, -1 + ) + for b in range(1, self.groups): + self.codebook_indices[:, b] += self.num_vars * b + self.codebook_indices = self.codebook_indices.flatten() + return self.codebook_indices + + def codebook(self): + indices = self.get_codebook_indices() + return ( + self.vars.squeeze(0) + .index_select(0, indices) + .view(self.num_vars ** self.groups, -1) + ) + + def sample_from_codebook(self, b, n): + indices = self.get_codebook_indices() + indices = indices.view(-1, self.groups) + cb_size = indices.size(0) + assert ( + n < cb_size + ), f"sample size {n} is greater than size of codebook {cb_size}" + sample_idx = torch.randint(low=0, high=cb_size, size=(b * n,)) + indices = indices[sample_idx] + + z = self.vars.squeeze(0).index_select(0, indices.flatten()).view(b, n, -1) + return z + + def to_codebook_index(self, indices): + res = indices.new_full(indices.shape[:-1], 0) + for i in range(self.groups): + exponent = self.groups - i - 1 + res += indices[..., i] * (self.num_vars ** exponent) + return res + + def forward_idx(self, x): + res = self.forward(x, produce_targets=True) + return res["x"], res["targets"] + + def forward(self, x, produce_targets=False): + + result = {"num_vars": self.num_vars * self.groups} + + if not self.time_first: + x = x.transpose(1, 2) + + bsz, tsz, fsz = x.shape + x = x.reshape(-1, fsz) + x = self.weight_proj(x) + x = x.view(bsz * tsz * self.groups, -1) + + _, k = x.max(-1) + hard_x = ( + x.new_zeros(*x.shape) + .scatter_(-1, k.view(-1, 1), 1.0) + .view(bsz * tsz, self.groups, -1) + ) + hard_probs = torch.mean(hard_x.float(), dim=0) + result["code_perplexity"] = torch.exp( + -torch.sum(hard_probs * torch.log(hard_probs + 1e-7), dim=-1) + ).sum() + + avg_probs = torch.softmax( + x.view(bsz * tsz, self.groups, -1).float(), dim=-1 + ).mean(dim=0) + result["prob_perplexity"] = torch.exp( + -torch.sum(avg_probs * torch.log(avg_probs + 1e-7), dim=-1) + ).sum() + + result["temp"] = self.curr_temp + + if self.training: + x = F.gumbel_softmax(x.float(), tau=self.curr_temp, hard=True).type_as(x) + else: + x = hard_x + + x = x.view(bsz * tsz, -1) + + vars = self.vars + if self.combine_groups: + vars = vars.repeat(1, self.groups, 1) + + if produce_targets: + result["targets"] = ( + x.view(bsz * tsz * self.groups, -1) + .argmax(dim=-1) + .view(bsz, tsz, self.groups) + .detach() + ) + + x = x.unsqueeze(-1) * vars + x = x.view(bsz * tsz, self.groups, self.num_vars, -1) + x = x.sum(-2) + x = x.view(bsz, tsz, -1) + + if not self.time_first: + x = x.transpose(1, 2) # BTC -> BCT + + result["x"] = x + + return result diff --git a/wenet/wav2vec/w2v_loss.py b/wenet/wav2vec/w2v_loss.py new file mode 100644 index 000000000..3924cf340 --- /dev/null +++ b/wenet/wav2vec/w2v_loss.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + + +# The origial Wav2vec work is in: +# Paper: https://arxiv.org/pdf/2006.11477.pdf +# Code in Fairseq: https://github.com/pytorch/fairseq/tree/master/examples/wav2vec + + +from distutils.version import LooseVersion +import logging + +import numpy as np +import six +import torch +import torch.nn.functional as F +import math + +class W2vLoss(torch.nn.Module): + def __init__(self,infonce=False,loss_weights=None): + super().__init__() + self.infonce = infonce + self.loss_weights = loss_weights + + def forward(self, model, net_output ,reduce=True): + + losses = [] + weights = None + self.infonce=True + + logits = model.get_logits(net_output) + target = model.get_targets(None, net_output) + sample_size = target.numel() + + if self.infonce: + loss = F.cross_entropy( + logits, + target, + reduction="sum" if reduce else "none", + ) + else: + loss = F.binary_cross_entropy_with_logits( + logits, + target.float(), + weights, + reduction="sum" if reduce else "none", + ) + + losses.append(loss.detach().clone()) + + if self.loss_weights is not None: + extra_losses = model.get_extra_losses(net_output) + if torch.is_tensor(extra_losses): + extra_losses = [extra_losses] + if len(self.loss_weights) == 1 and len(extra_losses) != 1: + self.loss_weights = [self.loss_weights[0]] * len(extra_losses) + assert len(extra_losses) == len( + self.loss_weights + ), f"{len(extra_losses)}, {len(self.loss_weights)}" + for p, coef in zip(extra_losses, self.loss_weights): + if coef != 0 and p is not None: + p = coef * p.float() * sample_size + loss += p + losses.append(p) + + logging_output = { + "loss": loss.item()/sample_size / math.log(2) if reduce else loss, + "sample_size": sample_size, + } + + if len(losses) > 1: + for i, l in enumerate(losses): + logging_output[f"loss_{i}"] = l.item() + + if self.infonce: + with torch.no_grad(): + if logits.numel() == 0: + corr = 0 + count = 0 + else: + assert logits.dim() > 1, logits.shape + max = logits.argmax(-1) == 0 + min = logits.argmin(-1) == 0 + both = max & min + corr = max.long().sum().item() - both.long().sum().item() + count = max.numel() + + logging_output["correct"] = corr + logging_output["count"] = count + + return loss,sample_size,logging_output + diff --git a/wenet/wav2vec/wav2vec2_encoder.py b/wenet/wav2vec/wav2vec2_encoder.py new file mode 100644 index 000000000..2325883f5 --- /dev/null +++ b/wenet/wav2vec/wav2vec2_encoder.py @@ -0,0 +1,821 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + + +# The origial Wav2vec work is in: +# Paper: https://arxiv.org/pdf/2006.11477.pdf +# Code in Fairseq: https://github.com/pytorch/fairseq/tree/master/examples/wav2vec + +"""Encoder definition.""" +from typing import Tuple, List, Optional + +import torch +from typeguard import check_argument_types + +from wenet.transformer.attention import MultiHeadedAttention +from wenet.transformer.attention import RelPositionMultiHeadedAttention +from wenet.transformer.convolution import ConvolutionModule +from wenet.transformer.embedding import PositionalEncoding +from wenet.transformer.embedding import RelPositionalEncoding +from wenet.transformer.embedding import NoPositionalEncoding +from wenet.transformer.encoder_layer import TransformerEncoderLayer +from wenet.transformer.encoder_layer import ConformerEncoderLayer +from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward +from wenet.transformer.subsampling import Conv2dSubsampling4 +from wenet.transformer.subsampling import Conv2dSubsampling6 +from wenet.transformer.subsampling import Conv2dSubsampling8 +from wenet.transformer.subsampling import LinearNoSubsampling +from wenet.utils.common import get_activation +from wenet.utils.mask import make_pad_mask +from wenet.utils.mask import add_optional_chunk_mask +from wenet.utils.mask import compute_mask_indices +from wenet.wav2vec.gumbel_vector_quantizer import GumbelVectorQuantizer + +def buffered_arange(max): + if not hasattr(buffered_arange, "buf"): + buffered_arange.buf = torch.LongTensor() + if max > buffered_arange.buf.numel(): + buffered_arange.buf.resize_(max) + torch.arange(max, out=buffered_arange.buf) + return buffered_arange.buf[:max] + + +class W2vBaseEncoder(torch.nn.Module): + def __init__( + self, + wav2vec_conf:dict, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: str = "conv2d", + pos_enc_layer_type: str = "abs_pos", + normalize_before: bool = True, + concat_after: bool = False, + static_chunk_size: int = 0, + use_dynamic_chunk: bool = False, + global_cmvn: torch.nn.Module = None, + use_dynamic_left_chunk: bool = False, + use_feature_norm: bool = False, + ): + """ + Args: + input_size (int): input dim + output_size (int): dimension of attention + attention_heads (int): the number of heads of multi head attention + linear_units (int): the hidden units number of position-wise feed + forward + num_blocks (int): the number of decoder blocks + dropout_rate (float): dropout rate + attention_dropout_rate (float): dropout rate in attention + positional_dropout_rate (float): dropout rate after adding + positional encoding + input_layer (str): input layer type. + optional [linear, conv2d, conv2d6, conv2d8] + pos_enc_layer_type (str): Encoder positional encoding layer type. + opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos] + normalize_before (bool): + True: use layer_norm before each sub-block of a layer. + False: use layer_norm after each sub-block of a layer. + concat_after (bool): whether to concat attention layer's input + and output. + True: x -> x + linear(concat(x, att(x))) + False: x -> x + att(x) + static_chunk_size (int): chunk size for static chunk training and + decoding + use_dynamic_chunk (bool): whether use dynamic chunk size for + training or not, You can only use fixed chunk(chunk_size > 0) + or dyanmic chunk size(use_dynamic_chunk = True) + global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module + use_dynamic_left_chunk (bool): whether use dynamic left chunk in + dynamic chunk training + use_feature_norm (bool): whether to use layer_norm after + cnn subsampling + """ + assert check_argument_types() + super().__init__() + self._output_size = output_size + + if pos_enc_layer_type == "abs_pos": + pos_enc_class = PositionalEncoding + elif pos_enc_layer_type == "rel_pos": + pos_enc_class = RelPositionalEncoding + elif pos_enc_layer_type == "no_pos": + pos_enc_class = NoPositionalEncoding + else: + raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type) + + if input_layer == "linear": + subsampling_class = LinearNoSubsampling + elif input_layer == "conv2d": + subsampling_class = Conv2dSubsampling4 + elif input_layer == "conv2d6": + subsampling_class = Conv2dSubsampling6 + elif input_layer == "conv2d8": + subsampling_class = Conv2dSubsampling8 + else: + raise ValueError("unknown input_layer: " + input_layer) + + self.global_cmvn = global_cmvn + self.embed = subsampling_class( + input_size, + output_size, + dropout_rate, + pos_enc_class(output_size, positional_dropout_rate), + ) + + self.normalize_before = normalize_before + self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5) + self.static_chunk_size = static_chunk_size + self.use_dynamic_chunk = use_dynamic_chunk + self.use_dynamic_left_chunk = use_dynamic_left_chunk + + if use_feature_norm: + self.feature_norm = torch.nn.LayerNorm(output_size, eps=1e-5) + else: + self.feature_norm = None + + self.embed_dim=output_size + final_dim=output_size + self.encoder_embed_dim=output_size + + self.mask_prob = wav2vec_conf.get('mask_prob', 0.65) + self.mask_selection = "static" + self.mask_other = 0 + self.mask_length = 10 + self.no_mask_overlap =False + self.mask_min_space = 1 + + self.mask_channel_prob = wav2vec_conf.get('mask_channel_prob', 0.0) + self.mask_channel_selection = "static" + self.mask_channel_other = 0 + self.mask_channel_length = wav2vec_conf.get('mask_channel_length', 10) + self.no_mask_channel_overlap = False + self.mask_channel_min_space = 1 + + self.n_negatives = wav2vec_conf.get('num_negatives', 100) + self.cross_sample_negatives = 0 + self.codebook_negatives = 0 + self.negatives_from_everywhere = False + + self.quantize_targets=wav2vec_conf.get('quantize_targets', True) + self.project_targets=wav2vec_conf.get('project_targets', False) + self.project_final=wav2vec_conf.get('project_final', False) + self.quantize_input=wav2vec_conf.get('quantize_input', False) + self.latent_dim=wav2vec_conf.get('latent_dim',0) + self.latent_vars=wav2vec_conf.get('latent_vars',320) + self.latent_temp='(2,0.5,0.999995)' + self.latent_groups=wav2vec_conf.get('latent_groups',2) + use_target_feature_norm=wav2vec_conf.get('target_feature_norm', False) + + if use_target_feature_norm: + self.target_feature_norm = torch.nn.LayerNorm(output_size, eps=1e-12) + else: + self.target_feature_norm = None + + if self.quantize_targets: + vq_dim = self.latent_dim if self.latent_dim > 0 else final_dim + self.quantizer = GumbelVectorQuantizer( + dim=self.embed_dim, + num_vars=self.latent_vars, + temp=self.latent_temp, + groups=self.latent_groups, + combine_groups=False, + vq_dim=vq_dim, + time_first=True, + ) + self.project_q = torch.nn.Linear(vq_dim, final_dim) + else: + self.project_q = torch.nn.Linear(self.embed_dim, final_dim) + + if self.quantize_input: + if self.quantizer is not None: + vq_dim = final_dim + self.input_quantizer = self.quantizer + else: + vq_dim = self.latent_dim if self.latent_dim > 0 else final_dim + self.input_quantizer = GumbelVectorQuantizer( + dim=self.embed_dim, + num_vars=self.latent_vars, + temp=self.latent_temp, + groups=self.latent_groups, + combine_groups=False, + vq_dim=vq_dim, + time_first=True, + ) + self.project_inp = torch.nn.Linear(vq_dim, final_dim) + + self.mask=wav2vec_conf.get('mask', True) + + self.pretrain=wav2vec_conf.get('pretrain',True) + + self.final_proj = torch.nn.Linear(output_size, final_dim) + + self.target_glu=None + + self.logit_temp=0.1 + + self.mask_emb = torch.nn.Parameter( + torch.FloatTensor(self.encoder_embed_dim).uniform_() + ) + + self.feature_grad_mult=wav2vec_conf.get('feature_grad_mult',1.0) + + def output_size(self) -> int: + return self._output_size + + def apply_mask(self, x, padding_mask): + B, T, C = x.shape + if self.mask_prob > 0: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x[mask_indices] = self.mask_emb + else: + mask_indices = None + + if self.mask_channel_prob > 0: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x[mask_channel_indices] = 0 + + return x, mask_indices + + def sample_negatives(self, y, num): + + if self.n_negatives == 0 and self.cross_sample_negatives == 0: + return y.new(0) + + bsz, tsz, fsz = y.shape + y = y.view(-1, fsz) # BTC => (BxT)C + + cross_high = tsz * bsz + high = tsz + with torch.no_grad(): + assert high > 1, f"{bsz,tsz,fsz}" + + if self.n_negatives > 0: + tszs = ( + buffered_arange(num) + .unsqueeze(-1) + .expand(-1, self.n_negatives) + .flatten() + ) + + neg_idxs = torch.randint( + low=0, high=high - 1, size=(bsz, self.n_negatives * num) + ) + neg_idxs[neg_idxs >= tszs] += 1 + + if self.cross_sample_negatives > 0: + tszs = ( + buffered_arange(num) + .unsqueeze(-1) + .expand(-1, self.cross_sample_negatives) + .flatten() + ) + + cross_neg_idxs = torch.randint( + low=0, + high=cross_high - 1, + size=(bsz, self.cross_sample_negatives * num), + ) + cross_neg_idxs[cross_neg_idxs >= tszs] += 1 + + if self.n_negatives > 0: + for i in range(1, bsz): + neg_idxs[i] += i * high + else: + neg_idxs = cross_neg_idxs + + if self.cross_sample_negatives > 0 and self.n_negatives > 0: + neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1) + + negs = y[neg_idxs.view(-1)] + negs = negs.view( + bsz, num, self.n_negatives + self.cross_sample_negatives, fsz + ).permute( + 2, 0, 1, 3 + ) # to NxBxTxC + return negs, neg_idxs + + def compute_preds(self, x, y, negatives): + + neg_is_pos = (y == negatives).all(-1) + y = y.unsqueeze(0) + targets = torch.cat([y, negatives], dim=0) + + logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x) + + logits /= self.logit_temp + + if neg_is_pos.any(): + logits[1:][neg_is_pos] = float("-inf") + + return logits + + def forward( + self, + xs: torch.Tensor, + xs_lens: torch.Tensor, + decoding_chunk_size: int = 0, + num_decoding_left_chunks: int = -1, + features_only: bool=False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Embed positions in tensor. + + Args: + xs: padded input tensor (B, T, D) + xs_lens: input length (B) + decoding_chunk_size: decoding chunk size for dynamic chunk + 0: default for training, use random dynamic chunk. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + num_decoding_left_chunks: number of left chunks, this is for decoding, + the chunk size is decoding_chunk_size. + >=0: use num_decoding_left_chunks + <0: use all left chunks + Returns: + encoder output tensor xs, and subsampled masks + xs: padded output tensor (B, T' ~= T/subsample_rate, D) + masks: torch.Tensor batch padding mask after subsample + (B, 1, T' ~= T/subsample_rate) + """ + T = xs.size(1) + masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) + if self.global_cmvn is not None: + xs = self.global_cmvn(xs) + + if self.feature_grad_mult > 0: + features, pos_emb, masks = self.embed(xs, masks) + if self.feature_grad_mult != 1.0: + features = GradMultiply.apply(features, self.feature_grad_mult) + else: + with torch.no_grad(): + features, pos_emb, masks = self.embed(xs, masks) + + + mask_pad = masks # (B, 1, T/subsample_rate) + chunk_masks=masks + # chunk_masks = add_optional_chunk_mask(xs, masks, + # self.use_dynamic_chunk, + # self.use_dynamic_left_chunk, + # decoding_chunk_size, + # self.static_chunk_size, + # num_decoding_left_chunks) + + + #L2 loss pen + features_pen = features.float().pow(2).mean() + + # features = self.feature_norm(features) + unmasked_features = features.clone() + input_features=None + + if self.target_feature_norm is not None: + unmasked_features=self.target_feature_norm(unmasked_features) + + if self.quantize_input: + q = self.input_quantizer(features, produce_targets=False) + features = q["x"] + num_vars = q["num_vars"] + code_ppl = q["code_perplexity"] + prob_ppl = q["prob_perplexity"] + curr_temp = q["temp"] + features = self.project_inp(features) + + if self.mask and self.pretrain: + + x, mask_indices = self.apply_mask(features, None) + input_features=x.clone() + if mask_indices is not None: + y = unmasked_features[mask_indices].view( + unmasked_features.size(0), -1, unmasked_features.size(-1) + ) + else: + y = unmasked_features + elif self.mask and self.training: + x, mask_indices = self.apply_mask(features, None) + if mask_indices is not None: + y = unmasked_features[mask_indices].view( + unmasked_features.size(0), -1, unmasked_features.size(-1) + ) + else: + y = unmasked_features + x = features + mask_indices = None + + for layer in self.encoders: + x, chunk_masks, _ = layer(x, chunk_masks, pos_emb, mask_pad) + if self.normalize_before: + x = self.after_norm(x) + # Here we assume the mask is not changed in encoder layers, so just + # return the masks before encoder layers, and the masks will be used + # for cross attention with decoder later + ext_result={"features_pen": features_pen} + + if features_only: + return x, masks, ext_result + + if self.quantize_targets: + q = self.quantizer(y, produce_targets=True) + y = q["x"] + y_targets=q["targets"] + + num_vars = q["num_vars"] + code_ppl = q["code_perplexity"] + prob_ppl = q["prob_perplexity"] + curr_temp = q["temp"] + + if self.project_targets: + y = self.project_q(y) + + if self.negatives_from_everywhere: + neg_cands, *_ = self.quantizer(unmasked_features, produce_targets=False) + negs, _ = self.sample_negatives(neg_cands, y.size(1)) + negs = self.project_q(negs) + else: + negs, _ = self.sample_negatives(y, y.size(1)) + + if self.codebook_negatives > 0: + cb_negs = self.quantizer.sample_from_codebook( + y.size(0) * y.size(1), self.codebook_negatives + ) + cb_negs = cb_negs.view( + self.codebook_negatives, y.size(0), y.size(1), -1 + ) # order doesnt matter + cb_negs = self.project_q(cb_negs) + negs = torch.cat([negs, cb_negs], dim=0) + else: + if self.project_targets: + y = self.project_q(y) + + if self.negatives_from_everywhere: + negs, _ = self.sample_negatives(unmasked_features, y.size(1)) + negs = self.project_q(negs) + else: + negs, _ = self.sample_negatives(y, y.size(1)) + + x = x[mask_indices].view(x.size(0), -1, x.size(-1)) + + if self.target_glu: + y = self.target_glu(y) + negs = self.target_glu(negs) + + # final project + if self.project_final: + x = self.final_proj(x) + + x = self.compute_preds(x, y, negs) + + if prob_ppl is not None: + ext_result["prob_perplexity"] = prob_ppl + ext_result["code_perplexity"] = code_ppl + ext_result["num_vars"] = num_vars + ext_result["temp"] = curr_temp + + ext_result["x"]=x + + return x,masks,ext_result + + def get_logits(self, net_output): + logits=net_output["x"] + logits = logits.transpose(0, 2) + logits = logits.reshape(-1, logits.size(-1)) + return logits + + def get_targets(self, sample, net_output, expand_steps=True): + x=net_output["x"] + return x.new_zeros(x.size(1) * x.size(2), dtype=torch.long) + + def get_extra_losses(self, net_output): + pen = [] + + if "prob_perplexity" in net_output: + pen.append( + (net_output["num_vars"] - net_output["prob_perplexity"]) + / net_output["num_vars"] + ) + + if "features_pen" in net_output: + pen.append(net_output["features_pen"]) + + return pen + + def forward_chunk( + self, + xs: torch.Tensor, + offset: int, + required_cache_size: int, + subsampling_cache: Optional[torch.Tensor] = None, + elayers_output_cache: Optional[List[torch.Tensor]] = None, + conformer_cnn_cache: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], + List[torch.Tensor]]: + """ Forward just one chunk + + Args: + xs (torch.Tensor): chunk input + offset (int): current offset in encoder output time stamp + required_cache_size (int): cache size required for next chunk + compuation + >=0: actual cache size + <0: means all history cache is required + subsampling_cache (Optional[torch.Tensor]): subsampling cache + elayers_output_cache (Optional[List[torch.Tensor]]): + transformer/conformer encoder layers output cache + conformer_cnn_cache (Optional[List[torch.Tensor]]): conformer + cnn cache + + Returns: + torch.Tensor: output of current input xs + torch.Tensor: subsampling cache required for next chunk computation + List[torch.Tensor]: encoder layers output cache required for next + chunk computation + List[torch.Tensor]: conformer cnn cache + + """ + assert xs.size(0) == 1 + # tmp_masks is just for interface compatibility + tmp_masks = torch.ones(1, + xs.size(1), + device=xs.device, + dtype=torch.bool) + tmp_masks = tmp_masks.unsqueeze(1) + if self.global_cmvn is not None: + xs = self.global_cmvn(xs) + xs, pos_emb, _ = self.embed(xs, tmp_masks, offset) + if self.feature_norm is not None: + xs = self.feature_norm(xs) + if subsampling_cache is not None: + cache_size = subsampling_cache.size(1) + xs = torch.cat((subsampling_cache, xs), dim=1) + else: + cache_size = 0 + pos_emb = self.embed.position_encoding(offset - cache_size, xs.size(1)) + if required_cache_size < 0: + next_cache_start = 0 + elif required_cache_size == 0: + next_cache_start = xs.size(1) + else: + next_cache_start = max(xs.size(1) - required_cache_size, 0) + r_subsampling_cache = xs[:, next_cache_start:, :] + # Real mask for transformer/conformer layers + masks = torch.ones(1, xs.size(1), device=xs.device, dtype=torch.bool) + masks = masks.unsqueeze(1) + r_elayers_output_cache = [] + r_conformer_cnn_cache = [] + for i, layer in enumerate(self.encoders): + if elayers_output_cache is None: + attn_cache = None + else: + attn_cache = elayers_output_cache[i] + if conformer_cnn_cache is None: + cnn_cache = None + else: + cnn_cache = conformer_cnn_cache[i] + xs, _, new_cnn_cache = layer(xs, + masks, + pos_emb, + output_cache=attn_cache, + cnn_cache=cnn_cache) + r_elayers_output_cache.append(xs[:, next_cache_start:, :]) + r_conformer_cnn_cache.append(new_cnn_cache) + if self.normalize_before: + xs = self.after_norm(xs) + + return (xs[:, cache_size:, :], r_subsampling_cache, + r_elayers_output_cache, r_conformer_cnn_cache) + + def forward_chunk_by_chunk( + self, + xs: torch.Tensor, + decoding_chunk_size: int, + num_decoding_left_chunks: int = -1, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ Forward input chunk by chunk with chunk_size like a streaming + fashion + + Here we should pay special attention to computation cache in the + streaming style forward chunk by chunk. Three things should be taken + into account for computation in the current network: + 1. transformer/conformer encoder layers output cache + 2. convolution in conformer + 3. convolution in subsampling + + However, we don't implement subsampling cache for: + 1. We can control subsampling module to output the right result by + overlapping input instead of cache left context, even though it + wastes some computation, but subsampling only takes a very + small fraction of computation in the whole model. + 2. Typically, there are several covolution layers with subsampling + in subsampling module, it is tricky and complicated to do cache + with different convolution layers with different subsampling + rate. + 3. Currently, nn.Sequential is used to stack all the convolution + layers in subsampling, we need to rewrite it to make it work + with cache, which is not prefered. + Args: + xs (torch.Tensor): (1, max_len, dim) + chunk_size (int): decoding chunk size + """ + assert decoding_chunk_size > 0 + # The model is trained by static or dynamic chunk + assert self.static_chunk_size > 0 or self.use_dynamic_chunk + subsampling = self.embed.subsampling_rate + context = self.embed.right_context + 1 # Add current frame + stride = subsampling * decoding_chunk_size + decoding_window = (decoding_chunk_size - 1) * subsampling + context + num_frames = xs.size(1) + subsampling_cache: Optional[torch.Tensor] = None + elayers_output_cache: Optional[List[torch.Tensor]] = None + conformer_cnn_cache: Optional[List[torch.Tensor]] = None + outputs = [] + offset = 0 + required_cache_size = decoding_chunk_size * num_decoding_left_chunks + + # Feed forward overlap input step by step + for cur in range(0, num_frames - context + 1, stride): + end = min(cur + decoding_window, num_frames) + chunk_xs = xs[:, cur:end, :] + (y, subsampling_cache, elayers_output_cache, + conformer_cnn_cache) = self.forward_chunk(chunk_xs, offset, + required_cache_size, + subsampling_cache, + elayers_output_cache, + conformer_cnn_cache) + outputs.append(y) + offset += y.size(1) + ys = torch.cat(outputs, 1) + masks = torch.ones(1, ys.size(1), device=ys.device, dtype=torch.bool) + masks = masks.unsqueeze(1) + return ys, masks + + +class TransformerEncoder(W2vBaseEncoder): + """Transformer encoder module.""" + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: str = "conv2d", + pos_enc_layer_type: str = "abs_pos", + normalize_before: bool = True, + concat_after: bool = False, + static_chunk_size: int = 0, + use_dynamic_chunk: bool = False, + global_cmvn: torch.nn.Module = None, + use_dynamic_left_chunk: bool = False, + use_feature_norm: bool = False, + ): + """ Construct TransformerEncoder + + See Encoder for the meaning of each parameter. + """ + assert check_argument_types() + super().__init__(input_size, output_size, attention_heads, + linear_units, num_blocks, dropout_rate, + positional_dropout_rate, attention_dropout_rate, + input_layer, pos_enc_layer_type, normalize_before, + concat_after, static_chunk_size, use_dynamic_chunk, + global_cmvn, use_dynamic_left_chunk, use_feature_norm) + self.encoders = torch.nn.ModuleList([ + TransformerEncoderLayer( + output_size, + MultiHeadedAttention(attention_heads, output_size, + attention_dropout_rate), + PositionwiseFeedForward(output_size, linear_units, + dropout_rate), dropout_rate, + normalize_before, concat_after) for _ in range(num_blocks) + ]) + + +class W2vConformerEncoder(W2vBaseEncoder): + """Conformer encoder module.""" + def __init__( + self, + wav2vec_conf:dict, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: str = "conv2d", + pos_enc_layer_type: str = "rel_pos", + normalize_before: bool = True, + concat_after: bool = False, + static_chunk_size: int = 0, + use_dynamic_chunk: bool = False, + global_cmvn: torch.nn.Module = None, + use_dynamic_left_chunk: bool = False, + positionwise_conv_kernel_size: int = 1, + macaron_style: bool = True, + selfattention_layer_type: str = "rel_selfattn", + activation_type: str = "swish", + use_cnn_module: bool = True, + cnn_module_kernel: int = 15, + causal: bool = False, + cnn_module_norm: str = "batch_norm", + cnn_module_before: bool = False, + use_feature_norm: bool = False, + ): + """Construct ConformerEncoder + + Args: + input_size to use_dynamic_chunk, see in BaseEncoder + positionwise_conv_kernel_size (int): Kernel size of positionwise + conv1d layer. + macaron_style (bool): Whether to use macaron style for + positionwise layer. + selfattention_layer_type (str): Encoder attention layer type, + the parameter has no effect now, it's just for configure + compatibility. + activation_type (str): Encoder activation function type. + use_cnn_module (bool): Whether to use convolution module. + cnn_module_kernel (int): Kernel size of convolution module. + causal (bool): whether to use causal convolution or not. + """ + assert check_argument_types() + super().__init__(wav2vec_conf,input_size, output_size, attention_heads, + linear_units, num_blocks, dropout_rate, + positional_dropout_rate, attention_dropout_rate, + input_layer, pos_enc_layer_type, normalize_before, + concat_after, static_chunk_size, use_dynamic_chunk, + global_cmvn, use_dynamic_left_chunk, use_feature_norm) + activation = get_activation(activation_type) + + # self-attention module definition + if selfattention_layer_type == "rel_selfattn": + encoder_selfattn_layer = RelPositionMultiHeadedAttention + else: + encoder_selfattn_layer = MultiHeadedAttention + + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + ) + # feed-forward module definition + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = ( + output_size, + linear_units, + dropout_rate, + activation, + ) + # convolution module definition + convolution_layer = ConvolutionModule + convolution_layer_args = (output_size, cnn_module_kernel, activation, + cnn_module_norm, causal) + + self.encoders = torch.nn.ModuleList([ + ConformerEncoderLayer( + output_size, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + positionwise_layer( + *positionwise_layer_args) if macaron_style else None, + convolution_layer( + *convolution_layer_args) if use_cnn_module else None, + cnn_module_before, + dropout_rate, + normalize_before, + concat_after, + ) for _ in range(num_blocks) + ]) diff --git a/wenet/wav2vec/wav2vec2_model.py b/wenet/wav2vec/wav2vec2_model.py new file mode 100644 index 000000000..5e3be90cc --- /dev/null +++ b/wenet/wav2vec/wav2vec2_model.py @@ -0,0 +1,765 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + + +# The origial Wav2vec work is in: +# Paper: https://arxiv.org/pdf/2006.11477.pdf +# Code in Fairseq: https://github.com/pytorch/fairseq/tree/master/examples/wav2vec + +from collections import defaultdict +from typing import List, Optional, Tuple + +import torch +import logging + +from torch.nn.utils.rnn import pad_sequence + +from wenet.transformer.cmvn import GlobalCMVN +from wenet.transformer.ctc import CTC +from wenet.transformer.decoder import (TransformerDecoder, + BiTransformerDecoder) +from wenet.transformer.encoder import ConformerEncoder +from wenet.transformer.encoder import TransformerEncoder +from wenet.transformer.label_smoothing_loss import LabelSmoothingLoss +from wenet.utils.cmvn import load_cmvn +from wenet.utils.common import (IGNORE_ID, add_sos_eos, log_add, + remove_duplicates_and_blank, th_accuracy, + reverse_pad_list) +from wenet.utils.mask import (make_pad_mask, mask_finished_preds, + mask_finished_scores, subsequent_mask) +from wenet.wav2vec.wav2vec2_encoder import W2vConformerEncoder +from wenet.wav2vec.w2v_loss import W2vLoss + +class Wav2vec2Model(torch.nn.Module): + """CTC-attention hybrid Encoder-Decoder model""" + def __init__( + self, + wav2vec_conf:dict, + vocab_size: int, + encoder: TransformerEncoder, + decoder: TransformerDecoder, + ctc: CTC, + ctc_weight: float = 0.5, + ignore_id: int = IGNORE_ID, + reverse_weight: float = 0.0, + lsm_weight: float = 0.0, + length_normalized_loss: bool = False, + ): + assert 0.0 <= ctc_weight <= 1.0, ctc_weight + + super().__init__() + # note that eos is the same as sos (equivalent ID) + self.sos = vocab_size - 1 + self.eos = vocab_size - 1 + self.vocab_size = vocab_size + self.ignore_id = ignore_id + self.ctc_weight = ctc_weight + self.reverse_weight = reverse_weight + + self.encoder = encoder + self.decoder = decoder + self.ctc = ctc + self.criterion_att = LabelSmoothingLoss( + size=vocab_size, + padding_idx=ignore_id, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, + ) + + self.pretrain=wav2vec_conf.get('pretrain',False) + + self.w2v_ext_loss=wav2vec_conf.get('w2v_ext_loss',True) + if self.w2v_ext_loss: + self.w2v_loss_weights=wav2vec_conf.get('w2v_loss_weights',[0.1,10]) + else: + self.w2v_loss_weights=None + if self.pretrain: + self.w2v_loss=W2vLoss(loss_weights=self.w2v_loss_weights) + + self.encoder_grad_mult=wav2vec_conf.get('encoder_grad_mult',1.0) + self.freeze_finetune_updates=wav2vec_conf.get('freeze_finetune_updates',0) + + self.num_updates=0 + + self.log_interval=wav2vec_conf.get('log_interval',100) + self.accum_grad=wav2vec_conf.get('accum_grad',4) + + def set_num_updates(self, num_updates): + + self.num_updates = num_updates + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + num_updates:int=0 + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], + Optional[torch.Tensor]]: + """Frontend + Encoder + Decoder + Calc loss + + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + text: (Batch, Length) + text_lengths: (Batch,) + """ + # Check that batch_size is unified + assert (speech.shape[0] == speech_lengths.shape[0]), (speech.shape, speech_lengths.shape) + self.num_updates=num_updates + + if self.pretrain: + encoder_out, encoder_mask,ext_res = self.encoder(speech, speech_lengths,False) + encoder_out_lens = encoder_mask.squeeze(1).sum(1) + else: + if self.encoder_grad_mult>0 and self.num_updates>=self.freeze_finetune_updates: + encoder_out, encoder_mask ,ext_res= self.encoder(speech, speech_lengths,features_only=True) + if self.encoder_grad_mult !=1.0: + encoder_out=GradMultiply.apply(encoder_out,self.encoder_grad_mult) + encoder_mask=GradMultiply.apply(encoder_mask,self.encoder_grad_mult) + else: + with torch.no_grad(): + encoder_out, encoder_mask,ext_res = self.encoder(speech, speech_lengths,features_only=True) + # 1. Encoder + encoder_out, encoder_mask,ext_res = self.encoder(speech, speech_lengths,features_only=True) + encoder_out_lens = encoder_mask.squeeze(1).sum(1) + + + if self.pretrain: + loss,sample_size,logging_output=self.w2v_loss(self.encoder,ext_res) + loss_att=None + loss_ctc=None + if self.num_updates*self.accum_grad % 100 ==0: + logging.info(logging_output) + else: + # 2a. Attention-decoder branch + if self.ctc_weight != 1.0: + loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask, + text, text_lengths) + else: + loss_att = None + + # 2b. CTC branch + if self.ctc_weight != 0.0: + loss_ctc = self.ctc(encoder_out, encoder_out_lens, text, + text_lengths) + else: + loss_ctc = None + + if loss_ctc is None: + loss = loss_att + elif loss_att is None: + loss = loss_ctc + else: + loss = self.ctc_weight * loss_ctc + (1 - + self.ctc_weight) * loss_att + + + + return loss, loss_att, loss_ctc + + def _calc_att_loss( + self, + encoder_out: torch.Tensor, + encoder_mask: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, float]: + ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, + self.ignore_id) + ys_in_lens = ys_pad_lens + 1 + + # reverse the seq, used for right to left decoder + r_ys_pad = reverse_pad_list(ys_pad, ys_pad_lens, float(self.ignore_id)) + r_ys_in_pad, r_ys_out_pad = add_sos_eos(r_ys_pad, self.sos, self.eos, + self.ignore_id) + # 1. Forward decoder + decoder_out, r_decoder_out, _ = self.decoder(encoder_out, encoder_mask, + ys_in_pad, ys_in_lens, + r_ys_in_pad, + self.reverse_weight) + # 2. Compute attention loss + loss_att = self.criterion_att(decoder_out, ys_out_pad) + r_loss_att = torch.tensor(0.0) + if self.reverse_weight > 0.0: + r_loss_att = self.criterion_att(r_decoder_out, r_ys_out_pad) + loss_att = loss_att * ( + 1 - self.reverse_weight) + r_loss_att * self.reverse_weight + acc_att = th_accuracy( + decoder_out.view(-1, self.vocab_size), + ys_out_pad, + ignore_label=self.ignore_id, + ) + return loss_att, acc_att + + def _forward_encoder( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + decoding_chunk_size: int = -1, + num_decoding_left_chunks: int = -1, + simulate_streaming: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Let's assume B = batch_size + # 1. Encoder + if simulate_streaming and decoding_chunk_size > 0: + encoder_out, encoder_mask = self.encoder.forward_chunk_by_chunk( + speech, + decoding_chunk_size=decoding_chunk_size, + num_decoding_left_chunks=num_decoding_left_chunks + ) # (B, maxlen, encoder_dim) + else: + encoder_out, encoder_mask = self.encoder( + speech, + speech_lengths, + decoding_chunk_size=decoding_chunk_size, + num_decoding_left_chunks=num_decoding_left_chunks + ) # (B, maxlen, encoder_dim) + return encoder_out, encoder_mask + + def recognize( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + beam_size: int = 10, + decoding_chunk_size: int = -1, + num_decoding_left_chunks: int = -1, + simulate_streaming: bool = False, + ) -> torch.Tensor: + """ Apply beam search on attention decoder + + Args: + speech (torch.Tensor): (batch, max_len, feat_dim) + speech_length (torch.Tensor): (batch, ) + beam_size (int): beam size for beam search + decoding_chunk_size (int): decoding chunk for dynamic chunk + trained model. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + 0: used for training, it's prohibited here + simulate_streaming (bool): whether do encoder forward in a + streaming fashion + + Returns: + torch.Tensor: decoding result, (batch, max_result_len) + """ + assert speech.shape[0] == speech_lengths.shape[0] + assert decoding_chunk_size != 0 + device = speech.device + batch_size = speech.shape[0] + + # Let's assume B = batch_size and N = beam_size + # 1. Encoder + encoder_out, encoder_mask = self._forward_encoder( + speech, speech_lengths, decoding_chunk_size, + num_decoding_left_chunks, + simulate_streaming) # (B, maxlen, encoder_dim) + maxlen = encoder_out.size(1) + encoder_dim = encoder_out.size(2) + running_size = batch_size * beam_size + encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1).view( + running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim) + encoder_mask = encoder_mask.unsqueeze(1).repeat( + 1, beam_size, 1, 1).view(running_size, 1, + maxlen) # (B*N, 1, max_len) + + hyps = torch.ones([running_size, 1], dtype=torch.long, + device=device).fill_(self.sos) # (B*N, 1) + scores = torch.tensor([0.0] + [-float('inf')] * (beam_size - 1), + dtype=torch.float) + scores = scores.to(device).repeat([batch_size]).unsqueeze(1).to( + device) # (B*N, 1) + end_flag = torch.zeros_like(scores, dtype=torch.bool, device=device) + cache: Optional[List[torch.Tensor]] = None + # 2. Decoder forward step by step + for i in range(1, maxlen + 1): + # Stop if all batch and all beam produce eos + if end_flag.sum() == running_size: + break + # 2.1 Forward decoder step + hyps_mask = subsequent_mask(i).unsqueeze(0).repeat( + running_size, 1, 1).to(device) # (B*N, i, i) + # logp: (B*N, vocab) + logp, cache = self.decoder.forward_one_step( + encoder_out, encoder_mask, hyps, hyps_mask, cache) + # 2.2 First beam prune: select topk best prob at current time + top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N) + top_k_logp = mask_finished_scores(top_k_logp, end_flag) + top_k_index = mask_finished_preds(top_k_index, end_flag, self.eos) + # 2.3 Second beam prune: select topk score with history + scores = scores + top_k_logp # (B*N, N), broadcast add + scores = scores.view(batch_size, beam_size * beam_size) # (B, N*N) + scores, offset_k_index = scores.topk(k=beam_size) # (B, N) + scores = scores.view(-1, 1) # (B*N, 1) + # 2.4. Compute base index in top_k_index, + # regard top_k_index as (B*N*N),regard offset_k_index as (B*N), + # then find offset_k_index in top_k_index + base_k_index = torch.arange(batch_size, device=device).view( + -1, 1).repeat([1, beam_size]) # (B, N) + base_k_index = base_k_index * beam_size * beam_size + best_k_index = base_k_index.view(-1) + offset_k_index.view( + -1) # (B*N) + + # 2.5 Update best hyps + best_k_pred = torch.index_select(top_k_index.view(-1), + dim=-1, + index=best_k_index) # (B*N) + best_hyps_index = best_k_index // beam_size + last_best_k_hyps = torch.index_select( + hyps, dim=0, index=best_hyps_index) # (B*N, i) + hyps = torch.cat((last_best_k_hyps, best_k_pred.view(-1, 1)), + dim=1) # (B*N, i+1) + + # 2.6 Update end flag + end_flag = torch.eq(hyps[:, -1], self.eos).view(-1, 1) + + # 3. Select best of best + scores = scores.view(batch_size, beam_size) + # TODO: length normalization + best_scores, best_index = scores.max(dim=-1) + best_hyps_index = best_index + torch.arange( + batch_size, dtype=torch.long, device=device) * beam_size + best_hyps = torch.index_select(hyps, dim=0, index=best_hyps_index) + best_hyps = best_hyps[:, 1:] + return best_hyps, best_scores + + def ctc_greedy_search( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + decoding_chunk_size: int = -1, + num_decoding_left_chunks: int = -1, + simulate_streaming: bool = False, + ) -> List[List[int]]: + """ Apply CTC greedy search + + Args: + speech (torch.Tensor): (batch, max_len, feat_dim) + speech_length (torch.Tensor): (batch, ) + beam_size (int): beam size for beam search + decoding_chunk_size (int): decoding chunk for dynamic chunk + trained model. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + 0: used for training, it's prohibited here + simulate_streaming (bool): whether do encoder forward in a + streaming fashion + Returns: + List[List[int]]: best path result + """ + assert speech.shape[0] == speech_lengths.shape[0] + assert decoding_chunk_size != 0 + batch_size = speech.shape[0] + # Let's assume B = batch_size + encoder_out, encoder_mask = self._forward_encoder( + speech, speech_lengths, decoding_chunk_size, + num_decoding_left_chunks, + simulate_streaming) # (B, maxlen, encoder_dim) + maxlen = encoder_out.size(1) + encoder_out_lens = encoder_mask.squeeze(1).sum(1) + ctc_probs = self.ctc.log_softmax( + encoder_out) # (B, maxlen, vocab_size) + topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1) + topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen) + mask = make_pad_mask(encoder_out_lens, maxlen) # (B, maxlen) + topk_index = topk_index.masked_fill_(mask, self.eos) # (B, maxlen) + hyps = [hyp.tolist() for hyp in topk_index] + scores = topk_prob.max(1) + hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps] + return hyps, scores + + def _ctc_prefix_beam_search( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + beam_size: int, + decoding_chunk_size: int = -1, + num_decoding_left_chunks: int = -1, + simulate_streaming: bool = False, + ) -> Tuple[List[List[int]], torch.Tensor]: + """ CTC prefix beam search inner implementation + + Args: + speech (torch.Tensor): (batch, max_len, feat_dim) + speech_length (torch.Tensor): (batch, ) + beam_size (int): beam size for beam search + decoding_chunk_size (int): decoding chunk for dynamic chunk + trained model. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + 0: used for training, it's prohibited here + simulate_streaming (bool): whether do encoder forward in a + streaming fashion + + Returns: + List[List[int]]: nbest results + torch.Tensor: encoder output, (1, max_len, encoder_dim), + it will be used for rescoring in attention rescoring mode + """ + assert speech.shape[0] == speech_lengths.shape[0] + assert decoding_chunk_size != 0 + batch_size = speech.shape[0] + # For CTC prefix beam search, we only support batch_size=1 + assert batch_size == 1 + # Let's assume B = batch_size and N = beam_size + # 1. Encoder forward and get CTC score + encoder_out, encoder_mask = self._forward_encoder( + speech, speech_lengths, decoding_chunk_size, + num_decoding_left_chunks, + simulate_streaming) # (B, maxlen, encoder_dim) + maxlen = encoder_out.size(1) + ctc_probs = self.ctc.log_softmax( + encoder_out) # (1, maxlen, vocab_size) + ctc_probs = ctc_probs.squeeze(0) + # cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score)) + cur_hyps = [(tuple(), (0.0, -float('inf')))] + # 2. CTC beam search step by step + for t in range(0, maxlen): + logp = ctc_probs[t] # (vocab_size,) + # key: prefix, value (pb, pnb), default value(-inf, -inf) + next_hyps = defaultdict(lambda: (-float('inf'), -float('inf'))) + # 2.1 First beam prune: select topk best + top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,) + for s in top_k_index: + s = s.item() + ps = logp[s].item() + for prefix, (pb, pnb) in cur_hyps: + last = prefix[-1] if len(prefix) > 0 else None + if s == 0: # blank + n_pb, n_pnb = next_hyps[prefix] + n_pb = log_add([n_pb, pb + ps, pnb + ps]) + next_hyps[prefix] = (n_pb, n_pnb) + elif s == last: + # Update *ss -> *s; + n_pb, n_pnb = next_hyps[prefix] + n_pnb = log_add([n_pnb, pnb + ps]) + next_hyps[prefix] = (n_pb, n_pnb) + # Update *s-s -> *ss, - is for blank + n_prefix = prefix + (s, ) + n_pb, n_pnb = next_hyps[n_prefix] + n_pnb = log_add([n_pnb, pb + ps]) + next_hyps[n_prefix] = (n_pb, n_pnb) + else: + n_prefix = prefix + (s, ) + n_pb, n_pnb = next_hyps[n_prefix] + n_pnb = log_add([n_pnb, pb + ps, pnb + ps]) + next_hyps[n_prefix] = (n_pb, n_pnb) + + # 2.2 Second beam prune + next_hyps = sorted(next_hyps.items(), + key=lambda x: log_add(list(x[1])), + reverse=True) + cur_hyps = next_hyps[:beam_size] + hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in cur_hyps] + return hyps, encoder_out + + def ctc_prefix_beam_search( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + beam_size: int, + decoding_chunk_size: int = -1, + num_decoding_left_chunks: int = -1, + simulate_streaming: bool = False, + ) -> List[int]: + """ Apply CTC prefix beam search + + Args: + speech (torch.Tensor): (batch, max_len, feat_dim) + speech_length (torch.Tensor): (batch, ) + beam_size (int): beam size for beam search + decoding_chunk_size (int): decoding chunk for dynamic chunk + trained model. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + 0: used for training, it's prohibited here + simulate_streaming (bool): whether do encoder forward in a + streaming fashion + + Returns: + List[int]: CTC prefix beam search nbest results + """ + hyps, _ = self._ctc_prefix_beam_search(speech, speech_lengths, + beam_size, decoding_chunk_size, + num_decoding_left_chunks, + simulate_streaming) + return hyps[0] + + def attention_rescoring( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + beam_size: int, + decoding_chunk_size: int = -1, + num_decoding_left_chunks: int = -1, + ctc_weight: float = 0.0, + simulate_streaming: bool = False, + reverse_weight: float = 0.0, + ) -> List[int]: + """ Apply attention rescoring decoding, CTC prefix beam search + is applied first to get nbest, then we resoring the nbest on + attention decoder with corresponding encoder out + + Args: + speech (torch.Tensor): (batch, max_len, feat_dim) + speech_length (torch.Tensor): (batch, ) + beam_size (int): beam size for beam search + decoding_chunk_size (int): decoding chunk for dynamic chunk + trained model. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + 0: used for training, it's prohibited here + simulate_streaming (bool): whether do encoder forward in a + streaming fashion + reverse_weight (float): right to left decoder weight + ctc_weight (float): ctc score weight + + Returns: + List[int]: Attention rescoring result + """ + assert speech.shape[0] == speech_lengths.shape[0] + assert decoding_chunk_size != 0 + if reverse_weight > 0.0: + # decoder should be a bitransformer decoder if reverse_weight > 0.0 + assert hasattr(self.decoder, 'right_decoder') + device = speech.device + batch_size = speech.shape[0] + # For attention rescoring we only support batch_size=1 + assert batch_size == 1 + # encoder_out: (1, maxlen, encoder_dim), len(hyps) = beam_size + hyps, encoder_out = self._ctc_prefix_beam_search( + speech, speech_lengths, beam_size, decoding_chunk_size, + num_decoding_left_chunks, simulate_streaming) + + assert len(hyps) == beam_size + hyps_pad = pad_sequence([ + torch.tensor(hyp[0], device=device, dtype=torch.long) + for hyp in hyps + ], True, self.ignore_id) # (beam_size, max_hyps_len) + ori_hyps_pad = hyps_pad + hyps_lens = torch.tensor([len(hyp[0]) for hyp in hyps], + device=device, + dtype=torch.long) # (beam_size,) + hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id) + hyps_lens = hyps_lens + 1 # Add at begining + encoder_out = encoder_out.repeat(beam_size, 1, 1) + encoder_mask = torch.ones(beam_size, + 1, + encoder_out.size(1), + dtype=torch.bool, + device=device) + # used for right to left decoder + r_hyps_pad = reverse_pad_list(ori_hyps_pad, hyps_lens, self.ignore_id) + r_hyps_pad, _ = add_sos_eos(r_hyps_pad, self.sos, self.eos, + self.ignore_id) + decoder_out, r_decoder_out, _ = self.decoder( + encoder_out, encoder_mask, hyps_pad, hyps_lens, r_hyps_pad, + reverse_weight) # (beam_size, max_hyps_len, vocab_size) + decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1) + decoder_out = decoder_out.cpu().numpy() + # r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a + # conventional transformer decoder. + r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1) + r_decoder_out = r_decoder_out.cpu().numpy() + # Only use decoder score for rescoring + best_score = -float('inf') + best_index = 0 + for i, hyp in enumerate(hyps): + score = 0.0 + for j, w in enumerate(hyp[0]): + score += decoder_out[i][j][w] + score += decoder_out[i][len(hyp[0])][self.eos] + # add right to left decoder score + if reverse_weight > 0: + r_score = 0.0 + for j, w in enumerate(hyp[0]): + r_score += r_decoder_out[i][len(hyp[0]) - j - 1][w] + r_score += r_decoder_out[i][len(hyp[0])][self.eos] + score = score * (1 - reverse_weight) + r_score * reverse_weight + # add ctc score + score += hyp[1] * ctc_weight + if score > best_score: + best_score = score + best_index = i + return hyps[best_index][0], best_score + + @torch.jit.export + def subsampling_rate(self) -> int: + """ Export interface for c++ call, return subsampling_rate of the + model + """ + return self.encoder.embed.subsampling_rate + + @torch.jit.export + def right_context(self) -> int: + """ Export interface for c++ call, return right_context of the model + """ + return self.encoder.embed.right_context + + @torch.jit.export + def sos_symbol(self) -> int: + """ Export interface for c++ call, return sos symbol id of the model + """ + return self.sos + + @torch.jit.export + def eos_symbol(self) -> int: + """ Export interface for c++ call, return eos symbol id of the model + """ + return self.eos + + @torch.jit.export + def forward_encoder_chunk( + self, + xs: torch.Tensor, + offset: int, + required_cache_size: int, + subsampling_cache: Optional[torch.Tensor] = None, + elayers_output_cache: Optional[List[torch.Tensor]] = None, + conformer_cnn_cache: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], + List[torch.Tensor]]: + """ Export interface for c++ call, give input chunk xs, and return + output from time 0 to current chunk. + + Args: + xs (torch.Tensor): chunk input + subsampling_cache (Optional[torch.Tensor]): subsampling cache + elayers_output_cache (Optional[List[torch.Tensor]]): + transformer/conformer encoder layers output cache + conformer_cnn_cache (Optional[List[torch.Tensor]]): conformer + cnn cache + + Returns: + torch.Tensor: output, it ranges from time 0 to current chunk. + torch.Tensor: subsampling cache + List[torch.Tensor]: attention cache + List[torch.Tensor]: conformer cnn cache + + """ + return self.encoder.forward_chunk(xs, offset, required_cache_size, + subsampling_cache, + elayers_output_cache, + conformer_cnn_cache) + + @torch.jit.export + def ctc_activation(self, xs: torch.Tensor) -> torch.Tensor: + """ Export interface for c++ call, apply linear transform and log + softmax before ctc + Args: + xs (torch.Tensor): encoder output + + Returns: + torch.Tensor: activation before ctc + + """ + return self.ctc.log_softmax(xs) + + @torch.jit.export + def is_bidirectional_decoder(self) -> bool: + """ + Returns: + torch.Tensor: decoder output + """ + if hasattr(self.decoder, 'right_decoder'): + return True + else: + return False + + @torch.jit.export + def forward_attention_decoder( + self, + hyps: torch.Tensor, + hyps_lens: torch.Tensor, + encoder_out: torch.Tensor, + reverse_weight: float = 0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ Export interface for c++ call, forward decoder with multiple + hypothesis from ctc prefix beam search and one encoder output + Args: + hyps (torch.Tensor): hyps from ctc prefix beam search, already + pad sos at the begining + hyps_lens (torch.Tensor): length of each hyp in hyps + encoder_out (torch.Tensor): corresponding encoder output + r_hyps (torch.Tensor): hyps from ctc prefix beam search, already + pad eos at the begining which is used fo right to left decoder + reverse_weight: used for verfing whether used right to left decoder, + > 0 will use. + + Returns: + torch.Tensor: decoder output + """ + assert encoder_out.size(0) == 1 + num_hyps = hyps.size(0) + assert hyps_lens.size(0) == num_hyps + encoder_out = encoder_out.repeat(num_hyps, 1, 1) + encoder_mask = torch.ones(num_hyps, + 1, + encoder_out.size(1), + dtype=torch.bool, + device=encoder_out.device) + # input for right to left decoder + # this hyps_lens has count token, we need minus it. + r_hyps_lens = hyps_lens - 1 + # this hyps has included token, so it should be + # convert the original hyps. + r_hyps = hyps[:, 1:] + r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(self.ignore_id)) + r_hyps, _ = add_sos_eos(r_hyps, self.sos, self.eos, self.ignore_id) + decoder_out, r_decoder_out, _ = self.decoder( + encoder_out, encoder_mask, hyps, hyps_lens, r_hyps, + reverse_weight) # (num_hyps, max_hyps_len, vocab_size) + decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1) + + # right to left decoder may be not used during decoding process, + # which depends on reverse_weight param. + # r_dccoder_out will be 0.0, if reverse_weight is 0.0 + r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1) + return decoder_out, r_decoder_out + + +def init_wav2vec2_model(configs): + if configs['cmvn_file'] is not None: + mean, istd = load_cmvn(configs['cmvn_file'], configs['is_json_cmvn']) + global_cmvn = GlobalCMVN( + torch.from_numpy(mean).float(), + torch.from_numpy(istd).float()) + else: + global_cmvn = None + + input_dim = configs['input_dim'] + vocab_size = configs['output_dim'] + + encoder_type = configs.get('encoder', 'conformer') + decoder_type = configs.get('decoder', 'bitransformer') + + wav2vec_conf=configs['wav2vec_conf'] + + if encoder_type == 'conformer': + encoder = W2vConformerEncoder(wav2vec_conf, + input_dim, + global_cmvn=global_cmvn, + **configs['encoder_conf']) + else: + encoder = TransformerEncoder(input_dim, + global_cmvn=global_cmvn, + **configs['encoder_conf']) + if decoder_type == 'transformer': + decoder = TransformerDecoder(vocab_size, encoder.output_size(), + **configs['decoder_conf']) + else: + assert 0.0 < configs['model_conf']['reverse_weight'] < 1.0 + assert configs['decoder_conf']['r_num_blocks'] > 0 + decoder = BiTransformerDecoder(vocab_size, encoder.output_size(), + **configs['decoder_conf']) + ctc = CTC(vocab_size, encoder.output_size()) + model = Wav2vec2Model( + wav2vec_conf=wav2vec_conf, + vocab_size=vocab_size, + encoder=encoder, + decoder=decoder, + ctc=ctc, + **configs['model_conf'], + ) + return model From d41cdfc9726a8492a8a139599beb98d03d0553bf Mon Sep 17 00:00:00 2001 From: aydentang Date: Tue, 15 Mar 2022 20:22:59 +0800 Subject: [PATCH 02/19] ssl recipe --- .../conf/finetune/train_conformer_100h.yaml | 91 +++++ .../train_conformer_pretrain_w2v.yaml | 94 +++++ examples/librispeech/ssl/run.sh | 335 ++++++++++++++++++ 3 files changed, 520 insertions(+) create mode 100644 examples/librispeech/ssl/conf/finetune/train_conformer_100h.yaml create mode 100644 examples/librispeech/ssl/conf/pretrain/train_conformer_pretrain_w2v.yaml create mode 100644 examples/librispeech/ssl/run.sh diff --git a/examples/librispeech/ssl/conf/finetune/train_conformer_100h.yaml b/examples/librispeech/ssl/conf/finetune/train_conformer_100h.yaml new file mode 100644 index 000000000..acf7941c8 --- /dev/null +++ b/examples/librispeech/ssl/conf/finetune/train_conformer_100h.yaml @@ -0,0 +1,91 @@ +# network architecture +# encoder related +encoder: conformer +encoder_conf: + output_size: 512 # dimension of attention + attention_heads: 8 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 + positional_dropout_rate: 0.0 + attention_dropout_rate: 0.0 + input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + normalize_before: true + cnn_module_kernel: 31 + use_cnn_module: True + activation_type: 'swish' + pos_enc_layer_type: 'rel_pos' + selfattention_layer_type: 'rel_selfattn' + +# decoder related +decoder: transformer +decoder_conf: + attention_heads: 2 + linear_units: 512 + num_blocks: 1 + dropout_rate: 0.1 + positional_dropout_rate: 0.0 + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 + +# hybrid CTC/attention +model_conf: + ctc_weight: 0.7 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + +# use raw_wav or kaldi feature +raw_wav: true + +# dataset related +dataset_conf: + filter_conf: + max_length: 4000 + min_length: 50 + token_max_length: 400 + token_min_length: 1 + resample_conf: + resample_rate: 16000 + speed_perturb: true + fbank_conf: + num_mel_bins: 80 + frame_shift: 10 + frame_length: 25 + dither: 1.0 + spec_aug: true + spec_aug_conf: + num_t_mask: 3 + num_f_mask: 2 + max_t: 50 + max_f: 10 + shuffle: true + shuffle_conf: + shuffle_size: 1500 + sort: true + sort_conf: + sort_size: 500 # sort_size should be less than shuffle_size + batch_conf: + batch_type: 'static' # static or dynamic + batch_size: 12 + +pretrain: False +wav2vec_conf: + pretrain: False + quantize_targets: True + project_targets: True + latent_vars: 320 + latent_dim: 512 + latent_groups: 2 + mask: False + +grad_clip: 5 +accum_grad: 1 +max_epoch: 120 +log_interval: 100 + +optim: adam +optim_conf: + lr: 0.001 +scheduler: warmuplr # pytorch v1.1.0+ required +scheduler_conf: + warmup_steps: 15000 diff --git a/examples/librispeech/ssl/conf/pretrain/train_conformer_pretrain_w2v.yaml b/examples/librispeech/ssl/conf/pretrain/train_conformer_pretrain_w2v.yaml new file mode 100644 index 000000000..6a11e8823 --- /dev/null +++ b/examples/librispeech/ssl/conf/pretrain/train_conformer_pretrain_w2v.yaml @@ -0,0 +1,94 @@ +# network architecture +# encoder related +encoder: conformer +encoder_conf: + output_size: 512 # dimension of attention + attention_heads: 8 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 + positional_dropout_rate: 0.0 + attention_dropout_rate: 0.0 + input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + normalize_before: true + cnn_module_kernel: 31 + use_cnn_module: True + activation_type: 'swish' + pos_enc_layer_type: 'rel_pos' + selfattention_layer_type: 'rel_selfattn' + +# decoder related +decoder: transformer +decoder_conf: + attention_heads: 8 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 + positional_dropout_rate: 0.0 + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 + +# hybrid CTC/attention +model_conf: + ctc_weight: 1.0 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + +# use raw_wav or kaldi feature +raw_wav: true + +# dataset related +dataset_conf: + filter_conf: + max_length: 2500 + min_length: 50 + token_max_length: 400 + token_min_length: 1 + resample_conf: + resample_rate: 16000 + speed_perturb: false + fbank_conf: + num_mel_bins: 80 + frame_shift: 10 + frame_length: 25 + dither: 1.0 + spec_aug: false + spec_aug_conf: + num_t_mask: 3 + num_f_mask: 2 + max_t: 50 + max_f: 10 + shuffle: true + shuffle_conf: + shuffle_size: 1500 + sort: true + sort_conf: + sort_size: 500 # sort_size should be less than shuffle_size + batch_conf: + batch_type: 'dynamic' # static or dynamic + max_frames_in_batch: 20000 + +pretrain: True +wav2vec_conf: + pretrain: True + quantize_targets: True + project_targets: True + latent_vars: 320 + latent_dim: 512 + latent_groups: 2 + w2v_ext_loss: True + w2v_loss_weights: [1.5,0] + mask: True + mask_prob: 0.65 + +grad_clip: 5 +accum_grad: 4 +max_epoch: 90 +log_interval: 100 + +optim: adam +optim_conf: + lr: 0.002 +scheduler: warmuplr # pytorch v1.1.0+ required +scheduler_conf: + warmup_steps: 25000 diff --git a/examples/librispeech/ssl/run.sh b/examples/librispeech/ssl/run.sh new file mode 100644 index 000000000..5768d2b85 --- /dev/null +++ b/examples/librispeech/ssl/run.sh @@ -0,0 +1,335 @@ +#!/bin/bash +# Copyright 2021 Tencent Inc. (Author: Kai Tang). +# Apach 2.0 + +. ./path1.sh || exit 1; + +# Use this to control how many gpu you use, It's 1-gpu training if you specify +# just 1gpu, otherwise it's is multiple gpu training based on DDP in pytorch +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" +stage=4 # start from 0 if you need to start from data preparation +stop_stage=4 +# data +data_url=www.openslr.org/resources/12 +# use your own data path +datadir= +# wav data dir +wave_data=data +# Optional train_config +# 1. conf/train_transformer_large.yaml: Standard transformer +pretrain_config=conf/pretrain/train_conformer_pretrain_w2v.yaml +# finetune_config +finetune_config=conf/finetune/train_conformer_100h.yaml +checkpoint= +cmvn=true +do_delta=false + +pretrain_dir=exp/pretrain_960h +finetune_dir=exp/finetune_spec_aug_100h + +# use average_checkpoint will get better result +average_checkpoint=true +decode_checkpoint=$dir/final.pt +# maybe you can try to adjust it if you can not get close results as README.md +average_num=20 +decode_modes="attention_rescoring ctc_greedy_search ctc_prefix_beam_search attention" + +. tools/parse_options.sh || exit 1; + +# bpemode (unigram or bpe) +nbpe=5000 +bpemode=unigram + +set -e +set -u +set -o pipefail + +train_set=train_960 +train_set_100h=train_clean_100 +dev_set=dev +recog_set="test_clean test_other dev_clean dev_other" + + +# pretrained w2v-conformer encoder +enc_init=$pretrain_dir/final.pt +enc_init_mods='encoder.encoders.0,encoder.encoders.1,encoder.encoders.2,encoder.encoders.3,encoder.encoders.4,encoder.encoders.5,encoder.encoders.6,encoder.encoders.7,encoder.encoders.8,encoder.encoders.9,encoder.encoders.10,encoder.encoders.11,encoder.embed' + +if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then + echo "stage -1: Data Download" + for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do + local/download_and_untar.sh ${datadir} ${data_url} ${part} + done +fi + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + ### Task dependent. You have to make data the following preparation part by yourself. + ### But you can utilize Kaldi recipes in most cases + echo "stage 0: Data preparation" + #for part in dev-clean test-clean test-other train-clean-100 train-clean-360 train-other-500; do + for part in dev-other ; do + # use underscore-separated names in data directories. + local/data_prep_torchaudio.sh ${datadir}/LibriSpeech/${part} $wave_data/${part//-/_} + done +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + ### Task dependent. You have to design training and dev sets by yourself. + ### But you can utilize Kaldi recipes in most cases + echo "stage 1: Feature Generation" + mkdir -p $wave_data/train_960 + # merge total training data + for set in train_clean_100 train_clean_360 train_other_500; do + for f in `ls $wave_data/$set`; do + cat $wave_data/$set/$f >> $wave_data/train_960/$f + done + done + mkdir -p $wave_data/dev + # merge total dev data + for set in dev_clean dev_other; do + for f in `ls $wave_data/$set`; do + cat $wave_data/$set/$f >> $wave_data/$dev_set/$f + done + done + + tools/compute_cmvn_stats.py --num_workers 16 --train_config $train_config \ + --in_scp $wave_data/$train_set/wav.scp \ + --out_cmvn $wave_data/$train_set/global_cmvn + +fi + + +dict=$wave_data/lang_char/${train_set}_${bpemode}${nbpe}_units.txt +bpemodel=$wave_data/lang_char/${train_set}_${bpemode}${nbpe} +echo "dictionary: ${dict}" +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + ### Task dependent. You have to check non-linguistic symbols used in the corpus. + echo "stage 2: Dictionary and Json Data Preparation" + mkdir -p data/lang_char/ + + echo " 0" > ${dict} # 0 will be used for "blank" in CTC + echo " 1" >> ${dict} # must be 1 + + # we borrowed these code and scripts which are related bpe from ESPnet. + cut -f 2- -d" " $wave_data/${train_set}/text > $wave_data/lang_char/input.txt + tools/spm_train --input=$wave_data/lang_char/input.txt --vocab_size=${nbpe} --model_type=${bpemode} --model_prefix=${bpemodel} --input_sentence_size=100000000 + tools/spm_encode --model=${bpemodel}.model --output_format=piece < $wave_data/lang_char/input.txt | tr ' ' '\n' | sort | uniq | awk '{print $0 " " NR+1}' >> ${dict} + num_token=$(cat $dict | wc -l) + echo " $num_token" >> $dict # + wc -l ${dict} +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # Prepare wenet requried data + echo "Prepare data, prepare requried format" + #for x in $dev_set ${recog_set} $train_set ; do + for x in $train_set_100h ; do + tools/make_raw_list.py $wave_data/$x/wav.scp $wave_data/$x/text \ + $wave_data/$x/data.list + done + +fi + + +#pretrain +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + # Training + mkdir -p $pretrain_dir + INIT_FILE=$pretrain_dir/ddp_init + rm -f $INIT_FILE # delete old one before starting + init_method=file://$(readlink -f $INIT_FILE) + echo "$0: init method is $init_method" + num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') + # Use "nccl" if it works, otherwise use "gloo" + dist_backend="nccl" + cmvn_opts= + $cmvn && cmvn_opts="--cmvn $wave_data/${train_set}/global_cmvn" + # train.py will write $train_config to $dir/train.yaml with model input + # and output dimension, train.yaml will be used for inference or model + # export later + for ((i = 0; i < $num_gpus; ++i)); do + { + gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1]) + python wenet/bin/train.py --gpu $gpu_id \ + --config $pretrain_config \ + --data_type raw \ + --symbol_table $dict \ + --bpe_model ${bpemodel}.model \ + --train_data $wave_data/$train_set/data.list \ + --cv_data $wave_data/$dev_set/data.list \ + ${checkpoint:+--checkpoint $checkpoint} \ + --model_dir $pretrain_dir \ + --ddp.init_method $init_method \ + --ddp.world_size $num_gpus \ + --ddp.rank $i \ + --ddp.dist_backend $dist_backend \ + --num_workers 3 \ + $cmvn_opts \ + --pin_memory + } & + done + wait +fi + +#finetune 100h using wav2vec model +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + # Training + mkdir -p $finetune_dir + INIT_FILE=$finetune_dir/ddp_init + rm -f $INIT_FILE # delete old one before starting + init_method=file://$(readlink -f $INIT_FILE) + echo "$0: init method is $init_method" + num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') + # Use "nccl" if it works, otherwise use "gloo" + dist_backend="nccl" + cmvn_opts= + $cmvn && cmvn_opts="--cmvn $wave_data/${train_set}/global_cmvn" + # train.py will write $train_config to $dir/train.yaml with model input + # and output dimension, train.yaml will be used for inference or model + # export later + for ((i = 0; i < $num_gpus; ++i)); do + { + gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1]) + python wenet/bin/train.py --gpu $gpu_id \ + --config $finetune_config \ + --data_type raw \ + --symbol_table $dict \ + --bpe_model ${bpemodel}.model \ + --train_data $wave_data/$train_set_100h/data.list \ + --cv_data $wave_data/$dev_set/data.list \ + ${checkpoint:+--checkpoint $checkpoint} \ + ${enc_init:+--enc_init $enc_init} \ + --enc_init_mods $enc_init_mods \ + --model_dir $finetune_dir \ + --ddp.init_method $init_method \ + --ddp.world_size $num_gpus \ + --ddp.rank $i \ + --ddp.dist_backend $dist_backend \ + --num_workers 3 \ + $cmvn_opts \ + --pin_memory + } & + done + wait +fi + +if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then + # Test model, please specify the model you want to test by --checkpoint + cmvn_opts= + $cmvn && cmvn_opts="--cmvn data/${train_set}/global_cmvn" + # TODO, Add model average here + mkdir -p $finetune_dir/test + if [ ${average_checkpoint} == true ]; then + decode_checkpoint=$finetune_dir/avg_${average_num}.pt + echo "do model average and final checkpoint is $decode_checkpoint" + python wenet/bin/average_model.py \ + --dst_model $decode_checkpoint \ + --src_path $finetune_dir \ + --num ${average_num} \ + --val_best + fi + # Specify decoding_chunk_size if it's a unified dynamic chunk trained model + # -1 for full chunk + decoding_chunk_size= + ctc_weight=0.5 + # Polling GPU id begin with index 0 + num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') + idx=0 + for test in $recog_set; do + for mode in ${decode_modes}; do + { + { + test_dir=$finetune_dir/${test}_${mode} + mkdir -p $test_dir + gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$idx+1]) + python wenet/bin/recognize.py --gpu $gpu_id \ + --mode $mode \ + --config $finetune_dir/train.yaml \ + --data_type raw \ + --dict $dict \ + --bpe_model ${bpemodel}.model \ + --test_data $wave_data/$test/data.list \ + --checkpoint $decode_checkpoint \ + --beam_size 10 \ + --batch_size 1 \ + --penalty 0.0 \ + --result_file $test_dir/text_bpe \ + --ctc_weight $ctc_weight \ + ${decoding_chunk_size:+--decoding_chunk_size $decoding_chunk_size} + + cut -f2- -d " " $test_dir/text_bpe > $test_dir/text_bpe_value_tmp + cut -f1 -d " " $test_dir/text_bpe > $test_dir/text_bpe_key_tmp + tools/spm_decode --model=${bpemodel}.model --input_format=piece \ + < $test_dir/text_bpe_value_tmp | sed -e "s/▁/ /g" > $test_dir/text_value_tmp + paste -d " " $test_dir/text_bpe_key_tmp $test_dir/text_value_tmp > $test_dir/text + + python tools/compute-wer.py --char=1 --v=1 \ + $wave_data/$test/text $test_dir/text > $test_dir/wer + } & + + ((idx+=1)) + if [ $idx -eq $num_gpus ]; then + idx=0 + fi + } + done + done + wait + +fi + +if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then + # Export the best model you want + python wenet/bin/export_jit.py \ + --config $finetune_dir/train.yaml \ + --checkpoint $finetune_dir/avg_${average_num}.pt \ + --output_file $finetune_dir/final.zip +fi + +# Optionally, you can add LM and test it with runtime. +if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then + lm=data/local/lm + lexicon=data/local/dict/lexicon.txt + mkdir -p $lm + mkdir -p data/local/dict + + # 7.1 Download & format LM + which_lm=3-gram.pruned.1e-7.arpa.gz + if [ ! -e ${lm}/${which_lm} ]; then + wget http://www.openslr.org/resources/11/${which_lm} -P ${lm} + fi + echo "unzip lm($which_lm)..." + gunzip -k ${lm}/${which_lm} -c > ${lm}/lm.arpa + echo "Lm saved as ${lm}/lm.arpa" + + # 7.2 Prepare dict + unit_file=$dict + bpemodel=$bpemodel + # use $dir/words.txt (unit_file) and $dir/train_960_unigram5000 (bpemodel) + # if you download pretrained librispeech conformer model + cp $unit_file data/local/dict/units.txt + if [ ! -e ${lm}/librispeech-lexicon.txt ]; then + wget http://www.openslr.org/resources/11/librispeech-lexicon.txt -P ${lm} + fi + echo "build lexicon..." + tools/fst/prepare_dict.py $unit_file ${lm}/librispeech-lexicon.txt \ + $lexicon $bpemodel.model + echo "lexicon saved as '$lexicon'" + + # 7.3 Build decoding TLG + tools/fst/compile_lexicon_token_fst.sh \ + data/local/dict data/local/tmp data/local/lang + tools/fst/make_tlg.sh data/local/lm data/local/lang data/lang_test || exit 1; + + # 7.4 Decoding with runtime + fst_dir=data/lang_test + for test in ${recog_set}; do + ./tools/decode.sh --nj 6 \ + --beam 10.0 --lattice_beam 5 --max_active 7000 --blank_skip_thresh 0.98 \ + --ctc_weight 0.5 --rescoring_weight 1.0 --acoustic_scale 1.2 \ + --fst_path $fst_dir/TLG.fst \ + data/$test/wav.scp data/$test/text $dir/final.zip $fst_dir/words.txt \ + $dir/lm_with_runtime_${test} + tail $dir/lm_with_runtime_${test}/wer + done +fi + From f82c4434d2a43e27861b968304f2f4e5541ab62c Mon Sep 17 00:00:00 2001 From: aydentang Date: Tue, 15 Mar 2022 20:25:45 +0800 Subject: [PATCH 03/19] update recipe --- .../ssl/local/data_prep_torchaudio.sh | 54 +++++++++++ .../ssl/local/download_and_untar.sh | 97 +++++++++++++++++++ examples/librispeech/ssl/path.sh | 8 ++ examples/librispeech/ssl/tools | 1 + examples/librispeech/ssl/wenet | 1 + 5 files changed, 161 insertions(+) create mode 100755 examples/librispeech/ssl/local/data_prep_torchaudio.sh create mode 100755 examples/librispeech/ssl/local/download_and_untar.sh create mode 100644 examples/librispeech/ssl/path.sh create mode 120000 examples/librispeech/ssl/tools create mode 120000 examples/librispeech/ssl/wenet diff --git a/examples/librispeech/ssl/local/data_prep_torchaudio.sh b/examples/librispeech/ssl/local/data_prep_torchaudio.sh new file mode 100755 index 000000000..c7dc1deb7 --- /dev/null +++ b/examples/librispeech/ssl/local/data_prep_torchaudio.sh @@ -0,0 +1,54 @@ +#!/bin/bash + +# Copyright 2014 Vassil Panayotov +# 2014 Johns Hopkins University (author: Daniel Povey) +# Apache 2.0 + +if [ "$#" -ne 2 ]; then + echo "Usage: $0 " + echo "e.g.: $0 /export/a15/vpanayotov/data/LibriSpeech/dev-clean data/dev-clean" + exit 1 +fi + +src=$1 +dst=$2 + +# all utterances are FLAC compressed +if ! which flac >&/dev/null; then + echo "Please install 'flac' on ALL worker nodes!" + exit 1 +fi + +mkdir -p $dst || exit 1 + +[ ! -d $src ] && echo "$0: no such directory $src" && exit 1 + +wav_scp=$dst/wav.scp; [[ -f "$wav_scp" ]] && rm $wav_scp +trans=$dst/text; [[ -f "$trans" ]] && rm $trans + +for reader_dir in $(find -L $src -mindepth 1 -maxdepth 1 -type d | sort); do + reader=$(basename $reader_dir) + if ! [ $reader -eq $reader ]; then # not integer. + echo "$0: unexpected subdirectory name $reader" + exit 1 + fi + + for chapter_dir in $(find -L $reader_dir/ -mindepth 1 -maxdepth 1 -type d | sort); do + chapter=$(basename $chapter_dir) + if ! [ "$chapter" -eq "$chapter" ]; then + echo "$0: unexpected chapter-subdirectory name $chapter" + exit 1 + fi + + find -L $chapter_dir/ -iname "*.flac" | sort | xargs -I% basename % .flac | \ + awk -v "dir=$chapter_dir" '{printf "%s %s/%s.flac\n", $0, dir, $0}' >>$wav_scp|| exit 1 + + chapter_trans=$chapter_dir/${reader}-${chapter}.trans.txt + [ ! -f $chapter_trans ] && echo "$0: expected file $chapter_trans to exist" && exit 1 + cat $chapter_trans >>$trans + done +done + +echo "$0: successfully prepared data in $dst" + +exit 0 diff --git a/examples/librispeech/ssl/local/download_and_untar.sh b/examples/librispeech/ssl/local/download_and_untar.sh new file mode 100755 index 000000000..cd32fb6b9 --- /dev/null +++ b/examples/librispeech/ssl/local/download_and_untar.sh @@ -0,0 +1,97 @@ +#!/bin/bash + +# Copyright 2014 Johns Hopkins University (author: Daniel Povey) +# Apache 2.0 + +remove_archive=false + +if [ "$1" == --remove-archive ]; then + remove_archive=true + shift +fi + +if [ $# -ne 3 ]; then + echo "Usage: $0 [--remove-archive] " + echo "e.g.: $0 /export/a15/vpanayotov/data www.openslr.org/resources/11 dev-clean" + echo "With --remove-archive it will remove the archive after successfully un-tarring it." + echo " can be one of: dev-clean, test-clean, dev-other, test-other," + echo " train-clean-100, train-clean-360, train-other-500." + exit 1 +fi + +data=$1 +url=$2 +part=$3 + +if [ ! -d "$data" ]; then + echo "$0: no such directory $data" + exit 1 +fi + +part_ok=false +list="dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500" +for x in $list; do + if [ "$part" == $x ]; then part_ok=true; fi +done +if ! $part_ok; then + echo "$0: expected to be one of $list, but got '$part'" + exit 1 +fi + +if [ -z "$url" ]; then + echo "$0: empty URL base." + exit 1 +fi + +if [ -f $data/LibriSpeech/$part/.complete ]; then + echo "$0: data part $part was already successfully extracted, nothing to do." + exit 0 +fi + + +# sizes of the archive files in bytes. This is some older versions. +sizes_old="371012589 347390293 379743611 361838298 6420417880 23082659865 30626749128" +# sizes_new is the archive file sizes of the final release. Some of these sizes are of +# things we probably won't download. +sizes_new="337926286 314305928 695964615 297279345 87960560420 33373768 346663984 328757843 6387309499 23049477885 30593501606" + +if [ -f $data/$part.tar.gz ]; then + size=$(/bin/ls -l $data/$part.tar.gz | awk '{print $5}') + size_ok=false + for s in $sizes_old $sizes_new; do if [ $s == $size ]; then size_ok=true; fi; done + if ! $size_ok; then + echo "$0: removing existing file $data/$part.tar.gz because its size in bytes $size" + echo "does not equal the size of one of the archives." + rm $data/$part.tar.gz + else + echo "$data/$part.tar.gz exists and appears to be complete." + fi +fi + +if [ ! -f $data/$part.tar.gz ]; then + if ! which wget >/dev/null; then + echo "$0: wget is not installed." + exit 1 + fi + full_url=$url/$part.tar.gz + echo "$0: downloading data from $full_url. This may take some time, please be patient." + + if ! wget -P $data --no-check-certificate $full_url; then + echo "$0: error executing wget $full_url" + exit 1 + fi +fi + +if ! tar -C $data -xvzf $data/$part.tar.gz; then + echo "$0: error un-tarring archive $data/$part.tar.gz" + exit 1 +fi + +touch $data/LibriSpeech/$part/.complete + +echo "$0: Successfully downloaded and un-tarred $data/$part.tar.gz" + +if $remove_archive; then + echo "$0: removing $data/$part.tar.gz file since --remove-archive option was supplied." + rm $data/$part.tar.gz +fi diff --git a/examples/librispeech/ssl/path.sh b/examples/librispeech/ssl/path.sh new file mode 100644 index 000000000..5ddca76cc --- /dev/null +++ b/examples/librispeech/ssl/path.sh @@ -0,0 +1,8 @@ +export WENET_DIR=$PWD/../../.. +export BUILD_DIR=${WENET_DIR}/runtime/server/x86/build +export OPENFST_PREFIX_DIR=${BUILD_DIR}/../fc_base/openfst-subbuild/openfst-populate-prefix +export PATH=$PWD:${BUILD_DIR}:${BUILD_DIR}/kaldi:${OPENFST_PREFIX_DIR}/bin:$PATH + +# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=../../../:$PYTHONPATH diff --git a/examples/librispeech/ssl/tools b/examples/librispeech/ssl/tools new file mode 120000 index 000000000..c92f4172d --- /dev/null +++ b/examples/librispeech/ssl/tools @@ -0,0 +1 @@ +../../../tools \ No newline at end of file diff --git a/examples/librispeech/ssl/wenet b/examples/librispeech/ssl/wenet new file mode 120000 index 000000000..702de77db --- /dev/null +++ b/examples/librispeech/ssl/wenet @@ -0,0 +1 @@ +../../../wenet \ No newline at end of file From 139001e23a040ce6695f20935723e1ca12ab794a Mon Sep 17 00:00:00 2001 From: aydentang Date: Tue, 15 Mar 2022 20:27:42 +0800 Subject: [PATCH 04/19] update recipe --- examples/librispeech/ssl/run.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/librispeech/ssl/run.sh b/examples/librispeech/ssl/run.sh index 5768d2b85..2d202f104 100644 --- a/examples/librispeech/ssl/run.sh +++ b/examples/librispeech/ssl/run.sh @@ -2,7 +2,7 @@ # Copyright 2021 Tencent Inc. (Author: Kai Tang). # Apach 2.0 -. ./path1.sh || exit 1; +. ./path.sh || exit 1; # Use this to control how many gpu you use, It's 1-gpu training if you specify # just 1gpu, otherwise it's is multiple gpu training based on DDP in pytorch From 1754fe124aa92ef19aeaf2e653487e1099131fa4 Mon Sep 17 00:00:00 2001 From: aydentang Date: Tue, 15 Mar 2022 20:33:11 +0800 Subject: [PATCH 05/19] update recipe --- examples/librispeech/ssl/run.sh | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/librispeech/ssl/run.sh b/examples/librispeech/ssl/run.sh index 2d202f104..e536d15f6 100644 --- a/examples/librispeech/ssl/run.sh +++ b/examples/librispeech/ssl/run.sh @@ -65,8 +65,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ### Task dependent. You have to make data the following preparation part by yourself. ### But you can utilize Kaldi recipes in most cases echo "stage 0: Data preparation" - #for part in dev-clean test-clean test-other train-clean-100 train-clean-360 train-other-500; do - for part in dev-other ; do + #for part in dev-clean dev-other test-clean test-other train-clean-100 train-clean-360 train-other-500; do # use underscore-separated names in data directories. local/data_prep_torchaudio.sh ${datadir}/LibriSpeech/${part} $wave_data/${part//-/_} done @@ -121,8 +120,7 @@ fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then # Prepare wenet requried data echo "Prepare data, prepare requried format" - #for x in $dev_set ${recog_set} $train_set ; do - for x in $train_set_100h ; do + for x in $dev_set ${recog_set} $train_set $train_set_100h; do tools/make_raw_list.py $wave_data/$x/wav.scp $wave_data/$x/text \ $wave_data/$x/data.list done From f8b3dfdbd868ce00cde88a5f3738f67927af49a8 Mon Sep 17 00:00:00 2001 From: aydentang Date: Tue, 15 Mar 2022 21:40:13 +0800 Subject: [PATCH 06/19] fix recog bug --- wenet/bin/recognize.py | 13 ++++++++++++- wenet/wav2vec/wav2vec2_encoder.py | 2 +- wenet/wav2vec/wav2vec2_model.py | 4 ++-- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/wenet/bin/recognize.py b/wenet/bin/recognize.py index f8c116c70..f4e20fcb9 100644 --- a/wenet/bin/recognize.py +++ b/wenet/bin/recognize.py @@ -26,6 +26,7 @@ from wenet.dataset.dataset import Dataset from wenet.transformer.asr_model import init_asr_model +from wenet.wav2vec.wav2vec2_model import init_wav2vec2_model from wenet.utils.checkpoint import load_checkpoint from wenet.utils.file_utils import read_symbol_table, read_non_lang_symbols from wenet.utils.config import override_config @@ -142,6 +143,13 @@ def main(): test_conf['batch_conf']['batch_size'] = args.batch_size non_lang_syms = read_non_lang_symbols(args.non_lang_syms) + pretrain=configs.get('pretrain',False) + if 'wav2vec_conf' in configs: + wav2vec_conf=configs['wav2vec_conf'] + wav2vec_conf['pretrain']=pretrain + else: + wav2vec_conf=None + test_dataset = Dataset(args.data_type, args.test_data, symbol_table, @@ -153,7 +161,10 @@ def main(): test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0) # Init asr model from configs - model = init_asr_model(configs) + if wav2vec_conf: + model=init_wav2vec2_model(configs) + else: + model = init_asr_model(configs) # Load dict char_dict = {v: k for k, v in symbol_table.items()} diff --git a/wenet/wav2vec/wav2vec2_encoder.py b/wenet/wav2vec/wav2vec2_encoder.py index 2325883f5..573f767ce 100644 --- a/wenet/wav2vec/wav2vec2_encoder.py +++ b/wenet/wav2vec/wav2vec2_encoder.py @@ -347,7 +347,7 @@ def forward( xs_lens: torch.Tensor, decoding_chunk_size: int = 0, num_decoding_left_chunks: int = -1, - features_only: bool=False, + features_only: bool=True, ) -> Tuple[torch.Tensor, torch.Tensor]: """Embed positions in tensor. diff --git a/wenet/wav2vec/wav2vec2_model.py b/wenet/wav2vec/wav2vec2_model.py index 5e3be90cc..4aac869b0 100644 --- a/wenet/wav2vec/wav2vec2_model.py +++ b/wenet/wav2vec/wav2vec2_model.py @@ -110,7 +110,7 @@ def forward( self.num_updates=num_updates if self.pretrain: - encoder_out, encoder_mask,ext_res = self.encoder(speech, speech_lengths,False) + encoder_out, encoder_mask,ext_res = self.encoder(speech, speech_lengths, features_only=False) encoder_out_lens = encoder_mask.squeeze(1).sum(1) else: if self.encoder_grad_mult>0 and self.num_updates>=self.freeze_finetune_updates: @@ -210,7 +210,7 @@ def _forward_encoder( num_decoding_left_chunks=num_decoding_left_chunks ) # (B, maxlen, encoder_dim) else: - encoder_out, encoder_mask = self.encoder( + encoder_out, encoder_mask,_ = self.encoder( speech, speech_lengths, decoding_chunk_size=decoding_chunk_size, From b9b2f1836123f6d4e072f04842dd8b438b408e26 Mon Sep 17 00:00:00 2001 From: aydentang Date: Wed, 16 Mar 2022 13:02:54 +0800 Subject: [PATCH 07/19] w2v mask --- wenet/utils/mask.py | 128 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 128 insertions(+) diff --git a/wenet/utils/mask.py b/wenet/utils/mask.py index 8dc6e28d7..f81ba9d31 100644 --- a/wenet/utils/mask.py +++ b/wenet/utils/mask.py @@ -4,6 +4,8 @@ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) import torch +import numpy as np +from typing import Optional, Tuple ''' def subsequent_mask( @@ -285,3 +287,129 @@ def mask_finished_preds(pred: torch.Tensor, flag: torch.Tensor, beam_size = pred.size(-1) finished = flag.repeat([1, beam_size]) return pred.masked_fill_(finished, eos) + +def compute_mask_indices( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape + + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + """ + + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + all_num_mask = int( + # add a random number for probabilistic rounding + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + + all_num_mask = max(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + sz = all_sz + num_mask = all_num_mask + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type == "normal": + lengths = np.random.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type == "poisson": + lengths = np.random.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + mask_type) + + if sum(lengths) == 0: + lengths[0] = min(mask_length, sz - 1) + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = np.random.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - keep_length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter( + (e - s if e - s >= length + min_space else 0 for s, e in parts), + np.int, + ) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = np.random.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + + mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) + + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ] + ) + + mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) + + min_len = min([len(m) for m in mask_idcs]) + for i, mask_idc in enumerate(mask_idcs): + if len(mask_idc) > min_len: + mask_idc = np.random.choice(mask_idc, min_len, replace=False) + mask[i, mask_idc] = True + + return mask From 422c89e5737b035c33102b300c627580e2880d2c Mon Sep 17 00:00:00 2001 From: aydentang Date: Tue, 22 Mar 2022 19:21:52 +0800 Subject: [PATCH 08/19] fix encoder --- wenet/wav2vec/wav2vec2_encoder.py | 1 - 1 file changed, 1 deletion(-) diff --git a/wenet/wav2vec/wav2vec2_encoder.py b/wenet/wav2vec/wav2vec2_encoder.py index 573f767ce..0f1239641 100644 --- a/wenet/wav2vec/wav2vec2_encoder.py +++ b/wenet/wav2vec/wav2vec2_encoder.py @@ -813,7 +813,6 @@ def __init__( *positionwise_layer_args) if macaron_style else None, convolution_layer( *convolution_layer_args) if use_cnn_module else None, - cnn_module_before, dropout_rate, normalize_before, concat_after, From e59b33cbfc88294bfafe28844e1781d6b24f7f1d Mon Sep 17 00:00:00 2001 From: aydentang Date: Wed, 30 Mar 2022 17:03:09 +0800 Subject: [PATCH 09/19] readme --- examples/librispeech/ssl/README.md | 33 ++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 examples/librispeech/ssl/README.md diff --git a/examples/librispeech/ssl/README.md b/examples/librispeech/ssl/README.md new file mode 100644 index 000000000..f144eccdd --- /dev/null +++ b/examples/librispeech/ssl/README.md @@ -0,0 +1,33 @@ +# Performance Record + +## Conformer Result (Base 12layer) + +pretrain Conformer +* pretrain config: conf/pretrain/train_conformer_pretrain_w2v.yaml +* finetune config: conf/finetune/train_conformer_100h.yaml +* beam: 10 +* num of gpu: 8 +* num of averaged model: 20 +* ctc weight (used for attention rescoring): 0.5 +* pretrain 90 epochs ,finetune 80 epochs + +test set results trained with 100 hours train-clean set + +## wav2vec2.0 Results + +test clean +| decoding mode | full | +|--------------------------|------| +| ctc prefix beam search | 5.77 | +| attention rescoring | 5.30 | + +test other +| decoding mode | full | +|--------------------------|------| +| ctc prefix beam search | 12.73 | +| attention rescoring | 12.14 | + + +## data2vec Results + +going \ No newline at end of file From 12f69731fd3e2829fe52627baf65ff767dd7fd48 Mon Sep 17 00:00:00 2001 From: aydentang Date: Wed, 30 Mar 2022 19:31:50 +0800 Subject: [PATCH 10/19] data2vec training --- .../train_conformer_100h_data2vec.yaml | 89 +++ .../train_conformer_pretrain_data2vec.yaml | 91 +++ examples/librispeech/ssl/run_data2vec.sh | 332 ++++++++ wenet/bin/recognize.py | 11 +- wenet/bin/train.py | 10 +- wenet/data2vec/data2vec_encoder.py | 617 ++++++++++++++ wenet/data2vec/data2vec_model.py | 756 ++++++++++++++++++ wenet/data2vec/ema.py | 205 +++++ wenet/transformer/encoder_layer.py | 11 +- wenet/utils/executor.py | 32 +- 10 files changed, 2141 insertions(+), 13 deletions(-) create mode 100644 examples/librispeech/ssl/conf/finetune/train_conformer_100h_data2vec.yaml create mode 100644 examples/librispeech/ssl/conf/pretrain/train_conformer_pretrain_data2vec.yaml create mode 100644 examples/librispeech/ssl/run_data2vec.sh create mode 100644 wenet/data2vec/data2vec_encoder.py create mode 100644 wenet/data2vec/data2vec_model.py create mode 100644 wenet/data2vec/ema.py diff --git a/examples/librispeech/ssl/conf/finetune/train_conformer_100h_data2vec.yaml b/examples/librispeech/ssl/conf/finetune/train_conformer_100h_data2vec.yaml new file mode 100644 index 000000000..9e2c38694 --- /dev/null +++ b/examples/librispeech/ssl/conf/finetune/train_conformer_100h_data2vec.yaml @@ -0,0 +1,89 @@ +# network architecture +# encoder related +encoder: conformer +encoder_conf: + output_size: 512 # dimension of attention + attention_heads: 8 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 + positional_dropout_rate: 0.0 + attention_dropout_rate: 0.0 + input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + normalize_before: true + cnn_module_kernel: 31 + use_cnn_module: True + activation_type: 'swish' + pos_enc_layer_type: 'rel_pos' + selfattention_layer_type: 'rel_selfattn' + +# decoder related +decoder: transformer +decoder_conf: + attention_heads: 2 + linear_units: 512 + num_blocks: 1 + dropout_rate: 0.1 + positional_dropout_rate: 0.0 + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 + +# hybrid CTC/attention +model_conf: + ctc_weight: 0.7 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + +# use raw_wav or kaldi feature +raw_wav: true + +# dataset related +dataset_conf: + filter_conf: + max_length: 2000 + min_length: 50 + token_max_length: 400 + token_min_length: 1 + resample_conf: + resample_rate: 16000 + speed_perturb: true + fbank_conf: + num_mel_bins: 80 + frame_shift: 10 + frame_length: 25 + dither: 1.0 + spec_aug: true + spec_aug_conf: + num_t_mask: 3 + num_f_mask: 2 + max_t: 50 + max_f: 10 + shuffle: true + shuffle_conf: + shuffle_size: 1500 + sort: true + sort_conf: + sort_size: 500 # sort_size should be less than shuffle_size + batch_conf: + batch_type: 'static' # static or dynamic + batch_size: 10 + +pretrain: False +data2vec_conf: + pretrain: False + intermediate_layers: [4,5,6,7,8,9,10,11] + ema_anneal_end_step: 30000 + mask: False + mask_prob: 0.65 + +grad_clip: 5 +accum_grad: 1 +max_epoch: 120 +log_interval: 100 + +optim: adam +optim_conf: + lr: 0.001 +scheduler: warmuplr # pytorch v1.1.0+ required +scheduler_conf: + warmup_steps: 15000 diff --git a/examples/librispeech/ssl/conf/pretrain/train_conformer_pretrain_data2vec.yaml b/examples/librispeech/ssl/conf/pretrain/train_conformer_pretrain_data2vec.yaml new file mode 100644 index 000000000..5744449bd --- /dev/null +++ b/examples/librispeech/ssl/conf/pretrain/train_conformer_pretrain_data2vec.yaml @@ -0,0 +1,91 @@ +# network architecture +# encoder related +encoder: conformer +encoder_conf: + output_size: 512 # dimension of attention + attention_heads: 8 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 + positional_dropout_rate: 0.0 + attention_dropout_rate: 0.0 + input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + normalize_before: true + cnn_module_kernel: 31 + cnn_module_norm: 'layer_norm' + use_cnn_module: True + activation_type: 'swish' + pos_enc_layer_type: 'rel_pos' + selfattention_layer_type: 'rel_selfattn' + +# decoder related +decoder: transformer +decoder_conf: + attention_heads: 8 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 + positional_dropout_rate: 0.0 + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 + +# hybrid CTC/attention +model_conf: + ctc_weight: 1.0 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + +# use raw_wav or kaldi feature +raw_wav: true + +# dataset related +dataset_conf: + filter_conf: + max_length: 2000 + min_length: 50 + token_max_length: 400 + token_min_length: 1 + resample_conf: + resample_rate: 16000 + speed_perturb: false + fbank_conf: + num_mel_bins: 80 + frame_shift: 10 + frame_length: 25 + dither: 1.0 + spec_aug: false + spec_aug_conf: + num_t_mask: 3 + num_f_mask: 2 + max_t: 50 + max_f: 10 + shuffle: true + shuffle_conf: + shuffle_size: 1500 + sort: true + sort_conf: + sort_size: 500 # sort_size should be less than shuffle_size + batch_conf: + batch_type: 'dynamic' # static or dynamic + max_frames_in_batch: 20000 + batch_size: 20 + +pretrain: True +data2vec_conf: + pretrain: True + intermediate_layers: [4,5,6,7,8,9,10,11] + ema_anneal_end_step: 30000 + mask: True + mask_prob: 0.65 + +grad_clip: 5 +accum_grad: 4 +max_epoch: 90 +log_interval: 100 + +optim: adam +optim_conf: + lr: 0.001 +scheduler: warmuplr # pytorch v1.1.0+ required +scheduler_conf: + warmup_steps: 25000 diff --git a/examples/librispeech/ssl/run_data2vec.sh b/examples/librispeech/ssl/run_data2vec.sh new file mode 100644 index 000000000..dba8481c6 --- /dev/null +++ b/examples/librispeech/ssl/run_data2vec.sh @@ -0,0 +1,332 @@ +#!/bin/bash +# Copyright 2022 Tencent Inc. (Author: Kai Tang). +# Apach 2.0 + +. ./path.sh || exit 1; + +# Use this to control how many gpu you use, It's 1-gpu training if you specify +# just 1gpu, otherwise it's is multiple gpu training based on DDP in pytorch +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" +#export CUDA_VISIBLE_DEVICES="0" +stage=0 # start from 0 if you need to start from data preparation +stop_stage=5 +# data +data_url=www.openslr.org/resources/12 +# use your own data path +datadir=/apdcephfs/share_1157259/train/asr/egs/english2/dump +# wav data dir +wave_data=data +# Optional train_config +# 1. conf/train_transformer_large.yaml: Standard transformer +pretrain_config=conf/pretrain/train_conformer_pretrain_data2vec.yaml +# finetune_config +finetune_config=conf/finetune/train_conformer_100h_data2vec.yaml +checkpoint= +cmvn=true +do_delta=false + +pretrain_dir=exp/pretrain_d2v_960h +finetune_dir=exp/finetune_d2v_spec_aug_100h + +# use average_checkpoint will get better result +average_checkpoint=true +decode_checkpoint=$dir/final.pt +# maybe you can try to adjust it if you can not get close results as README.md +average_num=20 +decode_modes="attention_rescoring ctc_greedy_search ctc_prefix_beam_search attention" + +. tools/parse_options.sh || exit 1; + +# bpemode (unigram or bpe) +nbpe=5000 +bpemode=unigram + +set -e +set -u +set -o pipefail + +train_set=train_960 +train_set_100h=train_clean_100 +dev_set=dev +recog_set="test_clean test_other dev_clean dev_other" + + +# pretrained w2v-conformer encoder +enc_init=$pretrain_dir/final.pt +enc_init_mods='encoder.encoders.0,encoder.encoders.1,encoder.encoders.2,encoder.encoders.3,encoder.encoders.4,encoder.encoders.5,encoder.encoders.6,encoder.encoders.7,encoder.encoders.8,encoder.encoders.9,encoder.encoders.10,encoder.encoders.11,encoder.embed' + +if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then + echo "stage -1: Data Download" + for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do + local/download_and_untar.sh ${datadir} ${data_url} ${part} + done +fi + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + ### Task dependent. You have to make data the following preparation part by yourself. + ### But you can utilize Kaldi recipes in most cases + echo "stage 0: Data preparation" + for part in dev-clean dev-other test-clean test-other train-clean-100 train-clean-360 train-other-500; do + # use underscore-separated names in data directories. + local/data_prep_torchaudio.sh ${datadir}/LibriSpeech/${part} $wave_data/${part//-/_} + done +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + ### Task dependent. You have to design training and dev sets by yourself. + ### But you can utilize Kaldi recipes in most cases + echo "stage 1: Feature Generation" + mkdir -p $wave_data/train_960 + # merge total training data + for set in train_clean_100 train_clean_360 train_other_500; do + for f in `ls $wave_data/$set`; do + cat $wave_data/$set/$f >> $wave_data/train_960/$f + done + done + mkdir -p $wave_data/dev + # merge total dev data + for set in dev_clean dev_other; do + for f in `ls $wave_data/$set`; do + cat $wave_data/$set/$f >> $wave_data/$dev_set/$f + done + done + + tools/compute_cmvn_stats.py --num_workers 16 --train_config $train_config \ + --in_scp $wave_data/$train_set/wav.scp \ + --out_cmvn $wave_data/$train_set/global_cmvn + +fi + + +dict=$wave_data/lang_char/${train_set}_${bpemode}${nbpe}_units.txt +bpemodel=$wave_data/lang_char/${train_set}_${bpemode}${nbpe} +echo "dictionary: ${dict}" +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + ### Task dependent. You have to check non-linguistic symbols used in the corpus. + echo "stage 2: Dictionary and Json Data Preparation" + mkdir -p data/lang_char/ + + echo " 0" > ${dict} # 0 will be used for "blank" in CTC + echo " 1" >> ${dict} # must be 1 + + # we borrowed these code and scripts which are related bpe from ESPnet. + cut -f 2- -d" " $wave_data/${train_set}/text > $wave_data/lang_char/input.txt + tools/spm_train --input=$wave_data/lang_char/input.txt --vocab_size=${nbpe} --model_type=${bpemode} --model_prefix=${bpemodel} --input_sentence_size=100000000 + tools/spm_encode --model=${bpemodel}.model --output_format=piece < $wave_data/lang_char/input.txt | tr ' ' '\n' | sort | uniq | awk '{print $0 " " NR+1}' >> ${dict} + num_token=$(cat $dict | wc -l) + echo " $num_token" >> $dict # + wc -l ${dict} +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # Prepare wenet requried data + echo "Prepare data, prepare requried format" + for x in $dev_set ${recog_set} $train_set ; do + tools/make_raw_list.py $wave_data/$x/wav.scp $wave_data/$x/text \ + $wave_data/$x/data.list + done + +fi + + +#data2vec pretrain +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + # Training + mkdir -p $pretrain_dir + INIT_FILE=$pretrain_dir/ddp_init + rm -f $INIT_FILE # delete old one before starting + init_method=file://$(readlink -f $INIT_FILE) + echo "$0: init method is $init_method" + num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') + # Use "nccl" if it works, otherwise use "gloo" + dist_backend="nccl" + cmvn_opts= + $cmvn && cmvn_opts="--cmvn $wave_data/${train_set}/global_cmvn" + # train.py will write $train_config to $dir/train.yaml with model input + # and output dimension, train.yaml will be used for inference or model + # export later + for ((i = 0; i < $num_gpus; ++i)); do + { + gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1]) + python wenet/bin/train.py --gpu $gpu_id \ + --config $pretrain_config \ + --data_type raw \ + --symbol_table $dict \ + --bpe_model ${bpemodel}.model \ + --train_data $wave_data/$train_set/data.list \ + --cv_data $wave_data/$dev_set/data.list \ + ${checkpoint:+--checkpoint $checkpoint} \ + --model_dir $pretrain_dir \ + --ddp.init_method $init_method \ + --ddp.world_size $num_gpus \ + --ddp.rank $i \ + --ddp.dist_backend $dist_backend \ + --num_workers 3 \ + $cmvn_opts + }& + done + wait +fi + +#finetune 100h using data2vec model +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + # Training + mkdir -p $finetune_dir + INIT_FILE=$finetune_dir/ddp_init + rm -f $INIT_FILE # delete old one before starting + init_method=file://$(readlink -f $INIT_FILE) + echo "$0: init method is $init_method" + num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') + # Use "nccl" if it works, otherwise use "gloo" + dist_backend="nccl" + cmvn_opts= + $cmvn && cmvn_opts="--cmvn $wave_data/${train_set}/global_cmvn" + # train.py will write $train_config to $dir/train.yaml with model input + # and output dimension, train.yaml will be used for inference or model + # export later + for ((i = 0; i < $num_gpus; ++i)); do + { + gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1]) + python wenet/bin/train.py --gpu $gpu_id \ + --config $finetune_config \ + --data_type raw \ + --symbol_table $dict \ + --bpe_model ${bpemodel}.model \ + --train_data $wave_data/$train_set_100h/data.list \ + --cv_data $wave_data/$dev_set/data.list \ + ${checkpoint:+--checkpoint $checkpoint} \ + ${enc_init:+--enc_init $enc_init} \ + --enc_init_mods $enc_init_mods \ + --model_dir $finetune_dir \ + --ddp.init_method $init_method \ + --ddp.world_size $num_gpus \ + --ddp.rank $i \ + --ddp.dist_backend $dist_backend \ + --num_workers 3 \ + $cmvn_opts \ + } & + done + wait +fi + +if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then + # Test model, please specify the model you want to test by --checkpoint + cmvn_opts= + $cmvn && cmvn_opts="--cmvn data/${train_set}/global_cmvn" + # TODO, Add model average here + mkdir -p $finetune_dir/test + if [ ${average_checkpoint} == true ]; then + decode_checkpoint=$finetune_dir/avg_${average_num}.pt + echo "do model average and final checkpoint is $decode_checkpoint" + python wenet/bin/average_model.py \ + --dst_model $decode_checkpoint \ + --src_path $finetune_dir \ + --num ${average_num} \ + --val_best + fi + # Specify decoding_chunk_size if it's a unified dynamic chunk trained model + # -1 for full chunk + decoding_chunk_size= + ctc_weight=0.5 + # Polling GPU id begin with index 0 + num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') + idx=0 + for test in $recog_set; do + for mode in ${decode_modes}; do + { + { + test_dir=$finetune_dir/${test}_${mode} + mkdir -p $test_dir + gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$idx+1]) + python wenet/bin/recognize.py --gpu $gpu_id \ + --mode $mode \ + --config $finetune_dir/train.yaml \ + --data_type raw \ + --dict $dict \ + --bpe_model ${bpemodel}.model \ + --test_data $wave_data/$test/data.list \ + --checkpoint $decode_checkpoint \ + --beam_size 10 \ + --batch_size 1 \ + --penalty 0.0 \ + --result_file $test_dir/text_bpe \ + --ctc_weight $ctc_weight \ + ${decoding_chunk_size:+--decoding_chunk_size $decoding_chunk_size} + + cut -f2- -d " " $test_dir/text_bpe > $test_dir/text_bpe_value_tmp + cut -f1 -d " " $test_dir/text_bpe > $test_dir/text_bpe_key_tmp + tools/spm_decode --model=${bpemodel}.model --input_format=piece \ + < $test_dir/text_bpe_value_tmp | sed -e "s/▁/ /g" > $test_dir/text_value_tmp + paste -d " " $test_dir/text_bpe_key_tmp $test_dir/text_value_tmp > $test_dir/text + + python tools/compute-wer.py --char=1 --v=1 \ + $wave_data/$test/text $test_dir/text > $test_dir/wer + } & + + ((idx+=1)) + if [ $idx -eq $num_gpus ]; then + idx=0 + fi + } + done + done + wait + +fi + +if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then + # Export the best model you want + python wenet/bin/export_jit.py \ + --config $finetune_dir/train.yaml \ + --checkpoint $finetune_dir/avg_${average_num}.pt \ + --output_file $finetune_dir/final.zip +fi + +# Optionally, you can add LM and test it with runtime. +if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then + lm=data/local/lm + lexicon=data/local/dict/lexicon.txt + mkdir -p $lm + mkdir -p data/local/dict + + # 7.1 Download & format LM + which_lm=3-gram.pruned.1e-7.arpa.gz + if [ ! -e ${lm}/${which_lm} ]; then + wget http://www.openslr.org/resources/11/${which_lm} -P ${lm} + fi + echo "unzip lm($which_lm)..." + gunzip -k ${lm}/${which_lm} -c > ${lm}/lm.arpa + echo "Lm saved as ${lm}/lm.arpa" + + # 7.2 Prepare dict + unit_file=$dict + bpemodel=$bpemodel + # use $dir/words.txt (unit_file) and $dir/train_960_unigram5000 (bpemodel) + # if you download pretrained librispeech conformer model + cp $unit_file data/local/dict/units.txt + if [ ! -e ${lm}/librispeech-lexicon.txt ]; then + wget http://www.openslr.org/resources/11/librispeech-lexicon.txt -P ${lm} + fi + echo "build lexicon..." + tools/fst/prepare_dict.py $unit_file ${lm}/librispeech-lexicon.txt \ + $lexicon $bpemodel.model + echo "lexicon saved as '$lexicon'" + + # 7.3 Build decoding TLG + tools/fst/compile_lexicon_token_fst.sh \ + data/local/dict data/local/tmp data/local/lang + tools/fst/make_tlg.sh data/local/lm data/local/lang data/lang_test || exit 1; + + # 7.4 Decoding with runtime + fst_dir=data/lang_test + for test in ${recog_set}; do + ./tools/decode.sh --nj 6 \ + --beam 10.0 --lattice_beam 5 --max_active 7000 --blank_skip_thresh 0.98 \ + --ctc_weight 0.5 --rescoring_weight 1.0 --acoustic_scale 1.2 \ + --fst_path $fst_dir/TLG.fst \ + data/$test/wav.scp data/$test/text $dir/final.zip $fst_dir/words.txt \ + $dir/lm_with_runtime_${test} + tail $dir/lm_with_runtime_${test}/wer + done +fi + diff --git a/wenet/bin/recognize.py b/wenet/bin/recognize.py index f4e20fcb9..519e76530 100644 --- a/wenet/bin/recognize.py +++ b/wenet/bin/recognize.py @@ -27,6 +27,7 @@ from wenet.dataset.dataset import Dataset from wenet.transformer.asr_model import init_asr_model from wenet.wav2vec.wav2vec2_model import init_wav2vec2_model +from wenet.data2vec.data2vec_model import init_data2vec_model from wenet.utils.checkpoint import load_checkpoint from wenet.utils.file_utils import read_symbol_table, read_non_lang_symbols from wenet.utils.config import override_config @@ -149,6 +150,11 @@ def main(): wav2vec_conf['pretrain']=pretrain else: wav2vec_conf=None + if 'data2vec_conf' in configs: + data2vec_conf=configs['data2vec_conf'] + data2vec_conf['pretrain']=pretrain + else: + data2vec_conf=None test_dataset = Dataset(args.data_type, args.test_data, @@ -161,11 +167,12 @@ def main(): test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0) # Init asr model from configs - if wav2vec_conf: + if wav2vec_conf: model=init_wav2vec2_model(configs) + elif data2vec_conf: + model=init_data2vec_model(configs) else: model = init_asr_model(configs) - # Load dict char_dict = {v: k for k, v in symbol_table.items()} eos = len(char_dict) - 1 diff --git a/wenet/bin/train.py b/wenet/bin/train.py index d6c5aa7c6..06749284d 100644 --- a/wenet/bin/train.py +++ b/wenet/bin/train.py @@ -29,6 +29,7 @@ from wenet.dataset.dataset import Dataset from wenet.transformer.asr_model import init_asr_model from wenet.wav2vec.wav2vec2_model import init_wav2vec2_model +from wenet.data2vec.data2vec_model import init_data2vec_model from wenet.utils.checkpoint import (load_checkpoint, save_checkpoint, load_trained_modules) from wenet.utils.executor import Executor @@ -185,7 +186,12 @@ def main(): wav2vec_conf=configs['wav2vec_conf'] wav2vec_conf['pretrain']=pretrain else: - wav2vec_conf=None + wav2vec_conf=None + if 'data2vec_conf' in configs: + data2vec_conf=configs['data2vec_conf'] + data2vec_conf['pretrain']=pretrain + else: + data2vec_conf=None # Save configs to model_dir/train.yaml for inference and export configs['input_dim'] = input_dim @@ -201,6 +207,8 @@ def main(): # Init asr model from configs if wav2vec_conf: model=init_wav2vec2_model(configs) + elif data2vec_conf: + model=init_data2vec_model(configs) else: model = init_asr_model(configs) print(model) diff --git a/wenet/data2vec/data2vec_encoder.py b/wenet/data2vec/data2vec_encoder.py new file mode 100644 index 000000000..a7adf1347 --- /dev/null +++ b/wenet/data2vec/data2vec_encoder.py @@ -0,0 +1,617 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + + +# The origial Data2vec work is in: +# Paper: https://arxiv.org/pdf/2006.11477.pdf +# Code in Fairseq: https://github.com/pytorch/fairseq/tree/master/examples/data2vec + +"""Encoder definition.""" +from typing import Tuple, List, Optional +import logging +import torch +import math +from typeguard import check_argument_types +import torch.nn.functional as F +import torch.distributed as dist +import numpy as np + +from wenet.transformer.attention import MultiHeadedAttention +from wenet.transformer.attention import RelPositionMultiHeadedAttention +from wenet.transformer.convolution import ConvolutionModule +from wenet.transformer.embedding import PositionalEncoding +from wenet.transformer.embedding import RelPositionalEncoding +from wenet.transformer.encoder_layer import TransformerEncoderLayer +from wenet.transformer.encoder_layer import ConformerEncoderLayer +from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward +from wenet.transformer.subsampling import Conv2dSubsampling4 +from wenet.transformer.subsampling import Conv2dSubsampling6 +from wenet.transformer.subsampling import Conv2dSubsampling8 +from wenet.transformer.subsampling import LinearNoSubsampling +from wenet.utils.common import get_activation +from wenet.utils.mask import make_pad_mask +from wenet.utils.mask import add_optional_chunk_mask +from wenet.utils.mask import compute_mask_indices +from wenet.data2vec.ema import EMA + +def get_annealed_rate(start, end, curr_step, total_steps): + r = end - start + pct_remaining = 1 - curr_step / total_steps + return end - r * pct_remaining + +class Data2vecBaseEncoder(torch.nn.Module): + def __init__( + self, + data2vec_conf:dict, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: str = "conv2d", + pos_enc_layer_type: str = "abs_pos", + normalize_before: bool = True, + concat_after: bool = False, + static_chunk_size: int = 0, + use_dynamic_chunk: bool = False, + global_cmvn: torch.nn.Module = None, + ema: EMA = None + ): + + assert check_argument_types() + super().__init__() + self._output_size = output_size + + if pos_enc_layer_type == "abs_pos": + pos_enc_class = PositionalEncoding + elif pos_enc_layer_type == "rel_pos": + pos_enc_class = RelPositionalEncoding + else: + raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type) + + if input_layer == "linear": + subsampling_class = LinearNoSubsampling + elif input_layer == "conv2d": + subsampling_class = Conv2dSubsampling4 + elif input_layer == "conv2d6": + subsampling_class = Conv2dSubsampling6 + elif input_layer == "conv2d8": + subsampling_class = Conv2dSubsampling8 + else: + raise ValueError("unknown input_layer: " + input_layer) + + self.global_cmvn = global_cmvn + self.embed = subsampling_class( + input_size, + output_size, + dropout_rate, + pos_enc_class(output_size, positional_dropout_rate), + ) + + self.normalize_before = normalize_before + self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-12) + self.feature_norm = torch.nn.LayerNorm(output_size, eps=1e-12) + self.static_chunk_size = static_chunk_size + self.use_dynamic_chunk = use_dynamic_chunk + + self.embed_dim=output_size + final_dim=output_size + self.encoder_embed_dim=output_size + + self.mask_prob = data2vec_conf.get('mask_prob', 0.65) + self.mask_selection = "static" + self.mask_other = 0 + self.mask_length = 10 + self.no_mask_overlap =False + self.mask_min_space = 1 + + self.mask_channel_prob = data2vec_conf.get('mask_channel_prob', 0.0) + self.mask_channel_selection = "static" + self.mask_channel_other = 0 + self.mask_channel_length = data2vec_conf.get('mask_channel_length', 10) + self.no_mask_channel_overlap = False + self.mask_channel_min_space = 1 + + self.ema = None + self.embed_dim=self.encoder_embed_dim + self.ema_decay=data2vec_conf.get("ema_decay",0.999) + self.ema_end_decay=data2vec_conf.get("ema_end_decay",0.9999) + self.ema_transformer_only=data2vec_conf.get("ema_transformer_only",True) + self.ema_layers_only=data2vec_conf.get("ema_layers_only",True) + self.ema_anneal_end_step=data2vec_conf.get("ema_anneal_end_step",30000) + + self.min_target_var=0.1 + self.min_pred_var=0.01 + + self.layer_norm_target_layer: bool = False + self.instance_norm_target_layer: bool = True + self.instance_norm_targets: bool = False + self.layer_norm_targets: bool = False + self.batch_norm_target_layer: bool = False + self.group_norm_target_layer: bool = False + + self.average_top_k_layers = data2vec_conf.get("average_top_k_layers",8) + self.loss_beta = data2vec_conf.get("loss_beta",0) + self.loss_scale = data2vec_conf.get("loss_scale",None) + + self.project_final=data2vec_conf.get('project_final', False) + self.intermediate_layers=data2vec_conf.get('intermediate_layers',None) + + self.mask=data2vec_conf.get('mask', True) + + self.pretrain=data2vec_conf.get('pretrain',True) + + self.final_proj = torch.nn.Linear(output_size, final_dim) + + self.target_glu=None + self.logit_temp=0.1 + + self.mask_emb = torch.nn.Parameter( + torch.FloatTensor(self.encoder_embed_dim).uniform_() + ) + self.feature_grad_mult=data2vec_conf.get('feature_grad_mult',1.0) + + self.num_updates=0 + + + def output_size(self) -> int: + return self._output_size + + def apply_mask(self, x, padding_mask): + B, T, C = x.shape + if self.mask_prob > 0: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x[mask_indices] = self.mask_emb + else: + mask_indices = None + + if self.mask_channel_prob > 0: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x[mask_channel_indices] = 0 + + return x, mask_indices + + def make_ema_teacher(self): + ema_config = { + "ema_decay": self.ema_decay, + "ema_fp32" :True, + "store_ema" : False, + "ema_start_update" : 0 , + "ema_update_freq" : 1, + } + skip_keys = set() + if self.ema_layers_only: + self.ema_transformer_only = True + skip_keys.add("embed.") + + self.ema = EMA( + self.encoders, + ema_config, + skip_keys=skip_keys, + ) + + def set_num_updates(self, num_updates): + + if self.ema is None and self.final_proj is not None: + logging.info(f"making ema teacher") + self.make_ema_teacher() + elif self.training and self.ema is not None and num_updates!= self.num_updates : + if self.ema_decay != self.ema_end_decay: + if num_updates >= self.ema_anneal_end_step: + decay = self.ema_end_decay + else: + decay = get_annealed_rate( + self.ema_decay, + self.ema_end_decay, + num_updates, + self.ema_anneal_end_step, + ) + self.ema._set_decay(decay) + if self.ema.get_decay() < 1: + self.ema.step(self.encoders) + + self.num_updates = num_updates + + def state_dict(self, destination=None, prefix="", keep_vars=False): + state = super().state_dict(destination, prefix, keep_vars) + logging.info(f"state_dict") + # if self.ema is not None: + # state[prefix + "_ema"] = self.ema.fp32_params + + return state + + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + logging.info(f"_load_from_state_dict") + if self.ema is not None: + logging.info(f"_load_from_state_dict ema!") + k = prefix + "_ema" + assert k in state_dict + logging.info(k) + self.ema.restore(state_dict[k], True) + del state_dict[k] + return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + def extract_features( + self, + x: torch.Tensor, + chunk_masks: torch.Tensor, + pos_emb: torch.Tensor, + ): + layer_idx=0 + inter_idx=0 + intermediate_outputs = [] + for layer in self.ema.model: + x, chunk_masks,_,layer_out = layer(x, chunk_masks, pos_emb) + encoder_output = x + if ( + self.intermediate_layers is not None + and inter_idx Tuple[torch.Tensor, torch.Tensor]: + + T = xs.size(1) + masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) + if self.global_cmvn is not None: + xs = self.global_cmvn(xs) + + if self.feature_grad_mult > 0: + features, pos_emb, masks = self.embed(xs, masks) + if self.feature_grad_mult != 1.0: + features = GradMultiply.apply(features, self.feature_grad_mult) + else: + with torch.no_grad(): + features, pos_emb, masks = self.embed(xs, masks) + + chunk_masks=masks + #L2 loss pen + features_pen = features.float().pow(2).mean() + + features = self.feature_norm(features) + unmasked_features = features.clone() + input_features=None + + if self.mask and self.pretrain: + + x, mask_indices = self.apply_mask(features, None) + input_features=x.clone() + if mask_indices is not None: + y = unmasked_features[mask_indices].view( + unmasked_features.size(0), -1, unmasked_features.size(-1) + ) + else: + y = unmasked_features + elif self.mask and self.training: + x, mask_indices = self.apply_mask(features, None) + if mask_indices is not None: + y = unmasked_features[mask_indices].view( + unmasked_features.size(0), -1, unmasked_features.size(-1) + ) + else: + y = unmasked_features + x = features + y = unmasked_features + mask_indices = None + + for layer in self.encoders: + x, chunk_masks, _ , _ = layer(x, chunk_masks, pos_emb) + + if self.normalize_before: + x = self.after_norm(x) + ext_result={"outputs":x} + + if features_only: + return x, masks,ext_result + + with torch.no_grad(): + self.ema.model.eval() + + if self.ema_transformer_only: + y, layer_results = self.extract_features( + unmasked_features, + chunk_masks, + pos_emb, + ) + y = { + "x": y, + "layer_results": layer_results, + } + + target_layer_results = [l[1] for l in y["layer_results"]] + + permuted = False + if self.instance_norm_target_layer or self.batch_norm_target_layer: + target_layer_results = [ + tl.permute(0, 2, 1) for tl in target_layer_results # TBC -> BCT + ] + permuted = True + + if self.batch_norm_target_layer: + target_layer_results = [ + F.batch_norm( + tl.float(), running_mean=None, running_var=None, training=True + ) + for tl in target_layer_results + ] + + if self.instance_norm_target_layer: + target_layer_results = [ + F.instance_norm(tl.float()) for tl in target_layer_results + ] + + if permuted: + target_layer_results = [ + tl.transpose(1, 2) for tl in target_layer_results # BCT -> BTC + ] + + if self.group_norm_target_layer: + target_layer_results = [ + F.layer_norm(tl.float(), tl.shape[-2:]) + for tl in target_layer_results + ] + + if self.layer_norm_target_layer: + target_layer_results = [ + F.layer_norm(tl.float(), tl.shape[-1:]) + for tl in target_layer_results + ] + + y = sum(target_layer_results) / len(target_layer_results) + + if self.layer_norm_targets: + y = F.layer_norm(y.float(), y.shape[-1:]) + + if self.instance_norm_targets: + y = F.instance_norm(y.float().transpose(1, 2)).transpose(1, 2) + + # if not permuted: + # y = y.transpose(0, 1) + + y = y[mask_indices] + + x = x[mask_indices] + x = self.final_proj(x) + + sz = x.size(-1) + + if self.loss_beta == 0: + loss = F.mse_loss(x.float(), y.float(), reduction="none").sum(dim=-1) + else: + loss = F.smooth_l1_loss( + x.float(), y.float(), reduction="none", beta=self.loss_beta + ).sum(dim=-1) + if self.loss_scale is not None: + scale = self.loss_scale + else: + scale = 1 / math.sqrt(sz) + + ext_result["losses_reg"] = loss.sum() * scale + + if "sample_size" not in ext_result: + ext_result["sample_size"] = loss.numel() + + #bug: uio mode trainging will be hold on + + # with torch.no_grad(): + # ext_result["target_var"] = self.compute_var(y) + # ext_result["pred_var"] = self.compute_var(x.float()) + + # if self.num_updates > 5000 and ext_result["target_var"] < self.min_target_var: + # logging.warning( + # f"target var is {ext_result['target_var'].item()} < {self.min_target_var}, exiting" + # ) + # # raise Exception( + # # f"target var is {ext_result['target_var'].item()} < {self.min_target_var}, exiting" + # # ) + # if self.num_updates > 5000 and ext_result["pred_var"] < self.min_pred_var: + # logging.warning( + # f"pred var is {ext_result['pred_var'].item()} < {self.min_pred_var}, exiting" + # ) + # # raise Exception( + # # f"pred var is {ext_result['pred_var'].item()} < {self.min_pred_var}, exiting" + # # ) + + # if self.ema is not None: + # ext_result["ema_decay"] = self.ema.get_decay() * 1000 + + return x,masks,ext_result + + @staticmethod + def compute_var(y): + y = y.view(-1, y.size(-1)) + if dist.is_initialized(): + zc = torch.tensor(y.size(0)).cuda() + zs = y.sum(dim=0) + zss = (y ** 2).sum(dim=0) + + dist.all_reduce(zc) + dist.all_reduce(zs) + dist.all_reduce(zss) + + var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1)) + return torch.sqrt(var + 1e-6).mean() + else: + return torch.sqrt(y.var(dim=0) + 1e-6).mean() + + def forward_mask( + self, + xs: torch.Tensor, + masks: torch.Tensor, + decoding_chunk_size: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # def forward( + # self, + # xs: torch.Tensor, + # xs_lens: torch.Tensor, + # ) -> Tuple[torch.Tensor, torch.Tensor]: + """Embed positions in tensor. + + Args: + xs: padded input tensor (B, L, D) + xs_lens: input length (B) + decoding_chunk_size: decoding chunk size for dynamic chunk, it's + 0: default for training, use random dynamic chunk. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + Returns: + encoder output tensor, lens and mask + """ + # batch, max_len = torch.tensor(xs.shape[:2]).tolist() + + + if self.global_cmvn is not None: + xs = self.global_cmvn(xs) + xs, pos_emb, masks = self.embed(xs, masks) + + chunk_masks=masks + + for layer in self.encoders: + xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb) + if self.normalize_before: + xs = self.after_norm(xs) + + return xs, masks + + +class Data2vecConformerEncoder(Data2vecBaseEncoder): + """Conformer encoder module.""" + def __init__( + self, + data2vec_conf:dict, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: str = "conv2d", + pos_enc_layer_type: str = "rel_pos", + normalize_before: bool = True, + concat_after: bool = False, + static_chunk_size: int = 0, + use_dynamic_chunk: bool = False, + global_cmvn: torch.nn.Module = None, + use_dynamic_left_chunk: bool = False, + positionwise_conv_kernel_size: int = 1, + macaron_style: bool = True, + selfattention_layer_type: str = "rel_selfattn", + activation_type: str = "swish", + use_cnn_module: bool = True, + cnn_module_kernel: int = 15, + causal: bool = False, + cnn_module_norm: str = "batch_norm", + ): + """Construct ConformerEncoder + + Args: + input_size to use_dynamic_chunk, see in BaseEncoder + positionwise_conv_kernel_size (int): Kernel size of positionwise + conv1d layer. + macaron_style (bool): Whether to use macaron style for + positionwise layer. + selfattention_layer_type (str): Encoder attention layer type, + the parameter has no effect now, it's just for configure + compatibility. + activation_type (str): Encoder activation function type. + use_cnn_module (bool): Whether to use convolution module. + cnn_module_kernel (int): Kernel size of convolution module. + causal (bool): whether to use causal convolution or not. + """ + assert check_argument_types() + super().__init__(data2vec_conf,input_size, output_size, attention_heads, + linear_units, num_blocks, dropout_rate, + positional_dropout_rate, attention_dropout_rate, + input_layer, pos_enc_layer_type, normalize_before, + concat_after, static_chunk_size, use_dynamic_chunk, + global_cmvn) + activation = get_activation(activation_type) + + # self-attention module definition + if selfattention_layer_type == "rel_selfattn": + encoder_selfattn_layer = RelPositionMultiHeadedAttention + else: + encoder_selfattn_layer = MultiHeadedAttention + + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + ) + # feed-forward module definition + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = ( + output_size, + linear_units, + dropout_rate, + activation, + ) + # convolution module definition + convolution_layer = ConvolutionModule + convolution_layer_args = (output_size, cnn_module_kernel, activation, + cnn_module_norm, causal) + + self.encoders = torch.nn.ModuleList([ + ConformerEncoderLayer( + output_size, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + positionwise_layer( + *positionwise_layer_args) if macaron_style else None, + convolution_layer( + *convolution_layer_args) if use_cnn_module else None, + dropout_rate, + normalize_before, + concat_after, + ffn_res=True, + ) for _ in range(num_blocks) + ]) diff --git a/wenet/data2vec/data2vec_model.py b/wenet/data2vec/data2vec_model.py new file mode 100644 index 000000000..30beff4d0 --- /dev/null +++ b/wenet/data2vec/data2vec_model.py @@ -0,0 +1,756 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + + +# The origial Data2vec work is in: +# Paper: https://arxiv.org/pdf/2006.11477.pdf +# Code in Fairseq: https://github.com/pytorch/fairseq/tree/master/examples/data2vec + +from collections import defaultdict +from typing import List, Optional, Tuple + +import torch +import logging + +from torch.nn.utils.rnn import pad_sequence + +from wenet.transformer.cmvn import GlobalCMVN +from wenet.transformer.ctc import CTC +from wenet.transformer.decoder import (TransformerDecoder, + BiTransformerDecoder) +from wenet.transformer.encoder import ConformerEncoder +from wenet.transformer.encoder import TransformerEncoder +from wenet.transformer.label_smoothing_loss import LabelSmoothingLoss +from wenet.utils.cmvn import load_cmvn +from wenet.utils.common import (IGNORE_ID, add_sos_eos, log_add, + remove_duplicates_and_blank, th_accuracy, + reverse_pad_list) +from wenet.utils.mask import (make_pad_mask, mask_finished_preds, + mask_finished_scores, subsequent_mask) +from wenet.data2vec.data2vec_encoder import Data2vecConformerEncoder + +class Data2vecModel(torch.nn.Module): + """CTC-attention hybrid Encoder-Decoder model""" + def __init__( + self, + data2vec_conf:dict, + vocab_size: int, + encoder: TransformerEncoder, + decoder: TransformerDecoder, + ctc: CTC, + ctc_weight: float = 0.5, + ignore_id: int = IGNORE_ID, + reverse_weight: float = 0.0, + lsm_weight: float = 0.0, + length_normalized_loss: bool = False, + ): + assert 0.0 <= ctc_weight <= 1.0, ctc_weight + + super().__init__() + # note that eos is the same as sos (equivalent ID) + self.sos = vocab_size - 1 + self.eos = vocab_size - 1 + self.vocab_size = vocab_size + self.ignore_id = ignore_id + self.ctc_weight = ctc_weight + self.reverse_weight = reverse_weight + + self.encoder = encoder + self.decoder = decoder + self.ctc = ctc + self.criterion_att = LabelSmoothingLoss( + size=vocab_size, + padding_idx=ignore_id, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, + ) + + self.pretrain=data2vec_conf.get('pretrain',False) + self.encoder_grad_mult=data2vec_conf.get('encoder_grad_mult',1.0) + self.freeze_finetune_updates=data2vec_conf.get('freeze_finetune_updates',0) + self.num_updates=0 + + self.log_interval=data2vec_conf.get('log_interval',100) + self.accum_grad=data2vec_conf.get('accum_grad',4) + + def set_num_updates(self, num_updates): + + self.num_updates = num_updates + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + num_updates:int=0 + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], + Optional[torch.Tensor]]: + """Frontend + Encoder + Decoder + Calc loss + + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + text: (Batch, Length) + text_lengths: (Batch,) + """ + # Check that batch_size is unified + assert (speech.shape[0] == speech_lengths.shape[0]), (speech.shape, speech_lengths.shape) + self.num_updates=num_updates + + if self.pretrain: + #for EMA updates + if self.train: + self.encoder.set_num_updates(self.num_updates) + encoder_out, encoder_mask,ext_res = self.encoder(speech, speech_lengths, features_only=False) + encoder_out_lens = encoder_mask.squeeze(1).sum(1) + else: + if self.encoder_grad_mult>0 and self.num_updates>=self.freeze_finetune_updates: + encoder_out, encoder_mask ,ext_res= self.encoder(speech, speech_lengths,features_only=True) + if self.encoder_grad_mult !=1.0: + encoder_out=GradMultiply.apply(encoder_out,self.encoder_grad_mult) + encoder_mask=GradMultiply.apply(encoder_mask,self.encoder_grad_mult) + else: + with torch.no_grad(): + encoder_out, encoder_mask,ext_res = self.encoder(speech, speech_lengths,features_only=True) + # 1. Encoder + encoder_out, encoder_mask,ext_res = self.encoder(speech, speech_lengths,features_only=True) + encoder_out_lens = encoder_mask.squeeze(1).sum(1) + + + if self.pretrain: + loss=ext_res["losses_reg"] + sample_size=ext_res["sample_size"] + loss_att=None + loss_ctc=None + else: + # 2a. Attention-decoder branch + if self.ctc_weight != 1.0: + loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask, + text, text_lengths) + else: + loss_att = None + + # 2b. CTC branch + if self.ctc_weight != 0.0: + loss_ctc = self.ctc(encoder_out, encoder_out_lens, text, + text_lengths) + else: + loss_ctc = None + + if loss_ctc is None: + loss = loss_att + elif loss_att is None: + loss = loss_ctc + else: + loss = self.ctc_weight * loss_ctc + (1 - + self.ctc_weight) * loss_att + + + + return loss, loss_att, loss_ctc,sample_size + + def _calc_att_loss( + self, + encoder_out: torch.Tensor, + encoder_mask: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, float]: + ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, + self.ignore_id) + ys_in_lens = ys_pad_lens + 1 + + # reverse the seq, used for right to left decoder + r_ys_pad = reverse_pad_list(ys_pad, ys_pad_lens, float(self.ignore_id)) + r_ys_in_pad, r_ys_out_pad = add_sos_eos(r_ys_pad, self.sos, self.eos, + self.ignore_id) + # 1. Forward decoder + decoder_out, r_decoder_out, _ = self.decoder(encoder_out, encoder_mask, + ys_in_pad, ys_in_lens, + r_ys_in_pad, + self.reverse_weight) + # 2. Compute attention loss + loss_att = self.criterion_att(decoder_out, ys_out_pad) + r_loss_att = torch.tensor(0.0) + if self.reverse_weight > 0.0: + r_loss_att = self.criterion_att(r_decoder_out, r_ys_out_pad) + loss_att = loss_att * ( + 1 - self.reverse_weight) + r_loss_att * self.reverse_weight + acc_att = th_accuracy( + decoder_out.view(-1, self.vocab_size), + ys_out_pad, + ignore_label=self.ignore_id, + ) + return loss_att, acc_att + + def _forward_encoder( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + decoding_chunk_size: int = -1, + num_decoding_left_chunks: int = -1, + simulate_streaming: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Let's assume B = batch_size + # 1. Encoder + if simulate_streaming and decoding_chunk_size > 0: + encoder_out, encoder_mask = self.encoder.forward_chunk_by_chunk( + speech, + decoding_chunk_size=decoding_chunk_size, + num_decoding_left_chunks=num_decoding_left_chunks + ) # (B, maxlen, encoder_dim) + else: + encoder_out, encoder_mask,_ = self.encoder( + speech, + speech_lengths, + decoding_chunk_size=decoding_chunk_size, + num_decoding_left_chunks=num_decoding_left_chunks + ) # (B, maxlen, encoder_dim) + return encoder_out, encoder_mask + + def recognize( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + beam_size: int = 10, + decoding_chunk_size: int = -1, + num_decoding_left_chunks: int = -1, + simulate_streaming: bool = False, + ) -> torch.Tensor: + """ Apply beam search on attention decoder + + Args: + speech (torch.Tensor): (batch, max_len, feat_dim) + speech_length (torch.Tensor): (batch, ) + beam_size (int): beam size for beam search + decoding_chunk_size (int): decoding chunk for dynamic chunk + trained model. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + 0: used for training, it's prohibited here + simulate_streaming (bool): whether do encoder forward in a + streaming fashion + + Returns: + torch.Tensor: decoding result, (batch, max_result_len) + """ + assert speech.shape[0] == speech_lengths.shape[0] + assert decoding_chunk_size != 0 + device = speech.device + batch_size = speech.shape[0] + + # Let's assume B = batch_size and N = beam_size + # 1. Encoder + encoder_out, encoder_mask = self._forward_encoder( + speech, speech_lengths, decoding_chunk_size, + num_decoding_left_chunks, + simulate_streaming) # (B, maxlen, encoder_dim) + maxlen = encoder_out.size(1) + encoder_dim = encoder_out.size(2) + running_size = batch_size * beam_size + encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1).view( + running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim) + encoder_mask = encoder_mask.unsqueeze(1).repeat( + 1, beam_size, 1, 1).view(running_size, 1, + maxlen) # (B*N, 1, max_len) + + hyps = torch.ones([running_size, 1], dtype=torch.long, + device=device).fill_(self.sos) # (B*N, 1) + scores = torch.tensor([0.0] + [-float('inf')] * (beam_size - 1), + dtype=torch.float) + scores = scores.to(device).repeat([batch_size]).unsqueeze(1).to( + device) # (B*N, 1) + end_flag = torch.zeros_like(scores, dtype=torch.bool, device=device) + cache: Optional[List[torch.Tensor]] = None + # 2. Decoder forward step by step + for i in range(1, maxlen + 1): + # Stop if all batch and all beam produce eos + if end_flag.sum() == running_size: + break + # 2.1 Forward decoder step + hyps_mask = subsequent_mask(i).unsqueeze(0).repeat( + running_size, 1, 1).to(device) # (B*N, i, i) + # logp: (B*N, vocab) + logp, cache = self.decoder.forward_one_step( + encoder_out, encoder_mask, hyps, hyps_mask, cache) + # 2.2 First beam prune: select topk best prob at current time + top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N) + top_k_logp = mask_finished_scores(top_k_logp, end_flag) + top_k_index = mask_finished_preds(top_k_index, end_flag, self.eos) + # 2.3 Second beam prune: select topk score with history + scores = scores + top_k_logp # (B*N, N), broadcast add + scores = scores.view(batch_size, beam_size * beam_size) # (B, N*N) + scores, offset_k_index = scores.topk(k=beam_size) # (B, N) + scores = scores.view(-1, 1) # (B*N, 1) + # 2.4. Compute base index in top_k_index, + # regard top_k_index as (B*N*N),regard offset_k_index as (B*N), + # then find offset_k_index in top_k_index + base_k_index = torch.arange(batch_size, device=device).view( + -1, 1).repeat([1, beam_size]) # (B, N) + base_k_index = base_k_index * beam_size * beam_size + best_k_index = base_k_index.view(-1) + offset_k_index.view( + -1) # (B*N) + + # 2.5 Update best hyps + best_k_pred = torch.index_select(top_k_index.view(-1), + dim=-1, + index=best_k_index) # (B*N) + best_hyps_index = best_k_index // beam_size + last_best_k_hyps = torch.index_select( + hyps, dim=0, index=best_hyps_index) # (B*N, i) + hyps = torch.cat((last_best_k_hyps, best_k_pred.view(-1, 1)), + dim=1) # (B*N, i+1) + + # 2.6 Update end flag + end_flag = torch.eq(hyps[:, -1], self.eos).view(-1, 1) + + # 3. Select best of best + scores = scores.view(batch_size, beam_size) + # TODO: length normalization + best_scores, best_index = scores.max(dim=-1) + best_hyps_index = best_index + torch.arange( + batch_size, dtype=torch.long, device=device) * beam_size + best_hyps = torch.index_select(hyps, dim=0, index=best_hyps_index) + best_hyps = best_hyps[:, 1:] + return best_hyps, best_scores + + def ctc_greedy_search( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + decoding_chunk_size: int = -1, + num_decoding_left_chunks: int = -1, + simulate_streaming: bool = False, + ) -> List[List[int]]: + """ Apply CTC greedy search + + Args: + speech (torch.Tensor): (batch, max_len, feat_dim) + speech_length (torch.Tensor): (batch, ) + beam_size (int): beam size for beam search + decoding_chunk_size (int): decoding chunk for dynamic chunk + trained model. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + 0: used for training, it's prohibited here + simulate_streaming (bool): whether do encoder forward in a + streaming fashion + Returns: + List[List[int]]: best path result + """ + assert speech.shape[0] == speech_lengths.shape[0] + assert decoding_chunk_size != 0 + batch_size = speech.shape[0] + # Let's assume B = batch_size + encoder_out, encoder_mask = self._forward_encoder( + speech, speech_lengths, decoding_chunk_size, + num_decoding_left_chunks, + simulate_streaming) # (B, maxlen, encoder_dim) + maxlen = encoder_out.size(1) + encoder_out_lens = encoder_mask.squeeze(1).sum(1) + ctc_probs = self.ctc.log_softmax( + encoder_out) # (B, maxlen, vocab_size) + topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1) + topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen) + mask = make_pad_mask(encoder_out_lens, maxlen) # (B, maxlen) + topk_index = topk_index.masked_fill_(mask, self.eos) # (B, maxlen) + hyps = [hyp.tolist() for hyp in topk_index] + scores = topk_prob.max(1) + hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps] + return hyps, scores + + def _ctc_prefix_beam_search( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + beam_size: int, + decoding_chunk_size: int = -1, + num_decoding_left_chunks: int = -1, + simulate_streaming: bool = False, + ) -> Tuple[List[List[int]], torch.Tensor]: + """ CTC prefix beam search inner implementation + + Args: + speech (torch.Tensor): (batch, max_len, feat_dim) + speech_length (torch.Tensor): (batch, ) + beam_size (int): beam size for beam search + decoding_chunk_size (int): decoding chunk for dynamic chunk + trained model. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + 0: used for training, it's prohibited here + simulate_streaming (bool): whether do encoder forward in a + streaming fashion + + Returns: + List[List[int]]: nbest results + torch.Tensor: encoder output, (1, max_len, encoder_dim), + it will be used for rescoring in attention rescoring mode + """ + assert speech.shape[0] == speech_lengths.shape[0] + assert decoding_chunk_size != 0 + batch_size = speech.shape[0] + # For CTC prefix beam search, we only support batch_size=1 + assert batch_size == 1 + # Let's assume B = batch_size and N = beam_size + # 1. Encoder forward and get CTC score + encoder_out, encoder_mask = self._forward_encoder( + speech, speech_lengths, decoding_chunk_size, + num_decoding_left_chunks, + simulate_streaming) # (B, maxlen, encoder_dim) + maxlen = encoder_out.size(1) + ctc_probs = self.ctc.log_softmax( + encoder_out) # (1, maxlen, vocab_size) + ctc_probs = ctc_probs.squeeze(0) + # cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score)) + cur_hyps = [(tuple(), (0.0, -float('inf')))] + # 2. CTC beam search step by step + for t in range(0, maxlen): + logp = ctc_probs[t] # (vocab_size,) + # key: prefix, value (pb, pnb), default value(-inf, -inf) + next_hyps = defaultdict(lambda: (-float('inf'), -float('inf'))) + # 2.1 First beam prune: select topk best + top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,) + for s in top_k_index: + s = s.item() + ps = logp[s].item() + for prefix, (pb, pnb) in cur_hyps: + last = prefix[-1] if len(prefix) > 0 else None + if s == 0: # blank + n_pb, n_pnb = next_hyps[prefix] + n_pb = log_add([n_pb, pb + ps, pnb + ps]) + next_hyps[prefix] = (n_pb, n_pnb) + elif s == last: + # Update *ss -> *s; + n_pb, n_pnb = next_hyps[prefix] + n_pnb = log_add([n_pnb, pnb + ps]) + next_hyps[prefix] = (n_pb, n_pnb) + # Update *s-s -> *ss, - is for blank + n_prefix = prefix + (s, ) + n_pb, n_pnb = next_hyps[n_prefix] + n_pnb = log_add([n_pnb, pb + ps]) + next_hyps[n_prefix] = (n_pb, n_pnb) + else: + n_prefix = prefix + (s, ) + n_pb, n_pnb = next_hyps[n_prefix] + n_pnb = log_add([n_pnb, pb + ps, pnb + ps]) + next_hyps[n_prefix] = (n_pb, n_pnb) + + # 2.2 Second beam prune + next_hyps = sorted(next_hyps.items(), + key=lambda x: log_add(list(x[1])), + reverse=True) + cur_hyps = next_hyps[:beam_size] + hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in cur_hyps] + return hyps, encoder_out + + def ctc_prefix_beam_search( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + beam_size: int, + decoding_chunk_size: int = -1, + num_decoding_left_chunks: int = -1, + simulate_streaming: bool = False, + ) -> List[int]: + """ Apply CTC prefix beam search + + Args: + speech (torch.Tensor): (batch, max_len, feat_dim) + speech_length (torch.Tensor): (batch, ) + beam_size (int): beam size for beam search + decoding_chunk_size (int): decoding chunk for dynamic chunk + trained model. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + 0: used for training, it's prohibited here + simulate_streaming (bool): whether do encoder forward in a + streaming fashion + + Returns: + List[int]: CTC prefix beam search nbest results + """ + hyps, _ = self._ctc_prefix_beam_search(speech, speech_lengths, + beam_size, decoding_chunk_size, + num_decoding_left_chunks, + simulate_streaming) + return hyps[0] + + def attention_rescoring( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + beam_size: int, + decoding_chunk_size: int = -1, + num_decoding_left_chunks: int = -1, + ctc_weight: float = 0.0, + simulate_streaming: bool = False, + reverse_weight: float = 0.0, + ) -> List[int]: + """ Apply attention rescoring decoding, CTC prefix beam search + is applied first to get nbest, then we resoring the nbest on + attention decoder with corresponding encoder out + + Args: + speech (torch.Tensor): (batch, max_len, feat_dim) + speech_length (torch.Tensor): (batch, ) + beam_size (int): beam size for beam search + decoding_chunk_size (int): decoding chunk for dynamic chunk + trained model. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + 0: used for training, it's prohibited here + simulate_streaming (bool): whether do encoder forward in a + streaming fashion + reverse_weight (float): right to left decoder weight + ctc_weight (float): ctc score weight + + Returns: + List[int]: Attention rescoring result + """ + assert speech.shape[0] == speech_lengths.shape[0] + assert decoding_chunk_size != 0 + if reverse_weight > 0.0: + # decoder should be a bitransformer decoder if reverse_weight > 0.0 + assert hasattr(self.decoder, 'right_decoder') + device = speech.device + batch_size = speech.shape[0] + # For attention rescoring we only support batch_size=1 + assert batch_size == 1 + # encoder_out: (1, maxlen, encoder_dim), len(hyps) = beam_size + hyps, encoder_out = self._ctc_prefix_beam_search( + speech, speech_lengths, beam_size, decoding_chunk_size, + num_decoding_left_chunks, simulate_streaming) + + assert len(hyps) == beam_size + hyps_pad = pad_sequence([ + torch.tensor(hyp[0], device=device, dtype=torch.long) + for hyp in hyps + ], True, self.ignore_id) # (beam_size, max_hyps_len) + ori_hyps_pad = hyps_pad + hyps_lens = torch.tensor([len(hyp[0]) for hyp in hyps], + device=device, + dtype=torch.long) # (beam_size,) + hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id) + hyps_lens = hyps_lens + 1 # Add at begining + encoder_out = encoder_out.repeat(beam_size, 1, 1) + encoder_mask = torch.ones(beam_size, + 1, + encoder_out.size(1), + dtype=torch.bool, + device=device) + # used for right to left decoder + r_hyps_pad = reverse_pad_list(ori_hyps_pad, hyps_lens, self.ignore_id) + r_hyps_pad, _ = add_sos_eos(r_hyps_pad, self.sos, self.eos, + self.ignore_id) + decoder_out, r_decoder_out, _ = self.decoder( + encoder_out, encoder_mask, hyps_pad, hyps_lens, r_hyps_pad, + reverse_weight) # (beam_size, max_hyps_len, vocab_size) + decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1) + decoder_out = decoder_out.cpu().numpy() + # r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a + # conventional transformer decoder. + r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1) + r_decoder_out = r_decoder_out.cpu().numpy() + # Only use decoder score for rescoring + best_score = -float('inf') + best_index = 0 + for i, hyp in enumerate(hyps): + score = 0.0 + for j, w in enumerate(hyp[0]): + score += decoder_out[i][j][w] + score += decoder_out[i][len(hyp[0])][self.eos] + # add right to left decoder score + if reverse_weight > 0: + r_score = 0.0 + for j, w in enumerate(hyp[0]): + r_score += r_decoder_out[i][len(hyp[0]) - j - 1][w] + r_score += r_decoder_out[i][len(hyp[0])][self.eos] + score = score * (1 - reverse_weight) + r_score * reverse_weight + # add ctc score + score += hyp[1] * ctc_weight + if score > best_score: + best_score = score + best_index = i + return hyps[best_index][0], best_score + + @torch.jit.export + def subsampling_rate(self) -> int: + """ Export interface for c++ call, return subsampling_rate of the + model + """ + return self.encoder.embed.subsampling_rate + + @torch.jit.export + def right_context(self) -> int: + """ Export interface for c++ call, return right_context of the model + """ + return self.encoder.embed.right_context + + @torch.jit.export + def sos_symbol(self) -> int: + """ Export interface for c++ call, return sos symbol id of the model + """ + return self.sos + + @torch.jit.export + def eos_symbol(self) -> int: + """ Export interface for c++ call, return eos symbol id of the model + """ + return self.eos + + @torch.jit.export + def forward_encoder_chunk( + self, + xs: torch.Tensor, + offset: int, + required_cache_size: int, + subsampling_cache: Optional[torch.Tensor] = None, + elayers_output_cache: Optional[List[torch.Tensor]] = None, + conformer_cnn_cache: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], + List[torch.Tensor]]: + """ Export interface for c++ call, give input chunk xs, and return + output from time 0 to current chunk. + + Args: + xs (torch.Tensor): chunk input + subsampling_cache (Optional[torch.Tensor]): subsampling cache + elayers_output_cache (Optional[List[torch.Tensor]]): + transformer/conformer encoder layers output cache + conformer_cnn_cache (Optional[List[torch.Tensor]]): conformer + cnn cache + + Returns: + torch.Tensor: output, it ranges from time 0 to current chunk. + torch.Tensor: subsampling cache + List[torch.Tensor]: attention cache + List[torch.Tensor]: conformer cnn cache + + """ + return self.encoder.forward_chunk(xs, offset, required_cache_size, + subsampling_cache, + elayers_output_cache, + conformer_cnn_cache) + + @torch.jit.export + def ctc_activation(self, xs: torch.Tensor) -> torch.Tensor: + """ Export interface for c++ call, apply linear transform and log + softmax before ctc + Args: + xs (torch.Tensor): encoder output + + Returns: + torch.Tensor: activation before ctc + + """ + return self.ctc.log_softmax(xs) + + @torch.jit.export + def is_bidirectional_decoder(self) -> bool: + """ + Returns: + torch.Tensor: decoder output + """ + if hasattr(self.decoder, 'right_decoder'): + return True + else: + return False + + @torch.jit.export + def forward_attention_decoder( + self, + hyps: torch.Tensor, + hyps_lens: torch.Tensor, + encoder_out: torch.Tensor, + reverse_weight: float = 0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ Export interface for c++ call, forward decoder with multiple + hypothesis from ctc prefix beam search and one encoder output + Args: + hyps (torch.Tensor): hyps from ctc prefix beam search, already + pad sos at the begining + hyps_lens (torch.Tensor): length of each hyp in hyps + encoder_out (torch.Tensor): corresponding encoder output + r_hyps (torch.Tensor): hyps from ctc prefix beam search, already + pad eos at the begining which is used fo right to left decoder + reverse_weight: used for verfing whether used right to left decoder, + > 0 will use. + + Returns: + torch.Tensor: decoder output + """ + assert encoder_out.size(0) == 1 + num_hyps = hyps.size(0) + assert hyps_lens.size(0) == num_hyps + encoder_out = encoder_out.repeat(num_hyps, 1, 1) + encoder_mask = torch.ones(num_hyps, + 1, + encoder_out.size(1), + dtype=torch.bool, + device=encoder_out.device) + # input for right to left decoder + # this hyps_lens has count token, we need minus it. + r_hyps_lens = hyps_lens - 1 + # this hyps has included token, so it should be + # convert the original hyps. + r_hyps = hyps[:, 1:] + r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(self.ignore_id)) + r_hyps, _ = add_sos_eos(r_hyps, self.sos, self.eos, self.ignore_id) + decoder_out, r_decoder_out, _ = self.decoder( + encoder_out, encoder_mask, hyps, hyps_lens, r_hyps, + reverse_weight) # (num_hyps, max_hyps_len, vocab_size) + decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1) + + # right to left decoder may be not used during decoding process, + # which depends on reverse_weight param. + # r_dccoder_out will be 0.0, if reverse_weight is 0.0 + r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1) + return decoder_out, r_decoder_out + + +def init_data2vec_model(configs): + if configs['cmvn_file'] is not None: + mean, istd = load_cmvn(configs['cmvn_file'], configs['is_json_cmvn']) + global_cmvn = GlobalCMVN( + torch.from_numpy(mean).float(), + torch.from_numpy(istd).float()) + else: + global_cmvn = None + + input_dim = configs['input_dim'] + vocab_size = configs['output_dim'] + + encoder_type = configs.get('encoder', 'conformer') + decoder_type = configs.get('decoder', 'bitransformer') + + data2vec_conf=configs['data2vec_conf'] + + if encoder_type == 'conformer': + encoder = Data2vecConformerEncoder(data2vec_conf, + input_dim, + global_cmvn=global_cmvn, + **configs['encoder_conf']) + else: + encoder = TransformerEncoder(input_dim, + global_cmvn=global_cmvn, + **configs['encoder_conf']) + if decoder_type == 'transformer': + decoder = TransformerDecoder(vocab_size, encoder.output_size(), + **configs['decoder_conf']) + else: + assert 0.0 < configs['model_conf']['reverse_weight'] < 1.0 + assert configs['decoder_conf']['r_num_blocks'] > 0 + decoder = BiTransformerDecoder(vocab_size, encoder.output_size(), + **configs['decoder_conf']) + ctc = CTC(vocab_size, encoder.output_size()) + model = Data2vecModel( + data2vec_conf=data2vec_conf, + vocab_size=vocab_size, + encoder=encoder, + decoder=decoder, + ctc=ctc, + **configs['model_conf'], + ) + return model diff --git a/wenet/data2vec/ema.py b/wenet/data2vec/ema.py new file mode 100644 index 000000000..ba93448ae --- /dev/null +++ b/wenet/data2vec/ema.py @@ -0,0 +1,205 @@ +#!/usr/bin/env python3 + +""" +This module has the EMA class used to store a copy of the exponentially decayed +model params. + +Typical usage of EMA class involves initializing an object using an existing +model (random or from a seed model) and setting the config like ema_decay, +ema_start_update which determine how the EMA model is updated. After every +update of the model i.e. at the end of the train_step, the EMA should be updated +by passing the new model to the EMA.step function. The EMA model state dict +can be stored in the extra state under the key of "ema" and dumped +into a checkpoint and loaded. The EMA object can be passed to tasks +by setting task.uses_ema property. +EMA is a smoothed/ensemble model which might have better performance +when used for inference or further fine-tuning. EMA class has a +reverse function to load the EMA params into a model and use it +like a regular model. +""" + +import copy +import logging + +import torch + +class EMA(object): + """Exponential Moving Average of Fairseq Models + EMA keeps a copy of the exponentially decayed model params. + The set of params should include both gradient-descent and + non-gradient descent params, such as batch mean/var and buffers. + This is a modified implementation of + the open source code in https://github.com/zhawe01/fairseq-gec.git, + and internal source code in + fbcode/mobile-vision/projects/classification_pytorch/lib/utils/model_ema.py. + + Similar to TF EMA. + https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage. + EMA provides a averaged and smoothed set of model weights, and has been shown to + improve vision models. EMA class does all necessary functions to update, reload, + or init EMA methods. + + EMA object is initialized from an arbitrary model. By default, it is stored in + the same device (unless device specified at initialization) and with the + same precision as the model (unless ema_fp32 is True). ema_fp32 is recommended. + This stores the EMA parameters in fp32 only for the EMA update step, and + is used at the default precision otherwise. + EMA is usually enabled using EMAConfig with store_ema=True. Some important + parameters to configure EMA are + 1) ema_decay - The decay of EMA + 2) ema_update_freq - EMA is updated every this many model updates. + 3) ema_start_update - Start EMA update after this many model updates [default 0] + + Key methods: + 1) step - One update of EMA using new model + 2) restore - Update EMA from a state dict + 3) reverse - Load EMA into a model + 4) get_decay, _set_decay - Used to get or set the decay. Note _set_decay is + called from step. + 5) build_fp32_params - Used to initialize or update the fp32 copy of EMA params. + Note this is enabled only when ema_fp32=True + """ + + def __init__(self, model, config, device=None, skip_keys=None): + """ + @param model model to initialize the EMA with + @param config EMAConfig object with configuration like + ema_decay, ema_update_freq, ema_fp32 + @param device If provided, copy EMA to this device (e.g. gpu). + Otherwise EMA is in the same device as the model. + """ + + self.decay = config["ema_decay"] + self.model = copy.deepcopy(model) + self.model.requires_grad_(False) + self.config = config + self.skip_keys = skip_keys or set() + self.fp32_params = {} + + # if self.config.ema_seed_model is not None: + # state = checkpoint_utils.load_ema_from_checkpoint( + # self.config.ema_seed_model + # ) + # self.model.load_state_dict(state["model"], strict=True) + + if device is not None: + logging.info(f"Copying EMA model to device {device}") + self.model = self.model.to(device=device) + + if self.config["ema_fp32"]: + self.build_fp32_params() + + self.update_freq_counter = 0 + + def get_model(self): + return self.model + + def build_fp32_params(self, state_dict=None): + """ + Store a copy of the EMA params in fp32. + If state dict is passed, the EMA params is copied from + the provided state dict. Otherwise, it is copied from the + current EMA model parameters. + """ + if not self.config["ema_fp32"]: + raise RuntimeError( + "build_fp32_params should not be called if ema_fp32=False. " + "Use ema_fp32=True if this is really intended." + ) + + if state_dict is None: + state_dict = self.model.state_dict() + + def _to_float(t): + return t.float() if torch.is_floating_point(t) else t + + for param_key in state_dict: + # self.fp32_params[param_key].copy_(state_dict[param_key]) + if param_key in self.fp32_params: + self.fp32_params[param_key].copy_(state_dict[param_key]) + else: + self.fp32_params[param_key] = _to_float(state_dict[param_key]) + + def restore(self, state_dict, build_fp32_params=False): + """Load data from a model spec into EMA model""" + self.model.load_state_dict(state_dict, strict=False) + self.model.requires_grad_(False) + if build_fp32_params: + self.build_fp32_params(state_dict) + + def _set_decay(self, decay): + self.decay = decay + + def get_decay(self): + return self.decay + + def _step_internal(self, new_model, updates=None): + """One update of the EMA model based on new model weights""" + decay = self.decay + + ema_state_dict = {} + ema_params = ( + self.fp32_params if self.config["ema_fp32"] else self.model.state_dict() + ) + for key, param in new_model.state_dict().items(): + if isinstance(param, dict): + continue + try: + ema_param = ema_params[key] + except KeyError: + ema_param = ( + param.float().clone() if param.ndim == 1 else copy.deepcopy(param) + ) + + if param.shape != ema_param.shape: + raise ValueError( + "incompatible tensor shapes between model param and ema param" + + "{} vs. {}".format(param.shape, ema_param.shape) + ) + + if "version" in key: + # Do not decay a model.version pytorch param + continue + + if key in self.skip_keys: + ema_param = param.to(dtype=ema_param.dtype).clone() + else: + ema_param.mul_(decay) + ema_param.add_(param.to(dtype=ema_param.dtype), alpha=1 - decay) + ema_state_dict[key] = ema_param + self.restore(ema_state_dict, build_fp32_params=False) + + def step(self, new_model, updates=None): + """ + One update of EMA which is done every self.config.ema_update_freq + updates of the model. + + @param updates The current number of model updates done. + Decay is set of 0 if model updates < ema_start_update, which means + the model will be simply copied over to the EMA. + When model updates >= ema_start_updates, then EMA is updated with + a decay of self.config.ema_decay. + """ + if updates is not None: + self._set_decay( + 0 if updates < self.config["ema_start_update"] else self.config["ema_decay"] + ) + if updates is not None and self.config["ema_update_freq"] > 1: + self.update_freq_counter += 1 + if self.update_freq_counter >= self.config["ema_update_freq"]: + self._step_internal(new_model, updates) + self.update_freq_counter = 0 + else: + self._step_internal(new_model, updates) + + def reverse(self, model): + """ + Load the model parameters from EMA model. + Useful for inference or fine-tuning from the EMA model. + """ + d = self.model.state_dict() + if "_ema" in d: + del d["_ema"] + + model.load_state_dict(d, strict=False) + return model diff --git a/wenet/transformer/encoder_layer.py b/wenet/transformer/encoder_layer.py index 4d78a745f..3cd662c5e 100644 --- a/wenet/transformer/encoder_layer.py +++ b/wenet/transformer/encoder_layer.py @@ -151,6 +151,7 @@ def __init__( dropout_rate: float = 0.1, normalize_before: bool = True, concat_after: bool = False, + ffn_res: bool = False, ): """Construct an EncoderLayer object.""" super().__init__() @@ -176,6 +177,7 @@ def __init__( self.concat_after = concat_after self.concat_linear = nn.Linear(size + size, size) + self.ffn_res=ffn_res def forward( self, x: torch.Tensor, @@ -255,7 +257,9 @@ def forward( if self.normalize_before: x = self.norm_ff(x) - x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) + x=self.feed_forward(x) + layer_out=x + x = residual + self.ff_scale * self.dropout(x) if not self.normalize_before: x = self.norm_ff(x) @@ -265,4 +269,7 @@ def forward( if output_cache is not None: x = torch.cat([output_cache, x], dim=1) - return x, mask, new_cnn_cache + if self.ffn_res: + return x, mask, new_cnn_cache,layer_out + else: + return x, mask, new_cnn_cache diff --git a/wenet/utils/executor.py b/wenet/utils/executor.py index a1060e697..8436a8665 100644 --- a/wenet/utils/executor.py +++ b/wenet/utils/executor.py @@ -7,7 +7,7 @@ # from contextlib import suppress as nullcontext import torch from torch.nn.utils import clip_grad_norm_ - +import math class Executor: def __init__(self): @@ -25,6 +25,7 @@ def train(self, model, optimizer, scheduler, data_loader, device, writer, accum_grad = args.get('accum_grad', 1) is_distributed = args.get('is_distributed', True) use_amp = args.get('use_amp', False) + pretrain=args.get('pretrain',False) logging.info('using accumulate grad, new batch size is {} times' ' larger than before'.format(accum_grad)) if use_amp: @@ -62,8 +63,12 @@ def train(self, model, optimizer, scheduler, data_loader, device, writer, # The more details about amp can be found in # https://pytorch.org/docs/stable/notes/amp_examples.html with torch.cuda.amp.autocast(scaler is not None): - loss, loss_att, loss_ctc = model( - feats, feats_lengths, target, target_lengths) + if pretrain: + loss, loss_att, loss_ctc ,sample_size= model( + feats, feats_lengths, target, target_lengths,self.step) + else: + loss, loss_att, loss_ctc = model( + feats, feats_lengths, target, target_lengths) loss = loss / accum_grad if use_amp: scaler.scale(loss).backward() @@ -96,9 +101,14 @@ def train(self, model, optimizer, scheduler, data_loader, device, writer, self.step += 1 if batch_idx % log_interval == 0: lr = optimizer.param_groups[0]['lr'] - log_str = 'TRAIN Batch {}/{} loss {:.6f} '.format( - epoch, batch_idx, - loss.item() * accum_grad) + if pretrain: + log_str = 'TRAIN Batch {}/{} loss {:.6f} '.format( + epoch, batch_idx, + loss.item() * accum_grad/sample_size / math.log(2)) + else: + log_str = 'TRAIN Batch {}/{} loss {:.6f} '.format( + epoch, batch_idx, + loss.item() * accum_grad) if loss_att is not None: log_str += 'loss_att {:.6f} '.format(loss_att.item()) if loss_ctc is not None: @@ -113,6 +123,7 @@ def cv(self, model, data_loader, device, args): rank = args.get('rank', 0) epoch = args.get('epoch', 0) log_interval = args.get('log_interval', 10) + pretrain=args.get('pretrain',False) # in order to avoid division by 0 num_seen_utts = 1 total_loss = 0.0 @@ -126,8 +137,13 @@ def cv(self, model, data_loader, device, args): num_utts = target_lengths.size(0) if num_utts == 0: continue - loss, loss_att, loss_ctc = model(feats, feats_lengths, target, - target_lengths) + if pretrain: + loss, loss_att, loss_ctc ,sample_size= model(feats, feats_lengths, target, + target_lengths) + loss=loss / sample_size / math.log(2) + else: + loss, loss_att, loss_ctc = model(feats, feats_lengths, target, + target_lengths) if torch.isfinite(loss): num_seen_utts += num_utts total_loss += loss.item() * num_utts From ae098dc9ec6c2a454979416fcc38e2ca76a11995 Mon Sep 17 00:00:00 2001 From: aydentang Date: Wed, 30 Mar 2022 19:42:01 +0800 Subject: [PATCH 11/19] update config --- .../librispeech/ssl/conf/finetune/train_conformer_100h.yaml | 4 ++-- .../ssl/conf/finetune/train_conformer_100h_data2vec.yaml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/librispeech/ssl/conf/finetune/train_conformer_100h.yaml b/examples/librispeech/ssl/conf/finetune/train_conformer_100h.yaml index acf7941c8..f01681615 100644 --- a/examples/librispeech/ssl/conf/finetune/train_conformer_100h.yaml +++ b/examples/librispeech/ssl/conf/finetune/train_conformer_100h.yaml @@ -20,9 +20,9 @@ encoder_conf: # decoder related decoder: transformer decoder_conf: - attention_heads: 2 + attention_heads: 4 linear_units: 512 - num_blocks: 1 + num_blocks: 2 dropout_rate: 0.1 positional_dropout_rate: 0.0 self_attention_dropout_rate: 0.0 diff --git a/examples/librispeech/ssl/conf/finetune/train_conformer_100h_data2vec.yaml b/examples/librispeech/ssl/conf/finetune/train_conformer_100h_data2vec.yaml index 9e2c38694..bd19fc72e 100644 --- a/examples/librispeech/ssl/conf/finetune/train_conformer_100h_data2vec.yaml +++ b/examples/librispeech/ssl/conf/finetune/train_conformer_100h_data2vec.yaml @@ -20,9 +20,9 @@ encoder_conf: # decoder related decoder: transformer decoder_conf: - attention_heads: 2 + attention_heads: 4 linear_units: 512 - num_blocks: 1 + num_blocks: 2 dropout_rate: 0.1 positional_dropout_rate: 0.0 self_attention_dropout_rate: 0.0 From 76169d4f81923a2dfc30752b1d114eeaeb328e57 Mon Sep 17 00:00:00 2001 From: aydentang Date: Thu, 31 Mar 2022 21:39:47 +0800 Subject: [PATCH 12/19] fix arxiv paper,fix jit export --- wenet/bin/train.py | 14 ++++++++------ wenet/data2vec/data2vec_encoder.py | 2 +- wenet/data2vec/data2vec_model.py | 2 +- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/wenet/bin/train.py b/wenet/bin/train.py index 06749284d..7ccfef540 100644 --- a/wenet/bin/train.py +++ b/wenet/bin/train.py @@ -211,14 +211,16 @@ def main(): model=init_data2vec_model(configs) else: model = init_asr_model(configs) - print(model) - num_params = sum(p.numel() for p in model.parameters()) - print('the number of model params: {}'.format(num_params)) + if args.rank == 0: + print(model) + num_params = sum(p.numel() for p in model.parameters()) + print('the number of model params: {}'.format(num_params)) # !!!IMPORTANT!!! - # Try to export the model by script, if fails, we should refine - # the code to satisfy the script export requirements - if args.rank == 0 and not pretrain and wav2vec_conf is None: + # Now the pretrain model does not support jit export + Try to export the model by script, if fails, we should refine + the code to satisfy the script export requirements + if args.rank == 0 and not pretrain and wav2vec_conf is None and data2vec_conf is None: script_model = torch.jit.script(model) script_model.save(os.path.join(args.model_dir, 'init.zip')) executor = Executor() diff --git a/wenet/data2vec/data2vec_encoder.py b/wenet/data2vec/data2vec_encoder.py index a7adf1347..b0ad07c25 100644 --- a/wenet/data2vec/data2vec_encoder.py +++ b/wenet/data2vec/data2vec_encoder.py @@ -3,7 +3,7 @@ # The origial Data2vec work is in: -# Paper: https://arxiv.org/pdf/2006.11477.pdf +# Paper: https://arxiv.org/pdf/2202.03555.pdf # Code in Fairseq: https://github.com/pytorch/fairseq/tree/master/examples/data2vec """Encoder definition.""" diff --git a/wenet/data2vec/data2vec_model.py b/wenet/data2vec/data2vec_model.py index 30beff4d0..d0c07e5a6 100644 --- a/wenet/data2vec/data2vec_model.py +++ b/wenet/data2vec/data2vec_model.py @@ -3,7 +3,7 @@ # The origial Data2vec work is in: -# Paper: https://arxiv.org/pdf/2006.11477.pdf +# Paper: https://arxiv.org/pdf/2202.03555.pdf # Code in Fairseq: https://github.com/pytorch/fairseq/tree/master/examples/data2vec from collections import defaultdict From 8a1cb96795b38ba9aeb23c4005242acd0d9829e8 Mon Sep 17 00:00:00 2001 From: aydentang Date: Thu, 31 Mar 2022 21:42:07 +0800 Subject: [PATCH 13/19] fix jit export --- wenet/bin/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wenet/bin/train.py b/wenet/bin/train.py index 7ccfef540..f12decd8f 100644 --- a/wenet/bin/train.py +++ b/wenet/bin/train.py @@ -218,8 +218,8 @@ def main(): # !!!IMPORTANT!!! # Now the pretrain model does not support jit export - Try to export the model by script, if fails, we should refine - the code to satisfy the script export requirements + # Try to export the model by script, if fails, we should refine + # the code to satisfy the script export requirements if args.rank == 0 and not pretrain and wav2vec_conf is None and data2vec_conf is None: script_model = torch.jit.script(model) script_model.save(os.path.join(args.model_dir, 'init.zip')) From 54f6b368fdf47c20f1381b96e42f134f687a9782 Mon Sep 17 00:00:00 2001 From: aydentang Date: Thu, 25 Aug 2022 22:41:54 +0800 Subject: [PATCH 14/19] fix some bugs --- wenet/wav2vec/wav2vec2_encoder.py | 1 - wenet/wav2vec/wav2vec2_model.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/wenet/wav2vec/wav2vec2_encoder.py b/wenet/wav2vec/wav2vec2_encoder.py index 0f1239641..000c6dfe4 100644 --- a/wenet/wav2vec/wav2vec2_encoder.py +++ b/wenet/wav2vec/wav2vec2_encoder.py @@ -752,7 +752,6 @@ def __init__( cnn_module_kernel: int = 15, causal: bool = False, cnn_module_norm: str = "batch_norm", - cnn_module_before: bool = False, use_feature_norm: bool = False, ): """Construct ConformerEncoder diff --git a/wenet/wav2vec/wav2vec2_model.py b/wenet/wav2vec/wav2vec2_model.py index 4aac869b0..119c663fc 100644 --- a/wenet/wav2vec/wav2vec2_model.py +++ b/wenet/wav2vec/wav2vec2_model.py @@ -122,7 +122,7 @@ def forward( with torch.no_grad(): encoder_out, encoder_mask,ext_res = self.encoder(speech, speech_lengths,features_only=True) # 1. Encoder - encoder_out, encoder_mask,ext_res = self.encoder(speech, speech_lengths,features_only=True) + # encoder_out, encoder_mask,ext_res = self.encoder(speech, speech_lengths,features_only=True) encoder_out_lens = encoder_mask.squeeze(1).sum(1) @@ -157,7 +157,7 @@ def forward( - return loss, loss_att, loss_ctc + return loss, loss_att, loss_ctc,sample_size def _calc_att_loss( self, From 28842eb8f123ca004965b21b9b87bbe663c7b0e0 Mon Sep 17 00:00:00 2001 From: aydentang Date: Thu, 25 Aug 2022 23:00:41 +0800 Subject: [PATCH 15/19] aishell finetune example --- .../conf/train_base_conformer_100h.yaml | 77 ++++++ .../aishell/s0_ssl/conf/train_conformer.yaml | 77 ++++++ .../s0_ssl/conf/train_u2++_conformer.yaml | 83 ++++++ .../s0_ssl/conf/train_unified_conformer.yaml | 81 ++++++ .../aishell/s0_ssl/local/aishell_data_prep.sh | 65 +++++ .../aishell/s0_ssl/local/aishell_train_lms.sh | 59 +++++ .../s0_ssl/local/download_and_untar.sh | 105 ++++++++ examples/aishell/s0_ssl/path.sh | 8 + examples/aishell/s0_ssl/run_ssl.sh | 250 ++++++++++++++++++ 9 files changed, 805 insertions(+) create mode 100644 examples/aishell/s0_ssl/conf/train_base_conformer_100h.yaml create mode 100644 examples/aishell/s0_ssl/conf/train_conformer.yaml create mode 100644 examples/aishell/s0_ssl/conf/train_u2++_conformer.yaml create mode 100644 examples/aishell/s0_ssl/conf/train_unified_conformer.yaml create mode 100644 examples/aishell/s0_ssl/local/aishell_data_prep.sh create mode 100644 examples/aishell/s0_ssl/local/aishell_train_lms.sh create mode 100644 examples/aishell/s0_ssl/local/download_and_untar.sh create mode 100644 examples/aishell/s0_ssl/path.sh create mode 100644 examples/aishell/s0_ssl/run_ssl.sh diff --git a/examples/aishell/s0_ssl/conf/train_base_conformer_100h.yaml b/examples/aishell/s0_ssl/conf/train_base_conformer_100h.yaml new file mode 100644 index 000000000..e888c8681 --- /dev/null +++ b/examples/aishell/s0_ssl/conf/train_base_conformer_100h.yaml @@ -0,0 +1,77 @@ +# network architecture +# encoder related +encoder: conformer +encoder_conf: + output_size: 512 # dimension of attention + attention_heads: 8 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.0 + input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + normalize_before: true + cnn_module_kernel: 15 + use_cnn_module: True + activation_type: 'swish' + pos_enc_layer_type: 'rel_pos' + selfattention_layer_type: 'rel_selfattn' + +# decoder related +decoder: transformer +decoder_conf: + attention_heads: 4 + linear_units: 512 + num_blocks: 2 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 + +# hybrid CTC/attention +model_conf: + ctc_weight: 0.7 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + +dataset_conf: + filter_conf: + max_length: 40960 + min_length: 0 + token_max_length: 200 + token_min_length: 1 + resample_conf: + resample_rate: 16000 + speed_perturb: true + fbank_conf: + num_mel_bins: 80 + frame_shift: 10 + frame_length: 25 + dither: 0.1 + spec_aug: true + spec_aug_conf: + num_t_mask: 2 + num_f_mask: 2 + max_t: 50 + max_f: 10 + shuffle: true + shuffle_conf: + shuffle_size: 1500 + sort: true + sort_conf: + sort_size: 500 # sort_size should be less than shuffle_size + batch_conf: + batch_type: 'static' # static or dynamic + batch_size: 16 + +grad_clip: 5 +accum_grad: 2 +max_epoch: 240 +log_interval: 100 + +optim: adam +optim_conf: + lr: 0.0004 +scheduler: warmuplr # pytorch v1.1.0+ required +scheduler_conf: + warmup_steps: 25000 diff --git a/examples/aishell/s0_ssl/conf/train_conformer.yaml b/examples/aishell/s0_ssl/conf/train_conformer.yaml new file mode 100644 index 000000000..b8ce511cd --- /dev/null +++ b/examples/aishell/s0_ssl/conf/train_conformer.yaml @@ -0,0 +1,77 @@ +# network architecture +# encoder related +encoder: conformer +encoder_conf: + output_size: 256 # dimension of attention + attention_heads: 4 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.0 + input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + normalize_before: true + cnn_module_kernel: 15 + use_cnn_module: True + activation_type: 'swish' + pos_enc_layer_type: 'rel_pos' + selfattention_layer_type: 'rel_selfattn' + +# decoder related +decoder: transformer +decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 + +# hybrid CTC/attention +model_conf: + ctc_weight: 0.3 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + +dataset_conf: + filter_conf: + max_length: 40960 + min_length: 0 + token_max_length: 200 + token_min_length: 1 + resample_conf: + resample_rate: 16000 + speed_perturb: true + fbank_conf: + num_mel_bins: 80 + frame_shift: 10 + frame_length: 25 + dither: 0.1 + spec_aug: true + spec_aug_conf: + num_t_mask: 2 + num_f_mask: 2 + max_t: 50 + max_f: 10 + shuffle: true + shuffle_conf: + shuffle_size: 1500 + sort: true + sort_conf: + sort_size: 500 # sort_size should be less than shuffle_size + batch_conf: + batch_type: 'static' # static or dynamic + batch_size: 16 + +grad_clip: 5 +accum_grad: 4 +max_epoch: 240 +log_interval: 100 + +optim: adam +optim_conf: + lr: 0.002 +scheduler: warmuplr # pytorch v1.1.0+ required +scheduler_conf: + warmup_steps: 25000 diff --git a/examples/aishell/s0_ssl/conf/train_u2++_conformer.yaml b/examples/aishell/s0_ssl/conf/train_u2++_conformer.yaml new file mode 100644 index 000000000..88742a0ff --- /dev/null +++ b/examples/aishell/s0_ssl/conf/train_u2++_conformer.yaml @@ -0,0 +1,83 @@ +# network architecture +# encoder related +encoder: conformer +encoder_conf: + output_size: 256 # dimension of attention + attention_heads: 4 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.1 + input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + normalize_before: true + cnn_module_kernel: 8 + use_cnn_module: True + activation_type: 'swish' + pos_enc_layer_type: 'rel_pos' + selfattention_layer_type: 'rel_selfattn' + causal: true + use_dynamic_chunk: true + cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster + use_dynamic_left_chunk: false + +# decoder related +decoder: bitransformer +decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 3 + r_num_blocks: 3 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.1 + src_attention_dropout_rate: 0.1 + +# hybrid CTC/attention +model_conf: + ctc_weight: 0.3 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + reverse_weight: 0.3 + +dataset_conf: + filter_conf: + max_length: 40960 + min_length: 0 + token_max_length: 200 + token_min_length: 1 + resample_conf: + resample_rate: 16000 + speed_perturb: true + fbank_conf: + num_mel_bins: 80 + frame_shift: 10 + frame_length: 25 + dither: 1.0 + spec_aug: true + spec_aug_conf: + num_t_mask: 2 + num_f_mask: 2 + max_t: 50 + max_f: 10 + shuffle: true + shuffle_conf: + shuffle_size: 1500 + sort: true + sort_conf: + sort_size: 500 # sort_size should be less than shuffle_size + batch_conf: + batch_type: 'static' # static or dynamic + batch_size: 16 + +grad_clip: 5 +accum_grad: 1 +max_epoch: 360 +log_interval: 100 + +optim: adam +optim_conf: + lr: 0.001 +scheduler: warmuplr # pytorch v1.1.0+ required +scheduler_conf: + warmup_steps: 25000 diff --git a/examples/aishell/s0_ssl/conf/train_unified_conformer.yaml b/examples/aishell/s0_ssl/conf/train_unified_conformer.yaml new file mode 100644 index 000000000..978d3d91c --- /dev/null +++ b/examples/aishell/s0_ssl/conf/train_unified_conformer.yaml @@ -0,0 +1,81 @@ +# network architecture +# encoder related +encoder: conformer +encoder_conf: + output_size: 256 # dimension of attention + attention_heads: 4 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.0 + input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + normalize_before: true + cnn_module_kernel: 15 + use_cnn_module: True + activation_type: 'swish' + pos_enc_layer_type: 'rel_pos' + selfattention_layer_type: 'rel_selfattn' + causal: true + use_dynamic_chunk: true + cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster + use_dynamic_left_chunk: false + +# decoder related +decoder: transformer +decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 + +# hybrid CTC/attention +model_conf: + ctc_weight: 0.3 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + +dataset_conf: + filter_conf: + max_length: 40960 + min_length: 0 + token_max_length: 200 + token_min_length: 1 + resample_conf: + resample_rate: 16000 + speed_perturb: true + fbank_conf: + num_mel_bins: 80 + frame_shift: 10 + frame_length: 25 + dither: 0.1 + spec_aug: true + spec_aug_conf: + num_t_mask: 2 + num_f_mask: 2 + max_t: 50 + max_f: 10 + shuffle: true + shuffle_conf: + shuffle_size: 1500 + sort: true + sort_conf: + sort_size: 500 # sort_size should be less than shuffle_size + batch_conf: + batch_type: 'static' # static or dynamic + batch_size: 16 + +grad_clip: 5 +accum_grad: 1 +max_epoch: 180 +log_interval: 100 + +optim: adam +optim_conf: + lr: 0.001 +scheduler: warmuplr # pytorch v1.1.0+ required +scheduler_conf: + warmup_steps: 25000 diff --git a/examples/aishell/s0_ssl/local/aishell_data_prep.sh b/examples/aishell/s0_ssl/local/aishell_data_prep.sh new file mode 100644 index 000000000..fb4d5fb0a --- /dev/null +++ b/examples/aishell/s0_ssl/local/aishell_data_prep.sh @@ -0,0 +1,65 @@ +#!/bin/bash + +# Copyright 2017 Xingyu Na +# Apache 2.0 + +. ./path.sh || exit 1; + +if [ $# != 2 ]; then + echo "Usage: $0 " + echo " $0 /export/a05/xna/data/data_aishell/wav /export/a05/xna/data/data_aishell/transcript" + exit 1; +fi + +aishell_audio_dir=$1 +aishell_text=$2/aishell_transcript_v0.8.txt + +train_dir=data/local/train +dev_dir=data/local/dev +test_dir=data/local/test +tmp_dir=data/local/tmp + +mkdir -p $train_dir +mkdir -p $dev_dir +mkdir -p $test_dir +mkdir -p $tmp_dir + +# data directory check +if [ ! -d $aishell_audio_dir ] || [ ! -f $aishell_text ]; then + echo "Error: $0 requires two directory arguments" + exit 1; +fi + +# find wav audio file for train, dev and test resp. +find $aishell_audio_dir -iname "*.wav" > $tmp_dir/wav.flist +n=`cat $tmp_dir/wav.flist | wc -l` +[ $n -ne 141925 ] && \ + echo Warning: expected 141925 data data files, found $n + +grep -i "wav/train" $tmp_dir/wav.flist > $train_dir/wav.flist || exit 1; +grep -i "wav/dev" $tmp_dir/wav.flist > $dev_dir/wav.flist || exit 1; +grep -i "wav/test" $tmp_dir/wav.flist > $test_dir/wav.flist || exit 1; + +rm -r $tmp_dir + +# Transcriptions preparation +for dir in $train_dir $dev_dir $test_dir; do + echo Preparing $dir transcriptions + sed -e 's/\.wav//' $dir/wav.flist | awk -F '/' '{print $NF}' > $dir/utt.list + paste -d' ' $dir/utt.list $dir/wav.flist > $dir/wav.scp_all + tools/filter_scp.pl -f 1 $dir/utt.list $aishell_text > $dir/transcripts.txt + awk '{print $1}' $dir/transcripts.txt > $dir/utt.list + tools/filter_scp.pl -f 1 $dir/utt.list $dir/wav.scp_all | sort -u > $dir/wav.scp + sort -u $dir/transcripts.txt > $dir/text +done + +mkdir -p data/train data/dev data/test + +for f in wav.scp text; do + cp $train_dir/$f data/train/$f || exit 1; + cp $dev_dir/$f data/dev/$f || exit 1; + cp $test_dir/$f data/test/$f || exit 1; +done + +echo "$0: AISHELL data preparation succeeded" +exit 0; diff --git a/examples/aishell/s0_ssl/local/aishell_train_lms.sh b/examples/aishell/s0_ssl/local/aishell_train_lms.sh new file mode 100644 index 000000000..30ffb7973 --- /dev/null +++ b/examples/aishell/s0_ssl/local/aishell_train_lms.sh @@ -0,0 +1,59 @@ +#!/bin/bash + + +# To be run from one directory above this script. +. ./path.sh + +text=data/local/lm/text +lexicon=data/local/dict/lexicon.txt + +for f in "$text" "$lexicon"; do + [ ! -f $x ] && echo "$0: No such file $f" && exit 1; +done + +# Check SRILM tools +if ! which ngram-count > /dev/null; then + echo "srilm tools are not found, please download it and install it from: " + echo "http://www.speech.sri.com/projects/srilm/download.html" + echo "Then add the tools to your PATH" + exit 1 +fi + +# This script takes no arguments. It assumes you have already run +# aishell_data_prep.sh. +# It takes as input the files +# data/local/lm/text +# data/local/dict/lexicon.txt +dir=data/local/lm +mkdir -p $dir + + +cleantext=$dir/text.no_oov + +cat $text | awk -v lex=$lexicon 'BEGIN{while((getline0){ seen[$1]=1; } } + {for(n=1; n<=NF;n++) { if (seen[$n]) { printf("%s ", $n); } else {printf(" ");} } printf("\n");}' \ + > $cleantext || exit 1; + +cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | sort | uniq -c | \ + sort -nr > $dir/word.counts || exit 1; + +# Get counts from acoustic training transcripts, and add one-count +# for each word in the lexicon (but not silence, we don't want it +# in the LM-- we'll add it optionally later). +cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | \ + cat - <(grep -w -v '!SIL' $lexicon | awk '{print $1}') | \ + sort | uniq -c | sort -nr > $dir/unigram.counts || exit 1; + +cat $dir/unigram.counts | awk '{print $2}' | cat - <(echo ""; echo "" ) > $dir/wordlist + +heldout_sent=10000 # Don't change this if you want result to be comparable with + # kaldi_lm results +mkdir -p $dir +cat $cleantext | awk '{for(n=2;n<=NF;n++){ printf $n; if(n $dir/heldout +cat $cleantext | awk '{for(n=2;n<=NF;n++){ printf $n; if(n $dir/train + +ngram-count -text $dir/train -order 3 -limit-vocab -vocab $dir/wordlist -unk \ + -map-unk "" -kndiscount -interpolate -lm $dir/lm.arpa +ngram -lm $dir/lm.arpa -ppl $dir/heldout diff --git a/examples/aishell/s0_ssl/local/download_and_untar.sh b/examples/aishell/s0_ssl/local/download_and_untar.sh new file mode 100644 index 000000000..58a278241 --- /dev/null +++ b/examples/aishell/s0_ssl/local/download_and_untar.sh @@ -0,0 +1,105 @@ +#!/bin/bash + +# Copyright 2014 Johns Hopkins University (author: Daniel Povey) +# 2017 Xingyu Na +# Apache 2.0 + +remove_archive=false + +if [ "$1" == --remove-archive ]; then + remove_archive=true + shift +fi + +if [ $# -ne 3 ]; then + echo "Usage: $0 [--remove-archive] " + echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell" + echo "With --remove-archive it will remove the archive after successfully un-tarring it." + echo " can be one of: data_aishell, resource_aishell." +fi + +data=$1 +url=$2 +part=$3 + +if [ ! -d "$data" ]; then + echo "$0: no such directory $data" + exit 1; +fi + +part_ok=false +list="data_aishell resource_aishell" +for x in $list; do + if [ "$part" == $x ]; then part_ok=true; fi +done +if ! $part_ok; then + echo "$0: expected to be one of $list, but got '$part'" + exit 1; +fi + +if [ -z "$url" ]; then + echo "$0: empty URL base." + exit 1; +fi + +if [ -f $data/$part/.complete ]; then + echo "$0: data part $part was already successfully extracted, nothing to do." + exit 0; +fi + +# sizes of the archive files in bytes. +sizes="15582913665 1246920" + +if [ -f $data/$part.tgz ]; then + size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}') + size_ok=false + for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done + if ! $size_ok; then + echo "$0: removing existing file $data/$part.tgz because its size in bytes $size" + echo "does not equal the size of one of the archives." + rm $data/$part.tgz + else + echo "$data/$part.tgz exists and appears to be complete." + fi +fi + +if [ ! -f $data/$part.tgz ]; then + if ! which wget >/dev/null; then + echo "$0: wget is not installed." + exit 1; + fi + full_url=$url/$part.tgz + echo "$0: downloading data from $full_url. This may take some time, please be patient." + + cd $data + if ! wget --no-check-certificate $full_url; then + echo "$0: error executing wget $full_url" + exit 1; + fi +fi + +cd $data + +if ! tar -xvzf $part.tgz; then + echo "$0: error un-tarring archive $data/$part.tgz" + exit 1; +fi + +touch $data/$part/.complete + +if [ $part == "data_aishell" ]; then + cd $data/$part/wav + for wav in ./*.tar.gz; do + echo "Extracting wav from $wav" + tar -zxf $wav && rm $wav + done +fi + +echo "$0: Successfully downloaded and un-tarred $data/$part.tgz" + +if $remove_archive; then + echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied." + rm $data/$part.tgz +fi + +exit 0; diff --git a/examples/aishell/s0_ssl/path.sh b/examples/aishell/s0_ssl/path.sh new file mode 100644 index 000000000..5ddca76cc --- /dev/null +++ b/examples/aishell/s0_ssl/path.sh @@ -0,0 +1,8 @@ +export WENET_DIR=$PWD/../../.. +export BUILD_DIR=${WENET_DIR}/runtime/server/x86/build +export OPENFST_PREFIX_DIR=${BUILD_DIR}/../fc_base/openfst-subbuild/openfst-populate-prefix +export PATH=$PWD:${BUILD_DIR}:${BUILD_DIR}/kaldi:${OPENFST_PREFIX_DIR}/bin:$PATH + +# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=../../../:$PYTHONPATH diff --git a/examples/aishell/s0_ssl/run_ssl.sh b/examples/aishell/s0_ssl/run_ssl.sh new file mode 100644 index 000000000..e4da77595 --- /dev/null +++ b/examples/aishell/s0_ssl/run_ssl.sh @@ -0,0 +1,250 @@ +#!/bin/bash +# Copyright 2022 Tencent Inc. (Author: Kai Tang). +# Apach 2.0 + +. ./path.sh || exit 1; + +# Use this to control how many gpu you use, It's 1-gpu training if you specify +# just 1gpu, otherwise it's is multiple gpu training based on DDP in pytorch +#export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" +export CUDA_VISIBLE_DEVICES="0,1,2,3" +# The NCCL_SOCKET_IFNAME variable specifies which IP interface to use for nccl +# communication. More details can be found in +# https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html +# export NCCL_SOCKET_IFNAME=ens4f1 +export NCCL_DEBUG=INFO +stage=4 # start from 0 if you need to start from data preparation +stop_stage=4 + +# The num of machines(nodes) for multi-machine training, 1 is for one machine. +# NFS is required if num_nodes > 1. +num_nodes=1 + +# The rank of each node or machine, which ranges from 0 to `num_nodes - 1`. +# You should set the node_rank=0 on the first machine, set the node_rank=1 +# on the second machine, and so on. +node_rank=0 +# data +data_url=www.openslr.org/resources/33 + +nj=16 +dict=data/dict/lang_char.txt + +# data_type can be `raw` or `shard`. Typically, raw is used for small dataset, +# `shard` is used for large dataset which is over 1k hours, and `shard` is +# faster on reading data and training. +data_type=raw +num_utts_per_shard=1000 + +train_set=train +# Optional train_config +# 1. conf/train_transformer.yaml: Standard transformer +# 2. conf/train_conformer.yaml: Standard conformer +# 3. conf/train_unified_conformer.yaml: Unified dynamic chunk causal conformer +# 4. conf/train_unified_transformer.yaml: Unified dynamic chunk transformer +# 5. conf/train_u2++_conformer.yaml: U2++ conformer +# 6. conf/train_u2++_transformer.yaml: U2++ transformer +train_config=conf/train_base_conformer_100h.yaml +cmvn=true +dir=exp/conformer_w2vbase +checkpoint= +pretrain_dir=pretrain + +# use average_checkpoint will get better result +average_checkpoint=true +decode_checkpoint=$dir/final.pt +average_num=30 +decode_modes="ctc_greedy_search ctc_prefix_beam_search attention attention_rescoring" + +# pretrained w2v-conformer encoder +enc_init=$pretrain_dir/w2vc_base12_7wh.pt +#reinit last pretrained encoder layer : https://arxiv.org/pdf/2107.04734.pdf +enc_init_mods='encoder.global_cmvn,encoder.encoders.0.,encoder.encoders.1.,encoder.encoders.2.,encoder.encoders.3.,encoder.encoders.4.,encoder.encoders.5.,encoder.encoders.6.,encoder.encoders.7.,encoder.encoders.8.,encoder.encoders.9.,encoder.encoders.10.,encoder.embed' + +. tools/parse_options.sh || exit 1; + +if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then + echo "stage -1: Data Download" + local/download_and_untar.sh ${data} ${data_url} data_aishell + local/download_and_untar.sh ${data} ${data_url} resource_aishell +fi + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # Data preparation + local/aishell_data_prep.sh ${data}/data_aishell/wav \ + ${data}/data_aishell/transcript +fi + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # remove the space between the text labels for Mandarin dataset + for x in train dev test; do + cp data/${x}/text data/${x}/text.org + paste -d " " <(cut -f 1 -d" " data/${x}/text.org) \ + <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \ + > data/${x}/text + rm data/${x}/text.org + done + + tools/compute_cmvn_stats.py --num_workers 16 --train_config $train_config \ + --in_scp data/${train_set}/wav.scp \ + --out_cmvn data/$train_set/global_cmvn +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + echo "Make a dictionary" + mkdir -p $(dirname $dict) + echo " 0" > ${dict} # 0 is for "blank" in CTC + echo " 1" >> ${dict} # must be 1 + tools/text2token.py -s 1 -n 1 data/train/text | cut -f 2- -d" " \ + | tr " " "\n" | sort | uniq | grep -a -v -e '^\s*$' | \ + awk '{print $0 " " NR+1}' >> ${dict} + num_token=$(cat $dict | wc -l) + echo " $num_token" >> $dict +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + echo "Prepare data, prepare requried format" + for x in dev test ${train_set}; do + if [ $data_type == "shard" ]; then + tools/make_shard_list.py --num_utts_per_shard $num_utts_per_shard \ + --num_threads 16 data/$x/wav.scp data/$x/text \ + $(realpath data/$x/shards) data/$x/data.list + else + tools/make_raw_list.py data/$x/wav.scp data/$x/text \ + data/$x/data.list + fi + done +fi + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + mkdir -p $dir + # You have to rm `INIT_FILE` manually when you resume or restart a + # multi-machine training. + INIT_FILE=$dir/ddp_init + init_method=file://$(readlink -f $INIT_FILE) + echo "$0: init method is $init_method" + num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') + # Use "nccl" if it works, otherwise use "gloo" + #dist_backend="gloo" + dist_backend="nccl" + world_size=`expr $num_gpus \* $num_nodes` + echo "total gpus is: $world_size" + cmvn_opts= + $cmvn && cp data/${train_set}/global_cmvn $dir + $cmvn && cmvn_opts="--cmvn ${dir}/global_cmvn" + + # train.py rewrite $train_config to $dir/train.yaml with model input + # and output dimension, and $dir/train.yaml will be used for inference + # and export. + for ((i = 0; i < $num_gpus; ++i)); do + { + gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1]) + # Rank of each gpu/process used for knowing whether it is + # the master of a worker. + rank=`expr $node_rank \* $num_gpus + $i` + python wenet/bin/train.py --gpu $gpu_id \ + --config $train_config \ + --data_type $data_type \ + --symbol_table $dict \ + --train_data data/$train_set/data.list \ + --cv_data data/dev/data.list \ + ${checkpoint:+--checkpoint $checkpoint} \ + ${enc_init:+--enc_init $enc_init} \ + --enc_init_mods $enc_init_mods \ + --model_dir $dir \ + --ddp.init_method $init_method \ + --ddp.world_size $world_size \ + --ddp.rank $rank \ + --ddp.dist_backend $dist_backend \ + --num_workers 4 \ + $cmvn_opts \ + --pin_memory + }& + done + wait +fi + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + # Test model, please specify the model you want to test by --checkpoint + if [ ${average_checkpoint} == true ]; then + decode_checkpoint=$dir/avg_${average_num}.pt + echo "do model average and final checkpoint is $decode_checkpoint" + python wenet/bin/average_model.py \ + --dst_model $decode_checkpoint \ + --src_path $dir \ + --num ${average_num} \ + --val_best + fi + # Please specify decoding_chunk_size for unified streaming and + # non-streaming model. The default value is -1, which is full chunk + # for non-streaming inference. + decoding_chunk_size= + ctc_weight=0.5 + reverse_weight=0.0 + for mode in ${decode_modes}; do + { + test_dir=$dir/test_${mode} + mkdir -p $test_dir + python wenet/bin/recognize.py --gpu 0 \ + --mode $mode \ + --config $dir/train.yaml \ + --data_type $data_type \ + --test_data data/test/data.list \ + --checkpoint $decode_checkpoint \ + --beam_size 10 \ + --batch_size 1 \ + --penalty 0.0 \ + --dict $dict \ + --ctc_weight $ctc_weight \ + --reverse_weight $reverse_weight \ + --result_file $test_dir/text \ + ${decoding_chunk_size:+--decoding_chunk_size $decoding_chunk_size} + python tools/compute-wer.py --char=1 --v=1 \ + data/test/text $test_dir/text > $test_dir/wer + } & + done + wait +fi + + +if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then + # Export the best model you want + python wenet/bin/export_jit.py \ + --config $dir/train.yaml \ + --checkpoint $dir/avg_${average_num}.pt \ + --output_file $dir/final.zip \ + --output_quant_file $dir/final_quant.zip +fi + +# Optionally, you can add LM and test it with runtime. +if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then + # 7.1 Prepare dict + unit_file=$dict + mkdir -p data/local/dict + cp $unit_file data/local/dict/units.txt + tools/fst/prepare_dict.py $unit_file ${data}/resource_aishell/lexicon.txt \ + data/local/dict/lexicon.txt + # 7.2 Train lm + lm=data/local/lm + mkdir -p $lm + tools/filter_scp.pl data/train/text \ + $data/data_aishell/transcript/aishell_transcript_v0.8.txt > $lm/text + local/aishell_train_lms.sh + # 7.3 Build decoding TLG + tools/fst/compile_lexicon_token_fst.sh \ + data/local/dict data/local/tmp data/local/lang + tools/fst/make_tlg.sh data/local/lm data/local/lang data/lang_test || exit 1; + # 7.4 Decoding with runtime + chunk_size=-1 + ./tools/decode.sh --nj 16 \ + --beam 15.0 --lattice_beam 7.5 --max_active 7000 \ + --blank_skip_thresh 0.98 --ctc_weight 0.5 --rescoring_weight 1.0 \ + --chunk_size $chunk_size \ + --fst_path data/lang_test/TLG.fst \ + data/test/wav.scp data/test/text $dir/final.zip \ + data/lang_test/words.txt $dir/lm_with_runtime + # Please see $dir/lm_with_runtime for wer +fi + + From f1ba8524250f005305c050646550d964a3e65d39 Mon Sep 17 00:00:00 2001 From: aydentang Date: Fri, 26 Aug 2022 23:43:05 +0800 Subject: [PATCH 16/19] readme --- examples/aishell/s0_ssl/README.md | 63 +++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 examples/aishell/s0_ssl/README.md diff --git a/examples/aishell/s0_ssl/README.md b/examples/aishell/s0_ssl/README.md new file mode 100644 index 000000000..c57c1c329 --- /dev/null +++ b/examples/aishell/s0_ssl/README.md @@ -0,0 +1,63 @@ +# w2v-conformer + +This is a example to use unsupervised pretrained w2v-conformer model to fintune Aishell task. + +We pretrain conformer encoders using wav2vec 2.0 pre-training method and we use fbank features as inputs. + +The w2v-conformer model uses ISML datasets to pretrain, this is a internal dataset contains 60k hours Chinese. + + +## pretraining : + +We use two model configurations to pretrain the conformer encoder architecture: + +Base model contains 12 conformer blocks, model dimension 512, FFN dimension 2048 and 8 attention heads. +samples are batched together to not exceed 30000 frames per GPU. we train a total of 32 V100 GPUs for 800k steeps. + +Middle model contains 24 conformer blocks with model dimension 2048, FFN dimension 512 and 8 attention heads. We add a reconstruction loss to slightly improve performance. To speed up training procedure, we change The time stride of convolutional subsampling blocks to 3, so the length of the input feature becomes one sixth. samples are batched together to not exceed 20000 frames per GPU. we train a total of 32 V100 GPUs for 600k steps. + +We are also trying to train the large model with 300m parameters, and this work is ongoing. + +pretrained model link: +| model | Architecture | Model | +|-------------------|----|----| +| Base | 90Mb | - +| Middle | 180M | - | + + + +## finetuning tips: + +* After pretraining, we can build encoder-decoder based ASR system.The conformer based encoder takes the pretrained model as initialization and the transformer based decoder will be trained from scratch. Just set --enc_init_mods like 'encoder.embed.,encoder.encoders.0.,encoder.encoders.1. ...' to load customized pretrained parameters. + +* In aishell task, we carefully adjusted the learning rate to 0.0004 to get best performence we also find that if too many layers are set for decoder,the migration performance of the pre-training model will be degraded, so we only build a small transformer decoder for joint training. If the downstream task is more than 500 hours, you can increase the learning rate and the parameter amount of the decoder. + +* Please note that the final layer of the pretraining model do not provide a good initialization for fine-tuning and would benefit from being re-initialized before fine-tuning. + +# Base model performance + +## Conformer Result + +* config: conf/train_conformer_base_100h.yaml +* Training info: lr 0.0004, batch size 16, 4 gpus on V100, acc_grad 1, 80 epochs +* Decoding info: ctc_weight 0.5, average_num 35 + +| decoding mode | CER | +|---------------------------|-------| +| ctc prefix beam search | 3.9 | +| attention rescoring | 3.75 | + +# Middle model performance + +## Conformer Result + +* config: conf/train_conformer_large_100h.yaml +* Training info: lr 0.0004, batch size 16, 4 gpus on V100, acc_grad 1, 80 epochs +* Decoding info: ctc_weight 0.5, average_num 35 + +| decoding mode | CER | +|---------------------------|-------| +| ctc prefix beam search | | +| attention rescoring | | + + From b3e9244b2f6c7cd66370a5dadee635eb6d7582f0 Mon Sep 17 00:00:00 2001 From: Emiyasstarx Date: Fri, 26 Aug 2022 23:45:15 +0800 Subject: [PATCH 17/19] Update README.md --- examples/aishell/s0_ssl/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/aishell/s0_ssl/README.md b/examples/aishell/s0_ssl/README.md index c57c1c329..b83bd8852 100644 --- a/examples/aishell/s0_ssl/README.md +++ b/examples/aishell/s0_ssl/README.md @@ -30,7 +30,7 @@ pretrained model link: * After pretraining, we can build encoder-decoder based ASR system.The conformer based encoder takes the pretrained model as initialization and the transformer based decoder will be trained from scratch. Just set --enc_init_mods like 'encoder.embed.,encoder.encoders.0.,encoder.encoders.1. ...' to load customized pretrained parameters. -* In aishell task, we carefully adjusted the learning rate to 0.0004 to get best performence we also find that if too many layers are set for decoder,the migration performance of the pre-training model will be degraded, so we only build a small transformer decoder for joint training. If the downstream task is more than 500 hours, you can increase the learning rate and the parameter amount of the decoder. +* In aishell task, we carefully adjust the learning rate to 0.0004 to get best performence. we also find that if too many layers are set for decoder,the migration performance of the pre-training model will be degraded, so we only build a small transformer decoder for joint training. If the downstream task is more than 500 hours, you can increase the learning rate and the parameter amount of the decoder. * Please note that the final layer of the pretraining model do not provide a good initialization for fine-tuning and would benefit from being re-initialized before fine-tuning. From ccc425ed40ff4b770678654e306e6d190d379088 Mon Sep 17 00:00:00 2001 From: Emiyasstarx Date: Fri, 26 Aug 2022 23:52:26 +0800 Subject: [PATCH 18/19] Update README.md --- examples/aishell/s0_ssl/README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/aishell/s0_ssl/README.md b/examples/aishell/s0_ssl/README.md index b83bd8852..d680b29a4 100644 --- a/examples/aishell/s0_ssl/README.md +++ b/examples/aishell/s0_ssl/README.md @@ -12,9 +12,9 @@ The w2v-conformer model uses ISML datasets to pretrain, this is a internal datas We use two model configurations to pretrain the conformer encoder architecture: Base model contains 12 conformer blocks, model dimension 512, FFN dimension 2048 and 8 attention heads. -samples are batched together to not exceed 30000 frames per GPU. we train a total of 32 V100 GPUs for 800k steeps. +samples are batched together to not exceed 30000 frames per GPU. we train a total of 32 A100 GPUs for 800k steeps. -Middle model contains 24 conformer blocks with model dimension 2048, FFN dimension 512 and 8 attention heads. We add a reconstruction loss to slightly improve performance. To speed up training procedure, we change The time stride of convolutional subsampling blocks to 3, so the length of the input feature becomes one sixth. samples are batched together to not exceed 20000 frames per GPU. we train a total of 32 V100 GPUs for 600k steps. +Middle model contains 24 conformer blocks with model dimension 2048, FFN dimension 512 and 8 attention heads. We add a reconstruction loss to slightly improve performance. To speed up training procedure, we change The time stride of convolutional subsampling blocks to 3, so the length of the input feature becomes one sixth. samples are batched together to not exceed 20000 frames per GPU. we train a total of 32 A100 GPUs for 600k steps. We are also trying to train the large model with 300m parameters, and this work is ongoing. @@ -39,7 +39,7 @@ pretrained model link: ## Conformer Result * config: conf/train_conformer_base_100h.yaml -* Training info: lr 0.0004, batch size 16, 4 gpus on V100, acc_grad 1, 80 epochs +* Training info: lr 0.0004, batch size 16, 4 gpus on A100, acc_grad 1, 80 epochs * Decoding info: ctc_weight 0.5, average_num 35 | decoding mode | CER | @@ -52,7 +52,7 @@ pretrained model link: ## Conformer Result * config: conf/train_conformer_large_100h.yaml -* Training info: lr 0.0004, batch size 16, 4 gpus on V100, acc_grad 1, 80 epochs +* Training info: lr 0.0004, batch size 16, 4 gpus on A100, acc_grad 1, 80 epochs * Decoding info: ctc_weight 0.5, average_num 35 | decoding mode | CER | From 1269a6e5bbec440302e934f243f623baeebf2758 Mon Sep 17 00:00:00 2001 From: aydentang Date: Thu, 15 Sep 2022 10:56:54 +0800 Subject: [PATCH 19/19] update readme --- examples/aishell/s0_ssl/README.md | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/examples/aishell/s0_ssl/README.md b/examples/aishell/s0_ssl/README.md index d680b29a4..6acb14af2 100644 --- a/examples/aishell/s0_ssl/README.md +++ b/examples/aishell/s0_ssl/README.md @@ -12,11 +12,11 @@ The w2v-conformer model uses ISML datasets to pretrain, this is a internal datas We use two model configurations to pretrain the conformer encoder architecture: Base model contains 12 conformer blocks, model dimension 512, FFN dimension 2048 and 8 attention heads. -samples are batched together to not exceed 30000 frames per GPU. we train a total of 32 A100 GPUs for 800k steeps. +samples are batched together to not exceed 30000 frames per GPU. we train a total of 32 V100 GPUs for 800k steeps. -Middle model contains 24 conformer blocks with model dimension 2048, FFN dimension 512 and 8 attention heads. We add a reconstruction loss to slightly improve performance. To speed up training procedure, we change The time stride of convolutional subsampling blocks to 3, so the length of the input feature becomes one sixth. samples are batched together to not exceed 20000 frames per GPU. we train a total of 32 A100 GPUs for 600k steps. +Middle model contains 24 conformer blocks with model dimension 2048, FFN dimension 512 and 8 attention heads. We add a reconstruction loss to slightly improve performance. To speed up training procedure, we change The time stride of convolutional subsampling blocks to 3, so the length of the input feature becomes one sixth. samples are batched together to not exceed 20000 frames per GPU. we train a total of 32 V100 GPUs for 600k steps. -We are also trying to train the large model with 300m parameters, and this work is ongoing. +We are also trying to train the causal model for u2 training and large model with 300m parameters, and this work is ongoing. pretrained model link: | model | Architecture | Model | @@ -30,7 +30,7 @@ pretrained model link: * After pretraining, we can build encoder-decoder based ASR system.The conformer based encoder takes the pretrained model as initialization and the transformer based decoder will be trained from scratch. Just set --enc_init_mods like 'encoder.embed.,encoder.encoders.0.,encoder.encoders.1. ...' to load customized pretrained parameters. -* In aishell task, we carefully adjust the learning rate to 0.0004 to get best performence. we also find that if too many layers are set for decoder,the migration performance of the pre-training model will be degraded, so we only build a small transformer decoder for joint training. If the downstream task is more than 500 hours, you can increase the learning rate and the parameter amount of the decoder. +* In aishell task, we carefully adjust the learning rate to 0.0004~0.0005 to get best performence we also find that if too many layers are set for decoder,the migration performance of the pre-training model will be degraded, so we only build a small transformer decoder for joint training. If the downstream task is more than 500 hours, you can increase the learning rate and the parameter amount of the decoder. * Please note that the final layer of the pretraining model do not provide a good initialization for fine-tuning and would benefit from being re-initialized before fine-tuning. @@ -39,25 +39,27 @@ pretrained model link: ## Conformer Result * config: conf/train_conformer_base_100h.yaml -* Training info: lr 0.0004, batch size 16, 4 gpus on A100, acc_grad 1, 80 epochs -* Decoding info: ctc_weight 0.5, average_num 35 +* Training info: lr 0.0004, batch size 16, 4 gpus on A100, acc_grad 1, 250 epochs +* Decoding info: ctc_weight 0.5, average_num 35 | decoding mode | CER | |---------------------------|-------| -| ctc prefix beam search | 3.9 | -| attention rescoring | 3.75 | +| ctc greedy search | 3.86 | +| ctc prefix beam search | 3.86 | +| attention rescoring | 3.79 | # Middle model performance ## Conformer Result * config: conf/train_conformer_large_100h.yaml -* Training info: lr 0.0004, batch size 16, 4 gpus on A100, acc_grad 1, 80 epochs +* Training info: lr 0.0005, batch size 16, 4 gpus on A100, acc_grad 1, 250 epochs * Decoding info: ctc_weight 0.5, average_num 35 | decoding mode | CER | |---------------------------|-------| -| ctc prefix beam search | | -| attention rescoring | | +| ctc greedy search | 3.46 | +| ctc prefix beam search | 3.46 | +| attention rescoring | 3.37 |