Skip to content

Commit 20fd00b

Browse files
authored
[Tests] Add single file tester mixin for Models and remove unittest dependency (#12352)
* update * update * update * update * update
1 parent 76d4e41 commit 20fd00b

23 files changed

+173
-390
lines changed

tests/single_file/single_file_testing_utils.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import gc
12
import tempfile
23
from io import BytesIO
34

@@ -9,7 +10,10 @@
910
from diffusers.models.attention_processor import AttnProcessor
1011

1112
from ..testing_utils import (
13+
backend_empty_cache,
14+
nightly,
1215
numpy_cosine_similarity_distance,
16+
require_torch_accelerator,
1317
torch_device,
1418
)
1519

@@ -47,6 +51,93 @@ def download_diffusers_config(repo_id, tmpdir):
4751
return path
4852

4953

54+
@nightly
55+
@require_torch_accelerator
56+
class SingleFileModelTesterMixin:
57+
def setup_method(self):
58+
gc.collect()
59+
backend_empty_cache(torch_device)
60+
61+
def teardown_method(self):
62+
gc.collect()
63+
backend_empty_cache(torch_device)
64+
65+
def test_single_file_model_config(self):
66+
pretrained_kwargs = {}
67+
single_file_kwargs = {}
68+
69+
if hasattr(self, "subfolder") and self.subfolder:
70+
pretrained_kwargs["subfolder"] = self.subfolder
71+
72+
if hasattr(self, "torch_dtype") and self.torch_dtype:
73+
pretrained_kwargs["torch_dtype"] = self.torch_dtype
74+
single_file_kwargs["torch_dtype"] = self.torch_dtype
75+
76+
model = self.model_class.from_pretrained(self.repo_id, **pretrained_kwargs)
77+
model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
78+
79+
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
80+
for param_name, param_value in model_single_file.config.items():
81+
if param_name in PARAMS_TO_IGNORE:
82+
continue
83+
assert model.config[param_name] == param_value, (
84+
f"{param_name} differs between pretrained loading and single file loading"
85+
)
86+
87+
def test_single_file_model_parameters(self):
88+
pretrained_kwargs = {}
89+
single_file_kwargs = {}
90+
91+
if hasattr(self, "subfolder") and self.subfolder:
92+
pretrained_kwargs["subfolder"] = self.subfolder
93+
94+
if hasattr(self, "torch_dtype") and self.torch_dtype:
95+
pretrained_kwargs["torch_dtype"] = self.torch_dtype
96+
single_file_kwargs["torch_dtype"] = self.torch_dtype
97+
98+
model = self.model_class.from_pretrained(self.repo_id, **pretrained_kwargs)
99+
model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
100+
101+
state_dict = model.state_dict()
102+
state_dict_single_file = model_single_file.state_dict()
103+
104+
assert set(state_dict.keys()) == set(state_dict_single_file.keys()), (
105+
"Model parameters keys differ between pretrained and single file loading"
106+
)
107+
108+
for key in state_dict.keys():
109+
param = state_dict[key]
110+
param_single_file = state_dict_single_file[key]
111+
112+
assert param.shape == param_single_file.shape, (
113+
f"Parameter shape mismatch for {key}: "
114+
f"pretrained {param.shape} vs single file {param_single_file.shape}"
115+
)
116+
117+
assert torch.allclose(param, param_single_file, rtol=1e-5, atol=1e-5), (
118+
f"Parameter values differ for {key}: "
119+
f"max difference {torch.max(torch.abs(param - param_single_file)).item()}"
120+
)
121+
122+
def test_checkpoint_altered_keys_loading(self):
123+
# Test loading with checkpoints that have altered keys
124+
if not hasattr(self, "alternate_keys_ckpt_paths") or not self.alternate_keys_ckpt_paths:
125+
return
126+
127+
for ckpt_path in self.alternate_keys_ckpt_paths:
128+
backend_empty_cache(torch_device)
129+
130+
single_file_kwargs = {}
131+
if hasattr(self, "torch_dtype") and self.torch_dtype:
132+
single_file_kwargs["torch_dtype"] = self.torch_dtype
133+
134+
model = self.model_class.from_single_file(ckpt_path, **single_file_kwargs)
135+
136+
del model
137+
gc.collect()
138+
backend_empty_cache(torch_device)
139+
140+
50141
class SDSingleFileTesterMixin:
51142
single_file_kwargs = {}
52143

tests/single_file/test_lumina2_transformer.py

Lines changed: 3 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -13,61 +13,26 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import gc
17-
import unittest
1816

1917
from diffusers import (
2018
Lumina2Transformer2DModel,
2119
)
2220

2321
from ..testing_utils import (
24-
backend_empty_cache,
2522
enable_full_determinism,
26-
require_torch_accelerator,
27-
torch_device,
2823
)
24+
from .single_file_testing_utils import SingleFileModelTesterMixin
2925

3026

3127
enable_full_determinism()
3228

3329

34-
@require_torch_accelerator
35-
class Lumina2Transformer2DModelSingleFileTests(unittest.TestCase):
30+
class TestLumina2Transformer2DModelSingleFile(SingleFileModelTesterMixin):
3631
model_class = Lumina2Transformer2DModel
3732
ckpt_path = "https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/diffusion_models/lumina_2_model_bf16.safetensors"
3833
alternate_keys_ckpt_paths = [
3934
"https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/diffusion_models/lumina_2_model_bf16.safetensors"
4035
]
4136

4237
repo_id = "Alpha-VLLM/Lumina-Image-2.0"
43-
44-
def setUp(self):
45-
super().setUp()
46-
gc.collect()
47-
backend_empty_cache(torch_device)
48-
49-
def tearDown(self):
50-
super().tearDown()
51-
gc.collect()
52-
backend_empty_cache(torch_device)
53-
54-
def test_single_file_components(self):
55-
model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
56-
model_single_file = self.model_class.from_single_file(self.ckpt_path)
57-
58-
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
59-
for param_name, param_value in model_single_file.config.items():
60-
if param_name in PARAMS_TO_IGNORE:
61-
continue
62-
assert model.config[param_name] == param_value, (
63-
f"{param_name} differs between single file loading and pretrained loading"
64-
)
65-
66-
def test_checkpoint_loading(self):
67-
for ckpt_path in self.alternate_keys_ckpt_paths:
68-
backend_empty_cache(torch_device)
69-
model = self.model_class.from_single_file(ckpt_path)
70-
71-
del model
72-
gc.collect()
73-
backend_empty_cache(torch_device)
38+
subfolder = "transformer"

tests/single_file/test_model_autoencoder_dc_single_file.py

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import gc
17-
import unittest
1816

1917
import torch
2018

@@ -23,38 +21,24 @@
2321
)
2422

2523
from ..testing_utils import (
26-
backend_empty_cache,
2724
enable_full_determinism,
2825
load_hf_numpy,
2926
numpy_cosine_similarity_distance,
30-
require_torch_accelerator,
31-
slow,
3227
torch_device,
3328
)
29+
from .single_file_testing_utils import SingleFileModelTesterMixin
3430

3531

3632
enable_full_determinism()
3733

3834

39-
@slow
40-
@require_torch_accelerator
41-
class AutoencoderDCSingleFileTests(unittest.TestCase):
35+
class TestAutoencoderDCSingleFile(SingleFileModelTesterMixin):
4236
model_class = AutoencoderDC
4337
ckpt_path = "https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.0/blob/main/model.safetensors"
4438
repo_id = "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"
4539
main_input_name = "sample"
4640
base_precision = 1e-2
4741

48-
def setUp(self):
49-
super().setUp()
50-
gc.collect()
51-
backend_empty_cache(torch_device)
52-
53-
def tearDown(self):
54-
super().tearDown()
55-
gc.collect()
56-
backend_empty_cache(torch_device)
57-
5842
def get_file_format(self, seed, shape):
5943
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
6044

@@ -80,18 +64,6 @@ def test_single_file_inference_same_as_pretrained(self):
8064

8165
assert numpy_cosine_similarity_distance(output_slice_1, output_slice_2) < 1e-4
8266

83-
def test_single_file_components(self):
84-
model = self.model_class.from_pretrained(self.repo_id)
85-
model_single_file = self.model_class.from_single_file(self.ckpt_path)
86-
87-
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
88-
for param_name, param_value in model_single_file.config.items():
89-
if param_name in PARAMS_TO_IGNORE:
90-
continue
91-
assert model.config[param_name] == param_value, (
92-
f"{param_name} differs between pretrained loading and single file loading"
93-
)
94-
9567
def test_single_file_in_type_variant_components(self):
9668
# `in` variant checkpoints require passing in a `config` parameter
9769
# in order to set the scaling factor correctly.

tests/single_file/test_model_controlnet_single_file.py

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import gc
17-
import unittest
1816

1917
import torch
2018

@@ -23,46 +21,19 @@
2321
)
2422

2523
from ..testing_utils import (
26-
backend_empty_cache,
2724
enable_full_determinism,
28-
require_torch_accelerator,
29-
slow,
30-
torch_device,
3125
)
26+
from .single_file_testing_utils import SingleFileModelTesterMixin
3227

3328

3429
enable_full_determinism()
3530

3631

37-
@slow
38-
@require_torch_accelerator
39-
class ControlNetModelSingleFileTests(unittest.TestCase):
32+
class TestControlNetModelSingleFile(SingleFileModelTesterMixin):
4033
model_class = ControlNetModel
4134
ckpt_path = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
4235
repo_id = "lllyasviel/control_v11p_sd15_canny"
4336

44-
def setUp(self):
45-
super().setUp()
46-
gc.collect()
47-
backend_empty_cache(torch_device)
48-
49-
def tearDown(self):
50-
super().tearDown()
51-
gc.collect()
52-
backend_empty_cache(torch_device)
53-
54-
def test_single_file_components(self):
55-
model = self.model_class.from_pretrained(self.repo_id)
56-
model_single_file = self.model_class.from_single_file(self.ckpt_path)
57-
58-
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
59-
for param_name, param_value in model_single_file.config.items():
60-
if param_name in PARAMS_TO_IGNORE:
61-
continue
62-
assert model.config[param_name] == param_value, (
63-
f"{param_name} differs between single file loading and pretrained loading"
64-
)
65-
6637
def test_single_file_arguments(self):
6738
model_default = self.model_class.from_single_file(self.ckpt_path)
6839

tests/single_file/test_model_flux_transformer_single_file.py

Lines changed: 3 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# limitations under the License.
1515

1616
import gc
17-
import unittest
1817

1918
from diffusers import (
2019
FluxTransformer2DModel,
@@ -23,52 +22,21 @@
2322
from ..testing_utils import (
2423
backend_empty_cache,
2524
enable_full_determinism,
26-
require_torch_accelerator,
2725
torch_device,
2826
)
27+
from .single_file_testing_utils import SingleFileModelTesterMixin
2928

3029

3130
enable_full_determinism()
3231

3332

34-
@require_torch_accelerator
35-
class FluxTransformer2DModelSingleFileTests(unittest.TestCase):
33+
class TestFluxTransformer2DModelSingleFile(SingleFileModelTesterMixin):
3634
model_class = FluxTransformer2DModel
3735
ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
3836
alternate_keys_ckpt_paths = ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"]
3937

4038
repo_id = "black-forest-labs/FLUX.1-dev"
41-
42-
def setUp(self):
43-
super().setUp()
44-
gc.collect()
45-
backend_empty_cache(torch_device)
46-
47-
def tearDown(self):
48-
super().tearDown()
49-
gc.collect()
50-
backend_empty_cache(torch_device)
51-
52-
def test_single_file_components(self):
53-
model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
54-
model_single_file = self.model_class.from_single_file(self.ckpt_path)
55-
56-
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
57-
for param_name, param_value in model_single_file.config.items():
58-
if param_name in PARAMS_TO_IGNORE:
59-
continue
60-
assert model.config[param_name] == param_value, (
61-
f"{param_name} differs between single file loading and pretrained loading"
62-
)
63-
64-
def test_checkpoint_loading(self):
65-
for ckpt_path in self.alternate_keys_ckpt_paths:
66-
backend_empty_cache(torch_device)
67-
model = self.model_class.from_single_file(ckpt_path)
68-
69-
del model
70-
gc.collect()
71-
backend_empty_cache(torch_device)
39+
subfolder = "transformer"
7240

7341
def test_device_map_cuda(self):
7442
backend_empty_cache(torch_device)

tests/single_file/test_model_motion_adapter_single_file.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import unittest
1716

1817
from diffusers import (
1918
MotionAdapter,
@@ -27,7 +26,7 @@
2726
enable_full_determinism()
2827

2928

30-
class MotionAdapterSingleFileTests(unittest.TestCase):
29+
class MotionAdapterSingleFileTests:
3130
model_class = MotionAdapter
3231

3332
def test_single_file_components_version_v1_5(self):

0 commit comments

Comments
 (0)