Skip to content

Commit 4f44dd4

Browse files
committed
Update modeling files
Signed-off-by: Mamta Singh <[email protected]>
1 parent 6aaa75a commit 4f44dd4

36 files changed

+2287
-2819
lines changed

QEfficient/transformers/cache_utils.py

Lines changed: 148 additions & 99 deletions
Large diffs are not rendered by default.

QEfficient/transformers/modeling_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
LlamaForCausalLM,
5151
LlamaModel,
5252
LlamaRMSNorm,
53+
LlamaRotaryEmbedding,
5354
)
5455
from transformers.models.mistral.modeling_mistral import (
5556
MistralAttention,
@@ -93,7 +94,7 @@
9394
# Placeholder for all non-transformer models
9495
from .models.codegen.modeling_codegen import (
9596
QEffCodeGenAttention,
96-
QeffCodeGenBlock,
97+
QEffCodeGenBlock,
9798
QEffCodeGenForCausalLM,
9899
QEffCodeGenModel,
99100
)
@@ -122,6 +123,7 @@
122123
QEffLlamaDecoderLayer,
123124
QEffLlamaForCausalLM,
124125
QEffLlamaModel,
126+
QEffLlamaRotaryEmbedding,
125127
)
126128
from .models.mistral.modeling_mistral import (
127129
QEffMistralAttention,
@@ -203,6 +205,7 @@
203205
LlamaForCausalLM: QEffLlamaForCausalLM,
204206
LlamaDecoderLayer: QEffLlamaDecoderLayer,
205207
LlamaRMSNorm: CustomRMSNormAIC,
208+
LlamaRotaryEmbedding: QEffLlamaRotaryEmbedding,
206209
# Gemma model layers
207210
GemmaModel: QEffGemmaModel,
208211
GemmaAttention: QEffGemmaAttention,
@@ -224,7 +227,7 @@
224227
CodeGenAttention: QEffCodeGenAttention,
225228
CodeGenModel: QEffCodeGenModel,
226229
CodeGenForCausalLM: QEffCodeGenForCausalLM,
227-
CodeGenBlock: QeffCodeGenBlock,
230+
CodeGenBlock: QEffCodeGenBlock,
228231
# Mistral model layers
229232
MistralAttention: QEffMistralAttention,
230233
MistralDecoderLayer: QEffMistralDecoderLayer,

QEfficient/transformers/models/codegen/modeling_codegen.py

Lines changed: 36 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,6 @@ def _attn(
4747

4848
attn_weights = torch.matmul(query, key.transpose(-1, -2))
4949

50-
attn_weights = attn_weights / self.scale_attn
51-
5250
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
5351
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
5452
mask_value = torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=attn_weights.dtype).to(attn_weights.device)
@@ -57,6 +55,7 @@ def _attn(
5755
# Apply the attention mask
5856
attn_weights = torch.where(attention_mask, mask_value, attn_weights)
5957

58+
attn_weights = attn_weights / self.scale_attn
6059
attn_weights = nn.Softmax(dim=-1)(attn_weights)
6160
attn_weights = attn_weights.to(value.dtype)
6261
attn_weights = self.attn_dropout(attn_weights)
@@ -124,36 +123,16 @@ def forward(
124123
query = query.permute(0, 2, 1, 3)
125124

126125
if layer_past is not None:
127-
# Update the cache_kwargs with position_ids for Cloud AI 100
128-
past_key_value = layer_past
129-
cache_kwargs = {
130-
"position_ids": position_ids,
131-
"batch_index": batch_index,
132-
}
133-
pkv = QEffDynamicCache()
134-
pkv.key_cache.append(past_key_value[0])
135-
pkv.value_cache.append(past_key_value[1])
136-
key, value = pkv.update(key, value, 0, cache_kwargs)
137-
138-
if use_cache is True:
139-
# Note that this cast is quite ugly, but is not implemented before ROPE as k_rot in the original codebase is always in fp32.
140-
# Reference: https://github.com/salesforce/CodeGen/blob/f210c3bb1216c975ad858cd4132c0fdeabf4bfc2/codegen1/jaxformer/hf/codegen/modeling_codegen.py#L38
141-
present = (pkv.key_cache[0].to(hidden_states.dtype), pkv.value_cache[0])
142-
else:
143-
present = None
126+
cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index}
127+
key, value = layer_past.update(key.to(hidden_states.dtype), value, self.layer_idx, cache_kwargs)
144128

145129
# compute self-attention: V x Softmax(QK^T)
146130
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
147131

148132
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
149133
attn_output = self.out_proj(attn_output)
150134
attn_output = self.resid_dropout(attn_output)
151-
152-
outputs = (attn_output, present)
153-
if output_attentions:
154-
outputs += (attn_weights,)
155-
156-
return outputs # a, present, (attentions)
135+
return attn_output, attn_weights
157136

158137

159138
class QEffCodeGenModel(CodeGenModel):
@@ -167,7 +146,7 @@ class QEffCodeGenModel(CodeGenModel):
167146
def forward(
168147
self,
169148
input_ids: Optional[torch.LongTensor] = None,
170-
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
149+
past_key_values: Optional[Union[Cache, tuple[tuple[torch.Tensor]]]] = None,
171150
attention_mask: Optional[torch.FloatTensor] = None,
172151
token_type_ids: Optional[torch.LongTensor] = None,
173152
batch_index: Optional[torch.LongTensor] = None,
@@ -179,7 +158,8 @@ def forward(
179158
output_hidden_states: Optional[bool] = None,
180159
return_dict: Optional[bool] = None,
181160
cache_position: Optional[torch.LongTensor] = None,
182-
) -> Union[Tuple, BaseModelOutputWithPast]:
161+
**kwargs, # NOOP kwargs, for now
162+
) -> Union[tuple, BaseModelOutputWithPast]:
183163
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
184164
output_hidden_states = (
185165
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -200,20 +180,21 @@ def forward(
200180
else:
201181
raise ValueError("You have to specify either input_ids or inputs_embeds")
202182

203-
device = input_ids.device if input_ids is not None else inputs_embeds.device
183+
if inputs_embeds is None:
184+
inputs_embeds = self.wte(input_ids)
204185

205-
if token_type_ids is not None:
206-
token_type_ids = token_type_ids.view(-1, input_shape[-1])
186+
return_legacy_cache = False
187+
if use_cache and not isinstance(past_key_values, Cache):
188+
return_legacy_cache = True
189+
past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values)
207190

208-
if past_key_values is None:
209-
past_length = 0
210-
past_key_values = tuple([None] * len(self.h))
211-
else:
212-
past_length = past_key_values[0][0].size(-2)
191+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
192+
seq_length = inputs_embeds.shape[1]
193+
if cache_position is None:
194+
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=inputs_embeds.device)
213195

214196
if position_ids is None:
215-
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
216-
position_ids = position_ids.unsqueeze(0)
197+
position_ids = cache_position.unsqueeze(0)
217198

218199
# Attention mask.
219200
if attention_mask is not None:
@@ -237,40 +218,33 @@ def forward(
237218

238219
elif attention_mask is None:
239220
# 4d mask is passed through the layers
240-
attention_mask = _create_causal_mask(position_ids=position_ids, target_length=past_length)
221+
attention_mask = _create_causal_mask(position_ids=position_ids, target_length=past_seen_tokens)
241222

242223
# Prepare head mask if needed
243224
# 1.0 in head_mask indicate we keep the head
244225
# attention_probs has shape bsz x num_attention_heads x N x N
245226
# head_mask has shape n_layer x batch x num_attention_heads x N x N
246227
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
247228

248-
if inputs_embeds is None:
249-
inputs_embeds = self.wte(input_ids)
250-
251229
hidden_states = inputs_embeds
252230

253231
if token_type_ids is not None:
232+
token_type_ids = token_type_ids.view(-1, seq_length)
254233
token_type_embeds = self.wte(token_type_ids)
255234
hidden_states = hidden_states + token_type_embeds
256235

257236
hidden_states = self.drop(hidden_states)
237+
output_shape = (-1, seq_length, hidden_states.size(-1))
258238

259-
output_shape = input_shape + (hidden_states.size(-1),)
260-
261-
if position_ids is None:
262-
position_ids = cache_position.unsqueeze(0)
263-
264-
presents = () if use_cache else None
265239
all_self_attentions = () if output_attentions else None
266240
all_hidden_states = () if output_hidden_states else None
267-
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
241+
for i, block in enumerate(self.h):
268242
if output_hidden_states:
269243
all_hidden_states = all_hidden_states + (hidden_states,)
270244

271245
outputs = block(
272-
hidden_states=hidden_states,
273-
layer_past=layer_past,
246+
hidden_states,
247+
layer_past=past_key_values,
274248
batch_index=batch_index,
275249
attention_mask=attention_mask,
276250
position_ids=position_ids,
@@ -281,11 +255,9 @@ def forward(
281255
)
282256

283257
hidden_states = outputs[0]
284-
if use_cache is True:
285-
presents = presents + (outputs[1],)
286258

287259
if output_attentions:
288-
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
260+
all_self_attentions = all_self_attentions + (outputs[1],)
289261

290262
hidden_states = self.ln_f(hidden_states)
291263

@@ -294,12 +266,17 @@ def forward(
294266
if output_hidden_states:
295267
all_hidden_states = all_hidden_states + (hidden_states,)
296268

269+
if return_legacy_cache:
270+
past_key_values = past_key_values.to_legacy_cache()
271+
297272
if not return_dict:
298-
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
273+
return tuple(
274+
v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions] if v is not None
275+
)
299276

300277
return BaseModelOutputWithPast(
301278
last_hidden_state=hidden_states,
302-
past_key_values=presents,
279+
past_key_values=past_key_values,
303280
hidden_states=all_hidden_states,
304281
attentions=all_self_attentions,
305282
)
@@ -330,12 +307,6 @@ def forward(
330307
return_dict: Optional[bool] = None,
331308
cache_position: Optional[torch.LongTensor] = None,
332309
) -> Union[Tuple, CausalLMOutputWithPast]:
333-
r"""
334-
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
335-
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
336-
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
337-
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
338-
"""
339310
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
340311

341312
transformer_outputs = self.transformer(
@@ -372,9 +343,7 @@ def forward(
372343
)
373344

374345

375-
class QeffCodeGenBlock(CodeGenBlock):
376-
# Ignore copy
377-
346+
class QEffCodeGenBlock(CodeGenBlock):
378347
def forward(
379348
self,
380349
hidden_states: Optional[torch.FloatTensor],
@@ -389,7 +358,7 @@ def forward(
389358
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
390359
residual = hidden_states
391360
hidden_states = self.ln_1(hidden_states)
392-
attn_outputs = self.attn(
361+
attn_outputs, attn_weights = self.attn(
393362
hidden_states=hidden_states,
394363
layer_past=layer_past,
395364
attention_mask=attention_mask,
@@ -400,15 +369,8 @@ def forward(
400369
output_attentions=output_attentions,
401370
cache_position=cache_position,
402371
)
403-
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
404-
outputs = attn_outputs[1:]
405372

406373
feed_forward_hidden_states = self.mlp(hidden_states)
407-
hidden_states = attn_output + feed_forward_hidden_states + residual
408-
409-
if use_cache:
410-
outputs = (hidden_states,) + outputs
411-
else:
412-
outputs = (hidden_states,) + outputs[1:]
374+
hidden_states = attn_outputs + feed_forward_hidden_states + residual
413375

414-
return outputs # hidden_states, present, (attentions)
376+
return hidden_states, attn_weights

QEfficient/transformers/models/falcon/modeling_falcon.py

Lines changed: 5 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -135,13 +135,12 @@ def forward(
135135
key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
136136
value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
137137

138-
kv_seq_len = key_layer.shape[-2]
139-
kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
138+
kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position)
140139
cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len)
141140
query_layer, key_layer = qeff_apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids)
142141

143142
if layer_past is not None:
144-
cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids}
143+
cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids}
145144
key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs)
146145

147146
if attention_mask is not None:
@@ -162,10 +161,7 @@ def forward(
162161

163162
attn_output = self.dense(attn_output)
164163

165-
if output_attentions:
166-
return attn_output, layer_past, attention_scores
167-
else:
168-
return attn_output, layer_past
164+
return attn_output, attention_scores
169165

170166

171167
class QEffFalconDecoderLayer(FalconDecoderLayer):
@@ -193,7 +189,7 @@ def forward(
193189
attention_layernorm_out = self.input_layernorm(hidden_states)
194190

195191
# Self attention.
196-
attn_outputs = self.self_attention(
192+
attention_output, attn_weights = self.self_attention(
197193
attention_layernorm_out,
198194
layer_past=layer_past,
199195
attention_mask=attention_mask,
@@ -207,8 +203,6 @@ def forward(
207203
cache_position=cache_position,
208204
)
209205

210-
attention_output = attn_outputs[0]
211-
212206
if not self.config.new_decoder_architecture:
213207
if self.config.parallel_attn:
214208
mlp_layernorm_out = attention_layernorm_out
@@ -225,8 +219,6 @@ def forward(
225219
):
226220
mlp_layernorm_out = attention_layernorm_out
227221

228-
outputs = attn_outputs[1:]
229-
230222
# MLP.
231223
mlp_output = self.mlp(mlp_layernorm_out)
232224

@@ -235,12 +227,7 @@ def forward(
235227

236228
output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)
237229

238-
if use_cache:
239-
outputs = (output,) + outputs
240-
else:
241-
outputs = (output,) + outputs[1:]
242-
243-
return outputs # hidden_states, past_kv, attentions
230+
return output, attn_weights
244231

245232

246233
class QEffFalconModel(FalconModel):
@@ -367,22 +354,13 @@ def forward(
367354
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
368355
head_mask: Optional[torch.Tensor] = None,
369356
inputs_embeds: Optional[torch.Tensor] = None,
370-
labels: Optional[torch.Tensor] = None,
371357
use_cache: Optional[bool] = None,
372358
output_attentions: Optional[bool] = None,
373359
output_hidden_states: Optional[bool] = None,
374360
return_dict: Optional[bool] = None,
375361
cache_position: Optional[torch.LongTensor] = None,
376-
logits_to_keep: Union[int, torch.Tensor] = 0,
377362
**kwargs,
378363
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
379-
r"""
380-
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
381-
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
382-
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
383-
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
384-
"""
385-
386364
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
387365

388366
transformer_outputs = self.transformer(

0 commit comments

Comments
 (0)