Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
247 changes: 148 additions & 99 deletions QEfficient/transformers/cache_utils.py

Large diffs are not rendered by default.

7 changes: 5 additions & 2 deletions QEfficient/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
LlamaForCausalLM,
LlamaModel,
LlamaRMSNorm,
LlamaRotaryEmbedding,
)
from transformers.models.mistral.modeling_mistral import (
MistralAttention,
Expand Down Expand Up @@ -93,7 +94,7 @@
# Placeholder for all non-transformer models
from .models.codegen.modeling_codegen import (
QEffCodeGenAttention,
QeffCodeGenBlock,
QEffCodeGenBlock,
QEffCodeGenForCausalLM,
QEffCodeGenModel,
)
Expand Down Expand Up @@ -122,6 +123,7 @@
QEffLlamaDecoderLayer,
QEffLlamaForCausalLM,
QEffLlamaModel,
QEffLlamaRotaryEmbedding,
)
from .models.mistral.modeling_mistral import (
QEffMistralAttention,
Expand Down Expand Up @@ -203,6 +205,7 @@
LlamaForCausalLM: QEffLlamaForCausalLM,
LlamaDecoderLayer: QEffLlamaDecoderLayer,
LlamaRMSNorm: CustomRMSNormAIC,
LlamaRotaryEmbedding: QEffLlamaRotaryEmbedding,
# Gemma model layers
GemmaModel: QEffGemmaModel,
GemmaAttention: QEffGemmaAttention,
Expand All @@ -224,7 +227,7 @@
CodeGenAttention: QEffCodeGenAttention,
CodeGenModel: QEffCodeGenModel,
CodeGenForCausalLM: QEffCodeGenForCausalLM,
CodeGenBlock: QeffCodeGenBlock,
CodeGenBlock: QEffCodeGenBlock,
# Mistral model layers
MistralAttention: QEffMistralAttention,
MistralDecoderLayer: QEffMistralDecoderLayer,
Expand Down
110 changes: 36 additions & 74 deletions QEfficient/transformers/models/codegen/modeling_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ def _attn(

attn_weights = torch.matmul(query, key.transpose(-1, -2))

attn_weights = attn_weights / self.scale_attn

# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=attn_weights.dtype).to(attn_weights.device)
Expand All @@ -57,6 +55,7 @@ def _attn(
# Apply the attention mask
attn_weights = torch.where(attention_mask, mask_value, attn_weights)

attn_weights = attn_weights / self.scale_attn
attn_weights = nn.Softmax(dim=-1)(attn_weights)
attn_weights = attn_weights.to(value.dtype)
attn_weights = self.attn_dropout(attn_weights)
Expand Down Expand Up @@ -124,36 +123,16 @@ def forward(
query = query.permute(0, 2, 1, 3)

if layer_past is not None:
# Update the cache_kwargs with position_ids for Cloud AI 100
past_key_value = layer_past
cache_kwargs = {
"position_ids": position_ids,
"batch_index": batch_index,
}
pkv = QEffDynamicCache()
pkv.key_cache.append(past_key_value[0])
pkv.value_cache.append(past_key_value[1])
key, value = pkv.update(key, value, 0, cache_kwargs)

if use_cache is True:
# Note that this cast is quite ugly, but is not implemented before ROPE as k_rot in the original codebase is always in fp32.
# Reference: https://github.com/salesforce/CodeGen/blob/f210c3bb1216c975ad858cd4132c0fdeabf4bfc2/codegen1/jaxformer/hf/codegen/modeling_codegen.py#L38
present = (pkv.key_cache[0].to(hidden_states.dtype), pkv.value_cache[0])
else:
present = None
cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index}
key, value = layer_past.update(key.to(hidden_states.dtype), value, self.layer_idx, cache_kwargs)

# compute self-attention: V x Softmax(QK^T)
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)

attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
attn_output = self.out_proj(attn_output)
attn_output = self.resid_dropout(attn_output)

outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)

return outputs # a, present, (attentions)
return attn_output, attn_weights


class QEffCodeGenModel(CodeGenModel):
Expand All @@ -167,7 +146,7 @@ class QEffCodeGenModel(CodeGenModel):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
past_key_values: Optional[Union[Cache, tuple[tuple[torch.Tensor]]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
Expand All @@ -179,7 +158,8 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
**kwargs, # NOOP kwargs, for now
) -> Union[tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
Expand All @@ -200,20 +180,21 @@ def forward(
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")

device = input_ids.device if input_ids is not None else inputs_embeds.device
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)

if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
return_legacy_cache = True
past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values)

if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
seq_length = inputs_embeds.shape[1]
if cache_position is None:
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=inputs_embeds.device)

if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0)
position_ids = cache_position.unsqueeze(0)

# Attention mask.
if attention_mask is not None:
Expand All @@ -237,40 +218,33 @@ def forward(

elif attention_mask is None:
# 4d mask is passed through the layers
attention_mask = _create_causal_mask(position_ids=position_ids, target_length=past_length)
attention_mask = _create_causal_mask(position_ids=position_ids, target_length=past_seen_tokens)

# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x num_attention_heads x N x N
# head_mask has shape n_layer x batch x num_attention_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)

if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)

hidden_states = inputs_embeds

if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, seq_length)
token_type_embeds = self.wte(token_type_ids)
hidden_states = hidden_states + token_type_embeds

hidden_states = self.drop(hidden_states)
output_shape = (-1, seq_length, hidden_states.size(-1))

output_shape = input_shape + (hidden_states.size(-1),)

if position_ids is None:
position_ids = cache_position.unsqueeze(0)

presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
for i, block in enumerate(self.h):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

outputs = block(
hidden_states=hidden_states,
layer_past=layer_past,
hidden_states,
layer_past=past_key_values,
batch_index=batch_index,
attention_mask=attention_mask,
position_ids=position_ids,
Expand All @@ -281,11 +255,9 @@ def forward(
)

hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)

if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
all_self_attentions = all_self_attentions + (outputs[1],)

hidden_states = self.ln_f(hidden_states)

Expand All @@ -294,12 +266,17 @@ def forward(
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

if return_legacy_cache:
past_key_values = past_key_values.to_legacy_cache()

if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
return tuple(
v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions] if v is not None
)

return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
past_key_values=past_key_values,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
Expand Down Expand Up @@ -330,12 +307,6 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

transformer_outputs = self.transformer(
Expand Down Expand Up @@ -372,9 +343,7 @@ def forward(
)


class QeffCodeGenBlock(CodeGenBlock):
# Ignore copy

class QEffCodeGenBlock(CodeGenBlock):
def forward(
self,
hidden_states: Optional[torch.FloatTensor],
Expand All @@ -389,7 +358,7 @@ def forward(
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_outputs = self.attn(
attn_outputs, attn_weights = self.attn(
hidden_states=hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
Expand All @@ -400,15 +369,8 @@ def forward(
output_attentions=output_attentions,
cache_position=cache_position,
)
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
outputs = attn_outputs[1:]

feed_forward_hidden_states = self.mlp(hidden_states)
hidden_states = attn_output + feed_forward_hidden_states + residual

if use_cache:
outputs = (hidden_states,) + outputs
else:
outputs = (hidden_states,) + outputs[1:]
hidden_states = attn_outputs + feed_forward_hidden_states + residual

return outputs # hidden_states, present, (attentions)
return hidden_states, attn_weights
32 changes: 5 additions & 27 deletions QEfficient/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,12 @@ def forward(
key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)

kv_seq_len = key_layer.shape[-2]
kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position)
cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len)
query_layer, key_layer = qeff_apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids)

if layer_past is not None:
cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids}
cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids}
key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs)

if attention_mask is not None:
Expand All @@ -162,10 +161,7 @@ def forward(

attn_output = self.dense(attn_output)

if output_attentions:
return attn_output, layer_past, attention_scores
else:
return attn_output, layer_past
return attn_output, attention_scores


class QEffFalconDecoderLayer(FalconDecoderLayer):
Expand Down Expand Up @@ -193,7 +189,7 @@ def forward(
attention_layernorm_out = self.input_layernorm(hidden_states)

# Self attention.
attn_outputs = self.self_attention(
attention_output, attn_weights = self.self_attention(
attention_layernorm_out,
layer_past=layer_past,
attention_mask=attention_mask,
Expand All @@ -207,8 +203,6 @@ def forward(
cache_position=cache_position,
)

attention_output = attn_outputs[0]

if not self.config.new_decoder_architecture:
if self.config.parallel_attn:
mlp_layernorm_out = attention_layernorm_out
Expand All @@ -225,8 +219,6 @@ def forward(
):
mlp_layernorm_out = attention_layernorm_out

outputs = attn_outputs[1:]

# MLP.
mlp_output = self.mlp(mlp_layernorm_out)

Expand All @@ -235,12 +227,7 @@ def forward(

output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)

if use_cache:
outputs = (output,) + outputs
else:
outputs = (output,) + outputs[1:]

return outputs # hidden_states, past_kv, attentions
return output, attn_weights


class QEffFalconModel(FalconModel):
Expand Down Expand Up @@ -367,22 +354,13 @@ def forward(
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

transformer_outputs = self.transformer(
Expand Down
Loading
Loading