@@ -47,8 +47,6 @@ def _attn(
47
47
48
48
attn_weights = torch .matmul (query , key .transpose (- 1 , - 2 ))
49
49
50
- attn_weights = attn_weights / self .scale_attn
51
-
52
50
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
53
51
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
54
52
mask_value = torch .tensor (MIN_MASKED_ATTENTION_VALUE , dtype = attn_weights .dtype ).to (attn_weights .device )
@@ -57,6 +55,7 @@ def _attn(
57
55
# Apply the attention mask
58
56
attn_weights = torch .where (attention_mask , mask_value , attn_weights )
59
57
58
+ attn_weights = attn_weights / self .scale_attn
60
59
attn_weights = nn .Softmax (dim = - 1 )(attn_weights )
61
60
attn_weights = attn_weights .to (value .dtype )
62
61
attn_weights = self .attn_dropout (attn_weights )
@@ -124,36 +123,16 @@ def forward(
124
123
query = query .permute (0 , 2 , 1 , 3 )
125
124
126
125
if layer_past is not None :
127
- # Update the cache_kwargs with position_ids for Cloud AI 100
128
- past_key_value = layer_past
129
- cache_kwargs = {
130
- "position_ids" : position_ids ,
131
- "batch_index" : batch_index ,
132
- }
133
- pkv = QEffDynamicCache ()
134
- pkv .key_cache .append (past_key_value [0 ])
135
- pkv .value_cache .append (past_key_value [1 ])
136
- key , value = pkv .update (key , value , 0 , cache_kwargs )
137
-
138
- if use_cache is True :
139
- # Note that this cast is quite ugly, but is not implemented before ROPE as k_rot in the original codebase is always in fp32.
140
- # Reference: https://github.com/salesforce/CodeGen/blob/f210c3bb1216c975ad858cd4132c0fdeabf4bfc2/codegen1/jaxformer/hf/codegen/modeling_codegen.py#L38
141
- present = (pkv .key_cache [0 ].to (hidden_states .dtype ), pkv .value_cache [0 ])
142
- else :
143
- present = None
126
+ cache_kwargs = {"position_ids" : position_ids , "batch_index" : batch_index }
127
+ key , value = layer_past .update (key .to (hidden_states .dtype ), value , self .layer_idx , cache_kwargs )
144
128
145
129
# compute self-attention: V x Softmax(QK^T)
146
130
attn_output , attn_weights = self ._attn (query , key , value , attention_mask , head_mask )
147
131
148
132
attn_output = self ._merge_heads (attn_output , self .num_attention_heads , self .head_dim )
149
133
attn_output = self .out_proj (attn_output )
150
134
attn_output = self .resid_dropout (attn_output )
151
-
152
- outputs = (attn_output , present )
153
- if output_attentions :
154
- outputs += (attn_weights ,)
155
-
156
- return outputs # a, present, (attentions)
135
+ return attn_output , attn_weights
157
136
158
137
159
138
class QEffCodeGenModel (CodeGenModel ):
@@ -167,7 +146,7 @@ class QEffCodeGenModel(CodeGenModel):
167
146
def forward (
168
147
self ,
169
148
input_ids : Optional [torch .LongTensor ] = None ,
170
- past_key_values : Optional [Tuple [ Tuple [ torch .Tensor ]]] = None ,
149
+ past_key_values : Optional [Union [ Cache , tuple [ tuple [ torch .Tensor ] ]]] = None ,
171
150
attention_mask : Optional [torch .FloatTensor ] = None ,
172
151
token_type_ids : Optional [torch .LongTensor ] = None ,
173
152
batch_index : Optional [torch .LongTensor ] = None ,
@@ -179,7 +158,8 @@ def forward(
179
158
output_hidden_states : Optional [bool ] = None ,
180
159
return_dict : Optional [bool ] = None ,
181
160
cache_position : Optional [torch .LongTensor ] = None ,
182
- ) -> Union [Tuple , BaseModelOutputWithPast ]:
161
+ ** kwargs , # NOOP kwargs, for now
162
+ ) -> Union [tuple , BaseModelOutputWithPast ]:
183
163
output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
184
164
output_hidden_states = (
185
165
output_hidden_states if output_hidden_states is not None else self .config .output_hidden_states
@@ -200,20 +180,21 @@ def forward(
200
180
else :
201
181
raise ValueError ("You have to specify either input_ids or inputs_embeds" )
202
182
203
- device = input_ids .device if input_ids is not None else inputs_embeds .device
183
+ if inputs_embeds is None :
184
+ inputs_embeds = self .wte (input_ids )
204
185
205
- if token_type_ids is not None :
206
- token_type_ids = token_type_ids .view (- 1 , input_shape [- 1 ])
186
+ return_legacy_cache = False
187
+ if use_cache and not isinstance (past_key_values , Cache ):
188
+ return_legacy_cache = True
189
+ past_key_values = QEffDynamicCache .from_legacy_cache (past_key_values )
207
190
208
- if past_key_values is None :
209
- past_length = 0
210
- past_key_values = tuple ([None ] * len (self .h ))
211
- else :
212
- past_length = past_key_values [0 ][0 ].size (- 2 )
191
+ past_seen_tokens = past_key_values .get_seq_length () if past_key_values is not None else 0
192
+ seq_length = inputs_embeds .shape [1 ]
193
+ if cache_position is None :
194
+ cache_position = torch .arange (past_seen_tokens , past_seen_tokens + seq_length , device = inputs_embeds .device )
213
195
214
196
if position_ids is None :
215
- position_ids = torch .arange (past_length , input_shape [- 1 ] + past_length , dtype = torch .long , device = device )
216
- position_ids = position_ids .unsqueeze (0 )
197
+ position_ids = cache_position .unsqueeze (0 )
217
198
218
199
# Attention mask.
219
200
if attention_mask is not None :
@@ -237,40 +218,33 @@ def forward(
237
218
238
219
elif attention_mask is None :
239
220
# 4d mask is passed through the layers
240
- attention_mask = _create_causal_mask (position_ids = position_ids , target_length = past_length )
221
+ attention_mask = _create_causal_mask (position_ids = position_ids , target_length = past_seen_tokens )
241
222
242
223
# Prepare head mask if needed
243
224
# 1.0 in head_mask indicate we keep the head
244
225
# attention_probs has shape bsz x num_attention_heads x N x N
245
226
# head_mask has shape n_layer x batch x num_attention_heads x N x N
246
227
head_mask = self .get_head_mask (head_mask , self .config .n_layer )
247
228
248
- if inputs_embeds is None :
249
- inputs_embeds = self .wte (input_ids )
250
-
251
229
hidden_states = inputs_embeds
252
230
253
231
if token_type_ids is not None :
232
+ token_type_ids = token_type_ids .view (- 1 , seq_length )
254
233
token_type_embeds = self .wte (token_type_ids )
255
234
hidden_states = hidden_states + token_type_embeds
256
235
257
236
hidden_states = self .drop (hidden_states )
237
+ output_shape = (- 1 , seq_length , hidden_states .size (- 1 ))
258
238
259
- output_shape = input_shape + (hidden_states .size (- 1 ),)
260
-
261
- if position_ids is None :
262
- position_ids = cache_position .unsqueeze (0 )
263
-
264
- presents = () if use_cache else None
265
239
all_self_attentions = () if output_attentions else None
266
240
all_hidden_states = () if output_hidden_states else None
267
- for i , ( block , layer_past ) in enumerate (zip ( self .h , past_key_values ) ):
241
+ for i , block in enumerate (self .h ):
268
242
if output_hidden_states :
269
243
all_hidden_states = all_hidden_states + (hidden_states ,)
270
244
271
245
outputs = block (
272
- hidden_states = hidden_states ,
273
- layer_past = layer_past ,
246
+ hidden_states ,
247
+ layer_past = past_key_values ,
274
248
batch_index = batch_index ,
275
249
attention_mask = attention_mask ,
276
250
position_ids = position_ids ,
@@ -281,11 +255,9 @@ def forward(
281
255
)
282
256
283
257
hidden_states = outputs [0 ]
284
- if use_cache is True :
285
- presents = presents + (outputs [1 ],)
286
258
287
259
if output_attentions :
288
- all_self_attentions = all_self_attentions + (outputs [2 if use_cache else 1 ],)
260
+ all_self_attentions = all_self_attentions + (outputs [1 ],)
289
261
290
262
hidden_states = self .ln_f (hidden_states )
291
263
@@ -294,12 +266,17 @@ def forward(
294
266
if output_hidden_states :
295
267
all_hidden_states = all_hidden_states + (hidden_states ,)
296
268
269
+ if return_legacy_cache :
270
+ past_key_values = past_key_values .to_legacy_cache ()
271
+
297
272
if not return_dict :
298
- return tuple (v for v in [hidden_states , presents , all_hidden_states , all_self_attentions ] if v is not None )
273
+ return tuple (
274
+ v for v in [hidden_states , past_key_values , all_hidden_states , all_self_attentions ] if v is not None
275
+ )
299
276
300
277
return BaseModelOutputWithPast (
301
278
last_hidden_state = hidden_states ,
302
- past_key_values = presents ,
279
+ past_key_values = past_key_values ,
303
280
hidden_states = all_hidden_states ,
304
281
attentions = all_self_attentions ,
305
282
)
@@ -330,12 +307,6 @@ def forward(
330
307
return_dict : Optional [bool ] = None ,
331
308
cache_position : Optional [torch .LongTensor ] = None ,
332
309
) -> Union [Tuple , CausalLMOutputWithPast ]:
333
- r"""
334
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
335
- Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
336
- `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
337
- are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
338
- """
339
310
return_dict = return_dict if return_dict is not None else self .config .use_return_dict
340
311
341
312
transformer_outputs = self .transformer (
@@ -372,9 +343,7 @@ def forward(
372
343
)
373
344
374
345
375
- class QeffCodeGenBlock (CodeGenBlock ):
376
- # Ignore copy
377
-
346
+ class QEffCodeGenBlock (CodeGenBlock ):
378
347
def forward (
379
348
self ,
380
349
hidden_states : Optional [torch .FloatTensor ],
@@ -389,7 +358,7 @@ def forward(
389
358
) -> Union [Tuple [torch .Tensor ], Optional [Tuple [torch .Tensor , Tuple [torch .FloatTensor , ...]]]]:
390
359
residual = hidden_states
391
360
hidden_states = self .ln_1 (hidden_states )
392
- attn_outputs = self .attn (
361
+ attn_outputs , attn_weights = self .attn (
393
362
hidden_states = hidden_states ,
394
363
layer_past = layer_past ,
395
364
attention_mask = attention_mask ,
@@ -400,15 +369,8 @@ def forward(
400
369
output_attentions = output_attentions ,
401
370
cache_position = cache_position ,
402
371
)
403
- attn_output = attn_outputs [0 ] # output_attn: a, present, (attentions)
404
- outputs = attn_outputs [1 :]
405
372
406
373
feed_forward_hidden_states = self .mlp (hidden_states )
407
- hidden_states = attn_output + feed_forward_hidden_states + residual
408
-
409
- if use_cache :
410
- outputs = (hidden_states ,) + outputs
411
- else :
412
- outputs = (hidden_states ,) + outputs [1 :]
374
+ hidden_states = attn_outputs + feed_forward_hidden_states + residual
413
375
414
- return outputs # hidden_states, present, (attentions)
376
+ return hidden_states , attn_weights
0 commit comments