Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def get_version() -> str:
extras["testing"] = (
extras["cli"]
+ extras["inference"]
+ extras["torch"]
+ [
"jedi",
"Jinja2",
Expand Down
25 changes: 14 additions & 11 deletions src/huggingface_hub/hub_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
import torch # type: ignore

if is_safetensors_available():
import packaging.version
import safetensors
from safetensors.torch import load_model as load_model_as_safetensor
from safetensors.torch import save_model as save_model_as_safetensor

Expand Down Expand Up @@ -827,17 +829,18 @@ def _load_as_pickle(cls, model: T, model_file: str, map_location: str, strict: b

@classmethod
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
load_model_as_safetensor(model, model_file, strict=strict) # type: ignore [arg-type]
if map_location != "cpu":
# TODO: remove this once https://github.com/huggingface/safetensors/pull/449 is merged.
logger.warning(
"Loading model weights on other devices than 'cpu' is not supported natively."
" This means that the model is loaded on 'cpu' first and then copied to the device."
" This leads to a slower loading time."
" Support for loading directly on other devices is planned to be added in future releases."
" See https://github.com/huggingface/huggingface_hub/pull/2086 for more details."
)
model.to(map_location) # type: ignore [attr-defined]
if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"):
load_model_as_safetensor(model, model_file, strict=strict) # type: ignore [arg-type]
if map_location != "cpu":
logger.warning(
"Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors."
" This means that the model is loaded on 'cpu' first and then copied to the device."
" This leads to a slower loading time."
" Please update safetensors to version 0.4.3 or above for improved performance."
)
model.to(map_location) # type: ignore [attr-defined]
else:
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
return model


Expand Down