Skip to content
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
6001899
xx
yao-matrix Jun 10, 2025
8a1d6e5
fix
yao-matrix Jun 11, 2025
603257b
Update model_loading_utils.py
yao-matrix Jun 11, 2025
8cdfdd8
Update test_models_unet_2d_condition.py
yao-matrix Jun 11, 2025
45e29bd
Update test_models_unet_2d_condition.py
yao-matrix Jun 11, 2025
2a7c17d
Merge branch 'main' into xpu
yao-matrix Jun 11, 2025
fae7c70
fix style
yao-matrix Jun 11, 2025
80fdbfc
Merge branch 'main' into xpu
yao-matrix Jun 11, 2025
97a37a1
Merge branch 'main' into xpu
yao-matrix Jun 11, 2025
5f0c794
Merge branch 'main' into xpu
yao-matrix Jun 12, 2025
8cd06b3
Merge branch 'main' into xpu
yao-matrix Jun 13, 2025
02a6a35
Merge branch 'main' into xpu
yao-matrix Jun 17, 2025
ed1a788
Merge branch 'main' into xpu
yao-matrix Jun 18, 2025
220ce94
Merge branch 'main' into xpu
yao-matrix Jun 18, 2025
e59cb0c
Merge branch 'main' into xpu
yao-matrix Jun 24, 2025
fd618b5
Merge branch 'main' into xpu
yao-matrix Jun 24, 2025
c340f9e
Merge branch 'main' into xpu
yao-matrix Jun 24, 2025
e674ce7
Merge branch 'main' into xpu
yao-matrix Jun 27, 2025
d389758
Merge branch 'main' into xpu
yao-matrix Jun 30, 2025
7e8ae22
Merge branch 'main' into xpu
yao-matrix Jul 1, 2025
c43bb19
Merge branch 'main' into xpu
yao-matrix Jul 3, 2025
49ac5d4
Merge branch 'main' into xpu
yao-matrix Jul 7, 2025
b7148d6
Merge branch 'main' into xpu
yao-matrix Jul 8, 2025
bda0afd
Merge branch 'main' into xpu
yao-matrix Jul 18, 2025
1ba8a88
fix comments
yao-matrix Jul 18, 2025
fd9fa99
Update unet_2d_blocks.py
yao-matrix Jul 18, 2025
692f0bd
Merge branch 'main' into xpu
sayakpaul Jul 18, 2025
ab5f55c
Merge branch 'main' into xpu
sayakpaul Jul 18, 2025
9b41f3a
Merge branch 'main' into xpu
yao-matrix Jul 19, 2025
9948c9c
update
yao-matrix Jul 19, 2025
38ff983
Merge branch 'main' into xpu
yao-matrix Jul 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/diffusers/models/unets/unet_2d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2557,7 +2557,8 @@ def forward(
b1=self.b1,
b2=self.b2,
)

if hidden_states.device != res_hidden_states.device:
res_hidden_states = res_hidden_states.to(hidden_states.device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we shouldn't need that since both hidden_states and res_hidden_states should be on the same device no ? The pre-forward hook added by accelerate should be move all the inputs to the same device.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SunMarc , i suppose this is a corner case? torch.cat is a weight-less function, so seems cannot covered by the pre-forward hook set by accelerate...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean since hidden_states and res_hidden_states_tuple are in the forward definition, they should be moved to the same device by the pre-forward hook added by accelerate

Copy link
Contributor Author

@yao-matrix yao-matrix Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SunMarc We run into a corner case here. Since we have 8 cards here, so the determined device_map(by https://github.com/huggingface/diffusers/blob/1bc6f3dc0f21779480db70a4928d14282c0198ed/src/diffusers/models/model_loading_utils.py#L64C5-L64C26) is

device_map: OrderedDict([('conv_in', 0), ('time_proj', 0), ('time_embedding', 0), ('down_blocks.0', 0), ('down_blocks.1.resnets.0', 1), ('up_blocks.0.resnets.0', 1), ('up_blocks.0.resnets.1', 2), ('up_blocks.0.upsamplers', 2), ('up_blocks.1', 3), ('mid_block.attentions', 3), ('conv_norm_out', 4), ('conv_act', 4), ('conv_out', 4), ('mid_block.resnets', 4)])

We can see UpBlock is not the atomic module, its submodules are assigned to different devices(up_blocks.0.resnets.0, up_blocks.0.resnets.1), so pre-hook for UpBlock will not help in this case. And since torch.cat is not pre-hooked(and cannot since it's a function rather than a module?), so the issue happens.

If there is no a torch.cat btw the sub-blocks in UpBlock, things will be all fine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SunMarc, need your inputs in how to proceed for this corner case, thx.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SunMarc We can see a similar case in transformers ut pytest -rA tests/models/chameleon/test_modeling_chameleon.py::ChameleonVision2SeqModelTest::test_model_parallel_beam_search w/ 2 cards, the error log is "RuntimeError: Expected all tensors to be on the same device, but found at least two devices,src/transformers/models/chameleon/modeling_chameleon.py", the reason is even residual is in the same device as hidden_states at the beginning, but after they went through some operators as both input and output, they finally placed to different device, but when they come to + which is not a nn.Module(so accelerate cannot pre-hook it), error happens. Do you have some insights on such issues?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SunMarc, could you share your insights on the issue i mentioned above? thx very much.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the long wait @yao-matrix , if you add UpBlock2D in _no_split_modules of UNet2DConditionModel, the test should pass !
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"]
I've tested on my end and it works.

As for ChameleonVision2SeqModelTest, we probably need to also update _no_split_modules.

hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)

if torch.is_grad_enabled() and self.gradient_checkpointing:
Expand Down
5 changes: 2 additions & 3 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@
require_torch_2,
require_torch_accelerator,
require_torch_accelerator_with_training,
require_torch_gpu,
require_torch_multi_accelerator,
require_torch_version_greater,
run_test_in_subprocess,
Expand Down Expand Up @@ -1902,7 +1901,7 @@ def test_push_to_hub_library_name(self):
delete_repo(self.repo_id, token=TOKEN)


@require_torch_gpu
@require_torch_accelerator
@require_torch_2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is unrelated to this PR. Going forward prefer not including unrelated changes in a particular PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, will follow the rule going forward.

@is_torch_compile
@slow
Expand Down Expand Up @@ -1970,7 +1969,7 @@ def test_compile_with_group_offloading(self):
model.eval()
# TODO: Can test for other group offloading kwargs later if needed.
group_offload_kwargs = {
"onload_device": "cuda",
"onload_device": torch_device,
"offload_device": "cpu",
"offload_type": "block_level",
"num_blocks_per_group": 1,
Expand Down