@@ -97,10 +97,11 @@ def _get_llama3_parameters(config: PretrainedConfig):
9797def _get_fope_parameters (config : PretrainedConfig ):
9898 """Get fope parameters."""
9999 params = FopeParameters ()
100- params .num_inv_freq = config .num_inv_freq
100+ rope_scaling = config .rope_scaling
101+ params .num_inv_freq = rope_scaling ['num_inv_freq' ]
101102 params .num_key_value_heads = config .num_key_value_heads
102- params .fope_sep_head = config . fope_sep_head
103- return dict (use_fope = True , fope_params = params )
103+ params .fope_sep_head = rope_scaling [ ' fope_sep_head' ]
104+ return dict (fope_params = params )
104105
105106
106107def build_rotary_params (config : PretrainedConfig ):
@@ -111,6 +112,9 @@ def build_rotary_params(config: PretrainedConfig):
111112 if rope_scaling is not None :
112113 # BC: "rope_type" was originally "type"
113114 rope_type_str = config .rope_scaling .get ('rope_type' , config .rope_scaling .get ('type' , 'default' ))
115+ if rope_type_str .startswith ('fope' ):
116+ params .update (_get_fope_parameters (config ))
117+ rope_type_str = 'default' if rope_type_str == 'fope' else rope_type_str [5 :]
114118 build_funcs = dict (default = _get_default_rope_parameters ,
115119 linear = _get_linear_scaling_rope_parameters ,
116120 dynamic = _get_dynamic_ntk_parameters ,
@@ -125,9 +129,6 @@ def build_rotary_params(config: PretrainedConfig):
125129 if partial_rotary_factor is not None :
126130 params ['partial_rotary_factor' ] = partial_rotary_factor
127131
128- if getattr (config , 'use_fope' , False ):
129- params .update (_get_fope_parameters (config ))
130-
131132 return params
132133
133134
@@ -140,8 +141,7 @@ def build_rotary_embedding(dim: int,
140141 llama3_params : Llama3Parameters = None ,
141142 fope_params : FopeParameters = None ,
142143 emb_type : RopeType = RopeType .Default ,
143- partial_rotary_factor : float = None ,
144- use_fope : bool = False ) -> nn .Module :
144+ partial_rotary_factor : float = None ) -> nn .Module :
145145 """Build rotary embedding op."""
146146 backend = get_backend ()
147147
@@ -159,8 +159,7 @@ def build_rotary_embedding(dim: int,
159159 llama3_params = llama3_params ,
160160 emb_type = emb_type )
161161
162- if use_fope :
163- assert fope_params is not None , 'fope_params should not be None when use_fope is True.'
162+ if fope_params is not None :
164163 inv_freq = impl .inv_freq
165164 fope_params .inv_freq = inv_freq
166165 fope = FopeRotaryEmbedding (dim , max_position_embeddings , scaling_factor , fope_params )
0 commit comments