Skip to content

Commit b506898

Browse files
authored
Merge branch 'main' into flux-inference-ptxla
2 parents ccd902d + 5063aa5 commit b506898

File tree

10 files changed

+269
-73
lines changed

10 files changed

+269
-73
lines changed

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@
146146
"phonemizer",
147147
"opencv-python",
148148
"timm",
149+
"flashpack",
149150
]
150151

151152
# this is a lookup table with items like:
@@ -250,6 +251,7 @@ def run(self):
250251
extras["optimum_quanto"] = deps_list("optimum_quanto", "accelerate")
251252
extras["torchao"] = deps_list("torchao", "accelerate")
252253
extras["nvidia_modelopt"] = deps_list("nvidia_modelopt[hf]")
254+
extras["flashpack"] = deps_list("flashpack")
253255

254256
if os.name == "nt": # windows
255257
extras["flax"] = [] # jax is not supported on windows

src/diffusers/dependency_versions_table.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,5 @@
5353
"phonemizer": "phonemizer",
5454
"opencv-python": "opencv-python",
5555
"timm": "timm",
56+
"flashpack": "flashpack",
5657
}

src/diffusers/models/modeling_utils.py

Lines changed: 149 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from ..quantizers.quantization_config import QuantizationMethod
4343
from ..utils import (
4444
CONFIG_NAME,
45+
FLASHPACK_WEIGHTS_NAME,
4546
FLAX_WEIGHTS_NAME,
4647
HF_ENABLE_PARALLEL_LOADING,
4748
SAFE_WEIGHTS_INDEX_NAME,
@@ -55,6 +56,7 @@
5556
is_accelerate_available,
5657
is_bitsandbytes_available,
5758
is_bitsandbytes_version,
59+
is_flashpack_available,
5860
is_peft_available,
5961
is_torch_version,
6062
logging,
@@ -673,6 +675,7 @@ def save_pretrained(
673675
variant: str | None = None,
674676
max_shard_size: int | str = "10GB",
675677
push_to_hub: bool = False,
678+
use_flashpack: bool = False,
676679
**kwargs,
677680
):
678681
"""
@@ -725,7 +728,12 @@ def save_pretrained(
725728
" the logger on the traceback to understand the reason why the quantized model is not serializable."
726729
)
727730

728-
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
731+
weights_name = WEIGHTS_NAME
732+
if use_flashpack:
733+
weights_name = FLASHPACK_WEIGHTS_NAME
734+
elif safe_serialization:
735+
weights_name = SAFETENSORS_WEIGHTS_NAME
736+
729737
weights_name = _add_variant(weights_name, variant)
730738
weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
731739
".safetensors", "{suffix}.safetensors"
@@ -752,58 +760,74 @@ def save_pretrained(
752760
# Save the model
753761
state_dict = model_to_save.state_dict()
754762

755-
# Save the model
756-
state_dict_split = split_torch_state_dict_into_shards(
757-
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
758-
)
759-
760-
# Clean the folder from a previous save
761-
if is_main_process:
762-
for filename in os.listdir(save_directory):
763-
if filename in state_dict_split.filename_to_tensors.keys():
764-
continue
765-
full_filename = os.path.join(save_directory, filename)
766-
if not os.path.isfile(full_filename):
767-
continue
768-
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
769-
weights_without_ext = weights_without_ext.replace("{suffix}", "")
770-
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
771-
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
772-
if (
773-
filename.startswith(weights_without_ext)
774-
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
775-
):
776-
os.remove(full_filename)
777-
778-
for filename, tensors in state_dict_split.filename_to_tensors.items():
779-
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
780-
filepath = os.path.join(save_directory, filename)
781-
if safe_serialization:
782-
# At some point we will need to deal better with save_function (used for TPU and other distributed
783-
# joyfulness), but for now this enough.
784-
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
763+
if use_flashpack:
764+
if is_flashpack_available():
765+
import flashpack
785766
else:
786-
torch.save(shard, filepath)
767+
logger.error(
768+
"Saving a FlashPack checkpoint in PyTorch, requires both PyTorch and flashpack to be installed. Please see "
769+
"https://pytorch.org/ and https://github.com/fal-ai/flashpack for installation instructions."
770+
)
771+
raise ImportError("Please install torch and flashpack to save a FlashPack checkpoint in PyTorch.")
787772

788-
if state_dict_split.is_sharded:
789-
index = {
790-
"metadata": state_dict_split.metadata,
791-
"weight_map": state_dict_split.tensor_to_filename,
792-
}
793-
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
794-
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
795-
# Save the index as well
796-
with open(save_index_file, "w", encoding="utf-8") as f:
797-
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
798-
f.write(content)
799-
logger.info(
800-
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
801-
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
802-
f"index located at {save_index_file}."
773+
flashpack.serialization.pack_to_file(
774+
state_dict_or_model=state_dict,
775+
destination_path=os.path.join(save_directory, weights_name),
776+
target_dtype=self.dtype,
803777
)
804778
else:
805-
path_to_weights = os.path.join(save_directory, weights_name)
806-
logger.info(f"Model weights saved in {path_to_weights}")
779+
# Save the model
780+
state_dict_split = split_torch_state_dict_into_shards(
781+
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
782+
)
783+
784+
# Clean the folder from a previous save
785+
if is_main_process:
786+
for filename in os.listdir(save_directory):
787+
if filename in state_dict_split.filename_to_tensors.keys():
788+
continue
789+
full_filename = os.path.join(save_directory, filename)
790+
if not os.path.isfile(full_filename):
791+
continue
792+
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
793+
weights_without_ext = weights_without_ext.replace("{suffix}", "")
794+
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
795+
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
796+
if (
797+
filename.startswith(weights_without_ext)
798+
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
799+
):
800+
os.remove(full_filename)
801+
802+
for filename, tensors in state_dict_split.filename_to_tensors.items():
803+
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
804+
filepath = os.path.join(save_directory, filename)
805+
if safe_serialization:
806+
# At some point we will need to deal better with save_function (used for TPU and other distributed
807+
# joyfulness), but for now this enough.
808+
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
809+
else:
810+
torch.save(shard, filepath)
811+
812+
if state_dict_split.is_sharded:
813+
index = {
814+
"metadata": state_dict_split.metadata,
815+
"weight_map": state_dict_split.tensor_to_filename,
816+
}
817+
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
818+
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
819+
# Save the index as well
820+
with open(save_index_file, "w", encoding="utf-8") as f:
821+
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
822+
f.write(content)
823+
logger.info(
824+
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
825+
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
826+
f"index located at {save_index_file}."
827+
)
828+
else:
829+
path_to_weights = os.path.join(save_directory, weights_name)
830+
logger.info(f"Model weights saved in {path_to_weights}")
807831

808832
if push_to_hub:
809833
# Create a new empty model card and eventually tag it
@@ -940,6 +964,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike | None
940964
disable_mmap ('bool', *optional*, defaults to 'False'):
941965
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
942966
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
967+
use_flashpack (`bool`, *optional*, defaults to `False`):
968+
If set to `True`, the model is loaded from `flashpack` weights.
969+
flashpack_kwargs(`dict[str, Any]`, *optional*, defaults to `{}`):
970+
Kwargs passed to
971+
[`flashpack.deserialization.assign_from_file`](https://github.com/fal-ai/flashpack/blob/f1aa91c5cd9532a3dbf5bcc707ab9b01c274b76c/src/flashpack/deserialization.py#L408-L422)
972+
943973
944974
> [!TIP] > To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in
945975
with `hf > auth login`. You can also activate the special >
@@ -984,6 +1014,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike | None
9841014
dduf_entries: dict[str, DDUFEntry] | None = kwargs.pop("dduf_entries", None)
9851015
disable_mmap = kwargs.pop("disable_mmap", False)
9861016
parallel_config: ParallelConfig | ContextParallelConfig | None = kwargs.pop("parallel_config", None)
1017+
use_flashpack = kwargs.pop("use_flashpack", False)
1018+
flashpack_kwargs = kwargs.pop("flashpack_kwargs", {})
9871019

9881020
is_parallel_loading_enabled = HF_ENABLE_PARALLEL_LOADING
9891021
if is_parallel_loading_enabled and not low_cpu_mem_usage:
@@ -1212,30 +1244,37 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike | None
12121244
subfolder=subfolder or "",
12131245
dduf_entries=dduf_entries,
12141246
)
1215-
elif use_safetensors:
1216-
try:
1217-
resolved_model_file = _get_model_file(
1218-
pretrained_model_name_or_path,
1219-
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
1220-
cache_dir=cache_dir,
1221-
force_download=force_download,
1222-
proxies=proxies,
1223-
local_files_only=local_files_only,
1224-
token=token,
1225-
revision=revision,
1226-
subfolder=subfolder,
1227-
user_agent=user_agent,
1228-
commit_hash=commit_hash,
1229-
dduf_entries=dduf_entries,
1230-
)
1247+
else:
1248+
if use_flashpack:
1249+
weights_name = FLASHPACK_WEIGHTS_NAME
1250+
elif use_safetensors:
1251+
weights_name = _add_variant(SAFETENSORS_WEIGHTS_NAME, variant)
1252+
else:
1253+
weights_name = None
1254+
if weights_name is not None:
1255+
try:
1256+
resolved_model_file = _get_model_file(
1257+
pretrained_model_name_or_path,
1258+
weights_name=weights_name,
1259+
cache_dir=cache_dir,
1260+
force_download=force_download,
1261+
proxies=proxies,
1262+
local_files_only=local_files_only,
1263+
token=token,
1264+
revision=revision,
1265+
subfolder=subfolder,
1266+
user_agent=user_agent,
1267+
commit_hash=commit_hash,
1268+
dduf_entries=dduf_entries,
1269+
)
12311270

1232-
except IOError as e:
1233-
logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
1234-
if not allow_pickle:
1235-
raise
1236-
logger.warning(
1237-
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
1238-
)
1271+
except IOError as e:
1272+
logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
1273+
if not allow_pickle:
1274+
raise
1275+
logger.warning(
1276+
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
1277+
)
12391278

12401279
if resolved_model_file is None and not is_sharded:
12411280
resolved_model_file = _get_model_file(
@@ -1275,6 +1314,44 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike | None
12751314
with ContextManagers(init_contexts):
12761315
model = cls.from_config(config, **unused_kwargs)
12771316

1317+
if use_flashpack:
1318+
if is_flashpack_available():
1319+
import flashpack
1320+
else:
1321+
logger.error(
1322+
"Loading a FlashPack checkpoint in PyTorch, requires both PyTorch and flashpack to be installed. Please see "
1323+
"https://pytorch.org/ and https://github.com/fal-ai/flashpack for installation instructions."
1324+
)
1325+
raise ImportError("Please install torch and flashpack to load a FlashPack checkpoint in PyTorch.")
1326+
1327+
if device_map is None:
1328+
logger.warning(
1329+
"`device_map` has not been provided for FlashPack, model will be on `cpu` - provide `device_map` to fully utilize "
1330+
"the benefit of FlashPack."
1331+
)
1332+
flashpack_device = torch.device("cpu")
1333+
else:
1334+
device = device_map[""]
1335+
if isinstance(device, str) and device in ["auto", "balanced", "balanced_low_0", "sequential"]:
1336+
raise ValueError(
1337+
"FlashPack `device_map` should not be one of `auto`, `balanced`, `balanced_low_0`, `sequential`. Use a specific device instead, e.g., `device_map='cuda'` or `device_map='cuda:0'"
1338+
)
1339+
flashpack_device = torch.device(device) if not isinstance(device, torch.device) else device
1340+
1341+
flashpack.mixin.assign_from_file(
1342+
model=model,
1343+
path=resolved_model_file[0],
1344+
device=flashpack_device,
1345+
**flashpack_kwargs,
1346+
)
1347+
if dtype_orig is not None:
1348+
torch.set_default_dtype(dtype_orig)
1349+
if output_loading_info:
1350+
logger.warning("`output_loading_info` is not supported with FlashPack.")
1351+
return model, {}
1352+
1353+
return model
1354+
12781355
if dtype_orig is not None:
12791356
torch.set_default_dtype(dtype_orig)
12801357

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
from .. import __version__
3030
from ..utils import (
31+
FLASHPACK_WEIGHTS_NAME,
3132
FLAX_WEIGHTS_NAME,
3233
ONNX_EXTERNAL_WEIGHTS_NAME,
3334
ONNX_WEIGHTS_NAME,
@@ -194,6 +195,7 @@ def filter_model_files(filenames):
194195
FLAX_WEIGHTS_NAME,
195196
ONNX_WEIGHTS_NAME,
196197
ONNX_EXTERNAL_WEIGHTS_NAME,
198+
FLASHPACK_WEIGHTS_NAME,
197199
]
198200

199201
if is_transformers_available():
@@ -413,6 +415,9 @@ def get_class_obj_and_candidates(
413415
"""Simple helper method to retrieve class object of module as well as potential parent class objects"""
414416
component_folder = os.path.join(cache_dir, component_name) if component_name and cache_dir else None
415417

418+
if class_name.startswith("FlashPack"):
419+
class_name = class_name.removeprefix("FlashPack")
420+
416421
if is_pipeline_module:
417422
pipeline_module = getattr(pipelines, library_name)
418423

@@ -760,6 +765,7 @@ def load_sub_model(
760765
provider_options: Any,
761766
disable_mmap: bool,
762767
quantization_config: Any | None = None,
768+
use_flashpack: bool = False,
763769
):
764770
"""Helper method to load the module `name` from `library_name` and `class_name`"""
765771
from ..quantizers import PipelineQuantizationConfig
@@ -838,6 +844,9 @@ def load_sub_model(
838844
loading_kwargs["variant"] = model_variants.pop(name, None)
839845
loading_kwargs["use_safetensors"] = use_safetensors
840846

847+
if is_diffusers_model:
848+
loading_kwargs["use_flashpack"] = use_flashpack
849+
841850
if from_flax:
842851
loading_kwargs["from_flax"] = True
843852

@@ -887,7 +896,7 @@ def load_sub_model(
887896
# else load from the root directory
888897
loaded_sub_model = load_method(cached_folder, **loading_kwargs)
889898

890-
if isinstance(loaded_sub_model, torch.nn.Module) and isinstance(device_map, dict):
899+
if isinstance(loaded_sub_model, torch.nn.Module) and isinstance(device_map, dict) and not use_flashpack:
891900
# remove hooks
892901
remove_hook_from_module(loaded_sub_model, recurse=True)
893902
needs_offloading_to_cpu = device_map[""] == "cpu"
@@ -1093,6 +1102,7 @@ def _get_ignore_patterns(
10931102
allow_pickle: bool,
10941103
use_onnx: bool,
10951104
is_onnx: bool,
1105+
use_flashpack: bool,
10961106
variant: str | None = None,
10971107
) -> list[str]:
10981108
if (
@@ -1118,6 +1128,9 @@ def _get_ignore_patterns(
11181128
if not use_onnx:
11191129
ignore_patterns += ["*.onnx", "*.pb"]
11201130

1131+
elif use_flashpack:
1132+
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb", "*.msgpack"]
1133+
11211134
else:
11221135
ignore_patterns = ["*.safetensors", "*.msgpack"]
11231136

0 commit comments

Comments
 (0)