Skip to content

Commit f8b6bb0

Browse files
remove negative index bug in multicontrolnet, also support both regular and union controlnets for multi
1 parent 370f382 commit f8b6bb0

File tree

2 files changed

+10
-11
lines changed

2 files changed

+10
-11
lines changed

src/diffusers/models/controlnet_flux.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,13 +479,20 @@ def forward(
479479
# Regular Multi-ControlNets
480480
# load all ControlNets into memories
481481
else:
482+
482483
for i, (image, mode, scale, controlnet) in enumerate(
483484
zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets)
484485
):
486+
if mode is not None:
487+
mode_ = torch.LongTensor([[mode]]).to(device=image.device)
488+
else:
489+
mode_ = mode
490+
if controlnet.union and mode_ is None:
491+
raise ValueError(f"controlnet #{i} in `self.nets` is a union model, controlnet_mode[{i}] cannot be `None`")
485492
block_samples, single_block_samples = controlnet(
486493
hidden_states=hidden_states,
487494
controlnet_cond=image,
488-
controlnet_mode=mode[:, None],
495+
controlnet_mode=mode_,
489496
conditioning_scale=scale,
490497
timestep=timestep,
491498
guidance=guidance,

src/diffusers/pipelines/flux/pipeline_flux_controlnet.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,7 @@ def __call__(
749749

750750
# set control mode
751751
orig_mode_type = type(control_mode)
752+
752753
if control_mode is not None:
753754
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long).view(-1,1)
754755
if orig_mode_type == int:
@@ -788,16 +789,7 @@ def __call__(
788789
control_image = control_images
789790

790791
# set control mode
791-
control_mode_ = []
792-
if isinstance(control_mode, list):
793-
for cmode in control_mode:
794-
if cmode is None:
795-
control_mode_.append(-1)
796-
else:
797-
control_mode_.append(cmode)
798-
control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long)
799-
control_mode = control_mode.view(-1, 1)
800-
else:
792+
if not isinstance(control_mode, list):
801793
raise ValueError("For multi-controlnet, control_mode should be a list")
802794

803795
# 4. Prepare latent variables

0 commit comments

Comments
 (0)