From 16420920ef23b0ff02101a2da6adfa8d18cc7240 Mon Sep 17 00:00:00 2001 From: azziko Date: Thu, 5 Mar 2026 11:57:08 +0000 Subject: [PATCH 1/4] Add xatt trimming for multitask beam decoding Signed-off-by: azziko --- .../asr/parts/submodules/multitask_beam_decoding.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nemo/collections/asr/parts/submodules/multitask_beam_decoding.py b/nemo/collections/asr/parts/submodules/multitask_beam_decoding.py index 3c2a424dcd09..fa9e790fd64b 100644 --- a/nemo/collections/asr/parts/submodules/multitask_beam_decoding.py +++ b/nemo/collections/asr/parts/submodules/multitask_beam_decoding.py @@ -308,7 +308,8 @@ def format_hypotheses( break # empty sequence if pos < -1: hyp.y_sequence = ids[: pos + 1] - + if hyp.xatt_scores is not None: + hyp.xatt_scores = [xatt_layer[:, : pos + 1, :] for xatt_layer in hyp.xatt_scores] @dataclass class AEDBeamInferConfig: From b338ca33d5e71c5e9343a244c5ce18865f72bccb Mon Sep 17 00:00:00 2001 From: azziko Date: Thu, 5 Mar 2026 12:15:24 +0000 Subject: [PATCH 2/4] Apply isort and black reformatting Signed-off-by: azziko --- nemo/collections/asr/parts/submodules/multitask_beam_decoding.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nemo/collections/asr/parts/submodules/multitask_beam_decoding.py b/nemo/collections/asr/parts/submodules/multitask_beam_decoding.py index fa9e790fd64b..e6598f0edc97 100644 --- a/nemo/collections/asr/parts/submodules/multitask_beam_decoding.py +++ b/nemo/collections/asr/parts/submodules/multitask_beam_decoding.py @@ -311,6 +311,7 @@ def format_hypotheses( if hyp.xatt_scores is not None: hyp.xatt_scores = [xatt_layer[:, : pos + 1, :] for xatt_layer in hyp.xatt_scores] + @dataclass class AEDBeamInferConfig: beam_size: int = 1 From 0dbba499161a85f1113da898e3c928e9a3746aaf Mon Sep 17 00:00:00 2001 From: azziko Date: Thu, 5 Mar 2026 18:30:36 +0000 Subject: [PATCH 3/4] Add unit tests for cross-attention's decoder_inputs trimming Signed-off-by: azziko --- .../asr/decoding/test_multi_task_decoding.py | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/tests/collections/asr/decoding/test_multi_task_decoding.py b/tests/collections/asr/decoding/test_multi_task_decoding.py index fa21b2f60328..a25bba587213 100644 --- a/tests/collections/asr/decoding/test_multi_task_decoding.py +++ b/tests/collections/asr/decoding/test_multi_task_decoding.py @@ -292,3 +292,50 @@ def test_transformer_aed_greedy_infer_strips_prompt(prompted_inputs, decoder_nm, torch.testing.assert_close( untrimmed[decoder_input_ids.shape[1] :], best_path ) # stripped the prompt from the beggining + +def test_transformer_aed_beam_infer_trims_xatt_scores(prompted_inputs, decoder_nm, nnet, tokenizer): + decoder_input_ids, encoder_hidden_states, encoder_input_mask = prompted_inputs + *_, classifier = nnet + + # Run the actual top-level module used by MultiTask AED model for decoding. + # This module is expected to trim eos and pads in xatt from the end. + gen = TransformerAEDBeamInfer(decoder_nm, classifier, tokenizer, return_xattn_scores=True) + ans = gen( + encoder_hidden_states=encoder_hidden_states, + encoder_input_mask=encoder_input_mask, + decoder_input_ids=decoder_input_ids, + ) + hyp = ans[0][0] + + assert hyp.xatt_scores is not None + seq_len = hyp.y_sequence.shape[0] + decoder_input_ids_len = decoder_input_ids.shape[1] + total_expected_len = seq_len + decoder_input_ids_len - 1 + + # Check that the expected trimming has indeed been done. + for layer_scores in hyp.xatt_scores: + assert layer_scores.shape[1] == total_expected_len + + +def test_transformer_aed_greedy_infer_trims_xatt_scores(prompted_inputs, decoder_nm, nnet, tokenizer): + decoder_input_ids, encoder_hidden_states, encoder_input_mask = prompted_inputs + *_, classifier = nnet + + # Run the actual top-level module used by MultiTask AED model for decoding. + # This module is expected to trim eos and pads in xatt from the end. + gen = TransformerAEDGreedyInfer(decoder_nm, classifier, tokenizer, return_xattn_scores=True) + ans = gen( + encoder_hidden_states=encoder_hidden_states, + encoder_input_mask=encoder_input_mask, + decoder_input_ids=decoder_input_ids, + ) + hyp = ans[0][0] + + assert hyp.xatt_scores is not None + seq_len = hyp.y_sequence.shape[0] + decoder_input_ids_len = decoder_input_ids.shape[1] + total_expected_len = seq_len + decoder_input_ids_len - 1 + + # Check that the expected trimming has indeed been done. + for layer_scores in hyp.xatt_scores: + assert layer_scores.shape[1] == total_expected_len \ No newline at end of file From 2d0b82f4281ffb2c60f78d961d33afa60ffd5a3f Mon Sep 17 00:00:00 2001 From: azziko Date: Thu, 5 Mar 2026 18:31:22 +0000 Subject: [PATCH 4/4] Apply isort and black reformatting Signed-off-by: azziko --- tests/collections/asr/decoding/test_multi_task_decoding.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/collections/asr/decoding/test_multi_task_decoding.py b/tests/collections/asr/decoding/test_multi_task_decoding.py index a25bba587213..256e012d9c3f 100644 --- a/tests/collections/asr/decoding/test_multi_task_decoding.py +++ b/tests/collections/asr/decoding/test_multi_task_decoding.py @@ -293,6 +293,7 @@ def test_transformer_aed_greedy_infer_strips_prompt(prompted_inputs, decoder_nm, untrimmed[decoder_input_ids.shape[1] :], best_path ) # stripped the prompt from the beggining + def test_transformer_aed_beam_infer_trims_xatt_scores(prompted_inputs, decoder_nm, nnet, tokenizer): decoder_input_ids, encoder_hidden_states, encoder_input_mask = prompted_inputs *_, classifier = nnet @@ -338,4 +339,4 @@ def test_transformer_aed_greedy_infer_trims_xatt_scores(prompted_inputs, decoder # Check that the expected trimming has indeed been done. for layer_scores in hyp.xatt_scores: - assert layer_scores.shape[1] == total_expected_len \ No newline at end of file + assert layer_scores.shape[1] == total_expected_len