@@ -96,9 +96,16 @@ def _get_llama3_parameters(config: PretrainedConfig):
9696
9797def _get_fope_parameters (config : PretrainedConfig ):
9898 """Get fope parameters."""
99+ # check if fope is used
100+ rope_scaling = getattr (config , 'rope_scaling' , dict ())
101+ fope_keys = ['fope_sep_head' , 'fope_num_inv_freq' ]
102+ is_fope = any (key in rope_scaling for key in fope_keys )
103+ if not is_fope :
104+ return dict ()
105+
99106 params = FopeParameters ()
100107 rope_scaling = config .rope_scaling
101- params .num_inv_freq = rope_scaling [ ' num_inv_freq']
108+ params .num_inv_freq = rope_scaling . get ( 'fope_num_inv_freq' , rope_scaling . get ( ' num_inv_freq', params . num_inv_freq ))
102109 params .num_key_value_heads = config .num_key_value_heads
103110 params .fope_sep_head = rope_scaling ['fope_sep_head' ]
104111 return dict (fope_params = params )
@@ -112,9 +119,8 @@ def build_rotary_params(config: PretrainedConfig):
112119 if rope_scaling is not None :
113120 # BC: "rope_type" was originally "type"
114121 rope_type_str = config .rope_scaling .get ('rope_type' , config .rope_scaling .get ('type' , 'default' ))
115- if rope_type_str .startswith ('fope' ):
116- params .update (_get_fope_parameters (config ))
117- rope_type_str = 'default' if rope_type_str == 'fope' else rope_type_str [5 :]
122+ if rope_type_str == 'fope' :
123+ rope_type_str = 'default'
118124 build_funcs = dict (default = _get_default_rope_parameters ,
119125 linear = _get_linear_scaling_rope_parameters ,
120126 dynamic = _get_dynamic_ntk_parameters ,
@@ -123,6 +129,7 @@ def build_rotary_params(config: PretrainedConfig):
123129 su = _get_longrope_parameters ,
124130 llama3 = _get_llama3_parameters )
125131 params .update (build_funcs [rope_type_str ](config ))
132+ params .update (_get_fope_parameters (config ))
126133
127134 # update partial_rotary_factor
128135 partial_rotary_factor = config .partial_rotary_factor if hasattr (config , 'partial_rotary_factor' ) else None
0 commit comments