Skip to content

Quanto + Group Offload causes device mismatch error (weights on cpu, mat1 on gpu) #12610

@bghira

Description

@bghira

Describe the bug

When using quanto int8 and Diffusers group offload, we're seeing device mismatch on forward pass.

Reproduction

"""
This demonstrates the core issue: when .to(device) is called on WeightQBytesTensor,
the wrapper reports the new device but internal components may not be moved properly
when accessed through parameter references.
"""

import torch
from optimum.quanto import freeze, qint8, quantize


class SimpleLinear(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(64, 64)

    def forward(self, x):
        return self.linear(x)


def demonstrate_device_issue():
    compute_device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cuda")
    print(f"Using compute device: {compute_device}")
    # Create and quantize
    model = SimpleLinear()
    quantize(model, weights=qint8)
    freeze(model)

    print("=" * 70)
    print("Initial state (after quantization)")
    print("=" * 70)
    weight_param = model.linear.weight
    print(f"Parameter device: {weight_param.device}")
    print(f"Parameter.data type: {type(weight_param.data).__name__}")
    print(f"Parameter.data device: {weight_param.data.device}")
    if hasattr(weight_param.data, "_data"):
        print(f"Parameter.data._data device: {weight_param.data._data.device}")
        print(f"Parameter.data._scale device: {weight_param.data._scale.device}")

    # Simulate what group offload hook does
    print("\n" + "=" * 70)
    print(f"Simulating group offload: moving to {compute_device}")
    print("=" * 70)
    print("Executing: param.data = param.data.to(device)")

    # Save references BEFORE the move
    old_data_id = id(weight_param.data)
    if hasattr(weight_param.data, "_data"):
        old_data_backing_id = id(weight_param.data._data)

    # This is what the hook does
    weight_param.data = weight_param.data.to(compute_device)

    # Check AFTER the move
    new_data_id = id(weight_param.data)
    if hasattr(weight_param.data, "_data"):
        new_data_backing_id = id(weight_param.data._data)

    print(f"\nAfter .to({compute_device}):")
    print(f"Parameter device: {weight_param.device}")
    print(f"Parameter.data device: {weight_param.data.device}")
    if hasattr(weight_param.data, "_data"):
        print(f"Parameter.data._data device: {weight_param.data._data.device}")
        print(f"Parameter.data._scale device: {weight_param.data._scale.device}")

    print(f"\nObject identity changed:")
    print(f"  WeightQBytesTensor wrapper: {old_data_id} -> {new_data_id} ({'NEW' if old_data_id != new_data_id else 'SAME'})")
    if hasattr(weight_param.data, "_data"):
        print(
            f"  Backing _data tensor:       {old_data_backing_id} -> {new_data_backing_id} ({'NEW' if old_data_backing_id != new_data_backing_id else 'SAME'})"
        )

    # Now try to use it
    print("\n" + "=" * 70)
    print("Attempting forward pass")
    print("=" * 70)

    input_tensor = torch.randn(2, 64, device=compute_device)
    print(f"Input device: {input_tensor.device}")

    try:
        output = model(input_tensor)
        print(f"✓ Success! Output device: {output.device}")

        print("\n" + "=" * 70)
        print("Analysis")
        print("=" * 70)
        print(
            "If this succeeded, the parameter reassignment properly updated all references."
        )

    except RuntimeError as e:
        print(f"✗ Failed with device mismatch!")
        print(f"\nError: {e}")

        print("\n" + "=" * 70)
        print("Bug Confirmed")
        print("=" * 70)
        print("The parameter reassignment created a new WeightQBytesTensor")
        print(
            "with components on the target device, but the module's parameter"
        )
        print("reference is pointing to old tensor data.")


if __name__ == "__main__":
    print("=" * 70)
    print("Quanto WeightQBytesTensor Device Mismatch Demonstration")
    print("=" * 70)
    print("\nThis simulates what happens during Diffusers group offload hooks\n")

    demonstrate_device_issue()

    print("\n")
    print("=" * 70)

Logs

======================================================================
Quanto WeightQBytesTensor Device Mismatch Demonstration
======================================================================

This simulates what happens during Diffusers group offload hooks

Using compute device: mps
======================================================================
Initial state (after quantization)
======================================================================
Parameter device: cpu
Parameter.data type: WeightQBytesTensor
Parameter.data device: cpu
Parameter.data._data device: cpu
Parameter.data._scale device: cpu

======================================================================
Simulating group offload: moving to mps
======================================================================
Executing: param.data = param.data.to(device)

After .to(mps):
Parameter device: mps:0
Parameter.data device: cpu
Parameter.data._data device: cpu
Parameter.data._scale device: cpu

Object identity changed:
  WeightQBytesTensor wrapper: 13019515792 -> 13019516672 (NEW)
  Backing _data tensor:       6162576992 -> 4347766144 (NEW)

======================================================================
Attempting forward pass
======================================================================
Input device: mps:0
✗ Failed with device mismatch!

Error: Tensor for argument #2 'mat2' is on CPU, but expected it to be on GPU (while checking arguments for mm)

======================================================================
Bug Confirmed
======================================================================
The parameter reassignment created a new WeightQBytesTensor
with components on the target device, but the module's parameter
reference is pointing to old tensor data.

System Info

  • 🤗 Diffusers version: 0.35.2
  • Platform: macOS-15.5-arm64-arm-64bit
  • Running on Google Colab?: No
  • Python version: 3.12.12
  • PyTorch version (GPU?): 2.9.0 (False)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.36.0
  • Transformers version: 4.57.1
  • Accelerate version: 1.11.0
  • PEFT version: 0.17.1
  • Bitsandbytes version: not installed
  • Safetensors version: 0.6.2
  • xFormers version: not installed
  • Accelerator: Apple M3 Max
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Also occurs using NVIDIA CUDA or AMD ROCm.

Who can help?

@DN6 @sayakpaul

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions