@@ -117,55 +117,42 @@ def retrieve_latents(
117117
118118
119119def encode_vae_image (
120- image : torch .Tensor ,
120+ video_tensor : torch .Tensor ,
121121 vae : AutoencoderKLWan ,
122122 generator : torch .Generator ,
123123 device : torch .device ,
124124 dtype : torch .dtype ,
125- num_frames : int = 81 ,
126- height : int = 480 ,
127- width : int = 832 ,
128125 latent_channels : int = 16 ,
129126):
130- if not isinstance (image , torch .Tensor ):
131- raise ValueError (f"Expected image to be a tensor, got { type (image )} ." )
127+ if not isinstance (video_tensor , torch .Tensor ):
128+ raise ValueError (f"Expected video_tensor to be a tensor, got { type (video_tensor )} ." )
132129
133- if isinstance (generator , list ) and len (generator ) != image .shape [0 ]:
134- raise ValueError (f"You have passed a list of generators of length { len (generator )} , but it is not same as number of images { image .shape [0 ]} ." )
130+ if isinstance (generator , list ) and len (generator ) != video_tensor .shape [0 ]:
131+ raise ValueError (f"You have passed a list of generators of length { len (generator )} , but it is not same as number of images { video_tensor .shape [0 ]} ." )
135132
136- # preprocessed image should be a 4D tensor: batch_size, num_channels, height, width
137- if image .dim () == 4 :
138- image = image .unsqueeze (2 )
139- elif image .dim () != 5 :
140- raise ValueError (f"Expected image dims 4 or 5, got { image .dim ()} ." )
141-
142- video_condition = torch .cat (
143- [image , image .new_zeros (image .shape [0 ], image .shape [1 ], num_frames - 1 , height , width )], dim = 2
144- )
145-
146- video_condition = video_condition .to (device = device , dtype = dtype )
133+ video_tensor = video_tensor .to (device = device , dtype = dtype )
147134
148135 if isinstance (generator , list ):
149- latent_condition = [
150- retrieve_latents (vae .encode (video_condition [i : i + 1 ]), generator = generator [i ], sample_mode = "argmax" ) for i in range (image .shape [0 ])
136+ video_latents = [
137+ retrieve_latents (vae .encode (video_tensor [i : i + 1 ]), generator = generator [i ], sample_mode = "argmax" ) for i in range (video_tensor .shape [0 ])
151138 ]
152- latent_condition = torch .cat (latent_condition , dim = 0 )
139+ video_latents = torch .cat (video_latents , dim = 0 )
153140 else :
154- latent_condition = retrieve_latents (vae .encode (video_condition ), sample_mode = "argmax" )
141+ video_latents = retrieve_latents (vae .encode (video_tensor ), sample_mode = "argmax" )
155142
156143 latents_mean = (
157144 torch .tensor (vae .config .latents_mean )
158145 .view (1 , latent_channels , 1 , 1 , 1 )
159- .to (latent_condition .device , latent_condition .dtype )
146+ .to (video_latents .device , video_latents .dtype )
160147 )
161148 latents_std = (
162149 1.0 / torch .tensor (vae .config .latents_std )
163150 .view (1 , latent_channels , 1 , 1 , 1 )
164- .to (latent_condition .device , latent_condition .dtype )
151+ .to (video_latents .device , video_latents .dtype )
165152 )
166- latent_condition = (latent_condition - latents_mean ) * latents_std
153+ video_latents = (video_latents - latents_mean ) * latents_std
167154
168- return latent_condition
155+ return video_latents
169156
170157
171158
@@ -441,7 +428,7 @@ class WanVaeImageEncoderStep(ModularPipelineBlocks):
441428
442429 @property
443430 def description (self ) -> str :
444- return "Vae Image Encoder step that generate first_frame_latents to guide the video generation"
431+ return "Vae Image Encoder step that generate condition_latents to guide the video generation"
445432
446433 @property
447434 def expected_components (self ) -> List [ComponentSpec ]:
@@ -463,7 +450,7 @@ def inputs(self) -> List[InputParam]:
463450 @property
464451 def intermediate_outputs (self ) -> List [OutputParam ]:
465452 return [
466- OutputParam ("first_frame_latents " , type_hint = torch .Tensor , description = "The latent condition" ),
453+ OutputParam ("condition_latents " , type_hint = torch .Tensor , description = "The condition latents " ),
467454 ]
468455
469456 @staticmethod
@@ -497,18 +484,21 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe
497484 image_tensor = components .video_processor .preprocess (
498485 image , height = height , width = width ).to (device = device , dtype = dtype )
499486
500- latent_condition = encode_vae_image (
501- image = image_tensor ,
487+ if image_tensor .dim () == 4 :
488+ image_tensor = image_tensor .unsqueeze (2 )
489+
490+ video_tensor = torch .cat (
491+ [image_tensor , image_tensor .new_zeros (image_tensor .shape [0 ], image_tensor .shape [1 ], num_frames - 1 , height , width )], dim = 2
492+ ).to (device = device , dtype = dtype )
493+
494+ block_state .condition_latents = encode_vae_image (
495+ video_tensor = video_tensor ,
502496 vae = components .vae ,
503497 generator = block_state .generator ,
504498 device = device ,
505499 dtype = dtype ,
506- num_frames = num_frames ,
507- height = height ,
508- width = width ,
509500 latent_channels = components .num_channels_latents ,
510501 )
511502
512- block_state .first_frame_latents = latent_condition
513503 self .set_block_state (state , block_state )
514504 return components , state
0 commit comments