-
Notifications
You must be signed in to change notification settings - Fork 6.4k
[Core] fix variant-identification. #9253
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 28 commits
6b379a9
f155ec7
3f36e59
91253e8
dd5941e
564b8b4
fdd0435
d5cad9e
c0b1ceb
247dd93
b024a6d
fdfdc5f
dcf1852
3a71ad9
ab91852
aa631c5
453bfa5
11e4b71
dbdf0f9
671038a
57382f2
ea5ecdb
a510a9b
f583dad
dc0255a
f2ab3de
10baa9d
25ac01f
bac62ac
b6794ed
fcb4e39
4c0c5d2
0b1c2a6
8ad6b23
1190f7d
59cfefb
d72f5c1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -50,15 +50,14 @@ | |||
DEPRECATED_REVISION_ARGS, | ||||
BaseOutput, | ||||
PushToHubMixin, | ||||
deprecate, | ||||
is_accelerate_available, | ||||
is_accelerate_version, | ||||
is_torch_npu_available, | ||||
is_torch_version, | ||||
logging, | ||||
numpy_to_pil, | ||||
) | ||||
from ..utils.hub_utils import load_or_create_model_card, populate_model_card | ||||
from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card | ||||
from ..utils.torch_utils import is_compiled_module | ||||
|
||||
|
||||
|
@@ -735,6 +734,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |||
else: | ||||
cached_folder = pretrained_model_name_or_path | ||||
|
||||
# The variant filenames can have the legacy sharding checkpoint format that we check and throw | ||||
# a warning if detected. | ||||
if variant is not None and _check_legacy_sharding_variant_format(cached_folder, variant): | ||||
warn_msg = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`." | ||||
logger.warning(warn_msg) | ||||
|
||||
config_dict = cls.load_config(cached_folder) | ||||
|
||||
# pop out "_ignore_files" as it is only needed for download | ||||
|
@@ -745,6 +750,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |||
# Example: `diffusion_pytorch_model.safetensors` -> `diffusion_pytorch_model.fp16.safetensors` | ||||
# with variant being `"fp16"`. | ||||
model_variants = _identify_model_variants(folder=cached_folder, variant=variant, config=config_dict) | ||||
if len(model_variants) == 0 and variant is not None: | ||||
error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." | ||||
raise ValueError(error_message) | ||||
|
||||
# 3. Load the pipeline class, if using custom module then load it from the hub | ||||
# if we load from explicit class, let's use it | ||||
|
@@ -1251,6 +1259,9 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: | |||
model_info_call_error = e # save error to reraise it if model is not cached locally | ||||
|
||||
if not local_files_only: | ||||
filenames = {sibling.rfilename for sibling in info.siblings} | ||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) | ||||
|
||||
config_file = hf_hub_download( | ||||
pretrained_model_name, | ||||
cls.config_name, | ||||
|
@@ -1267,9 +1278,6 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: | |||
# retrieve all folder_names that contain relevant files | ||||
folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"] | ||||
|
||||
filenames = {sibling.rfilename for sibling in info.siblings} | ||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) | ||||
|
||||
Comment on lines
-1270
to
-1272
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was moved up to raise error earlier in code. |
||||
diffusers_module = importlib.import_module(__name__.split(".")[0]) | ||||
pipelines = getattr(diffusers_module, "pipelines") | ||||
|
||||
|
@@ -1292,13 +1300,8 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: | |||
) | ||||
|
||||
if len(variant_filenames) == 0 and variant is not None: | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's not remove this error in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not an error, though. It's a deprecation. Do we exactly want to keep it that way? If so, we will have to remove it anyway because the deprecation is supposed to expire after "0.24.0" version. Instead, we are erroring out now from
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah got it. I think this should be resolved now. WDYT about catching these errors without having to download the actual files and leveraging This could live in a future PR.
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
deprecation_message = ( | ||||
f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." | ||||
f"The default model files: {model_filenames} will be loaded instead. Make sure to not load from `variant={variant}`" | ||||
"if such variant modeling files are not available. Doing so will lead to an error in v0.24.0 as defaulting to non-variant" | ||||
"modeling files is deprecated." | ||||
) | ||||
deprecate("no variant default", "0.24.0", deprecation_message, standard_warn=False) | ||||
error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." | ||||
raise ValueError(error_message) | ||||
|
||||
# remove ignored filenames | ||||
model_filenames = set(model_filenames) - set(ignore_filenames) | ||||
|
@@ -1368,6 +1371,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: | |||
) | ||||
expected_components, _ = cls._get_signature_keys(pipeline_class) | ||||
passed_components = [k for k in expected_components if k in kwargs] | ||||
is_sharded = any("index.json" in f and f != "model_index.json" for f in filenames) | ||||
|
||||
if ( | ||||
use_safetensors | ||||
|
@@ -1392,9 +1396,13 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: | |||
|
||||
safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")} | ||||
safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")} | ||||
# `not is_sharded` because sharded checkpoints with a variant | ||||
sayakpaul marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||
# ("fp16") for example may have lesser shards actually. Consider | ||||
# https://huggingface.co/fal/AuraFlow/tree/main/transformer, for example. | ||||
sayakpaul marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||
if ( | ||||
len(safetensors_variant_filenames) > 0 | ||||
and safetensors_model_filenames != safetensors_variant_filenames | ||||
and not is_sharded | ||||
): | ||||
logger.warning( | ||||
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure." | ||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,8 +27,9 @@ | |
import requests_mock | ||
import torch | ||
from accelerate.utils import compute_module_sizes | ||
from huggingface_hub import ModelCard, delete_repo | ||
from huggingface_hub import ModelCard, delete_repo, snapshot_download | ||
from huggingface_hub.utils import is_jinja_available | ||
from parameterized import parameterized | ||
from requests.exceptions import HTTPError | ||
|
||
from diffusers.models import UNet2DConditionModel | ||
|
@@ -39,7 +40,13 @@ | |
XFormersAttnProcessor, | ||
) | ||
from diffusers.training_utils import EMAModel | ||
from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, is_torch_npu_available, is_xformers_available, logging | ||
from diffusers.utils import ( | ||
SAFE_WEIGHTS_INDEX_NAME, | ||
WEIGHTS_INDEX_NAME, | ||
is_torch_npu_available, | ||
is_xformers_available, | ||
logging, | ||
) | ||
from diffusers.utils.hub_utils import _add_variant | ||
from diffusers.utils.testing_utils import ( | ||
CaptureLogger, | ||
|
@@ -100,6 +107,48 @@ def test_accelerate_loading_error_message(self): | |
# make sure that error message states what keys are missing | ||
assert "conv_out.bias" in str(error_context.exception) | ||
|
||
@parameterized.expand( | ||
[ | ||
("hf-internal-testing/tiny-stable-diffusion-pipe-variants-all-kinds", "unet", False), | ||
("hf-internal-testing/tiny-stable-diffusion-pipe-variants-all-kinds", "unet", True), | ||
("hf-internal-testing/tiny-sd-unet-with-sharded-ckpt", None, False), | ||
("hf-internal-testing/tiny-sd-unet-with-sharded-ckpt", None, True), | ||
] | ||
) | ||
def test_variant_sharded_ckpt_legacy_format_raises_warning(self, repo_id, subfolder, use_local): | ||
def load_model(path): | ||
kwargs = {"variant": "fp16"} | ||
if subfolder: | ||
kwargs["subfolder"] = subfolder | ||
return UNet2DConditionModel.from_pretrained(path, **kwargs) | ||
|
||
with self.assertWarns(FutureWarning) as warning: | ||
if use_local: | ||
with tempfile.TemporaryDirectory() as tmpdirname: | ||
tmpdirname = snapshot_download(repo_id=repo_id) | ||
_ = load_model(tmpdirname) | ||
else: | ||
_ = load_model(repo_id) | ||
|
||
warning_message = str(warning.warnings[0].message) | ||
self.assertIn("This serialization format is now deprecated to standardize the serialization", warning_message) | ||
|
||
# Local tests are already covered down below. | ||
@parameterized.expand( | ||
[ | ||
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", None), | ||
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolde", "unet"), | ||
] | ||
) | ||
def test_variant_sharded_ckpt_loads_from_hub(self, repo_id, subfolder): | ||
|
||
def load_model(): | ||
kwargs = {"variant": "fp16"} | ||
if subfolder: | ||
kwargs["subfolder"] = subfolder | ||
return UNet2DConditionModel.from_pretrained(repo_id, **kwargs) | ||
|
||
assert load_model() | ||
|
||
def test_cached_files_are_used_when_no_internet(self): | ||
# A mock response for an HTTP head request to emulate server down | ||
response_mock = mock.Mock() | ||
|
@@ -924,6 +973,7 @@ def test_sharded_checkpoints_with_variant(self): | |
# testing if loading works with the variant when the checkpoint is sharded should be | ||
# enough. | ||
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB", variant=variant) | ||
|
||
index_filename = _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) | ||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, index_filename))) | ||
|
||
|
@@ -976,6 +1026,46 @@ def test_sharded_checkpoints_device_map(self): | |
new_output = new_model(**inputs_dict) | ||
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) | ||
|
||
# This test is okay without a GPU because we're not running any execution. We're just serializing | ||
# and check if the resultant files are following an expected format. | ||
def test_variant_sharded_ckpt_right_format(self): | ||
for use_safe in [True, False]: | ||
extension = ".safetensors" if use_safe else ".bin" | ||
config, _ = self.prepare_init_args_and_inputs_for_common() | ||
model = self.model_class(**config).eval() | ||
if model._no_split_modules is None: | ||
return | ||
sayakpaul marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
model_size = compute_module_sizes(model)[""] | ||
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small. | ||
variant = "fp16" | ||
with tempfile.TemporaryDirectory() as tmp_dir: | ||
model.cpu().save_pretrained( | ||
tmp_dir, variant=variant, max_shard_size=f"{max_shard_size}KB", safe_serialization=use_safe | ||
) | ||
index_variant = _add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safe else WEIGHTS_INDEX_NAME, variant) | ||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, index_variant))) | ||
|
||
# Now check if the right number of shards exists. First, let's get the number of shards. | ||
# Since this number can be dependent on the model being tested, it's important that we calculate it | ||
# instead of hardcoding it. | ||
expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, index_variant)) | ||
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(extension)]) | ||
self.assertTrue(actual_num_shards == expected_num_shards) | ||
|
||
# Check if the variant is present as a substring in the checkpoints. | ||
shard_files = [ | ||
file | ||
for file in os.listdir(tmp_dir) | ||
if file.endswith(extension) or ("index" in file and "json" in file) | ||
] | ||
assert all(variant in f for f in shard_files) | ||
|
||
# Check if the sharded checkpoints were serialized in the right format. | ||
shard_files = [file for file in os.listdir(tmp_dir) if file.endswith(extension)] | ||
# Example: diffusion_pytorch_model.fp16-00001-of-00002.safetensors | ||
assert all(f.split(".")[1].split("-")[0] == variant for f in shard_files) | ||
|
||
|
||
@is_staging_test | ||
class ModelPushToHubTester(unittest.TestCase): | ||
|
Uh oh!
There was an error while loading. Please reload this page.