@@ -108,31 +108,16 @@ def prompt_clean(text):
108108 return text
109109
110110
111+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
111112def retrieve_latents (
112- encoder_output : torch .Tensor ,
113- latents_mean : torch .Tensor ,
114- latents_std : torch .Tensor ,
115- generator : Optional [torch .Generator ] = None ,
116- sample_mode : str = "sample" ,
113+ encoder_output : torch .Tensor , generator : Optional [torch .Generator ] = None , sample_mode : str = "sample"
117114):
118115 if hasattr (encoder_output , "latent_dist" ) and sample_mode == "sample" :
119- encoder_output .latent_dist .mean = (encoder_output .latent_dist .mean - latents_mean ) * latents_std
120- encoder_output .latent_dist .logvar = torch .clamp (
121- (encoder_output .latent_dist .logvar - latents_mean ) * latents_std , - 30.0 , 20.0
122- )
123- encoder_output .latent_dist .std = torch .exp (0.5 * encoder_output .latent_dist .logvar )
124- encoder_output .latent_dist .var = torch .exp (encoder_output .latent_dist .logvar )
125116 return encoder_output .latent_dist .sample (generator )
126117 elif hasattr (encoder_output , "latent_dist" ) and sample_mode == "argmax" :
127- encoder_output .latent_dist .mean = (encoder_output .latent_dist .mean - latents_mean ) * latents_std
128- encoder_output .latent_dist .logvar = torch .clamp (
129- (encoder_output .latent_dist .logvar - latents_mean ) * latents_std , - 30.0 , 20.0
130- )
131- encoder_output .latent_dist .std = torch .exp (0.5 * encoder_output .latent_dist .logvar )
132- encoder_output .latent_dist .var = torch .exp (encoder_output .latent_dist .logvar )
133118 return encoder_output .latent_dist .mode ()
134119 elif hasattr (encoder_output , "latents" ):
135- return ( encoder_output .latents - latents_mean ) * latents_std
120+ return encoder_output .latents
136121 else :
137122 raise AttributeError ("Could not access latents of provided encoder_output" )
138123
@@ -412,13 +397,15 @@ def prepare_latents(
412397
413398 if isinstance (generator , list ):
414399 latent_condition = [
415- retrieve_latents (self .vae .encode (video_condition ), latents_mean , latents_std , g ) for g in generator
400+ retrieve_latents (self .vae .encode (video_condition ), sample_mode = "argmax" ) for _ in generator
416401 ]
417402 latent_condition = torch .cat (latent_condition )
418403 else :
419- latent_condition = retrieve_latents (self .vae .encode (video_condition ), latents_mean , latents_std , generator )
404+ latent_condition = retrieve_latents (self .vae .encode (video_condition ), sample_mode = "argmax" )
420405 latent_condition = latent_condition .repeat (batch_size , 1 , 1 , 1 , 1 )
421406
407+ latent_condition = (latent_condition - latents_mean ) * latents_std
408+
422409 mask_lat_size = torch .ones (batch_size , 1 , num_frames , latent_height , latent_width )
423410 mask_lat_size [:, :, list (range (1 , num_frames ))] = 0
424411 first_frame_mask = mask_lat_size [:, :, 0 :1 ]
0 commit comments