6
6
# -----------------------------------------------------------------------------
7
7
8
8
import math
9
- from typing import Callable , List , Optional , Tuple , Union
9
+ from typing import List , Optional , Tuple , Union
10
10
11
11
import torch
12
12
from torch import nn
@@ -52,16 +52,14 @@ def eager_attention_forward_vision(
52
52
key_states = repeat_kv (key , module .num_key_value_groups )
53
53
value_states = repeat_kv (value , module .num_key_value_groups )
54
54
attn_weights = torch .matmul (query , key_states .transpose (2 , 3 )) / math .sqrt (module .head_dim )
55
- if attention_mask is not None :
56
- causal_mask = attention_mask [:, :, :, : key_states .shape [- 2 ]]
57
- attn_weights = attn_weights + causal_mask
55
+
58
56
if attention_mask is not None :
59
57
attn_weights = torch .where (
60
58
attention_mask , torch .tensor (MIN_MASKED_ATTENTION_VALUE , dtype = torch .float32 ), attn_weights
61
59
)
62
60
63
61
attn_weights = nn .functional .softmax (attn_weights .float (), dim = - 1 ).to (query .dtype )
64
- attn_weights = nn . functional . dropout ( attn_weights , p = dropout , training = module . training )
62
+
65
63
attn_output = torch .matmul (attn_weights , value_states )
66
64
attn_output = attn_output .transpose (1 , 2 ).contiguous ()
67
65
@@ -183,7 +181,7 @@ def forward(
183
181
key_states = key_states .transpose (1 , 2 )
184
182
value_states = value_states .transpose (1 , 2 )
185
183
186
- attention_interface : Callable = eager_attention_forward_vision
184
+ attention_interface = eager_attention_forward_vision
187
185
188
186
attn_output , attn_weights = attention_interface (
189
187
self ,
@@ -378,7 +376,6 @@ def eager_attention_forward(
378
376
value : torch .Tensor ,
379
377
attention_mask : Optional [torch .Tensor ],
380
378
scaling : float ,
381
- ** kwargs ,
382
379
):
383
380
key_states = repeat_kv (key , module .num_key_value_groups )
384
381
value_states = repeat_kv (value , module .num_key_value_groups )
@@ -403,55 +400,16 @@ def __qeff_init__(self):
403
400
404
401
405
402
class QEffLlama4TextMoe (Llama4TextMoe ):
406
- def forward (self , hidden : torch .Tensor ):
407
- B , S , H = hidden .shape
408
- T = B * S
409
- hidden = hidden .view (T , H )
410
-
411
- router_logits = self .router (hidden )
412
- # *top-k = 1* → LLama4
413
- top_w , top_i = torch .topk (router_logits , self .top_k , dim = - 1 ) # both [T, K]
414
- masked_logits = torch .full_like (router_logits , float ("-inf" ))
415
- masked_logits .scatter_ (1 , top_i , top_w )
416
-
417
- # Here we multiply by scores before experts, different only for Llama4
418
- x = hidden * torch .sigmoid (top_w .float ())
419
-
420
- # ── Book-keeping: create one boolean mask per expert once ───────────────
421
- # routing_weights[e] == True where token routed to that expert. Shape [E, T]
422
- routing_weights = torch .sigmoid (masked_logits .float ()).to (hidden .dtype )
423
-
424
- # ────────────────── allocate the two big tensors ─────
425
- ffn_dim = self .experts .intermediate_size # = 8/3 · H
426
- upgate = x .new_zeros ((T , ffn_dim ))
427
- expert_out = x .new_zeros ((T , H )) # accum-out buffer
428
-
429
- # ───────────────────────── Stage-1 : Up-Gate ─────────────────────────────
430
- # Loop over experts
431
- for e in range (self .num_experts ):
432
- W_g , W_u = self .experts .gate_proj [e ], self .experts .up_proj [e ]
433
- routing_weight = routing_weights [:, e ].unsqueeze (- 1 )
434
- masked_up = torch .where (
435
- routing_weights [:, e ].unsqueeze (- 1 ) > 0 ,
436
- ((self .experts .act_fn (x @ W_g )) * (x @ W_u )),
437
- torch .zeros_like (upgate ),
438
- )
439
- upgate += masked_up
440
-
441
- # At this point upgate[t] holds UpGate(x_t) for that token’s expert,
442
- # and arbitrary (zeros) data for tokens not routed to that expert.
443
- # ───────────────────────── Stage-2 : Down ────────────────────────────────
444
- for e in range (self .num_experts ):
445
- routing_weight = routing_weights [:, e ].unsqueeze (- 1 )
446
- masked_down = torch .where (
447
- routing_weight > 0 , (upgate @ self .experts .down_proj [e ]), torch .zeros_like (expert_out )
448
- )
449
- expert_out += masked_down
403
+ def forward (self , hidden_states ):
404
+ hidden_states = hidden_states .reshape (- 1 , self .hidden_dim )
405
+ router_scores , router_logits = self .router (hidden_states )
406
+ routed_in = hidden_states .repeat (router_scores .shape [1 ], 1 )
450
407
451
- # ───────────────────────── Stage-3 : Shared expert ───────────────────────
452
- shared_out = self .shared_expert (hidden ) # [T, H]
453
- final = shared_out + expert_out # restore [B,S,H]
454
- return final .view (B , S , H ), router_logits
408
+ routed_in = routed_in * router_scores .reshape (- 1 , 1 )
409
+ routed_out = self .experts (routed_in )
410
+ out = self .shared_expert (hidden_states )
411
+ out .add_ (routed_out .reshape (router_scores .shape [1 ], - 1 , routed_out .shape [- 1 ]).sum (dim = 0 ))
412
+ return out , router_logits
455
413
456
414
457
415
class QEffLlama4TextAttention (Llama4TextAttention ):
@@ -475,10 +433,6 @@ def forward(
475
433
key_states = self .k_proj (hidden_states ).view (* input_shape , - 1 , self .head_dim )
476
434
value_states = self .v_proj (hidden_states ).view (hidden_shape ).transpose (1 , 2 )
477
435
478
- kv_seq_len = key_states .shape [- 2 ]
479
-
480
- kv_seq_len = past_key_value .get_usable_length (kv_seq_len , self .layer_idx )
481
- ##
482
436
if self .use_rope : # the 16E model skips rope for long context on certain layers
483
437
query_states , key_states = qeff_apply_rotary_emb (
484
438
query_states , key_states , position_embeddings .to (query_states .device )
@@ -506,12 +460,11 @@ def forward(
506
460
chunk_position_ids = torch .where (
507
461
chunk_position_ids != - 1 , chunk_position_ids % self .config .attention_chunk_size , chunk_position_ids
508
462
)
509
-
510
463
# sin and cos are specific to RoPE models; cache_position needed for the static cache
511
464
cache_kwargs = {"batch_index" : batch_index , "position_ids" : chunk_position_ids }
512
465
key_states , value_states = past_key_value .update (key_states , value_states , self .layer_idx , cache_kwargs )
513
466
514
- attention_interface : Callable = eager_attention_forward
467
+ attention_interface = eager_attention_forward
515
468
516
469
attn_output , attn_weights = attention_interface (
517
470
self ,
@@ -520,7 +473,6 @@ def forward(
520
473
value_states ,
521
474
attention_mask ,
522
475
scaling = self .scaling ,
523
- ** kwargs ,
524
476
)
525
477
526
478
attn_output = attn_output .reshape (* input_shape , - 1 ).contiguous ()
@@ -552,10 +504,6 @@ def forward(
552
504
) -> Tuple [torch .FloatTensor , Optional [Tuple [torch .FloatTensor , torch .FloatTensor ]]]:
553
505
residual = hidden_states
554
506
555
- # use local attention mask for ROPE layers
556
- if self .use_chunked_attention :
557
- attention_mask = chunk_causal_mask
558
-
559
507
hidden_states = self .input_layernorm (hidden_states )
560
508
561
509
# Self Attention
@@ -654,12 +602,12 @@ def forward(
654
602
position_ids = cache_position .unsqueeze (0 )
655
603
656
604
causal_mask = _create_causal_mask (
657
- position_ids = position_ids , target_length = past_key_values .key_cache [3 ].shape [- 2 ]
605
+ position_ids = position_ids , target_length = past_key_values .layers [3 ]. keys .shape [- 2 ]
658
606
)
659
607
chunk_position_ids = torch .where (
660
608
position_ids != - 1 , position_ids % self .config .attention_chunk_size , position_ids
661
609
)
662
- target_length = min (past_key_values .key_cache [0 ].shape [- 2 ], torch .tensor (self .config .attention_chunk_size ))
610
+ target_length = min (past_key_values .layers [0 ]. keys .shape [- 2 ], torch .tensor (self .config .attention_chunk_size ))
663
611
chunk_causal_mask = _create_causal_mask (position_ids = chunk_position_ids , target_length = target_length )
664
612
665
613
# embed positions
0 commit comments