Skip to content

Commit 1589e75

Browse files
committed
refactor vae encoder block
1 parent 846b5f9 commit 1589e75

File tree

4 files changed

+36
-46
lines changed

4 files changed

+36
-46
lines changed

src/diffusers/modular_pipelines/wan/before_denoise.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ class WanInputsDynamicStep(ModularPipelineBlocks):
269269

270270
def __init__(
271271
self,
272-
image_latent_inputs: List[str] = ["first_frame_latents"],
272+
image_latent_inputs: List[str] = ["condition_latents"],
273273
additional_batch_inputs: List[str] = ["image_embeds"],
274274
):
275275
"""Initialize a configurable step that standardizes the inputs for the denoising step. It:\n"
@@ -559,15 +559,15 @@ def description(self) -> str:
559559
@property
560560
def inputs(self) -> List[InputParam]:
561561
return [
562-
InputParam("first_frame_latents", type_hint=Optional[torch.Tensor]),
562+
InputParam("condition_latents", type_hint=Optional[torch.Tensor]),
563563
InputParam("num_frames", type_hint=int),
564564
]
565565

566566

567567
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
568568
block_state = self.get_block_state(state)
569569

570-
batch_size, _, _, latent_height, latent_width = block_state.first_frame_latents.shape
570+
batch_size, _, _, latent_height, latent_width = block_state.condition_latents.shape
571571

572572
mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width)
573573
mask_lat_size[:, :, list(range(1, block_state.num_frames))] = 0
@@ -577,8 +577,8 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe
577577
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
578578
mask_lat_size = mask_lat_size.view(batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width)
579579
mask_lat_size = mask_lat_size.transpose(1, 2)
580-
mask_lat_size = mask_lat_size.to(block_state.first_frame_latents.device)
581-
block_state.first_frame_latents = torch.concat([mask_lat_size, block_state.first_frame_latents], dim=1)
580+
mask_lat_size = mask_lat_size.to(block_state.condition_latents.device)
581+
block_state.condition_latents = torch.concat([mask_lat_size, block_state.condition_latents], dim=1)
582582

583583
self.set_block_state(state, block_state)
584584
return components, state

src/diffusers/modular_pipelines/wan/denoise.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,16 +84,16 @@ def inputs(self) -> List[InputParam]:
8484
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
8585
),
8686
InputParam(
87-
"first_frame_latents",
87+
"condition_latents",
8888
required=True,
8989
type_hint=torch.Tensor,
90-
description="The first frame latents to use for the denoising process. Can be generated in prepare_first_frame_latents step.",
90+
description="The condition latents to use for the denoising process. Can be generated in prepare_condition_latents step.",
9191
),
9292
]
9393

9494
@torch.no_grad()
9595
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
96-
block_state.latent_model_input = torch.cat([block_state.latents, block_state.first_frame_latents], dim=1)
96+
block_state.latent_model_input = torch.cat([block_state.latents, block_state.condition_latents], dim=1)
9797
return components, block_state
9898

9999
class WanLoopDenoiserDynamic(ModularPipelineBlocks):

src/diffusers/modular_pipelines/wan/encoders.py

Lines changed: 25 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -117,55 +117,42 @@ def retrieve_latents(
117117

118118

119119
def encode_vae_image(
120-
image: torch.Tensor,
120+
video_tensor: torch.Tensor,
121121
vae: AutoencoderKLWan,
122122
generator: torch.Generator,
123123
device: torch.device,
124124
dtype: torch.dtype,
125-
num_frames: int = 81,
126-
height: int = 480,
127-
width: int = 832,
128125
latent_channels: int = 16,
129126
):
130-
if not isinstance(image, torch.Tensor):
131-
raise ValueError(f"Expected image to be a tensor, got {type(image)}.")
127+
if not isinstance(video_tensor, torch.Tensor):
128+
raise ValueError(f"Expected video_tensor to be a tensor, got {type(video_tensor)}.")
132129

133-
if isinstance(generator, list) and len(generator) != image.shape[0]:
134-
raise ValueError(f"You have passed a list of generators of length {len(generator)}, but it is not same as number of images {image.shape[0]}.")
130+
if isinstance(generator, list) and len(generator) != video_tensor.shape[0]:
131+
raise ValueError(f"You have passed a list of generators of length {len(generator)}, but it is not same as number of images {video_tensor.shape[0]}.")
135132

136-
# preprocessed image should be a 4D tensor: batch_size, num_channels, height, width
137-
if image.dim() == 4:
138-
image = image.unsqueeze(2)
139-
elif image.dim() != 5:
140-
raise ValueError(f"Expected image dims 4 or 5, got {image.dim()}.")
141-
142-
video_condition = torch.cat(
143-
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
144-
)
145-
146-
video_condition = video_condition.to(device=device, dtype=dtype)
133+
video_tensor = video_tensor.to(device=device, dtype=dtype)
147134

148135
if isinstance(generator, list):
149-
latent_condition = [
150-
retrieve_latents(vae.encode(video_condition[i : i + 1]), generator=generator[i], sample_mode="argmax") for i in range(image.shape[0])
136+
video_latents = [
137+
retrieve_latents(vae.encode(video_tensor[i : i + 1]), generator=generator[i], sample_mode="argmax") for i in range(video_tensor.shape[0])
151138
]
152-
latent_condition = torch.cat(latent_condition, dim=0)
139+
video_latents = torch.cat(video_latents, dim=0)
153140
else:
154-
latent_condition = retrieve_latents(vae.encode(video_condition), sample_mode="argmax")
141+
video_latents = retrieve_latents(vae.encode(video_tensor), sample_mode="argmax")
155142

156143
latents_mean = (
157144
torch.tensor(vae.config.latents_mean)
158145
.view(1, latent_channels, 1, 1, 1)
159-
.to(latent_condition.device, latent_condition.dtype)
146+
.to(video_latents.device, video_latents.dtype)
160147
)
161148
latents_std = (
162149
1.0 / torch.tensor(vae.config.latents_std)
163150
.view(1, latent_channels, 1, 1, 1)
164-
.to(latent_condition.device, latent_condition.dtype)
151+
.to(video_latents.device, video_latents.dtype)
165152
)
166-
latent_condition = (latent_condition - latents_mean) * latents_std
153+
video_latents = (video_latents - latents_mean) * latents_std
167154

168-
return latent_condition
155+
return video_latents
169156

170157

171158

@@ -441,7 +428,7 @@ class WanVaeImageEncoderStep(ModularPipelineBlocks):
441428

442429
@property
443430
def description(self) -> str:
444-
return "Vae Image Encoder step that generate first_frame_latents to guide the video generation"
431+
return "Vae Image Encoder step that generate condition_latents to guide the video generation"
445432

446433
@property
447434
def expected_components(self) -> List[ComponentSpec]:
@@ -463,7 +450,7 @@ def inputs(self) -> List[InputParam]:
463450
@property
464451
def intermediate_outputs(self) -> List[OutputParam]:
465452
return [
466-
OutputParam("first_frame_latents", type_hint=torch.Tensor, description="The latent condition"),
453+
OutputParam("condition_latents", type_hint=torch.Tensor, description="The condition latents"),
467454
]
468455

469456
@staticmethod
@@ -497,18 +484,21 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe
497484
image_tensor = components.video_processor.preprocess(
498485
image, height=height, width=width).to(device=device, dtype=dtype)
499486

500-
latent_condition = encode_vae_image(
501-
image=image_tensor,
487+
if image_tensor.dim() == 4:
488+
image_tensor = image_tensor.unsqueeze(2)
489+
490+
video_tensor = torch.cat(
491+
[image_tensor, image_tensor.new_zeros(image_tensor.shape[0], image_tensor.shape[1], num_frames - 1, height, width)], dim=2
492+
).to(device=device, dtype=dtype)
493+
494+
block_state.condition_latents = encode_vae_image(
495+
video_tensor=video_tensor,
502496
vae=components.vae,
503497
generator=block_state.generator,
504498
device=device,
505499
dtype=dtype,
506-
num_frames=num_frames,
507-
height=height,
508-
width=width,
509500
latent_channels=components.num_channels_latents,
510501
)
511502

512-
block_state.first_frame_latents = latent_condition
513503
self.set_block_state(state, block_state)
514504
return components, state

src/diffusers/modular_pipelines/wan/modular_blocks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def description(self):
8181
class WanImage2VideoCoreDenoiseStep(SequentialPipelineBlocks):
8282
block_classes = [
8383
WanTextInputStep,
84-
WanInputsDynamicStep(image_latent_inputs=["first_frame_latents"]),
84+
WanInputsDynamicStep(image_latent_inputs=["condition_latents"]),
8585
WanSetTimestepsStep,
8686
WanPrepareLatentsStep,
8787
WanPrepareFirstFrameLatentsStep,
@@ -137,7 +137,7 @@ class WanAutoDenoiseStep(AutoPipelineBlocks):
137137
WanCoreDenoiseStep,
138138
]
139139
block_names = ["image2video", "text2video"]
140-
block_trigger_inputs = ["first_frame_latents", None]
140+
block_trigger_inputs = ["condition_latents", None]
141141

142142
@property
143143
def description(self) -> str:
@@ -193,7 +193,7 @@ def description(self):
193193
("image_encoder", WanImage2VideoImageEncoderStep),
194194
("vae_image_encoder", WanImage2VideoVaeImageEncoderStep),
195195
("input", WanTextInputStep),
196-
("additional_inputs", WanInputsDynamicStep(image_latent_inputs=["first_frame_latents"])),
196+
("additional_inputs", WanInputsDynamicStep(image_latent_inputs=["condition_latents"])),
197197
("set_timesteps", WanSetTimestepsStep),
198198
("prepare_latents", WanPrepareLatentsStep),
199199
("denoise", WanImage2VideoCoreDenoiseStep),

0 commit comments

Comments
 (0)