@@ -195,6 +195,7 @@ def __init__(self,
195195 self .head_dim ,
196196 num_kv_heads = self .num_key_value_heads ,
197197 v_head_size = self .head_dim ,
198+ causal = False ,
198199 )
199200
200201 self .q_norm = RMSNorm (self .head_dim , eps = config .rms_norm_eps )
@@ -1128,10 +1129,12 @@ def forward(
11281129
11291130 # Collect intermediate layer outputs from encoder output
11301131 all_intermediate_hidden_states = output [1 ]
1132+ all_intermediate_hidden_states = [
1133+ all_intermediate_hidden_states [i ]
1134+ for i in self .intermediate_layers_indices
1135+ ]
11311136 intermediate_hidden_states = torch .stack (
11321137 all_intermediate_hidden_states , dim = - 1 )
1133- intermediate_hidden_states = intermediate_hidden_states [
1134- ..., self .intermediate_layers_indices ]
11351138
11361139 # Remove padding from intermediate hidden states
11371140 intermediate_hidden_states = intermediate_hidden_states .reshape (
@@ -1196,8 +1199,8 @@ def __init__(self,
11961199 # preprocessor
11971200 self .input_processor = MLlamaInputProcessor (self .config , dtype )
11981201
1199- def flat_encoder_result (self , cross_attention_states : torch . Tensor ,
1200- attn_metadata : Any , input_ids : torch .LongTensor ):
1202+ def flat_encoder_result (self , attn_metadata : Any ,
1203+ input_ids : torch .LongTensor ):
12011204 # since every state share the same shape
12021205 full_text_row_masked_out_mask = torch .ones (
12031206 (attn_metadata .q_seqlens .sum (), 1 ), dtype = torch .bool )
@@ -1208,9 +1211,9 @@ def flat_encoder_result(self, cross_attention_states: torch.Tensor,
12081211 full_text_row_masked_out_mask [start_pos :img_id ] = False
12091212 start_pos += q_seq_len
12101213 full_text_row_masked_out_mask = full_text_row_masked_out_mask .to (
1211- cross_attention_states .device )
1214+ input_ids .device )
12121215
1213- return cross_attention_states , full_text_row_masked_out_mask
1216+ return full_text_row_masked_out_mask
12141217
12151218 def forward (
12161219 self ,
@@ -1227,6 +1230,19 @@ def forward(
12271230 ):
12281231 """model forward, return logits."""
12291232
1233+ if cross_attn_metadata is None :
1234+ full_text_row_masked_out_mask = None
1235+ # FIXME basically, we want to inference
1236+ # text requests and image requests separately
1237+ elif pixel_values is None and (cross_attn_metadata .kv_seqlens is None ):
1238+ full_text_row_masked_out_mask = None
1239+ elif cross_attn_metadata .is_decoding :
1240+ full_text_row_masked_out_mask = input_ids .new_ones (
1241+ input_ids .size (- 1 ), 1 )
1242+ else :
1243+ full_text_row_masked_out_mask = self .flat_encoder_result (
1244+ cross_attn_metadata , input_ids ) # noqa
1245+
12301246 cross_attention_states = None
12311247 if pixel_values is not None :
12321248 cross_attention_states = self .vision_model (
@@ -1240,21 +1256,6 @@ def forward(
12401256 cross_attention_states = cross_attention_states .view (
12411257 bsz , - 1 , image_token_dim )
12421258
1243- if cross_attn_metadata is None :
1244- full_text_row_masked_out_mask = None
1245- # FIXME basically, we want to inference
1246- # text requests and image requests separately
1247- elif cross_attention_states is None and (cross_attn_metadata .kv_seqlens
1248- is None ):
1249- full_text_row_masked_out_mask = None
1250- elif cross_attn_metadata .is_decoding :
1251- full_text_row_masked_out_mask = input_ids .new_ones (
1252- input_ids .size (- 1 ), 1 )
1253- else :
1254- (cross_attention_states ,
1255- full_text_row_masked_out_mask ) = self .flat_encoder_result (
1256- cross_attention_states , cross_attn_metadata ,
1257- input_ids ) # noqa
12581259 hidden_states = self .language_model (
12591260 input_ids = input_ids ,
12601261 position_ids = position_ids ,
0 commit comments