diff --git a/tests/collections/asr/decoding/test_multi_task_decoding.py b/tests/collections/asr/decoding/test_multi_task_decoding.py index fa21b2f60328..463a0f57e6a3 100644 --- a/tests/collections/asr/decoding/test_multi_task_decoding.py +++ b/tests/collections/asr/decoding/test_multi_task_decoding.py @@ -292,3 +292,38 @@ def test_transformer_aed_greedy_infer_strips_prompt(prompted_inputs, decoder_nm, torch.testing.assert_close( untrimmed[decoder_input_ids.shape[1] :], best_path ) # stripped the prompt from the beggining + +def test_beam_xattn_u_dim_matches_prefix_plus_output( + prompted_inputs, decoder_nm, nnet, tokenizer +): + decoder_input_ids, encoder_hidden_states, encoder_input_mask = prompted_inputs + decoder_input_ids = torch.tensor([[1, 0, 2, 3, 4]], dtype=torch.long) + *_, classifier = nnet + + gen = TransformerAEDBeamInfer( + decoder_nm, + classifier, + tokenizer, + return_xattn_scores=True, + ) + (packed_result,) = gen( + encoder_hidden_states=encoder_hidden_states, + encoder_input_mask=encoder_input_mask, + decoder_input_ids=decoder_input_ids, + ) + + prefix_len = decoder_input_ids.shape[1] + hyp = packed_result[0] + + assert hyp.xatt_scores is not None + assert hyp.y_sequence is not None + + output_len = hyp.y_sequence.shape[0] + + for layer_idx, xatt in enumerate(hyp.xatt_scores): + u_dim = xatt.shape[1] + expected_u = prefix_len + output_len + assert u_dim == expected_u, ( + f"Layer {layer_idx}: xatt U dim {u_dim} != " + f"prefix_len({prefix_len}) + output_len({output_len}) = {expected_u}" + )