From b165cf37420fb415fd0967d2c0ceb2248d97c9da Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 15 Jul 2025 03:03:29 +0200 Subject: [PATCH 01/16] rearrage the params to groups: default params /image params /batch params / callback params --- tests/pipelines/pipeline_params.py | 57 ++-- .../test_stable_diffusion_xl_modular.py | 256 ++++++++++++++++++ 2 files changed, 287 insertions(+), 26 deletions(-) create mode 100644 tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_modular.py diff --git a/tests/pipelines/pipeline_params.py b/tests/pipelines/pipeline_params.py index 4e2c4dcdd9cb..2023d026636e 100644 --- a/tests/pipelines/pipeline_params.py +++ b/tests/pipelines/pipeline_params.py @@ -20,12 +20,6 @@ ] ) -TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"]) - -TEXT_TO_IMAGE_IMAGE_PARAMS = frozenset([]) - -IMAGE_TO_IMAGE_IMAGE_PARAMS = frozenset(["image"]) - IMAGE_VARIATION_PARAMS = frozenset( [ "image", @@ -35,8 +29,6 @@ ] ) -IMAGE_VARIATION_BATCH_PARAMS = frozenset(["image"]) - TEXT_GUIDED_IMAGE_VARIATION_PARAMS = frozenset( [ "prompt", @@ -50,8 +42,6 @@ ] ) -TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS = frozenset(["prompt", "image", "negative_prompt"]) - TEXT_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset( [ # Text guided image variation with an image mask @@ -67,8 +57,6 @@ ] ) -TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["prompt", "image", "mask_image", "negative_prompt"]) - IMAGE_INPAINTING_PARAMS = frozenset( [ # image variation with an image mask @@ -80,8 +68,6 @@ ] ) -IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["image", "mask_image"]) - IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset( [ "example_image", @@ -93,20 +79,12 @@ ] ) -IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["example_image", "image", "mask_image"]) +UNCONDITIONAL_IMAGE_GENERATION_PARAMS = frozenset(["batch_size"]) CLASS_CONDITIONED_IMAGE_GENERATION_PARAMS = frozenset(["class_labels"]) CLASS_CONDITIONED_IMAGE_GENERATION_BATCH_PARAMS = frozenset(["class_labels"]) -UNCONDITIONAL_IMAGE_GENERATION_PARAMS = frozenset(["batch_size"]) - -UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS = frozenset([]) - -UNCONDITIONAL_AUDIO_GENERATION_PARAMS = frozenset(["batch_size"]) - -UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS = frozenset([]) - TEXT_TO_AUDIO_PARAMS = frozenset( [ "prompt", @@ -119,11 +97,38 @@ ] ) -TEXT_TO_AUDIO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"]) TOKENS_TO_AUDIO_GENERATION_PARAMS = frozenset(["input_tokens"]) -TOKENS_TO_AUDIO_GENERATION_BATCH_PARAMS = frozenset(["input_tokens"]) +UNCONDITIONAL_AUDIO_GENERATION_PARAMS = frozenset(["batch_size"]) + +# image params +TEXT_TO_IMAGE_IMAGE_PARAMS = frozenset([]) + +IMAGE_TO_IMAGE_IMAGE_PARAMS = frozenset(["image"]) + + +# batch params +TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"]) + +IMAGE_VARIATION_BATCH_PARAMS = frozenset(["image"]) + +TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS = frozenset(["prompt", "image", "negative_prompt"]) + +TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["prompt", "image", "mask_image", "negative_prompt"]) + +IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["image", "mask_image"]) -TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS = frozenset(["prompt_embeds"]) +IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["example_image", "image", "mask_image"]) + +UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS = frozenset([]) + +UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS = frozenset([]) + +TEXT_TO_AUDIO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"]) + +TOKENS_TO_AUDIO_GENERATION_BATCH_PARAMS = frozenset(["input_tokens"]) VIDEO_TO_VIDEO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt", "video"]) + +# callback params +TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS = frozenset(["prompt_embeds"]) \ No newline at end of file diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_modular.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_modular.py new file mode 100644 index 000000000000..2027de2546f4 --- /dev/null +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_modular.py @@ -0,0 +1,256 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import gc +import tempfile +import unittest + +import numpy as np +import torch +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +from diffusers import ( + ModularPipeline, + ComponentSpec, + ComponentsManager, + AutoencoderKL, + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + LCMScheduler, + StableDiffusionXLImg2ImgPipeline, + StableDiffusionXLPipeline, + UNet2DConditionModel, + UniPCMultistepScheduler, +) +from diffusers.utils.testing_utils import ( + backend_empty_cache, + enable_full_determinism, + load_image, + numpy_cosine_similarity_distance, + require_torch_accelerator, + slow, + torch_device, +) + +from ..pipeline_params import ( + TEXT_TO_IMAGE_BATCH_PARAMS, + TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS, + TEXT_TO_IMAGE_IMAGE_PARAMS, + TEXT_TO_IMAGE_PARAMS, +) +from ..test_pipelines_common import ( + IPAdapterTesterMixin, + PipelineLatentTesterMixin, + PipelineTesterMixin, + SDFunctionTesterMixin, +) + + +enable_full_determinism() + + +class StableDiffusionXLModularPipelineFastTests( + SDFunctionTesterMixin, + IPAdapterTesterMixin, + PipelineLatentTesterMixin, + PipelineTesterMixin, + unittest.TestCase, +): + pipeline_class = StableDiffusionXLPipeline + params = (TEXT_TO_IMAGE_PARAMS | IMAGE_INPAINTING_PARAMS) - {"guidance_scale"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS | IMAGE_INPAINTING_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + test_layerwise_casting = False + test_group_offloading = False + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "output_type": "np", + } + return inputs + + def test_stable_diffusion_xl_euler(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + sd_pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe") + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = sd_pipe(**inputs, output="images") + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.5388, 0.5452, 0.4694, 0.4583, 0.5253, 0.4832, 0.5288, 0.5035, 0.47]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_xl_euler_lcm(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + sd_pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe") + sd_pipe.update_components(scheduler=LCMScheduler.from_config(sd_pipe.scheduler.config)) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = sd_pipe(**inputs, output="images") + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.4917, 0.6555, 0.4348, 0.5219, 0.7324, 0.4855, 0.5168, 0.5447, 0.5156]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_xl_euler_lcm_custom_timesteps(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + sd_pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe") + sd_pipe.update_components(scheduler=LCMScheduler.from_config(sd_pipe.scheduler.config)) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + del inputs["num_inference_steps"] + inputs["timesteps"] = [999, 499] + image = sd_pipe(**inputs, output="images") + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.4917, 0.6555, 0.4348, 0.5219, 0.7324, 0.4855, 0.5168, 0.5447, 0.5156]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + @require_torch_accelerator + def test_stable_diffusion_xl_offloads(self): + pipes = [] + sd_pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe",).to(torch_device) + pipes.append(sd_pipe) + + cm = ComponentsManager() + cm.enable_auto_cpu_offload(device=torch_device) + sd_pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe", components_manager=cm).to(torch_device) + pipes.append(sd_pipe) + + image_slices = [] + for pipe in pipes: + inputs = self.get_dummy_inputs(torch_device) + image = pipe(**inputs, output="images") + + image_slices.append(image[0, -3:, -3:, -1].flatten()) + + assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3 + assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3 + + def test_stable_diffusion_xl_multi_prompts(self): + sd_pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device) + + # forward with single prompt + inputs = self.get_dummy_inputs(torch_device) + output = sd_pipe(**inputs, output="images") + image_slice_1 = output.images[0, -3:, -3:, -1] + + # forward with same prompt duplicated + inputs = self.get_dummy_inputs(torch_device) + inputs["prompt_2"] = inputs["prompt"] + output = sd_pipe(**inputs, output="images") + image_slice_2 = output.images[0, -3:, -3:, -1] + + # ensure the results are equal + assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4 + + # forward with different prompt + inputs = self.get_dummy_inputs(torch_device) + inputs["prompt_2"] = "different prompt" + output = sd_pipe(**inputs, output="images") + image_slice_3 = output.images[0, -3:, -3:, -1] + + # ensure the results are not equal + assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4 + + # manually set a negative_prompt + inputs = self.get_dummy_inputs(torch_device) + inputs["negative_prompt"] = "negative prompt" + output = sd_pipe(**inputs, output="images") + image_slice_1 = output.images[0, -3:, -3:, -1] + + # forward with same negative_prompt duplicated + inputs = self.get_dummy_inputs(torch_device) + inputs["negative_prompt"] = "negative prompt" + inputs["negative_prompt_2"] = inputs["negative_prompt"] + output = sd_pipe(**inputs, output="images") + image_slice_2 = output.images[0, -3:, -3:, -1] + + # ensure the results are equal + assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4 + + # forward with different negative_prompt + inputs = self.get_dummy_inputs(torch_device) + inputs["negative_prompt"] = "negative prompt" + inputs["negative_prompt_2"] = "different negative prompt" + output = sd_pipe(**inputs, output="images") + image_slice_3 = output.images[0, -3:, -3:, -1] + + # ensure the results are not equal + assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4 + + def test_stable_diffusion_xl_negative_conditions(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + sd_pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = sd_pipe(**inputs, output="images") + image_slice_with_no_neg_cond = image[0, -3:, -3:, -1] + + image = sd_pipe( + **inputs, + negative_original_size=(512, 512), + negative_crops_coords_top_left=(0, 0), + negative_target_size=(1024, 1024), + output="images", + ) + image_slice_with_neg_cond = image[0, -3:, -3:, -1] + + self.assertTrue(np.abs(image_slice_with_no_neg_cond - image_slice_with_neg_cond).max() > 1e-2) + + def test_stable_diffusion_xl_save_from_pretrained(self): + pipes = [] + sd_pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device) + pipes.append(sd_pipe) + + with tempfile.TemporaryDirectory() as tmpdirname: + sd_pipe.save_pretrained(tmpdirname) + sd_pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device) + pipes.append(sd_pipe) + + image_slices = [] + for pipe in pipes: + pipe.unet.set_default_attn_processor() + + inputs = self.get_dummy_inputs(torch_device) + image = pipe(**inputs, output="images") + + image_slices.append(image[0, -3:, -3:, -1].flatten()) + + assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3 \ No newline at end of file From 0fa58127f8a0f2070be64967e8e47a24ed988085 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 15 Jul 2025 03:05:36 +0200 Subject: [PATCH 02/16] make style --- tests/pipelines/pipeline_params.py | 2 +- .../test_stable_diffusion_xl_modular.py | 31 ++++++------------- 2 files changed, 11 insertions(+), 22 deletions(-) diff --git a/tests/pipelines/pipeline_params.py b/tests/pipelines/pipeline_params.py index 2023d026636e..3db7c9fa1b0c 100644 --- a/tests/pipelines/pipeline_params.py +++ b/tests/pipelines/pipeline_params.py @@ -131,4 +131,4 @@ VIDEO_TO_VIDEO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt", "video"]) # callback params -TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS = frozenset(["prompt_embeds"]) \ No newline at end of file +TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS = frozenset(["prompt_embeds"]) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_modular.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_modular.py index 2027de2546f4..aaa780cf982d 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_modular.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_modular.py @@ -13,43 +13,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy -import gc import tempfile import unittest import numpy as np import torch -from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from diffusers import ( - ModularPipeline, - ComponentSpec, ComponentsManager, - AutoencoderKL, - DDIMScheduler, - DPMSolverMultistepScheduler, - EulerDiscreteScheduler, - HeunDiscreteScheduler, LCMScheduler, - StableDiffusionXLImg2ImgPipeline, + ModularPipeline, StableDiffusionXLPipeline, - UNet2DConditionModel, - UniPCMultistepScheduler, ) from diffusers.utils.testing_utils import ( - backend_empty_cache, enable_full_determinism, - load_image, - numpy_cosine_similarity_distance, require_torch_accelerator, - slow, torch_device, ) from ..pipeline_params import ( + IMAGE_INPAINTING_BATCH_PARAMS, + IMAGE_INPAINTING_PARAMS, TEXT_TO_IMAGE_BATCH_PARAMS, - TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS, ) @@ -143,12 +128,16 @@ def test_stable_diffusion_xl_euler_lcm_custom_timesteps(self): @require_torch_accelerator def test_stable_diffusion_xl_offloads(self): pipes = [] - sd_pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe",).to(torch_device) + sd_pipe = ModularPipeline.from_pretrained( + "hf-internal-testing/tiny-sd-pipe", + ).to(torch_device) pipes.append(sd_pipe) cm = ComponentsManager() cm.enable_auto_cpu_offload(device=torch_device) - sd_pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe", components_manager=cm).to(torch_device) + sd_pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe", components_manager=cm).to( + torch_device + ) pipes.append(sd_pipe) image_slices = [] @@ -253,4 +242,4 @@ def test_stable_diffusion_xl_save_from_pretrained(self): image_slices.append(image[0, -3:, -3:, -1].flatten()) - assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3 \ No newline at end of file + assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3 From 0a5c90ed47eb9927942bc6d116c245d4a530c7df Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 16 Jul 2025 04:25:26 +0200 Subject: [PATCH 03/16] add names property to pipeline blocks --- .../modular_pipelines/modular_pipeline.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index b99478cb58d1..c343ebc5ac46 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -478,6 +478,23 @@ def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> combined_dict[output_param.name] = output_param return list(combined_dict.values()) + + @property + def input_names(self) -> List[str]: + return [input_param.name for input_param in self.inputs] + + @property + def intermediate_input_names(self) -> List[str]: + return [input_param.name for input_param in self.intermediate_inputs] + + @property + def intermediate_output_names(self) -> List[str]: + return [output_param.name for output_param in self.intermediate_outputs] + + @property + def output_names(self) -> List[str]: + return [output_param.name for output_param in self.outputs] + class PipelineBlock(ModularPipelineBlocks): @@ -2825,3 +2842,9 @@ def _dict_to_component_spec( type_hint=type_hint, **spec_dict, ) + + + def set_progress_bar_config(self, **kwargs): + for sub_block_name, sub_block in self.blocks.sub_blocks.items(): + if hasattr(sub_block, "set_progress_bar_config"): + sub_block.set_progress_bar_config(**kwargs) \ No newline at end of file From d92855ddf0a0c7160dc8e56da486ee93adee2141 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 16 Jul 2025 04:26:27 +0200 Subject: [PATCH 04/16] style --- .../modular_pipelines/modular_pipeline.py | 12 +- .../test_stable_diffusion_xl_modular.py | 107 ++- .../test_modular_pipelines_common.py | 689 ++++++++++++++++++ 3 files changed, 744 insertions(+), 64 deletions(-) create mode 100644 tests/pipelines/test_modular_pipelines_common.py diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index c343ebc5ac46..a1972167c152 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -478,23 +478,22 @@ def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> combined_dict[output_param.name] = output_param return list(combined_dict.values()) - + @property def input_names(self) -> List[str]: return [input_param.name for input_param in self.inputs] - + @property def intermediate_input_names(self) -> List[str]: return [input_param.name for input_param in self.intermediate_inputs] - + @property def intermediate_output_names(self) -> List[str]: return [output_param.name for output_param in self.intermediate_outputs] - + @property def output_names(self) -> List[str]: return [output_param.name for output_param in self.outputs] - class PipelineBlock(ModularPipelineBlocks): @@ -2843,8 +2842,7 @@ def _dict_to_component_spec( **spec_dict, ) - def set_progress_bar_config(self, **kwargs): for sub_block_name, sub_block in self.blocks.sub_blocks.items(): if hasattr(sub_block, "set_progress_bar_config"): - sub_block.set_progress_bar_config(**kwargs) \ No newline at end of file + sub_block.set_progress_bar_config(**kwargs) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_modular.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_modular.py index aaa780cf982d..0d9c3d0f20cc 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_modular.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_modular.py @@ -23,7 +23,8 @@ ComponentsManager, LCMScheduler, ModularPipeline, - StableDiffusionXLPipeline, + StableDiffusionXLAutoBlocks, + StableDiffusionXLModularPipeline, ) from diffusers.utils.testing_utils import ( enable_full_determinism, @@ -38,11 +39,9 @@ TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS, ) -from ..test_pipelines_common import ( - IPAdapterTesterMixin, - PipelineLatentTesterMixin, - PipelineTesterMixin, - SDFunctionTesterMixin, +from ..test_modular_pipelines_common import ( + ModularIPAdapterTesterMixin, + ModularPipelineTesterMixin, ) @@ -50,18 +49,25 @@ class StableDiffusionXLModularPipelineFastTests( - SDFunctionTesterMixin, - IPAdapterTesterMixin, - PipelineLatentTesterMixin, - PipelineTesterMixin, + ModularIPAdapterTesterMixin, + ModularPipelineTesterMixin, unittest.TestCase, ): - pipeline_class = StableDiffusionXLPipeline - params = (TEXT_TO_IMAGE_PARAMS | IMAGE_INPAINTING_PARAMS) - {"guidance_scale"} + pipeline_class = StableDiffusionXLModularPipeline + pipeline_blocks_class = StableDiffusionXLAutoBlocks + repo = "hf-internal-testing/tiny-sdxl-modular" + params = (TEXT_TO_IMAGE_PARAMS | IMAGE_INPAINTING_PARAMS) - { + "guidance_scale", + "prompt_embeds", + "negative_prompt_embeds", + } batch_params = TEXT_TO_IMAGE_BATCH_PARAMS | IMAGE_INPAINTING_BATCH_PARAMS image_params = TEXT_TO_IMAGE_IMAGE_PARAMS - test_layerwise_casting = False - test_group_offloading = False + + def get_pipeline(self, components_manager=None, torch_dtype=torch.float32): + pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager) + pipeline.load_default_components(torch_dtype=torch_dtype) + return pipeline def get_dummy_inputs(self, device, seed=0): if str(device).startswith("mps"): @@ -78,7 +84,7 @@ def get_dummy_inputs(self, device, seed=0): def test_stable_diffusion_xl_euler(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator - sd_pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe") + sd_pipe = self.get_pipeline() sd_pipe = sd_pipe.to(device) sd_pipe.set_progress_bar_config(disable=None) @@ -87,13 +93,17 @@ def test_stable_diffusion_xl_euler(self): image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.5388, 0.5452, 0.4694, 0.4583, 0.5253, 0.4832, 0.5288, 0.5035, 0.47]) + expected_slice = np.array( + [0.5966781, 0.62939394, 0.48465094, 0.51573336, 0.57593524, 0.47035995, 0.53410417, 0.51436996, 0.47313565] + ) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, ( + f"image_slice: {image_slice.flatten()}, expected_slice: {expected_slice.flatten()}" + ) def test_stable_diffusion_xl_euler_lcm(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator - sd_pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe") + sd_pipe = self.get_pipeline() sd_pipe.update_components(scheduler=LCMScheduler.from_config(sd_pipe.scheduler.config)) sd_pipe = sd_pipe.to(device) sd_pipe.set_progress_bar_config(disable=None) @@ -103,41 +113,23 @@ def test_stable_diffusion_xl_euler_lcm(self): image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.4917, 0.6555, 0.4348, 0.5219, 0.7324, 0.4855, 0.5168, 0.5447, 0.5156]) - - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - - def test_stable_diffusion_xl_euler_lcm_custom_timesteps(self): - device = "cpu" # ensure determinism for the device-dependent torch.Generator - sd_pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe") - sd_pipe.update_components(scheduler=LCMScheduler.from_config(sd_pipe.scheduler.config)) - sd_pipe = sd_pipe.to(device) - sd_pipe.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(device) - del inputs["num_inference_steps"] - inputs["timesteps"] = [999, 499] - image = sd_pipe(**inputs, output="images") - image_slice = image[0, -3:, -3:, -1] - - assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.4917, 0.6555, 0.4348, 0.5219, 0.7324, 0.4855, 0.5168, 0.5447, 0.5156]) + expected_slice = np.array( + [0.6880376, 0.6511651, 0.587455, 0.61763, 0.55432945, 0.52064973, 0.5783733, 0.54915607, 0.5460011] + ) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, ( + f"image_slice: {image_slice.flatten()}, expected_slice: {expected_slice.flatten()}" + ) @require_torch_accelerator def test_stable_diffusion_xl_offloads(self): pipes = [] - sd_pipe = ModularPipeline.from_pretrained( - "hf-internal-testing/tiny-sd-pipe", - ).to(torch_device) + sd_pipe = self.get_pipeline().to(torch_device) pipes.append(sd_pipe) cm = ComponentsManager() cm.enable_auto_cpu_offload(device=torch_device) - sd_pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe", components_manager=cm).to( - torch_device - ) + sd_pipe = self.get_pipeline(components_manager=cm) pipes.append(sd_pipe) image_slices = [] @@ -148,21 +140,20 @@ def test_stable_diffusion_xl_offloads(self): image_slices.append(image[0, -3:, -3:, -1].flatten()) assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3 - assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3 def test_stable_diffusion_xl_multi_prompts(self): - sd_pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device) + sd_pipe = self.get_pipeline().to(torch_device) # forward with single prompt inputs = self.get_dummy_inputs(torch_device) output = sd_pipe(**inputs, output="images") - image_slice_1 = output.images[0, -3:, -3:, -1] + image_slice_1 = output[0, -3:, -3:, -1] # forward with same prompt duplicated inputs = self.get_dummy_inputs(torch_device) inputs["prompt_2"] = inputs["prompt"] output = sd_pipe(**inputs, output="images") - image_slice_2 = output.images[0, -3:, -3:, -1] + image_slice_2 = output[0, -3:, -3:, -1] # ensure the results are equal assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4 @@ -171,7 +162,7 @@ def test_stable_diffusion_xl_multi_prompts(self): inputs = self.get_dummy_inputs(torch_device) inputs["prompt_2"] = "different prompt" output = sd_pipe(**inputs, output="images") - image_slice_3 = output.images[0, -3:, -3:, -1] + image_slice_3 = output[0, -3:, -3:, -1] # ensure the results are not equal assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4 @@ -180,14 +171,14 @@ def test_stable_diffusion_xl_multi_prompts(self): inputs = self.get_dummy_inputs(torch_device) inputs["negative_prompt"] = "negative prompt" output = sd_pipe(**inputs, output="images") - image_slice_1 = output.images[0, -3:, -3:, -1] + image_slice_1 = output[0, -3:, -3:, -1] # forward with same negative_prompt duplicated inputs = self.get_dummy_inputs(torch_device) inputs["negative_prompt"] = "negative prompt" inputs["negative_prompt_2"] = inputs["negative_prompt"] output = sd_pipe(**inputs, output="images") - image_slice_2 = output.images[0, -3:, -3:, -1] + image_slice_2 = output[0, -3:, -3:, -1] # ensure the results are equal assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4 @@ -197,15 +188,14 @@ def test_stable_diffusion_xl_multi_prompts(self): inputs["negative_prompt"] = "negative prompt" inputs["negative_prompt_2"] = "different negative prompt" output = sd_pipe(**inputs, output="images") - image_slice_3 = output.images[0, -3:, -3:, -1] + image_slice_3 = output[0, -3:, -3:, -1] # ensure the results are not equal assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4 def test_stable_diffusion_xl_negative_conditions(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator - sd_pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device) - sd_pipe = sd_pipe.to(device) + sd_pipe = self.get_pipeline().to(device) sd_pipe.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs(device) @@ -225,21 +215,24 @@ def test_stable_diffusion_xl_negative_conditions(self): def test_stable_diffusion_xl_save_from_pretrained(self): pipes = [] - sd_pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device) + sd_pipe = self.get_pipeline().to(torch_device) pipes.append(sd_pipe) with tempfile.TemporaryDirectory() as tmpdirname: sd_pipe.save_pretrained(tmpdirname) sd_pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device) + sd_pipe.load_default_components(torch_dtype=torch.float32) + sd_pipe.to(torch_device) pipes.append(sd_pipe) image_slices = [] for pipe in pipes: - pipe.unet.set_default_attn_processor() - inputs = self.get_dummy_inputs(torch_device) image = pipe(**inputs, output="images") image_slices.append(image[0, -3:, -3:, -1].flatten()) assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3 + + def test_inference_batch_single_identical(self): + super().test_inference_batch_single_identical(expected_max_diff=3e-3) diff --git a/tests/pipelines/test_modular_pipelines_common.py b/tests/pipelines/test_modular_pipelines_common.py new file mode 100644 index 000000000000..f8e46b2ca592 --- /dev/null +++ b/tests/pipelines/test_modular_pipelines_common.py @@ -0,0 +1,689 @@ +import gc +import unittest +from typing import Any, Callable, Dict, Union + +import numpy as np +import torch + +import diffusers +from diffusers import ( + ClassifierFreeGuidance, + DiffusionPipeline, +) +from diffusers.loaders import ModularIPAdapterMixin +from diffusers.utils import logging +from diffusers.utils.testing_utils import ( + backend_empty_cache, + numpy_cosine_similarity_distance, + require_accelerator, + require_torch, + torch_device, +) + +from ..models.unets.test_models_unet_2d_condition import ( + create_ip_adapter_faceid_state_dict, + create_ip_adapter_state_dict, +) + + +def to_np(tensor): + if isinstance(tensor, torch.Tensor): + tensor = tensor.detach().cpu().numpy() + + return tensor + + +def check_same_shape(tensor_list): + shapes = [tensor.shape for tensor in tensor_list] + return all(shape == shapes[0] for shape in shapes[1:]) + + +class ModularIPAdapterTesterMixin: + """ + This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes. + It provides a set of common tests for pipelines that support IP Adapters. + """ + + def test_pipeline_inputs_and_blocks(self): + blocks = self.pipeline_blocks_class() + parameters = blocks.input_names + + assert issubclass(self.pipeline_class, ModularIPAdapterMixin) + self.assertIn( + "ip_adapter_image", + parameters, + "`ip_adapter_image` argument must be supported by the `__call__` method", + ) + self.assertIn( + "ip_adapter", + blocks.sub_blocks, + "pipeline must contain an IPAdapter block", + ) + + _ = blocks.sub_blocks.pop("ip_adapter") + parameters = blocks.input_names + intermediate_parameters = blocks.intermediate_input_names + self.assertNotIn( + "ip_adapter_image", + parameters, + "`ip_adapter_image` argument must be removed from the `__call__` method", + ) + self.assertNotIn( + "ip_adapter_image_embeds", + intermediate_parameters, + "`ip_adapter_image_embeds` argument must be supported by the `__call__` method", + ) + + def _get_dummy_image_embeds(self, cross_attention_dim: int = 32): + return torch.randn((1, 1, cross_attention_dim), device=torch_device) + + def _get_dummy_faceid_image_embeds(self, cross_attention_dim: int = 32): + return torch.randn((1, 1, 1, cross_attention_dim), device=torch_device) + + def _get_dummy_masks(self, input_size: int = 64): + _masks = torch.zeros((1, 1, input_size, input_size), device=torch_device) + _masks[0, :, :, : int(input_size / 2)] = 1 + return _masks + + def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]): + blocks = self.pipeline_blocks_class() + _ = blocks.sub_blocks.pop("ip_adapter") + parameters = blocks.input_names + if "image" in parameters and "strength" in parameters: + inputs["num_inference_steps"] = 4 + + inputs["output_type"] = "np" + return inputs + + def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None): + r"""Tests for IP-Adapter. + + The following scenarios are tested: + - Single IP-Adapter with scale=0 should produce same output as no IP-Adapter. + - Multi IP-Adapter with scale=0 should produce same output as no IP-Adapter. + - Single IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter. + - Multi IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter. + """ + # Raising the tolerance for this test when it's run on a CPU because we + # compare against static slices and that can be shaky (with a VVVV low probability). + expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff + + blocks = self.pipeline_blocks_class() + _ = blocks.sub_blocks.pop("ip_adapter") + pipe = blocks.init_pipeline(self.repo) + pipe.load_default_components(torch_dtype=torch.float32) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + cross_attention_dim = pipe.unet.config.get("cross_attention_dim") + + # forward pass without ip adapter + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + if expected_pipe_slice is None: + output_without_adapter = pipe(**inputs, output="images") + else: + output_without_adapter = expected_pipe_slice + + # 1. Single IP-Adapter test cases + adapter_state_dict = create_ip_adapter_state_dict(pipe.unet) + pipe.unet._load_ip_adapter_weights(adapter_state_dict) + + # forward pass with single ip adapter, but scale=0 which should have no effect + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] + inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] + pipe.set_ip_adapter_scale(0.0) + output_without_adapter_scale = pipe(**inputs, output="images") + if expected_pipe_slice is not None: + output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten() + + # forward pass with single ip adapter, but with scale of adapter weights + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] + inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] + pipe.set_ip_adapter_scale(42.0) + output_with_adapter_scale = pipe(**inputs, output="images") + if expected_pipe_slice is not None: + output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten() + + max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max() + max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max() + + self.assertLess( + max_diff_without_adapter_scale, + expected_max_diff, + "Output without ip-adapter must be same as normal inference", + ) + self.assertGreater( + max_diff_with_adapter_scale, 1e-2, "Output with ip-adapter must be different from normal inference" + ) + + # 2. Multi IP-Adapter test cases + adapter_state_dict_1 = create_ip_adapter_state_dict(pipe.unet) + adapter_state_dict_2 = create_ip_adapter_state_dict(pipe.unet) + pipe.unet._load_ip_adapter_weights([adapter_state_dict_1, adapter_state_dict_2]) + + # forward pass with multi ip adapter, but scale=0 which should have no effect + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2 + inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2 + pipe.set_ip_adapter_scale([0.0, 0.0]) + output_without_multi_adapter_scale = pipe(**inputs, output="images") + if expected_pipe_slice is not None: + output_without_multi_adapter_scale = output_without_multi_adapter_scale[0, -3:, -3:, -1].flatten() + + # forward pass with multi ip adapter, but with scale of adapter weights + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2 + inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2 + pipe.set_ip_adapter_scale([42.0, 42.0]) + output_with_multi_adapter_scale = pipe(**inputs, output="images") + if expected_pipe_slice is not None: + output_with_multi_adapter_scale = output_with_multi_adapter_scale[0, -3:, -3:, -1].flatten() + + max_diff_without_multi_adapter_scale = np.abs( + output_without_multi_adapter_scale - output_without_adapter + ).max() + max_diff_with_multi_adapter_scale = np.abs(output_with_multi_adapter_scale - output_without_adapter).max() + self.assertLess( + max_diff_without_multi_adapter_scale, + expected_max_diff, + "Output without multi-ip-adapter must be same as normal inference", + ) + self.assertGreater( + max_diff_with_multi_adapter_scale, + 1e-2, + "Output with multi-ip-adapter scale must be different from normal inference", + ) + + def test_ip_adapter_cfg(self, expected_max_diff: float = 1e-4): + blocks = self.pipeline_blocks_class() + _ = blocks.sub_blocks.pop("ip_adapter") + pipe = blocks.init_pipeline(self.repo) + pipe.load_default_components(torch_dtype=torch.float32) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + cross_attention_dim = pipe.unet.config.get("cross_attention_dim", 32) + + adapter_state_dict = create_ip_adapter_state_dict(pipe.unet) + pipe.unet._load_ip_adapter_weights(adapter_state_dict) + pipe.set_ip_adapter_scale(1.0) + + # forward pass with CFG not applied + guider = ClassifierFreeGuidance(guidance_scale=1.0) + pipe.update_components(guider=guider) + + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)[0].unsqueeze(0)] + out_no_cfg = pipe(**inputs, output="images") + + # forward pass with CFG applied + guider = ClassifierFreeGuidance(guidance_scale=7.5) + pipe.update_components(guider=guider) + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] + inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] + out_cfg = pipe(**inputs, output="images") + + assert out_cfg.shape == out_no_cfg.shape + + def test_ip_adapter_masks(self, expected_max_diff: float = 1e-4): + blocks = self.pipeline_blocks_class() + _ = blocks.sub_blocks.pop("ip_adapter") + pipe = blocks.init_pipeline(self.repo) + pipe.load_default_components(torch_dtype=torch.float32) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + cross_attention_dim = pipe.unet.config.get("cross_attention_dim", 32) + sample_size = pipe.unet.config.get("sample_size", 32) + block_out_channels = pipe.vae.config.get("block_out_channels", [128, 256, 512, 512]) + input_size = sample_size * (2 ** (len(block_out_channels) - 1)) + + # forward pass without ip adapter + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + output_without_adapter = pipe(**inputs, output="images") + output_without_adapter = output_without_adapter[0, -3:, -3:, -1].flatten() + + adapter_state_dict = create_ip_adapter_state_dict(pipe.unet) + pipe.unet._load_ip_adapter_weights(adapter_state_dict) + + # forward pass with single ip adapter and masks, but scale=0 which should have no effect + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] + inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] + inputs["cross_attention_kwargs"] = {"ip_adapter_masks": [self._get_dummy_masks(input_size)]} + pipe.set_ip_adapter_scale(0.0) + output_without_adapter_scale = pipe(**inputs, output="images") + output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten() + + # forward pass with single ip adapter and masks, but with scale of adapter weights + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] + inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] + inputs["cross_attention_kwargs"] = {"ip_adapter_masks": [self._get_dummy_masks(input_size)]} + pipe.set_ip_adapter_scale(42.0) + output_with_adapter_scale = pipe(**inputs, output="images") + output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten() + + max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max() + max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max() + + self.assertLess( + max_diff_without_adapter_scale, + expected_max_diff, + "Output without ip-adapter must be same as normal inference", + ) + self.assertGreater( + max_diff_with_adapter_scale, 1e-3, "Output with ip-adapter must be different from normal inference" + ) + + def test_ip_adapter_faceid(self, expected_max_diff: float = 1e-4): + blocks = self.pipeline_blocks_class() + _ = blocks.sub_blocks.pop("ip_adapter") + pipe = blocks.init_pipeline(self.repo) + pipe.load_default_components(torch_dtype=torch.float32) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + cross_attention_dim = pipe.unet.config.get("cross_attention_dim", 32) + + # forward pass without ip adapter + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + output_without_adapter = pipe(**inputs, output="images") + output_without_adapter = output_without_adapter[0, -3:, -3:, -1].flatten() + + adapter_state_dict = create_ip_adapter_faceid_state_dict(pipe.unet) + pipe.unet._load_ip_adapter_weights(adapter_state_dict) + + # forward pass with single ip adapter, but scale=0 which should have no effect + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + inputs["ip_adapter_embeds"] = [self._get_dummy_faceid_image_embeds(cross_attention_dim)] + inputs["negative_ip_adapter_embeds"] = [self._get_dummy_faceid_image_embeds(cross_attention_dim)] + pipe.set_ip_adapter_scale(0.0) + output_without_adapter_scale = pipe(**inputs, output="images") + output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten() + + # forward pass with single ip adapter, but with scale of adapter weights + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + inputs["ip_adapter_embeds"] = [self._get_dummy_faceid_image_embeds(cross_attention_dim)] + inputs["negative_ip_adapter_embeds"] = [self._get_dummy_faceid_image_embeds(cross_attention_dim)] + pipe.set_ip_adapter_scale(42.0) + output_with_adapter_scale = pipe(**inputs, output="images") + output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten() + + max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max() + max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max() + + self.assertLess( + max_diff_without_adapter_scale, + expected_max_diff, + "Output without ip-adapter must be same as normal inference", + ) + self.assertGreater( + max_diff_with_adapter_scale, 1e-3, "Output with ip-adapter must be different from normal inference" + ) + + +@require_torch +class ModularPipelineTesterMixin: + """ + This mixin is designed to be used with unittest.TestCase classes. + It provides a set of common tests for each PyTorch pipeline, e.g. saving and loading the pipeline, + equivalence of dict and tuple outputs, etc. + """ + + # Canonical parameters that are passed to `__call__` regardless + # of the type of pipeline. They are always optional and have common + # sense default values. + required_optional_params = frozenset( + [ + "num_inference_steps", + "num_images_per_prompt", + "latents", + "output_type", + ] + ) + # generator needs to be a intermediate input because it's mutable + required_intermediate_params = frozenset( + [ + "generator", + ] + ) + + def get_generator(self, seed): + device = torch_device if torch_device != "mps" else "cpu" + generator = torch.Generator(device).manual_seed(seed) + return generator + + @property + def pipeline_class(self) -> Union[Callable, DiffusionPipeline]: + raise NotImplementedError( + "You need to set the attribute `pipeline_class = ClassNameOfPipeline` in the child test class. " + "See existing pipeline tests for reference." + ) + + @property + def repo(self) -> str: + raise NotImplementedError( + "You need to set the attribute `repo` in the child test class. See existing pipeline tests for reference." + ) + + @property + def pipeline_blocks_class(self) -> Union[Callable, DiffusionPipeline]: + raise NotImplementedError( + "You need to set the attribute `pipeline_blocks_class = ClassNameOfPipelineBlocks` in the child test class. " + "See existing pipeline tests for reference." + ) + + def get_pipeline(self): + raise NotImplementedError( + "You need to implement `get_pipeline(self)` in the child test class. " + "See existing pipeline tests for reference." + ) + + def get_dummy_inputs(self, device, seed=0): + raise NotImplementedError( + "You need to implement `get_dummy_inputs(self, device, seed)` in the child test class. " + "See existing pipeline tests for reference." + ) + + @property + def params(self) -> frozenset: + raise NotImplementedError( + "You need to set the attribute `params` in the child test class. " + "`params` are checked for if all values are present in `__call__`'s signature." + " You can set `params` using one of the common set of parameters defined in `pipeline_params.py`" + " e.g., `TEXT_TO_IMAGE_PARAMS` defines the common parameters used in text to " + "image pipelines, including prompts and prompt embedding overrides." + "If your pipeline's set of arguments has minor changes from one of the common sets of arguments, " + "do not make modifications to the existing common sets of arguments. I.e. a text to image pipeline " + "with non-configurable height and width arguments should set the attribute as " + "`params = TEXT_TO_IMAGE_PARAMS - {'height', 'width'}`. " + "See existing pipeline tests for reference." + ) + + @property + def batch_params(self) -> frozenset: + raise NotImplementedError( + "You need to set the attribute `batch_params` in the child test class. " + "`batch_params` are the parameters required to be batched when passed to the pipeline's " + "`__call__` method. `pipeline_params.py` provides some common sets of parameters such as " + "`TEXT_TO_IMAGE_BATCH_PARAMS`, `IMAGE_VARIATION_BATCH_PARAMS`, etc... If your pipeline's " + "set of batch arguments has minor changes from one of the common sets of batch arguments, " + "do not make modifications to the existing common sets of batch arguments. I.e. a text to " + "image pipeline `negative_prompt` is not batched should set the attribute as " + "`batch_params = TEXT_TO_IMAGE_BATCH_PARAMS - {'negative_prompt'}`. " + "See existing pipeline tests for reference." + ) + + def setUp(self): + # clean up the VRAM before each test + super().setUp() + torch.compiler.reset() + gc.collect() + backend_empty_cache(torch_device) + + def tearDown(self): + # clean up the VRAM after each test in case of CUDA runtime errors + super().tearDown() + torch.compiler.reset() + gc.collect() + backend_empty_cache(torch_device) + + def test_pipeline_call_signature(self): + pipe = self.get_pipeline() + parameters = pipe.blocks.input_names + optional_parameters = pipe.default_call_parameters + intermediate_parameters = pipe.blocks.intermediate_input_names + + remaining_required_parameters = set() + + for param in self.params: + if param not in parameters: + remaining_required_parameters.add(param) + + self.assertTrue( + len(remaining_required_parameters) == 0, + f"Required parameters not present: {remaining_required_parameters}", + ) + + remaining_required_intermediate_parameters = set() + + for param in self.required_intermediate_params: + if param not in intermediate_parameters: + remaining_required_intermediate_parameters.add(param) + + self.assertTrue( + len(remaining_required_intermediate_parameters) == 0, + f"Required intermediate parameters not present: {remaining_required_intermediate_parameters}", + ) + + remaining_required_optional_parameters = set() + + for param in self.required_optional_params: + if param not in optional_parameters: + remaining_required_optional_parameters.add(param) + + self.assertTrue( + len(remaining_required_optional_parameters) == 0, + f"Required optional parameters not present: {remaining_required_optional_parameters}", + ) + + def test_inference_batch_consistent(self, batch_sizes=[2]): + self._test_inference_batch_consistent(batch_sizes=batch_sizes) + + def _test_inference_batch_consistent( + self, batch_sizes=[2], additional_params_copy_to_batched_inputs=["num_inference_steps"], batch_generator=True + ): + pipe = self.get_pipeline() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + inputs["generator"] = self.get_generator(0) + + logger = logging.get_logger(pipe.__module__) + logger.setLevel(level=diffusers.logging.FATAL) + + # prepare batched inputs + batched_inputs = [] + for batch_size in batch_sizes: + batched_input = {} + batched_input.update(inputs) + + for name in self.batch_params: + if name not in inputs: + continue + + value = inputs[name] + if name == "prompt": + len_prompt = len(value) + # make unequal batch sizes + batched_input[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)] + + # make last batch super long + batched_input[name][-1] = 100 * "very long" + + else: + batched_input[name] = batch_size * [value] + + if batch_generator and "generator" in inputs: + batched_input["generator"] = [self.get_generator(i) for i in range(batch_size)] + + if "batch_size" in inputs: + batched_input["batch_size"] = batch_size + + batched_inputs.append(batched_input) + + logger.setLevel(level=diffusers.logging.WARNING) + for batch_size, batched_input in zip(batch_sizes, batched_inputs): + output = pipe(**batched_input, output="images") + assert len(output) == batch_size + + def test_inference_batch_single_identical(self, batch_size=3, expected_max_diff=1e-4): + self._test_inference_batch_single_identical(batch_size=batch_size, expected_max_diff=expected_max_diff) + + def _test_inference_batch_single_identical( + self, + batch_size=2, + expected_max_diff=1e-4, + additional_params_copy_to_batched_inputs=["num_inference_steps"], + ): + pipe = self.get_pipeline() + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + inputs = self.get_dummy_inputs(torch_device) + # Reset generator in case it is has been used in self.get_dummy_inputs + inputs["generator"] = self.get_generator(0) + + logger = logging.get_logger(pipe.__module__) + logger.setLevel(level=diffusers.logging.FATAL) + + # batchify inputs + batched_inputs = {} + batched_inputs.update(inputs) + + for name in self.batch_params: + if name not in inputs: + continue + + value = inputs[name] + if name == "prompt": + len_prompt = len(value) + batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)] + batched_inputs[name][-1] = 100 * "very long" + + else: + batched_inputs[name] = batch_size * [value] + + if "generator" in inputs: + batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)] + + if "batch_size" in inputs: + batched_inputs["batch_size"] = batch_size + + for arg in additional_params_copy_to_batched_inputs: + batched_inputs[arg] = inputs[arg] + + output = pipe(**inputs, output="images") + output_batch = pipe(**batched_inputs, output="images") + + assert output_batch.shape[0] == batch_size + + max_diff = np.abs(to_np(output_batch[0]) - to_np(output[0])).max() + assert max_diff < expected_max_diff + + @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU") + @require_accelerator + def test_float16_inference(self, expected_max_diff=5e-2): + pipe = self.get_pipeline(torch_dtype=torch.float32) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + pipe_fp16 = self.get_pipeline(torch_dtype=torch.float16) + for component in pipe_fp16.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe_fp16.to(torch_device, torch.float16) + pipe_fp16.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + # Reset generator in case it is used inside dummy inputs + if "generator" in inputs: + inputs["generator"] = self.get_generator(0) + output = pipe(**inputs, output="images") + + fp16_inputs = self.get_dummy_inputs(torch_device) + # Reset generator in case it is used inside dummy inputs + if "generator" in fp16_inputs: + fp16_inputs["generator"] = self.get_generator(0) + output_fp16 = pipe_fp16(**fp16_inputs, output="images") + + if isinstance(output, torch.Tensor): + output = output.cpu() + output_fp16 = output_fp16.cpu() + + max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten()) + assert max_diff < expected_max_diff + + @require_accelerator + def test_to_device(self): + pipe = self.get_pipeline() + pipe.set_progress_bar_config(disable=None) + + pipe.to("cpu") + model_devices = [ + component.device.type for component in pipe.components.values() if hasattr(component, "device") + ] + self.assertTrue(all(device == "cpu" for device in model_devices)) + + output_cpu = pipe(**self.get_dummy_inputs("cpu"), output="images") + self.assertTrue(np.isnan(output_cpu).sum() == 0) + + pipe.to(torch_device) + model_devices = [ + component.device.type for component in pipe.components.values() if hasattr(component, "device") + ] + self.assertTrue(all(device == torch_device for device in model_devices)) + + output_device = pipe(**self.get_dummy_inputs(torch_device), output="images") + self.assertTrue(np.isnan(to_np(output_device)).sum() == 0) + + def test_num_images_per_prompt(self): + pipe = self.get_pipeline() + + if "num_images_per_prompt" not in pipe.blocks.input_names: + return + + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + batch_sizes = [1, 2] + num_images_per_prompts = [1, 2] + + for batch_size in batch_sizes: + for num_images_per_prompt in num_images_per_prompts: + inputs = self.get_dummy_inputs(torch_device) + + for key in inputs.keys(): + if key in self.batch_params: + inputs[key] = batch_size * [inputs[key]] + + images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt, output="images") + + assert images.shape[0] == batch_size * num_images_per_prompt + + def test_cfg(self): + pipe = self.get_pipeline() + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + guider = ClassifierFreeGuidance(guidance_scale=1.0) + pipe.update_components(guider=guider) + + inputs = self.get_dummy_inputs(torch_device) + out_no_cfg = pipe(**inputs, output="images") + + guider = ClassifierFreeGuidance(guidance_scale=7.5) + pipe.update_components(guider=guider) + + out_cfg = pipe(**inputs, output="images") + + assert out_cfg.shape == out_no_cfg.shape + + +# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used. +# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a +# reference image. +def assert_mean_pixel_difference(image, expected_image, expected_max_diff=10): + image = np.asarray(DiffusionPipeline.numpy_to_pil(image)[0], dtype=np.float32) + expected_image = np.asarray(DiffusionPipeline.numpy_to_pil(expected_image)[0], dtype=np.float32) + avg_diff = np.abs(image - expected_image).mean() + assert avg_diff < expected_max_diff, f"Error image deviates {avg_diff} pixels on average" From d8fa2de36f5076618b8a9ca1336a5f8068b5c031 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 16 Jul 2025 04:29:27 +0200 Subject: [PATCH 05/16] remove more unused func --- tests/pipelines/test_modular_pipelines_common.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/tests/pipelines/test_modular_pipelines_common.py b/tests/pipelines/test_modular_pipelines_common.py index f8e46b2ca592..9ce11ccbde82 100644 --- a/tests/pipelines/test_modular_pipelines_common.py +++ b/tests/pipelines/test_modular_pipelines_common.py @@ -33,11 +33,6 @@ def to_np(tensor): return tensor -def check_same_shape(tensor_list): - shapes = [tensor.shape for tensor in tensor_list] - return all(shape == shapes[0] for shape in shapes[1:]) - - class ModularIPAdapterTesterMixin: """ This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes. @@ -677,13 +672,3 @@ def test_cfg(self): out_cfg = pipe(**inputs, output="images") assert out_cfg.shape == out_no_cfg.shape - - -# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used. -# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a -# reference image. -def assert_mean_pixel_difference(image, expected_image, expected_max_diff=10): - image = np.asarray(DiffusionPipeline.numpy_to_pil(image)[0], dtype=np.float32) - expected_image = np.asarray(DiffusionPipeline.numpy_to_pil(expected_image)[0], dtype=np.float32) - avg_diff = np.abs(image - expected_image).mean() - assert avg_diff < expected_max_diff, f"Error image deviates {avg_diff} pixels on average" From 4b7a9e9fa9f845b7499a858bec04b538f7acd65e Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 16 Jul 2025 11:57:29 +0200 Subject: [PATCH 06/16] prepare_latents_inpaint always return noise and image_latents --- .../stable_diffusion_xl/before_denoise.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py index c56f4af1b8a5..1800a613ec0f 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py @@ -744,8 +744,6 @@ def prepare_latents_inpaint( timestep=None, is_strength_max=True, add_noise=True, - return_noise=False, - return_image_latents=False, ): shape = ( batch_size, @@ -768,7 +766,7 @@ def prepare_latents_inpaint( if image.shape[1] == 4: image_latents = image.to(device=device, dtype=dtype) image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) - elif return_image_latents or (latents is None and not is_strength_max): + elif latents is None and not is_strength_max: image = image.to(device=device, dtype=dtype) image_latents = self._encode_vae_image(components, image=image, generator=generator) image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) @@ -786,13 +784,7 @@ def prepare_latents_inpaint( noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = image_latents.to(device) - outputs = (latents,) - - if return_noise: - outputs += (noise,) - - if return_image_latents: - outputs += (image_latents,) + outputs = (latents, noise, image_latents) return outputs @@ -864,7 +856,7 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline block_state.height = block_state.image_latents.shape[-2] * components.vae_scale_factor block_state.width = block_state.image_latents.shape[-1] * components.vae_scale_factor - block_state.latents, block_state.noise = self.prepare_latents_inpaint( + block_state.latents, block_state.noise, block_state.image_latents = self.prepare_latents_inpaint( components, block_state.batch_size * block_state.num_images_per_prompt, components.num_channels_latents, @@ -878,8 +870,6 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline timestep=block_state.latent_timestep, is_strength_max=block_state.is_strength_max, add_noise=block_state.add_noise, - return_noise=True, - return_image_latents=False, ) # 7. Prepare mask latent variables From 5f560d05a2242903c121955a05812efccd8c772e Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 16 Jul 2025 11:58:23 +0200 Subject: [PATCH 07/16] up --- .../test_stable_diffusion_xl_modular.py | 497 ++++++++++++++---- .../test_modular_pipelines_common.py | 312 +---------- 2 files changed, 400 insertions(+), 409 deletions(-) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_modular.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_modular.py index 0d9c3d0f20cc..3c9ace762b7e 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_modular.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_modular.py @@ -13,34 +13,40 @@ # See the License for the specific language governing permissions and # limitations under the License. +import random import tempfile import unittest +from typing import Any, Dict import numpy as np import torch +from PIL import Image from diffusers import ( + ClassifierFreeGuidance, ComponentsManager, - LCMScheduler, ModularPipeline, StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline, ) +from diffusers.loaders import ModularIPAdapterMixin from diffusers.utils.testing_utils import ( enable_full_determinism, + floats_tensor, require_torch_accelerator, torch_device, ) +from ...models.unets.test_models_unet_2d_condition import ( + create_ip_adapter_state_dict, +) from ..pipeline_params import ( IMAGE_INPAINTING_BATCH_PARAMS, IMAGE_INPAINTING_PARAMS, TEXT_TO_IMAGE_BATCH_PARAMS, - TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS, ) from ..test_modular_pipelines_common import ( - ModularIPAdapterTesterMixin, ModularPipelineTesterMixin, ) @@ -48,11 +54,11 @@ enable_full_determinism() -class StableDiffusionXLModularPipelineFastTests( - ModularIPAdapterTesterMixin, - ModularPipelineTesterMixin, - unittest.TestCase, -): +class SDXLModularTests: + """ + This mixin defines method to create pipeline, base input and base test across all SDXL modular tests. + """ + pipeline_class = StableDiffusionXLModularPipeline pipeline_blocks_class = StableDiffusionXLAutoBlocks repo = "hf-internal-testing/tiny-sdxl-modular" @@ -62,7 +68,6 @@ class StableDiffusionXLModularPipelineFastTests( "negative_prompt_embeds", } batch_params = TEXT_TO_IMAGE_BATCH_PARAMS | IMAGE_INPAINTING_BATCH_PARAMS - image_params = TEXT_TO_IMAGE_IMAGE_PARAMS def get_pipeline(self, components_manager=None, torch_dtype=torch.float32): pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager) @@ -82,7 +87,7 @@ def get_dummy_inputs(self, device, seed=0): } return inputs - def test_stable_diffusion_xl_euler(self): + def _test_stable_diffusion_xl_euler(self, expected_image_shape, expected_slice, expected_max_diff=1e-2): device = "cpu" # ensure determinism for the device-dependent torch.Generator sd_pipe = self.get_pipeline() sd_pipe = sd_pipe.to(device) @@ -92,126 +97,339 @@ def test_stable_diffusion_xl_euler(self): image = sd_pipe(**inputs, output="images") image_slice = image[0, -3:, -3:, -1] - assert image.shape == (1, 64, 64, 3) - expected_slice = np.array( - [0.5966781, 0.62939394, 0.48465094, 0.51573336, 0.57593524, 0.47035995, 0.53410417, 0.51436996, 0.47313565] + assert image.shape == expected_image_shape + + assert np.abs(image_slice.flatten() - expected_slice).max() < expected_max_diff, ( + f"image_slice: {image_slice.flatten()}, expected_slice: {expected_slice}" ) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, ( - f"image_slice: {image_slice.flatten()}, expected_slice: {expected_slice.flatten()}" + +class SDXLModularIPAdapterTests: + """ + This mixin is designed to test IP Adapter. + """ + + def test_pipeline_inputs_and_blocks(self): + blocks = self.pipeline_blocks_class() + parameters = blocks.input_names + + assert issubclass(self.pipeline_class, ModularIPAdapterMixin) + self.assertIn( + "ip_adapter_image", + parameters, + "`ip_adapter_image` argument must be supported by the `__call__` method", + ) + self.assertIn( + "ip_adapter", + blocks.sub_blocks, + "pipeline must contain an IPAdapter block", ) - def test_stable_diffusion_xl_euler_lcm(self): - device = "cpu" # ensure determinism for the device-dependent torch.Generator - sd_pipe = self.get_pipeline() - sd_pipe.update_components(scheduler=LCMScheduler.from_config(sd_pipe.scheduler.config)) - sd_pipe = sd_pipe.to(device) - sd_pipe.set_progress_bar_config(disable=None) + _ = blocks.sub_blocks.pop("ip_adapter") + parameters = blocks.input_names + intermediate_parameters = blocks.intermediate_input_names + self.assertNotIn( + "ip_adapter_image", + parameters, + "`ip_adapter_image` argument must be removed from the `__call__` method", + ) + self.assertNotIn( + "ip_adapter_image_embeds", + intermediate_parameters, + "`ip_adapter_image_embeds` argument must be supported by the `__call__` method", + ) - inputs = self.get_dummy_inputs(device) - image = sd_pipe(**inputs, output="images") - image_slice = image[0, -3:, -3:, -1] + def _get_dummy_image_embeds(self, cross_attention_dim: int = 32): + return torch.randn((1, 1, cross_attention_dim), device=torch_device) + + def _get_dummy_faceid_image_embeds(self, cross_attention_dim: int = 32): + return torch.randn((1, 1, 1, cross_attention_dim), device=torch_device) + + def _get_dummy_masks(self, input_size: int = 64): + _masks = torch.zeros((1, 1, input_size, input_size), device=torch_device) + _masks[0, :, :, : int(input_size / 2)] = 1 + return _masks - assert image.shape == (1, 64, 64, 3) - expected_slice = np.array( - [0.6880376, 0.6511651, 0.587455, 0.61763, 0.55432945, 0.52064973, 0.5783733, 0.54915607, 0.5460011] + def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]): + blocks = self.pipeline_blocks_class() + _ = blocks.sub_blocks.pop("ip_adapter") + parameters = blocks.input_names + if "image" in parameters and "strength" in parameters: + inputs["num_inference_steps"] = 4 + + inputs["output_type"] = "np" + return inputs + + def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None): + r"""Tests for IP-Adapter. + + The following scenarios are tested: + - Single IP-Adapter with scale=0 should produce same output as no IP-Adapter. + - Multi IP-Adapter with scale=0 should produce same output as no IP-Adapter. + - Single IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter. + - Multi IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter. + """ + # Raising the tolerance for this test when it's run on a CPU because we + # compare against static slices and that can be shaky (with a VVVV low probability). + expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff + + blocks = self.pipeline_blocks_class() + _ = blocks.sub_blocks.pop("ip_adapter") + pipe = blocks.init_pipeline(self.repo) + pipe.load_default_components(torch_dtype=torch.float32) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + cross_attention_dim = pipe.unet.config.get("cross_attention_dim") + + # forward pass without ip adapter + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + if expected_pipe_slice is None: + output_without_adapter = pipe(**inputs, output="images") + else: + output_without_adapter = expected_pipe_slice + + # 1. Single IP-Adapter test cases + adapter_state_dict = create_ip_adapter_state_dict(pipe.unet) + pipe.unet._load_ip_adapter_weights(adapter_state_dict) + + # forward pass with single ip adapter, but scale=0 which should have no effect + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] + inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] + pipe.set_ip_adapter_scale(0.0) + output_without_adapter_scale = pipe(**inputs, output="images") + if expected_pipe_slice is not None: + output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten() + + # forward pass with single ip adapter, but with scale of adapter weights + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] + inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] + pipe.set_ip_adapter_scale(42.0) + output_with_adapter_scale = pipe(**inputs, output="images") + if expected_pipe_slice is not None: + output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten() + + max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max() + max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max() + + self.assertLess( + max_diff_without_adapter_scale, + expected_max_diff, + "Output without ip-adapter must be same as normal inference", + ) + self.assertGreater( + max_diff_with_adapter_scale, 1e-2, "Output with ip-adapter must be different from normal inference" ) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, ( - f"image_slice: {image_slice.flatten()}, expected_slice: {expected_slice.flatten()}" + # 2. Multi IP-Adapter test cases + adapter_state_dict_1 = create_ip_adapter_state_dict(pipe.unet) + adapter_state_dict_2 = create_ip_adapter_state_dict(pipe.unet) + pipe.unet._load_ip_adapter_weights([adapter_state_dict_1, adapter_state_dict_2]) + + # forward pass with multi ip adapter, but scale=0 which should have no effect + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2 + inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2 + pipe.set_ip_adapter_scale([0.0, 0.0]) + output_without_multi_adapter_scale = pipe(**inputs, output="images") + if expected_pipe_slice is not None: + output_without_multi_adapter_scale = output_without_multi_adapter_scale[0, -3:, -3:, -1].flatten() + + # forward pass with multi ip adapter, but with scale of adapter weights + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2 + inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2 + pipe.set_ip_adapter_scale([42.0, 42.0]) + output_with_multi_adapter_scale = pipe(**inputs, output="images") + if expected_pipe_slice is not None: + output_with_multi_adapter_scale = output_with_multi_adapter_scale[0, -3:, -3:, -1].flatten() + + max_diff_without_multi_adapter_scale = np.abs( + output_without_multi_adapter_scale - output_without_adapter + ).max() + max_diff_with_multi_adapter_scale = np.abs(output_with_multi_adapter_scale - output_without_adapter).max() + self.assertLess( + max_diff_without_multi_adapter_scale, + expected_max_diff, + "Output without multi-ip-adapter must be same as normal inference", + ) + self.assertGreater( + max_diff_with_multi_adapter_scale, + 1e-2, + "Output with multi-ip-adapter scale must be different from normal inference", ) - @require_torch_accelerator - def test_stable_diffusion_xl_offloads(self): - pipes = [] - sd_pipe = self.get_pipeline().to(torch_device) - pipes.append(sd_pipe) - cm = ComponentsManager() - cm.enable_auto_cpu_offload(device=torch_device) - sd_pipe = self.get_pipeline(components_manager=cm) - pipes.append(sd_pipe) +class SDXLModularControlNetTests: + """ + This mixin is designed to test ControlNet. + """ - image_slices = [] - for pipe in pipes: - inputs = self.get_dummy_inputs(torch_device) - image = pipe(**inputs, output="images") + def test_pipeline_inputs(self): + blocks = self.pipeline_blocks_class() + parameters = blocks.input_names - image_slices.append(image[0, -3:, -3:, -1].flatten()) + self.assertIn( + "control_image", + parameters, + "`control_image` argument must be supported by the `__call__` method", + ) + self.assertIn( + "controlnet_conditioning_scale", + parameters, + "`controlnet_conditioning_scale` argument must be supported by the `__call__` method", + ) - assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3 + def _modify_inputs_for_controlnet_test(self, inputs: Dict[str, Any]): + controlnet_embedder_scale_factor = 2 + image = torch.randn( + (1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor), + device=torch_device, + ) + inputs["control_image"] = image + return inputs - def test_stable_diffusion_xl_multi_prompts(self): - sd_pipe = self.get_pipeline().to(torch_device) + def test_controlnet(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None): + r"""Tests for ControlNet. - # forward with single prompt - inputs = self.get_dummy_inputs(torch_device) - output = sd_pipe(**inputs, output="images") - image_slice_1 = output[0, -3:, -3:, -1] + The following scenarios are tested: + - Single ControlNet with scale=0 should produce same output as no ControlNet. + - Single ControlNet with scale!=0 should produce different output compared to no ControlNet. + """ + # Raising the tolerance for this test when it's run on a CPU because we + # compare against static slices and that can be shaky (with a VVVV low probability). + expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff + + pipe = self.get_pipeline() + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) - # forward with same prompt duplicated + # forward pass without controlnet inputs = self.get_dummy_inputs(torch_device) - inputs["prompt_2"] = inputs["prompt"] - output = sd_pipe(**inputs, output="images") - image_slice_2 = output[0, -3:, -3:, -1] + output_without_controlnet = pipe(**inputs, output="images") + output_without_controlnet = output_without_controlnet[0, -3:, -3:, -1].flatten() + + # forward pass with single controlnet, but scale=0 which should have no effect + inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device)) + inputs["controlnet_conditioning_scale"] = 0.0 + output_without_controlnet_scale = pipe(**inputs, output="images") + output_without_controlnet_scale = output_without_controlnet_scale[0, -3:, -3:, -1].flatten() + + # forward pass with single controlnet, but with scale of adapter weights + inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device)) + inputs["controlnet_conditioning_scale"] = 42.0 + output_with_controlnet_scale = pipe(**inputs, output="images") + output_with_controlnet_scale = output_with_controlnet_scale[0, -3:, -3:, -1].flatten() + + max_diff_without_controlnet_scale = np.abs(output_without_controlnet_scale - output_without_controlnet).max() + max_diff_with_controlnet_scale = np.abs(output_with_controlnet_scale - output_without_controlnet).max() + + self.assertLess( + max_diff_without_controlnet_scale, + expected_max_diff, + "Output without controlnet must be same as normal inference", + ) + self.assertGreater( + max_diff_with_controlnet_scale, 1e-2, "Output with controlnet must be different from normal inference" + ) - # ensure the results are equal - assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4 + def test_controlnet_cfg(self): + pipe = self.get_pipeline() + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) - # forward with different prompt - inputs = self.get_dummy_inputs(torch_device) - inputs["prompt_2"] = "different prompt" - output = sd_pipe(**inputs, output="images") - image_slice_3 = output[0, -3:, -3:, -1] + # forward pass with CFG not applied + guider = ClassifierFreeGuidance(guidance_scale=1.0) + pipe.update_components(guider=guider) - # ensure the results are not equal - assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4 + inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device)) + out_no_cfg = pipe(**inputs, output="images") - # manually set a negative_prompt - inputs = self.get_dummy_inputs(torch_device) - inputs["negative_prompt"] = "negative prompt" - output = sd_pipe(**inputs, output="images") - image_slice_1 = output[0, -3:, -3:, -1] + # forward pass with CFG applied + guider = ClassifierFreeGuidance(guidance_scale=7.5) + pipe.update_components(guider=guider) + inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device)) + out_cfg = pipe(**inputs, output="images") - # forward with same negative_prompt duplicated - inputs = self.get_dummy_inputs(torch_device) - inputs["negative_prompt"] = "negative prompt" - inputs["negative_prompt_2"] = inputs["negative_prompt"] - output = sd_pipe(**inputs, output="images") - image_slice_2 = output[0, -3:, -3:, -1] + assert out_cfg.shape == out_no_cfg.shape + max_diff = np.abs(out_cfg - out_no_cfg).max() + self.assertGreater(max_diff, 1e-2) + + +class SDXLModularGuiderTests: + def test_guider_cfg(self): + pipe = self.get_pipeline() + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + # forward pass with CFG not applied + guider = ClassifierFreeGuidance(guidance_scale=1.0) + pipe.update_components(guider=guider) - # ensure the results are equal - assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4 + inputs = self.get_dummy_inputs(torch_device) + out_no_cfg = pipe(**inputs, output="images") - # forward with different negative_prompt + # forward pass with CFG applied + guider = ClassifierFreeGuidance(guidance_scale=7.5) + pipe.update_components(guider=guider) inputs = self.get_dummy_inputs(torch_device) - inputs["negative_prompt"] = "negative prompt" - inputs["negative_prompt_2"] = "different negative prompt" - output = sd_pipe(**inputs, output="images") - image_slice_3 = output[0, -3:, -3:, -1] + out_cfg = pipe(**inputs, output="images") - # ensure the results are not equal - assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4 + assert out_cfg.shape == out_no_cfg.shape + max_diff = np.abs(out_cfg - out_no_cfg).max() + self.assertGreater(max_diff, 1e-2) - def test_stable_diffusion_xl_negative_conditions(self): - device = "cpu" # ensure determinism for the device-dependent torch.Generator - sd_pipe = self.get_pipeline().to(device) - sd_pipe.set_progress_bar_config(disable=None) - inputs = self.get_dummy_inputs(device) - image = sd_pipe(**inputs, output="images") - image_slice_with_no_neg_cond = image[0, -3:, -3:, -1] +class SDXLModularPipelineFastTests( + SDXLModularTests, + SDXLModularIPAdapterTests, + SDXLModularControlNetTests, + SDXLModularGuiderTests, + ModularPipelineTesterMixin, + unittest.TestCase, +): + """Test cases for Stable Diffusion XL modular pipeline fast tests.""" - image = sd_pipe( - **inputs, - negative_original_size=(512, 512), - negative_crops_coords_top_left=(0, 0), - negative_target_size=(1024, 1024), - output="images", + def test_stable_diffusion_xl_euler(self): + self._test_stable_diffusion_xl_euler( + expected_image_shape=(1, 64, 64, 3), + expected_slice=[ + 0.5966781, + 0.62939394, + 0.48465094, + 0.51573336, + 0.57593524, + 0.47035995, + 0.53410417, + 0.51436996, + 0.47313565, + ], + expected_max_diff=1e-2, ) - image_slice_with_neg_cond = image[0, -3:, -3:, -1] - self.assertTrue(np.abs(image_slice_with_no_neg_cond - image_slice_with_neg_cond).max() > 1e-2) + def test_inference_batch_single_identical(self): + super().test_inference_batch_single_identical(expected_max_diff=3e-3) + + @require_torch_accelerator + def test_stable_diffusion_xl_offloads(self): + pipes = [] + sd_pipe = self.get_pipeline().to(torch_device) + pipes.append(sd_pipe) + + cm = ComponentsManager() + cm.enable_auto_cpu_offload(device=torch_device) + sd_pipe = self.get_pipeline(components_manager=cm) + pipes.append(sd_pipe) + + image_slices = [] + for pipe in pipes: + inputs = self.get_dummy_inputs(torch_device) + image = pipe(**inputs, output="images") + + image_slices.append(image[0, -3:, -3:, -1].flatten()) + + assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3 def test_stable_diffusion_xl_save_from_pretrained(self): pipes = [] @@ -234,5 +452,88 @@ def test_stable_diffusion_xl_save_from_pretrained(self): assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3 + +class SDXLImg2ImgModularPipelineFastTests( + SDXLModularTests, + SDXLModularIPAdapterTests, + SDXLModularControlNetTests, + SDXLModularGuiderTests, + ModularPipelineTesterMixin, + unittest.TestCase, +): + """Test cases for Stable Diffusion XL image-to-image modular pipeline fast tests.""" + + def get_dummy_inputs(self, device, seed=0): + inputs = super().get_dummy_inputs(device, seed) + image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device) + image = image / 2 + 0.5 + inputs["image"] = image + inputs["strength"] = 0.8 + + return inputs + + def test_stable_diffusion_xl_euler(self): + self._test_stable_diffusion_xl_euler( + expected_image_shape=(1, 64, 64, 3), + expected_slice=[ + 0.56943184, + 0.4702148, + 0.48048905, + 0.6235963, + 0.551138, + 0.49629188, + 0.60031277, + 0.5688907, + 0.43996853, + ], + expected_max_diff=1e-2, + ) + + def test_inference_batch_single_identical(self): + super().test_inference_batch_single_identical(expected_max_diff=3e-3) + + +class SDXLInpaintingModularPipelineFastTests( + SDXLModularTests, + SDXLModularIPAdapterTests, + SDXLModularControlNetTests, + SDXLModularGuiderTests, + ModularPipelineTesterMixin, + unittest.TestCase, +): + """Test cases for Stable Diffusion XL inpainting modular pipeline fast tests.""" + + def get_dummy_inputs(self, device, seed=0): + inputs = super().get_dummy_inputs(device, seed) + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + image = image.cpu().permute(0, 2, 3, 1)[0] + init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64)) + # create mask + image[8:, 8:, :] = 255 + mask_image = Image.fromarray(np.uint8(image)).convert("L").resize((64, 64)) + + inputs["image"] = init_image + inputs["mask_image"] = mask_image + inputs["strength"] = 1.0 + + return inputs + + def test_stable_diffusion_xl_euler(self): + self._test_stable_diffusion_xl_euler( + expected_image_shape=(1, 64, 64, 3), + expected_slice=[ + 0.40872607, + 0.38842705, + 0.34893104, + 0.47837183, + 0.43792963, + 0.5332134, + 0.3716843, + 0.47274873, + 0.45000193, + ], + expected_max_diff=1e-2, + ) + def test_inference_batch_single_identical(self): super().test_inference_batch_single_identical(expected_max_diff=3e-3) diff --git a/tests/pipelines/test_modular_pipelines_common.py b/tests/pipelines/test_modular_pipelines_common.py index 9ce11ccbde82..f5f08ac6a72e 100644 --- a/tests/pipelines/test_modular_pipelines_common.py +++ b/tests/pipelines/test_modular_pipelines_common.py @@ -1,16 +1,14 @@ import gc import unittest -from typing import Any, Callable, Dict, Union +from typing import Callable, Union import numpy as np import torch import diffusers from diffusers import ( - ClassifierFreeGuidance, DiffusionPipeline, ) -from diffusers.loaders import ModularIPAdapterMixin from diffusers.utils import logging from diffusers.utils.testing_utils import ( backend_empty_cache, @@ -20,11 +18,6 @@ torch_device, ) -from ..models.unets.test_models_unet_2d_condition import ( - create_ip_adapter_faceid_state_dict, - create_ip_adapter_state_dict, -) - def to_np(tensor): if isinstance(tensor, torch.Tensor): @@ -33,291 +26,6 @@ def to_np(tensor): return tensor -class ModularIPAdapterTesterMixin: - """ - This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes. - It provides a set of common tests for pipelines that support IP Adapters. - """ - - def test_pipeline_inputs_and_blocks(self): - blocks = self.pipeline_blocks_class() - parameters = blocks.input_names - - assert issubclass(self.pipeline_class, ModularIPAdapterMixin) - self.assertIn( - "ip_adapter_image", - parameters, - "`ip_adapter_image` argument must be supported by the `__call__` method", - ) - self.assertIn( - "ip_adapter", - blocks.sub_blocks, - "pipeline must contain an IPAdapter block", - ) - - _ = blocks.sub_blocks.pop("ip_adapter") - parameters = blocks.input_names - intermediate_parameters = blocks.intermediate_input_names - self.assertNotIn( - "ip_adapter_image", - parameters, - "`ip_adapter_image` argument must be removed from the `__call__` method", - ) - self.assertNotIn( - "ip_adapter_image_embeds", - intermediate_parameters, - "`ip_adapter_image_embeds` argument must be supported by the `__call__` method", - ) - - def _get_dummy_image_embeds(self, cross_attention_dim: int = 32): - return torch.randn((1, 1, cross_attention_dim), device=torch_device) - - def _get_dummy_faceid_image_embeds(self, cross_attention_dim: int = 32): - return torch.randn((1, 1, 1, cross_attention_dim), device=torch_device) - - def _get_dummy_masks(self, input_size: int = 64): - _masks = torch.zeros((1, 1, input_size, input_size), device=torch_device) - _masks[0, :, :, : int(input_size / 2)] = 1 - return _masks - - def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]): - blocks = self.pipeline_blocks_class() - _ = blocks.sub_blocks.pop("ip_adapter") - parameters = blocks.input_names - if "image" in parameters and "strength" in parameters: - inputs["num_inference_steps"] = 4 - - inputs["output_type"] = "np" - return inputs - - def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None): - r"""Tests for IP-Adapter. - - The following scenarios are tested: - - Single IP-Adapter with scale=0 should produce same output as no IP-Adapter. - - Multi IP-Adapter with scale=0 should produce same output as no IP-Adapter. - - Single IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter. - - Multi IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter. - """ - # Raising the tolerance for this test when it's run on a CPU because we - # compare against static slices and that can be shaky (with a VVVV low probability). - expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff - - blocks = self.pipeline_blocks_class() - _ = blocks.sub_blocks.pop("ip_adapter") - pipe = blocks.init_pipeline(self.repo) - pipe.load_default_components(torch_dtype=torch.float32) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - cross_attention_dim = pipe.unet.config.get("cross_attention_dim") - - # forward pass without ip adapter - inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) - if expected_pipe_slice is None: - output_without_adapter = pipe(**inputs, output="images") - else: - output_without_adapter = expected_pipe_slice - - # 1. Single IP-Adapter test cases - adapter_state_dict = create_ip_adapter_state_dict(pipe.unet) - pipe.unet._load_ip_adapter_weights(adapter_state_dict) - - # forward pass with single ip adapter, but scale=0 which should have no effect - inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) - inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] - inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] - pipe.set_ip_adapter_scale(0.0) - output_without_adapter_scale = pipe(**inputs, output="images") - if expected_pipe_slice is not None: - output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten() - - # forward pass with single ip adapter, but with scale of adapter weights - inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) - inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] - inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] - pipe.set_ip_adapter_scale(42.0) - output_with_adapter_scale = pipe(**inputs, output="images") - if expected_pipe_slice is not None: - output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten() - - max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max() - max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max() - - self.assertLess( - max_diff_without_adapter_scale, - expected_max_diff, - "Output without ip-adapter must be same as normal inference", - ) - self.assertGreater( - max_diff_with_adapter_scale, 1e-2, "Output with ip-adapter must be different from normal inference" - ) - - # 2. Multi IP-Adapter test cases - adapter_state_dict_1 = create_ip_adapter_state_dict(pipe.unet) - adapter_state_dict_2 = create_ip_adapter_state_dict(pipe.unet) - pipe.unet._load_ip_adapter_weights([adapter_state_dict_1, adapter_state_dict_2]) - - # forward pass with multi ip adapter, but scale=0 which should have no effect - inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) - inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2 - inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2 - pipe.set_ip_adapter_scale([0.0, 0.0]) - output_without_multi_adapter_scale = pipe(**inputs, output="images") - if expected_pipe_slice is not None: - output_without_multi_adapter_scale = output_without_multi_adapter_scale[0, -3:, -3:, -1].flatten() - - # forward pass with multi ip adapter, but with scale of adapter weights - inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) - inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2 - inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2 - pipe.set_ip_adapter_scale([42.0, 42.0]) - output_with_multi_adapter_scale = pipe(**inputs, output="images") - if expected_pipe_slice is not None: - output_with_multi_adapter_scale = output_with_multi_adapter_scale[0, -3:, -3:, -1].flatten() - - max_diff_without_multi_adapter_scale = np.abs( - output_without_multi_adapter_scale - output_without_adapter - ).max() - max_diff_with_multi_adapter_scale = np.abs(output_with_multi_adapter_scale - output_without_adapter).max() - self.assertLess( - max_diff_without_multi_adapter_scale, - expected_max_diff, - "Output without multi-ip-adapter must be same as normal inference", - ) - self.assertGreater( - max_diff_with_multi_adapter_scale, - 1e-2, - "Output with multi-ip-adapter scale must be different from normal inference", - ) - - def test_ip_adapter_cfg(self, expected_max_diff: float = 1e-4): - blocks = self.pipeline_blocks_class() - _ = blocks.sub_blocks.pop("ip_adapter") - pipe = blocks.init_pipeline(self.repo) - pipe.load_default_components(torch_dtype=torch.float32) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - cross_attention_dim = pipe.unet.config.get("cross_attention_dim", 32) - - adapter_state_dict = create_ip_adapter_state_dict(pipe.unet) - pipe.unet._load_ip_adapter_weights(adapter_state_dict) - pipe.set_ip_adapter_scale(1.0) - - # forward pass with CFG not applied - guider = ClassifierFreeGuidance(guidance_scale=1.0) - pipe.update_components(guider=guider) - - inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) - inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)[0].unsqueeze(0)] - out_no_cfg = pipe(**inputs, output="images") - - # forward pass with CFG applied - guider = ClassifierFreeGuidance(guidance_scale=7.5) - pipe.update_components(guider=guider) - inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) - inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] - inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] - out_cfg = pipe(**inputs, output="images") - - assert out_cfg.shape == out_no_cfg.shape - - def test_ip_adapter_masks(self, expected_max_diff: float = 1e-4): - blocks = self.pipeline_blocks_class() - _ = blocks.sub_blocks.pop("ip_adapter") - pipe = blocks.init_pipeline(self.repo) - pipe.load_default_components(torch_dtype=torch.float32) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - cross_attention_dim = pipe.unet.config.get("cross_attention_dim", 32) - sample_size = pipe.unet.config.get("sample_size", 32) - block_out_channels = pipe.vae.config.get("block_out_channels", [128, 256, 512, 512]) - input_size = sample_size * (2 ** (len(block_out_channels) - 1)) - - # forward pass without ip adapter - inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) - output_without_adapter = pipe(**inputs, output="images") - output_without_adapter = output_without_adapter[0, -3:, -3:, -1].flatten() - - adapter_state_dict = create_ip_adapter_state_dict(pipe.unet) - pipe.unet._load_ip_adapter_weights(adapter_state_dict) - - # forward pass with single ip adapter and masks, but scale=0 which should have no effect - inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) - inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] - inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] - inputs["cross_attention_kwargs"] = {"ip_adapter_masks": [self._get_dummy_masks(input_size)]} - pipe.set_ip_adapter_scale(0.0) - output_without_adapter_scale = pipe(**inputs, output="images") - output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten() - - # forward pass with single ip adapter and masks, but with scale of adapter weights - inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) - inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] - inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] - inputs["cross_attention_kwargs"] = {"ip_adapter_masks": [self._get_dummy_masks(input_size)]} - pipe.set_ip_adapter_scale(42.0) - output_with_adapter_scale = pipe(**inputs, output="images") - output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten() - - max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max() - max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max() - - self.assertLess( - max_diff_without_adapter_scale, - expected_max_diff, - "Output without ip-adapter must be same as normal inference", - ) - self.assertGreater( - max_diff_with_adapter_scale, 1e-3, "Output with ip-adapter must be different from normal inference" - ) - - def test_ip_adapter_faceid(self, expected_max_diff: float = 1e-4): - blocks = self.pipeline_blocks_class() - _ = blocks.sub_blocks.pop("ip_adapter") - pipe = blocks.init_pipeline(self.repo) - pipe.load_default_components(torch_dtype=torch.float32) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - cross_attention_dim = pipe.unet.config.get("cross_attention_dim", 32) - - # forward pass without ip adapter - inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) - output_without_adapter = pipe(**inputs, output="images") - output_without_adapter = output_without_adapter[0, -3:, -3:, -1].flatten() - - adapter_state_dict = create_ip_adapter_faceid_state_dict(pipe.unet) - pipe.unet._load_ip_adapter_weights(adapter_state_dict) - - # forward pass with single ip adapter, but scale=0 which should have no effect - inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) - inputs["ip_adapter_embeds"] = [self._get_dummy_faceid_image_embeds(cross_attention_dim)] - inputs["negative_ip_adapter_embeds"] = [self._get_dummy_faceid_image_embeds(cross_attention_dim)] - pipe.set_ip_adapter_scale(0.0) - output_without_adapter_scale = pipe(**inputs, output="images") - output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten() - - # forward pass with single ip adapter, but with scale of adapter weights - inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) - inputs["ip_adapter_embeds"] = [self._get_dummy_faceid_image_embeds(cross_attention_dim)] - inputs["negative_ip_adapter_embeds"] = [self._get_dummy_faceid_image_embeds(cross_attention_dim)] - pipe.set_ip_adapter_scale(42.0) - output_with_adapter_scale = pipe(**inputs, output="images") - output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten() - - max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max() - max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max() - - self.assertLess( - max_diff_without_adapter_scale, - expected_max_diff, - "Output without ip-adapter must be same as normal inference", - ) - self.assertGreater( - max_diff_with_adapter_scale, 1e-3, "Output with ip-adapter must be different from normal inference" - ) - - @require_torch class ModularPipelineTesterMixin: """ @@ -654,21 +362,3 @@ def test_num_images_per_prompt(self): images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt, output="images") assert images.shape[0] == batch_size * num_images_per_prompt - - def test_cfg(self): - pipe = self.get_pipeline() - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - guider = ClassifierFreeGuidance(guidance_scale=1.0) - pipe.update_components(guider=guider) - - inputs = self.get_dummy_inputs(torch_device) - out_no_cfg = pipe(**inputs, output="images") - - guider = ClassifierFreeGuidance(guidance_scale=7.5) - pipe.update_components(guider=guider) - - out_cfg = pipe(**inputs, output="images") - - assert out_cfg.shape == out_no_cfg.shape From 0998bd75ad4c39a069a881dd6de28c357d6f21f3 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 16 Jul 2025 12:02:58 +0200 Subject: [PATCH 08/16] up --- tests/pipelines/test_modular_pipelines_common.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/pipelines/test_modular_pipelines_common.py b/tests/pipelines/test_modular_pipelines_common.py index f5f08ac6a72e..4bd45e207b91 100644 --- a/tests/pipelines/test_modular_pipelines_common.py +++ b/tests/pipelines/test_modular_pipelines_common.py @@ -30,8 +30,13 @@ def to_np(tensor): class ModularPipelineTesterMixin: """ This mixin is designed to be used with unittest.TestCase classes. - It provides a set of common tests for each PyTorch pipeline, e.g. saving and loading the pipeline, - equivalence of dict and tuple outputs, etc. + It provides a set of common tests for each modular pipeline, + including: + - test_pipeline_call_signature: check if the pipeline's __call__ method has all required parameters + - test_inference_batch_consistent: check if the pipeline's __call__ method can handle batch inputs + - test_inference_batch_single_identical: check if the pipeline's __call__ method can handle single input + - test_float16_inference: check if the pipeline's __call__ method can handle float16 inputs + - test_to_device: check if the pipeline's __call__ method can handle different devices """ # Canonical parameters that are passed to `__call__` regardless @@ -45,7 +50,7 @@ class ModularPipelineTesterMixin: "output_type", ] ) - # generator needs to be a intermediate input because it's mutable + # this is modular specific: generator needs to be a intermediate input because it's mutable required_intermediate_params = frozenset( [ "generator", From 625cc8ede873a56ce9b21f017bf00a30728315b6 Mon Sep 17 00:00:00 2001 From: DN6 Date: Thu, 17 Jul 2025 07:14:35 +0530 Subject: [PATCH 09/16] update --- tests/modular/__init__.py | 0 .../modular/test_modular_pipelines_common.py | 359 ++++++++++++++++++ 2 files changed, 359 insertions(+) create mode 100644 tests/modular/__init__.py create mode 100644 tests/modular/test_modular_pipelines_common.py diff --git a/tests/modular/__init__.py b/tests/modular/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/modular/test_modular_pipelines_common.py b/tests/modular/test_modular_pipelines_common.py new file mode 100644 index 000000000000..2c176d746d69 --- /dev/null +++ b/tests/modular/test_modular_pipelines_common.py @@ -0,0 +1,359 @@ +import gc +import unittest +from typing import Callable, Union + +import numpy as np +import torch + +import diffusers +from diffusers import ( + DiffusionPipeline, +) +from diffusers.utils import logging +from diffusers.utils.testing_utils import ( + backend_empty_cache, + numpy_cosine_similarity_distance, + require_accelerator, + require_torch, + torch_device, +) + + +def to_np(tensor): + if isinstance(tensor, torch.Tensor): + tensor = tensor.detach().cpu().numpy() + + return tensor + + +@require_torch +class ModularPipelineTesterMixin: + """ + This mixin is designed to be used with unittest.TestCase classes. + It provides a set of common tests for each modular pipeline, + including: + - test_pipeline_call_signature: check if the pipeline's __call__ method has all required parameters + - test_inference_batch_consistent: check if the pipeline's __call__ method can handle batch inputs + - test_inference_batch_single_identical: check if the pipeline's __call__ method can handle single input + - test_float16_inference: check if the pipeline's __call__ method can handle float16 inputs + - test_to_device: check if the pipeline's __call__ method can handle different devices + """ + + # Canonical parameters that are passed to `__call__` regardless + # of the type of pipeline. They are always optional and have common + # sense default values. + required_optional_params = frozenset( + [ + "num_inference_steps", + "num_images_per_prompt", + "latents", + "output_type", + ] + ) + # this is modular specific: generator needs to be a intermediate input because it's mutable + required_intermediate_params = frozenset( + [ + "generator", + ] + ) + + def get_generator(self, seed): + device = torch_device if torch_device != "mps" else "cpu" + generator = torch.Generator(device).manual_seed(seed) + return generator + + @property + def pipeline_class(self) -> Union[Callable, DiffusionPipeline]: + raise NotImplementedError( + "You need to set the attribute `pipeline_class = ClassNameOfPipeline` in the child test class. " + "See existing pipeline tests for reference." + ) + + @property + def repo(self) -> str: + raise NotImplementedError( + "You need to set the attribute `repo` in the child test class. See existing pipeline tests for reference." + ) + + @property + def pipeline_blocks_class(self) -> Union[Callable, DiffusionPipeline]: + raise NotImplementedError( + "You need to set the attribute `pipeline_blocks_class = ClassNameOfPipelineBlocks` in the child test class. " + "See existing pipeline tests for reference." + ) + + def get_pipeline(self): + raise NotImplementedError( + "You need to implement `get_pipeline(self)` in the child test class. " + "See existing pipeline tests for reference." + ) + + def get_dummy_inputs(self, device, seed=0): + raise NotImplementedError( + "You need to implement `get_dummy_inputs(self, device, seed)` in the child test class. " + "See existing pipeline tests for reference." + ) + + @property + def params(self) -> frozenset: + raise NotImplementedError( + "You need to set the attribute `params` in the child test class. " + "`params` are checked for if all values are present in `__call__`'s signature." + " You can set `params` using one of the common set of parameters defined in `pipeline_params.py`" + " e.g., `TEXT_TO_IMAGE_PARAMS` defines the common parameters used in text to " + "image pipelines, including prompts and prompt embedding overrides." + "If your pipeline's set of arguments has minor changes from one of the common sets of arguments, " + "do not make modifications to the existing common sets of arguments. I.e. a text to image pipeline " + "with non-configurable height and width arguments should set the attribute as " + "`params = TEXT_TO_IMAGE_PARAMS - {'height', 'width'}`. " + "See existing pipeline tests for reference." + ) + + @property + def batch_params(self) -> frozenset: + raise NotImplementedError( + "You need to set the attribute `batch_params` in the child test class. " + "`batch_params` are the parameters required to be batched when passed to the pipeline's " + "`__call__` method. `pipeline_params.py` provides some common sets of parameters such as " + "`TEXT_TO_IMAGE_BATCH_PARAMS`, `IMAGE_VARIATION_BATCH_PARAMS`, etc... If your pipeline's " + "set of batch arguments has minor changes from one of the common sets of batch arguments, " + "do not make modifications to the existing common sets of batch arguments. I.e. a text to " + "image pipeline `negative_prompt` is not batched should set the attribute as " + "`batch_params = TEXT_TO_IMAGE_BATCH_PARAMS - {'negative_prompt'}`. " + "See existing pipeline tests for reference." + ) + + def setUp(self): + # clean up the VRAM before each test + super().setUp() + torch.compiler.reset() + gc.collect() + backend_empty_cache(torch_device) + + def tearDown(self): + # clean up the VRAM after each test in case of CUDA runtime errors + super().tearDown() + torch.compiler.reset() + gc.collect() + backend_empty_cache(torch_device) + + def test_pipeline_call_signature(self): + pipe = self.get_pipeline() + parameters = pipe.blocks.input_names + optional_parameters = pipe.default_call_parameters + intermediate_parameters = pipe.blocks.intermediate_input_names + + remaining_required_parameters = set() + + for param in self.params: + if param not in parameters: + remaining_required_parameters.add(param) + + self.assertTrue( + len(remaining_required_parameters) == 0, + f"Required parameters not present: {remaining_required_parameters}", + ) + + remaining_required_intermediate_parameters = set() + + for param in self.required_intermediate_params: + if param not in intermediate_parameters: + remaining_required_intermediate_parameters.add(param) + + self.assertTrue( + len(remaining_required_intermediate_parameters) == 0, + f"Required intermediate parameters not present: {remaining_required_intermediate_parameters}", + ) + + remaining_required_optional_parameters = set() + + for param in self.required_optional_params: + if param not in optional_parameters: + remaining_required_optional_parameters.add(param) + + self.assertTrue( + len(remaining_required_optional_parameters) == 0, + f"Required optional parameters not present: {remaining_required_optional_parameters}", + ) + + def test_inference_batch_consistent(self, batch_sizes=[2]): + self._test_inference_batch_consistent(batch_sizes=batch_sizes) + + def _test_inference_batch_consistent( + self, batch_sizes=[2], additional_params_copy_to_batched_inputs=["num_inference_steps"], batch_generator=True + ): + pipe = self.get_pipeline() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + inputs["generator"] = self.get_generator(0) + + logger = logging.get_logger(pipe.__module__) + logger.setLevel(level=diffusers.logging.FATAL) + + # prepare batched inputs + batched_inputs = [] + for batch_size in batch_sizes: + batched_input = {} + batched_input.update(inputs) + + for name in self.batch_params: + if name not in inputs: + continue + + value = inputs[name] + if name == "prompt": + len_prompt = len(value) + # make unequal batch sizes + batched_input[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)] + + # make last batch super long + batched_input[name][-1] = 100 * "very long" + + else: + batched_input[name] = batch_size * [value] + + if batch_generator and "generator" in inputs: + batched_input["generator"] = [self.get_generator(i) for i in range(batch_size)] + + if "batch_size" in inputs: + batched_input["batch_size"] = batch_size + + batched_inputs.append(batched_input) + + logger.setLevel(level=diffusers.logging.WARNING) + for batch_size, batched_input in zip(batch_sizes, batched_inputs): + output = pipe(**batched_input, output="images") + assert len(output) == batch_size + + def test_inference_batch_single_identical(self, batch_size=3, expected_max_diff=1e-4): + self._test_inference_batch_single_identical(batch_size=batch_size, expected_max_diff=expected_max_diff) + + def _test_inference_batch_single_identical( + self, + batch_size=2, + expected_max_diff=1e-4, + additional_params_copy_to_batched_inputs=["num_inference_steps"], + ): + pipe = self.get_pipeline() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + inputs = self.get_dummy_inputs(torch_device) + # Reset generator in case it is has been used in self.get_dummy_inputs + inputs["generator"] = self.get_generator(0) + + logger = logging.get_logger(pipe.__module__) + logger.setLevel(level=diffusers.logging.FATAL) + + # batchify inputs + batched_inputs = {} + batched_inputs.update(inputs) + + for name in self.batch_params: + if name not in inputs: + continue + + value = inputs[name] + if name == "prompt": + len_prompt = len(value) + batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)] + batched_inputs[name][-1] = 100 * "very long" + + else: + batched_inputs[name] = batch_size * [value] + + if "generator" in inputs: + batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)] + + if "batch_size" in inputs: + batched_inputs["batch_size"] = batch_size + + for arg in additional_params_copy_to_batched_inputs: + batched_inputs[arg] = inputs[arg] + + output = pipe(**inputs, output="images") + output_batch = pipe(**batched_inputs, output="images") + + assert output_batch.shape[0] == batch_size + + max_diff = np.abs(to_np(output_batch[0]) - to_np(output[0])).max() + assert max_diff < expected_max_diff + + @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU") + @require_accelerator + def test_float16_inference(self, expected_max_diff=5e-2): + pipe = self.get_pipeline(torch_dtype=torch.float32) + + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + pipe_fp16 = self.get_pipeline(torch_dtype=torch.float16) + pipe_fp16.to(torch_device, torch.float16) + pipe_fp16.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + # Reset generator in case it is used inside dummy inputs + if "generator" in inputs: + inputs["generator"] = self.get_generator(0) + output = pipe(**inputs, output="images") + + fp16_inputs = self.get_dummy_inputs(torch_device) + # Reset generator in case it is used inside dummy inputs + if "generator" in fp16_inputs: + fp16_inputs["generator"] = self.get_generator(0) + output_fp16 = pipe_fp16(**fp16_inputs, output="images") + + if isinstance(output, torch.Tensor): + output = output.cpu() + output_fp16 = output_fp16.cpu() + + max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten()) + assert max_diff < expected_max_diff + + @require_accelerator + def test_to_device(self): + pipe = self.get_pipeline() + pipe.set_progress_bar_config(disable=None) + + pipe.to("cpu") + model_devices = [ + component.device.type for component in pipe.components.values() if hasattr(component, "device") + ] + self.assertTrue(all(device == "cpu" for device in model_devices)) + + output_cpu = pipe(**self.get_dummy_inputs("cpu"), output="images") + self.assertTrue(np.isnan(output_cpu).sum() == 0) + + pipe.to(torch_device) + model_devices = [ + component.device.type for component in pipe.components.values() if hasattr(component, "device") + ] + self.assertTrue(all(device == torch_device for device in model_devices)) + + output_device = pipe(**self.get_dummy_inputs(torch_device), output="images") + self.assertTrue(np.isnan(to_np(output_device)).sum() == 0) + + def test_num_images_per_prompt(self): + pipe = self.get_pipeline() + + if "num_images_per_prompt" not in pipe.blocks.input_names: + return + + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + batch_sizes = [1, 2] + num_images_per_prompts = [1, 2] + + for batch_size in batch_sizes: + for num_images_per_prompt in num_images_per_prompts: + inputs = self.get_dummy_inputs(torch_device) + + for key in inputs.keys(): + if key in self.batch_params: + inputs[key] = batch_size * [inputs[key]] + + images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt, output="images") + + assert images.shape[0] == batch_size * num_images_per_prompt From 80702d222d5e798d321beea59fbaee945a625615 Mon Sep 17 00:00:00 2001 From: DN6 Date: Thu, 17 Jul 2025 13:05:43 +0530 Subject: [PATCH 10/16] update --- .../__init__.py | 0 ...t_modular_pipeline_stable_diffusion_xl.py} | 30 ++-- .../test_modular_pipelines_common.py | 131 ++++++------------ 3 files changed, 61 insertions(+), 100 deletions(-) rename tests/{modular => modular_pipelines}/__init__.py (100%) rename tests/{pipelines/stable_diffusion_xl/test_stable_diffusion_xl_modular.py => modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py} (97%) rename tests/{modular => modular_pipelines}/test_modular_pipelines_common.py (73%) diff --git a/tests/modular/__init__.py b/tests/modular_pipelines/__init__.py similarity index 100% rename from tests/modular/__init__.py rename to tests/modular_pipelines/__init__.py diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_modular.py b/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py similarity index 97% rename from tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_modular.py rename to tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py index 3c9ace762b7e..b8acca687c08 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_modular.py +++ b/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py @@ -40,12 +40,6 @@ from ...models.unets.test_models_unet_2d_condition import ( create_ip_adapter_state_dict, ) -from ..pipeline_params import ( - IMAGE_INPAINTING_BATCH_PARAMS, - IMAGE_INPAINTING_PARAMS, - TEXT_TO_IMAGE_BATCH_PARAMS, - TEXT_TO_IMAGE_PARAMS, -) from ..test_modular_pipelines_common import ( ModularPipelineTesterMixin, ) @@ -62,12 +56,18 @@ class SDXLModularTests: pipeline_class = StableDiffusionXLModularPipeline pipeline_blocks_class = StableDiffusionXLAutoBlocks repo = "hf-internal-testing/tiny-sdxl-modular" - params = (TEXT_TO_IMAGE_PARAMS | IMAGE_INPAINTING_PARAMS) - { - "guidance_scale", - "prompt_embeds", - "negative_prompt_embeds", - } - batch_params = TEXT_TO_IMAGE_BATCH_PARAMS | IMAGE_INPAINTING_BATCH_PARAMS + params = frozenset( + [ + "prompt", + "height", + "width", + "negative_prompt", + "cross_attention_kwargs", + "image", + "mask_image", + ] + ) + batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"]) def get_pipeline(self, components_manager=None, torch_dtype=torch.float32): pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager) @@ -99,9 +99,9 @@ def _test_stable_diffusion_xl_euler(self, expected_image_shape, expected_slice, assert image.shape == expected_image_shape - assert np.abs(image_slice.flatten() - expected_slice).max() < expected_max_diff, ( - f"image_slice: {image_slice.flatten()}, expected_slice: {expected_slice}" - ) + assert ( + np.abs(image_slice.flatten() - expected_slice).max() < expected_max_diff + ), f"Image Slice does not match expected slice" class SDXLModularIPAdapterTests: diff --git a/tests/modular/test_modular_pipelines_common.py b/tests/modular_pipelines/test_modular_pipelines_common.py similarity index 73% rename from tests/modular/test_modular_pipelines_common.py rename to tests/modular_pipelines/test_modular_pipelines_common.py index 2c176d746d69..e150a1e59e07 100644 --- a/tests/modular/test_modular_pipelines_common.py +++ b/tests/modular_pipelines/test_modular_pipelines_common.py @@ -2,13 +2,11 @@ import unittest from typing import Callable, Union +from diffusers.utils.dummy_pt_objects import ModularPipeline, ModularPipelineBlocks import numpy as np import torch import diffusers -from diffusers import ( - DiffusionPipeline, -) from diffusers.utils import logging from diffusers.utils.testing_utils import ( backend_empty_cache, @@ -42,7 +40,7 @@ class ModularPipelineTesterMixin: # Canonical parameters that are passed to `__call__` regardless # of the type of pipeline. They are always optional and have common # sense default values. - required_optional_params = frozenset( + optional_params = frozenset( [ "num_inference_steps", "num_images_per_prompt", @@ -51,7 +49,7 @@ class ModularPipelineTesterMixin: ] ) # this is modular specific: generator needs to be a intermediate input because it's mutable - required_intermediate_params = frozenset( + intermediate_params = frozenset( [ "generator", ] @@ -63,7 +61,7 @@ def get_generator(self, seed): return generator @property - def pipeline_class(self) -> Union[Callable, DiffusionPipeline]: + def pipeline_class(self) -> Union[Callable, ModularPipeline]: raise NotImplementedError( "You need to set the attribute `pipeline_class = ClassNameOfPipeline` in the child test class. " "See existing pipeline tests for reference." @@ -76,7 +74,7 @@ def repo(self) -> str: ) @property - def pipeline_blocks_class(self) -> Union[Callable, DiffusionPipeline]: + def pipeline_blocks_class(self) -> Union[Callable, ModularPipelineBlocks]: raise NotImplementedError( "You need to set the attribute `pipeline_blocks_class = ClassNameOfPipelineBlocks` in the child test class. " "See existing pipeline tests for reference." @@ -139,49 +137,21 @@ def tearDown(self): def test_pipeline_call_signature(self): pipe = self.get_pipeline() - parameters = pipe.blocks.input_names - optional_parameters = pipe.default_call_parameters + input_parameters = pipe.blocks.input_names intermediate_parameters = pipe.blocks.intermediate_input_names + optional_parameters = pipe.default_call_parameters - remaining_required_parameters = set() - - for param in self.params: - if param not in parameters: - remaining_required_parameters.add(param) - - self.assertTrue( - len(remaining_required_parameters) == 0, - f"Required parameters not present: {remaining_required_parameters}", - ) - - remaining_required_intermediate_parameters = set() - - for param in self.required_intermediate_params: - if param not in intermediate_parameters: - remaining_required_intermediate_parameters.add(param) - - self.assertTrue( - len(remaining_required_intermediate_parameters) == 0, - f"Required intermediate parameters not present: {remaining_required_intermediate_parameters}", - ) - - remaining_required_optional_parameters = set() - - for param in self.required_optional_params: - if param not in optional_parameters: - remaining_required_optional_parameters.add(param) - - self.assertTrue( - len(remaining_required_optional_parameters) == 0, - f"Required optional parameters not present: {remaining_required_optional_parameters}", - ) + def _check_for_parameters(parameters, expected_parameters, param_type): + remaining_parameters = set(param for param in parameters if param not in expected_parameters) + assert ( + len(remaining_parameters) == 0 + ), f"Required {param_type} parameters not present: {remaining_parameters}" - def test_inference_batch_consistent(self, batch_sizes=[2]): - self._test_inference_batch_consistent(batch_sizes=batch_sizes) + _check_for_parameters(self.params, input_parameters, "input") + _check_for_parameters(self.intermediate_params, intermediate_parameters, "intermediate") + _check_for_parameters(self.optional_params, optional_parameters, "optional") - def _test_inference_batch_consistent( - self, batch_sizes=[2], additional_params_copy_to_batched_inputs=["num_inference_steps"], batch_generator=True - ): + def test_inference_batch_consistent(self, batch_sizes=[2], batch_generator=True): pipe = self.get_pipeline() pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) @@ -203,16 +173,7 @@ def _test_inference_batch_consistent( continue value = inputs[name] - if name == "prompt": - len_prompt = len(value) - # make unequal batch sizes - batched_input[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)] - - # make last batch super long - batched_input[name][-1] = 100 * "very long" - - else: - batched_input[name] = batch_size * [value] + batched_input[name] = batch_size * [value] if batch_generator and "generator" in inputs: batched_input["generator"] = [self.get_generator(i) for i in range(batch_size)] @@ -225,21 +186,18 @@ def _test_inference_batch_consistent( logger.setLevel(level=diffusers.logging.WARNING) for batch_size, batched_input in zip(batch_sizes, batched_inputs): output = pipe(**batched_input, output="images") - assert len(output) == batch_size - - def test_inference_batch_single_identical(self, batch_size=3, expected_max_diff=1e-4): - self._test_inference_batch_single_identical(batch_size=batch_size, expected_max_diff=expected_max_diff) + assert len(output) == batch_size, "Output is different from expected batch size" - def _test_inference_batch_single_identical( + def test_batch_inference_identical_to_single( self, batch_size=2, expected_max_diff=1e-4, - additional_params_copy_to_batched_inputs=["num_inference_steps"], ): pipe = self.get_pipeline() pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs(torch_device) + # Reset generator in case it is has been used in self.get_dummy_inputs inputs["generator"] = self.get_generator(0) @@ -255,13 +213,7 @@ def _test_inference_batch_single_identical( continue value = inputs[name] - if name == "prompt": - len_prompt = len(value) - batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)] - batched_inputs[name][-1] = 100 * "very long" - - else: - batched_inputs[name] = batch_size * [value] + batched_inputs[name] = batch_size * [value] if "generator" in inputs: batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)] @@ -269,26 +221,22 @@ def _test_inference_batch_single_identical( if "batch_size" in inputs: batched_inputs["batch_size"] = batch_size - for arg in additional_params_copy_to_batched_inputs: - batched_inputs[arg] = inputs[arg] - output = pipe(**inputs, output="images") output_batch = pipe(**batched_inputs, output="images") assert output_batch.shape[0] == batch_size max_diff = np.abs(to_np(output_batch[0]) - to_np(output[0])).max() - assert max_diff < expected_max_diff + assert max_diff < expected_max_diff, "Batch inference results different from single inference results" @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU") @require_accelerator def test_float16_inference(self, expected_max_diff=5e-2): - pipe = self.get_pipeline(torch_dtype=torch.float32) - - pipe.to(torch_device) + pipe = self.get_pipeline() + pipe.to(torch_device, torch.float32) pipe.set_progress_bar_config(disable=None) - pipe_fp16 = self.get_pipeline(torch_dtype=torch.float16) + pipe_fp16 = self.get_pipeline() pipe_fp16.to(torch_device, torch.float16) pipe_fp16.set_progress_bar_config(disable=None) @@ -309,7 +257,7 @@ def test_float16_inference(self, expected_max_diff=5e-2): output_fp16 = output_fp16.cpu() max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten()) - assert max_diff < expected_max_diff + assert max_diff < expected_max_diff, "FP16 inference is different from FP32 inference" @require_accelerator def test_to_device(self): @@ -320,19 +268,32 @@ def test_to_device(self): model_devices = [ component.device.type for component in pipe.components.values() if hasattr(component, "device") ] - self.assertTrue(all(device == "cpu" for device in model_devices)) - - output_cpu = pipe(**self.get_dummy_inputs("cpu"), output="images") - self.assertTrue(np.isnan(output_cpu).sum() == 0) + assert all(device == "cpu" for device in model_devices), "All pipeline components are not on CPU" pipe.to(torch_device) model_devices = [ component.device.type for component in pipe.components.values() if hasattr(component, "device") ] - self.assertTrue(all(device == torch_device for device in model_devices)) + assert all( + device == torch_device for device in model_devices + ), "All pipeline components are not on accelerator device" + + def test_inference_is_not_nan_cpu(self): + pipe = self.get_pipeline() + pipe.set_progress_bar_config(disable=None) + pipe.to("cpu") + + output = pipe(**self.get_dummy_inputs("cpu"), output="np") + assert np.isnan(to_np(output)).sum() == 0, "CPU Inference returns NaN" + + @require_accelerator + def test_inferece_is_not_nan(self): + pipe = self.get_pipeline() + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) - output_device = pipe(**self.get_dummy_inputs(torch_device), output="images") - self.assertTrue(np.isnan(to_np(output_device)).sum() == 0) + output = pipe(**self.get_dummy_inputs(torch_device), output="np") + assert np.isnan(to_np(output)).sum() == 0, "Accelerator Inference returns NaN" def test_num_images_per_prompt(self): pipe = self.get_pipeline() From 54e17f3084b5d3b73b63c2e102ab7f796203ea73 Mon Sep 17 00:00:00 2001 From: DN6 Date: Mon, 21 Jul 2025 08:56:47 +0530 Subject: [PATCH 11/16] update --- .github/workflows/pr_modular_tests.yml | 141 +++++++ .../stable_diffusion_xl/__init__.py | 0 ...st_modular_pipeline_stable_diffusion_xl.py | 90 ++--- .../test_modular_pipelines_common.py | 12 +- .../test_modular_pipelines_common.py | 369 ------------------ 5 files changed, 178 insertions(+), 434 deletions(-) create mode 100644 .github/workflows/pr_modular_tests.yml create mode 100644 tests/modular_pipelines/stable_diffusion_xl/__init__.py delete mode 100644 tests/pipelines/test_modular_pipelines_common.py diff --git a/.github/workflows/pr_modular_tests.yml b/.github/workflows/pr_modular_tests.yml new file mode 100644 index 000000000000..e01345e32524 --- /dev/null +++ b/.github/workflows/pr_modular_tests.yml @@ -0,0 +1,141 @@ +name: Fast PR tests for Modular + +on: + pull_request: + branches: [main] + paths: + - "src/diffusers/modular_pipelines/**.py" + - "src/diffusers/models/modeling_utils.py" + - "src/diffusers/models/model_loading_utils.py" + - "src/diffusers/pipelines/pipeline_utils.py" + - "src/diffusers/pipeline_loading_utils.py" + - "src/diffusers/loaders/lora_base.py" + - "src/diffusers/loaders/lora_pipeline.py" + - "src/diffusers/loaders/peft.py" + - "tests/modular_pipelines/**.py" + - ".github/**.yml" + - "utils/**.py" + - "setup.py" + push: + branches: + - ci-* + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +env: + DIFFUSERS_IS_CI: yes + HF_HUB_ENABLE_HF_TRANSFER: 1 + OMP_NUM_THREADS: 4 + MKL_NUM_THREADS: 4 + PYTEST_TIMEOUT: 60 + +jobs: + check_code_quality: + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.10" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install .[quality] + - name: Check quality + run: make quality + - name: Check if failure + if: ${{ failure() }} + run: | + echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make style && make quality'" >> $GITHUB_STEP_SUMMARY + + check_repository_consistency: + needs: check_code_quality + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.10" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install .[quality] + - name: Check repo consistency + run: | + python utils/check_copies.py + python utils/check_dummies.py + python utils/check_support_list.py + make deps_table_check_updated + - name: Check if failure + if: ${{ failure() }} + run: | + echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY + + run_fast_tests: + needs: [check_code_quality, check_repository_consistency] + strategy: + fail-fast: false + matrix: + config: + - name: Fast PyTorch Modular Pipeline CPU tests + framework: pytorch_pipelines + runner: aws-highmemory-32-plus + image: diffusers/diffusers-pytorch-cpu + report: torch_cpu_modular_pipelines + + name: ${{ matrix.config.name }} + + runs-on: + group: ${{ matrix.config.runner }} + + container: + image: ${{ matrix.config.image }} + options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ + + defaults: + run: + shell: bash + + steps: + - name: Checkout diffusers + uses: actions/checkout@v3 + with: + fetch-depth: 2 + + - name: Install dependencies + run: | + python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" + python -m uv pip install -e [quality,test] + pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps + pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps + + - name: Environment + run: | + python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" + python utils/print_env.py + + - name: Run fast PyTorch Pipeline CPU tests + if: ${{ matrix.config.framework == 'pytorch_pipelines' }} + run: | + python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" + python -m pytest -n 8 --max-worker-restart=0 --dist=loadfile \ + -s -v -k "not Flax and not Onnx" \ + --make-reports=tests_${{ matrix.config.report }} \ + tests/modular_pipelines + + - name: Failure short reports + if: ${{ failure() }} + run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt + + - name: Test suite reports artifacts + if: ${{ always() }} + uses: actions/upload-artifact@v4 + with: + name: pr_${{ matrix.config.framework }}_${{ matrix.config.report }}_test_reports + path: reports + + diff --git a/tests/modular_pipelines/stable_diffusion_xl/__init__.py b/tests/modular_pipelines/stable_diffusion_xl/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py b/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py index b8acca687c08..b8a9a0c9a6d6 100644 --- a/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py +++ b/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py @@ -101,7 +101,7 @@ def _test_stable_diffusion_xl_euler(self, expected_image_shape, expected_slice, assert ( np.abs(image_slice.flatten() - expected_slice).max() < expected_max_diff - ), f"Image Slice does not match expected slice" + ), "Image Slice does not match expected slice" class SDXLModularIPAdapterTests: @@ -114,30 +114,20 @@ def test_pipeline_inputs_and_blocks(self): parameters = blocks.input_names assert issubclass(self.pipeline_class, ModularIPAdapterMixin) - self.assertIn( - "ip_adapter_image", - parameters, - "`ip_adapter_image` argument must be supported by the `__call__` method", - ) - self.assertIn( - "ip_adapter", - blocks.sub_blocks, - "pipeline must contain an IPAdapter block", - ) + assert ( + "ip_adapter_image" in parameters + ), "`ip_adapter_image` argument must be supported by the `__call__` method" + assert "ip_adapter" in blocks.sub_blocks, "pipeline must contain an IPAdapter block" _ = blocks.sub_blocks.pop("ip_adapter") parameters = blocks.input_names intermediate_parameters = blocks.intermediate_input_names - self.assertNotIn( - "ip_adapter_image", - parameters, - "`ip_adapter_image` argument must be removed from the `__call__` method", - ) - self.assertNotIn( - "ip_adapter_image_embeds", - intermediate_parameters, - "`ip_adapter_image_embeds` argument must be supported by the `__call__` method", - ) + assert ( + "ip_adapter_image" not in parameters + ), "`ip_adapter_image` argument must be removed from the `__call__` method" + assert ( + "ip_adapter_image_embeds" not in intermediate_parameters + ), "`ip_adapter_image_embeds` argument must be supported by the `__call__` method" def _get_dummy_image_embeds(self, cross_attention_dim: int = 32): return torch.randn((1, 1, cross_attention_dim), device=torch_device) @@ -213,14 +203,10 @@ def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=N max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max() max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max() - self.assertLess( - max_diff_without_adapter_scale, - expected_max_diff, - "Output without ip-adapter must be same as normal inference", - ) - self.assertGreater( - max_diff_with_adapter_scale, 1e-2, "Output with ip-adapter must be different from normal inference" - ) + assert ( + max_diff_without_adapter_scale < expected_max_diff + ), "Output without ip-adapter must be same as normal inference" + assert max_diff_with_adapter_scale > 1e-2, "Output with ip-adapter must be different from normal inference" # 2. Multi IP-Adapter test cases adapter_state_dict_1 = create_ip_adapter_state_dict(pipe.unet) @@ -249,16 +235,12 @@ def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=N output_without_multi_adapter_scale - output_without_adapter ).max() max_diff_with_multi_adapter_scale = np.abs(output_with_multi_adapter_scale - output_without_adapter).max() - self.assertLess( - max_diff_without_multi_adapter_scale, - expected_max_diff, - "Output without multi-ip-adapter must be same as normal inference", - ) - self.assertGreater( - max_diff_with_multi_adapter_scale, - 1e-2, - "Output with multi-ip-adapter scale must be different from normal inference", - ) + assert ( + max_diff_without_multi_adapter_scale < expected_max_diff + ), "Output without multi-ip-adapter must be same as normal inference" + assert ( + max_diff_with_multi_adapter_scale > 1e-2 + ), "Output with multi-ip-adapter scale must be different from normal inference" class SDXLModularControlNetTests: @@ -270,16 +252,10 @@ def test_pipeline_inputs(self): blocks = self.pipeline_blocks_class() parameters = blocks.input_names - self.assertIn( - "control_image", - parameters, - "`control_image` argument must be supported by the `__call__` method", - ) - self.assertIn( - "controlnet_conditioning_scale", - parameters, - "`controlnet_conditioning_scale` argument must be supported by the `__call__` method", - ) + assert "control_image" in parameters, "`control_image` argument must be supported by the `__call__` method" + assert ( + "controlnet_conditioning_scale" in parameters + ), "`controlnet_conditioning_scale` argument must be supported by the `__call__` method" def _modify_inputs_for_controlnet_test(self, inputs: Dict[str, Any]): controlnet_embedder_scale_factor = 2 @@ -325,14 +301,10 @@ def test_controlnet(self, expected_max_diff: float = 1e-4, expected_pipe_slice=N max_diff_without_controlnet_scale = np.abs(output_without_controlnet_scale - output_without_controlnet).max() max_diff_with_controlnet_scale = np.abs(output_with_controlnet_scale - output_without_controlnet).max() - self.assertLess( - max_diff_without_controlnet_scale, - expected_max_diff, - "Output without controlnet must be same as normal inference", - ) - self.assertGreater( - max_diff_with_controlnet_scale, 1e-2, "Output with controlnet must be different from normal inference" - ) + assert ( + max_diff_without_controlnet_scale < expected_max_diff + ), "Output without controlnet must be same as normal inference" + assert max_diff_with_controlnet_scale > 1e-2, "Output with controlnet must be different from normal inference" def test_controlnet_cfg(self): pipe = self.get_pipeline() @@ -354,7 +326,7 @@ def test_controlnet_cfg(self): assert out_cfg.shape == out_no_cfg.shape max_diff = np.abs(out_cfg - out_no_cfg).max() - self.assertGreater(max_diff, 1e-2) + assert max_diff > 1e-2, "Output with CFG must be different from normal inference" class SDXLModularGuiderTests: @@ -378,7 +350,7 @@ def test_guider_cfg(self): assert out_cfg.shape == out_no_cfg.shape max_diff = np.abs(out_cfg - out_no_cfg).max() - self.assertGreater(max_diff, 1e-2) + assert max_diff > 1e-2, "Output with CFG must be different from normal inference" class SDXLModularPipelineFastTests( diff --git a/tests/modular_pipelines/test_modular_pipelines_common.py b/tests/modular_pipelines/test_modular_pipelines_common.py index e150a1e59e07..1bb9bcf2cb27 100644 --- a/tests/modular_pipelines/test_modular_pipelines_common.py +++ b/tests/modular_pipelines/test_modular_pipelines_common.py @@ -2,12 +2,12 @@ import unittest from typing import Callable, Union -from diffusers.utils.dummy_pt_objects import ModularPipeline, ModularPipelineBlocks import numpy as np import torch import diffusers from diffusers.utils import logging +from diffusers.utils.dummy_pt_objects import ModularPipeline, ModularPipelineBlocks from diffusers.utils.testing_utils import ( backend_empty_cache, numpy_cosine_similarity_distance, @@ -142,7 +142,7 @@ def test_pipeline_call_signature(self): optional_parameters = pipe.default_call_parameters def _check_for_parameters(parameters, expected_parameters, param_type): - remaining_parameters = set(param for param in parameters if param not in expected_parameters) + remaining_parameters = {param for param in parameters if param not in expected_parameters} assert ( len(remaining_parameters) == 0 ), f"Required {param_type} parameters not present: {remaining_parameters}" @@ -188,7 +188,7 @@ def test_inference_batch_consistent(self, batch_sizes=[2], batch_generator=True) output = pipe(**batched_input, output="images") assert len(output) == batch_size, "Output is different from expected batch size" - def test_batch_inference_identical_to_single( + def test_inference_batch_single_identical( self, batch_size=2, expected_max_diff=1e-4, @@ -283,16 +283,16 @@ def test_inference_is_not_nan_cpu(self): pipe.set_progress_bar_config(disable=None) pipe.to("cpu") - output = pipe(**self.get_dummy_inputs("cpu"), output="np") + output = pipe(**self.get_dummy_inputs("cpu"), output="images") assert np.isnan(to_np(output)).sum() == 0, "CPU Inference returns NaN" @require_accelerator - def test_inferece_is_not_nan(self): + def test_inference_is_not_nan(self): pipe = self.get_pipeline() pipe.set_progress_bar_config(disable=None) pipe.to(torch_device) - output = pipe(**self.get_dummy_inputs(torch_device), output="np") + output = pipe(**self.get_dummy_inputs(torch_device), output="images") assert np.isnan(to_np(output)).sum() == 0, "Accelerator Inference returns NaN" def test_num_images_per_prompt(self): diff --git a/tests/pipelines/test_modular_pipelines_common.py b/tests/pipelines/test_modular_pipelines_common.py deleted file mode 100644 index 4bd45e207b91..000000000000 --- a/tests/pipelines/test_modular_pipelines_common.py +++ /dev/null @@ -1,369 +0,0 @@ -import gc -import unittest -from typing import Callable, Union - -import numpy as np -import torch - -import diffusers -from diffusers import ( - DiffusionPipeline, -) -from diffusers.utils import logging -from diffusers.utils.testing_utils import ( - backend_empty_cache, - numpy_cosine_similarity_distance, - require_accelerator, - require_torch, - torch_device, -) - - -def to_np(tensor): - if isinstance(tensor, torch.Tensor): - tensor = tensor.detach().cpu().numpy() - - return tensor - - -@require_torch -class ModularPipelineTesterMixin: - """ - This mixin is designed to be used with unittest.TestCase classes. - It provides a set of common tests for each modular pipeline, - including: - - test_pipeline_call_signature: check if the pipeline's __call__ method has all required parameters - - test_inference_batch_consistent: check if the pipeline's __call__ method can handle batch inputs - - test_inference_batch_single_identical: check if the pipeline's __call__ method can handle single input - - test_float16_inference: check if the pipeline's __call__ method can handle float16 inputs - - test_to_device: check if the pipeline's __call__ method can handle different devices - """ - - # Canonical parameters that are passed to `__call__` regardless - # of the type of pipeline. They are always optional and have common - # sense default values. - required_optional_params = frozenset( - [ - "num_inference_steps", - "num_images_per_prompt", - "latents", - "output_type", - ] - ) - # this is modular specific: generator needs to be a intermediate input because it's mutable - required_intermediate_params = frozenset( - [ - "generator", - ] - ) - - def get_generator(self, seed): - device = torch_device if torch_device != "mps" else "cpu" - generator = torch.Generator(device).manual_seed(seed) - return generator - - @property - def pipeline_class(self) -> Union[Callable, DiffusionPipeline]: - raise NotImplementedError( - "You need to set the attribute `pipeline_class = ClassNameOfPipeline` in the child test class. " - "See existing pipeline tests for reference." - ) - - @property - def repo(self) -> str: - raise NotImplementedError( - "You need to set the attribute `repo` in the child test class. See existing pipeline tests for reference." - ) - - @property - def pipeline_blocks_class(self) -> Union[Callable, DiffusionPipeline]: - raise NotImplementedError( - "You need to set the attribute `pipeline_blocks_class = ClassNameOfPipelineBlocks` in the child test class. " - "See existing pipeline tests for reference." - ) - - def get_pipeline(self): - raise NotImplementedError( - "You need to implement `get_pipeline(self)` in the child test class. " - "See existing pipeline tests for reference." - ) - - def get_dummy_inputs(self, device, seed=0): - raise NotImplementedError( - "You need to implement `get_dummy_inputs(self, device, seed)` in the child test class. " - "See existing pipeline tests for reference." - ) - - @property - def params(self) -> frozenset: - raise NotImplementedError( - "You need to set the attribute `params` in the child test class. " - "`params` are checked for if all values are present in `__call__`'s signature." - " You can set `params` using one of the common set of parameters defined in `pipeline_params.py`" - " e.g., `TEXT_TO_IMAGE_PARAMS` defines the common parameters used in text to " - "image pipelines, including prompts and prompt embedding overrides." - "If your pipeline's set of arguments has minor changes from one of the common sets of arguments, " - "do not make modifications to the existing common sets of arguments. I.e. a text to image pipeline " - "with non-configurable height and width arguments should set the attribute as " - "`params = TEXT_TO_IMAGE_PARAMS - {'height', 'width'}`. " - "See existing pipeline tests for reference." - ) - - @property - def batch_params(self) -> frozenset: - raise NotImplementedError( - "You need to set the attribute `batch_params` in the child test class. " - "`batch_params` are the parameters required to be batched when passed to the pipeline's " - "`__call__` method. `pipeline_params.py` provides some common sets of parameters such as " - "`TEXT_TO_IMAGE_BATCH_PARAMS`, `IMAGE_VARIATION_BATCH_PARAMS`, etc... If your pipeline's " - "set of batch arguments has minor changes from one of the common sets of batch arguments, " - "do not make modifications to the existing common sets of batch arguments. I.e. a text to " - "image pipeline `negative_prompt` is not batched should set the attribute as " - "`batch_params = TEXT_TO_IMAGE_BATCH_PARAMS - {'negative_prompt'}`. " - "See existing pipeline tests for reference." - ) - - def setUp(self): - # clean up the VRAM before each test - super().setUp() - torch.compiler.reset() - gc.collect() - backend_empty_cache(torch_device) - - def tearDown(self): - # clean up the VRAM after each test in case of CUDA runtime errors - super().tearDown() - torch.compiler.reset() - gc.collect() - backend_empty_cache(torch_device) - - def test_pipeline_call_signature(self): - pipe = self.get_pipeline() - parameters = pipe.blocks.input_names - optional_parameters = pipe.default_call_parameters - intermediate_parameters = pipe.blocks.intermediate_input_names - - remaining_required_parameters = set() - - for param in self.params: - if param not in parameters: - remaining_required_parameters.add(param) - - self.assertTrue( - len(remaining_required_parameters) == 0, - f"Required parameters not present: {remaining_required_parameters}", - ) - - remaining_required_intermediate_parameters = set() - - for param in self.required_intermediate_params: - if param not in intermediate_parameters: - remaining_required_intermediate_parameters.add(param) - - self.assertTrue( - len(remaining_required_intermediate_parameters) == 0, - f"Required intermediate parameters not present: {remaining_required_intermediate_parameters}", - ) - - remaining_required_optional_parameters = set() - - for param in self.required_optional_params: - if param not in optional_parameters: - remaining_required_optional_parameters.add(param) - - self.assertTrue( - len(remaining_required_optional_parameters) == 0, - f"Required optional parameters not present: {remaining_required_optional_parameters}", - ) - - def test_inference_batch_consistent(self, batch_sizes=[2]): - self._test_inference_batch_consistent(batch_sizes=batch_sizes) - - def _test_inference_batch_consistent( - self, batch_sizes=[2], additional_params_copy_to_batched_inputs=["num_inference_steps"], batch_generator=True - ): - pipe = self.get_pipeline() - pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(torch_device) - inputs["generator"] = self.get_generator(0) - - logger = logging.get_logger(pipe.__module__) - logger.setLevel(level=diffusers.logging.FATAL) - - # prepare batched inputs - batched_inputs = [] - for batch_size in batch_sizes: - batched_input = {} - batched_input.update(inputs) - - for name in self.batch_params: - if name not in inputs: - continue - - value = inputs[name] - if name == "prompt": - len_prompt = len(value) - # make unequal batch sizes - batched_input[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)] - - # make last batch super long - batched_input[name][-1] = 100 * "very long" - - else: - batched_input[name] = batch_size * [value] - - if batch_generator and "generator" in inputs: - batched_input["generator"] = [self.get_generator(i) for i in range(batch_size)] - - if "batch_size" in inputs: - batched_input["batch_size"] = batch_size - - batched_inputs.append(batched_input) - - logger.setLevel(level=diffusers.logging.WARNING) - for batch_size, batched_input in zip(batch_sizes, batched_inputs): - output = pipe(**batched_input, output="images") - assert len(output) == batch_size - - def test_inference_batch_single_identical(self, batch_size=3, expected_max_diff=1e-4): - self._test_inference_batch_single_identical(batch_size=batch_size, expected_max_diff=expected_max_diff) - - def _test_inference_batch_single_identical( - self, - batch_size=2, - expected_max_diff=1e-4, - additional_params_copy_to_batched_inputs=["num_inference_steps"], - ): - pipe = self.get_pipeline() - for component in pipe.components.values(): - if hasattr(component, "set_default_attn_processor"): - component.set_default_attn_processor() - - pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - inputs = self.get_dummy_inputs(torch_device) - # Reset generator in case it is has been used in self.get_dummy_inputs - inputs["generator"] = self.get_generator(0) - - logger = logging.get_logger(pipe.__module__) - logger.setLevel(level=diffusers.logging.FATAL) - - # batchify inputs - batched_inputs = {} - batched_inputs.update(inputs) - - for name in self.batch_params: - if name not in inputs: - continue - - value = inputs[name] - if name == "prompt": - len_prompt = len(value) - batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)] - batched_inputs[name][-1] = 100 * "very long" - - else: - batched_inputs[name] = batch_size * [value] - - if "generator" in inputs: - batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)] - - if "batch_size" in inputs: - batched_inputs["batch_size"] = batch_size - - for arg in additional_params_copy_to_batched_inputs: - batched_inputs[arg] = inputs[arg] - - output = pipe(**inputs, output="images") - output_batch = pipe(**batched_inputs, output="images") - - assert output_batch.shape[0] == batch_size - - max_diff = np.abs(to_np(output_batch[0]) - to_np(output[0])).max() - assert max_diff < expected_max_diff - - @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU") - @require_accelerator - def test_float16_inference(self, expected_max_diff=5e-2): - pipe = self.get_pipeline(torch_dtype=torch.float32) - for component in pipe.components.values(): - if hasattr(component, "set_default_attn_processor"): - component.set_default_attn_processor() - - pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - pipe_fp16 = self.get_pipeline(torch_dtype=torch.float16) - for component in pipe_fp16.components.values(): - if hasattr(component, "set_default_attn_processor"): - component.set_default_attn_processor() - pipe_fp16.to(torch_device, torch.float16) - pipe_fp16.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(torch_device) - # Reset generator in case it is used inside dummy inputs - if "generator" in inputs: - inputs["generator"] = self.get_generator(0) - output = pipe(**inputs, output="images") - - fp16_inputs = self.get_dummy_inputs(torch_device) - # Reset generator in case it is used inside dummy inputs - if "generator" in fp16_inputs: - fp16_inputs["generator"] = self.get_generator(0) - output_fp16 = pipe_fp16(**fp16_inputs, output="images") - - if isinstance(output, torch.Tensor): - output = output.cpu() - output_fp16 = output_fp16.cpu() - - max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten()) - assert max_diff < expected_max_diff - - @require_accelerator - def test_to_device(self): - pipe = self.get_pipeline() - pipe.set_progress_bar_config(disable=None) - - pipe.to("cpu") - model_devices = [ - component.device.type for component in pipe.components.values() if hasattr(component, "device") - ] - self.assertTrue(all(device == "cpu" for device in model_devices)) - - output_cpu = pipe(**self.get_dummy_inputs("cpu"), output="images") - self.assertTrue(np.isnan(output_cpu).sum() == 0) - - pipe.to(torch_device) - model_devices = [ - component.device.type for component in pipe.components.values() if hasattr(component, "device") - ] - self.assertTrue(all(device == torch_device for device in model_devices)) - - output_device = pipe(**self.get_dummy_inputs(torch_device), output="images") - self.assertTrue(np.isnan(to_np(output_device)).sum() == 0) - - def test_num_images_per_prompt(self): - pipe = self.get_pipeline() - - if "num_images_per_prompt" not in pipe.blocks.input_names: - return - - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - batch_sizes = [1, 2] - num_images_per_prompts = [1, 2] - - for batch_size in batch_sizes: - for num_images_per_prompt in num_images_per_prompts: - inputs = self.get_dummy_inputs(torch_device) - - for key in inputs.keys(): - if key in self.batch_params: - inputs[key] = batch_size * [inputs[key]] - - images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt, output="images") - - assert images.shape[0] == batch_size * num_images_per_prompt From 39be37459142642066f8821986d32a03b8820967 Mon Sep 17 00:00:00 2001 From: DN6 Date: Mon, 21 Jul 2025 09:03:32 +0530 Subject: [PATCH 12/16] update --- ...st_modular_pipeline_stable_diffusion_xl.py | 54 +++++++++---------- .../test_modular_pipelines_common.py | 12 ++--- 2 files changed, 33 insertions(+), 33 deletions(-) diff --git a/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py b/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py index b8a9a0c9a6d6..1b9c79136113 100644 --- a/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py +++ b/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py @@ -99,9 +99,9 @@ def _test_stable_diffusion_xl_euler(self, expected_image_shape, expected_slice, assert image.shape == expected_image_shape - assert ( - np.abs(image_slice.flatten() - expected_slice).max() < expected_max_diff - ), "Image Slice does not match expected slice" + assert np.abs(image_slice.flatten() - expected_slice).max() < expected_max_diff, ( + "Image Slice does not match expected slice" + ) class SDXLModularIPAdapterTests: @@ -114,20 +114,20 @@ def test_pipeline_inputs_and_blocks(self): parameters = blocks.input_names assert issubclass(self.pipeline_class, ModularIPAdapterMixin) - assert ( - "ip_adapter_image" in parameters - ), "`ip_adapter_image` argument must be supported by the `__call__` method" + assert "ip_adapter_image" in parameters, ( + "`ip_adapter_image` argument must be supported by the `__call__` method" + ) assert "ip_adapter" in blocks.sub_blocks, "pipeline must contain an IPAdapter block" _ = blocks.sub_blocks.pop("ip_adapter") parameters = blocks.input_names intermediate_parameters = blocks.intermediate_input_names - assert ( - "ip_adapter_image" not in parameters - ), "`ip_adapter_image` argument must be removed from the `__call__` method" - assert ( - "ip_adapter_image_embeds" not in intermediate_parameters - ), "`ip_adapter_image_embeds` argument must be supported by the `__call__` method" + assert "ip_adapter_image" not in parameters, ( + "`ip_adapter_image` argument must be removed from the `__call__` method" + ) + assert "ip_adapter_image_embeds" not in intermediate_parameters, ( + "`ip_adapter_image_embeds` argument must be supported by the `__call__` method" + ) def _get_dummy_image_embeds(self, cross_attention_dim: int = 32): return torch.randn((1, 1, cross_attention_dim), device=torch_device) @@ -203,9 +203,9 @@ def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=N max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max() max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max() - assert ( - max_diff_without_adapter_scale < expected_max_diff - ), "Output without ip-adapter must be same as normal inference" + assert max_diff_without_adapter_scale < expected_max_diff, ( + "Output without ip-adapter must be same as normal inference" + ) assert max_diff_with_adapter_scale > 1e-2, "Output with ip-adapter must be different from normal inference" # 2. Multi IP-Adapter test cases @@ -235,12 +235,12 @@ def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=N output_without_multi_adapter_scale - output_without_adapter ).max() max_diff_with_multi_adapter_scale = np.abs(output_with_multi_adapter_scale - output_without_adapter).max() - assert ( - max_diff_without_multi_adapter_scale < expected_max_diff - ), "Output without multi-ip-adapter must be same as normal inference" - assert ( - max_diff_with_multi_adapter_scale > 1e-2 - ), "Output with multi-ip-adapter scale must be different from normal inference" + assert max_diff_without_multi_adapter_scale < expected_max_diff, ( + "Output without multi-ip-adapter must be same as normal inference" + ) + assert max_diff_with_multi_adapter_scale > 1e-2, ( + "Output with multi-ip-adapter scale must be different from normal inference" + ) class SDXLModularControlNetTests: @@ -253,9 +253,9 @@ def test_pipeline_inputs(self): parameters = blocks.input_names assert "control_image" in parameters, "`control_image` argument must be supported by the `__call__` method" - assert ( - "controlnet_conditioning_scale" in parameters - ), "`controlnet_conditioning_scale` argument must be supported by the `__call__` method" + assert "controlnet_conditioning_scale" in parameters, ( + "`controlnet_conditioning_scale` argument must be supported by the `__call__` method" + ) def _modify_inputs_for_controlnet_test(self, inputs: Dict[str, Any]): controlnet_embedder_scale_factor = 2 @@ -301,9 +301,9 @@ def test_controlnet(self, expected_max_diff: float = 1e-4, expected_pipe_slice=N max_diff_without_controlnet_scale = np.abs(output_without_controlnet_scale - output_without_controlnet).max() max_diff_with_controlnet_scale = np.abs(output_with_controlnet_scale - output_without_controlnet).max() - assert ( - max_diff_without_controlnet_scale < expected_max_diff - ), "Output without controlnet must be same as normal inference" + assert max_diff_without_controlnet_scale < expected_max_diff, ( + "Output without controlnet must be same as normal inference" + ) assert max_diff_with_controlnet_scale > 1e-2, "Output with controlnet must be different from normal inference" def test_controlnet_cfg(self): diff --git a/tests/modular_pipelines/test_modular_pipelines_common.py b/tests/modular_pipelines/test_modular_pipelines_common.py index 1bb9bcf2cb27..56e8254c8c26 100644 --- a/tests/modular_pipelines/test_modular_pipelines_common.py +++ b/tests/modular_pipelines/test_modular_pipelines_common.py @@ -143,9 +143,9 @@ def test_pipeline_call_signature(self): def _check_for_parameters(parameters, expected_parameters, param_type): remaining_parameters = {param for param in parameters if param not in expected_parameters} - assert ( - len(remaining_parameters) == 0 - ), f"Required {param_type} parameters not present: {remaining_parameters}" + assert len(remaining_parameters) == 0, ( + f"Required {param_type} parameters not present: {remaining_parameters}" + ) _check_for_parameters(self.params, input_parameters, "input") _check_for_parameters(self.intermediate_params, intermediate_parameters, "intermediate") @@ -274,9 +274,9 @@ def test_to_device(self): model_devices = [ component.device.type for component in pipe.components.values() if hasattr(component, "device") ] - assert all( - device == torch_device for device in model_devices - ), "All pipeline components are not on accelerator device" + assert all(device == torch_device for device in model_devices), ( + "All pipeline components are not on accelerator device" + ) def test_inference_is_not_nan_cpu(self): pipe = self.get_pipeline() From 3aabef5de464788db946c53cd685609a96e4fec5 Mon Sep 17 00:00:00 2001 From: DN6 Date: Thu, 24 Jul 2025 22:18:15 +0530 Subject: [PATCH 13/16] update --- .../test_modular_pipelines_common.py | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/tests/modular_pipelines/test_modular_pipelines_common.py b/tests/modular_pipelines/test_modular_pipelines_common.py index 56e8254c8c26..5a8227f5ce7e 100644 --- a/tests/modular_pipelines/test_modular_pipelines_common.py +++ b/tests/modular_pipelines/test_modular_pipelines_common.py @@ -143,9 +143,9 @@ def test_pipeline_call_signature(self): def _check_for_parameters(parameters, expected_parameters, param_type): remaining_parameters = {param for param in parameters if param not in expected_parameters} - assert len(remaining_parameters) == 0, ( - f"Required {param_type} parameters not present: {remaining_parameters}" - ) + assert ( + len(remaining_parameters) == 0 + ), f"Required {param_type} parameters not present: {remaining_parameters}" _check_for_parameters(self.params, input_parameters, "input") _check_for_parameters(self.intermediate_params, intermediate_parameters, "intermediate") @@ -274,9 +274,9 @@ def test_to_device(self): model_devices = [ component.device.type for component in pipe.components.values() if hasattr(component, "device") ] - assert all(device == torch_device for device in model_devices), ( - "All pipeline components are not on accelerator device" - ) + assert all( + device == torch_device for device in model_devices + ), "All pipeline components are not on accelerator device" def test_inference_is_not_nan_cpu(self): pipe = self.get_pipeline() @@ -318,3 +318,13 @@ def test_num_images_per_prompt(self): images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt, output="images") assert images.shape[0] == batch_size * num_images_per_prompt + + @require_accelerator + def test_components_auto_cpu_offload(self): + base_pipe = self.get_pipeline().to(torch_device) + for component in base_pipe.components: + assert component.device == torch_device + + cm = ComponentsManager() + cm.enable_auto_cpu_offload(device=torch_device) + offload_pipe = self.get_pipeline(components_manager=cm) From a176cfde843857ce83e53437bcf9d6b9cc682fea Mon Sep 17 00:00:00 2001 From: DN6 Date: Thu, 24 Jul 2025 22:23:55 +0530 Subject: [PATCH 14/16] update --- ...st_modular_pipeline_stable_diffusion_xl.py | 74 +++++++------------ .../test_modular_pipelines_common.py | 35 ++++++++- 2 files changed, 59 insertions(+), 50 deletions(-) diff --git a/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py b/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py index 1b9c79136113..7f080bf8de63 100644 --- a/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py +++ b/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py @@ -99,9 +99,9 @@ def _test_stable_diffusion_xl_euler(self, expected_image_shape, expected_slice, assert image.shape == expected_image_shape - assert np.abs(image_slice.flatten() - expected_slice).max() < expected_max_diff, ( - "Image Slice does not match expected slice" - ) + assert ( + np.abs(image_slice.flatten() - expected_slice).max() < expected_max_diff + ), "Image Slice does not match expected slice" class SDXLModularIPAdapterTests: @@ -114,20 +114,20 @@ def test_pipeline_inputs_and_blocks(self): parameters = blocks.input_names assert issubclass(self.pipeline_class, ModularIPAdapterMixin) - assert "ip_adapter_image" in parameters, ( - "`ip_adapter_image` argument must be supported by the `__call__` method" - ) + assert ( + "ip_adapter_image" in parameters + ), "`ip_adapter_image` argument must be supported by the `__call__` method" assert "ip_adapter" in blocks.sub_blocks, "pipeline must contain an IPAdapter block" _ = blocks.sub_blocks.pop("ip_adapter") parameters = blocks.input_names intermediate_parameters = blocks.intermediate_input_names - assert "ip_adapter_image" not in parameters, ( - "`ip_adapter_image` argument must be removed from the `__call__` method" - ) - assert "ip_adapter_image_embeds" not in intermediate_parameters, ( - "`ip_adapter_image_embeds` argument must be supported by the `__call__` method" - ) + assert ( + "ip_adapter_image" not in parameters + ), "`ip_adapter_image` argument must be removed from the `__call__` method" + assert ( + "ip_adapter_image_embeds" not in intermediate_parameters + ), "`ip_adapter_image_embeds` argument must be supported by the `__call__` method" def _get_dummy_image_embeds(self, cross_attention_dim: int = 32): return torch.randn((1, 1, cross_attention_dim), device=torch_device) @@ -203,9 +203,9 @@ def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=N max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max() max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max() - assert max_diff_without_adapter_scale < expected_max_diff, ( - "Output without ip-adapter must be same as normal inference" - ) + assert ( + max_diff_without_adapter_scale < expected_max_diff + ), "Output without ip-adapter must be same as normal inference" assert max_diff_with_adapter_scale > 1e-2, "Output with ip-adapter must be different from normal inference" # 2. Multi IP-Adapter test cases @@ -235,12 +235,12 @@ def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=N output_without_multi_adapter_scale - output_without_adapter ).max() max_diff_with_multi_adapter_scale = np.abs(output_with_multi_adapter_scale - output_without_adapter).max() - assert max_diff_without_multi_adapter_scale < expected_max_diff, ( - "Output without multi-ip-adapter must be same as normal inference" - ) - assert max_diff_with_multi_adapter_scale > 1e-2, ( - "Output with multi-ip-adapter scale must be different from normal inference" - ) + assert ( + max_diff_without_multi_adapter_scale < expected_max_diff + ), "Output without multi-ip-adapter must be same as normal inference" + assert ( + max_diff_with_multi_adapter_scale > 1e-2 + ), "Output with multi-ip-adapter scale must be different from normal inference" class SDXLModularControlNetTests: @@ -253,9 +253,9 @@ def test_pipeline_inputs(self): parameters = blocks.input_names assert "control_image" in parameters, "`control_image` argument must be supported by the `__call__` method" - assert "controlnet_conditioning_scale" in parameters, ( - "`controlnet_conditioning_scale` argument must be supported by the `__call__` method" - ) + assert ( + "controlnet_conditioning_scale" in parameters + ), "`controlnet_conditioning_scale` argument must be supported by the `__call__` method" def _modify_inputs_for_controlnet_test(self, inputs: Dict[str, Any]): controlnet_embedder_scale_factor = 2 @@ -301,9 +301,9 @@ def test_controlnet(self, expected_max_diff: float = 1e-4, expected_pipe_slice=N max_diff_without_controlnet_scale = np.abs(output_without_controlnet_scale - output_without_controlnet).max() max_diff_with_controlnet_scale = np.abs(output_with_controlnet_scale - output_without_controlnet).max() - assert max_diff_without_controlnet_scale < expected_max_diff, ( - "Output without controlnet must be same as normal inference" - ) + assert ( + max_diff_without_controlnet_scale < expected_max_diff + ), "Output without controlnet must be same as normal inference" assert max_diff_with_controlnet_scale > 1e-2, "Output with controlnet must be different from normal inference" def test_controlnet_cfg(self): @@ -383,26 +383,6 @@ def test_stable_diffusion_xl_euler(self): def test_inference_batch_single_identical(self): super().test_inference_batch_single_identical(expected_max_diff=3e-3) - @require_torch_accelerator - def test_stable_diffusion_xl_offloads(self): - pipes = [] - sd_pipe = self.get_pipeline().to(torch_device) - pipes.append(sd_pipe) - - cm = ComponentsManager() - cm.enable_auto_cpu_offload(device=torch_device) - sd_pipe = self.get_pipeline(components_manager=cm) - pipes.append(sd_pipe) - - image_slices = [] - for pipe in pipes: - inputs = self.get_dummy_inputs(torch_device) - image = pipe(**inputs, output="images") - - image_slices.append(image[0, -3:, -3:, -1].flatten()) - - assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3 - def test_stable_diffusion_xl_save_from_pretrained(self): pipes = [] sd_pipe = self.get_pipeline().to(torch_device) diff --git a/tests/modular_pipelines/test_modular_pipelines_common.py b/tests/modular_pipelines/test_modular_pipelines_common.py index 5a8227f5ce7e..684ba3c5832c 100644 --- a/tests/modular_pipelines/test_modular_pipelines_common.py +++ b/tests/modular_pipelines/test_modular_pipelines_common.py @@ -320,11 +320,40 @@ def test_num_images_per_prompt(self): assert images.shape[0] == batch_size * num_images_per_prompt @require_accelerator - def test_components_auto_cpu_offload(self): + def test_components_auto_cpu_offload_inference_consistent(self): base_pipe = self.get_pipeline().to(torch_device) - for component in base_pipe.components: - assert component.device == torch_device cm = ComponentsManager() cm.enable_auto_cpu_offload(device=torch_device) offload_pipe = self.get_pipeline(components_manager=cm) + + image_slices = [] + for pipe in [base_pipe, offload_pipe]: + inputs = self.get_dummy_inputs(torch_device) + image = pipe(**inputs, output="images") + + image_slices.append(image[0, -3:, -3:, -1].flatten()) + + assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3 + + def test_save_from_pretrained(self): + pipes = [] + base_pipe = self.get_pipeline().to(torch_device) + pipes.append(base_pipe) + + with tempfile.TemporaryDirectory() as tmpdirname: + base_pipe.save_pretrained(tmpdirname) + pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device) + pipe.load_default_components(torch_dtype=torch.float32) + pipe.to(torch_device) + + pipes.append(pipe) + + image_slices = [] + for pipe in pipes: + inputs = self.get_dummy_inputs(torch_device) + image = pipe(**inputs, output="images") + + image_slices.append(image[0, -3:, -3:, -1].flatten()) + + assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3 From 1f0570dba0967dafd1d7f19493b651f6b817d35e Mon Sep 17 00:00:00 2001 From: DN6 Date: Thu, 24 Jul 2025 22:26:33 +0530 Subject: [PATCH 15/16] update --- ...st_modular_pipeline_stable_diffusion_xl.py | 56 +++++++++---------- .../test_modular_pipelines_common.py | 17 +++--- 2 files changed, 36 insertions(+), 37 deletions(-) diff --git a/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py b/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py index 7f080bf8de63..2e7c90d8e490 100644 --- a/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py +++ b/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py @@ -24,7 +24,6 @@ from diffusers import ( ClassifierFreeGuidance, - ComponentsManager, ModularPipeline, StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline, @@ -33,7 +32,6 @@ from diffusers.utils.testing_utils import ( enable_full_determinism, floats_tensor, - require_torch_accelerator, torch_device, ) @@ -99,9 +97,9 @@ def _test_stable_diffusion_xl_euler(self, expected_image_shape, expected_slice, assert image.shape == expected_image_shape - assert ( - np.abs(image_slice.flatten() - expected_slice).max() < expected_max_diff - ), "Image Slice does not match expected slice" + assert np.abs(image_slice.flatten() - expected_slice).max() < expected_max_diff, ( + "Image Slice does not match expected slice" + ) class SDXLModularIPAdapterTests: @@ -114,20 +112,20 @@ def test_pipeline_inputs_and_blocks(self): parameters = blocks.input_names assert issubclass(self.pipeline_class, ModularIPAdapterMixin) - assert ( - "ip_adapter_image" in parameters - ), "`ip_adapter_image` argument must be supported by the `__call__` method" + assert "ip_adapter_image" in parameters, ( + "`ip_adapter_image` argument must be supported by the `__call__` method" + ) assert "ip_adapter" in blocks.sub_blocks, "pipeline must contain an IPAdapter block" _ = blocks.sub_blocks.pop("ip_adapter") parameters = blocks.input_names intermediate_parameters = blocks.intermediate_input_names - assert ( - "ip_adapter_image" not in parameters - ), "`ip_adapter_image` argument must be removed from the `__call__` method" - assert ( - "ip_adapter_image_embeds" not in intermediate_parameters - ), "`ip_adapter_image_embeds` argument must be supported by the `__call__` method" + assert "ip_adapter_image" not in parameters, ( + "`ip_adapter_image` argument must be removed from the `__call__` method" + ) + assert "ip_adapter_image_embeds" not in intermediate_parameters, ( + "`ip_adapter_image_embeds` argument must be supported by the `__call__` method" + ) def _get_dummy_image_embeds(self, cross_attention_dim: int = 32): return torch.randn((1, 1, cross_attention_dim), device=torch_device) @@ -203,9 +201,9 @@ def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=N max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max() max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max() - assert ( - max_diff_without_adapter_scale < expected_max_diff - ), "Output without ip-adapter must be same as normal inference" + assert max_diff_without_adapter_scale < expected_max_diff, ( + "Output without ip-adapter must be same as normal inference" + ) assert max_diff_with_adapter_scale > 1e-2, "Output with ip-adapter must be different from normal inference" # 2. Multi IP-Adapter test cases @@ -235,12 +233,12 @@ def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=N output_without_multi_adapter_scale - output_without_adapter ).max() max_diff_with_multi_adapter_scale = np.abs(output_with_multi_adapter_scale - output_without_adapter).max() - assert ( - max_diff_without_multi_adapter_scale < expected_max_diff - ), "Output without multi-ip-adapter must be same as normal inference" - assert ( - max_diff_with_multi_adapter_scale > 1e-2 - ), "Output with multi-ip-adapter scale must be different from normal inference" + assert max_diff_without_multi_adapter_scale < expected_max_diff, ( + "Output without multi-ip-adapter must be same as normal inference" + ) + assert max_diff_with_multi_adapter_scale > 1e-2, ( + "Output with multi-ip-adapter scale must be different from normal inference" + ) class SDXLModularControlNetTests: @@ -253,9 +251,9 @@ def test_pipeline_inputs(self): parameters = blocks.input_names assert "control_image" in parameters, "`control_image` argument must be supported by the `__call__` method" - assert ( - "controlnet_conditioning_scale" in parameters - ), "`controlnet_conditioning_scale` argument must be supported by the `__call__` method" + assert "controlnet_conditioning_scale" in parameters, ( + "`controlnet_conditioning_scale` argument must be supported by the `__call__` method" + ) def _modify_inputs_for_controlnet_test(self, inputs: Dict[str, Any]): controlnet_embedder_scale_factor = 2 @@ -301,9 +299,9 @@ def test_controlnet(self, expected_max_diff: float = 1e-4, expected_pipe_slice=N max_diff_without_controlnet_scale = np.abs(output_without_controlnet_scale - output_without_controlnet).max() max_diff_with_controlnet_scale = np.abs(output_with_controlnet_scale - output_without_controlnet).max() - assert ( - max_diff_without_controlnet_scale < expected_max_diff - ), "Output without controlnet must be same as normal inference" + assert max_diff_without_controlnet_scale < expected_max_diff, ( + "Output without controlnet must be same as normal inference" + ) assert max_diff_with_controlnet_scale > 1e-2, "Output with controlnet must be different from normal inference" def test_controlnet_cfg(self): diff --git a/tests/modular_pipelines/test_modular_pipelines_common.py b/tests/modular_pipelines/test_modular_pipelines_common.py index 684ba3c5832c..24c36b9fa93f 100644 --- a/tests/modular_pipelines/test_modular_pipelines_common.py +++ b/tests/modular_pipelines/test_modular_pipelines_common.py @@ -1,4 +1,5 @@ import gc +import tempfile import unittest from typing import Callable, Union @@ -6,8 +7,8 @@ import torch import diffusers +from diffusers import ComponentsManager, ModularPipeline, ModularPipelineBlocks from diffusers.utils import logging -from diffusers.utils.dummy_pt_objects import ModularPipeline, ModularPipelineBlocks from diffusers.utils.testing_utils import ( backend_empty_cache, numpy_cosine_similarity_distance, @@ -143,9 +144,9 @@ def test_pipeline_call_signature(self): def _check_for_parameters(parameters, expected_parameters, param_type): remaining_parameters = {param for param in parameters if param not in expected_parameters} - assert ( - len(remaining_parameters) == 0 - ), f"Required {param_type} parameters not present: {remaining_parameters}" + assert len(remaining_parameters) == 0, ( + f"Required {param_type} parameters not present: {remaining_parameters}" + ) _check_for_parameters(self.params, input_parameters, "input") _check_for_parameters(self.intermediate_params, intermediate_parameters, "intermediate") @@ -274,9 +275,9 @@ def test_to_device(self): model_devices = [ component.device.type for component in pipe.components.values() if hasattr(component, "device") ] - assert all( - device == torch_device for device in model_devices - ), "All pipeline components are not on accelerator device" + assert all(device == torch_device for device in model_devices), ( + "All pipeline components are not on accelerator device" + ) def test_inference_is_not_nan_cpu(self): pipe = self.get_pipeline() @@ -344,7 +345,7 @@ def test_save_from_pretrained(self): with tempfile.TemporaryDirectory() as tmpdirname: base_pipe.save_pretrained(tmpdirname) pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device) - pipe.load_default_components(torch_dtype=torch.float32) + pipe.load_default_components(torch_dtype=torch.float16) pipe.to(torch_device) pipes.append(pipe) From 798e0caac1275c195c0080c443cd4024d2cf88ee Mon Sep 17 00:00:00 2001 From: DN6 Date: Fri, 8 Aug 2025 18:23:37 +0530 Subject: [PATCH 16/16] update --- ...st_modular_pipeline_stable_diffusion_xl.py | 23 ------------------- .../test_modular_pipelines_common.py | 2 +- 2 files changed, 1 insertion(+), 24 deletions(-) diff --git a/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py b/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py index 2e7c90d8e490..4127d00c8e1a 100644 --- a/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py +++ b/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py @@ -14,7 +14,6 @@ # limitations under the License. import random -import tempfile import unittest from typing import Any, Dict @@ -24,7 +23,6 @@ from diffusers import ( ClassifierFreeGuidance, - ModularPipeline, StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline, ) @@ -381,27 +379,6 @@ def test_stable_diffusion_xl_euler(self): def test_inference_batch_single_identical(self): super().test_inference_batch_single_identical(expected_max_diff=3e-3) - def test_stable_diffusion_xl_save_from_pretrained(self): - pipes = [] - sd_pipe = self.get_pipeline().to(torch_device) - pipes.append(sd_pipe) - - with tempfile.TemporaryDirectory() as tmpdirname: - sd_pipe.save_pretrained(tmpdirname) - sd_pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device) - sd_pipe.load_default_components(torch_dtype=torch.float32) - sd_pipe.to(torch_device) - pipes.append(sd_pipe) - - image_slices = [] - for pipe in pipes: - inputs = self.get_dummy_inputs(torch_device) - image = pipe(**inputs, output="images") - - image_slices.append(image[0, -3:, -3:, -1].flatten()) - - assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3 - class SDXLImg2ImgModularPipelineFastTests( SDXLModularTests, diff --git a/tests/modular_pipelines/test_modular_pipelines_common.py b/tests/modular_pipelines/test_modular_pipelines_common.py index 24c36b9fa93f..6240797742d4 100644 --- a/tests/modular_pipelines/test_modular_pipelines_common.py +++ b/tests/modular_pipelines/test_modular_pipelines_common.py @@ -345,7 +345,7 @@ def test_save_from_pretrained(self): with tempfile.TemporaryDirectory() as tmpdirname: base_pipe.save_pretrained(tmpdirname) pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device) - pipe.load_default_components(torch_dtype=torch.float16) + pipe.load_default_components(torch_dtype=torch.float32) pipe.to(torch_device) pipes.append(pipe)