4242from ..quantizers .quantization_config import QuantizationMethod
4343from ..utils import (
4444 CONFIG_NAME ,
45+ FLASHPACK_WEIGHTS_NAME ,
4546 FLAX_WEIGHTS_NAME ,
4647 HF_ENABLE_PARALLEL_LOADING ,
4748 SAFE_WEIGHTS_INDEX_NAME ,
5556 is_accelerate_available ,
5657 is_bitsandbytes_available ,
5758 is_bitsandbytes_version ,
59+ is_flashpack_available ,
5860 is_peft_available ,
5961 is_torch_version ,
6062 logging ,
@@ -673,6 +675,7 @@ def save_pretrained(
673675 variant : str | None = None ,
674676 max_shard_size : int | str = "10GB" ,
675677 push_to_hub : bool = False ,
678+ use_flashpack : bool = False ,
676679 ** kwargs ,
677680 ):
678681 """
@@ -725,7 +728,12 @@ def save_pretrained(
725728 " the logger on the traceback to understand the reason why the quantized model is not serializable."
726729 )
727730
728- weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
731+ weights_name = WEIGHTS_NAME
732+ if use_flashpack :
733+ weights_name = FLASHPACK_WEIGHTS_NAME
734+ elif safe_serialization :
735+ weights_name = SAFETENSORS_WEIGHTS_NAME
736+
729737 weights_name = _add_variant (weights_name , variant )
730738 weights_name_pattern = weights_name .replace (".bin" , "{suffix}.bin" ).replace (
731739 ".safetensors" , "{suffix}.safetensors"
@@ -752,58 +760,74 @@ def save_pretrained(
752760 # Save the model
753761 state_dict = model_to_save .state_dict ()
754762
755- # Save the model
756- state_dict_split = split_torch_state_dict_into_shards (
757- state_dict , max_shard_size = max_shard_size , filename_pattern = weights_name_pattern
758- )
759-
760- # Clean the folder from a previous save
761- if is_main_process :
762- for filename in os .listdir (save_directory ):
763- if filename in state_dict_split .filename_to_tensors .keys ():
764- continue
765- full_filename = os .path .join (save_directory , filename )
766- if not os .path .isfile (full_filename ):
767- continue
768- weights_without_ext = weights_name_pattern .replace (".bin" , "" ).replace (".safetensors" , "" )
769- weights_without_ext = weights_without_ext .replace ("{suffix}" , "" )
770- filename_without_ext = filename .replace (".bin" , "" ).replace (".safetensors" , "" )
771- # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
772- if (
773- filename .startswith (weights_without_ext )
774- and _REGEX_SHARD .fullmatch (filename_without_ext ) is not None
775- ):
776- os .remove (full_filename )
777-
778- for filename , tensors in state_dict_split .filename_to_tensors .items ():
779- shard = {tensor : state_dict [tensor ].contiguous () for tensor in tensors }
780- filepath = os .path .join (save_directory , filename )
781- if safe_serialization :
782- # At some point we will need to deal better with save_function (used for TPU and other distributed
783- # joyfulness), but for now this enough.
784- safetensors .torch .save_file (shard , filepath , metadata = {"format" : "pt" })
763+ if use_flashpack :
764+ if is_flashpack_available ():
765+ import flashpack
785766 else :
786- torch .save (shard , filepath )
767+ logger .error (
768+ "Saving a FlashPack checkpoint in PyTorch, requires both PyTorch and flashpack to be installed. Please see "
769+ "https://pytorch.org/ and https://github.com/fal-ai/flashpack for installation instructions."
770+ )
771+ raise ImportError ("Please install torch and flashpack to save a FlashPack checkpoint in PyTorch." )
787772
788- if state_dict_split .is_sharded :
789- index = {
790- "metadata" : state_dict_split .metadata ,
791- "weight_map" : state_dict_split .tensor_to_filename ,
792- }
793- save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
794- save_index_file = os .path .join (save_directory , _add_variant (save_index_file , variant ))
795- # Save the index as well
796- with open (save_index_file , "w" , encoding = "utf-8" ) as f :
797- content = json .dumps (index , indent = 2 , sort_keys = True ) + "\n "
798- f .write (content )
799- logger .info (
800- f"The model is bigger than the maximum size per checkpoint ({ max_shard_size } ) and is going to be "
801- f"split in { len (state_dict_split .filename_to_tensors )} checkpoint shards. You can find where each parameters has been saved in the "
802- f"index located at { save_index_file } ."
773+ flashpack .serialization .pack_to_file (
774+ state_dict_or_model = state_dict ,
775+ destination_path = os .path .join (save_directory , weights_name ),
776+ target_dtype = self .dtype ,
803777 )
804778 else :
805- path_to_weights = os .path .join (save_directory , weights_name )
806- logger .info (f"Model weights saved in { path_to_weights } " )
779+ # Save the model
780+ state_dict_split = split_torch_state_dict_into_shards (
781+ state_dict , max_shard_size = max_shard_size , filename_pattern = weights_name_pattern
782+ )
783+
784+ # Clean the folder from a previous save
785+ if is_main_process :
786+ for filename in os .listdir (save_directory ):
787+ if filename in state_dict_split .filename_to_tensors .keys ():
788+ continue
789+ full_filename = os .path .join (save_directory , filename )
790+ if not os .path .isfile (full_filename ):
791+ continue
792+ weights_without_ext = weights_name_pattern .replace (".bin" , "" ).replace (".safetensors" , "" )
793+ weights_without_ext = weights_without_ext .replace ("{suffix}" , "" )
794+ filename_without_ext = filename .replace (".bin" , "" ).replace (".safetensors" , "" )
795+ # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
796+ if (
797+ filename .startswith (weights_without_ext )
798+ and _REGEX_SHARD .fullmatch (filename_without_ext ) is not None
799+ ):
800+ os .remove (full_filename )
801+
802+ for filename , tensors in state_dict_split .filename_to_tensors .items ():
803+ shard = {tensor : state_dict [tensor ].contiguous () for tensor in tensors }
804+ filepath = os .path .join (save_directory , filename )
805+ if safe_serialization :
806+ # At some point we will need to deal better with save_function (used for TPU and other distributed
807+ # joyfulness), but for now this enough.
808+ safetensors .torch .save_file (shard , filepath , metadata = {"format" : "pt" })
809+ else :
810+ torch .save (shard , filepath )
811+
812+ if state_dict_split .is_sharded :
813+ index = {
814+ "metadata" : state_dict_split .metadata ,
815+ "weight_map" : state_dict_split .tensor_to_filename ,
816+ }
817+ save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
818+ save_index_file = os .path .join (save_directory , _add_variant (save_index_file , variant ))
819+ # Save the index as well
820+ with open (save_index_file , "w" , encoding = "utf-8" ) as f :
821+ content = json .dumps (index , indent = 2 , sort_keys = True ) + "\n "
822+ f .write (content )
823+ logger .info (
824+ f"The model is bigger than the maximum size per checkpoint ({ max_shard_size } ) and is going to be "
825+ f"split in { len (state_dict_split .filename_to_tensors )} checkpoint shards. You can find where each parameters has been saved in the "
826+ f"index located at { save_index_file } ."
827+ )
828+ else :
829+ path_to_weights = os .path .join (save_directory , weights_name )
830+ logger .info (f"Model weights saved in { path_to_weights } " )
807831
808832 if push_to_hub :
809833 # Create a new empty model card and eventually tag it
@@ -940,6 +964,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike | None
940964 disable_mmap ('bool', *optional*, defaults to 'False'):
941965 Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
942966 is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
967+ use_flashpack (`bool`, *optional*, defaults to `False`):
968+ If set to `True`, the model is loaded from `flashpack` weights.
969+ flashpack_kwargs(`dict[str, Any]`, *optional*, defaults to `{}`):
970+ Kwargs passed to
971+ [`flashpack.deserialization.assign_from_file`](https://github.com/fal-ai/flashpack/blob/f1aa91c5cd9532a3dbf5bcc707ab9b01c274b76c/src/flashpack/deserialization.py#L408-L422)
972+
943973
944974 > [!TIP] > To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in
945975 with `hf > auth login`. You can also activate the special >
@@ -984,6 +1014,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike | None
9841014 dduf_entries : dict [str , DDUFEntry ] | None = kwargs .pop ("dduf_entries" , None )
9851015 disable_mmap = kwargs .pop ("disable_mmap" , False )
9861016 parallel_config : ParallelConfig | ContextParallelConfig | None = kwargs .pop ("parallel_config" , None )
1017+ use_flashpack = kwargs .pop ("use_flashpack" , False )
1018+ flashpack_kwargs = kwargs .pop ("flashpack_kwargs" , {})
9871019
9881020 is_parallel_loading_enabled = HF_ENABLE_PARALLEL_LOADING
9891021 if is_parallel_loading_enabled and not low_cpu_mem_usage :
@@ -1212,30 +1244,37 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike | None
12121244 subfolder = subfolder or "" ,
12131245 dduf_entries = dduf_entries ,
12141246 )
1215- elif use_safetensors :
1216- try :
1217- resolved_model_file = _get_model_file (
1218- pretrained_model_name_or_path ,
1219- weights_name = _add_variant (SAFETENSORS_WEIGHTS_NAME , variant ),
1220- cache_dir = cache_dir ,
1221- force_download = force_download ,
1222- proxies = proxies ,
1223- local_files_only = local_files_only ,
1224- token = token ,
1225- revision = revision ,
1226- subfolder = subfolder ,
1227- user_agent = user_agent ,
1228- commit_hash = commit_hash ,
1229- dduf_entries = dduf_entries ,
1230- )
1247+ else :
1248+ if use_flashpack :
1249+ weights_name = FLASHPACK_WEIGHTS_NAME
1250+ elif use_safetensors :
1251+ weights_name = _add_variant (SAFETENSORS_WEIGHTS_NAME , variant )
1252+ else :
1253+ weights_name = None
1254+ if weights_name is not None :
1255+ try :
1256+ resolved_model_file = _get_model_file (
1257+ pretrained_model_name_or_path ,
1258+ weights_name = weights_name ,
1259+ cache_dir = cache_dir ,
1260+ force_download = force_download ,
1261+ proxies = proxies ,
1262+ local_files_only = local_files_only ,
1263+ token = token ,
1264+ revision = revision ,
1265+ subfolder = subfolder ,
1266+ user_agent = user_agent ,
1267+ commit_hash = commit_hash ,
1268+ dduf_entries = dduf_entries ,
1269+ )
12311270
1232- except IOError as e :
1233- logger .error (f"An error occurred while trying to fetch { pretrained_model_name_or_path } : { e } " )
1234- if not allow_pickle :
1235- raise
1236- logger .warning (
1237- "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
1238- )
1271+ except IOError as e :
1272+ logger .error (f"An error occurred while trying to fetch { pretrained_model_name_or_path } : { e } " )
1273+ if not allow_pickle :
1274+ raise
1275+ logger .warning (
1276+ "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
1277+ )
12391278
12401279 if resolved_model_file is None and not is_sharded :
12411280 resolved_model_file = _get_model_file (
@@ -1275,6 +1314,44 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike | None
12751314 with ContextManagers (init_contexts ):
12761315 model = cls .from_config (config , ** unused_kwargs )
12771316
1317+ if use_flashpack :
1318+ if is_flashpack_available ():
1319+ import flashpack
1320+ else :
1321+ logger .error (
1322+ "Loading a FlashPack checkpoint in PyTorch, requires both PyTorch and flashpack to be installed. Please see "
1323+ "https://pytorch.org/ and https://github.com/fal-ai/flashpack for installation instructions."
1324+ )
1325+ raise ImportError ("Please install torch and flashpack to load a FlashPack checkpoint in PyTorch." )
1326+
1327+ if device_map is None :
1328+ logger .warning (
1329+ "`device_map` has not been provided for FlashPack, model will be on `cpu` - provide `device_map` to fully utilize "
1330+ "the benefit of FlashPack."
1331+ )
1332+ flashpack_device = torch .device ("cpu" )
1333+ else :
1334+ device = device_map ["" ]
1335+ if isinstance (device , str ) and device in ["auto" , "balanced" , "balanced_low_0" , "sequential" ]:
1336+ raise ValueError (
1337+ "FlashPack `device_map` should not be one of `auto`, `balanced`, `balanced_low_0`, `sequential`. Use a specific device instead, e.g., `device_map='cuda'` or `device_map='cuda:0'"
1338+ )
1339+ flashpack_device = torch .device (device ) if not isinstance (device , torch .device ) else device
1340+
1341+ flashpack .mixin .assign_from_file (
1342+ model = model ,
1343+ path = resolved_model_file [0 ],
1344+ device = flashpack_device ,
1345+ ** flashpack_kwargs ,
1346+ )
1347+ if dtype_orig is not None :
1348+ torch .set_default_dtype (dtype_orig )
1349+ if output_loading_info :
1350+ logger .warning ("`output_loading_info` is not supported with FlashPack." )
1351+ return model , {}
1352+
1353+ return model
1354+
12781355 if dtype_orig is not None :
12791356 torch .set_default_dtype (dtype_orig )
12801357
0 commit comments