Skip to content

Commit e9bd834

Browse files
committed
update config format
1 parent 732a274 commit e9bd834

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

lmdeploy/pytorch/nn/rotary_embedding.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,11 @@ def _get_llama3_parameters(config: PretrainedConfig):
9797
def _get_fope_parameters(config: PretrainedConfig):
9898
"""Get fope parameters."""
9999
params = FopeParameters()
100-
params.num_inv_freq = config.num_inv_freq
100+
rope_scaling = config.rope_scaling
101+
params.num_inv_freq = rope_scaling['num_inv_freq']
101102
params.num_key_value_heads = config.num_key_value_heads
102-
params.fope_sep_head = config.fope_sep_head
103-
return dict(use_fope=True, fope_params=params)
103+
params.fope_sep_head = rope_scaling['fope_sep_head']
104+
return dict(fope_params=params)
104105

105106

106107
def build_rotary_params(config: PretrainedConfig):
@@ -111,6 +112,9 @@ def build_rotary_params(config: PretrainedConfig):
111112
if rope_scaling is not None:
112113
# BC: "rope_type" was originally "type"
113114
rope_type_str = config.rope_scaling.get('rope_type', config.rope_scaling.get('type', 'default'))
115+
if rope_type_str.startswith('fope'):
116+
params.update(_get_fope_parameters(config))
117+
rope_type_str = 'default' if rope_type_str == 'fope' else rope_type_str[5:]
114118
build_funcs = dict(default=_get_default_rope_parameters,
115119
linear=_get_linear_scaling_rope_parameters,
116120
dynamic=_get_dynamic_ntk_parameters,
@@ -125,9 +129,6 @@ def build_rotary_params(config: PretrainedConfig):
125129
if partial_rotary_factor is not None:
126130
params['partial_rotary_factor'] = partial_rotary_factor
127131

128-
if getattr(config, 'use_fope', False):
129-
params.update(_get_fope_parameters(config))
130-
131132
return params
132133

133134

@@ -140,8 +141,7 @@ def build_rotary_embedding(dim: int,
140141
llama3_params: Llama3Parameters = None,
141142
fope_params: FopeParameters = None,
142143
emb_type: RopeType = RopeType.Default,
143-
partial_rotary_factor: float = None,
144-
use_fope: bool = False) -> nn.Module:
144+
partial_rotary_factor: float = None) -> nn.Module:
145145
"""Build rotary embedding op."""
146146
backend = get_backend()
147147

@@ -159,8 +159,7 @@ def build_rotary_embedding(dim: int,
159159
llama3_params=llama3_params,
160160
emb_type=emb_type)
161161

162-
if use_fope:
163-
assert fope_params is not None, 'fope_params should not be None when use_fope is True.'
162+
if fope_params is not None:
164163
inv_freq = impl.inv_freq
165164
fope_params.inv_freq = inv_freq
166165
fope = FopeRotaryEmbedding(dim, max_position_embeddings, scaling_factor, fope_params)

0 commit comments

Comments
 (0)