Skip to content

Conversation

tolgacangoz
Copy link
Contributor

@tolgacangoz tolgacangoz commented Aug 29, 2025

This PR is fixing #12257

This PR is ready for review except for these current TODOs:

  • Converting complex numbers-based RoPE into a real numbers-based one to be able to compile. I think this can be done in a follow-up PR, too.
  • If there is something vectorizable, then do it, because diffusers assumes that num_frames, height, and width are the same in a batch, etc., as opposed to the original repo. There are many for loops in the original repo. This is my current priority now.
  • Complete documentation
  • Complete test file

When I equalize several parameters to be able to produce the same/similar videos:

wan.mp4
diffusers.mp4
Try WanSpeechToVideoPipeline!
!git clone https://github.com/tolgacangoz/diffusers.git
%cd diffusers
!git switch "integrations/wan2.2-s2v"
!pip install pip uv -qU
!uv pip install -e ".[dev]" -q
!uv pip install imageio-ffmpeg ftfy decord ninja packaging -q
# For Flash attention 2:
#!uv pip install flash-attn --no-build-isolation
# For Flash attention 3 in diffusers:
#import os
#os.environ["DIFFUSERS_ENABLE_HUB_KERNELS"] = "YES"


import numpy as np
import torch, os
from diffusers import AutoencoderKLWan, WanSpeechToVideoPipeline
from diffusers.utils import export_to_video, load_image, load_audio, load_video
from transformers import Wav2Vec2ForCTC

model_id = "Wan-AI/Wan2.2-S2V-14B-Diffusers"  # will be official
model_id = "tolgacangoz/Wan2.2-S2V-14B-Diffusers"
audio_encoder = Wav2Vec2ForCTC.from_pretrained(model_id, subfolder="audio_encoder", dtype=torch.float32)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanSpeechToVideoPipeline.from_pretrained(
    model_id, vae=vae, audio_encoder=audio_encoder, torch_dtype=torch.bfloat16,
)#.to("cuda")
pipe.enable_model_cpu_offload()
#pipe.transformer.set_attention_backend("flash")  # FA 2
#pipe.transformer.set_attention_backend("_flash_3_hub")  # FA 3

first_frame = load_image("https://raw.githubusercontent.com/Wan-Video/Wan2.2/refs/heads/main/examples/i2v_input.JPG")
audio, sampling_rate = load_audio("https://github.com/Wan-Video/Wan2.2/raw/refs/heads/main/examples/talk.wav")

import math

def get_size_less_than_area(height,
                            width,
                            target_area=1024 * 704,
                            divisor=64):
    if height * width <= target_area:
        # If the original image area is already less than or equal to the target,
        # no resizing is needed—just padding. Still need to ensure that the padded area doesn't exceed the target.
        max_upper_area = target_area
        min_scale = 0.1
        max_scale = 1.0
    else:
        # Resize to fit within the target area and then pad to multiples of `divisor`
        max_upper_area = target_area  # Maximum allowed total pixel count after padding
        d = divisor - 1
        b = d * (height + width)
        a = height * width
        c = d**2 - max_upper_area

        # Calculate scale boundaries using quadratic equation
        min_scale = (-b + math.sqrt(b**2 - 2 * a * c)) / (
            2 * a)  # Scale when maximum padding is applied
        max_scale = math.sqrt(max_upper_area /
                                (height * width))  # Scale without any padding

    # We want to choose the largest possible scale such that the final padded area does not exceed max_upper_area
    # Use binary search-like iteration to find this scale
    find_it = False
    for i in range(100):
        scale = max_scale - (max_scale - min_scale) * i / 100
        new_height, new_width = int(height * scale), int(width * scale)

        # Pad to make dimensions divisible by 64
        pad_height = (64 - new_height % 64) % 64
        pad_width = (64 - new_width % 64) % 64
        pad_top = pad_height // 2
        pad_bottom = pad_height - pad_top
        pad_left = pad_width // 2
        pad_right = pad_width - pad_left

        padded_height, padded_width = new_height + pad_height, new_width + pad_width

        if padded_height * padded_width <= max_upper_area:
            find_it = True
            break

    if find_it:
        return padded_height, padded_width
    else:
        # Fallback: calculate target dimensions based on aspect ratio and divisor alignment
        aspect_ratio = width / height
        target_width = int(
            (target_area * aspect_ratio)**0.5 // divisor * divisor)
        target_height = int(
            (target_area / aspect_ratio)**0.5 // divisor * divisor)

        # Ensure the result is not larger than the original resolution
        if target_width >= width or target_height >= height:
            target_width = int(width // divisor * divisor)
            target_height = int(height // divisor * divisor)

        return target_height, target_width

height, width = get_size_less_than_area(first_frame.height, first_frame.width, target_area=480*832)

prompt = "Einstein singing a song."

output = pipe(
    image=first_frame, audio=audio, sampling_rate=sampling_rate,
    prompt=prompt, height=height, width=width, num_frames_per_chunk=80,
).frames[0]
export_to_video(output, "video.mp4", fps=16)

import logging, shutil, subprocess

def merge_video_audio(video_path: str, audio_path: str):
    """
    Merge the video and audio into a new video, with the duration set to the shorter of the two,
    and overwrite the original video file.

    Parameters:
    video_path (str): Path to the original video file
    audio_path (str): Path to the audio file
    """
    # set logging
    logging.basicConfig(level=logging.INFO)

    # check
    if not os.path.exists(video_path):
        raise FileNotFoundError(f"video file {video_path} does not exist")
    if not os.path.exists(audio_path):
        raise FileNotFoundError(f"audio file {audio_path} does not exist")

    base, ext = os.path.splitext(video_path)
    temp_output = f"{base}_temp{ext}"

    try:
        # create ffmpeg command
        command = [
            'ffmpeg',
            '-y',  # overwrite
            '-i',
            video_path,
            '-i',
            audio_path,
            '-c:v',
            'copy',  # copy video stream
            '-c:a',
            'aac',  # use AAC audio encoder
            '-b:a',
            '192k',  # set audio bitrate (optional)
            '-map',
            '0:v:0',  # select the first video stream
            '-map',
            '1:a:0',  # select the first audio stream
            '-shortest',  # choose the shortest duration
            temp_output
        ]

        # execute the command
        logging.info("Start merging video and audio...")
        result = subprocess.run(
            command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)

        # check result
        if result.returncode != 0:
            error_msg = f"FFmpeg execute failed: {result.stderr}"
            logging.error(error_msg)
            raise RuntimeError(error_msg)

        shutil.move(temp_output, video_path)
        logging.info(f"Merge completed, saved to {video_path}")

    except Exception as e:
        if os.path.exists(temp_output):
            os.remove(temp_output)
        logging.error(f"merge_video_audio failed with error: {e}")

import requests, tempfile
from diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT

response = requests.get(audio, stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT)
with tempfile.NamedTemporaryFile(delete=False) as talk:
    for chunk in response.iter_content(chunk_size=8192):
        talk.write(chunk)
    talk_file = talk.name

merge_video_audio("video.mp4", talk_file)

@yiyixuxu @sayakpaul @a-r-r-o-w @asomoza @DN6 @stevhliu
@WanX-Video-1 @Steven-SWZhang @kelseyee
@SHYuanBest @J4BEZ @okaris @xziayro-ai @teith @luke14free

…date example imports

Add unit tests for WanSpeechToVideoPipeline and WanS2VTransformer3DModel and gguf
The previous audio encoding logic was a placeholder. It is now replaced with a `Wav2Vec2ForCTC` model and processor, including the full implementation for processing audio inputs. This involves resampling and aligning audio features with video frames to ensure proper synchronization.

Additionally, utility functions for loading audio from files or URLs are added, and the `audio_processor` module is refactored to correctly handle audio data types instead of image types.
Introduces support for audio and pose conditioning, replacing the previous image conditioning mechanism. The model now accepts audio embeddings and pose latents as input.

This change also adds two new, mutually exclusive motion processing modules:
- `MotionerTransformers`: A transformer-based module for encoding motion.
- `FramePackMotioner`: A module that packs frames from different temporal buckets for motion representation.

Additionally, an `AudioInjector` module is implemented to fuse audio features into specific transformer blocks using cross-attention.
The `MotionerTransformers` module is removed and its functionality is replaced by a `FramePackMotioner` module and a simplified standard motion processing pipeline.

The codebase is refactored to remove the `einops` dependency, replacing `rearrange` operations with standard PyTorch tensor manipulations for better code consistency.

Additionally, `AdaLayerNorm` is introduced for improved conditioning, and helper functions for Rotary Positional Embeddings (RoPE) are added (probably temporarily) and refactored for clarity and flexibility. The audio injection mechanism is also updated to align with the new model structure.
Removes the calculation of several unused variables and an unnecessary `deepcopy` operation on the latents tensor.

This change also removes the now-unused `deepcopy` import, simplifying the overall logic.
Refactors the `WanS2VTransformer3DModel` for clarity and better handling of various conditioning inputs like audio, pose, and motion.

Key changes:
- Simplifies the `WanS2VTransformerBlock` by removing projection layers and streamlining the forward pass.
- Introduces `after_transformer_block` to cleanly inject audio information after each transformer block, improving code organization.
- Enhances the main `forward` method to better process and combine multiple conditioning signals (image, audio, motion) before the transformer blocks.
- Adds support for a zero-value timestep to differentiate between image and video latents.
- Generalizes temporal embedding logic to support multiple model variations.
Introduces the necessary configurations and state dictionary key mappings to enable the conversion of S2V model checkpoints to the Diffusers format.

This includes:
- A new transformer configuration for the S2V model architecture, including parameters for audio and pose conditioning.
- A comprehensive rename dictionary to map the original S2V layer names to their Diffusers equivalents.
tolgacangoz and others added 13 commits September 17, 2025 09:20
Adds a utility function to merge video and audio files using ffmpeg.

This simplifies the process of combining audio and video outputs,
especially useful in pipelines like WanSpeechToVideoPipeline.

The function handles temporary file creation, command execution,
and error handling for a more robust merging process.
Consolidates audio injection functionality by moving the `after_transformer_block` method into the `AudioInjector` class. This change improves code organization and encapsulation, making the injection process more modular and maintainable.
Simplifies the audio injection process by directly passing injection layer indices to the `AudioInjector`.

This removes the need for a depth-first search and dictionary creation within the injector, making the code more efficient and readable.
@tolgacangoz
Copy link
Contributor Author

tolgacangoz commented Sep 17, 2025

Hi @kelseyee. The official repo at HF will be required. Will you open a placeholder repo, i.e., Wan-AI/Wan2.2-S2V-14B-Diffusers, and then I will be able to open a PR there?

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Sep 17, 2025

hi @tolgacangoz
I can help with the repo once it's ready
We are still refactoring the transformer, so the checkpoints are not finalized yet

motion_latents = videos_last_latents.to(dtype=motion_latents.dtype, device=motion_latents.device)

# Accumulate latents so as to decode them all at once at the end
all_latents.append(segment_latents)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello, when testing on my side I saw duplicated frames between the end of a chunk and the start of the next chunk. It seems that one motion frame is still at the start of segment_latents in this line for all chunks except the first.
I think it's coming from the fact that num_latent_frames (21 with default inputs) is not the same as in self.prepare_latents (20), the extra 1 keeps the last motion frame. Replacing all_latents.append(segment_latents) with all_latents.append(latents) fixed it on my side, changing the num_latent_frame formula should do the same.
Super cool to do the chunking in latent space instead of video frames, I hope this helps a little :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, thanks a bunch for testing @gsprochette! I decided to follow the original repo, since this PR is supposed to be an integration PR rather than an extra-optimizations PR.

tolgacangoz and others added 6 commits September 24, 2025 10:49
…near resampling and adjust frame chunk settings

Updates the speech-to-video pipeline to perform a decode-encode cycle within the generation loop for each video chunk. This change improves temporal consistency between chunks by using the pixels of the previously generated frames, rather than their latents, to condition the next chunk.

Key changes include:
- Modifying the generation loop to decode latents into video frames, update the conditioning pixels, and then re-encode them for the next iteration's motion latents.
- Setting the default `num_frames_per_chunk` to 80 and adjusting the corresponding frame logic.
- Enabling `bilinear` resampling in the `VideoProcessor`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants