File tree Expand file tree Collapse file tree 1 file changed +8
-4
lines changed 
src/diffusers/pipelines/flux Expand file tree Collapse file tree 1 file changed +8
-4
lines changed Original file line number Diff line number Diff line change @@ -748,9 +748,11 @@ def __call__(
748748            )
749749
750750            # set control mode 
751+             orig_mode_type  =  type (control_mode )
751752            if  control_mode  is  not None :
752-                 control_mode  =  torch .tensor (control_mode ).to (device , dtype = torch .long )
753-                 control_mode  =  control_mode .reshape ([- 1 , 1 ])
753+                 control_mode  =  torch .tensor (control_mode ).to (device , dtype = torch .long ).view (- 1 ,1 )
754+                 if  orig_mode_type  ==  int :
755+                     control_mode  =  control_mode .repeat (control_image .shape [0 ], 1 )
754756
755757        elif  isinstance (self .controlnet , FluxMultiControlNetModel ):
756758            control_images  =  []
@@ -793,8 +795,10 @@ def __call__(
793795                        control_mode_ .append (- 1 )
794796                    else :
795797                        control_mode_ .append (cmode )
796-             control_mode  =  torch .tensor (control_mode_ ).to (device , dtype = torch .long )
797-             control_mode  =  control_mode .reshape ([- 1 , 1 ])
798+                 control_mode  =  torch .tensor (control_mode_ ).to (device , dtype = torch .long )
799+                 control_mode  =  control_mode .view (- 1 , 1 )
800+             else :
801+                 raise  ValueError ("For multi-controlnet, control_mode should be a list" )
798802
799803        # 4. Prepare latent variables 
800804        num_channels_latents  =  self .transformer .config .in_channels  //  4 
 
 
   
 
     
   
   
          
    
    
     
    
      
     
     
    You can’t perform that action at this time.
  
 
    
  
    
      
        
     
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments