Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@
torch_all_close,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin


enable_full_determinism()


class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AsymmetricAutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
model_class = AsymmetricAutoencoderKL
main_input_name = "sample"
base_precision = 1e-2
Expand Down
4 changes: 2 additions & 2 deletions tests/models/autoencoders/test_models_autoencoder_cosmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
from diffusers import AutoencoderKLCosmos

from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin


enable_full_determinism()


class AutoencoderKLCosmosTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderKLCosmosTests(ModelTesterMixin, unittest.TestCase):
model_class = AutoencoderKLCosmos
main_input_name = "sample"
base_precision = 1e-2
Expand Down
4 changes: 2 additions & 2 deletions tests/models/autoencoders/test_models_autoencoder_dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin


enable_full_determinism()


class AutoencoderDCTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderDCTests(ModelTesterMixin, unittest.TestCase):
model_class = AutoencoderDC
main_input_name = "sample"
base_precision = 1e-2
Expand Down
73 changes: 4 additions & 69 deletions tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,15 @@
from diffusers import AutoencoderKLHunyuanVideo
from diffusers.models.autoencoders.autoencoder_kl_hunyuan_video import prepare_causal_attention_mask

from ...testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin


enable_full_determinism()


class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLHunyuanVideo
main_input_name = "sample"
base_precision = 1e-2
Expand Down Expand Up @@ -87,68 +84,6 @@ def prepare_init_args_and_inputs_for_common(self):
inputs_dict = self.dummy_input
return init_dict, inputs_dict

def test_enable_disable_tiling(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)

inputs_dict.update({"return_dict": False})

torch.manual_seed(0)
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]

torch.manual_seed(0)
model.enable_tiling()
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertLess(
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
0.5,
"VAE tiling should not affect the inference results",
)

torch.manual_seed(0)
model.disable_tiling()
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertEqual(
output_without_tiling.detach().cpu().numpy().all(),
output_without_tiling_2.detach().cpu().numpy().all(),
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
)

def test_enable_disable_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)

inputs_dict.update({"return_dict": False})

torch.manual_seed(0)
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]

torch.manual_seed(0)
model.enable_slicing()
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertLess(
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
0.5,
"VAE slicing should not affect the inference results",
)

torch.manual_seed(0)
model.disable_slicing()
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertEqual(
output_without_slicing.detach().cpu().numpy().all(),
output_without_slicing_2.detach().cpu().numpy().all(),
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
)

def test_gradient_checkpointing_is_applied(self):
expected_set = {
"HunyuanVideoDecoder3D",
Expand Down
67 changes: 3 additions & 64 deletions tests/models/autoencoders/test_models_autoencoder_kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,14 @@
torch_all_close,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin


enable_full_determinism()


class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderKLTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKL
main_input_name = "sample"
base_precision = 1e-2
Expand Down Expand Up @@ -83,68 +84,6 @@ def prepare_init_args_and_inputs_for_common(self):
inputs_dict = self.dummy_input
return init_dict, inputs_dict

def test_enable_disable_tiling(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)

inputs_dict.update({"return_dict": False})

torch.manual_seed(0)
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]

torch.manual_seed(0)
model.enable_tiling()
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertLess(
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
0.5,
"VAE tiling should not affect the inference results",
)

torch.manual_seed(0)
model.disable_tiling()
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertEqual(
output_without_tiling.detach().cpu().numpy().all(),
output_without_tiling_2.detach().cpu().numpy().all(),
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
)

def test_enable_disable_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)

inputs_dict.update({"return_dict": False})

torch.manual_seed(0)
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]

torch.manual_seed(0)
model.enable_slicing()
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertLess(
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
0.5,
"VAE slicing should not affect the inference results",
)

torch.manual_seed(0)
model.disable_slicing()
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertEqual(
output_without_slicing.detach().cpu().numpy().all(),
output_without_slicing_2.detach().cpu().numpy().all(),
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
)

def test_gradient_checkpointing_is_applied(self):
expected_set = {"Decoder", "Encoder", "UNetMidBlock2D"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
Expand Down
67 changes: 3 additions & 64 deletions tests/models/autoencoders/test_models_autoencoder_kl_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,14 @@
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin


enable_full_determinism()


class AutoencoderKLCogVideoXTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderKLCogVideoXTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLCogVideoX
main_input_name = "sample"
base_precision = 1e-2
Expand Down Expand Up @@ -82,68 +83,6 @@ def prepare_init_args_and_inputs_for_common(self):
inputs_dict = self.dummy_input
return init_dict, inputs_dict

def test_enable_disable_tiling(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)

inputs_dict.update({"return_dict": False})

torch.manual_seed(0)
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]

torch.manual_seed(0)
model.enable_tiling()
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertLess(
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
0.5,
"VAE tiling should not affect the inference results",
)

torch.manual_seed(0)
model.disable_tiling()
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertEqual(
output_without_tiling.detach().cpu().numpy().all(),
output_without_tiling_2.detach().cpu().numpy().all(),
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
)

def test_enable_disable_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)

inputs_dict.update({"return_dict": False})

torch.manual_seed(0)
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]

torch.manual_seed(0)
model.enable_slicing()
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertLess(
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
0.5,
"VAE slicing should not affect the inference results",
)

torch.manual_seed(0)
model.disable_slicing()
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertEqual(
output_without_slicing.detach().cpu().numpy().all(),
output_without_slicing_2.detach().cpu().numpy().all(),
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
)

def test_gradient_checkpointing_is_applied(self):
expected_set = {
"CogVideoXDownBlock3D",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin


enable_full_determinism()


class AutoencoderKLTemporalDecoderTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderKLTemporalDecoderTests(ModelTesterMixin, unittest.TestCase):
model_class = AutoencoderKLTemporalDecoder
main_input_name = "sample"
base_precision = 1e-2
Expand Down
38 changes: 4 additions & 34 deletions tests/models/autoencoders/test_models_autoencoder_ltx_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,14 @@
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin


enable_full_determinism()


class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLLTXVideo
main_input_name = "sample"
base_precision = 1e-2
Expand Down Expand Up @@ -99,7 +100,7 @@ def test_forward_with_norm_groups(self):
pass


class AutoencoderKLLTXVideo091Tests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderKLLTXVideo091Tests(ModelTesterMixin, unittest.TestCase):
model_class = AutoencoderKLLTXVideo
main_input_name = "sample"
base_precision = 1e-2
Expand Down Expand Up @@ -167,34 +168,3 @@ def test_outputs_equivalence(self):
@unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.")
def test_forward_with_norm_groups(self):
pass

def test_enable_disable_tiling(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)

inputs_dict.update({"return_dict": False})

torch.manual_seed(0)
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]

torch.manual_seed(0)
model.enable_tiling()
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertLess(
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
0.5,
"VAE tiling should not affect the inference results",
)

torch.manual_seed(0)
model.disable_tiling()
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertEqual(
output_without_tiling.detach().cpu().numpy().all(),
output_without_tiling_2.detach().cpu().numpy().all(),
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
)
Loading
Loading