diff --git a/docs/source/en/api/pipelines/wan.md b/docs/source/en/api/pipelines/wan.md index 3289a840e2b1..0d6513d88648 100644 --- a/docs/source/en/api/pipelines/wan.md +++ b/docs/source/en/api/pipelines/wan.md @@ -40,6 +40,7 @@ The following Wan models are supported in Diffusers: - [Wan 2.2 T2V 14B](https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B-Diffusers) - [Wan 2.2 I2V 14B](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers) - [Wan 2.2 TI2V 5B](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B-Diffusers) +- [Wan 2.2 S2V 14B](https://huggingface.co/Wan-AI/Wan2.2-S2V-14B-Diffusers) > [!TIP] > Click on the Wan models in the right sidebar for more examples of video generation. @@ -95,15 +96,15 @@ pipeline = WanPipeline.from_pretrained( pipeline.to("cuda") prompt = """ -The camera rushes from far to near in a low-angle shot, -revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in -for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground. -Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic +The camera rushes from far to near in a low-angle shot, +revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in +for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground. +Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic shadows and warm highlights. Medium composition, front view, low angle, with depth of field. """ negative_prompt = """ -Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, -low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, +Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, +low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards """ @@ -150,15 +151,15 @@ pipeline.transformer = torch.compile( ) prompt = """ -The camera rushes from far to near in a low-angle shot, -revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in -for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground. -Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic +The camera rushes from far to near in a low-angle shot, +revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in +for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground. +Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic shadows and warm highlights. Medium composition, front view, low angle, with depth of field. """ negative_prompt = """ -Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, -low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, +Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, +low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards """ @@ -236,6 +237,125 @@ export_to_video(output, "output.mp4", fps=16) + +### Wan-S2V: Audio-Driven Cinematic Video Generation + +[Wan-S2V](https://huggingface.co/papers/2508.18621) by the Wan Team. + +*Current state-of-the-art (SOTA) methods for audio-driven character animation demonstrate promising performance for scenarios primarily involving speech and singing. However, they often fall short in more complex film and television productions, which demand sophisticated elements such as nuanced character interactions, realistic body movements, and dynamic camera work. To address this long-standing challenge of achieving film-level character animation, we propose an audio-driven model, which we refere to as Wan-S2V, built upon Wan. Our model achieves significantly enhanced expressiveness and fidelity in cinematic contexts compared to existing approaches. We conducted extensive experiments, benchmarking our method against cutting-edge models such as Hunyuan-Avatar and Omnihuman. The experimental results consistently demonstrate that our approach significantly outperforms these existing solutions. Additionally, we explore the versatility of our method through its applications in long-form video generation and precise video lip-sync editing.* + +The example below demonstrates how to use the speech-to-video pipeline to generate a video using a text description, a starting frame, an audio, and a pose video. + + + + +```python +import numpy as np, math +import torch +from diffusers import AutoencoderKLWan, WanSpeechToVideoPipeline +from diffusers.utils import export_to_merged_video_audio, load_image, load_audio, load_video, export_to_video +from transformers import Wav2Vec2ForCTC +import requests +from PIL import Image +from io import BytesIO + + +model_id = "Wan-AI/Wan2.2-S2V-14B-Diffusers" +audio_encoder = Wav2Vec2ForCTC.from_pretrained(model_id, subfolder="audio_encoder", dtype=torch.float32) +vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) +pipe = WanSpeechToVideoPipeline.from_pretrained( + model_id, vae=vae, audio_encoder=audio_encoder, torch_dtype=torch.bfloat16 +) +pipe.to("cuda") + +headers = {"User-Agent": "Mozilla/5.0"} +url = "https://upload.wikimedia.org/wikipedia/commons/4/46/Albert_Einstein_sticks_his_tongue.jpg" +resp = requests.get(url, headers=headers, timeout=30) +image = Image.open(BytesIO(resp.content)) + +audio, sampling_rate = load_audio("https://github.com/Wan-Video/Wan2.2/raw/refs/heads/main/examples/Five%20Hundred%20Miles.MP3") +#pose_video_path_or_url = "https://github.com/Wan-Video/Wan2.2/raw/refs/heads/main/examples/pose.mp4" + +def get_size_less_than_area(height, + width, + target_area=1024 * 704, + divisor=64): + if height * width <= target_area: + # If the original image area is already less than or equal to the target, + # no resizing is needed—just padding. Still need to ensure that the padded area doesn't exceed the target. + max_upper_area = target_area + min_scale = 0.1 + max_scale = 1.0 + else: + # Resize to fit within the target area and then pad to multiples of `divisor` + max_upper_area = target_area # Maximum allowed total pixel count after padding + d = divisor - 1 + b = d * (height + width) + a = height * width + c = d**2 - max_upper_area + + # Calculate scale boundaries using quadratic equation + min_scale = (-b + math.sqrt(b**2 - 2 * a * c)) / (2 * a) # Scale when maximum padding is applied + max_scale = math.sqrt(max_upper_area / (height * width)) # Scale without any padding + + # We want to choose the largest possible scale such that the final padded area does not exceed max_upper_area + # Use binary search-like iteration to find this scale + find_it = False + for i in range(100): + scale = max_scale - (max_scale - min_scale) * i / 100 + new_height, new_width = int(height * scale), int(width * scale) + + # Pad to make dimensions divisible by 64 + pad_height = (64 - new_height % 64) % 64 + pad_width = (64 - new_width % 64) % 64 + pad_top = pad_height // 2 + pad_bottom = pad_height - pad_top + pad_left = pad_width // 2 + pad_right = pad_width - pad_left + + padded_height, padded_width = new_height + pad_height, new_width + pad_width + + if padded_height * padded_width <= max_upper_area: + find_it = True + break + + if find_it: + return padded_height, padded_width + else: + # Fallback: calculate target dimensions based on aspect ratio and divisor alignment + aspect_ratio = width / height + target_width = int( + (target_area * aspect_ratio)**0.5 // divisor * divisor) + target_height = int( + (target_area / aspect_ratio)**0.5 // divisor * divisor) + + # Ensure the result is not larger than the original resolution + if target_width >= width or target_height >= height: + target_width = int(width // divisor * divisor) + target_height = int(height // divisor * divisor) + + return target_height, target_width + +height, width = get_size_less_than_area(first_frame.height, first_frame.width, 480*832) + +prompt = "Einstein singing a song." + +output = pipe( + prompt=prompt, image=image, audio=audio, sampling_rate=sampling_rate, + height=height, width=width, num_frames_per_chunk=80, + #pose_video_path_or_url=pose_video_path_or_url, +).frames[0] +export_to_video(output, "output.mp4", fps=16) + +# Lastly, we need to merge the video and audio into a new video, with the duration set to +# the shorter of the two and overwrite the original video file. +export_to_merged_video_audio("output.mp4", "audio.mp3") +``` + + + + + ### Any-to-Video Controllable Generation Wan VACE supports various generation techniques which achieve controllable video generation. Some of the capabilities include: @@ -281,10 +401,10 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip # use "steamboat willie style" to trigger the LoRA prompt = """ - steamboat willie style, golden era animation, The camera rushes from far to near in a low-angle shot, - revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in - for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground. - Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic + steamboat willie style, golden era animation, The camera rushes from far to near in a low-angle shot, + revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in + for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground. + Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic shadows and warm highlights. Medium composition, front view, low angle, with depth of field. """ @@ -353,6 +473,12 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip - all - __call__ +## WanSpeechToVideoPipeline + +[[autodoc]] WanSpeechToVideoPipeline + - all + - __call__ + ## WanVideoToVideoPipeline [[autodoc]] WanVideoToVideoPipeline diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py index 39a364b07d78..c64bd6e9c3f5 100644 --- a/scripts/convert_wan_to_diffusers.py +++ b/scripts/convert_wan_to_diffusers.py @@ -6,13 +6,22 @@ from accelerate import init_empty_weights from huggingface_hub import hf_hub_download, snapshot_download from safetensors.torch import load_file -from transformers import AutoProcessor, AutoTokenizer, CLIPVisionModelWithProjection, UMT5EncoderModel +from transformers import ( + AutoProcessor, + AutoTokenizer, + CLIPVisionModelWithProjection, + UMT5EncoderModel, + Wav2Vec2ForCTC, + Wav2Vec2Processor, +) from diffusers import ( AutoencoderKLWan, UniPCMultistepScheduler, WanImageToVideoPipeline, WanPipeline, + WanS2VTransformer3DModel, + WanSpeechToVideoPipeline, WanTransformer3DModel, WanVACEPipeline, WanVACETransformer3DModel, @@ -105,8 +114,59 @@ "after_proj": "proj_out", } +S2V_TRANSFORMER_KEYS_RENAME_DICT = { + "time_embedding.0": "condition_embedder.time_embedder.linear_1", + "time_embedding.2": "condition_embedder.time_embedder.linear_2", + "text_embedding.0": "condition_embedder.text_embedder.linear_1", + "text_embedding.2": "condition_embedder.text_embedder.linear_2", + "time_projection.1": "condition_embedder.time_proj", + "head.modulation": "scale_shift_table", + "head.head": "proj_out", + "modulation": "scale_shift_table", + "ffn.0": "ffn.net.0.proj", + "ffn.2": "ffn.net.2", + # Hack to swap the layer names + # The original model calls the norms in following order: norm1, norm3, norm2 + # We convert it to: norm1, norm2, norm3 + "norm2": "norm__placeholder", + "norm3": "norm2", + "norm__placeholder": "norm3", + # Add attention component mappings + "self_attn.q": "attn1.to_q", + "self_attn.k": "attn1.to_k", + "self_attn.v": "attn1.to_v", + "self_attn.o": "attn1.to_out.0", + "self_attn.norm_q": "attn1.norm_q", + "self_attn.norm_k": "attn1.norm_k", + "cross_attn.q": "attn2.to_q", + "cross_attn.k": "attn2.to_k", + "cross_attn.v": "attn2.to_v", + "cross_attn.o": "attn2.to_out.0", + "cross_attn.norm_q": "attn2.norm_q", + "cross_attn.norm_k": "attn2.norm_k", + "attn2.to_k_img": "attn2.add_k_proj", + "attn2.to_v_img": "attn2.add_v_proj", + "attn2.norm_k_img": "attn2.norm_added_k", + # S2V-specific audio component mappings + "casual_audio_encoder.encoder.conv2.conv": "condition_embedder.causal_audio_encoder.encoder.conv2.conv.conv", + "casual_audio_encoder.encoder.conv3.conv": "condition_embedder.causal_audio_encoder.encoder.conv3.conv.conv", + "casual_audio_encoder.weights": "condition_embedder.causal_audio_encoder.weighted_avg.weights", + # Pose condition encoder mappings + "cond_encoder.weight": "condition_embedder.pose_embedder.weight", + "cond_encoder.bias": "condition_embedder.pose_embedder.bias", + "trainable_cond_mask": "trainable_condition_mask", + "patch_embedding": "motion_in.patch_embedding", + # Audio injector attention mappings - convert original q/k/v/o format to diffusers format + **{ + f"audio_injector.injector.{i}.{src}": f"audio_injector.injector.{i}.{dst}" + for i in range(12) + for src, dst in [("q", "to_q"), ("k", "to_k"), ("v", "to_v"), ("o", "to_out.0")] + }, +} + TRANSFORMER_SPECIAL_KEYS_REMAP = {} VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = {} +S2V_TRANSFORMER_SPECIAL_KEYS_REMAP = {} def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: @@ -364,6 +424,36 @@ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]: } RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP + elif model_type == "Wan2.2-S2V-14B": + config = { + "model_id": "Wan-AI/Wan2.2-S2V-14B", + "diffusers_config": { + "added_kv_proj_dim": None, + "attention_head_dim": 128, + "cross_attn_norm": True, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_channels": 16, + "num_attention_heads": 40, + "num_layers": 40, + "out_channels": 16, + "patch_size": [1, 2, 2], + "qk_norm": "rms_norm_across_heads", + "text_dim": 4096, + "audio_dim": 1024, + "audio_inject_layers": [0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39], + "enable_adain": True, + "adain_mode": "attn_norm", + "pose_dim": 16, + "enable_framepack": True, + "framepack_drop_mode": "padd", + "add_last_motion": True, + "zero_timestep": True, + }, + } + RENAME_DICT = S2V_TRANSFORMER_KEYS_RENAME_DICT + SPECIAL_KEYS_REMAP = S2V_TRANSFORMER_SPECIAL_KEYS_REMAP return config, RENAME_DICT, SPECIAL_KEYS_REMAP @@ -380,7 +470,9 @@ def convert_transformer(model_type: str, stage: str = None): original_state_dict = load_sharded_safetensors(model_dir) with init_empty_weights(): - if "VACE" not in model_type: + if "S2V" in model_type: + transformer = WanS2VTransformer3DModel.from_config(diffusers_config) + elif "VACE" not in model_type: transformer = WanTransformer3DModel.from_config(diffusers_config) else: transformer = WanVACETransformer3DModel.from_config(diffusers_config) @@ -926,7 +1018,7 @@ def get_args(): if __name__ == "__main__": args = get_args() - if "Wan2.2" in args.model_type and "TI2V" not in args.model_type: + if "Wan2.2" in args.model_type and "TI2V" not in args.model_type and "S2V" not in args.model_type: transformer = convert_transformer(args.model_type, stage="high_noise_model") transformer_2 = convert_transformer(args.model_type, stage="low_noise_model") else: @@ -942,7 +1034,7 @@ def get_args(): tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl") if "FLF2V" in args.model_type: flow_shift = 16.0 - elif "TI2V" in args.model_type: + elif "TI2V" in args.model_type or "S2V" in args.model_type: flow_shift = 5.0 else: flow_shift = 3.0 @@ -1016,6 +1108,22 @@ def get_args(): vae=vae, scheduler=scheduler, ) + elif "S2V" in args.model_type: + audio_encoder = Wav2Vec2ForCTC.from_pretrained( + "Wan-AI/Wan2.2-S2V-14B", subfolder="wav2vec2-large-xlsr-53-english" + ) + audio_processor = Wav2Vec2Processor.from_pretrained( + "Wan-AI/Wan2.2-S2V-14B", subfolder="wav2vec2-large-xlsr-53-english" + ) + pipe = WanSpeechToVideoPipeline( + transformer=transformer, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + scheduler=scheduler, + audio_encoder=audio_encoder, + audio_processor=audio_processor, + ) else: pipe = WanPipeline( transformer=transformer, diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 8867250deda8..fce2632ed9c7 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -258,6 +258,7 @@ "UNetSpatioTemporalConditionModel", "UVit2DModel", "VQModel", + "WanS2VTransformer3DModel", "WanTransformer3DModel", "WanVACETransformer3DModel", "attention_backend", @@ -616,6 +617,7 @@ "VQDiffusionPipeline", "WanImageToVideoPipeline", "WanPipeline", + "WanSpeechToVideoPipeline", "WanVACEPipeline", "WanVideoToVideoPipeline", "WuerstchenCombinedPipeline", @@ -945,6 +947,7 @@ UNetSpatioTemporalConditionModel, UVit2DModel, VQModel, + WanS2VTransformer3DModel, WanTransformer3DModel, WanVACETransformer3DModel, attention_backend, @@ -1273,6 +1276,7 @@ VQDiffusionPipeline, WanImageToVideoPipeline, WanPipeline, + WanSpeechToVideoPipeline, WanVACEPipeline, WanVideoToVideoPipeline, WuerstchenCombinedPipeline, diff --git a/src/diffusers/audio_processor.py b/src/diffusers/audio_processor.py new file mode 100644 index 000000000000..491aacf530aa --- /dev/null +++ b/src/diffusers/audio_processor.py @@ -0,0 +1,71 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Union + +import numpy as np +import torch + + +PipelineAudioInput = Union[ + np.ndarray, + torch.Tensor, + List[np.ndarray], + List[torch.Tensor], +] + + +def is_valid_audio(audio) -> bool: + r""" + Checks if the input is a valid audio. + + A valid audio can be: + - A 2D or 3D `np.ndarray` or `torch.Tensor` (grayscale or color image). + + Args: + audio (`Union[np.ndarray, torch.Tensor]`): + The audio to validate. It can be a NumPy array or a torch tensor. + + Returns: + `bool`: + `True` if the input is a valid audio, `False` otherwise. + """ + return isinstance(audio, (np.ndarray, torch.Tensor)) and audio.ndim in (2, 3) + + +def is_valid_audio_audiolist(audios): + r""" + Checks if the input is a valid audio or list of audios. + + The input can be one of the following formats: + - A 4D tensor or numpy array (batch of audios). + - A valid single audio: `np.ndarray` or `torch.Tensor`. + - A list of valid audios. + + Args: + audios (`Union[np.ndarray, torch.Tensor, List]`): + The audio(s) to check. Can be a batch of audios (4D tensor/array), a single audio, or a list of valid + audios. + + Returns: + `bool`: + `True` if the input is valid, `False` otherwise. + """ + if isinstance(audios, (np.ndarray, torch.Tensor)) and audios.ndim == 4: + return True + elif is_valid_audio(audios): + return True + elif isinstance(audios, list): + return all(is_valid_audio(audio) for audio in audios) + return False diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index 0e3082eada8a..086c013458ce 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -437,10 +437,11 @@ def _resize_and_crop( image: PIL.Image.Image, width: int, height: int, + resize_type: str = "fit_within", + crop_type: str = "paste_center", ) -> PIL.Image.Image: r""" - Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center - the image within the dimensions, cropping the excess. + Resize and crop the image using different strategies. Args: image (`PIL.Image.Image`): @@ -449,28 +450,55 @@ def _resize_and_crop( The width to resize the image to. height (`int`): The height to resize the image to. + resize_type (`str`, optional): + How to resize the image. Options: + - "fit_within": Resize to fit within dimensions, maintaining aspect ratio (default) + - "min_dimension": Resize so smaller dimension becomes min(width, height) + crop_type (`str`, optional): + How to handle the final cropping/positioning. Options: + - "paste_center": Paste resized image on centered canvas, pad with black (default) + - "center_crop": Center crop to exact dimensions, pad with black if needed Returns: `PIL.Image.Image`: The resized and cropped image. """ - ratio = width / height - src_ratio = image.width / image.height - src_w = width if ratio > src_ratio else image.width * height // image.height - src_h = height if ratio <= src_ratio else image.height * width // image.width + if resize_type == "fit_within": + # Resize to fit within dimensions + ratio = width / height + src_ratio = image.width / image.height - resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"]) - res = Image.new("RGB", (width, height)) - res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) - return res + src_w = width if ratio > src_ratio else image.width * height // image.height + src_h = height if ratio <= src_ratio else image.height * width // image.width + + resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION[self.config.resample]) + elif resize_type == "min_dimension": + # # Resize so smaller dimension becomes min(width, height) + from torchvision.transforms import Resize + + resized = Resize(min(height, width))(image) + else: + raise ValueError(f"Unknown resize_type: {resize_type}") + + if crop_type == "paste_center": + # Paste on canvas, center position + res = Image.new("RGB", (width, height), color=0) # Black background + res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) + return res + elif crop_type == "center_crop": + from torchvision.transforms import CenterCrop + + return CenterCrop((height, width))(resized) + else: + raise ValueError(f"Unknown crop_type: {crop_type}") def resize( self, image: Union[PIL.Image.Image, np.ndarray, torch.Tensor], height: int, width: int, - resize_mode: str = "default", # "default", "fill", "crop" + resize_mode: str = "default", # "default", "fill", "crop", "resize_min_center_crop" ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]: """ Resize image. @@ -483,13 +511,16 @@ def resize( width (`int`): The width to resize to. resize_mode (`str`, *optional*, defaults to `default`): - The resize mode to use, can be one of `default` or `fill`. If `default`, will resize the image to fit - within the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, - will resize the image to fit within the specified width and height, maintaining the aspect ratio, and - then center the image within the dimensions, filling empty with data from image. If `crop`, will resize - the image to fit within the specified width and height, maintaining the aspect ratio, and then center - the image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only - supported for PIL image input. + The resize mode to use, can be one of `default`, `fill`, `crop`, or `resize_min_center_crop`. If + `default`, will resize the image to fit within the specified width and height, and it may not + maintaining the original aspect ratio. If `fill`, will resize the image to fit within the specified + width and height, maintaining the aspect ratio, and then center the image within the dimensions, + filling empty with data from image. If `crop`, will resize the image to fit within the specified width + and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the + excess. If `resize_min_center_crop`, will resize the image so that the smaller dimension becomes + min(width, height), then center crop to exact target dimensions (matches Wan2.2-S2V preprocessing). + Note that resize_mode `fill`, `crop`, and `resize_min_center_crop` are only supported for PIL image + input. Returns: `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`: @@ -508,6 +539,10 @@ def resize( image = self._resize_and_fill(image, width, height) elif resize_mode == "crop": image = self._resize_and_crop(image, width, height) + elif resize_mode == "resize_min_center_crop": + image = self._resize_and_crop( + image, width, height, resize_type="min_dimension", crop_type="center_crop" + ) else: raise ValueError(f"resize_mode {resize_mode} is not supported") @@ -615,7 +650,7 @@ def preprocess( image: PipelineImageInput, height: Optional[int] = None, width: Optional[int] = None, - resize_mode: str = "default", # "default", "fill", "crop" + resize_mode: str = "default", # "default", "fill", "crop", "resize_min_center_crop" crops_coords: Optional[Tuple[int, int, int, int]] = None, ) -> torch.Tensor: """ @@ -631,13 +666,15 @@ def preprocess( width (`int`, *optional*): The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width. resize_mode (`str`, *optional*, defaults to `default`): - The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within - the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will - resize the image to fit within the specified width and height, maintaining the aspect ratio, and then - center the image within the dimensions, filling empty with data from image. If `crop`, will resize the - image to fit within the specified width and height, maintaining the aspect ratio, and then center the - image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only - supported for PIL image input. + The resize mode, can be one of `default`, `fill`, `crop`, or `resize_min_center_crop`. If `default`, + will resize the image to fit within the specified width and height, and it may not maintaining the + original aspect ratio. If `fill`, will resize the image to fit within the specified width and height, + maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data + from image. If `crop`, will resize the image to fit within the specified width and height, maintaining + the aspect ratio, and then center the image within the dimensions, cropping the excess. If + `resize_min_center_crop`, will resize the image so that the smaller dimension becomes min(width, + height), then center crop to exact target dimensions (matches Wan2.2 preprocessing). Note that + resize_mode `fill`, `crop`, and `resize_min_center_crop` are only supported for PIL image input. crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`): The crop coordinates for each image in the batch. If `None`, will not crop the image. diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 457f70448af3..e13a40c47207 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -100,6 +100,7 @@ _import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"] _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"] _import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"] + _import_structure["transformers.transformer_wan_s2v"] = ["WanS2VTransformer3DModel"] _import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"] _import_structure["unets.unet_1d"] = ["UNet1DModel"] _import_structure["unets.unet_2d"] = ["UNet2DModel"] @@ -198,6 +199,7 @@ T5FilmDecoder, Transformer2DModel, TransformerTemporalModel, + WanS2VTransformer3DModel, WanTransformer3DModel, WanVACETransformer3DModel, ) diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index b60f0636e6dc..ce7fbce40b87 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -36,4 +36,5 @@ from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel from .transformer_temporal import TransformerTemporalModel from .transformer_wan import WanTransformer3DModel + from .transformer_wan_s2v import WanS2VTransformer3DModel from .transformer_wan_vace import WanVACETransformer3DModel diff --git a/src/diffusers/models/transformers/transformer_wan_s2v.py b/src/diffusers/models/transformers/transformer_wan_s2v.py new file mode 100644 index 000000000000..4d810eeca630 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_wan_s2v.py @@ -0,0 +1,1185 @@ +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import AttentionMixin, FeedForward +from ..attention_dispatch import dispatch_attention_fn +from ..cache_utils import CacheMixin +from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin, get_parameter_dtype +from ..normalization import AdaLayerNorm, FP32LayerNorm +from .transformer_wan import ( + WanAttention, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor): + # encoder_hidden_states is only passed for cross-attention + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + if attn.fused_projections: + if attn.cross_attention_dim_head is None: + # In self-attention layers, we can fuse the entire QKV projection into a single linear + query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) + else: + # In cross-attention layers, we can only fuse the KV projections into a single linear + query = attn.to_q(hidden_states) + key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1) + else: + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + return query, key, value + + +def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: torch.Tensor): + if attn.fused_projections: + key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1) + else: + key_img = attn.add_k_proj(encoder_hidden_states_img) + value_img = attn.add_v_proj(encoder_hidden_states_img) + return key_img, value_img + + +class WanS2VAttnProcessor: + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "WanS2VAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." + ) + + def __call__( + self, + attn: "WanAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + encoder_hidden_states_img = None + if attn.add_k_proj is not None: + # 512 is the context length of the text encoder, hardcoded for now + image_context_length = encoder_hidden_states.shape[1] - 512 + encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length] + encoder_hidden_states = encoder_hidden_states[:, image_context_length:] + + query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + if rotary_emb is not None: + + def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): + # dtype = torch.float32 if hidden_states.device.type == "mps" else torch.float64 + n = hidden_states.size(2) + # loop over samples + output = [] + for i in range(hidden_states.size(0)): + s = hidden_states.size(1) + x_i = torch.view_as_complex(hidden_states[i, :s].to(torch.float64).reshape(s, n, -1, 2)) + freqs_i = freqs[i, :s] + # apply rotary embedding + x_i = torch.view_as_real(x_i * freqs_i).flatten(2) + x_i = torch.cat([x_i, hidden_states[i, s:]]) + # append to collection + output.append(x_i) + return torch.stack(output).type_as(hidden_states) + + query = apply_rotary_emb(query, rotary_emb) + key = apply_rotary_emb(key, rotary_emb) + + # I2V task + hidden_states_img = None + if encoder_hidden_states_img is not None: + key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img) + key_img = attn.norm_added_k(key_img) + + key_img = key_img.unflatten(2, (attn.heads, -1)) + value_img = value_img.unflatten(2, (attn.heads, -1)) + + hidden_states_img = dispatch_attention_fn( + query, + key_img, + value_img, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states_img = hidden_states_img.flatten(2, 3) + hidden_states_img = hidden_states_img.type_as(query) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.type_as(query) + + if hidden_states_img is not None: + hidden_states = hidden_states + hidden_states_img + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class WanS2VCausalConv1d(nn.Module): + def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", **kwargs): + super().__init__() + + self.pad_mode = pad_mode + self.time_causal_padding = (kernel_size - 1, 0) # T + + self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + + def forward(self, x): + x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) + return self.conv(x) + + +class WanS2VCausalConvLayer(nn.Module): + """A layer that combines causal convolution, normalization, and activation in sequence.""" + + def __init__( + self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", eps=1e-6, **kwargs + ): + super().__init__() + + self.conv = WanS2VCausalConv1d(chan_in, chan_out, kernel_size, stride, dilation, pad_mode, **kwargs) + self.norm = nn.LayerNorm(chan_out, elementwise_affine=False, eps=eps) + self.act = nn.SiLU() + + def forward(self, x): + x = x.permute(0, 2, 1) + x = self.conv(x) + x = x.permute(0, 2, 1) + x = self.norm(x) + x = self.act(x) + return x + + +class WanS2VMotionEncoder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int, num_attention_heads: int, need_global: bool = True): + super().__init__() + + self.num_attention_heads = num_attention_heads + self.need_global = need_global + self.conv1_local = WanS2VCausalConv1d(in_dim, hidden_dim // 4 * num_attention_heads, 3, stride=1) + if need_global: + self.conv1_global = WanS2VCausalConv1d(in_dim, hidden_dim // 4, 3, stride=1) + self.conv2 = WanS2VCausalConvLayer(hidden_dim // 4, hidden_dim // 2, 3, stride=2) + self.conv3 = WanS2VCausalConvLayer(hidden_dim // 2, hidden_dim, 3, stride=2) + + if need_global: + self.final_linear = nn.Linear(hidden_dim, hidden_dim) + + self.norm1 = nn.LayerNorm(hidden_dim // 4, elementwise_affine=False, eps=1e-6) + self.act = nn.SiLU() + + self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim)) + + def forward(self, x): + x = x.permute(0, 2, 1) + residual = x.clone() + batch_size, num_channels, seq_len = x.shape + x = self.conv1_local(x) + x = x.unflatten(1, (self.num_attention_heads, -1)).permute(0, 1, 3, 2).flatten(0, 1) + x = self.norm1(x) + x = self.act(x) + x = self.conv2(x) + x = self.conv3(x) + x = x.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3) + padding = self.padding_tokens.repeat(batch_size, x.shape[1], 1, 1) + x = torch.cat([x, padding], dim=-2) + x_local = x.clone() + + if not self.need_global: + return x_local + + x = self.conv1_global(residual) + x = x.permute(0, 2, 1) + x = self.norm1(x) + x = self.act(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.final_linear(x) + x = x.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3) + + return x, x_local + + +class WeightedAveragelayer(nn.Module): + def __init__(self, num_layers): + super().__init__() + self.weights = torch.nn.Parameter(torch.ones((1, num_layers, 1, 1)) * 0.01) + self.act = torch.nn.SiLU() + + def forward(self, features): + # features B * num_layers * dim * video_length + weights = self.act(self.weights) + weights_sum = weights.sum(dim=1, keepdims=True) + weighted_feat = ((features * weights) / weights_sum).sum(dim=1) # b dim f + + return weighted_feat + + +class CausalAudioEncoder(nn.Module): + def __init__(self, dim=5120, num_weighted_avg_layers=25, out_dim=2048, num_audio_token=4, need_global=False): + super().__init__() + self.weighted_avg = WeightedAveragelayer(num_weighted_avg_layers) + self.encoder = WanS2VMotionEncoder( + in_dim=dim, hidden_dim=out_dim, num_attention_heads=num_audio_token, need_global=need_global + ) + + def forward(self, features): + # features B * num_layers * dim * video_length + weighted_feat = self.weighted_avg(features) + weighted_feat = weighted_feat.permute(0, 2, 1) # b f dim + res = self.encoder(weighted_feat) # b f n dim + + return res # b f n dim + + +class AudioInjector(nn.Module): + def __init__( + self, + num_injection_layers, + inject_layers, + dim=2048, + num_heads=32, + enable_adain=False, + adain_mode="attn_norm", + adain_dim=2048, + eps=1e-6, + added_kv_proj_dim=None, + ): + super().__init__() + self.enable_adain = enable_adain + self.adain_mode = adain_mode + self.injected_block_id = dict(zip(inject_layers, range(num_injection_layers))) + + # Cross-attention + self.injector = nn.ModuleList( + [ + WanAttention( + dim=dim, + heads=num_heads, + dim_head=dim // num_heads, + eps=eps, + added_kv_proj_dim=added_kv_proj_dim, + cross_attention_dim_head=dim // num_heads, + processor=WanS2VAttnProcessor(), + ) + for _ in range(num_injection_layers) + ] + ) + + self.injector_pre_norm_feat = nn.ModuleList( + [nn.LayerNorm(dim, elementwise_affine=False, eps=eps) for _ in range(num_injection_layers)] + ) + self.injector_pre_norm_vec = nn.ModuleList( + [nn.LayerNorm(dim, elementwise_affine=False, eps=eps) for _ in range(num_injection_layers)] + ) + + if enable_adain: + self.injector_adain_layers = nn.ModuleList( + [ + AdaLayerNorm(embedding_dim=adain_dim, output_dim=dim * 2, chunk_dim=1) + for _ in range(num_injection_layers) + ] + ) + if adain_mode != "attn_norm": + self.injector_adain_output_layers = nn.ModuleList( + [nn.Linear(dim, dim) for _ in range(num_injection_layers)] + ) + + def forward( + self, + block_idx, + hidden_states, + original_sequence_length, + merged_audio_emb_num_frames, + attn_audio_emb, + audio_emb_global, + ): + audio_attn_id = self.injected_block_id[block_idx] + + input_hidden_states = hidden_states[:, :original_sequence_length].clone() # B (F H W) C + input_hidden_states = input_hidden_states.unflatten(1, (merged_audio_emb_num_frames, -1)).flatten(0, 1) + + if self.enable_adain and self.adain_mode == "attn_norm": + attn_hidden_states = self.injector_adain_layers[audio_attn_id]( + input_hidden_states, temb=audio_emb_global[:, 0] + ) + else: + attn_hidden_states = self.injector_pre_norm_feat[audio_attn_id](input_hidden_states) + + residual_out = self.injector[audio_attn_id](attn_hidden_states, attn_audio_emb, None, None) + residual_out = residual_out.unflatten(0, (-1, merged_audio_emb_num_frames)).flatten(1, 2) + hidden_states[:, :original_sequence_length] = hidden_states[:, :original_sequence_length] + residual_out + + return hidden_states + + +class FramePackMotioner(nn.Module): + def __init__( + self, + inner_dim=1024, + num_attention_heads=16, # Used to indicate the number of heads in the backbone network; unrelated to this module's design + zip_frame_buckets=[ + 1, + 2, + 16, + ], # Three numbers representing the number of frames sampled for patch operations from the nearest to the farthest frames + drop_mode="drop", # If not "drop", it will use "padd", meaning padding instead of deletion + patch_size=(1, 2, 2), + ): + super().__init__() + self.inner_dim = inner_dim + self.num_attention_heads = num_attention_heads + if (inner_dim % num_attention_heads) != 0 or (inner_dim // num_attention_heads) % 2 != 0: + raise ValueError( + "inner_dim must be divisible by num_attention_heads and inner_dim // num_attention_heads must be even" + ) + self.drop_mode = drop_mode + + self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) + self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) + self.zip_frame_buckets = torch.tensor(zip_frame_buckets, dtype=torch.long) + + self.rope = WanS2VRotaryPosEmbed( + inner_dim // num_attention_heads, + patch_size=patch_size, + max_seq_len=1024, + num_attention_heads=num_attention_heads, + ) + + def forward(self, motion_latents, add_last_motion=2): + latent_height, latent_width = motion_latents.shape[3], motion_latents.shape[4] + padd_latent = torch.zeros( + (motion_latents.shape[0], 16, self.zip_frame_buckets.sum(), latent_height, latent_width), + device=motion_latents.device, + dtype=motion_latents.dtype, + ) + overlap_frame = min(padd_latent.shape[2], motion_latents.shape[2]) + if overlap_frame > 0: + padd_latent[:, :, -overlap_frame:] = motion_latents[:, :, -overlap_frame:] + + if add_last_motion < 2 and self.drop_mode != "drop": + zero_end_frame = self.zip_frame_buckets[: len(self.zip_frame_buckets) - add_last_motion - 1].sum() + padd_latent[:, :, -zero_end_frame:] = 0 + + clean_latents_4x, clean_latents_2x, clean_latents_post = padd_latent[ + :, :, -self.zip_frame_buckets.sum() :, :, : + ].split(list(self.zip_frame_buckets)[::-1], dim=2) # 16, 2, 1 + + # Patchify + clean_latents_post = self.proj(clean_latents_post).flatten(2).transpose(1, 2) + clean_latents_2x = self.proj_2x(clean_latents_2x).flatten(2).transpose(1, 2) + clean_latents_4x = self.proj_4x(clean_latents_4x).flatten(2).transpose(1, 2) + + if add_last_motion < 2 and self.drop_mode == "drop": + clean_latents_post = clean_latents_post[:, :0] if add_last_motion < 2 else clean_latents_post + clean_latents_2x = clean_latents_2x[:, :0] if add_last_motion < 1 else clean_latents_2x + + motion_lat = torch.cat([clean_latents_post, clean_latents_2x, clean_latents_4x], dim=1) + + # RoPE + start_time_id = -(self.zip_frame_buckets[:1].sum()) + end_time_id = start_time_id + self.zip_frame_buckets[0] + grid_sizes = ( + [] + if add_last_motion < 2 and self.drop_mode == "drop" + else [ + [ + torch.tensor([start_time_id, 0, 0]).unsqueeze(0), + torch.tensor([end_time_id, latent_height // 2, latent_width // 2]).unsqueeze(0), + torch.tensor([self.zip_frame_buckets[0], latent_height // 2, latent_width // 2]).unsqueeze(0), + ] + ] + ) + + start_time_id = -(self.zip_frame_buckets[:2].sum()) + end_time_id = start_time_id + self.zip_frame_buckets[1] // 2 + grid_sizes_2x = ( + [] + if add_last_motion < 1 and self.drop_mode == "drop" + else [ + [ + torch.tensor([start_time_id, 0, 0]).unsqueeze(0), + torch.tensor([end_time_id, latent_height // 4, latent_width // 4]).unsqueeze(0), + torch.tensor([self.zip_frame_buckets[1], latent_height // 2, latent_width // 2]).unsqueeze(0), + ] + ] + ) + + start_time_id = -(self.zip_frame_buckets[:3].sum()) + end_time_id = start_time_id + self.zip_frame_buckets[2] // 4 + grid_sizes_4x = [ + [ + torch.tensor([start_time_id, 0, 0]).unsqueeze(0), + torch.tensor([end_time_id, latent_height // 8, latent_width // 8]).unsqueeze(0), + torch.tensor([self.zip_frame_buckets[2], latent_height // 2, latent_width // 2]).unsqueeze(0), + ] + ] + + grid_sizes = grid_sizes + grid_sizes_2x + grid_sizes_4x + + motion_rope_emb = self.rope( + motion_lat.detach().view( + motion_lat.shape[0], + motion_lat.shape[1], + self.num_attention_heads, + self.inner_dim // self.num_attention_heads, + ), + grid_sizes=grid_sizes, + ) + + return motion_lat, motion_rope_emb + + +class Motioner(nn.Module): + def __init__(self, inner_dim, num_attention_heads, patch_size=(1, 2, 2), in_channels=16, rope_max_seq_len=1024): + super().__init__() + self.inner_dim = inner_dim + self.num_attention_heads = num_attention_heads + + self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) + self.rope = WanS2VRotaryPosEmbed( + inner_dim // num_attention_heads, patch_size, rope_max_seq_len, num_attention_heads + ) + + def forward(self, motion_latents): + latent_motion_frames = motion_latents.shape[2] + mot = self.patch_embedding(motion_latents) + + height, width = mot.shape[3], mot.shape[4] + flat_mot = mot.flatten(2).transpose(1, 2).contiguous() + motion_grid_sizes = [ + [ + torch.tensor([-latent_motion_frames, 0, 0]).unsqueeze(0), + torch.tensor([0, height, width]).unsqueeze(0), + torch.tensor([latent_motion_frames, height, width]).unsqueeze(0), + ] + ] + motion_rope_emb = self.rope( + flat_mot.detach().view( + flat_mot.shape[0], + flat_mot.shape[1], + self.num_attention_heads, + self.inner_dim // self.num_attention_heads, + ), + motion_grid_sizes, + ) + + return flat_mot, motion_rope_emb + + +class WanTimeTextAudioPoseEmbedding(nn.Module): + def __init__( + self, + dim: int, + time_freq_dim: int, + time_proj_dim: int, + text_embed_dim: int, + audio_embed_dim: int, + pose_embed_dim: int, + patch_size: Tuple[int], + enable_adain: bool, + num_weighted_avg_layers: int, + ): + super().__init__() + + self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim) + self.act_fn = nn.SiLU() + self.time_proj = nn.Linear(dim, time_proj_dim) + self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh") + self.causal_audio_encoder = CausalAudioEncoder( + dim=audio_embed_dim, + num_weighted_avg_layers=num_weighted_avg_layers, + out_dim=dim, + num_audio_token=4, + need_global=enable_adain, + ) + self.pose_embedder = nn.Conv3d(pose_embed_dim, dim, kernel_size=patch_size, stride=patch_size) + + def forward( + self, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + audio_hidden_states: torch.Tensor, + pose_hidden_states: Optional[torch.Tensor] = None, + timestep_seq_len: Optional[int] = None, + ): + timestep = self.timesteps_proj(timestep) + if timestep_seq_len is not None: + timestep = timestep.unflatten(0, (-1, timestep_seq_len)) + + time_embedder_dtype = get_parameter_dtype(self.time_embedder) + if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: + timestep = timestep.to(time_embedder_dtype) + temb = self.time_embedder(timestep).type_as(encoder_hidden_states) + timestep_proj = self.time_proj(self.act_fn(temb)) + + encoder_hidden_states = self.text_embedder(encoder_hidden_states) + + audio_hidden_states = self.causal_audio_encoder(audio_hidden_states) + + pose_hidden_states = self.pose_embedder(pose_hidden_states) + + return temb, timestep_proj, encoder_hidden_states, audio_hidden_states, pose_hidden_states + + +class WanS2VRotaryPosEmbed(nn.Module): + def __init__( + self, + attention_head_dim: int, + patch_size: Tuple[int, int, int], + max_seq_len: int, + num_attention_heads: int, + theta: float = 10000.0, + ): + super().__init__() + + self.attention_head_dim = attention_head_dim + self.patch_size = patch_size + self.max_seq_len = max_seq_len + self.num_attention_heads = num_attention_heads + + h_dim = w_dim = 2 * (attention_head_dim // 6) + t_dim = attention_head_dim - h_dim - w_dim + freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 + + freqs = [] + + for dim in [t_dim, h_dim, w_dim]: + freq = get_1d_rotary_pos_embed( + dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=freqs_dtype + ) + freqs.append(freq) + + self.freqs = torch.cat(freqs, dim=1) + + def forward( + self, + hidden_states: torch.Tensor, + image_latents: Optional[torch.Tensor] = None, + grid_sizes: Optional[List[List[torch.Tensor]]] = None, + ) -> torch.Tensor: + if grid_sizes is None: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.patch_size + ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w + + grid_sizes = torch.tensor([ppf, pph, ppw]).unsqueeze(0).repeat(batch_size, 1) + grid_sizes = [torch.zeros_like(grid_sizes), grid_sizes, grid_sizes] + + image_grid_sizes = [ + # The start index + torch.tensor([30, 0, 0]).unsqueeze(0).repeat(batch_size, 1), + # The end index + torch.tensor([31, image_latents.shape[3] // p_h, image_latents.shape[4] // p_w]) + .unsqueeze(0) + .repeat(batch_size, 1), + # The range + torch.tensor([1, image_latents.shape[3] // p_h, image_latents.shape[4] // p_w]) + .unsqueeze(0) + .repeat(batch_size, 1), + ] + + grids = [grid_sizes, image_grid_sizes] + S = ppf * pph * ppw + image_latents.shape[3] // p_h * image_latents.shape[4] // p_w + else: # FramePack's RoPE + batch_size, S, _, _ = hidden_states.shape + grids = grid_sizes + + split_sizes = [ + self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6), + self.attention_head_dim // 6, + self.attention_head_dim // 6, + ] + + freqs = self.freqs.split(split_sizes, dim=1) + + # Loop over samples + output = torch.view_as_complex( + torch.zeros( + (batch_size, S, self.num_attention_heads, self.attention_head_dim // 2, 2), + device=hidden_states.device, + dtype=torch.float64, + ) + ) + seq_bucket = [0] + for g in grids: + if type(g) is not list: + g = [torch.zeros_like(g), g] + batch_size = g[0].shape[0] + for i in range(batch_size): + f_o, h_o, w_o = g[0][i] + + f, h, w = g[1][i] + t_f, t_h, t_w = g[2][i] + seq_f, seq_h, seq_w = f - f_o, h - h_o, w - w_o + seq_len = int(seq_f * seq_h * seq_w) + if seq_len > 0: + if t_f > 0: + # Generate a list of seq_f integers starting from f_o and ending at math.ceil(factor_f * seq_f.item() + f_o.item()) + if f_o >= 0: + f_sam = np.linspace(f_o.item(), (t_f + f_o).item() - 1, seq_f).astype(int).tolist() + else: + f_sam = np.linspace(-f_o.item(), (-t_f - f_o).item() + 1, seq_f).astype(int).tolist() + h_sam = np.linspace(h_o.item(), (t_h + h_o).item() - 1, seq_h).astype(int).tolist() + w_sam = np.linspace(w_o.item(), (t_w + w_o).item() - 1, seq_w).astype(int).tolist() + + assert f_o * f >= 0 and h_o * h >= 0 and w_o * w >= 0 + freqs_0 = freqs[0][f_sam] if f_o >= 0 else freqs[0][f_sam].conj() + freqs_0 = freqs_0.view(seq_f, 1, 1, -1) + + freqs_i = torch.cat( + [ + freqs_0.expand(seq_f, seq_h, seq_w, -1), + freqs[1][h_sam].view(1, seq_h, 1, -1).expand(seq_f, seq_h, seq_w, -1), + freqs[2][w_sam].view(1, 1, seq_w, -1).expand(seq_f, seq_h, seq_w, -1), + ], + dim=-1, + ).reshape(seq_len, 1, -1) + + # apply rotary embedding + output[i, seq_bucket[-1] : seq_bucket[-1] + seq_len] = freqs_i + seq_bucket.append(seq_bucket[-1] + seq_len) + + return output + + +@maybe_allow_in_graph +class WanS2VTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + ffn_dim: int, + num_heads: int, + qk_norm: str = "rms_norm_across_heads", + cross_attn_norm: bool = False, + eps: float = 1e-6, + added_kv_proj_dim: Optional[int] = None, + ): + super().__init__() + + # 1. Self-attention + self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) + self.attn1 = WanAttention( + dim=dim, + heads=num_heads, + dim_head=dim // num_heads, + eps=eps, + cross_attention_dim_head=None, + processor=WanS2VAttnProcessor(), + ) + + # 2. Cross-attention + self.attn2 = WanAttention( + dim=dim, + heads=num_heads, + dim_head=dim // num_heads, + eps=eps, + added_kv_proj_dim=added_kv_proj_dim, + cross_attention_dim_head=dim // num_heads, + processor=WanS2VAttnProcessor(), + ) + self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() + + # 3. Feed-forward + self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate") + self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False) + + self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: Tuple[torch.Tensor, torch.Tensor], + rotary_emb: torch.Tensor, + ) -> torch.Tensor: + seg_idx = temb[1].item() + seg_idx = min(max(0, seg_idx), hidden_states.shape[1]) + seg_idx = [0, seg_idx, hidden_states.shape[1]] + temb = temb[0] + # temb: batch_size, 6, 2, inner_dim + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table.unsqueeze(2) + temb.float() + ).chunk(6, dim=1) + # batch_size, 1, seq_len, inner_dim + shift_msa = shift_msa.squeeze(1) + scale_msa = scale_msa.squeeze(1) + gate_msa = gate_msa.squeeze(1) + c_shift_msa = c_shift_msa.squeeze(1) + c_scale_msa = c_scale_msa.squeeze(1) + c_gate_msa = c_gate_msa.squeeze(1) + + norm_hidden_states = self.norm1(hidden_states.float()) + parts = [] + for i in range(2): + parts.append( + norm_hidden_states[:, seg_idx[i] : seg_idx[i + 1]] * (1 + scale_msa[:, i : i + 1]) + + shift_msa[:, i : i + 1] + ) + norm_hidden_states = torch.cat(parts, dim=1).type_as(hidden_states) + + # 1. Self-attention + attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb) + z = [] + for i in range(2): + z.append(attn_output[:, seg_idx[i] : seg_idx[i + 1]] * gate_msa[:, i : i + 1]) + attn_output = torch.cat(z, dim=1) + hidden_states = (hidden_states.float() + attn_output).type_as(hidden_states) + + # 2. Cross-attention + norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states) + attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None) + hidden_states = hidden_states + attn_output + + # 3. Feed-forward + norm3_hidden_states = self.norm3(hidden_states.float()) + parts = [] + for i in range(2): + parts.append( + norm3_hidden_states[:, seg_idx[i] : seg_idx[i + 1]] * (1 + c_scale_msa[:, i : i + 1]) + + c_shift_msa[:, i : i + 1] + ) + norm3_hidden_states = torch.cat(parts, dim=1).type_as(hidden_states) + ff_output = self.ffn(norm3_hidden_states) + z = [] + for i in range(2): + z.append(ff_output[:, seg_idx[i] : seg_idx[i + 1]] * c_gate_msa[:, i : i + 1]) + ff_output = torch.cat(z, dim=1) + hidden_states = (hidden_states.float() + ff_output.float()).type_as(hidden_states) + + return hidden_states + + +class WanS2VTransformer3DModel( + ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin +): + r""" + A Transformer model for video-like data used in the Wan2.2-S2V model. + + Args: + patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`): + 3D patch dimensions for video embedding (t_patch, h_patch, w_patch). + num_attention_heads (`int`, defaults to `40`): + Fixed length for text embeddings. + attention_head_dim (`int`, defaults to `128`): + The number of channels in each head. + in_channels (`int`, defaults to `16`): + The number of channels in the input. + out_channels (`int`, defaults to `16`): + The number of channels in the output. + text_dim (`int`, defaults to `512`): + Input dimension for text embeddings. + freq_dim (`int`, defaults to `256`): + Dimension for sinusoidal time embeddings. + ffn_dim (`int`, defaults to `13824`): + Intermediate dimension in feed-forward network. + num_layers (`int`, defaults to `40`): + The number of layers of transformer blocks to use. + window_size (`Tuple[int]`, defaults to `(-1, -1)`): + Window size for local attention (-1 indicates global attention). + cross_attn_norm (`bool`, defaults to `True`): + Enable cross-attention normalization. + qk_norm (`bool`, defaults to `True`): + Enable query/key normalization. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + add_img_emb (`bool`, defaults to `False`): + Whether to use img_emb. + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + zero_timestep (`bool`, defaults to `True`): + Whether to assign 0 value timestep to image/motion + """ + + _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"] + _no_split_modules = ["WanS2VTransformerBlock"] + _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3", "causal_audio_encoder"] + _keys_to_ignore_on_load_unexpected = ["norm_added_q"] + _repeated_blocks = ["WanS2VTransformerBlock"] + + @register_to_config + def __init__( + self, + patch_size: Tuple[int] = (1, 2, 2), + num_attention_heads: int = 40, + attention_head_dim: int = 128, + in_channels: int = 16, + out_channels: int = 16, + text_dim: int = 4096, + freq_dim: int = 256, + audio_dim: int = 1024, + audio_inject_layers: Tuple[int] = (0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39), + enable_adain: bool = True, + adain_mode: str = "attn_norm", + pose_dim: int = 16, + ffn_dim: int = 13824, + num_layers: int = 40, + num_weighted_avg_layers: int = 25, + cross_attn_norm: bool = True, + qk_norm: Optional[str] = "rms_norm_across_heads", + eps: float = 1e-6, + added_kv_proj_dim: Optional[int] = None, + rope_max_seq_len: int = 1024, + enable_framepack: bool = True, + framepack_drop_mode: str = "padd", + add_last_motion: bool = True, + zero_timestep: bool = True, + ) -> None: + super().__init__() + + self.inner_dim = inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels or in_channels + + # 1. Patch & position embedding + self.rope = WanS2VRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len, num_attention_heads) + self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) + + if enable_framepack: + self.frame_packer = FramePackMotioner( + inner_dim=inner_dim, + num_attention_heads=num_attention_heads, + zip_frame_buckets=[1, 2, 16], + drop_mode=framepack_drop_mode, + patch_size=patch_size, + ) + else: + self.motion_in = Motioner( + inner_dim=inner_dim, + num_attention_heads=num_attention_heads, + patch_size=patch_size, + in_channels=in_channels, + rope_max_seq_len=rope_max_seq_len, + ) + + self.trainable_condition_mask = nn.Embedding(3, inner_dim) + + # 2. Condition Embeddings + self.condition_embedder = WanTimeTextAudioPoseEmbedding( + dim=inner_dim, + time_freq_dim=freq_dim, + time_proj_dim=inner_dim * 6, + text_embed_dim=text_dim, + audio_embed_dim=audio_dim, + pose_embed_dim=pose_dim, + patch_size=patch_size, + enable_adain=enable_adain, + num_weighted_avg_layers=num_weighted_avg_layers, + ) + + # 3. Transformer blocks + self.blocks = nn.ModuleList( + [ + WanS2VTransformerBlock( + inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim + ) + for _ in range(num_layers) + ] + ) + + # 4. Audio Injector + self.audio_injector = AudioInjector( + num_injection_layers=len(audio_inject_layers), + inject_layers=audio_inject_layers, + dim=inner_dim, + num_heads=num_attention_heads, + enable_adain=enable_adain, + adain_dim=inner_dim, + adain_mode=adain_mode, + eps=eps, + ) + + # 4. Output norm & projection + self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) + self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) + self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) + + self.gradient_checkpointing = False + + def process_motion(self, motion_latents, drop_motion_frames=False): + flattern_mot, mot_remb = self.motion_in(motion_latents) + + if drop_motion_frames or motion_latents[0].shape[1] == 0: + return [], [] + else: + return flattern_mot, mot_remb + + def process_motion_frame_pack(self, motion_latents, drop_motion_frames=False, add_last_motion=2): + flattern_mot, mot_remb = self.frame_packer(motion_latents, add_last_motion) + + if drop_motion_frames: + return flattern_mot[:, :0], mot_remb[:, :0] + else: + return flattern_mot, mot_remb + + def inject_motion( + self, + hidden_states, + seq_lens, + rope_embs, + mask_input, + motion_latents, + drop_motion_frames=False, + add_last_motion=True, + ): + # Inject the motion frames token to the hidden states + if self.config.enable_framepack: + mot, mot_remb = self.process_motion_frame_pack(motion_latents, drop_motion_frames, add_last_motion) + else: + mot, mot_remb = self.process_motion(motion_latents, drop_motion_frames) + + if len(mot) > 0: + hidden_states = torch.cat([hidden_states, mot], dim=1) + seq_lens = seq_lens + torch.tensor([mot.shape[1]], dtype=torch.long) + rope_embs = torch.cat([rope_embs, mot_remb], dim=1) + mask_input = torch.cat( + [ + mask_input, + 2 + * torch.ones( + [1, hidden_states.shape[1] - mask_input.shape[1]], + device=mask_input.device, + dtype=mask_input.dtype, + ), + ], + dim=1, + ) + return hidden_states, seq_lens, rope_embs, mask_input + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + motion_latents: torch.Tensor, + audio_embeds: torch.Tensor, + image_latents: torch.Tensor, + pose_latents: torch.Tensor, + motion_frames: List[int] = [17, 5], + drop_motion_frames: bool = False, + add_last_motion: int = 2, + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + r""" + Parameters: + audio_embeds: + The input audio embedding [B, num_wav2vec_layer, C_a, T_a]. + motion_frames: + The number of motion frames and motion latents frames encoded by vae, i.e. [17, 5]. + add_last_motion: + For the motioner, if add_last_motion > 0, it means that the most recent frame (i.e., the last frame) + will be added. For frame packing, the behavior depends on the value of add_last_motion: add_last_motion + = 0: Only the farthest part of the latent (i.e., clean_latents_4x) is included. add_last_motion = 1: + Both clean_latents_2x and clean_latents_4x are included. add_last_motion = 2: All motion-related + latents are used. + drop_motion_frames: + Bool, whether drop the motion frames info. + """ + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.config.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + add_last_motion = self.config.add_last_motion * add_last_motion + + # 1. Rotary position embeddings + rotary_emb = self.rope(hidden_states, image_latents) + + # 2. Patch embeddings + hidden_states = self.patch_embedding(hidden_states) + image_latents = self.patch_embedding(image_latents) + + # 3. Condition embeddings + audio_embeds = torch.cat( + [audio_embeds[..., 0].unsqueeze(-1).repeat(1, 1, 1, motion_frames[0]), audio_embeds], dim=-1 + ) + + if self.config.zero_timestep: + timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)]) + + temb, timestep_proj, encoder_hidden_states, audio_hidden_states, pose_hidden_states = self.condition_embedder( + timestep, encoder_hidden_states, audio_embeds, pose_latents + ) + timestep_proj = timestep_proj.unflatten(1, (6, -1)) + + if self.config.enable_adain: + audio_emb_global, audio_emb = audio_hidden_states + audio_emb_global = audio_emb_global[:, motion_frames[1] :].clone() + else: + audio_emb = audio_hidden_states + merged_audio_emb = audio_emb[:, motion_frames[1] :, :] + + hidden_states = hidden_states + pose_hidden_states + + hidden_states = hidden_states.flatten(2).transpose(1, 2) + image_latents = image_latents.flatten(2).transpose(1, 2) + + sequence_length = torch.tensor([hidden_states.shape[1]], dtype=torch.long) + original_sequence_length = sequence_length + sequence_length = sequence_length + torch.tensor([image_latents.shape[1]], dtype=torch.long) + hidden_states = torch.cat([hidden_states, image_latents], dim=1) + + # Initialize masks to indicate noisy latent, image latent, and motion latent. + # However, at this point, only the first two (noisy and image latents) are marked; + # the marking of motion latent will be implemented inside `inject_motion`. + mask_input = torch.zeros([1, hidden_states.shape[1]], dtype=torch.long, device=hidden_states.device) + mask_input[:, original_sequence_length:] = 1 + + hidden_states, sequence_length, rotary_emb, mask_input = self.inject_motion( + hidden_states, + sequence_length, + rotary_emb, + mask_input, + motion_latents, + drop_motion_frames, + add_last_motion, + ) + + hidden_states = hidden_states + self.trainable_condition_mask(mask_input).to(hidden_states.dtype) + + if self.config.zero_timestep: + temb = temb[:-1] + zero_timestep_proj = timestep_proj[-1:] + timestep_proj = timestep_proj[:-1] + timestep_proj = torch.cat( + [timestep_proj.unsqueeze(2), zero_timestep_proj.unsqueeze(2).repeat(timestep_proj.shape[0], 1, 1, 1)], + dim=2, + ) + timestep_proj = [timestep_proj, original_sequence_length] + else: + timestep_proj = timestep_proj.unsqueeze(2).repeat(1, 1, 2, 1) + timestep_proj = [timestep_proj, 0] + + merged_audio_emb_num_frames = merged_audio_emb.shape[1] # B F N C + attn_audio_emb = merged_audio_emb.flatten(0, 1).to(hidden_states.dtype) + audio_emb_global = audio_emb_global.flatten(0, 1).to(hidden_states.dtype) + + # 5. Transformer blocks + if torch.is_grad_enabled() and self.gradient_checkpointing: + for block_idx, block in enumerate(self.blocks): + hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + timestep_proj, + rotary_emb, + ) + if block_idx in self.audio_injector.injected_block_id.keys(): + hidden_states = self.audio_injector( + block_idx, + hidden_states, + original_sequence_length, + merged_audio_emb_num_frames, + attn_audio_emb, + audio_emb_global, + ) + else: + for block_idx, block in enumerate(self.blocks): + hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + if block_idx in self.audio_injector.injected_block_id.keys(): + hidden_states = self.audio_injector( + block_idx, + hidden_states, + original_sequence_length, + merged_audio_emb_num_frames, + attn_audio_emb, + audio_emb_global, + ) + + hidden_states = hidden_states[:, :original_sequence_length] + + # 6. Output norm, projection & unpatchify + shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) + + # Move the shift and scale tensors to the same device as hidden_states. + # When using multi-GPU inference via accelerate these will be on the + # first device rather than the last device, which hidden_states ends up on. + shift = shift.to(hidden_states.device) + scale = scale.to(hidden_states.device) + + hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 + ) + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 190c7871d270..70a2ab714cc1 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -381,7 +381,13 @@ "WuerstchenDecoderPipeline", "WuerstchenPriorPipeline", ] - _import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline", "WanVideoToVideoPipeline", "WanVACEPipeline"] + _import_structure["wan"] = [ + "WanPipeline", + "WanImageToVideoPipeline", + "WanVideoToVideoPipeline", + "WanVACEPipeline", + "WanSpeechToVideoPipeline", + ] _import_structure["skyreels_v2"] = [ "SkyReelsV2DiffusionForcingPipeline", "SkyReelsV2DiffusionForcingImageToVideoPipeline", @@ -786,7 +792,13 @@ UniDiffuserTextDecoder, ) from .visualcloze import VisualClozeGenerationPipeline, VisualClozePipeline - from .wan import WanImageToVideoPipeline, WanPipeline, WanVACEPipeline, WanVideoToVideoPipeline + from .wan import ( + WanImageToVideoPipeline, + WanPipeline, + WanSpeechToVideoPipeline, + WanVACEPipeline, + WanVideoToVideoPipeline, + ) from .wuerstchen import ( WuerstchenCombinedPipeline, WuerstchenDecoderPipeline, diff --git a/src/diffusers/pipelines/wan/__init__.py b/src/diffusers/pipelines/wan/__init__.py index bb96372b1db2..f21a66dbb7e6 100644 --- a/src/diffusers/pipelines/wan/__init__.py +++ b/src/diffusers/pipelines/wan/__init__.py @@ -24,6 +24,7 @@ else: _import_structure["pipeline_wan"] = ["WanPipeline"] _import_structure["pipeline_wan_i2v"] = ["WanImageToVideoPipeline"] + _import_structure["pipeline_wan_s2v"] = ["WanSpeechToVideoPipeline"] _import_structure["pipeline_wan_vace"] = ["WanVACEPipeline"] _import_structure["pipeline_wan_video2video"] = ["WanVideoToVideoPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -36,6 +37,7 @@ else: from .pipeline_wan import WanPipeline from .pipeline_wan_i2v import WanImageToVideoPipeline + from .pipeline_wan_s2v import WanSpeechToVideoPipeline from .pipeline_wan_vace import WanVACEPipeline from .pipeline_wan_video2video import WanVideoToVideoPipeline diff --git a/src/diffusers/pipelines/wan/pipeline_wan_s2v.py b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py new file mode 100644 index 000000000000..aa72f0fc24f3 --- /dev/null +++ b/src/diffusers/pipelines/wan/pipeline_wan_s2v.py @@ -0,0 +1,1054 @@ +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import math +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import regex as re +import torch +import torch.nn.functional as F +from PIL import Image +from transformers import AutoTokenizer, UMT5EncoderModel, Wav2Vec2ForCTC, Wav2Vec2Processor + +from ...audio_processor import PipelineAudioInput +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import WanLoraLoaderMixin +from ...models import AutoencoderKLWan, WanS2VTransformer3DModel +from ...schedulers import UniPCMultistepScheduler +from ...utils import is_ftfy_available, is_torch_xla_available, load_video, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import WanPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import numpy as np, math, requests + >>> import torch + >>> from diffusers import AutoencoderKLWan, WanSpeechToVideoPipeline + >>> from diffusers.utils import export_to_video, load_audio, export_to_merged_video_audio + >>> from transformers import Wav2Vec2ForCTC + >>> from PIL import Image + >>> from io import BytesIO + + >>> model_id = "Wan-AI/Wan2.2-S2V-14B-Diffusers" + >>> audio_encoder = Wav2Vec2ForCTC.from_pretrained(model_id, subfolder="audio_encoder", dtype=torch.float32) + >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) + >>> pipe = WanSpeechToVideoPipeline.from_pretrained( + ... model_id, vae=vae, audio_encoder=audio_encoder, torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + + >>> headers = {"User-Agent": "Mozilla/5.0"} + >>> url = "https://upload.wikimedia.org/wikipedia/commons/4/46/Albert_Einstein_sticks_his_tongue.jpg" + >>> resp = requests.get(url, headers=headers, timeout=30) + >>> image = Image.open(BytesIO(resp.content)) + + >>> audio, sampling_rate = load_audio( + ... "https://github.com/Wan-Video/Wan2.2/raw/refs/heads/main/examples/Five%20Hundred%20Miles.MP3" + ... ) + >>> # pose_video_path_or_url = "https://github.com/Wan-Video/Wan2.2/raw/refs/heads/main/examples/pose.mp4" + + + >>> def get_size_less_than_area(height, width, target_area=1024 * 704, divisor=64): + ... if height * width <= target_area: + ... # If the original image area is already less than or equal to the target, + ... # no resizing is needed—just padding. Still need to ensure that the padded area doesn't exceed the target. + ... max_upper_area = target_area + ... min_scale = 0.1 + ... max_scale = 1.0 + ... else: + ... # Resize to fit within the target area and then pad to multiples of `divisor` + ... max_upper_area = target_area # Maximum allowed total pixel count after padding + ... d = divisor - 1 + ... b = d * (height + width) + ... a = height * width + ... c = d**2 - max_upper_area + + ... # Calculate scale boundaries using quadratic equation + ... min_scale = (-b + math.sqrt(b**2 - 2 * a * c)) / (2 * a) # Scale when maximum padding is applied + ... max_scale = math.sqrt(max_upper_area / (height * width)) # Scale without any padding + + ... # We want to choose the largest possible scale such that the final padded area does not exceed max_upper_area + ... # Use binary search-like iteration to find this scale + ... find_it = False + ... for i in range(100): + ... scale = max_scale - (max_scale - min_scale) * i / 100 + ... new_height, new_width = int(height * scale), int(width * scale) + + ... # Pad to make dimensions divisible by 64 + ... pad_height = (64 - new_height % 64) % 64 + ... pad_width = (64 - new_width % 64) % 64 + ... pad_top = pad_height // 2 + ... pad_bottom = pad_height - pad_top + ... pad_left = pad_width // 2 + ... pad_right = pad_width - pad_left + + ... padded_height, padded_width = new_height + pad_height, new_width + pad_width + + ... if padded_height * padded_width <= max_upper_area: + ... find_it = True + ... break + + ... if find_it: + ... return padded_height, padded_width + ... else: + ... # Fallback: calculate target dimensions based on aspect ratio and divisor alignment + ... aspect_ratio = width / height + ... target_width = int((target_area * aspect_ratio) ** 0.5 // divisor * divisor) + ... target_height = int((target_area / aspect_ratio) ** 0.5 // divisor * divisor) + + ... # Ensure the result is not larger than the original resolution + ... if target_width >= width or target_height >= height: + ... target_width = int(width // divisor * divisor) + ... target_height = int(height // divisor * divisor) + + ... return target_height, target_width + + + >>> height, width = get_size_less_than_area(image.height, image.width, target_area=480 * 832) + + >>> prompt = "Einstein singing a song." + + >>> output = pipe( + ... prompt=prompt, + ... image=image, + ... audio=audio, + ... sampling_rate=sampling_rate, + ... height=height, + ... width=width, + ... num_frames_per_chunk=80, + ... # pose_video_path_or_url=pose_video_path_or_url, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=16) + + >>> # Lastly, we need to merge the video and audio into a new video, with the duration set to + >>> # the shorter of the two and overwrite the original video file. + + >>> export_to_merged_video_audio("output.mp4", "audio.mp3") + ``` +""" + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +def get_sample_indices(original_fps, total_frames, target_fps, num_sample, fixed_start=None): + required_duration = num_sample / target_fps + required_origin_frames = int(np.ceil(required_duration * original_fps)) + if required_duration > total_frames / original_fps: + raise ValueError("required_duration must be less than video length") + + if fixed_start is not None and fixed_start >= 0: + start_frame = fixed_start + else: + max_start = total_frames - required_origin_frames + if max_start < 0: + raise ValueError("video length is too short") + start_frame = np.random.randint(0, max_start + 1) + start_time = start_frame / original_fps + + end_time = start_time + required_duration + time_points = np.linspace(start_time, end_time, num_sample, endpoint=False) + + frame_indices = np.round(np.array(time_points) * original_fps).astype(int) + frame_indices = np.clip(frame_indices, 0, total_frames - 1) + return frame_indices + + +def linear_interpolation(features, input_fps, output_fps, output_len=None): + """ + Args: + features: shape=[1, T, 512] + input_fps: fps for audio, f_a + output_fps: fps for video, f_m + output_len: video length + """ + features = features.transpose(1, 2) # [1, 512, T] + seq_len = features.shape[2] / float(input_fps) # T/f_a + output_len = int(seq_len * output_fps) # f_m*T/f_a + output_features = F.interpolate( + features, size=output_len, align_corners=True, mode="linear" + ) # [1, 512, output_len] + return output_features.transpose(1, 2) # [1, output_len, 512] + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class WanSpeechToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): + r""" + Pipeline for prompt-image-audio-to-video generation using Wan2.2-S2V. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + tokenizer ([`T5Tokenizer`]): + Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), + specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + transformer ([`WanT2VTransformer3DModel`]): + Conditional Transformer to denoise the input latents. + scheduler ([`UniPCMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + audio_encoder ([`Wav2Vec2ForCTC`]): + Audio Encoder to process audio inputs. + audio_processor ([`Wav2Vec2Processor`]): + Audio Processor to preprocess audio inputs. + """ + + model_cpu_offload_seq = "text_encoder->audio_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + vae: AutoencoderKLWan, + scheduler: UniPCMultistepScheduler, + transformer: WanS2VTransformer3DModel, + audio_encoder: Wav2Vec2ForCTC, + audio_processor: Wav2Vec2Processor, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + audio_encoder=audio_encoder, + audio_processor=audio_processor, + ) + + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial, resample="bilinear") + self.audio_processor = audio_processor + self.motion_frames = 73 + self.drop_first_motion = True + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_audio( + self, + audio: PipelineAudioInput, + sampling_rate: int, + num_frames: int, + fps: int = 16, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + video_rate = 30 + audio_sample_m = 0 + + input_values = self.audio_processor(audio, sampling_rate=sampling_rate, return_tensors="pt").input_values + + # retrieve logits & take argmax + res = self.audio_encoder(input_values.to(self.audio_encoder.device), output_hidden_states=True) + feat = torch.cat(res.hidden_states) + + feat = linear_interpolation(feat, input_fps=50, output_fps=video_rate) + + audio_embed = feat.to(torch.float32) # Encoding for the motion + + num_layers, audio_frame_num, audio_dim = audio_embed.shape + + if num_layers > 1: + return_all_layers = True + else: + return_all_layers = False + + scale = video_rate / fps + + num_repeat = int(audio_frame_num / (num_frames * scale)) + 1 + + bucket_num = num_repeat * num_frames + padd_audio_num = math.ceil(num_repeat * num_frames / fps * video_rate) - audio_frame_num + batch_idx = get_sample_indices( + original_fps=video_rate, + total_frames=audio_frame_num + padd_audio_num, + target_fps=fps, + num_sample=bucket_num, + fixed_start=0, + ) + batch_audio_eb = [] + audio_sample_stride = int(video_rate / fps) + for bi in batch_idx: + if bi < audio_frame_num: + chosen_idx = list( + range( + bi - audio_sample_m * audio_sample_stride, + bi + (audio_sample_m + 1) * audio_sample_stride, + audio_sample_stride, + ) + ) + chosen_idx = [0 if c < 0 else c for c in chosen_idx] + chosen_idx = [audio_frame_num - 1 if c >= audio_frame_num else c for c in chosen_idx] + + if return_all_layers: + frame_audio_embed = audio_embed[:, chosen_idx].flatten(start_dim=-2, end_dim=-1) + else: + frame_audio_embed = audio_embed[0][chosen_idx].flatten() + else: + frame_audio_embed = ( + torch.zeros([audio_dim * (2 * audio_sample_m + 1)], device=audio_embed.device) + if not return_all_layers + else torch.zeros([num_layers, audio_dim * (2 * audio_sample_m + 1)], device=audio_embed.device) + ) + batch_audio_eb.append(frame_audio_embed) + audio_embed_bucket = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0) + + audio_embed_bucket = audio_embed_bucket.to(device) + audio_embed_bucket = audio_embed_bucket.unsqueeze(0) + if len(audio_embed_bucket.shape) == 3: + audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1) + elif len(audio_embed_bucket.shape) == 4: + audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1) + return audio_embed_bucket, num_repeat + + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def check_inputs( + self, + prompt, + negative_prompt, + image, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + image_embeds=None, + callback_on_step_end_tensor_inputs=None, + audio=None, + audio_embeds=None, + ): + if image is not None and image_embeds is not None: + raise ValueError( + f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to" + " only forward one of the two." + ) + if image is None and image_embeds is None: + raise ValueError( + "Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined." + ) + if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, Image.Image): + raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}") + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + if audio is not None and audio_embeds is not None: + raise ValueError( + f"Cannot forward both `audio`: {audio} and `audio_embeds`: {audio_embeds}. Please make sure to" + " only forward one of the two." + ) + elif audio is None and audio_embeds is None: + raise ValueError( + "Provide either `audio` or `audio_embeds`. Cannot leave both `audio` and `audio_embeds` undefined." + ) + elif audio is not None and not isinstance(audio, (np.ndarray)): + raise ValueError(f"`audio` has to be of type `np.ndarray` but is {type(audio)}") + + def prepare_latents( + self, + image: PipelineImageInput, + batch_size: int, + latent_motion_frames: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames_per_chunk: int = 80, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + pose_video: Optional[List[Image.Image]] = None, + init_first_frame: bool = False, + num_chunks: int = 1, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[torch.Tensor]]]: + num_latent_frames = ( + num_frames_per_chunk + 3 + self.motion_frames + ) // self.vae_scale_factor_temporal - latent_motion_frames + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + + shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + if image is not None: + image = image.unsqueeze(2) # [batch_size, channels, 1, height, width] + + video_condition = image.to(device=device, dtype=self.vae.dtype) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + + if isinstance(generator, list): + latent_condition = [ + retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator + ] + latent_condition = torch.cat(latent_condition) + else: + latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") + latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) + + latent_condition = latent_condition.to(dtype) + latent_condition = (latent_condition - latents_mean) * latents_std + + motion_pixels = torch.zeros([1, 3, self.motion_frames, height, width], dtype=self.vae.dtype, device=device) + # Get pose condition input if needed + pose_condition = self.load_pose_condition( + pose_video, num_chunks, num_frames_per_chunk, height, width, latents_mean, latents_std + ) + # Encode motion latents + videos_last_pixels = motion_pixels.detach() + if init_first_frame: + self.drop_first_motion = False + motion_pixels[:, :, -6:] = video_condition + motion_latents = retrieve_latents(self.vae.encode(motion_pixels), sample_mode="argmax") + motion_latents = (motion_latents - latents_mean) * latents_std + + return latents, latent_condition, videos_last_pixels, motion_latents, pose_condition + else: + return latents + + def load_pose_condition( + self, pose_video, num_chunks, num_frames_per_chunk, height, width, latents_mean, latents_std + ): + if pose_video is not None: + padding_frame_num = num_chunks * num_frames_per_chunk - pose_video.shape[2] + pose_video = pose_video.to(dtype=self.vae.dtype, device=self.vae.device) + pose_video = torch.cat( + [ + pose_video, + -torch.ones( + [1, 3, padding_frame_num, height, width], dtype=self.vae.dtype, device=self.vae.device + ), + ], + dim=2, + ) + + pose_video = torch.chunk(pose_video, num_chunks, dim=2) + else: + pose_video = [ + -torch.ones([1, 3, num_frames_per_chunk, height, width], dtype=self.vae.dtype, device=self.vae.device) + ] + + # Vectorized processing: concatenate all chunks along batch dimension + all_poses = torch.cat( + [torch.cat([cond[:, :, 0:1], cond], dim=2) for cond in pose_video], dim=0 + ) # Shape: [num_chunks, 3, num_frames_per_chunk+1, height, width] + + pose_condition = retrieve_latents(self.vae.encode(all_poses), sample_mode="argmax")[:, :, 1:] + pose_condition = (pose_condition - latents_mean) * latents_std + + return pose_condition + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + audio: PipelineAudioInput, + sampling_rate: int, + prompt: Union[str, List[str]], + negative_prompt: Union[str, List[str]] = None, + pose_video_path_or_url: Optional[str] = None, + height: int = 480, + width: int = 832, + num_frames_per_chunk: int = 80, + num_inference_steps: int = 40, + guidance_scale: float = 4.5, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + image_embeds: Optional[torch.Tensor] = None, + audio_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + init_first_frame: bool = False, + sampling_fps: int = 16, + num_chunks: Optional[int] = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PipelineImageInput`): + The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. + audio (`PipelineAudioInput`): + The audio input to condition the generation on. Must be an audio, a list of audios or a `torch.Tensor`. + sampling_rate (`int`): + The sampling rate of the audio input. + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + pose_video_path_or_url (`str` or `List[str]`, *optional*): + The path or URL to the pose video to condition the generation on. + height (`int`, defaults to `480`): + The height of the generated video. + width (`int`, defaults to `832`): + The width of the generated video. + num_frames_per_chunk (`int`, defaults to `80`): + The number of frames in each chunk of the generated video. `num_frames_per_chunk` should be a multiple + of 4. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `negative_prompt` input argument. + image_embeds (`torch.Tensor`, *optional*): + Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided, + image embeddings are generated from the `image` input argument. + audio_embeds (`torch.Tensor`, *optional*): + Pre-generated audio embeddings. Can be used to easily tweak audio inputs (weighting). If not provided, + audio embeddings are generated from the `audio` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `512`): + The maximum sequence length of the text encoder. If the prompt is longer than this, it will be + truncated. If the prompt is shorter, it will be padded to this length. + init_first_frame (`bool`, *optional*, defaults to False): + Whether to use the reference image as the first frame (i.e., standard image-to-video generation). + sampling_fps (`int`, *optional*, defaults to 16): + The frame rate (in frames per second) at which the generated video will be sampled. + num_chunks (`int`, *optional*, defaults to None): + The number of chunks to process. If not provided, the number of chunks will be determined by the audio + input to generate whole audio. E.g., If the input audio has 4 chunks, then user can set num_chunks=1 to + see 1 out of 4 chunks only without generating the whole video. + Examples: + + Returns: + [`~WanPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + image, + height, + width, + prompt_embeds, + negative_prompt_embeds, + image_embeds, + callback_on_step_end_tensor_inputs, + audio, + audio_embeds, + ) + + if num_frames_per_chunk % self.vae_scale_factor_temporal != 0: + num_frames_per_chunk = ( + num_frames_per_chunk // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + ) + logger.warning( + f"`num_frames_per_chunk` had to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number: {num_frames_per_chunk}" + ) + num_frames_per_chunk = max(num_frames_per_chunk, 1) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + if audio_embeds is None: + audio_embeds, num_chunks_audio = self.encode_audio( + audio, sampling_rate, num_frames_per_chunk, sampling_fps, device + ) + if num_chunks is None or num_chunks > num_chunks_audio: + num_chunks = num_chunks_audio + audio_embeds = audio_embeds.to(transformer_dtype) + + latent_motion_frames = (self.motion_frames + 3) // self.vae_scale_factor_temporal + + # 5. Prepare latent variables + num_channels_latents = self.vae.config.z_dim + image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) + + pose_video = None + if pose_video_path_or_url is not None: + pose_video = load_video( + pose_video_path_or_url, + n_frames=num_frames_per_chunk * num_chunks, + target_fps=sampling_fps, + reverse=True, + ) + pose_video = self.video_processor.preprocess_video( + pose_video, height=height, width=width, resize_mode="resize_min_center_crop" + ).to(device, dtype=torch.float32) + + video_chunks = [] + for r in range(num_chunks): + latents_outputs = self.prepare_latents( + image if r == 0 else None, + batch_size * num_videos_per_prompt, + latent_motion_frames, + num_channels_latents, + height, + width, + num_frames_per_chunk, + torch.float32, + device, + generator, + latents if r == 0 else None, + pose_video, + init_first_frame, + num_chunks, + ) + + if r == 0: + latents, condition, videos_last_pixels, motion_latents, pose_condition = latents_outputs + else: + latents = latents_outputs + + with torch.no_grad(): + left_idx = r * num_frames_per_chunk + right_idx = r * num_frames_per_chunk + num_frames_per_chunk + pose_latents = pose_condition[r] if pose_video is not None else pose_condition[0] * 0 + pose_latents = pose_latents.to(dtype=transformer_dtype, device=device) + audio_embeds_input = audio_embeds[..., left_idx:right_idx] + motion_latents_input = motion_latents.to(transformer_dtype).clone() + + # 4. Prepare timesteps by resetting scheduler in each chunk + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + latent_model_input = latents.to(transformer_dtype) + condition = condition.to(transformer_dtype) + timestep = t.expand(latents.shape[0]) + + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + motion_latents=motion_latents_input, + image_latents=condition, + pose_latents=pose_latents, + audio_embeds=audio_embeds_input, + motion_frames=[self.motion_frames, latent_motion_frames], + drop_motion_frames=self.drop_first_motion and r == 0, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + motion_latents=motion_latents_input, + image_latents=condition, + pose_latents=pose_latents, + audio_embeds=0.0 * audio_embeds_input, + motion_frames=[self.motion_frames, latent_motion_frames], + drop_motion_frames=self.drop_first_motion and r == 0, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not (self.drop_first_motion and r == 0): + decode_latents = torch.cat([motion_latents, latents], dim=2) + else: + decode_latents = torch.cat([condition, latents], dim=2) + + decode_latents = decode_latents.to(self.vae.dtype) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(decode_latents.device, decode_latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + decode_latents.device, decode_latents.dtype + ) + decode_latents = decode_latents / latents_std + latents_mean + video = self.vae.decode(decode_latents, return_dict=False)[0] + video = video[:, :, -(num_frames_per_chunk):] + + if self.drop_first_motion and r == 0: + video = video[:, :, 3:] + + num_overlap_frames = min(self.motion_frames, video.shape[2]) + videos_last_pixels = torch.cat( + [videos_last_pixels[:, :, num_overlap_frames:], video[:, :, -num_overlap_frames:]], dim=2 + ) + + # Update motion_latents for next iteration + motion_latents = retrieve_latents(self.vae.encode(videos_last_pixels), sample_mode="argmax") + motion_latents = (motion_latents - latents_mean) * latents_std + + video_chunks.append(video) + + video_chunks = torch.cat(video_chunks, dim=2) + + self._current_timestep = None + + if not output_type == "latent": + video = self.video_processor.postprocess_video(video_chunks, output_type=output_type) + else: + # TODO + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return WanPipelineOutput(frames=video) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 63932221b207..fd3c5807b002 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -41,7 +41,7 @@ from .deprecation_utils import deprecate from .doc_utils import replace_example_docstring from .dynamic_modules_utils import get_class_from_dynamic_module -from .export_utils import export_to_gif, export_to_obj, export_to_ply, export_to_video +from .export_utils import export_to_gif, export_to_merged_video_audio, export_to_obj, export_to_ply, export_to_video from .hub_utils import ( PushToHubMixin, _add_variant, @@ -122,7 +122,7 @@ is_xformers_version, requires_backends, ) -from .loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video +from .loading_utils import get_module_from_name, get_submodule_by_name, load_audio, load_image, load_video from .logging import get_logger from .outputs import BaseOutput from .peft_utils import ( diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 6e7d22797902..cb758232328b 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1473,6 +1473,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class WanS2VTransformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class WanTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index bb8fea8c8a8b..76fcd8193ea0 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -3362,6 +3362,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class WanSpeechToVideoPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class WanVACEPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/src/diffusers/utils/export_utils.py b/src/diffusers/utils/export_utils.py index 07cf46928a44..6fc79418121d 100644 --- a/src/diffusers/utils/export_utils.py +++ b/src/diffusers/utils/export_utils.py @@ -1,6 +1,9 @@ import io +import os import random +import shutil import struct +import subprocess import tempfile from contextlib import contextmanager from typing import List, Optional, Union @@ -207,3 +210,62 @@ def export_to_video( writer.append_data(frame) return output_video_path + + +def export_to_merged_video_audio(video_path: str, audio_path: str): + """ + Merge the video and audio into a new video, with the duration set to the shorter of the two, and overwrite the + original video file. + + Parameters: + video_path (str): Path to the original video file + audio_path (str): Path to the audio file + """ + if not os.path.exists(video_path): + raise FileNotFoundError(f"video file {video_path} does not exist") + if not os.path.exists(audio_path): + raise FileNotFoundError(f"audio file {audio_path} does not exist") + + base, ext = os.path.splitext(video_path) + temp_output = f"{base}_temp{ext}" + + try: + # Create ffmpeg command + command = [ + "ffmpeg", + "-y", # overwrite + "-i", + video_path, + "-i", + audio_path, + "-c:v", + "copy", # copy video stream + "-c:a", + "aac", # use AAC audio encoder + "-b:a", + "192k", # set audio bitrate (optional) + "-map", + "0:v:0", # select the first video stream + "-map", + "1:a:0", # select the first audio stream + "-shortest", # choose the shortest duration + temp_output, + ] + + # Execute the command + logger.info("Start merging video and audio...") + result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + + # Check result + if result.returncode != 0: + error_msg = f"FFmpeg execute failed: {result.stderr}" + logger.error(error_msg) + raise RuntimeError(error_msg) + + shutil.move(temp_output, video_path) + logger.info(f"Merge completed, saved to {video_path}") + + except Exception as e: + if os.path.exists(temp_output): + os.remove(temp_output) + logger.error(f"merge_video_audio failed with error: {e}") diff --git a/src/diffusers/utils/loading_utils.py b/src/diffusers/utils/loading_utils.py index dd23ae73c861..29e8a7855fdd 100644 --- a/src/diffusers/utils/loading_utils.py +++ b/src/diffusers/utils/loading_utils.py @@ -3,6 +3,8 @@ from typing import Any, Callable, List, Optional, Tuple, Union from urllib.parse import unquote, urlparse +import librosa +import numpy import PIL.Image import PIL.ImageOps import requests @@ -57,6 +59,9 @@ def load_image( def load_video( video: str, convert_method: Optional[Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]]] = None, + n_frames: Optional[int] = None, + target_fps: Optional[int] = None, + reverse: bool = False, ) -> List[PIL.Image.Image]: """ Loads `video` to a list of PIL Image. @@ -67,6 +72,13 @@ def load_video( convert_method (Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]], *optional*): A conversion method to apply to the video after loading it. When set to `None` the images will be converted to "RGB". + n_frames (`int`, *optional*): + Number of frames to sample from the video. If None, all frames are loaded. + target_fps (`int`, *optional*): + Target sampling frame rate. If None, uses original frame rate. + reverse (`bool`, *optional*): + If True, samples frames starting from the beginning of the video; if False, samples frames starting from + the end. Defaults to False. Returns: `List[PIL.Image.Image]`: @@ -125,9 +137,40 @@ def load_video( ) with imageio.get_reader(video) as reader: - # Read all frames - for frame in reader: - pil_images.append(PIL.Image.fromarray(frame)) + # Determine which frames to sample + if n_frames is not None and target_fps is not None: + # Get video metadata + total_frames = reader.count_frames() + original_fps = reader.get_meta_data().get("fps") + + # Calculate sampling interval based on target fps + interval = max(1, round(original_fps / target_fps)) + required_span = (n_frames - 1) * interval + + if reverse: + start_frame = 0 + else: + start_frame = max(0, total_frames - required_span - 1) + + # Generate sampling indices + sampled_indices = [] + for i in range(n_frames): + indice = start_frame + i * interval + if indice >= total_frames: + break + sampled_indices.append(int(indice)) + + # Read specific frames + for idx in sampled_indices: + try: + frame = reader.get_data(idx) + pil_images.append(PIL.Image.fromarray(frame)) + except IndexError: + break + else: + # Read all frames + for frame in reader: + pil_images.append(PIL.Image.fromarray(frame)) if was_tempfile_created: os.remove(video_path) @@ -138,6 +181,53 @@ def load_video( return pil_images +def load_audio( + audio: Union[str, numpy.ndarray], convert_method: Optional[Callable[[numpy.ndarray], numpy.ndarray]] = None +) -> numpy.ndarray: + """ + Loads `audio` to a numpy array. + + Args: + audio (`str` or `numpy.ndarray`): + The audio to convert to the numpy array format. + convert_method (Callable[[numpy.ndarray], numpy.ndarray], *optional*): + A conversion method to apply to the audio after loading it. When set to `None` the audio will be converted + to a specific format. + + Returns: + `numpy.ndarray`: + A Librosa audio object. + `int`: + The sample rate of the audio. + """ + if isinstance(audio, str): + if audio.startswith("http://") or audio.startswith("https://"): + # Download audio from URL and load with librosa + response = requests.get(audio, stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT) + with tempfile.NamedTemporaryFile(delete=False) as temp_file: + for chunk in response.iter_content(chunk_size=8192): + temp_file.write(chunk) + temp_audio_path = temp_file.name + + audio, sample_rate = librosa.load(temp_audio_path, sr=16000) + os.remove(temp_audio_path) # Clean up temporary file + elif os.path.isfile(audio): + audio, sample_rate = librosa.load(audio, sr=16000) + else: + raise ValueError( + f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {audio} is not a valid path." + ) + elif isinstance(audio, numpy.ndarray): + audio = audio + sample_rate = 16000 # Default sample rate for numpy arrays + else: + raise ValueError( + "Incorrect format used for the audio. Should be a URL linking to an audio, a local path, or a numpy array." + ) + + return audio, sample_rate + + # Taken from `transformers`. def get_module_from_name(module, tensor_name: str) -> Tuple[Any, str]: if "." in tensor_name: diff --git a/src/diffusers/video_processor.py b/src/diffusers/video_processor.py index 59b59b47d2c7..dc6623e1e472 100644 --- a/src/diffusers/video_processor.py +++ b/src/diffusers/video_processor.py @@ -25,7 +25,9 @@ class VideoProcessor(VaeImageProcessor): r"""Simple video processor.""" - def preprocess_video(self, video, height: Optional[int] = None, width: Optional[int] = None) -> torch.Tensor: + def preprocess_video( + self, video, height: Optional[int] = None, width: Optional[int] = None, resize_mode: str = "default" + ) -> torch.Tensor: r""" Preprocesses input video(s). @@ -49,6 +51,9 @@ def preprocess_video(self, video, height: Optional[int] = None, width: Optional[ width (`int`, *optional*`, defaults to `None`): The width in preprocessed frames of the video. If `None`, will use get_default_height_width()` to get the default width. + resize_mode (`str`, *optional*, defaults to `default`): + The resize mode, can be one of `default`, `fill`, `crop`, or `center_crop`. See + `VaeImageProcessor.preprocess` for detailed descriptions of each mode. """ if isinstance(video, list) and isinstance(video[0], np.ndarray) and video[0].ndim == 5: warnings.warn( @@ -79,7 +84,9 @@ def preprocess_video(self, video, height: Optional[int] = None, width: Optional[ "Input is in incorrect format. Currently, we only support numpy.ndarray, torch.Tensor, PIL.Image.Image" ) - video = torch.stack([self.preprocess(img, height=height, width=width) for img in video], dim=0) + video = torch.stack( + [self.preprocess(img, height=height, width=width, resize_mode=resize_mode) for img in video], dim=0 + ) # move the number of channels before the number of frames. video = video.permute(0, 2, 1, 3, 4) diff --git a/tests/pipelines/wan/test_wan_speech_to_video.py b/tests/pipelines/wan/test_wan_speech_to_video.py new file mode 100644 index 000000000000..7396a151b3be --- /dev/null +++ b/tests/pipelines/wan/test_wan_speech_to_video.py @@ -0,0 +1,244 @@ +# Copyright 2025 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import AutoTokenizer, T5EncoderModel, Wav2Vec2ForCTC, Wav2Vec2Processor + +from diffusers import ( + AutoencoderKLWan, + FlowMatchEulerDiscreteScheduler, + WanS2VTransformer3DModel, + WanSpeechToVideoPipeline, +) + +from ...testing_utils import enable_full_determinism, torch_device +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class WanSpeechToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = WanSpeechToVideoPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + supports_dduf = False + + def get_dummy_components(self): + torch.manual_seed(0) + vae = AutoencoderKLWan( + base_dim=3, + z_dim=16, + dim_mult=[1, 1, 1, 1], + num_res_blocks=1, + temperal_downsample=[False, True, True], + ) + + torch.manual_seed(0) + scheduler = FlowMatchEulerDiscreteScheduler(shift=3.0) + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + transformer = WanS2VTransformer3DModel( + patch_size=(1, 2, 2), + num_attention_heads=2, + attention_head_dim=12, + in_channels=16, + out_channels=16, + text_dim=32, + freq_dim=256, + ffn_dim=32, + num_layers=3, + num_weighted_avg_layers=5, + cross_attn_norm=True, + qk_norm="rms_norm_across_heads", + rope_max_seq_len=32, + audio_dim=16, + audio_inject_layers=[0, 2], + enable_adain=True, + enable_framepack=True, + ) + + torch.manual_seed(0) + audio_encoder = Wav2Vec2ForCTC.from_pretrained("hf-internal-testing/tiny-random-wav2vec2") + audio_processor = Wav2Vec2Processor.from_pretrained("hf-internal-testing/tiny-random-wav2vec2") + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "audio_encoder": audio_encoder, + "audio_processor": audio_processor, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + # Use 64x64 so that after VAE downsampling (factor ~8) latent spatial size is 8x8, which matches + # the frame-packing conv kernel requirement. The largest kernel is (4, 8, 8) so we need at least 8x8 latents. + height = 64 + width = 64 + + image = Image.new("RGB", (width, height)) + + sampling_rate = 16000 + audio_length = 0.5 + # Make audio generation deterministic by using a fixed seed + np_rng = np.random.RandomState(seed) + audio = np_rng.rand(int(sampling_rate * audio_length)).astype(np.float32) + + inputs = { + "image": image, + "audio": audio, + "sampling_rate": sampling_rate, + "prompt": "A person speaking", + "negative_prompt": "low quality", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 4.5, + "height": height, + "width": width, + "num_frames_per_chunk": 4, + "num_chunks": 2, + "max_sequence_length": 16, + "output_type": "pt", + "pose_video_path_or_url": None, + "init_first_frame": True, + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames[0] + expected_num_frames = inputs["num_frames_per_chunk"] * inputs["num_chunks"] + if not inputs["init_first_frame"]: + expected_num_frames -= 3 + self.assertEqual(video.shape, (expected_num_frames, 3, inputs["height"], inputs["width"])) + + def test_inference_with_pose(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["pose_video_path_or_url"] = "https://github.com/Wan-Video/Wan2.2/raw/refs/heads/main/examples/pose.mp4" + video = pipe(**inputs).frames[0] + expected_num_frames = inputs["num_frames_per_chunk"] * inputs["num_chunks"] + if not inputs["init_first_frame"]: + expected_num_frames -= 3 + self.assertEqual(video.shape, (expected_num_frames, 3, inputs["height"], inputs["width"])) + + def test_callback_cfg(self): + sig = inspect.signature(self.pipeline_class.__call__) + has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters + has_callback_step_end = "callback_on_step_end" in sig.parameters + + if not (has_callback_tensor_inputs and has_callback_step_end): + return + + if "guidance_scale" not in sig.parameters: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + self.assertTrue( + hasattr(pipe, "_callback_tensor_inputs"), + f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", + ) + + def callback_increase_guidance(pipe, i, t, callback_kwargs): + pipe._guidance_scale += 1.0 + + return callback_kwargs + + inputs = self.get_dummy_inputs(torch_device) + + # use cfg guidance because some pipelines modify the shape of the latents + # outside of the denoising loop + inputs["guidance_scale"] = 2.0 + inputs["callback_on_step_end"] = callback_increase_guidance + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + _ = pipe(**inputs)[0] + + # we increase the guidance scale by 1.0 at every step + # check that the guidance scale is increased by the number of scheduler timesteps + # accounts for models that modify the number of inference steps based on strength. + # For this pipeline, the total number of timesteps is multiplied by num_chunks + # since each chunk runs independently with its own denoising loop + assert pipe.guidance_scale == (inputs["guidance_scale"] + pipe.num_timesteps * inputs["num_chunks"]) + + @unittest.skip("Test not supported") + def test_attention_slicing_forward_pass(self): + pass + + @unittest.skip("Errors out because passing multiple prompts at once is not yet supported by this pipeline.") + def test_encode_prompt_works_in_isolation(self): + pass + + @unittest.skip("Batching is not yet supported with this pipeline") + def test_inference_batch_consistent(self): + pass + + @unittest.skip("Batching is not yet supported with this pipeline") + def test_inference_batch_single_identical(self): + return super().test_inference_batch_single_identical() + + @unittest.skip( + "AutoencoderKLWan encoded latents are always in FP32. This test is not designed to handle mixed dtype inputs" + ) + def test_float16_inference(self): + pass + + @unittest.skip( + "AutoencoderKLWan encoded latents are always in FP32. This test is not designed to handle mixed dtype inputs" + ) + def test_save_load_float16(self): + pass diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py index 0f4fd408a7c1..8e53e5a4ce13 100644 --- a/tests/quantization/gguf/test_gguf.py +++ b/tests/quantization/gguf/test_gguf.py @@ -16,6 +16,7 @@ HiDreamImageTransformer2DModel, SD3Transformer2DModel, StableDiffusion3Pipeline, + WanS2VTransformer3DModel, WanTransformer3DModel, WanVACETransformer3DModel, ) @@ -721,6 +722,33 @@ def get_dummy_inputs(self): } +class WanS2VGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase): + ckpt_path = "https://huggingface.co/QuantStack/Wan2.2-S2V-14B-GGUF/blob/main/Wan2.2-S2V-14B-Q3_K_S.gguf" + torch_dtype = torch.bfloat16 + model_cls = WanS2VTransformer3DModel + expected_memory_use_in_gb = 9 + + def get_dummy_inputs(self): + return { + "hidden_states": torch.randn((1, 16, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to( + torch_device, self.torch_dtype + ), + "encoder_hidden_states": torch.randn( + (1, 512, 4096), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "control_hidden_states": torch.randn( + (1, 96, 2, 64, 64), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "control_hidden_states_scale": torch.randn( + (8,), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype), + } + + @require_torch_version_greater("2.7.1") class GGUFCompileTests(QuantCompileTests, unittest.TestCase): torch_dtype = torch.bfloat16