1010 merge_padding_and_attention_mask ,
1111)
1212from keras_hub .src .models .smollm3 .smollm3_utils import rope_init
13- from keras_hub .src .layers . modeling . rotary_embedding import RotaryEmbedding
13+ from keras_hub .src .models . smollm3 . smollm3_utils import apply_rotary_pos_emb
1414import math
1515
1616
@@ -39,6 +39,9 @@ def __init__(
3939 rope_layer_enabled_list : list [bool ],
4040 layer_types : list [str ],
4141 layer_idx : int ,
42+ max_position_embeddings : int = 2048 ,
43+ rope_theta : float = 10000.0 ,
44+ partial_rotary_factor : float = 1.0 ,
4245 ** kwargs ,
4346 ):
4447 super ().__init__ (** kwargs )
@@ -50,19 +53,17 @@ def __init__(
5053 self .attention_dropout = attention_dropout
5154 self .rope_layer_enabled_list = rope_layer_enabled_list
5255 self .layer_types = layer_types
56+ self .max_position_embeddings = max_position_embeddings
57+ self .rope_theta = rope_theta
58+ self .partial_rotary_factor = partial_rotary_factor
59+
5360 self ._dot_product_equation = "bquh,bkuh->buqk"
5461 self ._combine_equation = "buqk,bkuh->bquh"
5562
5663 self .head_dim = hidden_size // self .num_attention_heads
5764 self ._inv_norm_factor = 1.0 / math .sqrt (self .head_dim )
5865
59- self .rotary_embedding = RotaryEmbedding (
60- max_wavelength = 5000000.0 ,
61- )
62-
6366 self .layer_idx = layer_idx
64-
65- self .head_dim = self .hidden_size // self .num_attention_heads
6667 self .num_key_value_groups = (
6768 self .num_attention_heads // self .num_key_value_heads
6869 )
@@ -97,6 +98,15 @@ def __init__(
9798 else True
9899 ) # Default to True if index out of bounds
99100
101+ self .rotary_embedding = SmolLM3RotaryEmbedding (
102+ hidden_size = self .hidden_size ,
103+ num_attention_heads = self .num_attention_heads ,
104+ max_position_embeddings = self .max_position_embeddings ,
105+ rope_theta = self .rope_theta ,
106+ partial_rotary_factor = self .partial_rotary_factor ,
107+ name = "rotary_emb" ,
108+ )
109+
100110 self ._softmax = layers .Softmax (
101111 axis = - 1 ,
102112 dtype = "float32" ,
@@ -172,7 +182,15 @@ def _compute_kv_values(x_input):
172182 value = value_cache
173183 else :
174184 key_update , value_update = _compute_kv_values (hidden_states )
175- start = [0 , self_attention_cache_update_index , 0 , 0 ]
185+
186+ # Apply RoPE to key_update BEFORE caching
187+ if self .use_rope :
188+ cos , sin = self .rotary_embedding (query , start_index = start_index )
189+ query_rope , key_update = apply_rotary_pos_emb (query , key_update , cos , sin , expansion_axis = 2 )
190+ query = query_rope
191+
192+ start = (0 , self_attention_cache_update_index , 0 , 0 )
193+
176194 key = ops .slice_update (key_cache , start , key_update )
177195 value = ops .slice_update (
178196 value_cache , start , value_update
@@ -189,14 +207,13 @@ def _compute_kv_values(x_input):
189207 )
190208 key , value = _compute_kv_values (hidden_states )
191209
192- if self .use_rope :
193- query = self .rotary_embedding (query , start_index = start_index )
194- key = self .rotary_embedding (key , start_index = start_index )
210+ # Apply RoPE when not using cache
211+ if self .use_rope :
212+ cos , sin = self .rotary_embedding (query , start_index = start_index )
213+ query , key = apply_rotary_pos_emb (query , key , cos , sin , expansion_axis = 2 )
195214
196- print ('pre' , key .shape , value .shape )
197215 key = ops .repeat (key , repeats = self .num_key_value_groups , axis = 2 )
198216 value = ops .repeat (value , repeats = self .num_key_value_groups , axis = 2 )
199- print ('post' , key .shape , value .shape )
200217
201218 attn_output = self ._compute_attention (
202219 query ,
@@ -400,6 +417,9 @@ def __init__(
400417 intermediate_size : int ,
401418 mlp_bias : bool ,
402419 layer_norm_epsilon : float ,
420+ max_position_embeddings : int = 2048 ,
421+ rope_theta : float = 10000.0 ,
422+ partial_rotary_factor : float = 1.0 ,
403423 ** kwargs ,
404424 ):
405425 super ().__init__ (** kwargs )
@@ -415,6 +435,9 @@ def __init__(
415435 rope_layer_enabled_list = rope_layer_enabled_list ,
416436 layer_types = layer_types ,
417437 layer_idx = layer_idx ,
438+ max_position_embeddings = max_position_embeddings ,
439+ rope_theta = rope_theta ,
440+ partial_rotary_factor = partial_rotary_factor ,
418441 name = "self_attn" ,
419442 )
420443
@@ -641,26 +664,34 @@ def call(
641664 Shape can vary, but the last dimension is head_dim.
642665 position_ids: Tensor of position IDs of shape (batch_size, seq_len).
643666 """
644- inv_freq_expanded = ops .expand_dims (
645- ops .expand_dims (self .inv_freq , axis = 0 ), axis = - 1
646- )
647-
648667 batch_size = ops .shape (x )[0 ]
649668 seq_len = ops .shape (x )[1 ]
650669 positions = ops .arange (seq_len , dtype = "float32" )
651670 positions = positions + ops .cast (start_index , dtype = "float32" )
652671
672+ # inv_freq: (inv_freq_dim,) -> (1, inv_freq_dim, 1) -> (batch, inv_freq_dim, 1)
673+ inv_freq_expanded = ops .expand_dims (
674+ ops .expand_dims (self .inv_freq , axis = 0 ), axis = - 1
675+ )
653676 inv_freq_expanded = ops .broadcast_to (
654677 inv_freq_expanded , (batch_size , ops .shape (self .inv_freq )[0 ], 1 )
655678 )
656679
657- position_ids_expanded = ops .expand_dims (positions , axis = 1 ).T
680+ # positions: (seq_len,) -> (1, 1, seq_len) -> (batch, 1, seq_len)
681+ position_ids_expanded = ops .expand_dims (
682+ ops .expand_dims (positions , axis = 0 ), axis = 0
683+ )
684+ position_ids_expanded = ops .broadcast_to (
685+ position_ids_expanded , (batch_size , 1 , seq_len )
686+ )
658687
688+ # matmul: (batch, inv_freq_dim, 1) @ (batch, 1, seq_len) -> (batch, inv_freq_dim, seq_len)
659689 freqs = ops .matmul (
660690 ops .cast (inv_freq_expanded , "float32" ),
661691 ops .cast (position_ids_expanded , "float32" ),
662692 )
663693
694+ # transpose: (batch, inv_freq_dim, seq_len) -> (batch, seq_len, inv_freq_dim)
664695 freqs = ops .transpose (freqs , axes = (0 , 2 , 1 ))
665696
666697 emb = ops .concatenate ((freqs , freqs ), axis = - 1 )
0 commit comments