Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 2 additions & 64 deletions tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin, VAETestMixin


enable_full_determinism()


class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, VAETestMixin, UNetTesterMixin, unittest.TestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit. Let's call it AutoencoderTesterMixin to be consistent. Also it looks like UNetTesterMixin is used in all the autoencoder tests? It only has a single test method. Let's add it to the AutoencoderTesterMixin and just have these classes inherit from their specific Mixin

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

model_class = AutoencoderKLHunyuanVideo
main_input_name = "sample"
base_precision = 1e-2
Expand Down Expand Up @@ -87,68 +87,6 @@ def prepare_init_args_and_inputs_for_common(self):
inputs_dict = self.dummy_input
return init_dict, inputs_dict

def test_enable_disable_tiling(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)

inputs_dict.update({"return_dict": False})

torch.manual_seed(0)
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]

torch.manual_seed(0)
model.enable_tiling()
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertLess(
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
0.5,
"VAE tiling should not affect the inference results",
)

torch.manual_seed(0)
model.disable_tiling()
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertEqual(
output_without_tiling.detach().cpu().numpy().all(),
output_without_tiling_2.detach().cpu().numpy().all(),
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
)

def test_enable_disable_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)

inputs_dict.update({"return_dict": False})

torch.manual_seed(0)
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]

torch.manual_seed(0)
model.enable_slicing()
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertLess(
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
0.5,
"VAE slicing should not affect the inference results",
)

torch.manual_seed(0)
model.disable_slicing()
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertEqual(
output_without_slicing.detach().cpu().numpy().all(),
output_without_slicing_2.detach().cpu().numpy().all(),
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
)

def test_gradient_checkpointing_is_applied(self):
expected_set = {
"HunyuanVideoDecoder3D",
Expand Down
66 changes: 2 additions & 64 deletions tests/models/autoencoders/test_models_autoencoder_kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@
torch_all_close,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin, VAETestMixin


enable_full_determinism()


class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, VAETestMixin, unittest.TestCase):
model_class = AutoencoderKL
main_input_name = "sample"
base_precision = 1e-2
Expand Down Expand Up @@ -83,68 +83,6 @@ def prepare_init_args_and_inputs_for_common(self):
inputs_dict = self.dummy_input
return init_dict, inputs_dict

def test_enable_disable_tiling(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)

inputs_dict.update({"return_dict": False})

torch.manual_seed(0)
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]

torch.manual_seed(0)
model.enable_tiling()
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertLess(
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
0.5,
"VAE tiling should not affect the inference results",
)

torch.manual_seed(0)
model.disable_tiling()
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertEqual(
output_without_tiling.detach().cpu().numpy().all(),
output_without_tiling_2.detach().cpu().numpy().all(),
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
)

def test_enable_disable_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)

inputs_dict.update({"return_dict": False})

torch.manual_seed(0)
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]

torch.manual_seed(0)
model.enable_slicing()
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertLess(
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
0.5,
"VAE slicing should not affect the inference results",
)

torch.manual_seed(0)
model.disable_slicing()
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertEqual(
output_without_slicing.detach().cpu().numpy().all(),
output_without_slicing_2.detach().cpu().numpy().all(),
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
)

def test_gradient_checkpointing_is_applied(self):
expected_set = {"Decoder", "Encoder", "UNetMidBlock2D"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
Expand Down
66 changes: 2 additions & 64 deletions tests/models/autoencoders/test_models_autoencoder_kl_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin, VAETestMixin


enable_full_determinism()


class AutoencoderKLCogVideoXTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderKLCogVideoXTests(ModelTesterMixin, VAETestMixin, UNetTesterMixin, unittest.TestCase):
model_class = AutoencoderKLCogVideoX
main_input_name = "sample"
base_precision = 1e-2
Expand Down Expand Up @@ -82,68 +82,6 @@ def prepare_init_args_and_inputs_for_common(self):
inputs_dict = self.dummy_input
return init_dict, inputs_dict

def test_enable_disable_tiling(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)

inputs_dict.update({"return_dict": False})

torch.manual_seed(0)
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]

torch.manual_seed(0)
model.enable_tiling()
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertLess(
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
0.5,
"VAE tiling should not affect the inference results",
)

torch.manual_seed(0)
model.disable_tiling()
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertEqual(
output_without_tiling.detach().cpu().numpy().all(),
output_without_tiling_2.detach().cpu().numpy().all(),
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
)

def test_enable_disable_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)

inputs_dict.update({"return_dict": False})

torch.manual_seed(0)
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]

torch.manual_seed(0)
model.enable_slicing()
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertLess(
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
0.5,
"VAE slicing should not affect the inference results",
)

torch.manual_seed(0)
model.disable_slicing()
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertEqual(
output_without_slicing.detach().cpu().numpy().all(),
output_without_slicing_2.detach().cpu().numpy().all(),
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
)

def test_gradient_checkpointing_is_applied(self):
expected_set = {
"CogVideoXDownBlock3D",
Expand Down
35 changes: 2 additions & 33 deletions tests/models/autoencoders/test_models_autoencoder_ltx_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin, VAETestMixin


enable_full_determinism()


class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, VAETestMixin, UNetTesterMixin, unittest.TestCase):
model_class = AutoencoderKLLTXVideo
main_input_name = "sample"
base_precision = 1e-2
Expand Down Expand Up @@ -167,34 +167,3 @@ def test_outputs_equivalence(self):
@unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.")
def test_forward_with_norm_groups(self):
pass

def test_enable_disable_tiling(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)

inputs_dict.update({"return_dict": False})

torch.manual_seed(0)
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]

torch.manual_seed(0)
model.enable_tiling()
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertLess(
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
0.5,
"VAE tiling should not affect the inference results",
)

torch.manual_seed(0)
model.disable_tiling()
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertEqual(
output_without_tiling.detach().cpu().numpy().all(),
output_without_tiling_2.detach().cpu().numpy().all(),
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
)
35 changes: 2 additions & 33 deletions tests/models/autoencoders/test_models_autoencoder_tiny.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@
torch_all_close,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin, VAETestMixin


enable_full_determinism()


class AutoencoderTinyTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderTinyTests(ModelTesterMixin, VAETestMixin, UNetTesterMixin, unittest.TestCase):
model_class = AutoencoderTiny
main_input_name = "sample"
base_precision = 1e-2
Expand Down Expand Up @@ -81,37 +81,6 @@ def prepare_init_args_and_inputs_for_common(self):
def test_enable_disable_tiling(self):
pass

def test_enable_disable_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)

inputs_dict.update({"return_dict": False})

torch.manual_seed(0)
output_without_slicing = model(**inputs_dict)[0]

torch.manual_seed(0)
model.enable_slicing()
output_with_slicing = model(**inputs_dict)[0]

self.assertLess(
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
0.5,
"VAE slicing should not affect the inference results",
)

torch.manual_seed(0)
model.disable_slicing()
output_without_slicing_2 = model(**inputs_dict)[0]

self.assertEqual(
output_without_slicing.detach().cpu().numpy().all(),
output_without_slicing_2.detach().cpu().numpy().all(),
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
)

@unittest.skip("Test not supported.")
def test_outputs_equivalence(self):
pass
Expand Down
Loading