55import shutil
66from collections import OrderedDict
77import warnings
8- from typing import List , Dict , Literal , Tuple , Iterable , Type , Union , Callable , Optional , TYPE_CHECKING
8+ from typing import List , Dict , Literal , Tuple , Iterable , Type , Union , Callable , Optional , TYPE_CHECKING , Any
99import numpy as np
1010from numpy import ndarray
1111import transformers
@@ -74,7 +74,40 @@ class SentenceTransformer(nn.Sequential):
7474 :param local_files_only: If `True`, avoid downloading the model.
7575 :param token: Hugging Face authentication token to download private models.
7676 :param truncate_dim: The dimension to truncate sentence embeddings to. `None` does no truncation. Truncation is
77- only applicable during inference when `.encode` is called.
77+ only applicable during inference when :meth:`SentenceTransformer.encode` is called.
78+ :param model_kwargs: Additional model configuration parameters to be passed to the Huggingface Transformers model.
79+ Particularly useful options are:
80+
81+ - ``torch_dtype``: Override the default `torch.dtype` and load the model under a specific `dtype`.
82+ The different options are:
83+
84+ 1. ``torch.float16``, ``torch.bfloat16`` or ``torch.float``: load in a specified
85+ ``dtype``, ignoring the model's ``config.torch_dtype`` if one exists. If not specified - the model will
86+ get loaded in ``torch.float`` (fp32).
87+
88+ 2. ``"auto"`` - A ``torch_dtype`` entry in the ``config.json`` file of the model will be
89+ attempted to be used. If this entry isn't found then next check the ``dtype`` of the first weight in
90+ the checkpoint that's of a floating point type and use that as ``dtype``. This will load the model
91+ using the ``dtype`` it was saved in at the end of the training. It can't be used as an indicator of how
92+ the model was trained. Since it could be trained in one of half precision dtypes, but saved in fp32.
93+ - ``attn_implementation``: The attention implementation to use in the model (if relevant). Can be any of
94+ `"eager"` (manual implementation of the attention), `"sdpa"` (using `F.scaled_dot_product_attention
95+ <https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html>`_),
96+ or `"flash_attention_2"` (using `Dao-AILab/flash-attention <https://github.com/Dao-AILab/flash-attention>`_).
97+ By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"`
98+ implementation.
99+
100+ See the `PreTrainedModel.from_pretrained
101+ <https://huggingface.co/docs/transformers/en/main_classes/model#transformers.PreTrainedModel.from_pretrained>`_
102+ documentation for more details.
103+ :param tokenizer_kwargs: Additional tokenizer configuration parameters to be passed to the Huggingface Transformers tokenizer.
104+ See the `AutoTokenizer.from_pretrained
105+ <https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoTokenizer.from_pretrained>`_
106+ documentation for more details.
107+ :param config_kwargs: Additional model configuration parameters to be passed to the Huggingface Transformers config.
108+ See the `AutoConfig.from_pretrained
109+ <https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoConfig.from_pretrained>`_
110+ documentation for more details.
78111 """
79112
80113 def __init__ (
@@ -91,6 +124,9 @@ def __init__(
91124 token : Optional [Union [bool , str ]] = None ,
92125 use_auth_token : Optional [Union [bool , str ]] = None ,
93126 truncate_dim : Optional [int ] = None ,
127+ model_kwargs : Optional [Dict [str , Any ]] = None ,
128+ tokenizer_kwargs : Optional [Dict [str , Any ]] = None ,
129+ config_kwargs : Optional [Dict [str , Any ]] = None ,
94130 ):
95131 # Note: self._load_sbert_model can also update `self.prompts` and `self.default_prompt_name`
96132 self .prompts = prompts or {}
@@ -220,6 +256,9 @@ def __init__(
220256 revision = revision ,
221257 trust_remote_code = trust_remote_code ,
222258 local_files_only = local_files_only ,
259+ model_kwargs = model_kwargs ,
260+ tokenizer_kwargs = tokenizer_kwargs ,
261+ config_kwargs = config_kwargs ,
223262 )
224263 else :
225264 modules = self ._load_auto_model (
@@ -229,6 +268,9 @@ def __init__(
229268 revision = revision ,
230269 trust_remote_code = trust_remote_code ,
231270 local_files_only = local_files_only ,
271+ model_kwargs = model_kwargs ,
272+ tokenizer_kwargs = tokenizer_kwargs ,
273+ config_kwargs = config_kwargs ,
232274 )
233275
234276 if modules is not None and not isinstance (modules , OrderedDict ):
@@ -461,7 +503,10 @@ def encode(
461503 all_embeddings = torch .Tensor ()
462504 elif convert_to_numpy :
463505 if not isinstance (all_embeddings , np .ndarray ):
464- all_embeddings = np .asarray ([emb .numpy () for emb in all_embeddings ])
506+ if all_embeddings [0 ].dtype == torch .bfloat16 :
507+ all_embeddings = np .asarray ([emb .float ().numpy () for emb in all_embeddings ])
508+ else :
509+ all_embeddings = np .asarray ([emb .numpy () for emb in all_embeddings ])
465510 elif isinstance (all_embeddings , np .ndarray ):
466511 all_embeddings = [torch .from_numpy (embedding ) for embedding in all_embeddings ]
467512
@@ -1235,30 +1280,35 @@ def _load_auto_model(
12351280 revision : Optional [str ] = None ,
12361281 trust_remote_code : bool = False ,
12371282 local_files_only : bool = False ,
1283+ model_kwargs : Optional [Dict [str , Any ]] = None ,
1284+ tokenizer_kwargs : Optional [Dict [str , Any ]] = None ,
1285+ config_kwargs : Optional [Dict [str , Any ]] = None ,
12381286 ):
12391287 """
12401288 Creates a simple Transformer + Mean Pooling model and returns the modules
12411289 """
12421290 logger .warning (
1243- "No sentence-transformers model found with name {}. Creating a new one with MEAN pooling." .format (
1291+ "No sentence-transformers model found with name {}. Creating a new one with mean pooling." .format (
12441292 model_name_or_path
12451293 )
12461294 )
1295+
1296+ shared_kwargs = {
1297+ "token" : token ,
1298+ "trust_remote_code" : trust_remote_code ,
1299+ "revision" : revision ,
1300+ "local_files_only" : local_files_only ,
1301+ }
1302+ model_kwargs = shared_kwargs if model_kwargs is None else {** shared_kwargs , ** model_kwargs }
1303+ tokenizer_kwargs = shared_kwargs if tokenizer_kwargs is None else {** shared_kwargs , ** tokenizer_kwargs }
1304+ config_kwargs = shared_kwargs if config_kwargs is None else {** shared_kwargs , ** config_kwargs }
1305+
12471306 transformer_model = Transformer (
12481307 model_name_or_path ,
12491308 cache_dir = cache_folder ,
1250- model_args = {
1251- "token" : token ,
1252- "trust_remote_code" : trust_remote_code ,
1253- "revision" : revision ,
1254- "local_files_only" : local_files_only ,
1255- },
1256- tokenizer_args = {
1257- "token" : token ,
1258- "trust_remote_code" : trust_remote_code ,
1259- "revision" : revision ,
1260- "local_files_only" : local_files_only ,
1261- },
1309+ model_args = model_kwargs ,
1310+ tokenizer_args = tokenizer_kwargs ,
1311+ config_args = config_kwargs ,
12621312 )
12631313 pooling_model = Pooling (transformer_model .get_word_embedding_dimension (), "mean" )
12641314 return [transformer_model , pooling_model ]
@@ -1271,6 +1321,9 @@ def _load_sbert_model(
12711321 revision : Optional [str ] = None ,
12721322 trust_remote_code : bool = False ,
12731323 local_files_only : bool = False ,
1324+ model_kwargs : Optional [Dict [str , Any ]] = None ,
1325+ tokenizer_kwargs : Optional [Dict [str , Any ]] = None ,
1326+ config_kwargs : Optional [Dict [str , Any ]] = None ,
12741327 ):
12751328 """
12761329 Loads a full sentence-transformers model
@@ -1360,21 +1413,42 @@ def _load_sbert_model(
13601413 if config_path is not None :
13611414 with open (config_path ) as fIn :
13621415 kwargs = json .load (fIn )
1416+ # Don't allow configs to set trust_remote_code
1417+ if "model_args" in kwargs and "trust_remote_code" in kwargs ["model_args" ]:
1418+ kwargs ["model_args" ].pop ("trust_remote_code" )
1419+ if "tokenizer_args" in kwargs and "trust_remote_code" in kwargs ["tokenizer_args" ]:
1420+ kwargs ["tokenizer_args" ].pop ("trust_remote_code" )
1421+ if "config_args" in kwargs and "trust_remote_code" in kwargs ["config_args" ]:
1422+ kwargs ["config_args" ].pop ("trust_remote_code" )
13631423 break
1424+
13641425 hub_kwargs = {
13651426 "token" : token ,
13661427 "trust_remote_code" : trust_remote_code ,
13671428 "revision" : revision ,
13681429 "local_files_only" : local_files_only ,
13691430 }
1370- if "model_args" in kwargs :
1371- kwargs ["model_args" ].update (hub_kwargs )
1372- else :
1373- kwargs ["model_args" ] = hub_kwargs
1374- if "tokenizer_args" in kwargs :
1375- kwargs ["tokenizer_args" ].update (hub_kwargs )
1376- else :
1377- kwargs ["tokenizer_args" ] = hub_kwargs
1431+ # 3rd priority: config file
1432+ if "model_args" not in kwargs :
1433+ kwargs ["model_args" ] = {}
1434+ if "tokenizer_args" not in kwargs :
1435+ kwargs ["tokenizer_args" ] = {}
1436+ if "config_args" not in kwargs :
1437+ kwargs ["config_args" ] = {}
1438+
1439+ # 2nd priority: hub_kwargs
1440+ kwargs ["model_args" ].update (hub_kwargs )
1441+ kwargs ["tokenizer_args" ].update (hub_kwargs )
1442+ kwargs ["config_args" ].update (hub_kwargs )
1443+
1444+ # 1st priority: kwargs passed to SentenceTransformer
1445+ if model_kwargs :
1446+ kwargs ["model_args" ].update (model_kwargs )
1447+ if tokenizer_kwargs :
1448+ kwargs ["tokenizer_args" ].update (tokenizer_kwargs )
1449+ if config_kwargs :
1450+ kwargs ["config_args" ].update (config_kwargs )
1451+
13781452 module = Transformer (model_name_or_path , cache_dir = cache_folder , ** kwargs )
13791453 else :
13801454 # Normalize does not require any files to be loaded
0 commit comments