Skip to content

Wan 2.2 with LightX2V offloading tries to multiply tensors from different devices and fails #12052

@luke14free

Description

@luke14free

Describe the bug

After @sayakpaul great work in #12040 LightX2V now works. However what doesn't work is adding both a lora and offloading to the transformer_2. I can get away with either (i.e. offload both transformers but add a lora only to transformer and NOT to transformer_2, OR offload just transformer and add a lora to both transformer_2 and transformer).

However offloading transformer_2 is quite important, since it causes 2x the VRAM to be used, and even a Q4_K_S model with LightX2V will use >24gb vram (as opposed to <9GB VRAM as in ComfyUI).

Reproduction

The script is the same as the one posted by Paul in the #12040 PR with the addition of offloading

import torch
from diffusers import WanImageToVideoPipeline
from huggingface_hub import hf_hub_download
import requests
from PIL import Image
from diffusers.loaders.lora_conversion_utils import _convert_non_diffusers_wan_lora_to_diffusers
from io import BytesIO
import safetensors.torch

# Load a basic transformer model
pipe = WanImageToVideoPipeline.from_pretrained(
    "Wan-AI/Wan2.2-I2V-A14B-Diffusers",
    torch_dtype=torch.bfloat16
)

lora_path = hf_hub_download(
    repo_id="Kijai/WanVideo_comfy",
    filename="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors"
)

# This is what is different

self.pipe.vae.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level")
self.pipe.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level")

# Without this line it works but uses 2x the VRAM
self.pipe.transformer_2.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level")

self.pipe.text_encoder.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level")

pipe.to("cuda")

pipe.load_lora_weights(lora_path)
# print(pipe.transformer.__class__.__name__)
# print(pipe.transformer.peft_config)
org_state_dict = safetensors.torch.load_file(lora_path)
converted_state_dict = _convert_non_diffusers_wan_lora_to_diffusers(org_state_dict)
pipe.transformer_2.load_lora_adapter(converted_state_dict)

image_url = "https://cloud.inference.sh/u/4mg21r6ta37mpaz6ktzwtt8krr/01k1g7k73eebnrmzmc6h0bghq6.png"
response = requests.get(image_url)
input_image = Image.open(BytesIO(response.content)).convert("RGB")

frames = pipe(input_image, "animate", num_inference_steps=4, guidance_scale=1.0)

Logs

[t+1m44s256ms] [ERROR] Traceback (most recent call last):
[t+1m44s256ms]   File "/server/tasks.py", line 50, in run_task
[t+1m44s256ms]     output = await result
[t+1m44s256ms]              ^^^^^^^^^^^^
[t+1m44s256ms]   File "/inferencesh/apps/gpu/65b8e0w0x60df8we0x6njqx9kc/src/inference.py", line 424, in run
[t+1m44s256ms]     output = self.pipe(
[t+1m44s256ms]              ^^^^^^^^^^
[t+1m44s256ms]   File "/inferencesh/apps/gpu/65b8e0w0x60df8we0x6njqx9kc/venv/3.12/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[t+1m44s256ms]     return func(*args, **kwargs)
[t+1m44s256ms]            ^^^^^^^^^^^^^^^^^^^^^
[t+1m44s256ms]   File "/inferencesh/apps/gpu/65b8e0w0x60df8we0x6njqx9kc/venv/3.12/lib/python3.12/site-packages/diffusers/pipelines/wan/pipeline_wan_i2v.py", line 754, in __call__
[t+1m44s256ms]     noise_pred = current_model(
[t+1m44s256ms]                  ^^^^^^^^^^^^^^
[t+1m44s256ms]   File "/inferencesh/apps/gpu/65b8e0w0x60df8we0x6njqx9kc/venv/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[t+1m44s256ms]     return self._call_impl(*args, **kwargs)
[t+1m44s256ms]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[t+1m44s256ms]   File "/inferencesh/apps/gpu/65b8e0w0x60df8we0x6njqx9kc/venv/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[t+1m44s256ms]     return forward_call(*args, **kwargs)
[t+1m44s256ms]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[t+1m44s256ms]   File "/inferencesh/apps/gpu/65b8e0w0x60df8we0x6njqx9kc/venv/3.12/lib/python3.12/site-packages/diffusers/hooks/hooks.py", line 189, in new_forward
[t+1m44s256ms]     output = function_reference.forward(*args, **kwargs)
[t+1m44s256ms]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[t+1m44s256ms]   File "/inferencesh/apps/gpu/65b8e0w0x60df8we0x6njqx9kc/venv/3.12/lib/python3.12/site-packages/diffusers/models/transformers/transformer_wan.py", line 639, in forward
[t+1m44s256ms]     temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
[t+1m44s256ms]                                                                               ^^^^^^^^^^^^^^^^^^^^^^^^
[t+1m44s256ms]   File "/inferencesh/apps/gpu/65b8e0w0x60df8we0x6njqx9kc/venv/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[t+1m44s256ms]     return self._call_impl(*args, **kwargs)
[t+1m44s256ms]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[t+1m44s256ms]   File "/inferencesh/apps/gpu/65b8e0w0x60df8we0x6njqx9kc/venv/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[t+1m44s256ms]     return forward_call(*args, **kwargs)
[t+1m44s256ms]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[t+1m44s256ms]   File "/inferencesh/apps/gpu/65b8e0w0x60df8we0x6njqx9kc/venv/3.12/lib/python3.12/site-packages/diffusers/models/transformers/transformer_wan.py", line 332, in forward
[t+1m44s256ms]     temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
[t+1m44s256ms]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[t+1m44s256ms]   File "/inferencesh/apps/gpu/65b8e0w0x60df8we0x6njqx9kc/venv/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[t+1m44s256ms]     return self._call_impl(*args, **kwargs)
[t+1m44s256ms]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[t+1m44s256ms]   File "/inferencesh/apps/gpu/65b8e0w0x60df8we0x6njqx9kc/venv/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[t+1m44s256ms]     return forward_call(*args, **kwargs)
[t+1m44s256ms]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[t+1m44s256ms]   File "/inferencesh/apps/gpu/65b8e0w0x60df8we0x6njqx9kc/venv/3.12/lib/python3.12/site-packages/diffusers/models/embeddings.py", line 1290, in forward
[t+1m44s256ms]     sample = self.linear_1(sample)
[t+1m44s256ms]              ^^^^^^^^^^^^^^^^^^^^^
[t+1m44s256ms]   File "/inferencesh/apps/gpu/65b8e0w0x60df8we0x6njqx9kc/venv/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[t+1m44s256ms]     return self._call_impl(*args, **kwargs)
[t+1m44s256ms]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[t+1m44s256ms]   File "/inferencesh/apps/gpu/65b8e0w0x60df8we0x6njqx9kc/venv/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[t+1m44s256ms]     return forward_call(*args, **kwargs)
[t+1m44s256ms]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[t+1m44s256ms]   File "/inferencesh/apps/gpu/65b8e0w0x60df8we0x6njqx9kc/venv/3.12/lib/python3.12/site-packages/peft/tuners/lora/layer.py", line 771, in forward
[t+1m44s256ms]     result = result + lora_B(lora_A(dropout(x))) * scaling
[t+1m44s256ms]                              ^^^^^^^^^^^^^^^^^^
[t+1m44s256ms]   File "/inferencesh/apps/gpu/65b8e0w0x60df8we0x6njqx9kc/venv/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[t+1m44s256ms]     return self._call_impl(*args, **kwargs)
[t+1m44s256ms]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[t+1m44s256ms]   File "/inferencesh/apps/gpu/65b8e0w0x60df8we0x6njqx9kc/venv/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[t+1m44s256ms]     return forward_call(*args, **kwargs)
[t+1m44s256ms]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[t+1m44s256ms]   File "/inferencesh/apps/gpu/65b8e0w0x60df8we0x6njqx9kc/venv/3.12/lib/python3.12/site-packages/torch/nn/modules/linear.py", line 125, in forward
[t+1m44s256ms]     return F.linear(input, self.weight, self.bias)
[t+1m44s256ms]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[t+1m44s256ms] RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat2 in method wrapper_CUDA_mm)

System Info

  • 🤗 Diffusers version: 0.35.0.dev0
  • Platform: Linux-5.15.0-136-generic-x86_64-with-glibc2.35
  • Running on Google Colab?: No
  • Python version: 3.10.12
  • PyTorch version (GPU?): 2.7.1+cu126 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.34.3
  • Transformers version: 4.55.0.dev0
  • Accelerate version: 1.8.1
  • PEFT version: 0.16.0
  • Bitsandbytes version: not installed
  • Safetensors version: 0.5.3
  • xFormers version: not installed
  • Accelerator: NVIDIA A100-SXM4-80GB, 81920 MiB
    NVIDIA A100-SXM4-80GB, 81920 MiB
    NVIDIA A100-SXM4-80GB, 81920 MiB
    NVIDIA A100-SXM4-80GB, 81920 MiB
    NVIDIA A100-SXM4-80GB, 81920 MiB
    NVIDIA A100-SXM4-80GB, 81920 MiB
    NVIDIA A100-SXM4-80GB, 81920 MiB
    NVIDIA A100-SXM4-80GB, 81920 MiB
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: no

Who can help?

again, thanks a lot for the help @sayakpaul @a-r-r-o-w

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions