Skip to content
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
6b379a9
fix variant-idenitification.
sayakpaul Aug 23, 2024
f155ec7
fix variant
sayakpaul Aug 23, 2024
3f36e59
Merge branch 'main' into variant-tests
sayakpaul Aug 23, 2024
91253e8
fix sharded variant checkpoint loading.
sayakpaul Aug 27, 2024
dd5941e
Merge branch 'main' into variant-tests
sayakpaul Aug 27, 2024
564b8b4
Apply suggestions from code review
sayakpaul Aug 27, 2024
fdd0435
Merge branch 'main' into variant-tests
sayakpaul Sep 4, 2024
d5cad9e
Merge branch 'main' into variant-tests
sayakpaul Sep 10, 2024
c0b1ceb
fixes.
sayakpaul Sep 10, 2024
247dd93
more fixes.
sayakpaul Sep 10, 2024
b024a6d
remove print.
sayakpaul Sep 10, 2024
fdfdc5f
Merge branch 'main' into variant-tests
sayakpaul Sep 11, 2024
dcf1852
Merge branch 'main' into variant-tests
yiyixuxu Sep 12, 2024
3a71ad9
fixes
sayakpaul Sep 13, 2024
ab91852
fixes
sayakpaul Sep 13, 2024
aa631c5
comments
sayakpaul Sep 13, 2024
453bfa5
fixes
sayakpaul Sep 13, 2024
11e4b71
Merge branch 'main' into variant-tests
sayakpaul Sep 13, 2024
dbdf0f9
apply suggestions.
sayakpaul Sep 14, 2024
671038a
hub_utils.py
sayakpaul Sep 14, 2024
57382f2
Merge branch 'main' into variant-tests
sayakpaul Sep 14, 2024
ea5ecdb
fix test
sayakpaul Sep 14, 2024
a510a9b
Merge branch 'main' into variant-tests
sayakpaul Sep 17, 2024
f583dad
Merge branch 'main' into variant-tests
sayakpaul Sep 18, 2024
dc0255a
updates
sayakpaul Sep 19, 2024
f2ab3de
Merge branch 'main' into variant-tests
sayakpaul Sep 21, 2024
10baa9d
Merge branch 'main' into variant-tests
sayakpaul Sep 23, 2024
25ac01f
fixes
sayakpaul Sep 23, 2024
bac62ac
Merge branch 'main' into variant-tests
sayakpaul Sep 24, 2024
b6794ed
Merge branch 'main' into variant-tests
sayakpaul Sep 25, 2024
fcb4e39
Merge branch 'main' into variant-tests
sayakpaul Sep 26, 2024
4c0c5d2
fixes
sayakpaul Sep 26, 2024
0b1c2a6
Merge branch 'main' into variant-tests
sayakpaul Sep 27, 2024
8ad6b23
Apply suggestions from code review
sayakpaul Sep 28, 2024
1190f7d
updates.
sayakpaul Sep 28, 2024
59cfefb
removep patch file.
sayakpaul Sep 28, 2024
d72f5c1
Merge branch 'main' into variant-tests
sayakpaul Sep 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
WEIGHTS_INDEX_NAME,
_add_variant,
_get_model_file,
deprecate,
is_accelerate_available,
is_torch_version,
logging,
Expand Down Expand Up @@ -228,3 +229,67 @@ def _fetch_index_file(
index_file = None

return index_file


def _fetch_index_file_legacy(
is_local,
pretrained_model_name_or_path,
subfolder,
use_safetensors,
cache_dir,
variant,
force_download,
proxies,
local_files_only,
token,
revision,
user_agent,
commit_hash,
):
if is_local:
index_file = Path(
pretrained_model_name_or_path,
subfolder or "",
SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME,
).as_posix()
splits = index_file.split(".")
split_index = -3 if ".cache" in index_file else -2
splits = splits[:-split_index] + [variant] + splits[-split_index:]
index_file = ".".join(splits)
if os.path.exists(index_file):
deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
deprecate("legacy_sharded_ckpts_with_variant", "1.0.0", deprecation_message, standard_warn=False)
index_file = Path(index_file)
else:
index_file = None
else:
if variant is not None:
index_file_in_repo = Path(
subfolder or "",
SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME,
).as_posix()
splits = index_file_in_repo.split(".")
split_index = -2
splits = splits[:-split_index] + [variant] + splits[-split_index:]
index_file_in_repo = ".".join(splits)
try:
index_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=index_file_in_repo,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=None,
user_agent=user_agent,
commit_hash=commit_hash,
)
index_file = Path(index_file)
deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
deprecate("legacy_sharded_ckpts_with_variant", "1.0.0", deprecation_message, standard_warn=False)
except (EntryNotFoundError, EnvironmentError):
index_file = None

return index_file
40 changes: 24 additions & 16 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from .model_loading_utils import (
_determine_device_map,
_fetch_index_file,
_fetch_index_file_legacy,
_load_state_dict_into_model,
load_model_dict_into_meta,
load_state_dict,
Expand Down Expand Up @@ -315,7 +316,9 @@ def save_pretrained(
weights_name = _add_variant(weights_name, variant)
weight_name_split = weights_name.split(".")
if len(weight_name_split) in [2, 3]:
weights_name_pattern = weight_name_split[0] + "{suffix}." + ".".join(weight_name_split[1:])
weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
)
else:
raise ValueError(f"Invalid {weights_name} provided.")

Expand Down Expand Up @@ -628,21 +631,26 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
is_sharded = False
index_file = None
is_local = os.path.isdir(pretrained_model_name_or_path)
index_file = _fetch_index_file(
is_local=is_local,
pretrained_model_name_or_path=pretrained_model_name_or_path,
subfolder=subfolder or "",
use_safetensors=use_safetensors,
cache_dir=cache_dir,
variant=variant,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
user_agent=user_agent,
commit_hash=commit_hash,
)
index_file_kwargs = {
"is_local": is_local,
"pretrained_model_name_or_path": pretrained_model_name_or_path,
"subfolder": subfolder or "",
"use_safetensors": use_safetensors,
"cache_dir": cache_dir,
"variant": variant,
"force_download": force_download,
"proxies": proxies,
"local_files_only": local_files_only,
"token": token,
"revision": revision,
"user_agent": user_agent,
"commit_hash": commit_hash,
}
index_file = _fetch_index_file(**index_file_kwargs)
# In case the index file was not found we still have to consider the legacy format.
# this becomes applicable when the variant is not None.
if variant is not None and (index_file is None or not os.path.exists(index_file)):
index_file = _fetch_index_file_legacy(**index_file_kwargs)
if index_file is not None and index_file.is_file():
is_sharded = True

Expand Down
32 changes: 20 additions & 12 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,14 @@
DEPRECATED_REVISION_ARGS,
BaseOutput,
PushToHubMixin,
deprecate,
is_accelerate_available,
is_accelerate_version,
is_torch_npu_available,
is_torch_version,
logging,
numpy_to_pil,
)
from ..utils.hub_utils import load_or_create_model_card, populate_model_card
from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card
from ..utils.torch_utils import is_compiled_module


Expand Down Expand Up @@ -735,6 +734,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
else:
cached_folder = pretrained_model_name_or_path

# The variant filenames can have the legacy sharding checkpoint format that we check and throw
# a warning if detected.
if variant is not None and _check_legacy_sharding_variant_format(cached_folder, variant):
warn_msg = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
logger.warning(warn_msg)

config_dict = cls.load_config(cached_folder)

# pop out "_ignore_files" as it is only needed for download
Expand All @@ -745,6 +750,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
# Example: `diffusion_pytorch_model.safetensors` -> `diffusion_pytorch_model.fp16.safetensors`
# with variant being `"fp16"`.
model_variants = _identify_model_variants(folder=cached_folder, variant=variant, config=config_dict)
if len(model_variants) == 0 and variant is not None:
error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
raise ValueError(error_message)

# 3. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it
Expand Down Expand Up @@ -1251,6 +1259,9 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
model_info_call_error = e # save error to reraise it if model is not cached locally

if not local_files_only:
filenames = {sibling.rfilename for sibling in info.siblings}
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)

config_file = hf_hub_download(
pretrained_model_name,
cls.config_name,
Expand All @@ -1267,9 +1278,6 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
# retrieve all folder_names that contain relevant files
folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"]

filenames = {sibling.rfilename for sibling in info.siblings}
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)

Comment on lines -1270 to -1272
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was moved up to raise error earlier in code.

diffusers_module = importlib.import_module(__name__.split(".")[0])
pipelines = getattr(diffusers_module, "pipelines")

Expand All @@ -1292,13 +1300,8 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
)

if len(variant_filenames) == 0 and variant is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's not remove this error in download

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not an error, though. It's a deprecation. Do we exactly want to keep it that way? If so, we will have to remove it anyway because the deprecation is supposed to expire after "0.24.0" version.

Instead, we are erroring out now from from_pretrained():

model_variants = _identify_model_variants(folder=cached_folder, variant=variant, config=config_dict)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah got it. I think this should be resolved now.

WDYT about catching these errors without having to download the actual files and leveraging model_info() (in case we're querying the Hub) or regular string matching (in case it's local)? Currently, we're still calling download() in case we don't have the model files cached. I think many errors can be caught and warnings can be thrown without having to do that.

This could live in a future PR.

deprecation_message = (
f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
f"The default model files: {model_filenames} will be loaded instead. Make sure to not load from `variant={variant}`"
"if such variant modeling files are not available. Doing so will lead to an error in v0.24.0 as defaulting to non-variant"
"modeling files is deprecated."
)
deprecate("no variant default", "0.24.0", deprecation_message, standard_warn=False)
error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
raise ValueError(error_message)

# remove ignored filenames
model_filenames = set(model_filenames) - set(ignore_filenames)
Expand Down Expand Up @@ -1368,6 +1371,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
)
expected_components, _ = cls._get_signature_keys(pipeline_class)
passed_components = [k for k in expected_components if k in kwargs]
is_sharded = any("index.json" in f and f != "model_index.json" for f in filenames)

if (
use_safetensors
Expand All @@ -1392,9 +1396,13 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:

safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
# `not is_sharded` because sharded checkpoints with a variant
# ("fp16") for example may have lesser shards actually. Consider
# https://huggingface.co/fal/AuraFlow/tree/main/transformer, for example.
if (
len(safetensors_variant_filenames) > 0
and safetensors_model_filenames != safetensors_variant_filenames
and not is_sharded
):
logger.warning(
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
Expand Down
13 changes: 11 additions & 2 deletions src/diffusers/utils/hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,7 @@ def move_cache(old_cache_dir: Optional[str] = None, new_cache_dir: Optional[str]
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
if variant is not None:
splits = weights_name.split(".")
split_index = -2 if weights_name.endswith(".index.json") else -1
splits = splits[:-split_index] + [variant] + splits[-split_index:]
splits = splits[:-1] + [variant] + splits[-1:]
weights_name = ".".join(splits)

return weights_name
Expand Down Expand Up @@ -502,6 +501,16 @@ def _get_checkpoint_shard_files(
return cached_folder, sharded_metadata


def _check_legacy_sharding_variant_format(folder: str, variant: str):
filenames = []
for _, _, files in os.walk(folder):
for file in files:
filenames.append(os.path.basename(file))
transformers_index_format = r"\d{5}-of-\d{5}"
variant_file_re = re.compile(rf".*-{transformers_index_format}\.{variant}\.[a-z]+$")
return any(variant_file_re.match(f) is not None for f in filenames)


class PushToHubMixin:
"""
A Mixin to push a model, scheduler, or pipeline to the Hugging Face Hub.
Expand Down
94 changes: 92 additions & 2 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@
import requests_mock
import torch
from accelerate.utils import compute_module_sizes
from huggingface_hub import ModelCard, delete_repo
from huggingface_hub import ModelCard, delete_repo, snapshot_download
from huggingface_hub.utils import is_jinja_available
from parameterized import parameterized
from requests.exceptions import HTTPError

from diffusers.models import UNet2DConditionModel
Expand All @@ -39,7 +40,13 @@
XFormersAttnProcessor,
)
from diffusers.training_utils import EMAModel
from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, is_torch_npu_available, is_xformers_available, logging
from diffusers.utils import (
SAFE_WEIGHTS_INDEX_NAME,
WEIGHTS_INDEX_NAME,
is_torch_npu_available,
is_xformers_available,
logging,
)
from diffusers.utils.hub_utils import _add_variant
from diffusers.utils.testing_utils import (
CaptureLogger,
Expand Down Expand Up @@ -100,6 +107,48 @@ def test_accelerate_loading_error_message(self):
# make sure that error message states what keys are missing
assert "conv_out.bias" in str(error_context.exception)

@parameterized.expand(
[
("hf-internal-testing/tiny-stable-diffusion-pipe-variants-all-kinds", "unet", False),
("hf-internal-testing/tiny-stable-diffusion-pipe-variants-all-kinds", "unet", True),
("hf-internal-testing/tiny-sd-unet-with-sharded-ckpt", None, False),
("hf-internal-testing/tiny-sd-unet-with-sharded-ckpt", None, True),
]
)
def test_variant_sharded_ckpt_legacy_format_raises_warning(self, repo_id, subfolder, use_local):
def load_model(path):
kwargs = {"variant": "fp16"}
if subfolder:
kwargs["subfolder"] = subfolder
return UNet2DConditionModel.from_pretrained(path, **kwargs)

with self.assertWarns(FutureWarning) as warning:
if use_local:
with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = snapshot_download(repo_id=repo_id)
_ = load_model(tmpdirname)
else:
_ = load_model(repo_id)

warning_message = str(warning.warnings[0].message)
self.assertIn("This serialization format is now deprecated to standardize the serialization", warning_message)

# Local tests are already covered down below.
@parameterized.expand(
[
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolde", "unet"),
]
)
def test_variant_sharded_ckpt_loads_from_hub(self, repo_id, subfolder):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a nice test! let's also add to the @parameterized to test non-variant(if not already tested), and device_map

Copy link
Member Author

@sayakpaul sayakpaul Sep 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Added parameterized to have subfolder and variant testing in all the sharding tests here:
    https://github.com/huggingface/diffusers/blob/main/tests/models/unets/test_models_unet_2d_condition.py

  2. Modified this test to have non-variant checkpoints as well.

Ran everything with "pytest tests/models/ -k "sharded" and it was green.

Commit: 1190f7d

def load_model():
kwargs = {"variant": "fp16"}
if subfolder:
kwargs["subfolder"] = subfolder
return UNet2DConditionModel.from_pretrained(repo_id, **kwargs)

assert load_model()

def test_cached_files_are_used_when_no_internet(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
Expand Down Expand Up @@ -924,6 +973,7 @@ def test_sharded_checkpoints_with_variant(self):
# testing if loading works with the variant when the checkpoint is sharded should be
# enough.
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB", variant=variant)

index_filename = _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, index_filename)))

Expand Down Expand Up @@ -976,6 +1026,46 @@ def test_sharded_checkpoints_device_map(self):
new_output = new_model(**inputs_dict)
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))

# This test is okay without a GPU because we're not running any execution. We're just serializing
# and check if the resultant files are following an expected format.
def test_variant_sharded_ckpt_right_format(self):
for use_safe in [True, False]:
extension = ".safetensors" if use_safe else ".bin"
config, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval()
if model._no_split_modules is None:
return

model_size = compute_module_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
variant = "fp16"
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(
tmp_dir, variant=variant, max_shard_size=f"{max_shard_size}KB", safe_serialization=use_safe
)
index_variant = _add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safe else WEIGHTS_INDEX_NAME, variant)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, index_variant)))

# Now check if the right number of shards exists. First, let's get the number of shards.
# Since this number can be dependent on the model being tested, it's important that we calculate it
# instead of hardcoding it.
expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, index_variant))
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(extension)])
self.assertTrue(actual_num_shards == expected_num_shards)

# Check if the variant is present as a substring in the checkpoints.
shard_files = [
file
for file in os.listdir(tmp_dir)
if file.endswith(extension) or ("index" in file and "json" in file)
]
assert all(variant in f for f in shard_files)

# Check if the sharded checkpoints were serialized in the right format.
shard_files = [file for file in os.listdir(tmp_dir) if file.endswith(extension)]
# Example: diffusion_pytorch_model.fp16-00001-of-00002.safetensors
assert all(f.split(".")[1].split("-")[0] == variant for f in shard_files)


@is_staging_test
class ModelPushToHubTester(unittest.TestCase):
Expand Down
Loading
Loading