-
Notifications
You must be signed in to change notification settings - Fork 6.4k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
After @sayakpaul great work in #12040 LightX2V now works. However what doesn't work is adding both a lora and offloading to the transformer_2. I can get away with either (i.e. offload both transformers but add a lora only to transformer and NOT to transformer_2, OR offload just transformer and add a lora to both transformer_2 and transformer).
However offloading transformer_2 is quite important, since it causes 2x the VRAM to be used, and even a Q4_K_S model with LightX2V will use >24gb vram (as opposed to <9GB VRAM as in ComfyUI).
Reproduction
The script is the same as the one posted by Paul in the #12040 PR with the addition of offloading
import torch
from diffusers import WanImageToVideoPipeline
from huggingface_hub import hf_hub_download
import requests
from PIL import Image
from diffusers.loaders.lora_conversion_utils import _convert_non_diffusers_wan_lora_to_diffusers
from io import BytesIO
import safetensors.torch
# Load a basic transformer model
pipe = WanImageToVideoPipeline.from_pretrained(
"Wan-AI/Wan2.2-I2V-A14B-Diffusers",
torch_dtype=torch.bfloat16
)
lora_path = hf_hub_download(
repo_id="Kijai/WanVideo_comfy",
filename="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors"
)
# This is what is different
self.pipe.vae.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level")
self.pipe.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level")
# Without this line it works but uses 2x the VRAM
self.pipe.transformer_2.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level")
self.pipe.text_encoder.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level")
pipe.to("cuda")
pipe.load_lora_weights(lora_path)
# print(pipe.transformer.__class__.__name__)
# print(pipe.transformer.peft_config)
org_state_dict = safetensors.torch.load_file(lora_path)
converted_state_dict = _convert_non_diffusers_wan_lora_to_diffusers(org_state_dict)
pipe.transformer_2.load_lora_adapter(converted_state_dict)
image_url = "https://cloud.inference.sh/u/4mg21r6ta37mpaz6ktzwtt8krr/01k1g7k73eebnrmzmc6h0bghq6.png"
response = requests.get(image_url)
input_image = Image.open(BytesIO(response.content)).convert("RGB")
frames = pipe(input_image, "animate", num_inference_steps=4, guidance_scale=1.0)
Logs
[t+1m44s256ms] [ERROR] Traceback (most recent call last):
[t+1m44s256ms] File "/server/tasks.py", line 50, in run_task
[t+1m44s256ms] output = await result
[t+1m44s256ms] ^^^^^^^^^^^^
[t+1m44s256ms] File "/inferencesh/apps/gpu/65b8e0w0x60df8we0x6njqx9kc/src/inference.py", line 424, in run
[t+1m44s256ms] output = self.pipe(
[t+1m44s256ms] ^^^^^^^^^^
[t+1m44s256ms] File "/inferencesh/apps/gpu/65b8e0w0x60df8we0x6njqx9kc/venv/3.12/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[t+1m44s256ms] return func(*args, **kwargs)
[t+1m44s256ms] ^^^^^^^^^^^^^^^^^^^^^
[t+1m44s256ms] File "/inferencesh/apps/gpu/65b8e0w0x60df8we0x6njqx9kc/venv/3.12/lib/python3.12/site-packages/diffusers/pipelines/wan/pipeline_wan_i2v.py", line 754, in __call__
[t+1m44s256ms] noise_pred = current_model(
[t+1m44s256ms] ^^^^^^^^^^^^^^
[t+1m44s256ms] File "/inferencesh/apps/gpu/65b8e0w0x60df8we0x6njqx9kc/venv/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[t+1m44s256ms] return self._call_impl(*args, **kwargs)
[t+1m44s256ms] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[t+1m44s256ms] File "/inferencesh/apps/gpu/65b8e0w0x60df8we0x6njqx9kc/venv/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[t+1m44s256ms] return forward_call(*args, **kwargs)
[t+1m44s256ms] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[t+1m44s256ms] File "/inferencesh/apps/gpu/65b8e0w0x60df8we0x6njqx9kc/venv/3.12/lib/python3.12/site-packages/diffusers/hooks/hooks.py", line 189, in new_forward
[t+1m44s256ms] output = function_reference.forward(*args, **kwargs)
[t+1m44s256ms] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[t+1m44s256ms] File "/inferencesh/apps/gpu/65b8e0w0x60df8we0x6njqx9kc/venv/3.12/lib/python3.12/site-packages/diffusers/models/transformers/transformer_wan.py", line 639, in forward
[t+1m44s256ms] temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
[t+1m44s256ms] ^^^^^^^^^^^^^^^^^^^^^^^^
[t+1m44s256ms] File "/inferencesh/apps/gpu/65b8e0w0x60df8we0x6njqx9kc/venv/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[t+1m44s256ms] return self._call_impl(*args, **kwargs)
[t+1m44s256ms] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[t+1m44s256ms] File "/inferencesh/apps/gpu/65b8e0w0x60df8we0x6njqx9kc/venv/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[t+1m44s256ms] return forward_call(*args, **kwargs)
[t+1m44s256ms] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[t+1m44s256ms] File "/inferencesh/apps/gpu/65b8e0w0x60df8we0x6njqx9kc/venv/3.12/lib/python3.12/site-packages/diffusers/models/transformers/transformer_wan.py", line 332, in forward
[t+1m44s256ms] temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
[t+1m44s256ms] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[t+1m44s256ms] File "/inferencesh/apps/gpu/65b8e0w0x60df8we0x6njqx9kc/venv/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[t+1m44s256ms] return self._call_impl(*args, **kwargs)
[t+1m44s256ms] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[t+1m44s256ms] File "/inferencesh/apps/gpu/65b8e0w0x60df8we0x6njqx9kc/venv/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[t+1m44s256ms] return forward_call(*args, **kwargs)
[t+1m44s256ms] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[t+1m44s256ms] File "/inferencesh/apps/gpu/65b8e0w0x60df8we0x6njqx9kc/venv/3.12/lib/python3.12/site-packages/diffusers/models/embeddings.py", line 1290, in forward
[t+1m44s256ms] sample = self.linear_1(sample)
[t+1m44s256ms] ^^^^^^^^^^^^^^^^^^^^^
[t+1m44s256ms] File "/inferencesh/apps/gpu/65b8e0w0x60df8we0x6njqx9kc/venv/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[t+1m44s256ms] return self._call_impl(*args, **kwargs)
[t+1m44s256ms] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[t+1m44s256ms] File "/inferencesh/apps/gpu/65b8e0w0x60df8we0x6njqx9kc/venv/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[t+1m44s256ms] return forward_call(*args, **kwargs)
[t+1m44s256ms] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[t+1m44s256ms] File "/inferencesh/apps/gpu/65b8e0w0x60df8we0x6njqx9kc/venv/3.12/lib/python3.12/site-packages/peft/tuners/lora/layer.py", line 771, in forward
[t+1m44s256ms] result = result + lora_B(lora_A(dropout(x))) * scaling
[t+1m44s256ms] ^^^^^^^^^^^^^^^^^^
[t+1m44s256ms] File "/inferencesh/apps/gpu/65b8e0w0x60df8we0x6njqx9kc/venv/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[t+1m44s256ms] return self._call_impl(*args, **kwargs)
[t+1m44s256ms] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[t+1m44s256ms] File "/inferencesh/apps/gpu/65b8e0w0x60df8we0x6njqx9kc/venv/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[t+1m44s256ms] return forward_call(*args, **kwargs)
[t+1m44s256ms] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[t+1m44s256ms] File "/inferencesh/apps/gpu/65b8e0w0x60df8we0x6njqx9kc/venv/3.12/lib/python3.12/site-packages/torch/nn/modules/linear.py", line 125, in forward
[t+1m44s256ms] return F.linear(input, self.weight, self.bias)
[t+1m44s256ms] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[t+1m44s256ms] RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat2 in method wrapper_CUDA_mm)
System Info
- 🤗 Diffusers version: 0.35.0.dev0
- Platform: Linux-5.15.0-136-generic-x86_64-with-glibc2.35
- Running on Google Colab?: No
- Python version: 3.10.12
- PyTorch version (GPU?): 2.7.1+cu126 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.34.3
- Transformers version: 4.55.0.dev0
- Accelerate version: 1.8.1
- PEFT version: 0.16.0
- Bitsandbytes version: not installed
- Safetensors version: 0.5.3
- xFormers version: not installed
- Accelerator: NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB - Using GPU in script?: yes
- Using distributed or parallel set-up in script?: no
Who can help?
again, thanks a lot for the help @sayakpaul @a-r-r-o-w
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working