Skip to content

Commit 370f382

Browse files
flux controlnet mode to take into account batch size
1 parent f28a8c2 commit 370f382

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

src/diffusers/pipelines/flux/pipeline_flux_controlnet.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -748,9 +748,11 @@ def __call__(
748748
)
749749

750750
# set control mode
751+
orig_mode_type = type(control_mode)
751752
if control_mode is not None:
752-
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
753-
control_mode = control_mode.reshape([-1, 1])
753+
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long).view(-1,1)
754+
if orig_mode_type == int:
755+
control_mode = control_mode.repeat(control_image.shape[0], 1)
754756

755757
elif isinstance(self.controlnet, FluxMultiControlNetModel):
756758
control_images = []
@@ -793,8 +795,10 @@ def __call__(
793795
control_mode_.append(-1)
794796
else:
795797
control_mode_.append(cmode)
796-
control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long)
797-
control_mode = control_mode.reshape([-1, 1])
798+
control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long)
799+
control_mode = control_mode.view(-1, 1)
800+
else:
801+
raise ValueError("For multi-controlnet, control_mode should be a list")
798802

799803
# 4. Prepare latent variables
800804
num_channels_latents = self.transformer.config.in_channels // 4

0 commit comments

Comments
 (0)