|
41 | 41 | import torch # type: ignore |
42 | 42 |
|
43 | 43 | if is_safetensors_available(): |
| 44 | + import packaging.version |
| 45 | + import safetensors |
44 | 46 | from safetensors.torch import load_model as load_model_as_safetensor |
45 | 47 | from safetensors.torch import save_model as save_model_as_safetensor |
46 | 48 |
|
@@ -827,17 +829,18 @@ def _load_as_pickle(cls, model: T, model_file: str, map_location: str, strict: b |
827 | 829 |
|
828 | 830 | @classmethod |
829 | 831 | def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T: |
830 | | - load_model_as_safetensor(model, model_file, strict=strict) # type: ignore [arg-type] |
831 | | - if map_location != "cpu": |
832 | | - # TODO: remove this once https://github.com/huggingface/safetensors/pull/449 is merged. |
833 | | - logger.warning( |
834 | | - "Loading model weights on other devices than 'cpu' is not supported natively." |
835 | | - " This means that the model is loaded on 'cpu' first and then copied to the device." |
836 | | - " This leads to a slower loading time." |
837 | | - " Support for loading directly on other devices is planned to be added in future releases." |
838 | | - " See https://github.com/huggingface/huggingface_hub/pull/2086 for more details." |
839 | | - ) |
840 | | - model.to(map_location) # type: ignore [attr-defined] |
| 832 | + if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"): |
| 833 | + load_model_as_safetensor(model, model_file, strict=strict) # type: ignore [arg-type] |
| 834 | + if map_location != "cpu": |
| 835 | + logger.warning( |
| 836 | + "Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors." |
| 837 | + " This means that the model is loaded on 'cpu' first and then copied to the device." |
| 838 | + " This leads to a slower loading time." |
| 839 | + " Please update safetensors to version 0.4.3 or above for improved performance." |
| 840 | + ) |
| 841 | + model.to(map_location) # type: ignore [attr-defined] |
| 842 | + else: |
| 843 | + safetensors.torch.load_model(model, model_file, strict=strict, device=map_location) |
841 | 844 | return model |
842 | 845 |
|
843 | 846 |
|
|
0 commit comments