2323from ..modular_pipeline import ModularPipelineBlocks , PipelineState
2424from ..modular_pipeline_utils import ComponentSpec , InputParam , OutputParam
2525from .modular_pipeline import WanModularPipeline
26+ from ...models import WanTransformer3DModel
2627
2728
2829logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
@@ -194,6 +195,12 @@ def description(self) -> str:
194195 "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n "
195196 "have a final batch_size of batch_size * num_videos_per_prompt."
196197 )
198+
199+ @property
200+ def expected_components (self ) -> List [ComponentSpec ]:
201+ return [
202+ ComponentSpec ("transformer" , WanTransformer3DModel ),
203+ ]
197204
198205 @property
199206 def inputs (self ) -> List [InputParam ]:
@@ -223,7 +230,7 @@ def intermediate_outputs(self) -> List[str]:
223230 OutputParam (
224231 "dtype" ,
225232 type_hint = torch .dtype ,
226- description = "Data type of model tensor inputs (determined by `prompt_embeds `)" ,
233+ description = "Data type of model tensor inputs (determined by `transformer.dtype `)" ,
227234 ),
228235 ]
229236
@@ -242,7 +249,7 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe
242249 self .check_inputs (components , block_state )
243250
244251 block_state .batch_size = block_state .prompt_embeds .shape [0 ]
245- block_state .dtype = block_state . prompt_embeds .dtype
252+ block_state .dtype = components . transformer .dtype
246253
247254 _ , seq_len , _ = block_state .prompt_embeds .shape
248255 block_state .prompt_embeds = block_state .prompt_embeds .repeat (1 , block_state .num_videos_per_prompt , 1 )
@@ -269,8 +276,8 @@ class WanInputsDynamicStep(ModularPipelineBlocks):
269276
270277 def __init__ (
271278 self ,
272- image_latent_inputs : List [str ] = ["condition_latents " ],
273- additional_batch_inputs : List [str ] = ["image_embeds" ],
279+ image_latent_inputs : List [str ] = ["first_frame_latents " ],
280+ additional_batch_inputs : List [str ] = [],
274281 ):
275282 """Initialize a configurable step that standardizes the inputs for the denoising step. It:\n "
276283
@@ -559,15 +566,15 @@ def description(self) -> str:
559566 @property
560567 def inputs (self ) -> List [InputParam ]:
561568 return [
562- InputParam ("condition_latents " , type_hint = Optional [torch .Tensor ]),
569+ InputParam ("first_frame_latents " , type_hint = Optional [torch .Tensor ]),
563570 InputParam ("num_frames" , type_hint = int ),
564571 ]
565572
566573
567574 def __call__ (self , components : WanModularPipeline , state : PipelineState ) -> PipelineState :
568575 block_state = self .get_block_state (state )
569576
570- batch_size , _ , _ , latent_height , latent_width = block_state .condition_latents .shape
577+ batch_size , _ , _ , latent_height , latent_width = block_state .first_frame_latents .shape
571578
572579 mask_lat_size = torch .ones (batch_size , 1 , block_state .num_frames , latent_height , latent_width )
573580 mask_lat_size [:, :, list (range (1 , block_state .num_frames ))] = 0
@@ -577,8 +584,43 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe
577584 mask_lat_size = torch .concat ([first_frame_mask , mask_lat_size [:, :, 1 :, :]], dim = 2 )
578585 mask_lat_size = mask_lat_size .view (batch_size , - 1 , components .vae_scale_factor_temporal , latent_height , latent_width )
579586 mask_lat_size = mask_lat_size .transpose (1 , 2 )
580- mask_lat_size = mask_lat_size .to (block_state .condition_latents .device )
581- block_state .condition_latents = torch .concat ([mask_lat_size , block_state .condition_latents ], dim = 1 )
587+ mask_lat_size = mask_lat_size .to (block_state .first_frame_latents .device )
588+ block_state .first_frame_latents = torch .concat ([mask_lat_size , block_state .first_frame_latents ], dim = 1 )
582589
583590 self .set_block_state (state , block_state )
584591 return components , state
592+
593+
594+ class WanPrepareFirstLastFrameLatentsStep (ModularPipelineBlocks ):
595+ model_name = "wan"
596+
597+ @property
598+ def description (self ) -> str :
599+ return "step that prepares the last frame mask latents and add it to the latent condition"
600+
601+ @property
602+ def inputs (self ) -> List [InputParam ]:
603+ return [
604+ InputParam ("first_last_frame_latents" , type_hint = Optional [torch .Tensor ]),
605+ InputParam ("num_frames" , type_hint = int ),
606+ ]
607+
608+
609+ def __call__ (self , components : WanModularPipeline , state : PipelineState ) -> PipelineState :
610+ block_state = self .get_block_state (state )
611+
612+ batch_size , _ , _ , latent_height , latent_width = block_state .first_last_frame_latents .shape
613+
614+ mask_lat_size = torch .ones (batch_size , 1 , block_state .num_frames , latent_height , latent_width )
615+ mask_lat_size [:, :, list (range (1 , block_state .num_frames - 1 ))] = 0
616+
617+ first_frame_mask = mask_lat_size [:, :, 0 :1 ]
618+ first_frame_mask = torch .repeat_interleave (first_frame_mask , dim = 2 , repeats = components .vae_scale_factor_temporal )
619+ mask_lat_size = torch .concat ([first_frame_mask , mask_lat_size [:, :, 1 :, :]], dim = 2 )
620+ mask_lat_size = mask_lat_size .view (batch_size , - 1 , components .vae_scale_factor_temporal , latent_height , latent_width )
621+ mask_lat_size = mask_lat_size .transpose (1 , 2 )
622+ mask_lat_size = mask_lat_size .to (block_state .first_last_frame_latents .device )
623+ block_state .first_last_frame_latents = torch .concat ([mask_lat_size , block_state .first_last_frame_latents ], dim = 1 )
624+
625+ self .set_block_state (state , block_state )
626+ return components , state
0 commit comments