@@ -56,7 +56,8 @@ def __init__(self,
56
56
if hasattr (config , 'query_pre_attn_scalar' ):
57
57
self .scaling = config .query_pre_attn_scalar ** - 0.5
58
58
if self .model_type == 'gemma3_text' :
59
- is_sliding = bool ((layer_idx + 1 ) % config .sliding_window_pattern )
59
+ sliding_window_pattern = getattr (config , 'sliding_window_pattern' , 6 )
60
+ is_sliding = bool ((layer_idx + 1 ) % sliding_window_pattern )
60
61
self .sliding_window = (getattr (config , 'sliding_window' , - 1 ) if is_sliding else - 1 )
61
62
else :
62
63
self .sliding_window = (getattr (config , 'sliding_window' , - 1 ) if not bool (layer_idx % 2 ) else - 1 )
@@ -388,7 +389,7 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device:
388
389
emb_type = RopeType .DynamicNTKScaling
389
390
else :
390
391
raise RuntimeError (f'Unsupported rope type: { rope_type } ' )
391
- scaling_factor = rope_scaling .get ('scaling_factor' , scaling_factor )
392
+ scaling_factor = rope_scaling .get ('scaling_factor' , rope_scaling . get ( 'factor' , scaling_factor ) )
392
393
393
394
rope_dim = config .head_dim
394
395
rope_max_pos_emb = config .max_position_embeddings
@@ -406,8 +407,8 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device:
406
407
rope_dim ,
407
408
rope_max_pos_emb ,
408
409
config .rope_local_base_freq ,
409
- scaling_factor ,
410
- emb_type = emb_type ,
410
+ 1.0 ,
411
+ emb_type = RopeType . LinearScaling ,
411
412
)
412
413
413
414
def forward (
0 commit comments