Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions src/huggingface_hub/hub_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
import torch # type: ignore

if is_safetensors_available():
import packaging.version
import safetensors
from safetensors.torch import load_model as load_model_as_safetensor
from safetensors.torch import save_model as save_model_as_safetensor

Expand Down Expand Up @@ -827,17 +829,18 @@ def _load_as_pickle(cls, model: T, model_file: str, map_location: str, strict: b

@classmethod
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
load_model_as_safetensor(model, model_file, strict=strict) # type: ignore [arg-type]
if map_location != "cpu":
# TODO: remove this once https://github.com/huggingface/safetensors/pull/449 is merged.
logger.warning(
"Loading model weights on other devices than 'cpu' is not supported natively."
" This means that the model is loaded on 'cpu' first and then copied to the device."
" This leads to a slower loading time."
" Support for loading directly on other devices is planned to be added in future releases."
" See https://github.com/huggingface/huggingface_hub/pull/2086 for more details."
)
model.to(map_location) # type: ignore [attr-defined]
if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"):
load_model_as_safetensor(model, model_file, strict=strict) # type: ignore [arg-type]
if map_location != "cpu":
logger.warning(
"Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors."
" This means that the model is loaded on 'cpu' first and then copied to the device."
" This leads to a slower loading time."
" Please update safetensors to version 0.4.3 or above for improved performance."
)
model.to(map_location) # type: ignore [attr-defined]
else:
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
return model


Expand Down