Skip to content

Commit 9641159

Browse files
committed
update llama4
Signed-off-by: Mamta Singh <[email protected]>
1 parent d74ae4c commit 9641159

File tree

1 file changed

+16
-68
lines changed

1 file changed

+16
-68
lines changed

QEfficient/transformers/models/llama4/modeling_llama4.py

Lines changed: 16 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# -----------------------------------------------------------------------------
77

88
import math
9-
from typing import Callable, List, Optional, Tuple, Union
9+
from typing import List, Optional, Tuple, Union
1010

1111
import torch
1212
from torch import nn
@@ -52,16 +52,14 @@ def eager_attention_forward_vision(
5252
key_states = repeat_kv(key, module.num_key_value_groups)
5353
value_states = repeat_kv(value, module.num_key_value_groups)
5454
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) / math.sqrt(module.head_dim)
55-
if attention_mask is not None:
56-
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
57-
attn_weights = attn_weights + causal_mask
55+
5856
if attention_mask is not None:
5957
attn_weights = torch.where(
6058
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
6159
)
6260

6361
attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1).to(query.dtype)
64-
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
62+
6563
attn_output = torch.matmul(attn_weights, value_states)
6664
attn_output = attn_output.transpose(1, 2).contiguous()
6765

@@ -183,7 +181,7 @@ def forward(
183181
key_states = key_states.transpose(1, 2)
184182
value_states = value_states.transpose(1, 2)
185183

186-
attention_interface: Callable = eager_attention_forward_vision
184+
attention_interface = eager_attention_forward_vision
187185

188186
attn_output, attn_weights = attention_interface(
189187
self,
@@ -378,7 +376,6 @@ def eager_attention_forward(
378376
value: torch.Tensor,
379377
attention_mask: Optional[torch.Tensor],
380378
scaling: float,
381-
**kwargs,
382379
):
383380
key_states = repeat_kv(key, module.num_key_value_groups)
384381
value_states = repeat_kv(value, module.num_key_value_groups)
@@ -403,55 +400,16 @@ def __qeff_init__(self):
403400

404401

405402
class QEffLlama4TextMoe(Llama4TextMoe):
406-
def forward(self, hidden: torch.Tensor):
407-
B, S, H = hidden.shape
408-
T = B * S
409-
hidden = hidden.view(T, H)
410-
411-
router_logits = self.router(hidden)
412-
# *top-k = 1* → LLama4
413-
top_w, top_i = torch.topk(router_logits, self.top_k, dim=-1) # both [T, K]
414-
masked_logits = torch.full_like(router_logits, float("-inf"))
415-
masked_logits.scatter_(1, top_i, top_w)
416-
417-
# Here we multiply by scores before experts, different only for Llama4
418-
x = hidden * torch.sigmoid(top_w.float())
419-
420-
# ── Book-keeping: create one boolean mask per expert once ───────────────
421-
# routing_weights[e] == True where token routed to that expert. Shape [E, T]
422-
routing_weights = torch.sigmoid(masked_logits.float()).to(hidden.dtype)
423-
424-
# ────────────────── allocate the two big tensors ─────
425-
ffn_dim = self.experts.intermediate_size # = 8/3 · H
426-
upgate = x.new_zeros((T, ffn_dim))
427-
expert_out = x.new_zeros((T, H)) # accum-out buffer
428-
429-
# ───────────────────────── Stage-1 : Up-Gate ─────────────────────────────
430-
# Loop over experts
431-
for e in range(self.num_experts):
432-
W_g, W_u = self.experts.gate_proj[e], self.experts.up_proj[e]
433-
routing_weight = routing_weights[:, e].unsqueeze(-1)
434-
masked_up = torch.where(
435-
routing_weights[:, e].unsqueeze(-1) > 0,
436-
((self.experts.act_fn(x @ W_g)) * (x @ W_u)),
437-
torch.zeros_like(upgate),
438-
)
439-
upgate += masked_up
440-
441-
# At this point upgate[t] holds UpGate(x_t) for that token’s expert,
442-
# and arbitrary (zeros) data for tokens not routed to that expert.
443-
# ───────────────────────── Stage-2 : Down ────────────────────────────────
444-
for e in range(self.num_experts):
445-
routing_weight = routing_weights[:, e].unsqueeze(-1)
446-
masked_down = torch.where(
447-
routing_weight > 0, (upgate @ self.experts.down_proj[e]), torch.zeros_like(expert_out)
448-
)
449-
expert_out += masked_down
403+
def forward(self, hidden_states):
404+
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
405+
router_scores, router_logits = self.router(hidden_states)
406+
routed_in = hidden_states.repeat(router_scores.shape[1], 1)
450407

451-
# ───────────────────────── Stage-3 : Shared expert ───────────────────────
452-
shared_out = self.shared_expert(hidden) # [T, H]
453-
final = shared_out + expert_out # restore [B,S,H]
454-
return final.view(B, S, H), router_logits
408+
routed_in = routed_in * router_scores.reshape(-1, 1)
409+
routed_out = self.experts(routed_in)
410+
out = self.shared_expert(hidden_states)
411+
out.add_(routed_out.reshape(router_scores.shape[1], -1, routed_out.shape[-1]).sum(dim=0))
412+
return out, router_logits
455413

456414

457415
class QEffLlama4TextAttention(Llama4TextAttention):
@@ -475,10 +433,6 @@ def forward(
475433
key_states = self.k_proj(hidden_states).view(*input_shape, -1, self.head_dim)
476434
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
477435

478-
kv_seq_len = key_states.shape[-2]
479-
480-
kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
481-
##
482436
if self.use_rope: # the 16E model skips rope for long context on certain layers
483437
query_states, key_states = qeff_apply_rotary_emb(
484438
query_states, key_states, position_embeddings.to(query_states.device)
@@ -506,12 +460,11 @@ def forward(
506460
chunk_position_ids = torch.where(
507461
chunk_position_ids != -1, chunk_position_ids % self.config.attention_chunk_size, chunk_position_ids
508462
)
509-
510463
# sin and cos are specific to RoPE models; cache_position needed for the static cache
511464
cache_kwargs = {"batch_index": batch_index, "position_ids": chunk_position_ids}
512465
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
513466

514-
attention_interface: Callable = eager_attention_forward
467+
attention_interface = eager_attention_forward
515468

516469
attn_output, attn_weights = attention_interface(
517470
self,
@@ -520,7 +473,6 @@ def forward(
520473
value_states,
521474
attention_mask,
522475
scaling=self.scaling,
523-
**kwargs,
524476
)
525477

526478
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
@@ -552,10 +504,6 @@ def forward(
552504
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
553505
residual = hidden_states
554506

555-
# use local attention mask for ROPE layers
556-
if self.use_chunked_attention:
557-
attention_mask = chunk_causal_mask
558-
559507
hidden_states = self.input_layernorm(hidden_states)
560508

561509
# Self Attention
@@ -654,12 +602,12 @@ def forward(
654602
position_ids = cache_position.unsqueeze(0)
655603

656604
causal_mask = _create_causal_mask(
657-
position_ids=position_ids, target_length=past_key_values.key_cache[3].shape[-2]
605+
position_ids=position_ids, target_length=past_key_values.layers[3].keys.shape[-2]
658606
)
659607
chunk_position_ids = torch.where(
660608
position_ids != -1, position_ids % self.config.attention_chunk_size, position_ids
661609
)
662-
target_length = min(past_key_values.key_cache[0].shape[-2], torch.tensor(self.config.attention_chunk_size))
610+
target_length = min(past_key_values.layers[0].keys.shape[-2], torch.tensor(self.config.attention_chunk_size))
663611
chunk_causal_mask = _create_causal_mask(position_ids=chunk_position_ids, target_length=target_length)
664612

665613
# embed positions

0 commit comments

Comments
 (0)