1313# limitations under the License.
1414
1515import math
16- import paddle
1716from typing import Optional
1817
18+ import paddle
19+
1920from ..utils .log import logger
2021from .configuration_utils import PretrainedConfig
2122
23+
2224def standardize_rope_params (config , rope_theta : float | dict [str , float ] | None = None ):
2325 """
2426 Helper to standardize the config's rope params field by ensuring the params are defined for each
@@ -69,6 +71,7 @@ def standardize_rope_params(config, rope_theta: float | dict[str, float] | None
6971 }
7072 config .rope_parameters = rope_parameters_per_layer_type
7173
74+
7275def _compute_linear_scaling_rope_parameters (
7376 config : Optional [PretrainedConfig ] = None ,
7477 device : Optional [str ] = None ,
@@ -416,7 +419,7 @@ def _compute_longrope_parameters(
416419 else :
417420 ext_factors = paddle .to_tensor (short_factor , dtype = paddle .float32 , place = device )
418421 inv_freq_shape = paddle .arange (0 , dim , 2 , dtype = "int64" , device = device ).float () / dim
419- inv_freq = 1.0 / (ext_factors * base ** inv_freq_shape )
422+ inv_freq = 1.0 / (ext_factors * base ** inv_freq_shape )
420423
421424 return inv_freq , attention_factor
422425
@@ -511,6 +514,7 @@ def _compute_llama3_parameters(
511514 "llama3" : _compute_llama3_parameters ,
512515}
513516
517+
514518def _check_received_keys (
515519 rope_type : str ,
516520 received_keys : set ,
@@ -539,6 +543,7 @@ def _check_received_keys(
539543 if unused_keys :
540544 logger .warning (f"Unrecognized keys in `rope_parameters` for 'rope_type'='{ rope_type } ': { unused_keys } " )
541545
546+
542547def _validate_default_rope_parameters (
543548 rope_parameters : dict , config : Optional [PretrainedConfig ] = None , ignore_keys : Optional [set ] = None
544549):
@@ -547,6 +552,7 @@ def _validate_default_rope_parameters(
547552 rope_type = rope_parameters ["rope_type" ]
548553 _check_received_keys (rope_type , received_keys , required_keys , ignore_keys = ignore_keys )
549554
555+
550556def _validate_linear_scaling_rope_parameters (
551557 rope_parameters : dict , config : Optional [PretrainedConfig ] = None , ignore_keys : Optional [set ] = None
552558):
@@ -559,6 +565,7 @@ def _validate_linear_scaling_rope_parameters(
559565 if factor is None or not isinstance (factor , float ) or factor < 1.0 :
560566 logger .warning (f"`rope_parameters`'s factor field must be a float >= 1, got { factor } " )
561567
568+
562569def _validate_dynamic_scaling_rope_parameters (
563570 rope_parameters : dict , config : Optional [PretrainedConfig ] = None , ignore_keys : Optional [set ] = None
564571):
@@ -584,7 +591,7 @@ def _validate_yarn_parameters(
584591 "original_max_position_embeddings" ,
585592 "mscale" ,
586593 "mscale_all_dim" ,
587- "truncate"
594+ "truncate" ,
588595 }
589596 received_keys = set (rope_parameters .keys ())
590597 rope_type = rope_parameters ["rope_type" ]
@@ -729,6 +736,7 @@ def _validate_llama3_parameters(rope_parameters: dict, config: PretrainedConfig,
729736 f"{ original_max_position_embeddings } and max_position_embeddings={ config .max_position_embeddings } "
730737 )
731738
739+
732740# Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types.
733741ROPE_VALIDATION_FUNCTIONS = {
734742 "default" : _validate_default_rope_parameters ,
@@ -739,6 +747,7 @@ def _validate_llama3_parameters(rope_parameters: dict, config: PretrainedConfig,
739747 "llama3" : _validate_llama3_parameters ,
740748}
741749
750+
742751def rope_config_validation (config : PretrainedConfig , ignore_keys : Optional [set ] = None ):
743752 """
744753 Validate the RoPE config arguments, given a `PreTrainedConfig` object
@@ -767,4 +776,4 @@ def rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set]
767776 else :
768777 logger .warning (
769778 f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{ rope_type } '"
770- )
779+ )
0 commit comments