Skip to content

Commit dd8b38e

Browse files
committed
add Qwen3_moe and cleanup
Signed-off-by: Mamta Singh <[email protected]>
1 parent dc877cc commit dd8b38e

File tree

18 files changed

+135
-659
lines changed

18 files changed

+135
-659
lines changed

QEfficient/transformers/modeling_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
LlamaForCausalLM,
5151
LlamaModel,
5252
LlamaRMSNorm,
53+
LlamaRotaryEmbedding,
5354
)
5455
from transformers.models.mistral.modeling_mistral import (
5556
MistralAttention,
@@ -93,7 +94,7 @@
9394
# Placeholder for all non-transformer models
9495
from .models.codegen.modeling_codegen import (
9596
QEffCodeGenAttention,
96-
QeffCodeGenBlock,
97+
QEffCodeGenBlock,
9798
QEffCodeGenForCausalLM,
9899
QEffCodeGenModel,
99100
)
@@ -122,6 +123,7 @@
122123
QEffLlamaDecoderLayer,
123124
QEffLlamaForCausalLM,
124125
QEffLlamaModel,
126+
QEffLlamaRotaryEmbedding,
125127
)
126128
from .models.mistral.modeling_mistral import (
127129
QEffMistralAttention,
@@ -203,6 +205,7 @@
203205
LlamaForCausalLM: QEffLlamaForCausalLM,
204206
LlamaDecoderLayer: QEffLlamaDecoderLayer,
205207
LlamaRMSNorm: CustomRMSNormAIC,
208+
LlamaRotaryEmbedding: QEffLlamaRotaryEmbedding,
206209
# Gemma model layers
207210
GemmaModel: QEffGemmaModel,
208211
GemmaAttention: QEffGemmaAttention,
@@ -224,7 +227,7 @@
224227
CodeGenAttention: QEffCodeGenAttention,
225228
CodeGenModel: QEffCodeGenModel,
226229
CodeGenForCausalLM: QEffCodeGenForCausalLM,
227-
CodeGenBlock: QeffCodeGenBlock,
230+
CodeGenBlock: QEffCodeGenBlock,
228231
# Mistral model layers
229232
MistralAttention: QEffMistralAttention,
230233
MistralDecoderLayer: QEffMistralDecoderLayer,

QEfficient/transformers/models/codegen/modeling_codegen.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,7 @@ def forward(
123123
query = query.permute(0, 2, 1, 3)
124124

125125
if layer_past is not None:
126-
cache_kwargs = {
127-
"position_ids": position_ids,
128-
"batch_index": batch_index,
129-
}
126+
cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index}
130127
key, value = layer_past.update(key.to(hidden_states.dtype), value, self.layer_idx, cache_kwargs)
131128

132129
# compute self-attention: V x Softmax(QK^T)
@@ -163,12 +160,6 @@ def forward(
163160
cache_position: Optional[torch.LongTensor] = None,
164161
**kwargs, # NOOP kwargs, for now
165162
) -> Union[tuple, BaseModelOutputWithPast]:
166-
r"""
167-
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*):
168-
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
169-
is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
170-
model's internal embedding lookup matrix.
171-
"""
172163
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
173164
output_hidden_states = (
174165
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -316,12 +307,6 @@ def forward(
316307
return_dict: Optional[bool] = None,
317308
cache_position: Optional[torch.LongTensor] = None,
318309
) -> Union[Tuple, CausalLMOutputWithPast]:
319-
r"""
320-
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
321-
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
322-
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
323-
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
324-
"""
325310
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
326311

327312
transformer_outputs = self.transformer(
@@ -358,9 +343,7 @@ def forward(
358343
)
359344

360345

361-
class QeffCodeGenBlock(CodeGenBlock):
362-
# Ignore copy
363-
346+
class QEffCodeGenBlock(CodeGenBlock):
364347
def forward(
365348
self,
366349
hidden_states: Optional[torch.FloatTensor],

QEfficient/transformers/models/falcon/modeling_falcon.py

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def forward(
140140
query_layer, key_layer = qeff_apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids)
141141

142142
if layer_past is not None:
143-
cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids}
143+
cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids}
144144
key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs)
145145

146146
if attention_mask is not None:
@@ -161,10 +161,7 @@ def forward(
161161

162162
attn_output = self.dense(attn_output)
163163

164-
if output_attentions:
165-
return attn_output, layer_past, attention_scores
166-
else:
167-
return attn_output, layer_past
164+
return attn_output, attention_scores
168165

169166

170167
class QEffFalconDecoderLayer(FalconDecoderLayer):
@@ -192,7 +189,7 @@ def forward(
192189
attention_layernorm_out = self.input_layernorm(hidden_states)
193190

194191
# Self attention.
195-
attn_outputs = self.self_attention(
192+
attention_output, attn_weights = self.self_attention(
196193
attention_layernorm_out,
197194
layer_past=layer_past,
198195
attention_mask=attention_mask,
@@ -206,8 +203,6 @@ def forward(
206203
cache_position=cache_position,
207204
)
208205

209-
attention_output = attn_outputs[0]
210-
211206
if not self.config.new_decoder_architecture:
212207
if self.config.parallel_attn:
213208
mlp_layernorm_out = attention_layernorm_out
@@ -224,8 +219,6 @@ def forward(
224219
):
225220
mlp_layernorm_out = attention_layernorm_out
226221

227-
outputs = attn_outputs[1:]
228-
229222
# MLP.
230223
mlp_output = self.mlp(mlp_layernorm_out)
231224

@@ -234,12 +227,7 @@ def forward(
234227

235228
output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)
236229

237-
if use_cache:
238-
outputs = (output,) + outputs
239-
else:
240-
outputs = (output,) + outputs[1:]
241-
242-
return outputs # hidden_states, past_kv, attentions
230+
return output, attn_weights
243231

244232

245233
class QEffFalconModel(FalconModel):
@@ -366,22 +354,13 @@ def forward(
366354
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
367355
head_mask: Optional[torch.Tensor] = None,
368356
inputs_embeds: Optional[torch.Tensor] = None,
369-
labels: Optional[torch.Tensor] = None,
370357
use_cache: Optional[bool] = None,
371358
output_attentions: Optional[bool] = None,
372359
output_hidden_states: Optional[bool] = None,
373360
return_dict: Optional[bool] = None,
374361
cache_position: Optional[torch.LongTensor] = None,
375-
logits_to_keep: Union[int, torch.Tensor] = 0,
376362
**kwargs,
377363
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
378-
r"""
379-
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
380-
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
381-
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
382-
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
383-
"""
384-
385364
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
386365

387366
transformer_outputs = self.transformer(

QEfficient/transformers/models/gemma/modeling_gemma.py

Lines changed: 11 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#
66
# -----------------------------------------------------------------------------
77

8-
from typing import Callable, List, Optional, Tuple, Union
8+
from typing import List, Optional, Tuple, Union
99

1010
import torch
1111
from torch import nn
@@ -104,7 +104,6 @@ def eager_attention_forward(
104104
value: torch.Tensor,
105105
attention_mask: Optional[torch.Tensor],
106106
scaling: float,
107-
**kwargs,
108107
):
109108
key_states = repeat_kv(key, module.num_key_value_groups)
110109
value_states = repeat_kv(value, module.num_key_value_groups)
@@ -154,11 +153,10 @@ def forward(
154153
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
155154

156155
if past_key_value is not None:
157-
# sin and cos are specific to RoPE models; cache_position needed for the static cache
158-
cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids}
156+
cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids}
159157
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
160158

161-
attention_interface: Callable = eager_attention_forward
159+
attention_interface = eager_attention_forward
162160

163161
attn_output, attn_weights = attention_interface(
164162
self,
@@ -167,12 +165,12 @@ def forward(
167165
value_states,
168166
attention_mask,
169167
scaling=self.scaling,
170-
**kwargs,
171168
)
172169

173170
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
174171
attn_output = self.o_proj(attn_output)
175-
return attn_output, attn_weights, past_key_value
172+
173+
return attn_output, attn_weights
176174

177175

178176
class QEffGemmaDecoderLayer(GemmaDecoderLayer):
@@ -189,7 +187,6 @@ def forward(
189187
position_ids: Optional[torch.LongTensor] = None,
190188
past_key_value: Optional[Cache] = None,
191189
batch_index: Optional[torch.LongTensor] = None,
192-
output_attentions: Optional[bool] = False,
193190
use_cache: Optional[bool] = False,
194191
cache_position: Optional[torch.LongTensor] = None,
195192
**kwargs,
@@ -200,9 +197,6 @@ def forward(
200197
attention_mask (`torch.FloatTensor`, *optional*):
201198
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
202199
query_sequence_length, key_sequence_length)` if default attention is used.
203-
output_attentions (`bool`, *optional*):
204-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
205-
returned tensors for more detail.
206200
use_cache (`bool`, *optional*):
207201
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
208202
(see `past_key_values`).
@@ -215,13 +209,12 @@ def forward(
215209
hidden_states = self.input_layernorm(hidden_states)
216210

217211
# Self Attention
218-
hidden_states, self_attn_weights, present_key_value = self.self_attn(
212+
hidden_states, _ = self.self_attn(
219213
hidden_states=hidden_states,
220214
attention_mask=attention_mask,
221215
position_ids=position_ids,
222216
past_key_value=past_key_value,
223217
batch_index=batch_index,
224-
output_attentions=output_attentions,
225218
use_cache=use_cache,
226219
cache_position=cache_position,
227220
**kwargs,
@@ -234,15 +227,7 @@ def forward(
234227
hidden_states = self.mlp(hidden_states)
235228
hidden_states = residual + hidden_states
236229

237-
outputs = (hidden_states,)
238-
239-
if output_attentions:
240-
outputs += (self_attn_weights,)
241-
242-
if use_cache:
243-
outputs += (present_key_value,)
244-
245-
return outputs
230+
return hidden_states
246231

247232

248233
class QEffGemmaModel(GemmaModel):
@@ -261,18 +246,14 @@ def forward(
261246
batch_index: Optional[torch.LongTensor] = None,
262247
inputs_embeds: Optional[torch.FloatTensor] = None,
263248
use_cache: Optional[bool] = None,
264-
output_attentions: Optional[bool] = None,
265249
output_hidden_states: Optional[bool] = None,
266-
return_dict: Optional[bool] = None,
267250
cache_position: Optional[torch.LongTensor] = None,
268251
**kwargs,
269252
) -> Union[Tuple, BaseModelOutputWithPast]:
270-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
271253
output_hidden_states = (
272254
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
273255
)
274256
use_cache = use_cache if use_cache is not None else self.config.use_cache
275-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
276257

277258
if (input_ids is None) ^ (inputs_embeds is not None):
278259
raise ValueError(
@@ -308,27 +289,21 @@ def forward(
308289

309290
# decoder layers
310291
all_hidden_states = () if output_hidden_states else None
311-
all_self_attns = () if output_attentions else None
312292

313293
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
314294
if output_hidden_states:
315295
all_hidden_states += (hidden_states,)
316296

317-
layer_outputs = decoder_layer(
297+
hidden_states = decoder_layer(
318298
hidden_states,
319299
attention_mask=causal_mask,
320300
position_ids=position_ids,
321301
past_key_value=past_key_values,
322302
batch_index=batch_index,
323-
output_attentions=output_attentions,
324303
use_cache=use_cache,
325304
cache_position=cache_position,
326305
**kwargs,
327306
)
328-
hidden_states = layer_outputs[0]
329-
330-
if output_attentions:
331-
all_self_attns += (layer_outputs[1],)
332307

333308
hidden_states = self.norm(hidden_states)
334309

@@ -339,13 +314,11 @@ def forward(
339314
if return_legacy_cache:
340315
past_key_values = past_key_values.to_legacy_cache()
341316

342-
output = BaseModelOutputWithPast(
317+
return BaseModelOutputWithPast(
343318
last_hidden_state=hidden_states,
344319
past_key_values=past_key_values if use_cache else None,
345320
hidden_states=all_hidden_states,
346-
attentions=all_self_attns,
347321
)
348-
return output if return_dict else output.to_tuple()
349322

350323

351324
class QEffGemmaForCausalLM(GemmaForCausalLM):
@@ -363,21 +336,14 @@ def forward(
363336
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
364337
batch_index: Optional[torch.LongTensor] = None,
365338
inputs_embeds: Optional[torch.FloatTensor] = None,
366-
labels: Optional[torch.LongTensor] = None,
367339
use_cache: Optional[bool] = None,
368-
output_attentions: Optional[bool] = None,
369340
output_hidden_states: Optional[bool] = None,
370-
return_dict: Optional[bool] = None,
371341
cache_position: Optional[torch.LongTensor] = None,
372-
logits_to_keep: Union[int, torch.Tensor] = 0,
373342
**kwargs,
374343
) -> Union[Tuple, CausalLMOutputWithPast]:
375-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
376344
output_hidden_states = (
377345
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
378346
)
379-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
380-
381347
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
382348
outputs = self.model(
383349
input_ids=input_ids,
@@ -387,19 +353,15 @@ def forward(
387353
batch_index=batch_index,
388354
inputs_embeds=inputs_embeds,
389355
use_cache=use_cache,
390-
output_attentions=output_attentions,
391356
output_hidden_states=output_hidden_states,
392-
return_dict=return_dict,
393357
cache_position=cache_position,
394358
**kwargs,
395359
)
396360

397361
# Cast to INT32 to avoid issue while running in ONNXRT
398362
logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True)
399-
hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index]
400-
401-
logits = self.lm_head(hidden_states).float()
402-
logits = logits.float()
363+
hidden_states = outputs.last_hidden_state[torch.arange(position_ids.shape[0]).view(-1, 1), logit_index]
364+
logits = self.lm_head(hidden_states).float().float()
403365

404366
return CausalLMOutputWithPast(
405367
loss=None,

0 commit comments

Comments
 (0)