@@ -30,9 +30,11 @@ def __init__(self, encoder_dim, llm_dim, downsample_rate=5):
3030    def  forward (self , x ):
3131
3232        batch_size , seq_len , feat_dim  =  x .size ()
33-         num_padding_frames  =  (self .downsample_rate  -  seq_len  %  self .downsample_rate ) %  self .downsample_rate 
33+         num_padding_frames  =  (
34+             self .downsample_rate  -  seq_len  %  self .downsample_rate 
35+         ) %  self .downsample_rate 
3436        if  num_padding_frames  >  0 :
35-             x  =  torch .nn .functional .pad (x , (0 , 0 , 0 , num_padding_frames ))  
37+             x  =  torch .nn .functional .pad (x , (0 , 0 , 0 , num_padding_frames ))
3638        seq_len  =  x .size (1 )
3739
3840        x  =  x .contiguous ()
@@ -62,12 +64,14 @@ def __init__(
6264        self ,
6365        encoder_embed : nn .Module ,
6466        encoder : EncoderInterface ,
67+         ctc_output : nn .Module ,
6568        llm : nn .Module ,
6669        encoder_projector : nn .Module ,
6770    ):
6871        super ().__init__ ()
6972        self .encoder_embed  =  encoder_embed 
7073        self .encoder  =  encoder 
74+         self .ctc_output  =  ctc_output 
7175        self .llm  =  llm 
7276        self .encoder_projector  =  encoder_projector 
7377
@@ -186,7 +190,7 @@ def _merge_input_ids_with_speech_features(
186190            (final_attention_mask  ==  0 ), 1 
187191        )
188192
189-         # 6. Mask out  the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. 
193+         # 6. Mask compressed_output  the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. 
190194        batch_indices , pad_indices  =  torch .where (
191195            input_ids  ==  self .llm .config .pad_token_id 
192196        )
@@ -230,6 +234,57 @@ def forward_encoder(
230234
231235        return  encoder_out , encoder_out_lens 
232236
237+     def  ctc_compress (
238+         self ,
239+         encoder_out : torch .Tensor ,
240+         encoder_out_lens : torch .Tensor ,
241+         blank_id : int  =  0 ,
242+     ) ->  torch .Tensor :
243+         """ 
244+         Remove frames from encoder_out where CTC argmax predicts blank. 
245+         Args: 
246+           encoder_out: Tensor of shape (N, T, C), encoder output. 
247+           encoder_out_lens: Tensor of shape (N,), lengths before padding. 
248+           blank_id: CTC blank token ID (default: 0). 
249+ 
250+         Returns: 
251+           Compressed CTC output of shape (N, T', C). 
252+         """ 
253+         # 1. Compute CTC argmax predictions 
254+         ctc_output  =  self .ctc_output (encoder_out )
255+         ctc_preds  =  ctc_output .argmax (dim = - 1 )
256+ 
257+         # 2. Create non-blank, non-pad mask 
258+         padding_mask  =  make_pad_mask (encoder_out_lens )
259+         non_blank_mask  =  (ctc_preds  !=  blank_id ) &  (~ padding_mask )
260+ 
261+         # 3. Compute lengths after compress 
262+         compressed_lens  =  non_blank_mask .sum (dim = 1 )
263+         max_len  =  compressed_lens .max ().item ()
264+ 
265+         # 4. Pre-pad output 
266+         pad_lens_list  =  (
267+             torch .full_like (
268+                 compressed_lens ,
269+                 max_len ,
270+                 device = ctc_output .device ,
271+             )
272+             -  compressed_lens 
273+         )
274+         max_pad_len  =  int (pad_lens_list .max ())
275+         padded_ctc_output  =  torch .nn .functional .pad (ctc_output , [0 , 0 , 0 , max_pad_len ])
276+ 
277+         # 5. Create final mask 
278+         padding_mask  =  ~ make_pad_mask (pad_lens_list )
279+         total_mask  =  torch .concat ([non_blank_mask , padding_mask ], dim = 1 )
280+ 
281+         # 6. Apply mask and reshape 
282+         compressed_output  =  padded_ctc_output [total_mask ].reshape (
283+             ctc_output .shape [0 ], - 1 , ctc_output .shape [2 ]
284+         )
285+ 
286+         return  compressed_output 
287+ 
233288    def  forward (
234289        self ,
235290        fbank : torch .Tensor ,
@@ -238,9 +293,11 @@ def forward(
238293        attention_mask : torch .Tensor ,
239294        labels : torch .LongTensor ,
240295    ):
241-         encoder_outs , _  =  self .forward_encoder (fbank , fbank_lens )
296+         encoder_outs , encoder_out_lens  =  self .forward_encoder (fbank , fbank_lens )
242297
243-         speech_features  =  self .encoder_projector (encoder_outs )
298+         compressed_encoder_outs  =  self .ctc_compress (encoder_outs , encoder_out_lens )
299+ 
300+         speech_features  =  self .encoder_projector (compressed_encoder_outs )
244301
245302        inputs_embeds  =  self .llm .get_input_embeddings ()(input_ids )
246303
0 commit comments