From 6001899f30abd180bd1b74839ace5372f67cc8cf Mon Sep 17 00:00:00 2001 From: YAO Matrix Date: Tue, 10 Jun 2025 23:21:44 +0000 Subject: [PATCH 1/9] xx --- src/diffusers/models/model_loading_utils.py | 1 + src/diffusers/models/modeling_utils.py | 2 ++ tests/models/test_modeling_common.py | 4 ++-- tests/models/unets/test_models_unet_2d_condition.py | 5 +++++ 4 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index ebc7d79aeb28..3330cab61655 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -107,6 +107,7 @@ def _determine_device_map( device_map_kwargs["max_memory"] = max_memory device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs) + print(f"333333 device_map: {device_map}") if hf_quantizer is not None: hf_quantizer.validate_environment(device_map=device_map) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 55ce0cf79fb9..1b2a57b35eb9 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1201,9 +1201,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P ) # Now that the model is loaded, we can determine the device_map + print(f"111111 device_map: {device_map}") device_map = _determine_device_map( model, device_map, max_memory, torch_dtype, keep_in_fp32_modules, hf_quantizer ) + print(f"222222 device_map: {device_map}") if hf_quantizer is not None: hf_quantizer.validate_environment(device_map=device_map) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 5087bd0094a5..448e4d076b71 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1744,7 +1744,7 @@ def test_push_to_hub_library_name(self): delete_repo(self.repo_id, token=TOKEN) -@require_torch_gpu +@require_torch_accelerator @require_torch_2 @is_torch_compile @slow @@ -1789,7 +1789,7 @@ def test_compile_with_group_offloading(self): model.eval() # TODO: Can test for other group offloading kwargs later if needed. group_offload_kwargs = { - "onload_device": "cuda", + "onload_device": torch_device, "offload_device": "cpu", "offload_type": "block_level", "num_blocks_per_group": 1, diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index ab0dcbc1de11..c542c19b9b79 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -1015,6 +1015,8 @@ def test_load_sharded_checkpoint_from_hub_local(self): _, inputs_dict = self.prepare_init_args_and_inputs_for_common() ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy") loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True) + from torchviz import make_dot + make_dot(yhat, params=dict(list(loaded_model.named_parameters()))).render("unet_torchviz", format="png") loaded_model = loaded_model.to(torch_device) new_output = loaded_model(**inputs_dict) @@ -1067,8 +1069,11 @@ def test_load_sharded_checkpoint_device_map_from_hub_local(self): _, inputs_dict = self.prepare_init_args_and_inputs_for_common() ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy") loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True, device_map="auto") + from torchviz import make_dot new_output = loaded_model(**inputs_dict) + make_dot(new_output.sample, params=dict(loaded_model.named_parameters())).render("unet", format="png") + assert loaded_model assert new_output.sample.shape == (4, 4, 16, 16) From 8a1d6e5d885f35feed34fc15b314924dba9e10ac Mon Sep 17 00:00:00 2001 From: YAO Matrix Date: Wed, 11 Jun 2025 01:14:14 +0000 Subject: [PATCH 2/9] fix Signed-off-by: YAO Matrix --- src/diffusers/models/modeling_utils.py | 2 -- src/diffusers/models/unets/unet_2d_blocks.py | 3 ++- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 1b2a57b35eb9..55ce0cf79fb9 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1201,11 +1201,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P ) # Now that the model is loaded, we can determine the device_map - print(f"111111 device_map: {device_map}") device_map = _determine_device_map( model, device_map, max_memory, torch_dtype, keep_in_fp32_modules, hf_quantizer ) - print(f"222222 device_map: {device_map}") if hf_quantizer is not None: hf_quantizer.validate_environment(device_map=device_map) diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py index e082d524e766..f29680bc4c17 100644 --- a/src/diffusers/models/unets/unet_2d_blocks.py +++ b/src/diffusers/models/unets/unet_2d_blocks.py @@ -2557,7 +2557,8 @@ def forward( b1=self.b1, b2=self.b2, ) - + if hidden_states.device != res_hidden_states.device: + res_hidden_states = res_hidden_states.to(hidden_states.device) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if torch.is_grad_enabled() and self.gradient_checkpointing: From 603257bd997e982c3e6515a35bb0a84f702480c1 Mon Sep 17 00:00:00 2001 From: Yao Matrix Date: Wed, 11 Jun 2025 09:16:48 +0800 Subject: [PATCH 3/9] Update model_loading_utils.py --- src/diffusers/models/model_loading_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 3330cab61655..ebc7d79aeb28 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -107,7 +107,6 @@ def _determine_device_map( device_map_kwargs["max_memory"] = max_memory device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs) - print(f"333333 device_map: {device_map}") if hf_quantizer is not None: hf_quantizer.validate_environment(device_map=device_map) From 8cdfdd8e775bd94cd23dd5915d0f923a73495b3c Mon Sep 17 00:00:00 2001 From: Yao Matrix Date: Wed, 11 Jun 2025 09:17:42 +0800 Subject: [PATCH 4/9] Update test_models_unet_2d_condition.py --- tests/models/unets/test_models_unet_2d_condition.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index c542c19b9b79..0518a5c8e2c1 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -1015,8 +1015,6 @@ def test_load_sharded_checkpoint_from_hub_local(self): _, inputs_dict = self.prepare_init_args_and_inputs_for_common() ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy") loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True) - from torchviz import make_dot - make_dot(yhat, params=dict(list(loaded_model.named_parameters()))).render("unet_torchviz", format="png") loaded_model = loaded_model.to(torch_device) new_output = loaded_model(**inputs_dict) From 45e29bdff51db8e94c6200eb88ac982b5df2a86a Mon Sep 17 00:00:00 2001 From: Yao Matrix Date: Wed, 11 Jun 2025 09:18:16 +0800 Subject: [PATCH 5/9] Update test_models_unet_2d_condition.py --- tests/models/unets/test_models_unet_2d_condition.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index 0518a5c8e2c1..ab0dcbc1de11 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -1067,11 +1067,8 @@ def test_load_sharded_checkpoint_device_map_from_hub_local(self): _, inputs_dict = self.prepare_init_args_and_inputs_for_common() ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy") loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True, device_map="auto") - from torchviz import make_dot new_output = loaded_model(**inputs_dict) - make_dot(new_output.sample, params=dict(loaded_model.named_parameters())).render("unet", format="png") - assert loaded_model assert new_output.sample.shape == (4, 4, 16, 16) From fae7c7064b6cd7b14d1bd6c83f4a721e1a54816c Mon Sep 17 00:00:00 2001 From: YAO Matrix Date: Wed, 11 Jun 2025 13:29:46 +0000 Subject: [PATCH 6/9] fix style Signed-off-by: YAO Matrix --- tests/models/test_modeling_common.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 448e4d076b71..aa6db128d5e8 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -70,7 +70,6 @@ require_torch_2, require_torch_accelerator, require_torch_accelerator_with_training, - require_torch_gpu, require_torch_multi_accelerator, run_test_in_subprocess, slow, From 1ba8a88dc14f8c3528e111d23aeefba181e08d3c Mon Sep 17 00:00:00 2001 From: Matrix Yao Date: Fri, 18 Jul 2025 15:01:09 +0000 Subject: [PATCH 7/9] fix comments Signed-off-by: Matrix Yao --- src/diffusers/models/unets/unet_2d_blocks.py | 2 -- src/diffusers/models/unets/unet_2d_condition.py | 2 +- tests/models/test_modeling_common.py | 8 ++++---- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py index d7778b04c1c8..edf171d78e87 100644 --- a/src/diffusers/models/unets/unet_2d_blocks.py +++ b/src/diffusers/models/unets/unet_2d_blocks.py @@ -2557,8 +2557,6 @@ def forward( b1=self.b1, b2=self.b2, ) - if hidden_states.device != res_hidden_states.device: - res_hidden_states = res_hidden_states.to(hidden_states.device) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if torch.is_grad_enabled() and self.gradient_checkpointing: diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 0f789d3961fc..736deb28c376 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -165,7 +165,7 @@ class conditioning with `class_embed_type` equal to `None`. """ _supports_gradient_checkpointing = True - _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"] + _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"] _skip_layerwise_casting_patterns = ["norm"] _repeated_blocks = ["BasicTransformerBlock"] diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 13b799e9aee6..8309700ce106 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1828,8 +1828,8 @@ def test_wrong_device_map_raises_error(self, device_map, msg_substring): assert msg_substring in str(err_ctx.exception) - @parameterized.expand([0, "cuda", torch.device("cuda")]) - @require_torch_gpu + @parameterized.expand([0, torch_device, torch.device(torch_device)]) + @require_torch_accelerator def test_passing_non_dict_device_map_works(self, device_map): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict).eval() @@ -1838,8 +1838,8 @@ def test_passing_non_dict_device_map_works(self, device_map): loaded_model = self.model_class.from_pretrained(tmpdir, device_map=device_map) _ = loaded_model(**inputs_dict) - @parameterized.expand([("", "cuda"), ("", torch.device("cuda"))]) - @require_torch_gpu + @parameterized.expand([("", torch_device), ("", torch.device(torch_device))]) + @require_torch_accelerator def test_passing_dict_device_map_works(self, name, device): # There are other valid dict-based `device_map` values too. It's best to refer to # the docs for those: https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#the-devicemap. From fd9fa9912172dc3dc2409cf7e6418e76c5bb4469 Mon Sep 17 00:00:00 2001 From: Yao Matrix Date: Fri, 18 Jul 2025 08:07:38 -0700 Subject: [PATCH 8/9] Update unet_2d_blocks.py --- src/diffusers/models/unets/unet_2d_blocks.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py index edf171d78e87..94a9245e567c 100644 --- a/src/diffusers/models/unets/unet_2d_blocks.py +++ b/src/diffusers/models/unets/unet_2d_blocks.py @@ -2557,6 +2557,7 @@ def forward( b1=self.b1, b2=self.b2, ) + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if torch.is_grad_enabled() and self.gradient_checkpointing: From 9948c9c12ea8b20295c52ba5f9312fb109a3fd37 Mon Sep 17 00:00:00 2001 From: Matrix Yao Date: Sat, 19 Jul 2025 14:02:05 +0000 Subject: [PATCH 9/9] update Signed-off-by: Matrix Yao --- tests/models/unets/test_models_unet_2d_condition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index abf44aa7447b..123dff16f8b0 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -358,7 +358,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test model_class = UNet2DConditionModel main_input_name = "sample" # We override the items here because the unet under consideration is small. - model_split_percents = [0.5, 0.3, 0.4] + model_split_percents = [0.5, 0.34, 0.4] @property def dummy_input(self):