5
5
#
6
6
# -----------------------------------------------------------------------------
7
7
8
- from typing import Callable , List , Optional , Tuple , Union
8
+ from typing import List , Optional , Tuple , Union
9
9
10
10
import torch
11
11
from torch import nn
@@ -104,7 +104,6 @@ def eager_attention_forward(
104
104
value : torch .Tensor ,
105
105
attention_mask : Optional [torch .Tensor ],
106
106
scaling : float ,
107
- ** kwargs ,
108
107
):
109
108
key_states = repeat_kv (key , module .num_key_value_groups )
110
109
value_states = repeat_kv (value , module .num_key_value_groups )
@@ -154,11 +153,10 @@ def forward(
154
153
query_states , key_states = qeff_apply_rotary_pos_emb (query_states , key_states , cos , sin , position_ids )
155
154
156
155
if past_key_value is not None :
157
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
158
- cache_kwargs = {"sin" : sin , "cos" : cos , "batch_index" : batch_index , "position_ids" : position_ids }
156
+ cache_kwargs = {"batch_index" : batch_index , "position_ids" : position_ids }
159
157
key_states , value_states = past_key_value .update (key_states , value_states , self .layer_idx , cache_kwargs )
160
158
161
- attention_interface : Callable = eager_attention_forward
159
+ attention_interface = eager_attention_forward
162
160
163
161
attn_output , attn_weights = attention_interface (
164
162
self ,
@@ -167,12 +165,12 @@ def forward(
167
165
value_states ,
168
166
attention_mask ,
169
167
scaling = self .scaling ,
170
- ** kwargs ,
171
168
)
172
169
173
170
attn_output = attn_output .reshape (* input_shape , - 1 ).contiguous ()
174
171
attn_output = self .o_proj (attn_output )
175
- return attn_output , attn_weights , past_key_value
172
+
173
+ return attn_output , attn_weights
176
174
177
175
178
176
class QEffGemmaDecoderLayer (GemmaDecoderLayer ):
@@ -189,7 +187,6 @@ def forward(
189
187
position_ids : Optional [torch .LongTensor ] = None ,
190
188
past_key_value : Optional [Cache ] = None ,
191
189
batch_index : Optional [torch .LongTensor ] = None ,
192
- output_attentions : Optional [bool ] = False ,
193
190
use_cache : Optional [bool ] = False ,
194
191
cache_position : Optional [torch .LongTensor ] = None ,
195
192
** kwargs ,
@@ -200,9 +197,6 @@ def forward(
200
197
attention_mask (`torch.FloatTensor`, *optional*):
201
198
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
202
199
query_sequence_length, key_sequence_length)` if default attention is used.
203
- output_attentions (`bool`, *optional*):
204
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
205
- returned tensors for more detail.
206
200
use_cache (`bool`, *optional*):
207
201
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
208
202
(see `past_key_values`).
@@ -215,13 +209,12 @@ def forward(
215
209
hidden_states = self .input_layernorm (hidden_states )
216
210
217
211
# Self Attention
218
- hidden_states , self_attn_weights , present_key_value = self .self_attn (
212
+ hidden_states , _ = self .self_attn (
219
213
hidden_states = hidden_states ,
220
214
attention_mask = attention_mask ,
221
215
position_ids = position_ids ,
222
216
past_key_value = past_key_value ,
223
217
batch_index = batch_index ,
224
- output_attentions = output_attentions ,
225
218
use_cache = use_cache ,
226
219
cache_position = cache_position ,
227
220
** kwargs ,
@@ -234,15 +227,7 @@ def forward(
234
227
hidden_states = self .mlp (hidden_states )
235
228
hidden_states = residual + hidden_states
236
229
237
- outputs = (hidden_states ,)
238
-
239
- if output_attentions :
240
- outputs += (self_attn_weights ,)
241
-
242
- if use_cache :
243
- outputs += (present_key_value ,)
244
-
245
- return outputs
230
+ return hidden_states
246
231
247
232
248
233
class QEffGemmaModel (GemmaModel ):
@@ -261,18 +246,14 @@ def forward(
261
246
batch_index : Optional [torch .LongTensor ] = None ,
262
247
inputs_embeds : Optional [torch .FloatTensor ] = None ,
263
248
use_cache : Optional [bool ] = None ,
264
- output_attentions : Optional [bool ] = None ,
265
249
output_hidden_states : Optional [bool ] = None ,
266
- return_dict : Optional [bool ] = None ,
267
250
cache_position : Optional [torch .LongTensor ] = None ,
268
251
** kwargs ,
269
252
) -> Union [Tuple , BaseModelOutputWithPast ]:
270
- output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
271
253
output_hidden_states = (
272
254
output_hidden_states if output_hidden_states is not None else self .config .output_hidden_states
273
255
)
274
256
use_cache = use_cache if use_cache is not None else self .config .use_cache
275
- return_dict = return_dict if return_dict is not None else self .config .use_return_dict
276
257
277
258
if (input_ids is None ) ^ (inputs_embeds is not None ):
278
259
raise ValueError (
@@ -308,27 +289,21 @@ def forward(
308
289
309
290
# decoder layers
310
291
all_hidden_states = () if output_hidden_states else None
311
- all_self_attns = () if output_attentions else None
312
292
313
293
for decoder_layer in self .layers [: self .config .num_hidden_layers ]:
314
294
if output_hidden_states :
315
295
all_hidden_states += (hidden_states ,)
316
296
317
- layer_outputs = decoder_layer (
297
+ hidden_states = decoder_layer (
318
298
hidden_states ,
319
299
attention_mask = causal_mask ,
320
300
position_ids = position_ids ,
321
301
past_key_value = past_key_values ,
322
302
batch_index = batch_index ,
323
- output_attentions = output_attentions ,
324
303
use_cache = use_cache ,
325
304
cache_position = cache_position ,
326
305
** kwargs ,
327
306
)
328
- hidden_states = layer_outputs [0 ]
329
-
330
- if output_attentions :
331
- all_self_attns += (layer_outputs [1 ],)
332
307
333
308
hidden_states = self .norm (hidden_states )
334
309
@@ -339,13 +314,11 @@ def forward(
339
314
if return_legacy_cache :
340
315
past_key_values = past_key_values .to_legacy_cache ()
341
316
342
- output = BaseModelOutputWithPast (
317
+ return BaseModelOutputWithPast (
343
318
last_hidden_state = hidden_states ,
344
319
past_key_values = past_key_values if use_cache else None ,
345
320
hidden_states = all_hidden_states ,
346
- attentions = all_self_attns ,
347
321
)
348
- return output if return_dict else output .to_tuple ()
349
322
350
323
351
324
class QEffGemmaForCausalLM (GemmaForCausalLM ):
@@ -363,21 +336,14 @@ def forward(
363
336
past_key_values : Optional [Union [Cache , List [torch .FloatTensor ]]] = None ,
364
337
batch_index : Optional [torch .LongTensor ] = None ,
365
338
inputs_embeds : Optional [torch .FloatTensor ] = None ,
366
- labels : Optional [torch .LongTensor ] = None ,
367
339
use_cache : Optional [bool ] = None ,
368
- output_attentions : Optional [bool ] = None ,
369
340
output_hidden_states : Optional [bool ] = None ,
370
- return_dict : Optional [bool ] = None ,
371
341
cache_position : Optional [torch .LongTensor ] = None ,
372
- logits_to_keep : Union [int , torch .Tensor ] = 0 ,
373
342
** kwargs ,
374
343
) -> Union [Tuple , CausalLMOutputWithPast ]:
375
- output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
376
344
output_hidden_states = (
377
345
output_hidden_states if output_hidden_states is not None else self .config .output_hidden_states
378
346
)
379
- return_dict = return_dict if return_dict is not None else self .config .use_return_dict
380
-
381
347
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
382
348
outputs = self .model (
383
349
input_ids = input_ids ,
@@ -387,19 +353,15 @@ def forward(
387
353
batch_index = batch_index ,
388
354
inputs_embeds = inputs_embeds ,
389
355
use_cache = use_cache ,
390
- output_attentions = output_attentions ,
391
356
output_hidden_states = output_hidden_states ,
392
- return_dict = return_dict ,
393
357
cache_position = cache_position ,
394
358
** kwargs ,
395
359
)
396
360
397
361
# Cast to INT32 to avoid issue while running in ONNXRT
398
362
logit_index = position_ids .to (torch .int32 ).argmax (1 , keepdim = True )
399
- hidden_states = outputs [0 ][torch .arange (position_ids .shape [0 ]).view (- 1 , 1 ), logit_index ]
400
-
401
- logits = self .lm_head (hidden_states ).float ()
402
- logits = logits .float ()
363
+ hidden_states = outputs .last_hidden_state [torch .arange (position_ids .shape [0 ]).view (- 1 , 1 ), logit_index ]
364
+ logits = self .lm_head (hidden_states ).float ().float ()
403
365
404
366
return CausalLMOutputWithPast (
405
367
loss = None ,
0 commit comments