Skip to content

Commit cb2d3b9

Browse files
committed
support flf2video!
1 parent 1589e75 commit cb2d3b9

File tree

4 files changed

+368
-37
lines changed

4 files changed

+368
-37
lines changed

src/diffusers/modular_pipelines/wan/before_denoise.py

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
2424
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
2525
from .modular_pipeline import WanModularPipeline
26+
from ...models import WanTransformer3DModel
2627

2728

2829
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -194,6 +195,12 @@ def description(self) -> str:
194195
"of prompt_embeds. The tensors will be duplicated across the batch dimension to\n"
195196
"have a final batch_size of batch_size * num_videos_per_prompt."
196197
)
198+
199+
@property
200+
def expected_components(self) -> List[ComponentSpec]:
201+
return [
202+
ComponentSpec("transformer", WanTransformer3DModel),
203+
]
197204

198205
@property
199206
def inputs(self) -> List[InputParam]:
@@ -223,7 +230,7 @@ def intermediate_outputs(self) -> List[str]:
223230
OutputParam(
224231
"dtype",
225232
type_hint=torch.dtype,
226-
description="Data type of model tensor inputs (determined by `prompt_embeds`)",
233+
description="Data type of model tensor inputs (determined by `transformer.dtype`)",
227234
),
228235
]
229236

@@ -242,7 +249,7 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe
242249
self.check_inputs(components, block_state)
243250

244251
block_state.batch_size = block_state.prompt_embeds.shape[0]
245-
block_state.dtype = block_state.prompt_embeds.dtype
252+
block_state.dtype = components.transformer.dtype
246253

247254
_, seq_len, _ = block_state.prompt_embeds.shape
248255
block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_videos_per_prompt, 1)
@@ -269,8 +276,8 @@ class WanInputsDynamicStep(ModularPipelineBlocks):
269276

270277
def __init__(
271278
self,
272-
image_latent_inputs: List[str] = ["condition_latents"],
273-
additional_batch_inputs: List[str] = ["image_embeds"],
279+
image_latent_inputs: List[str] = ["first_frame_latents"],
280+
additional_batch_inputs: List[str] = [],
274281
):
275282
"""Initialize a configurable step that standardizes the inputs for the denoising step. It:\n"
276283
@@ -559,15 +566,15 @@ def description(self) -> str:
559566
@property
560567
def inputs(self) -> List[InputParam]:
561568
return [
562-
InputParam("condition_latents", type_hint=Optional[torch.Tensor]),
569+
InputParam("first_frame_latents", type_hint=Optional[torch.Tensor]),
563570
InputParam("num_frames", type_hint=int),
564571
]
565572

566573

567574
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
568575
block_state = self.get_block_state(state)
569576

570-
batch_size, _, _, latent_height, latent_width = block_state.condition_latents.shape
577+
batch_size, _, _, latent_height, latent_width = block_state.first_frame_latents.shape
571578

572579
mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width)
573580
mask_lat_size[:, :, list(range(1, block_state.num_frames))] = 0
@@ -577,8 +584,43 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe
577584
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
578585
mask_lat_size = mask_lat_size.view(batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width)
579586
mask_lat_size = mask_lat_size.transpose(1, 2)
580-
mask_lat_size = mask_lat_size.to(block_state.condition_latents.device)
581-
block_state.condition_latents = torch.concat([mask_lat_size, block_state.condition_latents], dim=1)
587+
mask_lat_size = mask_lat_size.to(block_state.first_frame_latents.device)
588+
block_state.first_frame_latents = torch.concat([mask_lat_size, block_state.first_frame_latents], dim=1)
582589

583590
self.set_block_state(state, block_state)
584591
return components, state
592+
593+
594+
class WanPrepareFirstLastFrameLatentsStep(ModularPipelineBlocks):
595+
model_name = "wan"
596+
597+
@property
598+
def description(self) -> str:
599+
return "step that prepares the last frame mask latents and add it to the latent condition"
600+
601+
@property
602+
def inputs(self) -> List[InputParam]:
603+
return [
604+
InputParam("first_last_frame_latents", type_hint=Optional[torch.Tensor]),
605+
InputParam("num_frames", type_hint=int),
606+
]
607+
608+
609+
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
610+
block_state = self.get_block_state(state)
611+
612+
batch_size, _, _, latent_height, latent_width = block_state.first_last_frame_latents.shape
613+
614+
mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width)
615+
mask_lat_size[:, :, list(range(1, block_state.num_frames-1))] = 0
616+
617+
first_frame_mask = mask_lat_size[:, :, 0:1]
618+
first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal)
619+
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
620+
mask_lat_size = mask_lat_size.view(batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width)
621+
mask_lat_size = mask_lat_size.transpose(1, 2)
622+
mask_lat_size = mask_lat_size.to(block_state.first_last_frame_latents.device)
623+
block_state.first_last_frame_latents = torch.concat([mask_lat_size, block_state.first_last_frame_latents], dim=1)
624+
625+
self.set_block_state(state, block_state)
626+
return components, state

src/diffusers/modular_pipelines/wan/denoise.py

Lines changed: 89 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,18 @@ def inputs(self) -> List[InputParam]:
5454
type_hint=torch.Tensor,
5555
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
5656
),
57+
InputParam(
58+
"dtype",
59+
required=True,
60+
type_hint=torch.dtype,
61+
description="The dtype of the model inputs. Can be generated in input step.",
62+
),
5763
]
5864

5965

6066
@torch.no_grad()
6167
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
62-
block_state.latent_model_input = block_state.latents
68+
block_state.latent_model_input = block_state.latents.to(block_state.dtype)
6369
return components, block_state
6470

6571

@@ -84,18 +90,67 @@ def inputs(self) -> List[InputParam]:
8490
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
8591
),
8692
InputParam(
87-
"condition_latents",
93+
"first_frame_latents",
8894
required=True,
8995
type_hint=torch.Tensor,
90-
description="The condition latents to use for the denoising process. Can be generated in prepare_condition_latents step.",
96+
description="The first frame latents to use for the denoising process. Can be generated in prepare_first_frame_latents step.",
97+
),
98+
InputParam(
99+
"dtype",
100+
required=True,
101+
type_hint=torch.dtype,
102+
description="The dtype of the model inputs. Can be generated in input step.",
91103
),
92104
]
93105

94106
@torch.no_grad()
95107
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
96-
block_state.latent_model_input = torch.cat([block_state.latents, block_state.condition_latents], dim=1)
108+
block_state.latent_model_input = torch.cat([block_state.latents, block_state.first_frame_latents], dim=1).to(block_state.dtype)
109+
block_state.image_embeds = block_state.image_embeds.to(block_state.dtype)
97110
return components, block_state
98111

112+
113+
class WanFLF2VLoopBeforeDenoiser(ModularPipelineBlocks):
114+
model_name = "wan"
115+
116+
@property
117+
def description(self) -> str:
118+
return (
119+
"step within the denoising loop that prepares the latent input for the denoiser. "
120+
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
121+
"object (e.g. `WanDenoiseLoopWrapper`)"
122+
)
123+
124+
@property
125+
def inputs(self) -> List[InputParam]:
126+
return [
127+
InputParam(
128+
"latents",
129+
required=True,
130+
type_hint=torch.Tensor,
131+
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
132+
),
133+
InputParam(
134+
"first_last_frame_latents",
135+
required=True,
136+
type_hint=torch.Tensor,
137+
description="The first and last frame latents to use for the denoising process. Can be generated in prepare_first_last_frame_latents step.",
138+
),
139+
InputParam(
140+
"dtype",
141+
required=True,
142+
type_hint=torch.dtype,
143+
description="The dtype of the model inputs. Can be generated in input step.",
144+
),
145+
]
146+
147+
@torch.no_grad()
148+
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
149+
block_state.latent_model_input = torch.cat([block_state.latents, block_state.first_last_frame_latents], dim=1).to(block_state.dtype)
150+
block_state.image_embeds = block_state.image_embeds.to(block_state.dtype)
151+
return components, block_state
152+
153+
99154
class WanLoopDenoiserDynamic(ModularPipelineBlocks):
100155
model_name = "wan"
101156

@@ -155,9 +210,6 @@ def inputs(self) -> List[Tuple[str, Any]]:
155210
def __call__(
156211
self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
157212
) -> PipelineState:
158-
159-
transformer_dtype = components.transformer.dtype
160-
161213
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
162214

163215
# The guider splits model inputs into separate batches for conditional/unconditional predictions.
@@ -179,8 +231,8 @@ def __call__(
179231
# Predict the noise residual
180232
# store the noise_pred in guider_state_batch so that we can apply guidance across all batches
181233
guider_state_batch.noise_pred = components.transformer(
182-
hidden_states=block_state.latent_model_input.to(transformer_dtype),
183-
timestep=t.expand(block_state.latent_model_input.shape[0]).to(block_state.latent_model_input.dtype),
234+
hidden_states=block_state.latent_model_input.to(block_state.dtype),
235+
timestep=t.expand(block_state.latent_model_input.shape[0]).to(block_state.dtype),
184236
attention_kwargs=block_state.attention_kwargs,
185237
return_dict=False,
186238
**cond_kwargs,
@@ -300,6 +352,7 @@ def description(self) -> str:
300352
"Denoise step that iteratively denoise the latents. \n"
301353
"Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n"
302354
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
355+
" - `WanLoopBeforeDenoiser`\n"
303356
" - `WanLoopDenoiser`\n"
304357
" - `WanLoopAfterDenoiser`\n"
305358
"This block supports text-to-video tasks."
@@ -324,7 +377,34 @@ def description(self) -> str:
324377
"Denoise step that iteratively denoise the latents. \n"
325378
"Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n"
326379
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
380+
" - `WanImage2VideoLoopBeforeDenoiser`\n"
327381
" - `WanLoopDenoiser`\n"
328382
" - `WanLoopAfterDenoiser`\n"
329383
"This block supports image-to-video tasks."
330384
)
385+
386+
387+
class WanFLF2VDenoiseStep(WanDenoiseLoopWrapper):
388+
block_classes = [
389+
WanFLF2VLoopBeforeDenoiser,
390+
WanLoopDenoiserDynamic(
391+
guider_input_fields={
392+
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
393+
"encoder_hidden_states_image": "image_embeds",
394+
}
395+
),
396+
WanLoopAfterDenoiser,
397+
]
398+
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
399+
400+
@property
401+
def description(self) -> str:
402+
return (
403+
"Denoise step that iteratively denoise the latents. \n"
404+
"Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n"
405+
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
406+
" - `WanFLF2VLoopBeforeDenoiser`\n"
407+
" - `WanLoopDenoiser`\n"
408+
" - `WanLoopAfterDenoiser`\n"
409+
"This block supports FLF2V tasks."
410+
)

0 commit comments

Comments
 (0)