From cfe1e2e3fa63aef113a43d359291a17a2cf06928 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 23 Sep 2025 13:32:05 +0530 Subject: [PATCH 1/7] up --- .../test_models_autoencoder_hunyuan_video.py | 66 +---------------- .../test_models_autoencoder_kl.py | 66 +---------------- .../test_models_autoencoder_kl_cogvideox.py | 66 +---------------- .../test_models_autoencoder_ltx_video.py | 35 +-------- .../test_models_autoencoder_tiny.py | 35 +-------- .../test_models_autoencoder_wan.py | 68 +----------------- .../test_models_consistency_decoder_vae.py | 68 +----------------- tests/models/test_modeling_common.py | 71 +++++++++++++++++++ 8 files changed, 85 insertions(+), 390 deletions(-) diff --git a/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py b/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py index 6f91f8bfa91b..f5d630a3623b 100644 --- a/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py +++ b/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py @@ -25,13 +25,13 @@ floats_tensor, torch_device, ) -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin +from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin, VAETestMixin enable_full_determinism() -class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): +class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, VAETestMixin, UNetTesterMixin, unittest.TestCase): model_class = AutoencoderKLHunyuanVideo main_input_name = "sample" base_precision = 1e-2 @@ -87,68 +87,6 @@ def prepare_init_args_and_inputs_for_common(self): inputs_dict = self.dummy_input return init_dict, inputs_dict - def test_enable_disable_tiling(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - torch.manual_seed(0) - model = self.model_class(**init_dict).to(torch_device) - - inputs_dict.update({"return_dict": False}) - - torch.manual_seed(0) - output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - torch.manual_seed(0) - model.enable_tiling() - output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertLess( - (output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(), - 0.5, - "VAE tiling should not affect the inference results", - ) - - torch.manual_seed(0) - model.disable_tiling() - output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertEqual( - output_without_tiling.detach().cpu().numpy().all(), - output_without_tiling_2.detach().cpu().numpy().all(), - "Without tiling outputs should match with the outputs when tiling is manually disabled.", - ) - - def test_enable_disable_slicing(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - torch.manual_seed(0) - model = self.model_class(**init_dict).to(torch_device) - - inputs_dict.update({"return_dict": False}) - - torch.manual_seed(0) - output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - torch.manual_seed(0) - model.enable_slicing() - output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertLess( - (output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(), - 0.5, - "VAE slicing should not affect the inference results", - ) - - torch.manual_seed(0) - model.disable_slicing() - output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertEqual( - output_without_slicing.detach().cpu().numpy().all(), - output_without_slicing_2.detach().cpu().numpy().all(), - "Without slicing outputs should match with the outputs when slicing is manually disabled.", - ) - def test_gradient_checkpointing_is_applied(self): expected_set = { "HunyuanVideoDecoder3D", diff --git a/tests/models/autoencoders/test_models_autoencoder_kl.py b/tests/models/autoencoders/test_models_autoencoder_kl.py index 662a3f1b80b7..80fb7f4b0dbc 100644 --- a/tests/models/autoencoders/test_models_autoencoder_kl.py +++ b/tests/models/autoencoders/test_models_autoencoder_kl.py @@ -35,13 +35,13 @@ torch_all_close, torch_device, ) -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin +from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin, VAETestMixin enable_full_determinism() -class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): +class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, VAETestMixin, unittest.TestCase): model_class = AutoencoderKL main_input_name = "sample" base_precision = 1e-2 @@ -83,68 +83,6 @@ def prepare_init_args_and_inputs_for_common(self): inputs_dict = self.dummy_input return init_dict, inputs_dict - def test_enable_disable_tiling(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - torch.manual_seed(0) - model = self.model_class(**init_dict).to(torch_device) - - inputs_dict.update({"return_dict": False}) - - torch.manual_seed(0) - output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - torch.manual_seed(0) - model.enable_tiling() - output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertLess( - (output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(), - 0.5, - "VAE tiling should not affect the inference results", - ) - - torch.manual_seed(0) - model.disable_tiling() - output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertEqual( - output_without_tiling.detach().cpu().numpy().all(), - output_without_tiling_2.detach().cpu().numpy().all(), - "Without tiling outputs should match with the outputs when tiling is manually disabled.", - ) - - def test_enable_disable_slicing(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - torch.manual_seed(0) - model = self.model_class(**init_dict).to(torch_device) - - inputs_dict.update({"return_dict": False}) - - torch.manual_seed(0) - output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - torch.manual_seed(0) - model.enable_slicing() - output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertLess( - (output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(), - 0.5, - "VAE slicing should not affect the inference results", - ) - - torch.manual_seed(0) - model.disable_slicing() - output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertEqual( - output_without_slicing.detach().cpu().numpy().all(), - output_without_slicing_2.detach().cpu().numpy().all(), - "Without slicing outputs should match with the outputs when slicing is manually disabled.", - ) - def test_gradient_checkpointing_is_applied(self): expected_set = {"Decoder", "Encoder", "UNetMidBlock2D"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) diff --git a/tests/models/autoencoders/test_models_autoencoder_kl_cogvideox.py b/tests/models/autoencoders/test_models_autoencoder_kl_cogvideox.py index 739daf2a492d..694fcfa70ffd 100644 --- a/tests/models/autoencoders/test_models_autoencoder_kl_cogvideox.py +++ b/tests/models/autoencoders/test_models_autoencoder_kl_cogvideox.py @@ -24,13 +24,13 @@ floats_tensor, torch_device, ) -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin +from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin, VAETestMixin enable_full_determinism() -class AutoencoderKLCogVideoXTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): +class AutoencoderKLCogVideoXTests(ModelTesterMixin, VAETestMixin, UNetTesterMixin, unittest.TestCase): model_class = AutoencoderKLCogVideoX main_input_name = "sample" base_precision = 1e-2 @@ -82,68 +82,6 @@ def prepare_init_args_and_inputs_for_common(self): inputs_dict = self.dummy_input return init_dict, inputs_dict - def test_enable_disable_tiling(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - torch.manual_seed(0) - model = self.model_class(**init_dict).to(torch_device) - - inputs_dict.update({"return_dict": False}) - - torch.manual_seed(0) - output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - torch.manual_seed(0) - model.enable_tiling() - output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertLess( - (output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(), - 0.5, - "VAE tiling should not affect the inference results", - ) - - torch.manual_seed(0) - model.disable_tiling() - output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertEqual( - output_without_tiling.detach().cpu().numpy().all(), - output_without_tiling_2.detach().cpu().numpy().all(), - "Without tiling outputs should match with the outputs when tiling is manually disabled.", - ) - - def test_enable_disable_slicing(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - torch.manual_seed(0) - model = self.model_class(**init_dict).to(torch_device) - - inputs_dict.update({"return_dict": False}) - - torch.manual_seed(0) - output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - torch.manual_seed(0) - model.enable_slicing() - output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertLess( - (output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(), - 0.5, - "VAE slicing should not affect the inference results", - ) - - torch.manual_seed(0) - model.disable_slicing() - output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertEqual( - output_without_slicing.detach().cpu().numpy().all(), - output_without_slicing_2.detach().cpu().numpy().all(), - "Without slicing outputs should match with the outputs when slicing is manually disabled.", - ) - def test_gradient_checkpointing_is_applied(self): expected_set = { "CogVideoXDownBlock3D", diff --git a/tests/models/autoencoders/test_models_autoencoder_ltx_video.py b/tests/models/autoencoders/test_models_autoencoder_ltx_video.py index 21ab3896c890..4fc5402f6a09 100644 --- a/tests/models/autoencoders/test_models_autoencoder_ltx_video.py +++ b/tests/models/autoencoders/test_models_autoencoder_ltx_video.py @@ -24,13 +24,13 @@ floats_tensor, torch_device, ) -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin +from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin, VAETestMixin enable_full_determinism() -class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): +class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, VAETestMixin, UNetTesterMixin, unittest.TestCase): model_class = AutoencoderKLLTXVideo main_input_name = "sample" base_precision = 1e-2 @@ -167,34 +167,3 @@ def test_outputs_equivalence(self): @unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.") def test_forward_with_norm_groups(self): pass - - def test_enable_disable_tiling(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - torch.manual_seed(0) - model = self.model_class(**init_dict).to(torch_device) - - inputs_dict.update({"return_dict": False}) - - torch.manual_seed(0) - output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - torch.manual_seed(0) - model.enable_tiling() - output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertLess( - (output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(), - 0.5, - "VAE tiling should not affect the inference results", - ) - - torch.manual_seed(0) - model.disable_tiling() - output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertEqual( - output_without_tiling.detach().cpu().numpy().all(), - output_without_tiling_2.detach().cpu().numpy().all(), - "Without tiling outputs should match with the outputs when tiling is manually disabled.", - ) diff --git a/tests/models/autoencoders/test_models_autoencoder_tiny.py b/tests/models/autoencoders/test_models_autoencoder_tiny.py index 4d1dc69cfaad..3beb934fdd50 100644 --- a/tests/models/autoencoders/test_models_autoencoder_tiny.py +++ b/tests/models/autoencoders/test_models_autoencoder_tiny.py @@ -31,13 +31,13 @@ torch_all_close, torch_device, ) -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin +from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin, VAETestMixin enable_full_determinism() -class AutoencoderTinyTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): +class AutoencoderTinyTests(ModelTesterMixin, VAETestMixin, UNetTesterMixin, unittest.TestCase): model_class = AutoencoderTiny main_input_name = "sample" base_precision = 1e-2 @@ -81,37 +81,6 @@ def prepare_init_args_and_inputs_for_common(self): def test_enable_disable_tiling(self): pass - def test_enable_disable_slicing(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - torch.manual_seed(0) - model = self.model_class(**init_dict).to(torch_device) - - inputs_dict.update({"return_dict": False}) - - torch.manual_seed(0) - output_without_slicing = model(**inputs_dict)[0] - - torch.manual_seed(0) - model.enable_slicing() - output_with_slicing = model(**inputs_dict)[0] - - self.assertLess( - (output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(), - 0.5, - "VAE slicing should not affect the inference results", - ) - - torch.manual_seed(0) - model.disable_slicing() - output_without_slicing_2 = model(**inputs_dict)[0] - - self.assertEqual( - output_without_slicing.detach().cpu().numpy().all(), - output_without_slicing_2.detach().cpu().numpy().all(), - "Without slicing outputs should match with the outputs when slicing is manually disabled.", - ) - @unittest.skip("Test not supported.") def test_outputs_equivalence(self): pass diff --git a/tests/models/autoencoders/test_models_autoencoder_wan.py b/tests/models/autoencoders/test_models_autoencoder_wan.py index cc9c88868157..16fef65e59e4 100644 --- a/tests/models/autoencoders/test_models_autoencoder_wan.py +++ b/tests/models/autoencoders/test_models_autoencoder_wan.py @@ -15,18 +15,16 @@ import unittest -import torch - from diffusers import AutoencoderKLWan from ...testing_utils import enable_full_determinism, floats_tensor, torch_device -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin +from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin, VAETestMixin enable_full_determinism() -class AutoencoderKLWanTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): +class AutoencoderKLWanTests(ModelTesterMixin, VAETestMixin, UNetTesterMixin, unittest.TestCase): model_class = AutoencoderKLWan main_input_name = "sample" base_precision = 1e-2 @@ -76,68 +74,6 @@ def prepare_init_args_and_inputs_for_tiling(self): inputs_dict = self.dummy_input_tiling return init_dict, inputs_dict - def test_enable_disable_tiling(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_tiling() - - torch.manual_seed(0) - model = self.model_class(**init_dict).to(torch_device) - - inputs_dict.update({"return_dict": False}) - - torch.manual_seed(0) - output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - torch.manual_seed(0) - model.enable_tiling(96, 96, 64, 64) - output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertLess( - (output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(), - 0.5, - "VAE tiling should not affect the inference results", - ) - - torch.manual_seed(0) - model.disable_tiling() - output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertEqual( - output_without_tiling.detach().cpu().numpy().all(), - output_without_tiling_2.detach().cpu().numpy().all(), - "Without tiling outputs should match with the outputs when tiling is manually disabled.", - ) - - def test_enable_disable_slicing(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - torch.manual_seed(0) - model = self.model_class(**init_dict).to(torch_device) - - inputs_dict.update({"return_dict": False}) - - torch.manual_seed(0) - output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - torch.manual_seed(0) - model.enable_slicing() - output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertLess( - (output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(), - 0.05, - "VAE slicing should not affect the inference results", - ) - - torch.manual_seed(0) - model.disable_slicing() - output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertEqual( - output_without_slicing.detach().cpu().numpy().all(), - output_without_slicing_2.detach().cpu().numpy().all(), - "Without slicing outputs should match with the outputs when slicing is manually disabled.", - ) - @unittest.skip("Gradient checkpointing has not been implemented yet") def test_gradient_checkpointing_is_applied(self): pass diff --git a/tests/models/autoencoders/test_models_consistency_decoder_vae.py b/tests/models/autoencoders/test_models_consistency_decoder_vae.py index 7e44edba3624..233af4084b1a 100644 --- a/tests/models/autoencoders/test_models_consistency_decoder_vae.py +++ b/tests/models/autoencoders/test_models_consistency_decoder_vae.py @@ -30,13 +30,13 @@ torch_all_close, torch_device, ) -from ..test_modeling_common import ModelTesterMixin +from ..test_modeling_common import ModelTesterMixin, VAETestMixin enable_full_determinism() -class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase): +class ConsistencyDecoderVAETests(ModelTesterMixin, VAETestMixin, unittest.TestCase): model_class = ConsistencyDecoderVAE main_input_name = "sample" base_precision = 1e-2 @@ -92,70 +92,6 @@ def init_dict(self): def prepare_init_args_and_inputs_for_common(self): return self.init_dict, self.inputs_dict() - def test_enable_disable_tiling(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - torch.manual_seed(0) - model = self.model_class(**init_dict).to(torch_device) - - inputs_dict.update({"return_dict": False}) - _ = inputs_dict.pop("generator") - - torch.manual_seed(0) - output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - torch.manual_seed(0) - model.enable_tiling() - output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertLess( - (output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(), - 0.5, - "VAE tiling should not affect the inference results", - ) - - torch.manual_seed(0) - model.disable_tiling() - output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertEqual( - output_without_tiling.detach().cpu().numpy().all(), - output_without_tiling_2.detach().cpu().numpy().all(), - "Without tiling outputs should match with the outputs when tiling is manually disabled.", - ) - - def test_enable_disable_slicing(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - torch.manual_seed(0) - model = self.model_class(**init_dict).to(torch_device) - - inputs_dict.update({"return_dict": False}) - _ = inputs_dict.pop("generator") - - torch.manual_seed(0) - output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - torch.manual_seed(0) - model.enable_slicing() - output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertLess( - (output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(), - 0.5, - "VAE slicing should not affect the inference results", - ) - - torch.manual_seed(0) - model.disable_slicing() - output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertEqual( - output_without_slicing.detach().cpu().numpy().all(), - output_without_slicing_2.detach().cpu().numpy().all(), - "Without slicing outputs should match with the outputs when slicing is manually disabled.", - ) - @slow class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase): diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 5e7be62342c3..1dbb33863091 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1941,6 +1941,77 @@ def test_passing_dict_device_map_works(self, name, device): _ = loaded_model(**inputs_dict) +class VAETestMixin: + """ + Test mixin class specific to VAEs to test for slicing and tiling. Diffusion networks + usually don't do slicing and tiling. + """ + + def test_enable_disable_tiling(self): + if not hasattr(self.model_class, "enable_tiling"): + pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support tiling.") + + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + torch.manual_seed(0) + model = self.model_class(**init_dict).to(torch_device) + + inputs_dict.update({"return_dict": False}) + _ = inputs_dict.pop("generator") + + torch.manual_seed(0) + output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + torch.manual_seed(0) + model.enable_tiling() + output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + assert ( + output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy() + ).max() < 0.5, "VAE tiling should not affect the inference results" + + torch.manual_seed(0) + model.disable_tiling() + output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + assert np.allclose( + output_without_tiling.detach().cpu().numpy().all(), + output_without_tiling_2.detach().cpu().numpy().all(), + ), "Without tiling outputs should match with the outputs when tiling is manually disabled." + + def test_enable_disable_slicing(self): + if not hasattr(self.model_class, "enable_slicing"): + pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support slicing.") + + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + torch.manual_seed(0) + model = self.model_class(**init_dict).to(torch_device) + + inputs_dict.update({"return_dict": False}) + _ = inputs_dict.pop("generator") + + torch.manual_seed(0) + output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + torch.manual_seed(0) + model.enable_slicing() + output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + assert ( + output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy() + ).max() < 0.5, "VAE slicing should not affect the inference results" + + torch.manual_seed(0) + model.disable_slicing() + output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + assert np.allclose( + output_without_slicing.detach().cpu().numpy().all(), + output_without_slicing_2.detach().cpu().numpy().all(), + ), "Without slicing outputs should match with the outputs when slicing is manually disabled." + + @is_staging_test class ModelPushToHubTester(unittest.TestCase): identifier = uuid.uuid4() From 490c4761b495e0f6f53efb661fe08e2ec5c5f7f5 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 23 Sep 2025 13:45:27 +0530 Subject: [PATCH 2/7] up --- tests/models/test_modeling_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 1dbb33863091..3964dfa7ade3 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1989,7 +1989,7 @@ def test_enable_disable_slicing(self): model = self.model_class(**init_dict).to(torch_device) inputs_dict.update({"return_dict": False}) - _ = inputs_dict.pop("generator") + _ = inputs_dict.pop("generator", None) torch.manual_seed(0) output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] From 01aa188d8d908251eaf5a2897eefae63ad4ea3f9 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 23 Sep 2025 14:05:51 +0530 Subject: [PATCH 3/7] up --- tests/models/test_modeling_common.py | 35 ++++++++++++++++++++++------ 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 3964dfa7ade3..5d02747986c5 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1947,6 +1947,12 @@ class VAETestMixin: usually don't do slicing and tiling. """ + @staticmethod + def _accepts_generator(model): + model_sig = inspect.signature(model.forward) + accepts_generator = "generator" in model_sig.parameters + return accepts_generator + def test_enable_disable_tiling(self): if not hasattr(self.model_class, "enable_tiling"): pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support tiling.") @@ -1957,14 +1963,19 @@ def test_enable_disable_tiling(self): model = self.model_class(**init_dict).to(torch_device) inputs_dict.update({"return_dict": False}) - _ = inputs_dict.pop("generator") + _ = inputs_dict.pop("generator", None) + accepts_generator = self._accepts_generator(model) torch.manual_seed(0) - output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] + if accepts_generator: + inputs_dict["generator"] = torch.manual_seed(0) + output_without_tiling = model(**inputs_dict)[0] torch.manual_seed(0) model.enable_tiling() - output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] + if accepts_generator: + inputs_dict["generator"] = torch.manual_seed(0) + output_with_tiling = model(**inputs_dict)[0] assert ( output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy() @@ -1972,7 +1983,9 @@ def test_enable_disable_tiling(self): torch.manual_seed(0) model.disable_tiling() - output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] + if accepts_generator: + inputs_dict["generator"] = torch.manual_seed(0) + output_without_tiling_2 = model(**inputs_dict)[0] assert np.allclose( output_without_tiling.detach().cpu().numpy().all(), @@ -1990,13 +2003,19 @@ def test_enable_disable_slicing(self): inputs_dict.update({"return_dict": False}) _ = inputs_dict.pop("generator", None) + accepts_generator = self._accepts_generator(model) + + if accepts_generator: + inputs_dict["generator"] = torch.manual_seed(0) torch.manual_seed(0) - output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] + output_without_slicing = model(**inputs_dict)[0] torch.manual_seed(0) model.enable_slicing() - output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] + if accepts_generator: + inputs_dict["generator"] = torch.manual_seed(0) + output_with_slicing = model(**inputs_dict)[0] assert ( output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy() @@ -2004,7 +2023,9 @@ def test_enable_disable_slicing(self): torch.manual_seed(0) model.disable_slicing() - output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] + if accepts_generator: + inputs_dict["generator"] = torch.manual_seed(0) + output_without_slicing_2 = model(**inputs_dict)[0] assert np.allclose( output_without_slicing.detach().cpu().numpy().all(), From 769b7452ed37485d032c9dd7b20cf5cb0efea2c0 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 24 Sep 2025 09:15:50 +0530 Subject: [PATCH 4/7] up --- .../autoencoders/test_models_autoencoder_hunyuan_video.py | 4 ++-- tests/models/autoencoders/test_models_autoencoder_kl.py | 4 ++-- .../autoencoders/test_models_autoencoder_kl_cogvideox.py | 4 ++-- .../models/autoencoders/test_models_autoencoder_ltx_video.py | 4 ++-- tests/models/autoencoders/test_models_autoencoder_tiny.py | 4 ++-- tests/models/autoencoders/test_models_autoencoder_wan.py | 4 ++-- .../autoencoders/test_models_consistency_decoder_vae.py | 4 ++-- tests/models/test_modeling_common.py | 2 +- 8 files changed, 15 insertions(+), 15 deletions(-) diff --git a/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py b/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py index f5d630a3623b..60d901eaac9f 100644 --- a/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py +++ b/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py @@ -25,13 +25,13 @@ floats_tensor, torch_device, ) -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin, VAETestMixin +from ..test_modeling_common import AutoencoderTesterMixin, ModelTesterMixin, UNetTesterMixin enable_full_determinism() -class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, VAETestMixin, UNetTesterMixin, unittest.TestCase): +class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, AutoencoderTesterMixin, UNetTesterMixin, unittest.TestCase): model_class = AutoencoderKLHunyuanVideo main_input_name = "sample" base_precision = 1e-2 diff --git a/tests/models/autoencoders/test_models_autoencoder_kl.py b/tests/models/autoencoders/test_models_autoencoder_kl.py index 80fb7f4b0dbc..9300f9b91e2f 100644 --- a/tests/models/autoencoders/test_models_autoencoder_kl.py +++ b/tests/models/autoencoders/test_models_autoencoder_kl.py @@ -35,13 +35,13 @@ torch_all_close, torch_device, ) -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin, VAETestMixin +from ..test_modeling_common import AutoencoderTesterMixin, ModelTesterMixin, UNetTesterMixin enable_full_determinism() -class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, VAETestMixin, unittest.TestCase): +class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, AutoencoderTesterMixin, unittest.TestCase): model_class = AutoencoderKL main_input_name = "sample" base_precision = 1e-2 diff --git a/tests/models/autoencoders/test_models_autoencoder_kl_cogvideox.py b/tests/models/autoencoders/test_models_autoencoder_kl_cogvideox.py index 694fcfa70ffd..08b19a7c51d5 100644 --- a/tests/models/autoencoders/test_models_autoencoder_kl_cogvideox.py +++ b/tests/models/autoencoders/test_models_autoencoder_kl_cogvideox.py @@ -24,13 +24,13 @@ floats_tensor, torch_device, ) -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin, VAETestMixin +from ..test_modeling_common import AutoencoderTesterMixin, ModelTesterMixin, UNetTesterMixin enable_full_determinism() -class AutoencoderKLCogVideoXTests(ModelTesterMixin, VAETestMixin, UNetTesterMixin, unittest.TestCase): +class AutoencoderKLCogVideoXTests(ModelTesterMixin, AutoencoderTesterMixin, UNetTesterMixin, unittest.TestCase): model_class = AutoencoderKLCogVideoX main_input_name = "sample" base_precision = 1e-2 diff --git a/tests/models/autoencoders/test_models_autoencoder_ltx_video.py b/tests/models/autoencoders/test_models_autoencoder_ltx_video.py index 4fc5402f6a09..5a13d180ed95 100644 --- a/tests/models/autoencoders/test_models_autoencoder_ltx_video.py +++ b/tests/models/autoencoders/test_models_autoencoder_ltx_video.py @@ -24,13 +24,13 @@ floats_tensor, torch_device, ) -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin, VAETestMixin +from ..test_modeling_common import AutoencoderTesterMixin, ModelTesterMixin, UNetTesterMixin enable_full_determinism() -class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, VAETestMixin, UNetTesterMixin, unittest.TestCase): +class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, AutoencoderTesterMixin, UNetTesterMixin, unittest.TestCase): model_class = AutoencoderKLLTXVideo main_input_name = "sample" base_precision = 1e-2 diff --git a/tests/models/autoencoders/test_models_autoencoder_tiny.py b/tests/models/autoencoders/test_models_autoencoder_tiny.py index 3beb934fdd50..080b5c1934e7 100644 --- a/tests/models/autoencoders/test_models_autoencoder_tiny.py +++ b/tests/models/autoencoders/test_models_autoencoder_tiny.py @@ -31,13 +31,13 @@ torch_all_close, torch_device, ) -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin, VAETestMixin +from ..test_modeling_common import AutoencoderTesterMixin, ModelTesterMixin, UNetTesterMixin enable_full_determinism() -class AutoencoderTinyTests(ModelTesterMixin, VAETestMixin, UNetTesterMixin, unittest.TestCase): +class AutoencoderTinyTests(ModelTesterMixin, AutoencoderTesterMixin, UNetTesterMixin, unittest.TestCase): model_class = AutoencoderTiny main_input_name = "sample" base_precision = 1e-2 diff --git a/tests/models/autoencoders/test_models_autoencoder_wan.py b/tests/models/autoencoders/test_models_autoencoder_wan.py index 16fef65e59e4..ae905f48bd6c 100644 --- a/tests/models/autoencoders/test_models_autoencoder_wan.py +++ b/tests/models/autoencoders/test_models_autoencoder_wan.py @@ -18,13 +18,13 @@ from diffusers import AutoencoderKLWan from ...testing_utils import enable_full_determinism, floats_tensor, torch_device -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin, VAETestMixin +from ..test_modeling_common import AutoencoderTesterMixin, ModelTesterMixin, UNetTesterMixin enable_full_determinism() -class AutoencoderKLWanTests(ModelTesterMixin, VAETestMixin, UNetTesterMixin, unittest.TestCase): +class AutoencoderKLWanTests(ModelTesterMixin, AutoencoderTesterMixin, UNetTesterMixin, unittest.TestCase): model_class = AutoencoderKLWan main_input_name = "sample" base_precision = 1e-2 diff --git a/tests/models/autoencoders/test_models_consistency_decoder_vae.py b/tests/models/autoencoders/test_models_consistency_decoder_vae.py index 233af4084b1a..b9820578ad80 100644 --- a/tests/models/autoencoders/test_models_consistency_decoder_vae.py +++ b/tests/models/autoencoders/test_models_consistency_decoder_vae.py @@ -30,13 +30,13 @@ torch_all_close, torch_device, ) -from ..test_modeling_common import ModelTesterMixin, VAETestMixin +from ..test_modeling_common import AutoencoderTesterMixin, ModelTesterMixin enable_full_determinism() -class ConsistencyDecoderVAETests(ModelTesterMixin, VAETestMixin, unittest.TestCase): +class ConsistencyDecoderVAETests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase): model_class = ConsistencyDecoderVAE main_input_name = "sample" base_precision = 1e-2 diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 5d02747986c5..671b9a6500a8 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1941,7 +1941,7 @@ def test_passing_dict_device_map_works(self, name, device): _ = loaded_model(**inputs_dict) -class VAETestMixin: +class AutoencoderTesterMixin: """ Test mixin class specific to VAEs to test for slicing and tiling. Diffusion networks usually don't do slicing and tiling. From 3a106f05ee83b813a5f24be154838439fd6a3de0 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 24 Sep 2025 09:20:32 +0530 Subject: [PATCH 5/7] up --- .../autoencoders/test_models_asymmetric_autoencoder_kl.py | 4 ++-- tests/models/autoencoders/test_models_autoencoder_cosmos.py | 4 ++-- tests/models/autoencoders/test_models_autoencoder_dc.py | 4 ++-- .../autoencoders/test_models_autoencoder_hunyuan_video.py | 4 ++-- tests/models/autoencoders/test_models_autoencoder_kl.py | 4 ++-- .../autoencoders/test_models_autoencoder_kl_cogvideox.py | 4 ++-- .../test_models_autoencoder_kl_temporal_decoder.py | 4 ++-- .../autoencoders/test_models_autoencoder_ltx_video.py | 6 +++--- tests/models/autoencoders/test_models_autoencoder_magvit.py | 4 ++-- tests/models/autoencoders/test_models_autoencoder_mochi.py | 4 ++-- .../models/autoencoders/test_models_autoencoder_oobleck.py | 4 ++-- tests/models/autoencoders/test_models_autoencoder_tiny.py | 4 ++-- tests/models/autoencoders/test_models_autoencoder_wan.py | 4 ++-- tests/models/autoencoders/test_models_vq.py | 4 ++-- tests/models/test_modeling_common.py | 2 +- 15 files changed, 30 insertions(+), 30 deletions(-) diff --git a/tests/models/autoencoders/test_models_asymmetric_autoencoder_kl.py b/tests/models/autoencoders/test_models_asymmetric_autoencoder_kl.py index 7eb830cd5097..36207c614383 100644 --- a/tests/models/autoencoders/test_models_asymmetric_autoencoder_kl.py +++ b/tests/models/autoencoders/test_models_asymmetric_autoencoder_kl.py @@ -35,13 +35,13 @@ torch_all_close, torch_device, ) -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin +from ..test_modeling_common import ModelTesterMixin enable_full_determinism() -class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): +class AsymmetricAutoencoderKLTests(ModelTesterMixin, unittest.TestCase): model_class = AsymmetricAutoencoderKL main_input_name = "sample" base_precision = 1e-2 diff --git a/tests/models/autoencoders/test_models_autoencoder_cosmos.py b/tests/models/autoencoders/test_models_autoencoder_cosmos.py index ceccc2364e26..45ae737860ca 100644 --- a/tests/models/autoencoders/test_models_autoencoder_cosmos.py +++ b/tests/models/autoencoders/test_models_autoencoder_cosmos.py @@ -17,13 +17,13 @@ from diffusers import AutoencoderKLCosmos from ...testing_utils import enable_full_determinism, floats_tensor, torch_device -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin +from ..test_modeling_common import ModelTesterMixin enable_full_determinism() -class AutoencoderKLCosmosTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): +class AutoencoderKLCosmosTests(ModelTesterMixin, unittest.TestCase): model_class = AutoencoderKLCosmos main_input_name = "sample" base_precision = 1e-2 diff --git a/tests/models/autoencoders/test_models_autoencoder_dc.py b/tests/models/autoencoders/test_models_autoencoder_dc.py index 56f172f1c869..0b48818bd1a8 100644 --- a/tests/models/autoencoders/test_models_autoencoder_dc.py +++ b/tests/models/autoencoders/test_models_autoencoder_dc.py @@ -22,13 +22,13 @@ floats_tensor, torch_device, ) -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin +from ..test_modeling_common import ModelTesterMixin enable_full_determinism() -class AutoencoderDCTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): +class AutoencoderDCTests(ModelTesterMixin, unittest.TestCase): model_class = AutoencoderDC main_input_name = "sample" base_precision = 1e-2 diff --git a/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py b/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py index 60d901eaac9f..e8e82b50c7b4 100644 --- a/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py +++ b/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py @@ -25,13 +25,13 @@ floats_tensor, torch_device, ) -from ..test_modeling_common import AutoencoderTesterMixin, ModelTesterMixin, UNetTesterMixin +from ..test_modeling_common import AutoencoderTesterMixin, ModelTesterMixin enable_full_determinism() -class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, AutoencoderTesterMixin, UNetTesterMixin, unittest.TestCase): +class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase): model_class = AutoencoderKLHunyuanVideo main_input_name = "sample" base_precision = 1e-2 diff --git a/tests/models/autoencoders/test_models_autoencoder_kl.py b/tests/models/autoencoders/test_models_autoencoder_kl.py index 9300f9b91e2f..3cd3c3fe15dc 100644 --- a/tests/models/autoencoders/test_models_autoencoder_kl.py +++ b/tests/models/autoencoders/test_models_autoencoder_kl.py @@ -35,13 +35,13 @@ torch_all_close, torch_device, ) -from ..test_modeling_common import AutoencoderTesterMixin, ModelTesterMixin, UNetTesterMixin +from ..test_modeling_common import AutoencoderTesterMixin, ModelTesterMixin enable_full_determinism() -class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, AutoencoderTesterMixin, unittest.TestCase): +class AutoencoderKLTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase): model_class = AutoencoderKL main_input_name = "sample" base_precision = 1e-2 diff --git a/tests/models/autoencoders/test_models_autoencoder_kl_cogvideox.py b/tests/models/autoencoders/test_models_autoencoder_kl_cogvideox.py index 08b19a7c51d5..148d5d8dd2a9 100644 --- a/tests/models/autoencoders/test_models_autoencoder_kl_cogvideox.py +++ b/tests/models/autoencoders/test_models_autoencoder_kl_cogvideox.py @@ -24,13 +24,13 @@ floats_tensor, torch_device, ) -from ..test_modeling_common import AutoencoderTesterMixin, ModelTesterMixin, UNetTesterMixin +from ..test_modeling_common import AutoencoderTesterMixin, ModelTesterMixin enable_full_determinism() -class AutoencoderKLCogVideoXTests(ModelTesterMixin, AutoencoderTesterMixin, UNetTesterMixin, unittest.TestCase): +class AutoencoderKLCogVideoXTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase): model_class = AutoencoderKLCogVideoX main_input_name = "sample" base_precision = 1e-2 diff --git a/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py b/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py index 6cb427bff8e1..66062e21e7f5 100644 --- a/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py +++ b/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py @@ -22,13 +22,13 @@ floats_tensor, torch_device, ) -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin +from ..test_modeling_common import ModelTesterMixin enable_full_determinism() -class AutoencoderKLTemporalDecoderTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): +class AutoencoderKLTemporalDecoderTests(ModelTesterMixin, unittest.TestCase): model_class = AutoencoderKLTemporalDecoder main_input_name = "sample" base_precision = 1e-2 diff --git a/tests/models/autoencoders/test_models_autoencoder_ltx_video.py b/tests/models/autoencoders/test_models_autoencoder_ltx_video.py index 5a13d180ed95..bd0bb62adbed 100644 --- a/tests/models/autoencoders/test_models_autoencoder_ltx_video.py +++ b/tests/models/autoencoders/test_models_autoencoder_ltx_video.py @@ -24,13 +24,13 @@ floats_tensor, torch_device, ) -from ..test_modeling_common import AutoencoderTesterMixin, ModelTesterMixin, UNetTesterMixin +from ..test_modeling_common import AutoencoderTesterMixin, ModelTesterMixin enable_full_determinism() -class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, AutoencoderTesterMixin, UNetTesterMixin, unittest.TestCase): +class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase): model_class = AutoencoderKLLTXVideo main_input_name = "sample" base_precision = 1e-2 @@ -99,7 +99,7 @@ def test_forward_with_norm_groups(self): pass -class AutoencoderKLLTXVideo091Tests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): +class AutoencoderKLLTXVideo091Tests(ModelTesterMixin, unittest.TestCase): model_class = AutoencoderKLLTXVideo main_input_name = "sample" base_precision = 1e-2 diff --git a/tests/models/autoencoders/test_models_autoencoder_magvit.py b/tests/models/autoencoders/test_models_autoencoder_magvit.py index 58cbfc05bd03..f77564fbb04d 100644 --- a/tests/models/autoencoders/test_models_autoencoder_magvit.py +++ b/tests/models/autoencoders/test_models_autoencoder_magvit.py @@ -18,13 +18,13 @@ from diffusers import AutoencoderKLMagvit from ...testing_utils import enable_full_determinism, floats_tensor, torch_device -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin +from ..test_modeling_common import ModelTesterMixin enable_full_determinism() -class AutoencoderKLMagvitTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): +class AutoencoderKLMagvitTests(ModelTesterMixin, unittest.TestCase): model_class = AutoencoderKLMagvit main_input_name = "sample" base_precision = 1e-2 diff --git a/tests/models/autoencoders/test_models_autoencoder_mochi.py b/tests/models/autoencoders/test_models_autoencoder_mochi.py index b8c5aaaa1eb6..371f0dd68f77 100755 --- a/tests/models/autoencoders/test_models_autoencoder_mochi.py +++ b/tests/models/autoencoders/test_models_autoencoder_mochi.py @@ -22,13 +22,13 @@ floats_tensor, torch_device, ) -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin +from ..test_modeling_common import ModelTesterMixin enable_full_determinism() -class AutoencoderKLMochiTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): +class AutoencoderKLMochiTests(ModelTesterMixin, unittest.TestCase): model_class = AutoencoderKLMochi main_input_name = "sample" base_precision = 1e-2 diff --git a/tests/models/autoencoders/test_models_autoencoder_oobleck.py b/tests/models/autoencoders/test_models_autoencoder_oobleck.py index eb7bd50f4a54..debbb5fdebc2 100644 --- a/tests/models/autoencoders/test_models_autoencoder_oobleck.py +++ b/tests/models/autoencoders/test_models_autoencoder_oobleck.py @@ -30,13 +30,13 @@ torch_all_close, torch_device, ) -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin +from ..test_modeling_common import ModelTesterMixin enable_full_determinism() -class AutoencoderOobleckTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): +class AutoencoderOobleckTests(ModelTesterMixin, unittest.TestCase): model_class = AutoencoderOobleck main_input_name = "sample" base_precision = 1e-2 diff --git a/tests/models/autoencoders/test_models_autoencoder_tiny.py b/tests/models/autoencoders/test_models_autoencoder_tiny.py index 080b5c1934e7..e25bc1e76903 100644 --- a/tests/models/autoencoders/test_models_autoencoder_tiny.py +++ b/tests/models/autoencoders/test_models_autoencoder_tiny.py @@ -31,13 +31,13 @@ torch_all_close, torch_device, ) -from ..test_modeling_common import AutoencoderTesterMixin, ModelTesterMixin, UNetTesterMixin +from ..test_modeling_common import AutoencoderTesterMixin, ModelTesterMixin enable_full_determinism() -class AutoencoderTinyTests(ModelTesterMixin, AutoencoderTesterMixin, UNetTesterMixin, unittest.TestCase): +class AutoencoderTinyTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase): model_class = AutoencoderTiny main_input_name = "sample" base_precision = 1e-2 diff --git a/tests/models/autoencoders/test_models_autoencoder_wan.py b/tests/models/autoencoders/test_models_autoencoder_wan.py index ae905f48bd6c..a476272a25c8 100644 --- a/tests/models/autoencoders/test_models_autoencoder_wan.py +++ b/tests/models/autoencoders/test_models_autoencoder_wan.py @@ -18,13 +18,13 @@ from diffusers import AutoencoderKLWan from ...testing_utils import enable_full_determinism, floats_tensor, torch_device -from ..test_modeling_common import AutoencoderTesterMixin, ModelTesterMixin, UNetTesterMixin +from ..test_modeling_common import AutoencoderTesterMixin, ModelTesterMixin enable_full_determinism() -class AutoencoderKLWanTests(ModelTesterMixin, AutoencoderTesterMixin, UNetTesterMixin, unittest.TestCase): +class AutoencoderKLWanTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase): model_class = AutoencoderKLWan main_input_name = "sample" base_precision = 1e-2 diff --git a/tests/models/autoencoders/test_models_vq.py b/tests/models/autoencoders/test_models_vq.py index 1c636b081733..c92b0227791b 100644 --- a/tests/models/autoencoders/test_models_vq.py +++ b/tests/models/autoencoders/test_models_vq.py @@ -25,13 +25,13 @@ floats_tensor, torch_device, ) -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin +from ..test_modeling_common import ModelTesterMixin enable_full_determinism() -class VQModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): +class VQModelTests(ModelTesterMixin, unittest.TestCase): model_class = VQModel main_input_name = "sample" diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 671b9a6500a8..080158de41eb 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1941,7 +1941,7 @@ def test_passing_dict_device_map_works(self, name, device): _ = loaded_model(**inputs_dict) -class AutoencoderTesterMixin: +class AutoencoderTesterMixin(UNetTesterMixin): """ Test mixin class specific to VAEs to test for slicing and tiling. Diffusion networks usually don't do slicing and tiling. From 6a01c4681cf80fded080cb01a8754cac025894f5 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 24 Sep 2025 09:30:34 +0530 Subject: [PATCH 6/7] u[ --- tests/models/test_modeling_common.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 080158de41eb..aac6474c48af 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -450,7 +450,15 @@ def get_dummy_inputs(): class UNetTesterMixin: + @staticmethod + def _accepts_norm_num_groups(model_class): + model_sig = inspect.signature(model_class.__init__) + accepts_norm_groups = "norm_num_groups" in model_sig.parameters + return accepts_norm_groups + def test_forward_with_norm_groups(self): + if not self._accepts_norm_num_groups(self.model_class): + pytest.skip(f"Test not supported for {self.model_class.__name__}") init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict["norm_num_groups"] = 16 From 7b8817ec04357424d74e6d203c93f95a72b38247 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 24 Sep 2025 10:44:15 +0530 Subject: [PATCH 7/7] up --- .../test_models_autoencoder_hunyuan_video.py | 9 +- .../test_models_autoencoder_kl.py | 3 +- .../test_models_autoencoder_kl_cogvideox.py | 3 +- .../test_models_autoencoder_ltx_video.py | 3 +- .../test_models_autoencoder_tiny.py | 3 +- .../test_models_autoencoder_wan.py | 3 +- .../test_models_consistency_decoder_vae.py | 3 +- tests/models/autoencoders/testing_utils.py | 127 ++++++++++++++++++ tests/models/test_modeling_common.py | 92 ------------- 9 files changed, 142 insertions(+), 104 deletions(-) create mode 100644 tests/models/autoencoders/testing_utils.py diff --git a/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py b/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py index e8e82b50c7b4..9813772a7c55 100644 --- a/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py +++ b/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py @@ -20,12 +20,9 @@ from diffusers import AutoencoderKLHunyuanVideo from diffusers.models.autoencoders.autoencoder_kl_hunyuan_video import prepare_causal_attention_mask -from ...testing_utils import ( - enable_full_determinism, - floats_tensor, - torch_device, -) -from ..test_modeling_common import AutoencoderTesterMixin, ModelTesterMixin +from ...testing_utils import enable_full_determinism, floats_tensor, torch_device +from ..test_modeling_common import ModelTesterMixin +from .testing_utils import AutoencoderTesterMixin enable_full_determinism() diff --git a/tests/models/autoencoders/test_models_autoencoder_kl.py b/tests/models/autoencoders/test_models_autoencoder_kl.py index 3cd3c3fe15dc..5f11c6cb0ab3 100644 --- a/tests/models/autoencoders/test_models_autoencoder_kl.py +++ b/tests/models/autoencoders/test_models_autoencoder_kl.py @@ -35,7 +35,8 @@ torch_all_close, torch_device, ) -from ..test_modeling_common import AutoencoderTesterMixin, ModelTesterMixin +from ..test_modeling_common import ModelTesterMixin +from .testing_utils import AutoencoderTesterMixin enable_full_determinism() diff --git a/tests/models/autoencoders/test_models_autoencoder_kl_cogvideox.py b/tests/models/autoencoders/test_models_autoencoder_kl_cogvideox.py index 148d5d8dd2a9..b6d59489d9c6 100644 --- a/tests/models/autoencoders/test_models_autoencoder_kl_cogvideox.py +++ b/tests/models/autoencoders/test_models_autoencoder_kl_cogvideox.py @@ -24,7 +24,8 @@ floats_tensor, torch_device, ) -from ..test_modeling_common import AutoencoderTesterMixin, ModelTesterMixin +from ..test_modeling_common import ModelTesterMixin +from .testing_utils import AutoencoderTesterMixin enable_full_determinism() diff --git a/tests/models/autoencoders/test_models_autoencoder_ltx_video.py b/tests/models/autoencoders/test_models_autoencoder_ltx_video.py index bd0bb62adbed..527be1b4ecb5 100644 --- a/tests/models/autoencoders/test_models_autoencoder_ltx_video.py +++ b/tests/models/autoencoders/test_models_autoencoder_ltx_video.py @@ -24,7 +24,8 @@ floats_tensor, torch_device, ) -from ..test_modeling_common import AutoencoderTesterMixin, ModelTesterMixin +from ..test_modeling_common import ModelTesterMixin +from .testing_utils import AutoencoderTesterMixin enable_full_determinism() diff --git a/tests/models/autoencoders/test_models_autoencoder_tiny.py b/tests/models/autoencoders/test_models_autoencoder_tiny.py index e25bc1e76903..68232aa12fdf 100644 --- a/tests/models/autoencoders/test_models_autoencoder_tiny.py +++ b/tests/models/autoencoders/test_models_autoencoder_tiny.py @@ -31,7 +31,8 @@ torch_all_close, torch_device, ) -from ..test_modeling_common import AutoencoderTesterMixin, ModelTesterMixin +from ..test_modeling_common import ModelTesterMixin +from .testing_utils import AutoencoderTesterMixin enable_full_determinism() diff --git a/tests/models/autoencoders/test_models_autoencoder_wan.py b/tests/models/autoencoders/test_models_autoencoder_wan.py index a476272a25c8..051098dc7aac 100644 --- a/tests/models/autoencoders/test_models_autoencoder_wan.py +++ b/tests/models/autoencoders/test_models_autoencoder_wan.py @@ -18,7 +18,8 @@ from diffusers import AutoencoderKLWan from ...testing_utils import enable_full_determinism, floats_tensor, torch_device -from ..test_modeling_common import AutoencoderTesterMixin, ModelTesterMixin +from ..test_modeling_common import ModelTesterMixin +from .testing_utils import AutoencoderTesterMixin enable_full_determinism() diff --git a/tests/models/autoencoders/test_models_consistency_decoder_vae.py b/tests/models/autoencoders/test_models_consistency_decoder_vae.py index b9820578ad80..ef04d151ecd1 100644 --- a/tests/models/autoencoders/test_models_consistency_decoder_vae.py +++ b/tests/models/autoencoders/test_models_consistency_decoder_vae.py @@ -30,7 +30,8 @@ torch_all_close, torch_device, ) -from ..test_modeling_common import AutoencoderTesterMixin, ModelTesterMixin +from ..test_modeling_common import ModelTesterMixin +from .testing_utils import AutoencoderTesterMixin enable_full_determinism() diff --git a/tests/models/autoencoders/testing_utils.py b/tests/models/autoencoders/testing_utils.py new file mode 100644 index 000000000000..da89f1b97dce --- /dev/null +++ b/tests/models/autoencoders/testing_utils.py @@ -0,0 +1,127 @@ +import inspect + +import numpy as np +import pytest +import torch + +from diffusers.utils.torch_utils import torch_device + + +class AutoencoderTesterMixin: + """ + Test mixin class specific to VAEs to test for slicing and tiling. Diffusion networks + usually don't do slicing and tiling. + """ + + @staticmethod + def _accepts_generator(model): + model_sig = inspect.signature(model.forward) + accepts_generator = "generator" in model_sig.parameters + return accepts_generator + + @staticmethod + def _accepts_norm_num_groups(model_class): + model_sig = inspect.signature(model_class.__init__) + accepts_norm_groups = "norm_num_groups" in model_sig.parameters + return accepts_norm_groups + + def test_forward_with_norm_groups(self): + if not self._accepts_norm_num_groups(self.model_class): + pytest.skip(f"Test not supported for {self.model_class.__name__}") + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["norm_num_groups"] = 16 + init_dict["block_out_channels"] = (16, 32) + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.to_tuple()[0] + + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + + def test_enable_disable_tiling(self): + if not hasattr(self.model_class, "enable_tiling"): + pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support tiling.") + + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + torch.manual_seed(0) + model = self.model_class(**init_dict).to(torch_device) + + inputs_dict.update({"return_dict": False}) + _ = inputs_dict.pop("generator", None) + accepts_generator = self._accepts_generator(model) + + torch.manual_seed(0) + if accepts_generator: + inputs_dict["generator"] = torch.manual_seed(0) + output_without_tiling = model(**inputs_dict)[0] + + torch.manual_seed(0) + model.enable_tiling() + if accepts_generator: + inputs_dict["generator"] = torch.manual_seed(0) + output_with_tiling = model(**inputs_dict)[0] + + assert ( + output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy() + ).max() < 0.5, "VAE tiling should not affect the inference results" + + torch.manual_seed(0) + model.disable_tiling() + if accepts_generator: + inputs_dict["generator"] = torch.manual_seed(0) + output_without_tiling_2 = model(**inputs_dict)[0] + + assert np.allclose( + output_without_tiling.detach().cpu().numpy().all(), + output_without_tiling_2.detach().cpu().numpy().all(), + ), "Without tiling outputs should match with the outputs when tiling is manually disabled." + + def test_enable_disable_slicing(self): + if not hasattr(self.model_class, "enable_slicing"): + pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support slicing.") + + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + torch.manual_seed(0) + model = self.model_class(**init_dict).to(torch_device) + + inputs_dict.update({"return_dict": False}) + _ = inputs_dict.pop("generator", None) + accepts_generator = self._accepts_generator(model) + + if accepts_generator: + inputs_dict["generator"] = torch.manual_seed(0) + + torch.manual_seed(0) + output_without_slicing = model(**inputs_dict)[0] + + torch.manual_seed(0) + model.enable_slicing() + if accepts_generator: + inputs_dict["generator"] = torch.manual_seed(0) + output_with_slicing = model(**inputs_dict)[0] + + assert ( + output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy() + ).max() < 0.5, "VAE slicing should not affect the inference results" + + torch.manual_seed(0) + model.disable_slicing() + if accepts_generator: + inputs_dict["generator"] = torch.manual_seed(0) + output_without_slicing_2 = model(**inputs_dict)[0] + + assert np.allclose( + output_without_slicing.detach().cpu().numpy().all(), + output_without_slicing_2.detach().cpu().numpy().all(), + ), "Without slicing outputs should match with the outputs when slicing is manually disabled." diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index aac6474c48af..90bb0b355d4a 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1949,98 +1949,6 @@ def test_passing_dict_device_map_works(self, name, device): _ = loaded_model(**inputs_dict) -class AutoencoderTesterMixin(UNetTesterMixin): - """ - Test mixin class specific to VAEs to test for slicing and tiling. Diffusion networks - usually don't do slicing and tiling. - """ - - @staticmethod - def _accepts_generator(model): - model_sig = inspect.signature(model.forward) - accepts_generator = "generator" in model_sig.parameters - return accepts_generator - - def test_enable_disable_tiling(self): - if not hasattr(self.model_class, "enable_tiling"): - pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support tiling.") - - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - torch.manual_seed(0) - model = self.model_class(**init_dict).to(torch_device) - - inputs_dict.update({"return_dict": False}) - _ = inputs_dict.pop("generator", None) - accepts_generator = self._accepts_generator(model) - - torch.manual_seed(0) - if accepts_generator: - inputs_dict["generator"] = torch.manual_seed(0) - output_without_tiling = model(**inputs_dict)[0] - - torch.manual_seed(0) - model.enable_tiling() - if accepts_generator: - inputs_dict["generator"] = torch.manual_seed(0) - output_with_tiling = model(**inputs_dict)[0] - - assert ( - output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy() - ).max() < 0.5, "VAE tiling should not affect the inference results" - - torch.manual_seed(0) - model.disable_tiling() - if accepts_generator: - inputs_dict["generator"] = torch.manual_seed(0) - output_without_tiling_2 = model(**inputs_dict)[0] - - assert np.allclose( - output_without_tiling.detach().cpu().numpy().all(), - output_without_tiling_2.detach().cpu().numpy().all(), - ), "Without tiling outputs should match with the outputs when tiling is manually disabled." - - def test_enable_disable_slicing(self): - if not hasattr(self.model_class, "enable_slicing"): - pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support slicing.") - - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - torch.manual_seed(0) - model = self.model_class(**init_dict).to(torch_device) - - inputs_dict.update({"return_dict": False}) - _ = inputs_dict.pop("generator", None) - accepts_generator = self._accepts_generator(model) - - if accepts_generator: - inputs_dict["generator"] = torch.manual_seed(0) - - torch.manual_seed(0) - output_without_slicing = model(**inputs_dict)[0] - - torch.manual_seed(0) - model.enable_slicing() - if accepts_generator: - inputs_dict["generator"] = torch.manual_seed(0) - output_with_slicing = model(**inputs_dict)[0] - - assert ( - output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy() - ).max() < 0.5, "VAE slicing should not affect the inference results" - - torch.manual_seed(0) - model.disable_slicing() - if accepts_generator: - inputs_dict["generator"] = torch.manual_seed(0) - output_without_slicing_2 = model(**inputs_dict)[0] - - assert np.allclose( - output_without_slicing.detach().cpu().numpy().all(), - output_without_slicing_2.detach().cpu().numpy().all(), - ), "Without slicing outputs should match with the outputs when slicing is manually disabled." - - @is_staging_test class ModelPushToHubTester(unittest.TestCase): identifier = uuid.uuid4()