diff --git a/src/huggingface_hub/hub_mixin.py b/src/huggingface_hub/hub_mixin.py index ceb3ee235a..0acb8f04fe 100644 --- a/src/huggingface_hub/hub_mixin.py +++ b/src/huggingface_hub/hub_mixin.py @@ -41,6 +41,8 @@ import torch # type: ignore if is_safetensors_available(): + import packaging.version + import safetensors from safetensors.torch import load_model as load_model_as_safetensor from safetensors.torch import save_model as save_model_as_safetensor @@ -827,17 +829,18 @@ def _load_as_pickle(cls, model: T, model_file: str, map_location: str, strict: b @classmethod def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T: - load_model_as_safetensor(model, model_file, strict=strict) # type: ignore [arg-type] - if map_location != "cpu": - # TODO: remove this once https://github.com/huggingface/safetensors/pull/449 is merged. - logger.warning( - "Loading model weights on other devices than 'cpu' is not supported natively." - " This means that the model is loaded on 'cpu' first and then copied to the device." - " This leads to a slower loading time." - " Support for loading directly on other devices is planned to be added in future releases." - " See https://github.com/huggingface/huggingface_hub/pull/2086 for more details." - ) - model.to(map_location) # type: ignore [attr-defined] + if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"): + load_model_as_safetensor(model, model_file, strict=strict) # type: ignore [arg-type] + if map_location != "cpu": + logger.warning( + "Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors." + " This means that the model is loaded on 'cpu' first and then copied to the device." + " This leads to a slower loading time." + " Please update safetensors to version 0.4.3 or above for improved performance." + ) + model.to(map_location) # type: ignore [attr-defined] + else: + safetensors.torch.load_model(model, model_file, strict=strict, device=map_location) return model