diff --git a/sentence_transformers/models/Transformer.py b/sentence_transformers/models/Transformer.py index 1a622eed0..2167962b6 100644 --- a/sentence_transformers/models/Transformer.py +++ b/sentence_transformers/models/Transformer.py @@ -88,16 +88,18 @@ def __init__( **tokenizer_args, ) - # No max_seq_length set. Try to infer from model + max_seq_options = [] if max_seq_length is None: - if ( - hasattr(self.auto_model, "config") - and hasattr(self.auto_model.config, "max_position_embeddings") - and hasattr(self.tokenizer, "model_max_length") - ): - max_seq_length = min(self.auto_model.config.max_position_embeddings, self.tokenizer.model_max_length) - - self.max_seq_length = max_seq_length + max_seq_options.append(max_seq_length) + if ( + hasattr(self.auto_model, "config") + and hasattr(self.auto_model.config, "max_position_embeddings") + ): + max_seq_options.append(self.auto_model.config.max_position_embeddings) + if hasattr(self.tokenizer, "model_max_length"): + max_seq_options.append(self.tokenizer.model_max_length) + + self.max_seq_length = min(max_seq_options) if tokenizer_name_or_path is not None: self.auto_model.config.tokenizer_class = self.tokenizer.__class__.__name__