Skip to content

Commit 807e733

Browse files
committed
update fope params
1 parent e9bd834 commit 807e733

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

lmdeploy/pytorch/nn/rotary_embedding.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,16 @@ def _get_llama3_parameters(config: PretrainedConfig):
9696

9797
def _get_fope_parameters(config: PretrainedConfig):
9898
"""Get fope parameters."""
99+
# check if fope is used
100+
rope_scaling = getattr(config, 'rope_scaling', dict())
101+
fope_keys = ['fope_sep_head', 'fope_num_inv_freq']
102+
is_fope = any(key in rope_scaling for key in fope_keys)
103+
if not is_fope:
104+
return dict()
105+
99106
params = FopeParameters()
100107
rope_scaling = config.rope_scaling
101-
params.num_inv_freq = rope_scaling['num_inv_freq']
108+
params.num_inv_freq = rope_scaling.get('fope_num_inv_freq', rope_scaling.get('num_inv_freq', params.num_inv_freq))
102109
params.num_key_value_heads = config.num_key_value_heads
103110
params.fope_sep_head = rope_scaling['fope_sep_head']
104111
return dict(fope_params=params)
@@ -112,9 +119,8 @@ def build_rotary_params(config: PretrainedConfig):
112119
if rope_scaling is not None:
113120
# BC: "rope_type" was originally "type"
114121
rope_type_str = config.rope_scaling.get('rope_type', config.rope_scaling.get('type', 'default'))
115-
if rope_type_str.startswith('fope'):
116-
params.update(_get_fope_parameters(config))
117-
rope_type_str = 'default' if rope_type_str == 'fope' else rope_type_str[5:]
122+
if rope_type_str == 'fope':
123+
rope_type_str = 'default'
118124
build_funcs = dict(default=_get_default_rope_parameters,
119125
linear=_get_linear_scaling_rope_parameters,
120126
dynamic=_get_dynamic_ntk_parameters,
@@ -123,6 +129,7 @@ def build_rotary_params(config: PretrainedConfig):
123129
su=_get_longrope_parameters,
124130
llama3=_get_llama3_parameters)
125131
params.update(build_funcs[rope_type_str](config))
132+
params.update(_get_fope_parameters(config))
126133

127134
# update partial_rotary_factor
128135
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, 'partial_rotary_factor') else None

0 commit comments

Comments
 (0)