Skip to content

Commit 5f75ce5

Browse files
Allow passing model_args to ST (#2578)
* Use torch_dtype='auto' when loading auto-class * Use torch_dtype='auto' as default if model_args doesn't have it * Allow passing model_args for ST * Make same change for T5 and MT5 * Update method documentation * Disable test if CUDA is not available * Add explicit dtype param to signature * Add model kwargs, update documentation, add more tests * Update based on suggestions * Format fixes * Update docstrings slightly * Reintroduce config kwargs; needed for token, trust_remote_code, revision, etc. * Propose dict kwargs for model, tokenizer & config * Remove dict defaults * Only cast to float() if necessary; adopted from v3.0-pre-release --------- Co-authored-by: Tom Aarsen <[email protected]>
1 parent 50e499a commit 5f75ce5

File tree

3 files changed

+160
-32
lines changed

3 files changed

+160
-32
lines changed

sentence_transformers/SentenceTransformer.py

Lines changed: 98 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import shutil
66
from collections import OrderedDict
77
import warnings
8-
from typing import List, Dict, Literal, Tuple, Iterable, Type, Union, Callable, Optional, TYPE_CHECKING
8+
from typing import List, Dict, Literal, Tuple, Iterable, Type, Union, Callable, Optional, TYPE_CHECKING, Any
99
import numpy as np
1010
from numpy import ndarray
1111
import transformers
@@ -74,7 +74,40 @@ class SentenceTransformer(nn.Sequential):
7474
:param local_files_only: If `True`, avoid downloading the model.
7575
:param token: Hugging Face authentication token to download private models.
7676
:param truncate_dim: The dimension to truncate sentence embeddings to. `None` does no truncation. Truncation is
77-
only applicable during inference when `.encode` is called.
77+
only applicable during inference when :meth:`SentenceTransformer.encode` is called.
78+
:param model_kwargs: Additional model configuration parameters to be passed to the Huggingface Transformers model.
79+
Particularly useful options are:
80+
81+
- ``torch_dtype``: Override the default `torch.dtype` and load the model under a specific `dtype`.
82+
The different options are:
83+
84+
1. ``torch.float16``, ``torch.bfloat16`` or ``torch.float``: load in a specified
85+
``dtype``, ignoring the model's ``config.torch_dtype`` if one exists. If not specified - the model will
86+
get loaded in ``torch.float`` (fp32).
87+
88+
2. ``"auto"`` - A ``torch_dtype`` entry in the ``config.json`` file of the model will be
89+
attempted to be used. If this entry isn't found then next check the ``dtype`` of the first weight in
90+
the checkpoint that's of a floating point type and use that as ``dtype``. This will load the model
91+
using the ``dtype`` it was saved in at the end of the training. It can't be used as an indicator of how
92+
the model was trained. Since it could be trained in one of half precision dtypes, but saved in fp32.
93+
- ``attn_implementation``: The attention implementation to use in the model (if relevant). Can be any of
94+
`"eager"` (manual implementation of the attention), `"sdpa"` (using `F.scaled_dot_product_attention
95+
<https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html>`_),
96+
or `"flash_attention_2"` (using `Dao-AILab/flash-attention <https://github.com/Dao-AILab/flash-attention>`_).
97+
By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"`
98+
implementation.
99+
100+
See the `PreTrainedModel.from_pretrained
101+
<https://huggingface.co/docs/transformers/en/main_classes/model#transformers.PreTrainedModel.from_pretrained>`_
102+
documentation for more details.
103+
:param tokenizer_kwargs: Additional tokenizer configuration parameters to be passed to the Huggingface Transformers tokenizer.
104+
See the `AutoTokenizer.from_pretrained
105+
<https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoTokenizer.from_pretrained>`_
106+
documentation for more details.
107+
:param config_kwargs: Additional model configuration parameters to be passed to the Huggingface Transformers config.
108+
See the `AutoConfig.from_pretrained
109+
<https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoConfig.from_pretrained>`_
110+
documentation for more details.
78111
"""
79112

80113
def __init__(
@@ -91,6 +124,9 @@ def __init__(
91124
token: Optional[Union[bool, str]] = None,
92125
use_auth_token: Optional[Union[bool, str]] = None,
93126
truncate_dim: Optional[int] = None,
127+
model_kwargs: Optional[Dict[str, Any]] = None,
128+
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
129+
config_kwargs: Optional[Dict[str, Any]] = None,
94130
):
95131
# Note: self._load_sbert_model can also update `self.prompts` and `self.default_prompt_name`
96132
self.prompts = prompts or {}
@@ -220,6 +256,9 @@ def __init__(
220256
revision=revision,
221257
trust_remote_code=trust_remote_code,
222258
local_files_only=local_files_only,
259+
model_kwargs=model_kwargs,
260+
tokenizer_kwargs=tokenizer_kwargs,
261+
config_kwargs=config_kwargs,
223262
)
224263
else:
225264
modules = self._load_auto_model(
@@ -229,6 +268,9 @@ def __init__(
229268
revision=revision,
230269
trust_remote_code=trust_remote_code,
231270
local_files_only=local_files_only,
271+
model_kwargs=model_kwargs,
272+
tokenizer_kwargs=tokenizer_kwargs,
273+
config_kwargs=config_kwargs,
232274
)
233275

234276
if modules is not None and not isinstance(modules, OrderedDict):
@@ -461,7 +503,10 @@ def encode(
461503
all_embeddings = torch.Tensor()
462504
elif convert_to_numpy:
463505
if not isinstance(all_embeddings, np.ndarray):
464-
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
506+
if all_embeddings[0].dtype == torch.bfloat16:
507+
all_embeddings = np.asarray([emb.float().numpy() for emb in all_embeddings])
508+
else:
509+
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
465510
elif isinstance(all_embeddings, np.ndarray):
466511
all_embeddings = [torch.from_numpy(embedding) for embedding in all_embeddings]
467512

@@ -1235,30 +1280,35 @@ def _load_auto_model(
12351280
revision: Optional[str] = None,
12361281
trust_remote_code: bool = False,
12371282
local_files_only: bool = False,
1283+
model_kwargs: Optional[Dict[str, Any]] = None,
1284+
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
1285+
config_kwargs: Optional[Dict[str, Any]] = None,
12381286
):
12391287
"""
12401288
Creates a simple Transformer + Mean Pooling model and returns the modules
12411289
"""
12421290
logger.warning(
1243-
"No sentence-transformers model found with name {}. Creating a new one with MEAN pooling.".format(
1291+
"No sentence-transformers model found with name {}. Creating a new one with mean pooling.".format(
12441292
model_name_or_path
12451293
)
12461294
)
1295+
1296+
shared_kwargs = {
1297+
"token": token,
1298+
"trust_remote_code": trust_remote_code,
1299+
"revision": revision,
1300+
"local_files_only": local_files_only,
1301+
}
1302+
model_kwargs = shared_kwargs if model_kwargs is None else {**shared_kwargs, **model_kwargs}
1303+
tokenizer_kwargs = shared_kwargs if tokenizer_kwargs is None else {**shared_kwargs, **tokenizer_kwargs}
1304+
config_kwargs = shared_kwargs if config_kwargs is None else {**shared_kwargs, **config_kwargs}
1305+
12471306
transformer_model = Transformer(
12481307
model_name_or_path,
12491308
cache_dir=cache_folder,
1250-
model_args={
1251-
"token": token,
1252-
"trust_remote_code": trust_remote_code,
1253-
"revision": revision,
1254-
"local_files_only": local_files_only,
1255-
},
1256-
tokenizer_args={
1257-
"token": token,
1258-
"trust_remote_code": trust_remote_code,
1259-
"revision": revision,
1260-
"local_files_only": local_files_only,
1261-
},
1309+
model_args=model_kwargs,
1310+
tokenizer_args=tokenizer_kwargs,
1311+
config_args=config_kwargs,
12621312
)
12631313
pooling_model = Pooling(transformer_model.get_word_embedding_dimension(), "mean")
12641314
return [transformer_model, pooling_model]
@@ -1271,6 +1321,9 @@ def _load_sbert_model(
12711321
revision: Optional[str] = None,
12721322
trust_remote_code: bool = False,
12731323
local_files_only: bool = False,
1324+
model_kwargs: Optional[Dict[str, Any]] = None,
1325+
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
1326+
config_kwargs: Optional[Dict[str, Any]] = None,
12741327
):
12751328
"""
12761329
Loads a full sentence-transformers model
@@ -1360,21 +1413,42 @@ def _load_sbert_model(
13601413
if config_path is not None:
13611414
with open(config_path) as fIn:
13621415
kwargs = json.load(fIn)
1416+
# Don't allow configs to set trust_remote_code
1417+
if "model_args" in kwargs and "trust_remote_code" in kwargs["model_args"]:
1418+
kwargs["model_args"].pop("trust_remote_code")
1419+
if "tokenizer_args" in kwargs and "trust_remote_code" in kwargs["tokenizer_args"]:
1420+
kwargs["tokenizer_args"].pop("trust_remote_code")
1421+
if "config_args" in kwargs and "trust_remote_code" in kwargs["config_args"]:
1422+
kwargs["config_args"].pop("trust_remote_code")
13631423
break
1424+
13641425
hub_kwargs = {
13651426
"token": token,
13661427
"trust_remote_code": trust_remote_code,
13671428
"revision": revision,
13681429
"local_files_only": local_files_only,
13691430
}
1370-
if "model_args" in kwargs:
1371-
kwargs["model_args"].update(hub_kwargs)
1372-
else:
1373-
kwargs["model_args"] = hub_kwargs
1374-
if "tokenizer_args" in kwargs:
1375-
kwargs["tokenizer_args"].update(hub_kwargs)
1376-
else:
1377-
kwargs["tokenizer_args"] = hub_kwargs
1431+
# 3rd priority: config file
1432+
if "model_args" not in kwargs:
1433+
kwargs["model_args"] = {}
1434+
if "tokenizer_args" not in kwargs:
1435+
kwargs["tokenizer_args"] = {}
1436+
if "config_args" not in kwargs:
1437+
kwargs["config_args"] = {}
1438+
1439+
# 2nd priority: hub_kwargs
1440+
kwargs["model_args"].update(hub_kwargs)
1441+
kwargs["tokenizer_args"].update(hub_kwargs)
1442+
kwargs["config_args"].update(hub_kwargs)
1443+
1444+
# 1st priority: kwargs passed to SentenceTransformer
1445+
if model_kwargs:
1446+
kwargs["model_args"].update(model_kwargs)
1447+
if tokenizer_kwargs:
1448+
kwargs["tokenizer_args"].update(tokenizer_kwargs)
1449+
if config_kwargs:
1450+
kwargs["config_args"].update(config_kwargs)
1451+
13781452
module = Transformer(model_name_or_path, cache_dir=cache_folder, **kwargs)
13791453
else:
13801454
# Normalize does not require any files to be loaded

sentence_transformers/models/Transformer.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from torch import nn
22
from transformers import AutoModel, AutoTokenizer, AutoConfig, T5Config, MT5Config
33
import json
4-
from typing import List, Dict, Optional, Union, Tuple
4+
from typing import Any, List, Dict, Optional, Union, Tuple
55
import os
66

77

@@ -11,9 +11,10 @@ class Transformer(nn.Module):
1111
1212
:param model_name_or_path: Huggingface models name (https://huggingface.co/models)
1313
:param max_seq_length: Truncate any inputs longer than max_seq_length
14-
:param model_args: Arguments (key, value pairs) passed to the Huggingface Transformers model
14+
:param model_args: Keyword arguments passed to the Huggingface Transformers model
15+
:param tokenizer_args: Keyword arguments passed to the Huggingface Transformers tokenizer
16+
:param config_args: Keyword arguments passed to the Huggingface Transformers config
1517
:param cache_dir: Cache dir for Huggingface Transformers to store/load models
16-
:param tokenizer_args: Arguments (key, value pairs) passed to the Huggingface Tokenizer model
1718
:param do_lower_case: If true, lowercases the input (independent if the model is cased or not)
1819
:param tokenizer_name_or_path: Name or path of the tokenizer. When None, then model_name_or_path is used
1920
"""
@@ -22,17 +23,24 @@ def __init__(
2223
self,
2324
model_name_or_path: str,
2425
max_seq_length: Optional[int] = None,
25-
model_args: Dict = {},
26+
model_args: Optional[Dict[str, Any]] = None,
27+
tokenizer_args: Optional[Dict[str, Any]] = None,
28+
config_args: Optional[Dict[str, Any]] = None,
2629
cache_dir: Optional[str] = None,
27-
tokenizer_args: Dict = {},
2830
do_lower_case: bool = False,
2931
tokenizer_name_or_path: str = None,
3032
):
3133
super(Transformer, self).__init__()
3234
self.config_keys = ["max_seq_length", "do_lower_case"]
3335
self.do_lower_case = do_lower_case
34-
35-
config = AutoConfig.from_pretrained(model_name_or_path, **model_args, cache_dir=cache_dir)
36+
if model_args is None:
37+
model_args = {}
38+
if tokenizer_args is None:
39+
tokenizer_args = {}
40+
if config_args is None:
41+
config_args = {}
42+
43+
config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
3644
self._load_model(model_name_or_path, config, cache_dir, **model_args)
3745

3846
self.tokenizer = AutoTokenizer.from_pretrained(
@@ -182,6 +190,10 @@ def load(input_path: str):
182190
with open(sbert_config_path) as fIn:
183191
config = json.load(fIn)
184192
# Don't allow configs to set trust_remote_code
185-
if "model_args" in config:
193+
if "model_args" in config and "trust_remote_code" in config["model_args"]:
186194
config["model_args"].pop("trust_remote_code")
195+
if "tokenizer_args" in config and "trust_remote_code" in config["tokenizer_args"]:
196+
config["tokenizer_args"].pop("trust_remote_code")
197+
if "config_args" in config and "trust_remote_code" in config["config_args"]:
198+
config["config_args"].pop("trust_remote_code")
187199
return Transformer(model_name_or_path=input_path, **config)

tests/test_sentence_transformer.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,48 @@ def test_save_load_prompts() -> None:
339339
assert fresh_model.default_prompt_name == "query"
340340

341341

342+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA must be available to test float16 support.")
343+
def test_load_with_torch_dtype() -> None:
344+
model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors")
345+
346+
assert model.encode(["Hello there!"], convert_to_tensor=True).dtype == torch.float32
347+
348+
with tempfile.TemporaryDirectory() as tmp_folder:
349+
fp16_model_dir = Path(tmp_folder) / "fp16_model"
350+
model.half()
351+
model.save(str(fp16_model_dir))
352+
del model
353+
354+
fp16_model = SentenceTransformer(
355+
str(fp16_model_dir),
356+
model_kwargs={"torch_dtype": "auto"},
357+
)
358+
assert fp16_model.encode(["Hello there!"], convert_to_tensor=True).dtype == torch.float16
359+
360+
361+
def test_load_with_model_kwargs(monkeypatch: pytest.MonkeyPatch) -> None:
362+
transformer_kwargs = {}
363+
original_transformer_init = Transformer.__init__
364+
365+
def transformers_init(*args, **kwargs):
366+
nonlocal transformer_kwargs
367+
nonlocal original_transformer_init
368+
transformer_kwargs = kwargs
369+
return original_transformer_init(*args, **kwargs)
370+
371+
monkeypatch.setattr(Transformer, "__init__", transformers_init)
372+
373+
SentenceTransformer(
374+
"sentence-transformers-testing/stsb-bert-tiny-safetensors",
375+
model_kwargs={"attn_implementation": "eager", "low_cpu_mem_usage": False},
376+
)
377+
378+
assert "low_cpu_mem_usage" in transformer_kwargs["model_args"]
379+
assert transformer_kwargs["model_args"]["low_cpu_mem_usage"] is False
380+
assert "attn_implementation" in transformer_kwargs["model_args"]
381+
assert transformer_kwargs["model_args"]["attn_implementation"] == "eager"
382+
383+
342384
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA must be available to test float16 support.")
343385
def test_encode_fp16() -> None:
344386
tiny_model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors")

0 commit comments

Comments
 (0)