From 6eff507bcb402eac93372f8460feadb54e0b87cc Mon Sep 17 00:00:00 2001 From: kxhuang Date: Thu, 3 Aug 2023 23:31:18 +0800 Subject: [PATCH 01/10] Fix context graph English word segmentation problem and token consumption problem due to context mismatch. --- runtime/core/decoder/context_graph.cc | 34 ++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/runtime/core/decoder/context_graph.cc b/runtime/core/decoder/context_graph.cc index 5d6f3009b..a04c50cf4 100644 --- a/runtime/core/decoder/context_graph.cc +++ b/runtime/core/decoder/context_graph.cc @@ -68,6 +68,9 @@ void ContextGraph::BuildContextGraph( float score = (i * config_.incremental_context_score + config_.context_score) * UTF8StringLength(words[i]); + if (IsAlpha(words[i]) || words[i][0] == kSpaceSymbol[0]) { + score = i * config_.incremental_context_score + config_.context_score; + } next_state = (i < words.size() - 1) ? ofst->AddState() : start_state; ofst->AddArc(prev_state, fst::StdArc(word_id, word_id, score, next_state)); @@ -106,6 +109,24 @@ int ContextGraph::GetNextState(int cur_state, int word_id, float* score, break; } } + if (next_state != 0) { + return next_state; + } + for (fst::ArcIterator aiter(*graph_, 0); !aiter.Done(); + aiter.Next()) { + const fst::StdArc& arc = aiter.Value(); + if (arc.ilabel == word_id) { + next_state = arc.nextstate; + *score += arc.weight.Value(); + if (cur_state == 0) { + *is_start_boundary = true; + } + if (graph_->Final(arc.nextstate) == fst::StdArc::Weight::One()) { + *is_end_boundary = true; + } + break; + } + } return next_state; } @@ -117,6 +138,7 @@ bool ContextGraph::SplitUTF8StringToWords( SplitUTF8StringToChars(Trim(str), &chars); bool no_oov = true; + bool beginning = true; for (size_t start = 0; start < chars.size();) { for (size_t end = chars.size(); end > start; --end) { std::string word; @@ -126,18 +148,28 @@ bool ContextGraph::SplitUTF8StringToWords( // Skip space. if (word == " ") { start = end; + beginning = true; continue; } // Add '▁' at the beginning of English word. - if (IsAlpha(word)) { + if (IsAlpha(word) && beginning == true) { word = kSpaceSymbol + word; } if (symbol_table->Find(word) != -1) { words->emplace_back(word); start = end; + beginning = false; continue; } + + // Matching using '▁' separately for English + if (end == start + 1 && word[0] == kSpaceSymbol[0]) { + words->emplace_back(string(kSpaceSymbol)); + beginning = false; + break; + } + if (end == start + 1) { ++start; no_oov = false; From 14f0865f2d91f415810bf66302deced95574faa9 Mon Sep 17 00:00:00 2001 From: kxhuang Date: Fri, 4 Aug 2023 00:11:05 +0800 Subject: [PATCH 02/10] fix the code format --- runtime/core/decoder/context_graph.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runtime/core/decoder/context_graph.cc b/runtime/core/decoder/context_graph.cc index a04c50cf4..e2d5e6bdb 100644 --- a/runtime/core/decoder/context_graph.cc +++ b/runtime/core/decoder/context_graph.cc @@ -169,7 +169,7 @@ bool ContextGraph::SplitUTF8StringToWords( beginning = false; break; } - + if (end == start + 1) { ++start; no_oov = false; From 9d15ce683c35a881751c7ed3e9232892fe597e85 Mon Sep 17 00:00:00 2001 From: kxhuang Date: Fri, 4 Aug 2023 09:54:38 +0800 Subject: [PATCH 03/10] Modify logic --- runtime/core/decoder/context_graph.cc | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/runtime/core/decoder/context_graph.cc b/runtime/core/decoder/context_graph.cc index e2d5e6bdb..a44aab628 100644 --- a/runtime/core/decoder/context_graph.cc +++ b/runtime/core/decoder/context_graph.cc @@ -152,7 +152,7 @@ bool ContextGraph::SplitUTF8StringToWords( continue; } // Add '▁' at the beginning of English word. - if (IsAlpha(word) && beginning == true) { + if (IsAlpha(word) && beginning) { word = kSpaceSymbol + word; } @@ -163,14 +163,13 @@ bool ContextGraph::SplitUTF8StringToWords( continue; } - // Matching using '▁' separately for English - if (end == start + 1 && word[0] == kSpaceSymbol[0]) { - words->emplace_back(string(kSpaceSymbol)); - beginning = false; - break; - } - if (end == start + 1) { + // Matching using '▁' separately for English + if (word[0] == kSpaceSymbol[0]) { + words->emplace_back(string(kSpaceSymbol)); + beginning = false; + break; + } ++start; no_oov = false; LOG(WARNING) << word << " is oov."; From 6874fe5f88b8115f2496ca9379e4613262617515 Mon Sep 17 00:00:00 2001 From: kxhuang Date: Mon, 7 Aug 2023 20:56:08 +0800 Subject: [PATCH 04/10] Add nn bias base code --- wenet/bin/train.py | 16 +++- wenet/dataset/dataset.py | 6 ++ wenet/dataset/processor.py | 128 +++++++++++++++++++++++++++- wenet/transformer/asr_model.py | 30 ++++++- wenet/transformer/context_module.py | 116 +++++++++++++++++++++++++ wenet/utils/executor.py | 39 +++++++-- wenet/utils/init_model.py | 9 ++ 7 files changed, 331 insertions(+), 13 deletions(-) create mode 100644 wenet/transformer/context_module.py diff --git a/wenet/bin/train.py b/wenet/bin/train.py index da9a6f6bb..33f80176f 100644 --- a/wenet/bin/train.py +++ b/wenet/bin/train.py @@ -272,10 +272,24 @@ def main(): # Init asr model from configs model = init_model(configs) - print(model) if local_rank == 0 else None + # print(model) if local_rank == 0 else None num_params = sum(p.numel() for p in model.parameters()) print('the number of model params: {:,d}'.format(num_params)) if local_rank == 0 else None # noqa + if 'context_conf' in configs: + for p in model.ctc.parameters(): + p.requires_grad = False + if model.decoder is not None: + for p in model.decoder.parameters(): + p.requires_grad = False + for p in model.encoder.embed.parameters(): + p.requires_grad = False + for p in model.encoder.after_norm.parameters(): + p.requires_grad = False + for layer in model.encoder.encoders: + for p in layer.parameters(): + p.requires_grad = False + # !!!IMPORTANT!!! # Try to export the model by script, if fails, we should refine # the code to satisfy the script export requirements diff --git a/wenet/dataset/dataset.py b/wenet/dataset/dataset.py index 6d799b5b5..8a338b411 100644 --- a/wenet/dataset/dataset.py +++ b/wenet/dataset/dataset.py @@ -189,5 +189,11 @@ def Dataset(data_type, batch_conf = conf.get('batch_conf', {}) dataset = Processor(dataset, processor.batch, **batch_conf) + + # context_conf = conf.get('context_conf', {}) + # if len(context_conf) != 0: + # dataset = Processor(dataset, processor.context_sampling, + # symbol_table, **context_conf) + dataset = Processor(dataset, processor.padding) return dataset diff --git a/wenet/dataset/processor.py b/wenet/dataset/processor.py index b69ceca85..1afc4f937 100644 --- a/wenet/dataset/processor.py +++ b/wenet/dataset/processor.py @@ -610,6 +610,107 @@ def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000): logging.fatal('Unsupported batch type {}'.format(batch_type)) +def context_sampling(data, + symbol_table, + len_min, + len_max, + batch_num_context, + ): + """context_sampling + + Args: + data: Iterable[List[{key, feat, label}]] + + Returns: + Iterable[List[{key, feat, label, context_list}]] + """ + rev_symbol_table = {} + for token in symbol_table: + rev_symbol_table[symbol_table[token]] = token + context_list_over_all = [] + print("@@@@@@@ start sampling") + for sample in data: + print("11111111", sample) + batch_label = [sample[i]['label'] for i in range(len(sample))] + print(batch_label) + context_list = [] + for utt_label in batch_label: + st_index_list = [] + for i in range(len(utt_label)): + if '▁' not in rev_symbol_table: + st_index_list.append(i) + elif rev_symbol_table[i][0] == '▁': + st_index_list.append(i) + + st_select = [] + en_select = [] + num_context = 3 + for _ in range(0, num_context): + random_len = random.randint(min(len(st_index_list), len_min), + min(len(st_index_list), len_max)) + random_index = random.randint(0, len(st_index_list) - + random_len - 1) + st_index = st_index_list[random_index] + en_index = st_index_list[random_index + random_len] + context_label = utt_label[st_index: en_index] + cross_flag = True + for i in range(len(st_select)): + if st_index >= st_select[i] and st_index < en_select[i]: + cross_flag = False + elif en_index > st_select[i] and en_index <= en_select[i]: + cross_flag = False + elif st_index < st_select[i] and en_index > en_select[i]: + cross_flag = False + if cross_flag: + context_list.append(context_label) + st_select.append(st_index) + en_select.append(en_index) + + if len(context_list) > batch_num_context: + context_list_over_all = context_list + elif len(context_list) + len(context_list_over_all) > batch_num_context: + context_list_over_all.extend(context_list) + context_list_over_all = context_list_over_all[-batch_num_context:] + else: + context_list_over_all.extend(context_list) + context_list = context_list_over_all + context_list.insert(0, torch.tensor([0])) + for i in range(len(sample)): + print(sample[i]['label']) + print(context_list) + print("_______________") + sample[i]['context_list'] = context_list + yield sample + + +def context_label_generate(label=[], context_list=[]): + """ generate context label + + Args: + + Returns + """ + context_labels = [] + context_label_length = [] + for x in label: + cur_len = len(x) + context_label = [] + count = 0 + for i in range(cur_len): + for j in range(1, len(context_list)): + if i + len(context_list[j]) > cur_len: + continue + if x[i:i + len(context_list[j])].equal(context_list[j]): + count = max(count, len(context_list[j])) + if count > 0: + context_label.append(x[i]) + count -= 1 + context_length.append(len(context_label)) + context_label = torch.tensor(context_label, dtype=torch.int64) + context_labels.append(context_label) + return context_labels + + def padding(data): """ Padding the data into training data @@ -641,5 +742,28 @@ def padding(data): batch_first=True, padding_value=-1) - yield (sorted_keys, padded_feats, padding_labels, feats_lengths, - label_lengths) + if 'context_list' not in sample[0]: + yield (sorted_keys, padded_feats, padding_labels, feats_lengths, + label_lengths, torch.tensor([0]), torch.tensor([0]), + torch.tensor([0]), torch.tensor([0])) + + sorted_context_lists = [sample[i]['context_list'] for i in order] + context_list_lengths = torch.tensor([x.size(0) + for x in sorted_context_lists], dtype=torch.int32) + padding_context_lists = pad_sequence(sorted_context_lists, + batch_first=True, + padding_value=-1) + + + sorted_context_labels = context_label_generate(sorted_labels, + sorted_context_lists) + context_label_lengths = torch.tensor([x.size(0) + for x in sorted_context_labels], dtype=torch.int32) + padding_context_labels = pad_sequence(sorted_context_labels, + batch_first=True, + padding_value=-1) + + yield (sorted_keys, padded_feats, padding_labels, + feats_lengths, label_lengths, padding_context_lists, + padding_context_labels, context_list_lengths, + context_label_lengths) \ No newline at end of file diff --git a/wenet/transformer/asr_model.py b/wenet/transformer/asr_model.py index f593d067b..7d44ae29c 100644 --- a/wenet/transformer/asr_model.py +++ b/wenet/transformer/asr_model.py @@ -35,6 +35,7 @@ from wenet.transformer.decoder import TransformerDecoder from wenet.transformer.encoder import TransformerEncoder from wenet.transformer.label_smoothing_loss import LabelSmoothingLoss +from wenet.transformer.context_module import ContextModule from wenet.utils.common import (IGNORE_ID, add_sos_eos, log_add, remove_duplicates_and_blank, th_accuracy, reverse_pad_list) @@ -51,12 +52,14 @@ def __init__( encoder: TransformerEncoder, decoder: TransformerDecoder, ctc: CTC, + context_module: ContextModule = None, ctc_weight: float = 0.5, ignore_id: int = IGNORE_ID, reverse_weight: float = 0.0, lsm_weight: float = 0.0, length_normalized_loss: bool = False, lfmmi_dir: str = '', + bias_weight: float = 0.03, ): assert 0.0 <= ctc_weight <= 1.0, ctc_weight @@ -68,10 +71,12 @@ def __init__( self.ignore_id = ignore_id self.ctc_weight = ctc_weight self.reverse_weight = reverse_weight + self.bias_weight = bias_weight self.encoder = encoder self.decoder = decoder self.ctc = ctc + self.context_module = context_module self.criterion_att = LabelSmoothingLoss( size=vocab_size, padding_idx=ignore_id, @@ -88,6 +93,10 @@ def forward( speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, + context_list: torch.Tensor = torch.tensor([0]), + context_list_lengths: torch.Tensor = torch.tensor([0]), + context_label: torch.Tensor = torch.tensor([0]), + context_label_lengths: torch.Tensor = torch.tensor([0]), ) -> Dict[str, Optional[torch.Tensor]]: """Frontend + Encoder + Decoder + Calc loss @@ -107,6 +116,18 @@ def forward( encoder_out, encoder_mask = self.encoder(speech, speech_lengths) encoder_out_lens = encoder_mask.squeeze(1).sum(1) + # 1a. Context biasing branch + if self.context_module is not None: + context_emb = self.context_module. \ + forward_context_emb(context_list, context_list_lengths) + encoder_out, bias_out = self.context_module(context_emb, + encoder_out) + loss_bias = self.context_module.bias_loss(bias_out, context_label, + encoder_out_lens, + context_label_lengths) + else: + loss_bias = None + # 2a. Attention-decoder branch if self.ctc_weight != 1.0: loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask, @@ -125,14 +146,17 @@ def forward( else: loss_ctc = None + # TODO: 找出jit导出不成功的原因 if loss_ctc is None: loss = loss_att elif loss_att is None: loss = loss_ctc else: - loss = self.ctc_weight * loss_ctc + (1 - - self.ctc_weight) * loss_att - return {"loss": loss, "loss_att": loss_att, "loss_ctc": loss_ctc} + loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * \ + loss_att #+ self.bias_weight * loss_bias + + return {"loss": loss, "loss_att": loss_att, "loss_ctc": loss_ctc, + "loss_bias": loss_bias} def _calc_att_loss( self, diff --git a/wenet/transformer/context_module.py b/wenet/transformer/context_module.py new file mode 100644 index 000000000..2afd2aee8 --- /dev/null +++ b/wenet/transformer/context_module.py @@ -0,0 +1,116 @@ +# Copyright (c) 2023 ASLP@NWPU (authors: Kaixun Huang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + +import torch +import torch.nn as nn +from typing import Tuple +from wenet.transformer.attention import MultiHeadedAttention + + +class BLSTM(torch.nn.Module): + """ + """ + + def __init__(self, + vocab_size, + embedding_size, + num_layers, + dropout=0.0): + super(BLSTM, self).__init__() + self.vocab_size = vocab_size + self.embedding_size = embedding_size + self.word_embedding = torch.nn.Embedding( + self.vocab_size, self.embedding_size) + + self.sen_rnn = torch.nn.LSTM(input_size=self.embedding_size, + hidden_size=self.embedding_size, + num_layers=num_layers, + dropout=dropout, + batch_first=True, + bidirectional=True) + + def forward(self, sen_batch, sen_lengths): + sen_batch = torch.clamp(sen_batch, 0) + sen_batch = self.word_embedding(sen_batch) + pack_seq = torch.nn.utils.rnn.pack_padded_sequence( + sen_batch, sen_lengths.to('cpu').type(torch.int32), + batch_first=True, enforce_sorted=False) + _, last_state = self.sen_rnn(pack_seq) + laste_h = last_state[0] + laste_c = last_state[1] + state = torch.cat([laste_h[-1, :, :], laste_h[0, :, :], + laste_c[-1, :, :], laste_c[0, :, :]], dim=-1) + return state + + +class ContextModule(torch.nn.Module): + """ + """ + def __init__( + self, + vocab_size: int, + embedding_size: int, + encoder_layers: int = 2, + attention_heads: int = 4, + dropout_rate: float = 0.0, + ): + super().__init__() + self.embedding_size = embedding_size + self.encoder_layers = encoder_layers + self.vocab_size = vocab_size + self.attention_heads = attention_heads + self.dropout_rate = dropout_rate + + self.context_extractor = BLSTM(self.vocab_size, self.embedding_size, + self.encoder_layers) + self.context_encoder = nn.Sequential( + nn.Linear(self.embedding_size * 4, self.embedding_size), + nn.LayerNorm(self.embedding_size) + ) + + self.biasing_layer = MultiHeadedAttention( + n_head=self.attention_heads, + n_feat=self.embedding_size, + dropout_rate=self.dropout_rate + ) + + self.combiner = nn.Sequential( + nn.Linear(self.embedding_size, self.embedding_size), + nn.LayerNorm(self.embedding_size) + ) + self.norm_aft_combiner = nn.LayerNorm(self.embedding_size) + + self.context_decoder = nn.Sequential( + nn.Linear(self.embedding_size, self.embedding_size), + nn.LayerNorm(self.embedding_size), + nn.ReLU(inplace=True), + nn.Linear(self.embedding_size, self.vocab_size), + ) + + self.bias_loss = torch.nn.CTCLoss(reduction="sum") + + def forward_context_emb(self, context_list, context_lengths) -> torch.Tensor: + context_emb = self.context_extractor(context_list, context_lengths) + context_emb = self.context_encoder(context_emb.unsqueeze(0)) + return context_emb + + def forward(self, context_emb, encoder_out) -> Tuple[torch.Tensor, torch.Tensor]: + context_emb = context_emb.expand(encoder_out.shape[0],-1,-1) + context_emb, _ = self.biasing_layer(encoder_out, context_emb, + context_emb) + bias_out = self.context_decoder(context_emb) + encoder_bias_out = self.norm_aft_combiner(encoder_out + + self.combiner(context_emb)) + return encoder_bias_out, bias_out \ No newline at end of file diff --git a/wenet/utils/executor.py b/wenet/utils/executor.py index a128f6d6c..419b0e278 100644 --- a/wenet/utils/executor.py +++ b/wenet/utils/executor.py @@ -58,14 +58,22 @@ def train(self, model, optimizer, scheduler, data_loader, device, writer, else: model_context = nullcontext num_seen_utts = 0 + print("____ mark 0") with model_context(): for batch_idx, batch in enumerate(data_loader): - key, feats, target, feats_lengths, target_lengths = batch + key, feats, target, feats_lengths, target_lengths, + context_list, context_label, context_list_lengths, + context_label_lengths = batch feats = feats.to(device) target = target.to(device) feats_lengths = feats_lengths.to(device) target_lengths = target_lengths.to(device) + context_list = context_list.to(device) + context_label = context_label.to(device) + context_list_lengths = context_list_lengths.to(device) + context_label_lengths = context_label_lengths.to(device) num_utts = target_lengths.size(0) + print("____ mark 1") if num_utts == 0: continue context = None @@ -85,7 +93,10 @@ def train(self, model, optimizer, scheduler, data_loader, device, writer, dtype=ds_dtype, cache_enabled=False ): loss_dict = model(feats, feats_lengths, target, - target_lengths) + target_lengths, context_list, + context_label, + context_list_lengths, + context_label_lengths) loss = loss_dict['loss'] # NOTE(xcsong): Zeroing the gradients is handled automatically by DeepSpeed after the weights # noqa # have been updated using a mini-batch. DeepSpeed also performs gradient averaging automatically # noqa @@ -99,7 +110,10 @@ def train(self, model, optimizer, scheduler, data_loader, device, writer, # https://pytorch.org/docs/stable/notes/amp_examples.html with torch.cuda.amp.autocast(scaler is not None): loss_dict = model(feats, feats_lengths, target, - target_lengths) + target_lengths, context_list, + context_label, + context_list_lengths, + context_label_lengths) loss = loss_dict['loss'] / accum_grad if use_amp: scaler.scale(loss).backward() @@ -170,11 +184,17 @@ def cv(self, model, data_loader, device, args): total_loss = 0.0 with torch.no_grad(): for batch_idx, batch in enumerate(data_loader): - key, feats, target, feats_lengths, target_lengths = batch + key, feats, target, feats_lengths, target_lengths, + context_list, context_label, context_list_lengths, + context_label_lengths = batch feats = feats.to(device) target = target.to(device) feats_lengths = feats_lengths.to(device) target_lengths = target_lengths.to(device) + context_list = context_list.to(device) + context_label = context_label.to(device) + context_list_lengths = context_list_lengths.to(device) + context_label_lengths = context_label_lengths.to(device) num_utts = target_lengths.size(0) if num_utts == 0: continue @@ -183,10 +203,15 @@ def cv(self, model, data_loader, device, args): enabled=ds_dtype is not None, dtype=ds_dtype, cache_enabled=False ): - loss_dict = model(feats, feats_lengths, - target, target_lengths) + loss_dict = model(feats, feats_lengths, target, + target_lengths, context_list, + context_label, context_list_lengths, + context_label_lengths) else: - loss_dict = model(feats, feats_lengths, target, target_lengths) + loss_dict = model(feats, feats_lengths, target, + target_lengths, context_list, + context_label, context_list_lengths, + context_label_lengths) loss = loss_dict['loss'] if torch.isfinite(loss): num_seen_utts += num_utts diff --git a/wenet/utils/init_model.py b/wenet/utils/init_model.py index f01e03469..acb496f58 100644 --- a/wenet/utils/init_model.py +++ b/wenet/utils/init_model.py @@ -22,6 +22,7 @@ from wenet.transformer.ctc import CTC from wenet.transformer.decoder import BiTransformerDecoder, TransformerDecoder from wenet.transformer.encoder import ConformerEncoder, TransformerEncoder +from wenet.transformer.context_module import ContextModule from wenet.branchformer.encoder import BranchformerEncoder from wenet.squeezeformer.encoder import SqueezeformerEncoder from wenet.efficient_conformer.encoder import EfficientConformerEncoder @@ -79,6 +80,13 @@ def init_model(configs): **configs['decoder_conf']) ctc = CTC(vocab_size, encoder.output_size()) + context_module_type = configs.get('context_module', '') + if context_module_type == 'cppn': + context_module = ContextModule(vocab_size, + **configs['context_module_conf']) + else: + context_module = None + # Init joint CTC/Attention or Transducer model if 'predictor' in configs: predictor_type = configs.get('predictor', 'rnn') @@ -122,6 +130,7 @@ def init_model(configs): encoder=encoder, decoder=decoder, ctc=ctc, + context_module=context_module, lfmmi_dir=configs.get('lfmmi_dir', ''), **configs['model_conf']) return model From 0f2cd4c3497dd8fe431781b54deb2627729b4d1f Mon Sep 17 00:00:00 2001 From: kaixunhuang0 Date: Fri, 11 Aug 2023 07:07:43 -0400 Subject: [PATCH 05/10] nn bias dataset --- wenet/dataset/dataset.py | 8 ++++---- wenet/dataset/processor.py | 34 ++++++++++++++-------------------- wenet/transformer/asr_model.py | 5 +++-- wenet/utils/executor.py | 10 ++++------ 4 files changed, 25 insertions(+), 32 deletions(-) diff --git a/wenet/dataset/dataset.py b/wenet/dataset/dataset.py index 8a338b411..5280a2578 100644 --- a/wenet/dataset/dataset.py +++ b/wenet/dataset/dataset.py @@ -190,10 +190,10 @@ def Dataset(data_type, batch_conf = conf.get('batch_conf', {}) dataset = Processor(dataset, processor.batch, **batch_conf) - # context_conf = conf.get('context_conf', {}) - # if len(context_conf) != 0: - # dataset = Processor(dataset, processor.context_sampling, - # symbol_table, **context_conf) + context_conf = conf.get('context_conf', {}) + if len(context_conf) != 0: + dataset = Processor(dataset, processor.context_sampling, + symbol_table, **context_conf) dataset = Processor(dataset, processor.padding) return dataset diff --git a/wenet/dataset/processor.py b/wenet/dataset/processor.py index 1afc4f937..7b6ea93ae 100644 --- a/wenet/dataset/processor.py +++ b/wenet/dataset/processor.py @@ -628,26 +628,24 @@ def context_sampling(data, for token in symbol_table: rev_symbol_table[symbol_table[token]] = token context_list_over_all = [] - print("@@@@@@@ start sampling") for sample in data: - print("11111111", sample) batch_label = [sample[i]['label'] for i in range(len(sample))] - print(batch_label) context_list = [] for utt_label in batch_label: st_index_list = [] for i in range(len(utt_label)): - if '▁' not in rev_symbol_table: + if '▁' not in symbol_table: st_index_list.append(i) - elif rev_symbol_table[i][0] == '▁': + elif rev_symbol_table[utt_label[i]][0] == '▁': st_index_list.append(i) + st_index_list.append(len(utt_label)) st_select = [] en_select = [] num_context = 3 for _ in range(0, num_context): - random_len = random.randint(min(len(st_index_list), len_min), - min(len(st_index_list), len_max)) + random_len = random.randint(min(len(st_index_list) - 1, len_min), + min(len(st_index_list) - 1, len_max)) random_index = random.randint(0, len(st_index_list) - random_len - 1) st_index = st_index_list[random_index] @@ -673,13 +671,11 @@ def context_sampling(data, context_list_over_all = context_list_over_all[-batch_num_context:] else: context_list_over_all.extend(context_list) - context_list = context_list_over_all - context_list.insert(0, torch.tensor([0])) - for i in range(len(sample)): - print(sample[i]['label']) - print(context_list) - print("_______________") - sample[i]['context_list'] = context_list + context_list = context_list_over_all.copy() + context_list.insert(0, [0]) + for i in range(len(context_list)): + context_list[i] = torch.tensor(context_list[i], dtype=torch.int64) + sample[0]['context_list'] = context_list yield sample @@ -705,7 +701,6 @@ def context_label_generate(label=[], context_list=[]): if count > 0: context_label.append(x[i]) count -= 1 - context_length.append(len(context_label)) context_label = torch.tensor(context_label, dtype=torch.int64) context_labels.append(context_label) return context_labels @@ -747,22 +742,21 @@ def padding(data): label_lengths, torch.tensor([0]), torch.tensor([0]), torch.tensor([0]), torch.tensor([0])) - sorted_context_lists = [sample[i]['context_list'] for i in order] + context_lists = sample[0]['context_list'] context_list_lengths = torch.tensor([x.size(0) - for x in sorted_context_lists], dtype=torch.int32) - padding_context_lists = pad_sequence(sorted_context_lists, + for x in context_lists], dtype=torch.int32) + padding_context_lists = pad_sequence(context_lists, batch_first=True, padding_value=-1) sorted_context_labels = context_label_generate(sorted_labels, - sorted_context_lists) + context_lists) context_label_lengths = torch.tensor([x.size(0) for x in sorted_context_labels], dtype=torch.int32) padding_context_labels = pad_sequence(sorted_context_labels, batch_first=True, padding_value=-1) - yield (sorted_keys, padded_feats, padding_labels, feats_lengths, label_lengths, padding_context_lists, padding_context_labels, context_list_lengths, diff --git a/wenet/transformer/asr_model.py b/wenet/transformer/asr_model.py index 7d44ae29c..71bf439ca 100644 --- a/wenet/transformer/asr_model.py +++ b/wenet/transformer/asr_model.py @@ -94,8 +94,8 @@ def forward( text: torch.Tensor, text_lengths: torch.Tensor, context_list: torch.Tensor = torch.tensor([0]), - context_list_lengths: torch.Tensor = torch.tensor([0]), context_label: torch.Tensor = torch.tensor([0]), + context_list_lengths: torch.Tensor = torch.tensor([0]), context_label_lengths: torch.Tensor = torch.tensor([0]), ) -> Dict[str, Optional[torch.Tensor]]: """Frontend + Encoder + Decoder + Calc loss @@ -122,6 +122,7 @@ def forward( forward_context_emb(context_list, context_list_lengths) encoder_out, bias_out = self.context_module(context_emb, encoder_out) + bias_out = bias_out.transpose(0, 1).log_softmax(2) loss_bias = self.context_module.bias_loss(bias_out, context_label, encoder_out_lens, context_label_lengths) @@ -153,7 +154,7 @@ def forward( loss = loss_ctc else: loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * \ - loss_att #+ self.bias_weight * loss_bias + loss_att + self.bias_weight * loss_bias return {"loss": loss, "loss_att": loss_att, "loss_ctc": loss_ctc, "loss_bias": loss_bias} diff --git a/wenet/utils/executor.py b/wenet/utils/executor.py index 419b0e278..9d356da08 100644 --- a/wenet/utils/executor.py +++ b/wenet/utils/executor.py @@ -58,11 +58,10 @@ def train(self, model, optimizer, scheduler, data_loader, device, writer, else: model_context = nullcontext num_seen_utts = 0 - print("____ mark 0") with model_context(): for batch_idx, batch in enumerate(data_loader): - key, feats, target, feats_lengths, target_lengths, - context_list, context_label, context_list_lengths, + key, feats, target, feats_lengths, target_lengths, \ + context_list, context_label, context_list_lengths, \ context_label_lengths = batch feats = feats.to(device) target = target.to(device) @@ -73,7 +72,6 @@ def train(self, model, optimizer, scheduler, data_loader, device, writer, context_list_lengths = context_list_lengths.to(device) context_label_lengths = context_label_lengths.to(device) num_utts = target_lengths.size(0) - print("____ mark 1") if num_utts == 0: continue context = None @@ -184,8 +182,8 @@ def cv(self, model, data_loader, device, args): total_loss = 0.0 with torch.no_grad(): for batch_idx, batch in enumerate(data_loader): - key, feats, target, feats_lengths, target_lengths, - context_list, context_label, context_list_lengths, + key, feats, target, feats_lengths, target_lengths, \ + context_list, context_label, context_list_lengths, \ context_label_lengths = batch feats = feats.to(device) target = target.to(device) From 4b188b48405a5182ec498d9c54ad84968a0f0885 Mon Sep 17 00:00:00 2001 From: kaixunhuang0 Date: Wed, 30 Aug 2023 16:31:36 +0800 Subject: [PATCH 06/10] Deep biasing is supported on the AED model, not tested on the RNNT model. --- wenet/bin/recognize.py | 18 +++- wenet/bin/train.py | 31 +++---- wenet/dataset/processor.py | 45 +++++----- wenet/transducer/transducer.py | 30 ++++++- wenet/transformer/asr_model.py | 32 +++++-- wenet/transformer/context_module.py | 26 +++--- wenet/utils/checkpoint.py | 4 + wenet/utils/context_graph.py | 133 +++++++++++++++++++++++++++- wenet/utils/executor.py | 42 +++++---- wenet/utils/init_model.py | 1 + 10 files changed, 279 insertions(+), 83 deletions(-) diff --git a/wenet/bin/recognize.py b/wenet/bin/recognize.py index 991bba515..602834f67 100644 --- a/wenet/bin/recognize.py +++ b/wenet/bin/recognize.py @@ -172,7 +172,7 @@ def get_args(): help='Context list path') parser.add_argument('--context_graph_score', type=float, - default=0.0, + default=2.0, help='''The higher the score, the greater the degree of bias using decoding-graph for biasing''') @@ -225,6 +225,9 @@ def main(): test_conf['batch_conf']['batch_size'] = args.batch_size non_lang_syms = read_non_lang_symbols(args.non_lang_syms) + if 'context_conf' in test_conf: + del test_conf['context_conf'] + test_dataset = Dataset(args.data_type, args.test_data, symbol_table, @@ -256,17 +259,26 @@ def main(): paraformer_beam_search = None context_graph = None - if 'decoding-graph' in args.context_bias_mode: + if args.context_bias_mode != '': context_graph = ContextGraph(args.context_list_path, symbol_table, args.bpe_model, args.context_graph_score) + if 'deep-biasing' in args.context_bias_mode: + context_graph.deep_biasing = True + if 'decoding-graph' in args.context_bias_mode: + context_graph.graph_biasing = True + if 'deep-biasing' in args.context_bias_mode and \ + 'decoding-graph' in args.context_bias_mode: + context_graph.deep_biasing_weight = 1.0 + context_graph.graph_biasing_weight = 0.7 with torch.no_grad(), open(args.result_file, 'w') as fout: for batch_idx, batch in enumerate(test_data_loader): - keys, feats, target, feats_lengths, target_lengths = batch + keys, feats, target, feats_lengths, target_lengths, _, _, _, _ = batch feats = feats.to(device) target = target.to(device) feats_lengths = feats_lengths.to(device) target_lengths = target_lengths.to(device) + if args.mode == 'attention': hyps, _ = model.recognize( feats, diff --git a/wenet/bin/train.py b/wenet/bin/train.py index 33f80176f..8756fee89 100644 --- a/wenet/bin/train.py +++ b/wenet/bin/train.py @@ -272,30 +272,25 @@ def main(): # Init asr model from configs model = init_model(configs) - # print(model) if local_rank == 0 else None + print(model) if local_rank == 0 else None num_params = sum(p.numel() for p in model.parameters()) print('the number of model params: {:,d}'.format(num_params)) if local_rank == 0 else None # noqa - if 'context_conf' in configs: - for p in model.ctc.parameters(): + # Freeze other parts of the model during training context bias module + if 'context_module_conf' in configs: + for p in model.parameters(): p.requires_grad = False - if model.decoder is not None: - for p in model.decoder.parameters(): - p.requires_grad = False - for p in model.encoder.embed.parameters(): + for p in model.context_module.parameters(): + p.requires_grad = True + for p in model.context_module.context_decoder_ctc_linear.parameters(): p.requires_grad = False - for p in model.encoder.after_norm.parameters(): - p.requires_grad = False - for layer in model.encoder.encoders: - for p in layer.parameters(): - p.requires_grad = False - # !!!IMPORTANT!!! - # Try to export the model by script, if fails, we should refine - # the code to satisfy the script export requirements - if local_rank == 0: - script_model = torch.jit.script(model) - script_model.save(os.path.join(args.model_dir, 'init.zip')) + # # !!!IMPORTANT!!! + # # Try to export the model by script, if fails, we should refine + # # the code to satisfy the script export requirements + # if local_rank == 0: + # script_model = torch.jit.script(model) + # script_model.save(os.path.join(args.model_dir, 'init.zip')) executor = Executor() # If specify checkpoint, load some info from checkpoint if args.checkpoint is not None: diff --git a/wenet/dataset/processor.py b/wenet/dataset/processor.py index 7b6ea93ae..73b8295f2 100644 --- a/wenet/dataset/processor.py +++ b/wenet/dataset/processor.py @@ -674,12 +674,12 @@ def context_sampling(data, context_list = context_list_over_all.copy() context_list.insert(0, [0]) for i in range(len(context_list)): - context_list[i] = torch.tensor(context_list[i], dtype=torch.int64) + context_list[i] = torch.tensor(context_list[i], dtype=torch.int32) sample[0]['context_list'] = context_list yield sample -def context_label_generate(label=[], context_list=[]): +def context_label_generate(label, context_list): """ generate context label Args: @@ -687,7 +687,6 @@ def context_label_generate(label=[], context_list=[]): Returns """ context_labels = [] - context_label_length = [] for x in label: cur_len = len(x) context_label = [] @@ -741,23 +740,23 @@ def padding(data): yield (sorted_keys, padded_feats, padding_labels, feats_lengths, label_lengths, torch.tensor([0]), torch.tensor([0]), torch.tensor([0]), torch.tensor([0])) - - context_lists = sample[0]['context_list'] - context_list_lengths = torch.tensor([x.size(0) - for x in context_lists], dtype=torch.int32) - padding_context_lists = pad_sequence(context_lists, - batch_first=True, - padding_value=-1) - - - sorted_context_labels = context_label_generate(sorted_labels, - context_lists) - context_label_lengths = torch.tensor([x.size(0) - for x in sorted_context_labels], dtype=torch.int32) - padding_context_labels = pad_sequence(sorted_context_labels, - batch_first=True, - padding_value=-1) - yield (sorted_keys, padded_feats, padding_labels, - feats_lengths, label_lengths, padding_context_lists, - padding_context_labels, context_list_lengths, - context_label_lengths) \ No newline at end of file + else: + context_lists = sample[0]['context_list'] + context_list_lengths = \ + torch.tensor([x.size(0) for x in context_lists], dtype=torch.int32) + padding_context_lists = pad_sequence(context_lists, + batch_first=True, + padding_value=-1) + + sorted_context_labels = context_label_generate(sorted_labels, + context_lists) + context_label_lengths = \ + torch.tensor([x.size(0) for x in sorted_context_labels], + dtype=torch.int32) + padding_context_labels = pad_sequence(sorted_context_labels, + batch_first=True, + padding_value=-1) + yield (sorted_keys, padded_feats, padding_labels, + feats_lengths, label_lengths, padding_context_lists, + padding_context_labels, context_list_lengths, + context_label_lengths) diff --git a/wenet/transducer/transducer.py b/wenet/transducer/transducer.py index c6b3b1f71..331677711 100644 --- a/wenet/transducer/transducer.py +++ b/wenet/transducer/transducer.py @@ -18,6 +18,7 @@ from wenet.transformer.ctc import CTC from wenet.transformer.decoder import BiTransformerDecoder, TransformerDecoder from wenet.transformer.label_smoothing_loss import LabelSmoothingLoss +from wenet.transformer.context_module import ContextModule from wenet.utils.common import (IGNORE_ID, add_blank, add_sos_eos, reverse_pad_list) @@ -35,6 +36,8 @@ def __init__( BiTransformerDecoder]] = None, ctc: Optional[CTC] = None, ctc_weight: float = 0, + context_module: ContextModule = None, + bias_weight: float = 0, ignore_id: int = IGNORE_ID, reverse_weight: float = 0.0, lsm_weight: float = 0.0, @@ -50,8 +53,8 @@ def __init__( ) -> None: assert attention_weight + ctc_weight + transducer_weight == 1.0 super().__init__(vocab_size, encoder, attention_decoder, ctc, - ctc_weight, ignore_id, reverse_weight, lsm_weight, - length_normalized_loss) + context_module, ctc_weight, ignore_id, reverse_weight, + lsm_weight, length_normalized_loss, '', bias_weight) self.blank = blank self.transducer_weight = transducer_weight @@ -93,6 +96,10 @@ def forward( text: torch.Tensor, text_lengths: torch.Tensor, steps: int = 0, + context_list: torch.Tensor = torch.tensor([0]), + context_label: torch.Tensor = torch.tensor([0]), + context_list_lengths: torch.Tensor = torch.tensor([0]), + context_label_lengths: torch.Tensor = torch.tensor([0]), ) -> Dict[str, Optional[torch.Tensor]]: """Frontend + Encoder + predictor + joint + loss @@ -112,6 +119,22 @@ def forward( encoder_out, encoder_mask = self.encoder(speech, speech_lengths) encoder_out_lens = encoder_mask.squeeze(1).sum(1) + # Context biasing branch + loss_bias: Optional[torch.Tensor] = None + if self.context_module is not None: + context_emb = self.context_module. \ + forward_context_emb(context_list, context_list_lengths) + encoder_out, bias_out = self.context_module(context_emb, + encoder_out) + bias_out = bias_out.transpose(0, 1).log_softmax(2) + loss_bias = self.context_module.bias_loss(bias_out, + context_label, + encoder_out_lens, + context_label_lengths + ) / bias_out.size(1) + else: + loss_bias = None + # compute_loss loss_rnnt = compute_loss( self, @@ -143,12 +166,15 @@ def forward( loss = loss + self.ctc_weight * loss_ctc.sum() if loss_att is not None: loss = loss + self.attention_decoder_weight * loss_att.sum() + if loss_bias is not None: + loss = loss + self.bias_weight * loss_bias.sum() # NOTE: 'loss' must be in dict return { 'loss': loss, 'loss_att': loss_att, 'loss_ctc': loss_ctc, 'loss_rnnt': loss_rnnt, + 'loss_bias': loss_bias, } def init_bs(self): diff --git a/wenet/transformer/asr_model.py b/wenet/transformer/asr_model.py index 71bf439ca..7d986b71e 100644 --- a/wenet/transformer/asr_model.py +++ b/wenet/transformer/asr_model.py @@ -117,15 +117,17 @@ def forward( encoder_out_lens = encoder_mask.squeeze(1).sum(1) # 1a. Context biasing branch + loss_bias: Optional[torch.Tensor] = None if self.context_module is not None: context_emb = self.context_module. \ forward_context_emb(context_list, context_list_lengths) encoder_out, bias_out = self.context_module(context_emb, - encoder_out) + encoder_out) bias_out = bias_out.transpose(0, 1).log_softmax(2) loss_bias = self.context_module.bias_loss(bias_out, context_label, encoder_out_lens, context_label_lengths) + loss_bias /= bias_out.size(1) else: loss_bias = None @@ -147,14 +149,15 @@ def forward( else: loss_ctc = None - # TODO: 找出jit导出不成功的原因 if loss_ctc is None: loss = loss_att elif loss_att is None: loss = loss_ctc else: loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * \ - loss_att + self.bias_weight * loss_bias + loss_att + if loss_bias is not None: + loss = loss + self.bias_weight * loss_bias return {"loss": loss, "loss_att": loss_att, "loss_ctc": loss_ctc, "loss_bias": loss_bias} @@ -386,6 +389,7 @@ def _ctc_prefix_beam_search( num_decoding_left_chunks: int = -1, simulate_streaming: bool = False, context_graph: ContextGraph = None, + context_filtering: bool = True, ) -> Tuple[List[List[int]], torch.Tensor]: """ CTC prefix beam search inner implementation @@ -417,6 +421,24 @@ def _ctc_prefix_beam_search( speech, speech_lengths, decoding_chunk_size, num_decoding_left_chunks, simulate_streaming) # (B, maxlen, encoder_dim) + + if context_graph is not None and context_graph.deep_biasing: + if context_filtering: + ctc_probs = self.ctc.log_softmax(encoder_out).squeeze(0) + filtered_context_list = \ + context_graph.tow_stage_filtering(context_graph.context_list, + ctc_probs, -6) + context_list, context_list_lengths = \ + context_graph.get_context_list_tensor(filtered_context_list) + else: + context_list, context_list_lengths = \ + context_graph.get_context_list_tensor(context_graph.context_list) + context_emb = self.context_module. \ + forward_context_emb(context_list, context_list_lengths) + encoder_out, _ = \ + self.context_module(context_emb, encoder_out, + context_graph.deep_biasing_weight, True) + maxlen = encoder_out.size(1) ctc_probs = self.ctc.log_softmax( encoder_out) # (1, maxlen, vocab_size) @@ -451,7 +473,7 @@ def _ctc_prefix_beam_search( n_prefix = prefix + (s, ) n_pb, n_pnb, _, _ = next_hyps[n_prefix] new_c_state, new_c_score = 0, 0 - if context_graph is not None: + if context_graph is not None and context_graph.graph_biasing: new_c_state, new_c_score = context_graph. \ find_next_state(c_state, s) n_pnb = log_add([n_pnb, pb + ps]) @@ -461,7 +483,7 @@ def _ctc_prefix_beam_search( n_prefix = prefix + (s, ) n_pb, n_pnb, _, _ = next_hyps[n_prefix] new_c_state, new_c_score = 0, 0 - if context_graph is not None: + if context_graph is not None and context_graph.graph_biasing: new_c_state, new_c_score = context_graph. \ find_next_state(c_state, s) n_pnb = log_add([n_pnb, pb + ps, pnb + ps]) diff --git a/wenet/transformer/context_module.py b/wenet/transformer/context_module.py index 2afd2aee8..153ea7d6d 100644 --- a/wenet/transformer/context_module.py +++ b/wenet/transformer/context_module.py @@ -86,31 +86,35 @@ def __init__( dropout_rate=self.dropout_rate ) - self.combiner = nn.Sequential( - nn.Linear(self.embedding_size, self.embedding_size), - nn.LayerNorm(self.embedding_size) - ) + self.combiner = nn.Linear(self.embedding_size, self.embedding_size) self.norm_aft_combiner = nn.LayerNorm(self.embedding_size) self.context_decoder = nn.Sequential( nn.Linear(self.embedding_size, self.embedding_size), nn.LayerNorm(self.embedding_size), nn.ReLU(inplace=True), - nn.Linear(self.embedding_size, self.vocab_size), ) + self.context_decoder_ctc_linear = nn.Linear(self.embedding_size, + self.vocab_size) - self.bias_loss = torch.nn.CTCLoss(reduction="sum") + self.bias_loss = torch.nn.CTCLoss(reduction="sum", zero_infinity=True) def forward_context_emb(self, context_list, context_lengths) -> torch.Tensor: context_emb = self.context_extractor(context_list, context_lengths) context_emb = self.context_encoder(context_emb.unsqueeze(0)) return context_emb - def forward(self, context_emb, encoder_out) -> Tuple[torch.Tensor, torch.Tensor]: - context_emb = context_emb.expand(encoder_out.shape[0],-1,-1) + def forward(self, context_emb, encoder_out, + biasing_weight=1.0, recognize=False) \ + -> Tuple[torch.Tensor, torch.Tensor]: + context_emb = context_emb.expand(encoder_out.shape[0], -1, -1) context_emb, _ = self.biasing_layer(encoder_out, context_emb, context_emb) + encoder_bias_out = \ + self.norm_aft_combiner(encoder_out + + self.combiner(context_emb) * biasing_weight) + if recognize: + return encoder_bias_out, torch.tensor(0.0) bias_out = self.context_decoder(context_emb) - encoder_bias_out = self.norm_aft_combiner(encoder_out + - self.combiner(context_emb)) - return encoder_bias_out, bias_out \ No newline at end of file + bias_out = self.context_decoder_ctc_linear(bias_out) + return encoder_bias_out, bias_out diff --git a/wenet/utils/checkpoint.py b/wenet/utils/checkpoint.py index 8e0c413c7..aad9b6fa1 100644 --- a/wenet/utils/checkpoint.py +++ b/wenet/utils/checkpoint.py @@ -31,6 +31,10 @@ def load_checkpoint(model: torch.nn.Module, path: str) -> dict: logging.info('Checkpoint: loading from checkpoint %s for CPU' % path) checkpoint = torch.load(path, map_location='cpu') model.load_state_dict(checkpoint, strict=False) + if hasattr(model, 'context_module') and \ + hasattr(model.context_module, 'context_decoder_ctc_linear'): + model.context_module.context_decoder_ctc_linear \ + .load_state_dict(model.ctc.ctc_lo.state_dict()) info_path = re.sub('.pt$', '.yaml', path) configs = {} if os.path.exists(info_path): diff --git a/wenet/utils/context_graph.py b/wenet/utils/context_graph.py index bb40fa1d8..9a7f426b8 100644 --- a/wenet/utils/context_graph.py +++ b/wenet/utils/context_graph.py @@ -1,4 +1,8 @@ +import torch +from torch.nn.utils.rnn import pad_sequence + from wenet.dataset.processor import __tokenize_by_bpe_model + from typing import Dict, List @@ -37,6 +41,8 @@ def tokenize(context_list_path, symbol_table, bpe_model=None): context_list.append(labels) return context_list +def tbbm(sp, context_txt): + return __tokenize_by_bpe_model(sp, context_txt) class ContextGraph: """ Context decoding graph, constructing graph using dict instead of WFST @@ -58,6 +64,10 @@ def __init__(self, self.state2token = {} self.back_score = {0: 0.0} self.build_graph(self.context_list) + self.graph_biasing = False + self.deep_biasing = False + self.graph_biasing_weight = 1.0 + self.deep_biasing_weight = 1.5 def build_graph(self, context_list: List[List[int]]): """ Constructing the context decoding graph, add arcs with negative @@ -95,10 +105,127 @@ def find_next_state(self, now_state: int, token: int): from the starting state to avoid token consumption due to mismatches. """ if token in self.graph[now_state]: - return self.graph[now_state][token], self.context_score + return self.graph[now_state][token], \ + self.context_score * self.graph_biasing_weight back_score = self.back_score[now_state] now_state = 0 if token in self.graph[now_state]: return self.graph[now_state][ - token], back_score + self.context_score - return 0, back_score + token], (back_score + self.context_score) * self.graph_biasing_weight + return 0, back_score * self.graph_biasing_weight + + def get_context_list_tensor(self, context_list: List[List[int]]): + context_list_tensor = [torch.tensor([0], dtype=torch.int32)] + for context_token in context_list: + context_list_tensor.append(torch.tensor(context_token, dtype=torch.int32)) + context_list_lengths = torch.tensor([x.size(0) for x in context_list_tensor], + dtype=torch.int32) + context_list_tensor = pad_sequence(context_list_tensor, + batch_first=True, + padding_value=-1) + return context_list_tensor, context_list_lengths + + def tow_stage_filtering(self, + context_list: List[List[int]], + ctc_posterior: torch.Tensor, + filter_threshold: float = -8, + filter_window_size: int = 64): + if len(context_list) == 0: + return context_list + + # ctc_posterior = torch.clamp(ctc_posterior, min=2 * filter_threshold) + + SOC_score = {} + for t in range(1, ctc_posterior.shape[0]): + if t % (filter_window_size // 2) != 0 and t != ctc_posterior.shape[0] - 1: + continue + # calculate PSC + PSC_score = {} + max_posterior, _ = torch.max(ctc_posterior[max(0, + t - filter_window_size):t, :], + dim=0, keepdim=False) + max_posterior = max_posterior.tolist() + for i in range(len(context_list)): + score = sum(max_posterior[j] for j in context_list[i]) \ + / len(context_list[i]) + PSC_score[i] = max(SOC_score.get(i, -float('inf')), score) + PSC_filtered_index = [] + for i in PSC_score: + if PSC_score[i] > filter_threshold: + PSC_filtered_index.append(i) + if len(PSC_filtered_index) == 0: + continue + filtered_context_list = [] + for i in PSC_filtered_index: + filtered_context_list.append(context_list[i]) + + # calculate SOC + win_posterior = ctc_posterior[max(0, t - filter_window_size):t, :] + win_posterior = win_posterior.unsqueeze(0) \ + .expand(len(filtered_context_list), -1, -1) + select_win_posterior = [] + for i in range(len(filtered_context_list)): + select_win_posterior.append(torch.index_select( + win_posterior[0], 1, + torch.tensor(filtered_context_list[i], + device=ctc_posterior.device)).transpose(0, 1)) + select_win_posterior = \ + pad_sequence(select_win_posterior, + batch_first=True).transpose(1, 2).contiguous() + dp = torch.full((select_win_posterior.shape[0], + select_win_posterior.shape[2]), + -10000.0, dtype=torch.float32, + device=select_win_posterior.device) + dp[:, 0] = select_win_posterior[:, 0, 0] + for win_t in range(1, select_win_posterior.shape[1]): + temp = dp[:, :-1] + select_win_posterior[:, win_t, 1:] + idx = torch.where(temp > dp[:, 1:]) + idx_ = (idx[0], idx[1] + 1) + dp[idx_] = temp[idx] + dp[:, 0] = \ + torch.where(select_win_posterior[:, win_t, 0] > dp[:, 0], + select_win_posterior[:, win_t, 0], dp[:, 0]) + for i in range(len(filtered_context_list)): + SOC_score[PSC_filtered_index[i]] = \ + max(SOC_score.get(PSC_filtered_index[i], -float('inf')), + dp[i][len(filtered_context_list[i]) - 1] + / len(filtered_context_list[i])) + filtered_context_list = [] + for i in range(len(context_list)): + if SOC_score.get(i, -float('inf')) > filter_threshold: + filtered_context_list.append(context_list[i]) + return filtered_context_list + + # TODO: delete this method + def new_context_list(self, context_txts, symbol_table, bpe_model): + context_list = [] + if bpe_model is not None: + import sentencepiece as spm + sp = spm.SentencePieceProcessor() + sp.load(bpe_model) + else: + sp = None + for context_txt in context_txts: + context_txt = context_txt.strip() + + labels = [] + tokens = [] + if bpe_model is not None: + tokens = tbbm(sp, context_txt) + else: + for ch in context_txt: + if ch == ' ': + ch = "▁" + tokens.append(ch) + for ch in tokens: + if ch in symbol_table: + labels.append(symbol_table[ch]) + elif '' in symbol_table: + labels.append(symbol_table['']) + context_list.append(labels) + self.context_list = context_list + self.graph = {0: {}} + self.graph_size = 0 + self.state2token = {} + self.back_score = {0: 0.0} + self.build_graph(self.context_list) diff --git a/wenet/utils/executor.py b/wenet/utils/executor.py index 9d356da08..220e2fec0 100644 --- a/wenet/utils/executor.py +++ b/wenet/utils/executor.py @@ -61,8 +61,8 @@ def train(self, model, optimizer, scheduler, data_loader, device, writer, with model_context(): for batch_idx, batch in enumerate(data_loader): key, feats, target, feats_lengths, target_lengths, \ - context_list, context_label, context_list_lengths, \ - context_label_lengths = batch + context_list, context_label, context_list_lengths, \ + context_label_lengths = batch feats = feats.to(device) target = target.to(device) feats_lengths = feats_lengths.to(device) @@ -91,10 +91,11 @@ def train(self, model, optimizer, scheduler, data_loader, device, writer, dtype=ds_dtype, cache_enabled=False ): loss_dict = model(feats, feats_lengths, target, - target_lengths, context_list, - context_label, - context_list_lengths, - context_label_lengths) + target_lengths, + context_list=context_list, + context_label=context_label, + context_list_lengths=context_list_lengths, + context_label_lengths=context_label_lengths) loss = loss_dict['loss'] # NOTE(xcsong): Zeroing the gradients is handled automatically by DeepSpeed after the weights # noqa # have been updated using a mini-batch. DeepSpeed also performs gradient averaging automatically # noqa @@ -108,10 +109,11 @@ def train(self, model, optimizer, scheduler, data_loader, device, writer, # https://pytorch.org/docs/stable/notes/amp_examples.html with torch.cuda.amp.autocast(scaler is not None): loss_dict = model(feats, feats_lengths, target, - target_lengths, context_list, - context_label, - context_list_lengths, - context_label_lengths) + target_lengths, + context_list=context_list, + context_label=context_label, + context_list_lengths=context_list_lengths, + context_label_lengths=context_label_lengths) loss = loss_dict['loss'] / accum_grad if use_amp: scaler.scale(loss).backward() @@ -183,8 +185,8 @@ def cv(self, model, data_loader, device, args): with torch.no_grad(): for batch_idx, batch in enumerate(data_loader): key, feats, target, feats_lengths, target_lengths, \ - context_list, context_label, context_list_lengths, \ - context_label_lengths = batch + context_list, context_label, context_list_lengths, \ + context_label_lengths = batch feats = feats.to(device) target = target.to(device) feats_lengths = feats_lengths.to(device) @@ -202,14 +204,18 @@ def cv(self, model, data_loader, device, args): dtype=ds_dtype, cache_enabled=False ): loss_dict = model(feats, feats_lengths, target, - target_lengths, context_list, - context_label, context_list_lengths, - context_label_lengths) + target_lengths, + context_list=context_list, + context_label=context_label, + context_list_lengths=context_list_lengths, + context_label_lengths=context_label_lengths) else: loss_dict = model(feats, feats_lengths, target, - target_lengths, context_list, - context_label, context_list_lengths, - context_label_lengths) + target_lengths, + context_list=context_list, + context_label=context_label, + context_list_lengths=context_list_lengths, + context_label_lengths=context_label_lengths) loss = loss_dict['loss'] if torch.isfinite(loss): num_seen_utts += num_utts diff --git a/wenet/utils/init_model.py b/wenet/utils/init_model.py index acb496f58..7ffa101e8 100644 --- a/wenet/utils/init_model.py +++ b/wenet/utils/init_model.py @@ -116,6 +116,7 @@ def init_model(configs): attention_decoder=decoder, joint=joint, ctc=ctc, + context_module=context_module, **configs['model_conf']) elif 'paraformer' in configs: predictor = Predictor(**configs['cif_predictor_conf']) From d7f7bae4ca182b912ea01362b0799f5bbd6ff123 Mon Sep 17 00:00:00 2001 From: kxhuang Date: Tue, 5 Sep 2023 02:14:52 +0800 Subject: [PATCH 07/10] Modify the method of passing context data during model training and support jit script exporting --- wenet/bin/recognize.py | 44 ++++++++++++-------- wenet/bin/train.py | 12 +++--- wenet/dataset/processor.py | 9 ++--- wenet/transducer/transducer.py | 12 +++--- wenet/transformer/asr_model.py | 48 ++++++++++++++-------- wenet/transformer/context_module.py | 21 +++++----- wenet/utils/context_graph.py | 63 ++++++----------------------- wenet/utils/executor.py | 42 +++++-------------- 8 files changed, 111 insertions(+), 140 deletions(-) diff --git a/wenet/bin/recognize.py b/wenet/bin/recognize.py index 602834f67..10506f1c4 100644 --- a/wenet/bin/recognize.py +++ b/wenet/bin/recognize.py @@ -160,12 +160,12 @@ def get_args(): default=0.0, help='lm scale for hlg attention rescore decode') - parser.add_argument( - '--context_bias_mode', - type=str, - default='', - help='''Context bias mode, selectable from the following - option: decoding-graph、deep-biasing''') + parser.add_argument('--context_bias_mode', + type=str, + default='', + help='''Context bias mode, selectable from the + following option: decoding_graph, + deep_biasing''') parser.add_argument('--context_list_path', type=str, default='', @@ -174,7 +174,17 @@ def get_args(): type=float, default=2.0, help='''The higher the score, the greater the degree of - bias using decoding-graph for biasing''') + bias using decoding_graph for biasing''') + parser.add_argument('--deep_biasing_score', + type=float, + default=1.5, + help='''The higher the score, the greater the degree of + bias using deep_biasing for biasing''') + parser.add_argument('--context_filtering', + action='store_true', + help='''Reduce the size of the context list through + filtering to enhance the effect of context + biasing''') args = parser.parse_args() print(args) @@ -262,22 +272,23 @@ def main(): if args.context_bias_mode != '': context_graph = ContextGraph(args.context_list_path, symbol_table, args.bpe_model, args.context_graph_score) - if 'deep-biasing' in args.context_bias_mode: + context_graph.context_filtering = args.context_filtering + context_list_all = context_graph.context_list + if 'deep_biasing' in args.context_bias_mode: context_graph.deep_biasing = True - if 'decoding-graph' in args.context_bias_mode: + context_graph.deep_biasing_score = args.deep_biasing_score + if 'decoding_graph' in args.context_bias_mode: context_graph.graph_biasing = True - if 'deep-biasing' in args.context_bias_mode and \ - 'decoding-graph' in args.context_bias_mode: - context_graph.deep_biasing_weight = 1.0 - context_graph.graph_biasing_weight = 0.7 with torch.no_grad(), open(args.result_file, 'w') as fout: for batch_idx, batch in enumerate(test_data_loader): - keys, feats, target, feats_lengths, target_lengths, _, _, _, _ = batch + keys, feats, target, feats_lengths, target_lengths, _ = batch feats = feats.to(device) target = target.to(device) feats_lengths = feats_lengths.to(device) target_lengths = target_lengths.to(device) + if context_graph is not None and args.context_filtering: + context_graph.context_list = context_list_all if args.mode == 'attention': hyps, _ = model.recognize( @@ -286,7 +297,7 @@ def main(): beam_size=args.beam_size, decoding_chunk_size=args.decoding_chunk_size, num_decoding_left_chunks=args.num_decoding_left_chunks, - simulate_streaming=args.simulate_streaming) + simulate_streaming=args.simulate_streaming,) hyps = [hyp.tolist() for hyp in hyps] elif args.mode == 'ctc_greedy_search': hyps, _ = model.ctc_greedy_search( @@ -294,7 +305,8 @@ def main(): feats_lengths, decoding_chunk_size=args.decoding_chunk_size, num_decoding_left_chunks=args.num_decoding_left_chunks, - simulate_streaming=args.simulate_streaming) + simulate_streaming=args.simulate_streaming, + context_graph=context_graph) elif args.mode == 'rnnt_greedy_search': assert (feats.size(0) == 1) assert 'predictor' in configs diff --git a/wenet/bin/train.py b/wenet/bin/train.py index 8756fee89..0915b9dc1 100644 --- a/wenet/bin/train.py +++ b/wenet/bin/train.py @@ -285,12 +285,12 @@ def main(): for p in model.context_module.context_decoder_ctc_linear.parameters(): p.requires_grad = False - # # !!!IMPORTANT!!! - # # Try to export the model by script, if fails, we should refine - # # the code to satisfy the script export requirements - # if local_rank == 0: - # script_model = torch.jit.script(model) - # script_model.save(os.path.join(args.model_dir, 'init.zip')) + # !!!IMPORTANT!!! + # Try to export the model by script, if fails, we should refine + # the code to satisfy the script export requirements + if local_rank == 0: + script_model = torch.jit.script(model) + script_model.save(os.path.join(args.model_dir, 'init.zip')) executor = Executor() # If specify checkpoint, load some info from checkpoint if args.checkpoint is not None: diff --git a/wenet/dataset/processor.py b/wenet/dataset/processor.py index 73b8295f2..9b5adce6e 100644 --- a/wenet/dataset/processor.py +++ b/wenet/dataset/processor.py @@ -738,8 +738,7 @@ def padding(data): if 'context_list' not in sample[0]: yield (sorted_keys, padded_feats, padding_labels, feats_lengths, - label_lengths, torch.tensor([0]), torch.tensor([0]), - torch.tensor([0]), torch.tensor([0])) + label_lengths, []) else: context_lists = sample[0]['context_list'] context_list_lengths = \ @@ -757,6 +756,6 @@ def padding(data): batch_first=True, padding_value=-1) yield (sorted_keys, padded_feats, padding_labels, - feats_lengths, label_lengths, padding_context_lists, - padding_context_labels, context_list_lengths, - context_label_lengths) + feats_lengths, label_lengths, + [padding_context_lists, padding_context_labels, + context_list_lengths, context_label_lengths]) diff --git a/wenet/transducer/transducer.py b/wenet/transducer/transducer.py index d4498c69a..b0a591126 100644 --- a/wenet/transducer/transducer.py +++ b/wenet/transducer/transducer.py @@ -96,10 +96,7 @@ def forward( text: torch.Tensor, text_lengths: torch.Tensor, steps: int = 0, - context_list: torch.Tensor = torch.tensor([0]), - context_label: torch.Tensor = torch.tensor([0]), - context_list_lengths: torch.Tensor = torch.tensor([0]), - context_label_lengths: torch.Tensor = torch.tensor([0]), + context_data: List[torch.Tensor] = None, ) -> Dict[str, Optional[torch.Tensor]]: """Frontend + Encoder + predictor + joint + loss @@ -122,12 +119,17 @@ def forward( # Context biasing branch loss_bias: Optional[torch.Tensor] = None if self.context_module is not None: + assert len(context_data) == 4 + context_list = context_data[0] + context_label = context_data[1] + context_list_lengths = context_data[2] + context_label_lengths = context_data[3] context_emb = self.context_module. \ forward_context_emb(context_list, context_list_lengths) encoder_out, bias_out = self.context_module(context_emb, encoder_out) bias_out = bias_out.transpose(0, 1).log_softmax(2) - loss_bias = self.context_module.bias_loss(bias_out, + loss_bias = self.context_module.bias_loss(bias_out, context_label, encoder_out_lens, context_label_lengths diff --git a/wenet/transformer/asr_model.py b/wenet/transformer/asr_model.py index 7d986b71e..592b548bc 100644 --- a/wenet/transformer/asr_model.py +++ b/wenet/transformer/asr_model.py @@ -93,10 +93,7 @@ def forward( speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, - context_list: torch.Tensor = torch.tensor([0]), - context_label: torch.Tensor = torch.tensor([0]), - context_list_lengths: torch.Tensor = torch.tensor([0]), - context_label_lengths: torch.Tensor = torch.tensor([0]), + context_data: List[torch.Tensor], ) -> Dict[str, Optional[torch.Tensor]]: """Frontend + Encoder + Decoder + Calc loss @@ -119,6 +116,11 @@ def forward( # 1a. Context biasing branch loss_bias: Optional[torch.Tensor] = None if self.context_module is not None: + assert len(context_data) == 4 + context_list = context_data[0] + context_label = context_data[1] + context_list_lengths = context_data[2] + context_label_lengths = context_data[3] context_emb = self.context_module. \ forward_context_emb(context_list, context_list_lengths) encoder_out, bias_out = self.context_module(context_emb, @@ -156,7 +158,7 @@ def forward( else: loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * \ loss_att - if loss_bias is not None: + if loss is not None and loss_bias is not None: loss = loss + self.bias_weight * loss_bias return {"loss": loss, "loss_att": loss_att, "loss_ctc": loss_ctc, @@ -342,6 +344,7 @@ def ctc_greedy_search( decoding_chunk_size: int = -1, num_decoding_left_chunks: int = -1, simulate_streaming: bool = False, + context_graph: ContextGraph = None, ) -> List[List[int]]: """ Apply CTC greedy search @@ -367,6 +370,22 @@ def ctc_greedy_search( speech, speech_lengths, decoding_chunk_size, num_decoding_left_chunks, simulate_streaming) # (B, maxlen, encoder_dim) + + if context_graph is not None and context_graph.deep_biasing: + if context_graph.context_filtering: + ctc_probs = self.ctc.log_softmax(encoder_out).squeeze(0) + filtered_context_list = context_graph.two_stage_filtering( + context_graph.context_list, ctc_probs) + context_graph.context_list = filtered_context_list + context_list, context_list_lengths = \ + context_graph.get_context_list_tensor(context_graph.context_list) + context_list = context_list.to(encoder_out.device) + context_emb = self.context_module. \ + forward_context_emb(context_list, context_list_lengths) + encoder_out, _ = \ + self.context_module(context_emb, encoder_out, + context_graph.deep_biasing_score, True) + maxlen = encoder_out.size(1) encoder_out_lens = encoder_mask.squeeze(1).sum(1) ctc_probs = self.ctc.log_softmax( @@ -389,7 +408,6 @@ def _ctc_prefix_beam_search( num_decoding_left_chunks: int = -1, simulate_streaming: bool = False, context_graph: ContextGraph = None, - context_filtering: bool = True, ) -> Tuple[List[List[int]], torch.Tensor]: """ CTC prefix beam search inner implementation @@ -423,21 +441,19 @@ def _ctc_prefix_beam_search( simulate_streaming) # (B, maxlen, encoder_dim) if context_graph is not None and context_graph.deep_biasing: - if context_filtering: + if context_graph.context_filtering: ctc_probs = self.ctc.log_softmax(encoder_out).squeeze(0) - filtered_context_list = \ - context_graph.tow_stage_filtering(context_graph.context_list, - ctc_probs, -6) - context_list, context_list_lengths = \ - context_graph.get_context_list_tensor(filtered_context_list) - else: - context_list, context_list_lengths = \ - context_graph.get_context_list_tensor(context_graph.context_list) + filtered_context_list = context_graph.two_stage_filtering( + context_graph.context_list, ctc_probs) + context_graph.context_list = filtered_context_list + context_list, context_list_lengths = \ + context_graph.get_context_list_tensor(context_graph.context_list) + context_list = context_list.to(encoder_out.device) context_emb = self.context_module. \ forward_context_emb(context_list, context_list_lengths) encoder_out, _ = \ self.context_module(context_emb, encoder_out, - context_graph.deep_biasing_weight, True) + context_graph.deep_biasing_score, True) maxlen = encoder_out.size(1) ctc_probs = self.ctc.log_softmax( diff --git a/wenet/transformer/context_module.py b/wenet/transformer/context_module.py index 153ea7d6d..7cf66ac40 100644 --- a/wenet/transformer/context_module.py +++ b/wenet/transformer/context_module.py @@ -10,7 +10,7 @@ # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations -# under the License. +# under the License. import torch @@ -23,10 +23,10 @@ class BLSTM(torch.nn.Module): """ """ - def __init__(self, - vocab_size, - embedding_size, - num_layers, + def __init__(self, + vocab_size, + embedding_size, + num_layers, dropout=0.0): super(BLSTM, self).__init__() self.vocab_size = vocab_size @@ -56,8 +56,6 @@ def forward(self, sen_batch, sen_lengths): class ContextModule(torch.nn.Module): - """ - """ def __init__( self, vocab_size: int, @@ -104,15 +102,18 @@ def forward_context_emb(self, context_list, context_lengths) -> torch.Tensor: context_emb = self.context_encoder(context_emb.unsqueeze(0)) return context_emb - def forward(self, context_emb, encoder_out, - biasing_weight=1.0, recognize=False) \ + def forward(self, + context_emb: torch.Tensor, + encoder_out: torch.Tensor, + biasing_score: float = 1.0, + recognize: bool = False) \ -> Tuple[torch.Tensor, torch.Tensor]: context_emb = context_emb.expand(encoder_out.shape[0], -1, -1) context_emb, _ = self.biasing_layer(encoder_out, context_emb, context_emb) encoder_bias_out = \ self.norm_aft_combiner(encoder_out + - self.combiner(context_emb) * biasing_weight) + self.combiner(context_emb) * biasing_score) if recognize: return encoder_bias_out, torch.tensor(0.0) bias_out = self.context_decoder(context_emb) diff --git a/wenet/utils/context_graph.py b/wenet/utils/context_graph.py index 9a7f426b8..438e4341c 100644 --- a/wenet/utils/context_graph.py +++ b/wenet/utils/context_graph.py @@ -49,14 +49,14 @@ class ContextGraph: Args: context_list_path(str): context list path bpe_model(str): model for english bpe part - context_score(float): context score for each token + context_graph_score(float): context score for each token """ def __init__(self, context_list_path: str, symbol_table: Dict[str, int], bpe_model: str = None, - context_score: float = 6): - self.context_score = context_score + context_graph_score: float = 2.0): + self.context_graph_score = context_graph_score self.context_list = tokenize(context_list_path, symbol_table, bpe_model) self.graph = {0: {}} @@ -66,8 +66,8 @@ def __init__(self, self.build_graph(self.context_list) self.graph_biasing = False self.deep_biasing = False - self.graph_biasing_weight = 1.0 - self.deep_biasing_weight = 1.5 + self.deep_biasing_score = 1.0 + self.context_filtering = True def build_graph(self, context_list: List[List[int]]): """ Constructing the context decoding graph, add arcs with negative @@ -92,8 +92,8 @@ def build_graph(self, context_list: List[List[int]]): self.graph[now_state][context_token[i]] = self.graph_size now_state = self.graph_size if i != len(context_token) - 1: - self.back_score[now_state] = -(i + - 1) * self.context_score + self.back_score[now_state] = \ + -(i + 1) * self.context_graph_score else: self.back_score[now_state] = 0 self.state2token[now_state] = context_token[i] @@ -105,14 +105,13 @@ def find_next_state(self, now_state: int, token: int): from the starting state to avoid token consumption due to mismatches. """ if token in self.graph[now_state]: - return self.graph[now_state][token], \ - self.context_score * self.graph_biasing_weight + return self.graph[now_state][token], self.context_graph_score back_score = self.back_score[now_state] now_state = 0 if token in self.graph[now_state]: - return self.graph[now_state][ - token], (back_score + self.context_score) * self.graph_biasing_weight - return 0, back_score * self.graph_biasing_weight + return self.graph[now_state][token], \ + back_score + self.context_graph_score + return 0, back_score def get_context_list_tensor(self, context_list: List[List[int]]): context_list_tensor = [torch.tensor([0], dtype=torch.int32)] @@ -125,16 +124,14 @@ def get_context_list_tensor(self, context_list: List[List[int]]): padding_value=-1) return context_list_tensor, context_list_lengths - def tow_stage_filtering(self, + def two_stage_filtering(self, context_list: List[List[int]], ctc_posterior: torch.Tensor, - filter_threshold: float = -8, + filter_threshold: float = -4, filter_window_size: int = 64): if len(context_list) == 0: return context_list - # ctc_posterior = torch.clamp(ctc_posterior, min=2 * filter_threshold) - SOC_score = {} for t in range(1, ctc_posterior.shape[0]): if t % (filter_window_size // 2) != 0 and t != ctc_posterior.shape[0] - 1: @@ -195,37 +192,3 @@ def tow_stage_filtering(self, if SOC_score.get(i, -float('inf')) > filter_threshold: filtered_context_list.append(context_list[i]) return filtered_context_list - - # TODO: delete this method - def new_context_list(self, context_txts, symbol_table, bpe_model): - context_list = [] - if bpe_model is not None: - import sentencepiece as spm - sp = spm.SentencePieceProcessor() - sp.load(bpe_model) - else: - sp = None - for context_txt in context_txts: - context_txt = context_txt.strip() - - labels = [] - tokens = [] - if bpe_model is not None: - tokens = tbbm(sp, context_txt) - else: - for ch in context_txt: - if ch == ' ': - ch = "▁" - tokens.append(ch) - for ch in tokens: - if ch in symbol_table: - labels.append(symbol_table[ch]) - elif '' in symbol_table: - labels.append(symbol_table['']) - context_list.append(labels) - self.context_list = context_list - self.graph = {0: {}} - self.graph_size = 0 - self.state2token = {} - self.back_score = {0: 0.0} - self.build_graph(self.context_list) diff --git a/wenet/utils/executor.py b/wenet/utils/executor.py index 220e2fec0..ad03d1498 100644 --- a/wenet/utils/executor.py +++ b/wenet/utils/executor.py @@ -61,16 +61,13 @@ def train(self, model, optimizer, scheduler, data_loader, device, writer, with model_context(): for batch_idx, batch in enumerate(data_loader): key, feats, target, feats_lengths, target_lengths, \ - context_list, context_label, context_list_lengths, \ - context_label_lengths = batch + context_data = batch feats = feats.to(device) target = target.to(device) feats_lengths = feats_lengths.to(device) target_lengths = target_lengths.to(device) - context_list = context_list.to(device) - context_label = context_label.to(device) - context_list_lengths = context_list_lengths.to(device) - context_label_lengths = context_label_lengths.to(device) + for i in range(len(context_data)): + context_data[i] = context_data[i].to(device) num_utts = target_lengths.size(0) if num_utts == 0: continue @@ -91,11 +88,7 @@ def train(self, model, optimizer, scheduler, data_loader, device, writer, dtype=ds_dtype, cache_enabled=False ): loss_dict = model(feats, feats_lengths, target, - target_lengths, - context_list=context_list, - context_label=context_label, - context_list_lengths=context_list_lengths, - context_label_lengths=context_label_lengths) + target_lengths, context_data) loss = loss_dict['loss'] # NOTE(xcsong): Zeroing the gradients is handled automatically by DeepSpeed after the weights # noqa # have been updated using a mini-batch. DeepSpeed also performs gradient averaging automatically # noqa @@ -109,11 +102,7 @@ def train(self, model, optimizer, scheduler, data_loader, device, writer, # https://pytorch.org/docs/stable/notes/amp_examples.html with torch.cuda.amp.autocast(scaler is not None): loss_dict = model(feats, feats_lengths, target, - target_lengths, - context_list=context_list, - context_label=context_label, - context_list_lengths=context_list_lengths, - context_label_lengths=context_label_lengths) + target_lengths, context_data) loss = loss_dict['loss'] / accum_grad if use_amp: scaler.scale(loss).backward() @@ -185,16 +174,13 @@ def cv(self, model, data_loader, device, args): with torch.no_grad(): for batch_idx, batch in enumerate(data_loader): key, feats, target, feats_lengths, target_lengths, \ - context_list, context_label, context_list_lengths, \ - context_label_lengths = batch + context_data = batch feats = feats.to(device) target = target.to(device) feats_lengths = feats_lengths.to(device) target_lengths = target_lengths.to(device) - context_list = context_list.to(device) - context_label = context_label.to(device) - context_list_lengths = context_list_lengths.to(device) - context_label_lengths = context_label_lengths.to(device) + for i in range(len(context_data)): + context_data[i] = context_data[i].to(device) num_utts = target_lengths.size(0) if num_utts == 0: continue @@ -204,18 +190,10 @@ def cv(self, model, data_loader, device, args): dtype=ds_dtype, cache_enabled=False ): loss_dict = model(feats, feats_lengths, target, - target_lengths, - context_list=context_list, - context_label=context_label, - context_list_lengths=context_list_lengths, - context_label_lengths=context_label_lengths) + target_lengths, context_data) else: loss_dict = model(feats, feats_lengths, target, - target_lengths, - context_list=context_list, - context_label=context_label, - context_list_lengths=context_list_lengths, - context_label_lengths=context_label_lengths) + target_lengths, context_data) loss = loss_dict['loss'] if torch.isfinite(loss): num_seen_utts += num_utts From 8b567f40304adc65980bb295330d265908d77fe7 Mon Sep 17 00:00:00 2001 From: kaixun huang Date: Tue, 5 Sep 2023 02:22:11 +0800 Subject: [PATCH 08/10] Modify the code formatting --- wenet/dataset/processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wenet/dataset/processor.py b/wenet/dataset/processor.py index 9b5adce6e..52072d509 100644 --- a/wenet/dataset/processor.py +++ b/wenet/dataset/processor.py @@ -756,6 +756,6 @@ def padding(data): batch_first=True, padding_value=-1) yield (sorted_keys, padded_feats, padding_labels, - feats_lengths, label_lengths, + feats_lengths, label_lengths, [padding_context_lists, padding_context_labels, context_list_lengths, context_label_lengths]) From d009779aef5c8cb17cf41161cc1399c3f31956e9 Mon Sep 17 00:00:00 2001 From: kaixunhuang0 Date: Wed, 6 Sep 2023 07:46:56 +0800 Subject: [PATCH 09/10] Add comments and modify the location of the deep biasing forward code --- wenet/bin/recognize.py | 18 ++++++++------ wenet/dataset/processor.py | 14 +++++------ wenet/transducer/transducer.py | 2 ++ wenet/transformer/asr_model.py | 32 +++++-------------------- wenet/transformer/context_module.py | 27 +++++++++++++++++++-- wenet/utils/context_graph.py | 37 ++++++++++++++++++++++++++--- 6 files changed, 84 insertions(+), 46 deletions(-) diff --git a/wenet/bin/recognize.py b/wenet/bin/recognize.py index 10506f1c4..682fdc078 100644 --- a/wenet/bin/recognize.py +++ b/wenet/bin/recognize.py @@ -164,7 +164,7 @@ def get_args(): type=str, default='', help='''Context bias mode, selectable from the - following option: decoding_graph, + following option: context_graph, deep_biasing''') parser.add_argument('--context_list_path', type=str, @@ -174,10 +174,10 @@ def get_args(): type=float, default=2.0, help='''The higher the score, the greater the degree of - bias using decoding_graph for biasing''') + bias using context_graph for biasing''') parser.add_argument('--deep_biasing_score', type=float, - default=1.5, + default=1.0, help='''The higher the score, the greater the degree of bias using deep_biasing for biasing''') parser.add_argument('--context_filtering', @@ -185,6 +185,12 @@ def get_args(): help='''Reduce the size of the context list through filtering to enhance the effect of context biasing''') + parser.add_argument('--context_filtering_threshold', + type=float, + default=-4.0, + help='''The threshold for context filtering, the larger + the value, the closer it is to 0, and the fewer + remaining context phrases are filtered''') args = parser.parse_args() print(args) @@ -273,11 +279,11 @@ def main(): context_graph = ContextGraph(args.context_list_path, symbol_table, args.bpe_model, args.context_graph_score) context_graph.context_filtering = args.context_filtering - context_list_all = context_graph.context_list + context_graph.filter_threshold = args.context_filtering_threshold if 'deep_biasing' in args.context_bias_mode: context_graph.deep_biasing = True context_graph.deep_biasing_score = args.deep_biasing_score - if 'decoding_graph' in args.context_bias_mode: + if 'context_graph' in args.context_bias_mode: context_graph.graph_biasing = True with torch.no_grad(), open(args.result_file, 'w') as fout: @@ -287,8 +293,6 @@ def main(): target = target.to(device) feats_lengths = feats_lengths.to(device) target_lengths = target_lengths.to(device) - if context_graph is not None and args.context_filtering: - context_graph.context_list = context_list_all if args.mode == 'attention': hyps, _ = model.recognize( diff --git a/wenet/dataset/processor.py b/wenet/dataset/processor.py index 52072d509..c37ac7c10 100644 --- a/wenet/dataset/processor.py +++ b/wenet/dataset/processor.py @@ -614,9 +614,11 @@ def context_sampling(data, symbol_table, len_min, len_max, + utt_num_context, batch_num_context, ): - """context_sampling + """Perform context sampling by randomly selecting context phrases from the + utterance to obtain a context list for the entire batch Args: data: Iterable[List[{key, feat, label}]] @@ -642,8 +644,7 @@ def context_sampling(data, st_select = [] en_select = [] - num_context = 3 - for _ in range(0, num_context): + for _ in range(0, utt_num_context): random_len = random.randint(min(len(st_index_list) - 1, len_min), min(len(st_index_list) - 1, len_max)) random_index = random.randint(0, len(st_index_list) - @@ -680,11 +681,8 @@ def context_sampling(data, def context_label_generate(label, context_list): - """ generate context label - - Args: - - Returns + """ Generate context labels corresponding to the utterances based on + the context list """ context_labels = [] for x in label: diff --git a/wenet/transducer/transducer.py b/wenet/transducer/transducer.py index b0a591126..364ef9fdd 100644 --- a/wenet/transducer/transducer.py +++ b/wenet/transducer/transducer.py @@ -105,6 +105,8 @@ def forward( speech_lengths: (Batch, ) text: (Batch, Length) text_lengths: (Batch,) + context_data: [context_list, context_label, + context_list_lengths, context_label_lengths] """ assert text_lengths.dim() == 1, text_lengths.shape # Check that batch_size is unified diff --git a/wenet/transformer/asr_model.py b/wenet/transformer/asr_model.py index 592b548bc..a651f591a 100644 --- a/wenet/transformer/asr_model.py +++ b/wenet/transformer/asr_model.py @@ -102,6 +102,8 @@ def forward( speech_lengths: (Batch, ) text: (Batch, Length) text_lengths: (Batch,) + context_data: [context_list, context_label, + context_list_lengths, context_label_lengths] """ assert text_lengths.dim() == 1, text_lengths.shape @@ -372,19 +374,8 @@ def ctc_greedy_search( simulate_streaming) # (B, maxlen, encoder_dim) if context_graph is not None and context_graph.deep_biasing: - if context_graph.context_filtering: - ctc_probs = self.ctc.log_softmax(encoder_out).squeeze(0) - filtered_context_list = context_graph.two_stage_filtering( - context_graph.context_list, ctc_probs) - context_graph.context_list = filtered_context_list - context_list, context_list_lengths = \ - context_graph.get_context_list_tensor(context_graph.context_list) - context_list = context_list.to(encoder_out.device) - context_emb = self.context_module. \ - forward_context_emb(context_list, context_list_lengths) - encoder_out, _ = \ - self.context_module(context_emb, encoder_out, - context_graph.deep_biasing_score, True) + encoder_out = context_graph.forward_deep_biasing( + encoder_out, self.context_module, self.ctc) maxlen = encoder_out.size(1) encoder_out_lens = encoder_mask.squeeze(1).sum(1) @@ -441,19 +432,8 @@ def _ctc_prefix_beam_search( simulate_streaming) # (B, maxlen, encoder_dim) if context_graph is not None and context_graph.deep_biasing: - if context_graph.context_filtering: - ctc_probs = self.ctc.log_softmax(encoder_out).squeeze(0) - filtered_context_list = context_graph.two_stage_filtering( - context_graph.context_list, ctc_probs) - context_graph.context_list = filtered_context_list - context_list, context_list_lengths = \ - context_graph.get_context_list_tensor(context_graph.context_list) - context_list = context_list.to(encoder_out.device) - context_emb = self.context_module. \ - forward_context_emb(context_list, context_list_lengths) - encoder_out, _ = \ - self.context_module(context_emb, encoder_out, - context_graph.deep_biasing_score, True) + encoder_out = context_graph.forward_deep_biasing( + encoder_out, self.context_module, self.ctc) maxlen = encoder_out.size(1) ctc_probs = self.ctc.log_softmax( diff --git a/wenet/transformer/context_module.py b/wenet/transformer/context_module.py index 7cf66ac40..2a6019bf8 100644 --- a/wenet/transformer/context_module.py +++ b/wenet/transformer/context_module.py @@ -20,7 +20,8 @@ class BLSTM(torch.nn.Module): - """ + """Context encoder, encoding unequal-length context phrases + into equal-length embedding representations. """ def __init__(self, @@ -56,6 +57,17 @@ def forward(self, sen_batch, sen_lengths): class ContextModule(torch.nn.Module): + """Context module, Using context information for deep contextual bias + + During the training process, the original parameters of the ASR model + are frozen, and only the parameters of context module are trained. + + Args: + vocab_size (int): vocabulary size + embedding_size (int): number of ASR encoder projection units + encoder_layers (int): number of context encoder layers + attention_heads (int): number of heads in the biasing layer + """ def __init__( self, vocab_size: int, @@ -97,7 +109,12 @@ def __init__( self.bias_loss = torch.nn.CTCLoss(reduction="sum", zero_infinity=True) - def forward_context_emb(self, context_list, context_lengths) -> torch.Tensor: + def forward_context_emb(self, + context_list: torch.Tensor, + context_lengths: torch.Tensor + ) -> torch.Tensor: + """Extracting context embeddings + """ context_emb = self.context_extractor(context_list, context_lengths) context_emb = self.context_encoder(context_emb.unsqueeze(0)) return context_emb @@ -108,6 +125,12 @@ def forward(self, biasing_score: float = 1.0, recognize: bool = False) \ -> Tuple[torch.Tensor, torch.Tensor]: + """Using context embeddings for deep biasing. + + Args: + biasing_score (int): degree of context biasing + recognize (bool): no context decoder computation if True + """ context_emb = context_emb.expand(encoder_out.shape[0], -1, -1) context_emb, _ = self.biasing_layer(encoder_out, context_emb, context_emb) diff --git a/wenet/utils/context_graph.py b/wenet/utils/context_graph.py index 438e4341c..4ed05395d 100644 --- a/wenet/utils/context_graph.py +++ b/wenet/utils/context_graph.py @@ -2,6 +2,8 @@ from torch.nn.utils.rnn import pad_sequence from wenet.dataset.processor import __tokenize_by_bpe_model +from wenet.transformer.context_module import ContextModule +from wenet.transformer.ctc import CTC from typing import Dict, List @@ -68,6 +70,7 @@ def __init__(self, self.deep_biasing = False self.deep_biasing_score = 1.0 self.context_filtering = True + self.filter_threshold = -4.0 def build_graph(self, context_list: List[List[int]]): """ Constructing the context decoding graph, add arcs with negative @@ -114,6 +117,9 @@ def find_next_state(self, now_state: int, token: int): return 0, back_score def get_context_list_tensor(self, context_list: List[List[int]]): + """Add 0 as no-bias in the context list and obtain the tensor + form of the context list + """ context_list_tensor = [torch.tensor([0], dtype=torch.int32)] for context_token in context_list: context_list_tensor.append(torch.tensor(context_token, dtype=torch.int32)) @@ -124,11 +130,36 @@ def get_context_list_tensor(self, context_list: List[List[int]]): padding_value=-1) return context_list_tensor, context_list_lengths + def forward_deep_biasing(self, + encoder_out: torch.Tensor, + context_module: ContextModule, + ctc: CTC): + """Apply deep biasing based on encoder output and context list + """ + if self.context_filtering: + ctc_probs = ctc.log_softmax(encoder_out).squeeze(0) + filtered_context_list = self.two_stage_filtering( + self.context_list, ctc_probs) + context_list, context_list_lengths = self. \ + get_context_list_tensor(filtered_context_list) + else: + context_list, context_list_lengths = self. \ + get_context_list_tensor(self.context_list) + context_list = context_list.to(encoder_out.device) + context_emb = context_module. \ + forward_context_emb(context_list, context_list_lengths) + encoder_out, _ = \ + context_module(context_emb, encoder_out, + self.deep_biasing_score, True) + return encoder_out + def two_stage_filtering(self, context_list: List[List[int]], ctc_posterior: torch.Tensor, - filter_threshold: float = -4, filter_window_size: int = 64): + """Calculate PSC and SOC for context phrase filtering, + refer to: https://arxiv.org/abs/2301.06735 + """ if len(context_list) == 0: return context_list @@ -148,7 +179,7 @@ def two_stage_filtering(self, PSC_score[i] = max(SOC_score.get(i, -float('inf')), score) PSC_filtered_index = [] for i in PSC_score: - if PSC_score[i] > filter_threshold: + if PSC_score[i] > self.filter_threshold: PSC_filtered_index.append(i) if len(PSC_filtered_index) == 0: continue @@ -189,6 +220,6 @@ def two_stage_filtering(self, / len(filtered_context_list[i])) filtered_context_list = [] for i in range(len(context_list)): - if SOC_score.get(i, -float('inf')) > filter_threshold: + if SOC_score.get(i, -float('inf')) > self.filter_threshold: filtered_context_list.append(context_list[i]) return filtered_context_list From 762e1991a83c2162145a81179c07958c0fcec211 Mon Sep 17 00:00:00 2001 From: kaixunhuang0 Date: Thu, 21 Sep 2023 14:40:24 +0800 Subject: [PATCH 10/10] Fix the error of BLSTM forward state and force turn off use_dynamic_chunk during bias module training --- wenet/bin/train.py | 4 +++- wenet/transformer/context_module.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/wenet/bin/train.py b/wenet/bin/train.py index 0915b9dc1..eba63dfb3 100644 --- a/wenet/bin/train.py +++ b/wenet/bin/train.py @@ -276,14 +276,16 @@ def main(): num_params = sum(p.numel() for p in model.parameters()) print('the number of model params: {:,d}'.format(num_params)) if local_rank == 0 else None # noqa - # Freeze other parts of the model during training context bias module if 'context_module_conf' in configs: + # Freeze other parts of the model during training context bias module for p in model.parameters(): p.requires_grad = False for p in model.context_module.parameters(): p.requires_grad = True for p in model.context_module.context_decoder_ctc_linear.parameters(): p.requires_grad = False + # Turn off dynamic chunk because it will affect the training of bias + model.encoder.use_dynamic_chunk = False # !!!IMPORTANT!!! # Try to export the model by script, if fails, we should refine diff --git a/wenet/transformer/context_module.py b/wenet/transformer/context_module.py index 2a6019bf8..d12d9d511 100644 --- a/wenet/transformer/context_module.py +++ b/wenet/transformer/context_module.py @@ -51,8 +51,8 @@ def forward(self, sen_batch, sen_lengths): _, last_state = self.sen_rnn(pack_seq) laste_h = last_state[0] laste_c = last_state[1] - state = torch.cat([laste_h[-1, :, :], laste_h[0, :, :], - laste_c[-1, :, :], laste_c[0, :, :]], dim=-1) + state = torch.cat([laste_h[-1, :, :], laste_h[-2, :, :], + laste_c[-1, :, :], laste_c[-2, :, :]], dim=-1) return state