diff --git a/nemo/collections/asr/parts/submodules/multitask_beam_decoding.py b/nemo/collections/asr/parts/submodules/multitask_beam_decoding.py index 3c2a424dcd09..e6598f0edc97 100644 --- a/nemo/collections/asr/parts/submodules/multitask_beam_decoding.py +++ b/nemo/collections/asr/parts/submodules/multitask_beam_decoding.py @@ -308,6 +308,8 @@ def format_hypotheses( break # empty sequence if pos < -1: hyp.y_sequence = ids[: pos + 1] + if hyp.xatt_scores is not None: + hyp.xatt_scores = [xatt_layer[:, : pos + 1, :] for xatt_layer in hyp.xatt_scores] @dataclass diff --git a/tests/collections/asr/decoding/test_multi_task_decoding.py b/tests/collections/asr/decoding/test_multi_task_decoding.py index fa21b2f60328..256e012d9c3f 100644 --- a/tests/collections/asr/decoding/test_multi_task_decoding.py +++ b/tests/collections/asr/decoding/test_multi_task_decoding.py @@ -292,3 +292,51 @@ def test_transformer_aed_greedy_infer_strips_prompt(prompted_inputs, decoder_nm, torch.testing.assert_close( untrimmed[decoder_input_ids.shape[1] :], best_path ) # stripped the prompt from the beggining + + +def test_transformer_aed_beam_infer_trims_xatt_scores(prompted_inputs, decoder_nm, nnet, tokenizer): + decoder_input_ids, encoder_hidden_states, encoder_input_mask = prompted_inputs + *_, classifier = nnet + + # Run the actual top-level module used by MultiTask AED model for decoding. + # This module is expected to trim eos and pads in xatt from the end. + gen = TransformerAEDBeamInfer(decoder_nm, classifier, tokenizer, return_xattn_scores=True) + ans = gen( + encoder_hidden_states=encoder_hidden_states, + encoder_input_mask=encoder_input_mask, + decoder_input_ids=decoder_input_ids, + ) + hyp = ans[0][0] + + assert hyp.xatt_scores is not None + seq_len = hyp.y_sequence.shape[0] + decoder_input_ids_len = decoder_input_ids.shape[1] + total_expected_len = seq_len + decoder_input_ids_len - 1 + + # Check that the expected trimming has indeed been done. + for layer_scores in hyp.xatt_scores: + assert layer_scores.shape[1] == total_expected_len + + +def test_transformer_aed_greedy_infer_trims_xatt_scores(prompted_inputs, decoder_nm, nnet, tokenizer): + decoder_input_ids, encoder_hidden_states, encoder_input_mask = prompted_inputs + *_, classifier = nnet + + # Run the actual top-level module used by MultiTask AED model for decoding. + # This module is expected to trim eos and pads in xatt from the end. + gen = TransformerAEDGreedyInfer(decoder_nm, classifier, tokenizer, return_xattn_scores=True) + ans = gen( + encoder_hidden_states=encoder_hidden_states, + encoder_input_mask=encoder_input_mask, + decoder_input_ids=decoder_input_ids, + ) + hyp = ans[0][0] + + assert hyp.xatt_scores is not None + seq_len = hyp.y_sequence.shape[0] + decoder_input_ids_len = decoder_input_ids.shape[1] + total_expected_len = seq_len + decoder_input_ids_len - 1 + + # Check that the expected trimming has indeed been done. + for layer_scores in hyp.xatt_scores: + assert layer_scores.shape[1] == total_expected_len