Skip to content
Merged
2 changes: 1 addition & 1 deletion src/diffusers/models/transformers/transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def forward(
):
timestep = self.timesteps_proj(timestep)
if timestep_seq_len is not None:
timestep = timestep.unflatten(0, (1, timestep_seq_len))
timestep = timestep.unflatten(0, (-1, timestep_seq_len))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so that it works with batch_size > 1


time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
Expand Down
12 changes: 8 additions & 4 deletions src/diffusers/pipelines/wan/pipeline_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,15 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):

model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
_optional_components = ["transformer_2"]
_optional_components = ["transformer", "transformer_2"]

def __init__(
self,
tokenizer: AutoTokenizer,
text_encoder: UMT5EncoderModel,
transformer: WanTransformer3DModel,
vae: AutoencoderKLWan,
scheduler: FlowMatchEulerDiscreteScheduler,
transformer: Optional[WanTransformer3DModel] = None,
transformer_2: Optional[WanTransformer3DModel] = None,
boundary_ratio: Optional[float] = None,
expand_timesteps: bool = False, # Wan2.2 ti2v
Expand Down Expand Up @@ -526,7 +526,7 @@ def __call__(
device=device,
)

transformer_dtype = self.transformer.dtype
transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
prompt_embeds = prompt_embeds.to(transformer_dtype)
if negative_prompt_embeds is not None:
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
Expand All @@ -536,7 +536,11 @@ def __call__(
timesteps = self.scheduler.timesteps

# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
num_channels_latents = (
self.transformer.config.in_channels
if self.transformer is not None
else self.transformer_2.config.in_channels
)
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,
Expand Down
10 changes: 6 additions & 4 deletions src/diffusers/pipelines/wan/pipeline_wan_i2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,17 +162,17 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):

model_cpu_offload_seq = "text_encoder->image_encoder->transformer->transformer_2->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
_optional_components = ["transformer_2", "image_encoder", "image_processor"]
_optional_components = ["transformer", "transformer_2", "image_encoder", "image_processor"]

def __init__(
self,
tokenizer: AutoTokenizer,
text_encoder: UMT5EncoderModel,
transformer: WanTransformer3DModel,
vae: AutoencoderKLWan,
scheduler: FlowMatchEulerDiscreteScheduler,
image_processor: CLIPImageProcessor = None,
image_encoder: CLIPVisionModel = None,
transformer: WanTransformer3DModel = None,
transformer_2: WanTransformer3DModel = None,
boundary_ratio: Optional[float] = None,
expand_timesteps: bool = False,
Expand Down Expand Up @@ -669,12 +669,13 @@ def __call__(
)

# Encode image embedding
transformer_dtype = self.transformer.dtype
transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
prompt_embeds = prompt_embeds.to(transformer_dtype)
if negative_prompt_embeds is not None:
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)

if self.config.boundary_ratio is None and not self.config.expand_timesteps:
# only wan 2.1 i2v transformer accepts image_embeds
if self.transformer is not None and self.transformer.config.image_dim is not None:
if image_embeds is None:
if last_image is None:
image_embeds = self.encode_image(image, device)
Expand Down Expand Up @@ -709,6 +710,7 @@ def __call__(
last_image,
)
if self.config.expand_timesteps:
# wan 2.2 5b i2v use firt_frame_mask to mask timesteps
latents, condition, first_frame_mask = latents_outputs
else:
latents, condition = latents_outputs
Expand Down
59 changes: 42 additions & 17 deletions tests/pipelines/wan/test_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
# limitations under the License.

import gc
import tempfile
import unittest

import numpy as np
import torch
from transformers import AutoTokenizer, T5EncoderModel

Expand Down Expand Up @@ -85,29 +87,13 @@ def get_dummy_components(self):
rope_max_seq_len=32,
)

torch.manual_seed(0)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert the change in #12004

transformer_2 = WanTransformer3DModel(
patch_size=(1, 2, 2),
num_attention_heads=2,
attention_head_dim=12,
in_channels=16,
out_channels=16,
text_dim=32,
freq_dim=256,
ffn_dim=32,
num_layers=2,
cross_attn_norm=True,
qk_norm="rms_norm_across_heads",
rope_max_seq_len=32,
)

components = {
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"transformer_2": transformer_2,
"transformer_2": None,
}
return components

Expand Down Expand Up @@ -155,6 +141,45 @@ def test_inference(self):
def test_attention_slicing_forward_pass(self):
pass

# _optional_components include transformer, transformer_2, but only transformer_2 is optional for this wan2.1 t2v pipeline
def test_save_load_optional_components(self, expected_max_difference=1e-4):
optional_component = "transformer_2"

components = self.get_dummy_components()
components[optional_component] = None
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
torch.manual_seed(0)
output = pipe(**inputs)[0]

with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, safe_serialization=False)
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
for component in pipe_loaded.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None)

self.assertTrue(
getattr(pipe_loaded, optional_component) is None,
f"`{optional_component}` did not stay set to None after loading.",
)

inputs = self.get_dummy_inputs(generator_device)
torch.manual_seed(0)
output_loaded = pipe_loaded(**inputs)[0]

max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max()
self.assertLess(max_diff, expected_max_difference)


@slow
@require_torch_accelerator
Expand Down
Loading
Loading