-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
When using quanto int8 and Diffusers group offload, we're seeing device mismatch on forward pass.
Reproduction
"""
This demonstrates the core issue: when .to(device) is called on WeightQBytesTensor,
the wrapper reports the new device but internal components may not be moved properly
when accessed through parameter references.
"""
import torch
from optimum.quanto import freeze, qint8, quantize
class SimpleLinear(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(64, 64)
def forward(self, x):
return self.linear(x)
def demonstrate_device_issue():
compute_device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cuda")
print(f"Using compute device: {compute_device}")
# Create and quantize
model = SimpleLinear()
quantize(model, weights=qint8)
freeze(model)
print("=" * 70)
print("Initial state (after quantization)")
print("=" * 70)
weight_param = model.linear.weight
print(f"Parameter device: {weight_param.device}")
print(f"Parameter.data type: {type(weight_param.data).__name__}")
print(f"Parameter.data device: {weight_param.data.device}")
if hasattr(weight_param.data, "_data"):
print(f"Parameter.data._data device: {weight_param.data._data.device}")
print(f"Parameter.data._scale device: {weight_param.data._scale.device}")
# Simulate what group offload hook does
print("\n" + "=" * 70)
print(f"Simulating group offload: moving to {compute_device}")
print("=" * 70)
print("Executing: param.data = param.data.to(device)")
# Save references BEFORE the move
old_data_id = id(weight_param.data)
if hasattr(weight_param.data, "_data"):
old_data_backing_id = id(weight_param.data._data)
# This is what the hook does
weight_param.data = weight_param.data.to(compute_device)
# Check AFTER the move
new_data_id = id(weight_param.data)
if hasattr(weight_param.data, "_data"):
new_data_backing_id = id(weight_param.data._data)
print(f"\nAfter .to({compute_device}):")
print(f"Parameter device: {weight_param.device}")
print(f"Parameter.data device: {weight_param.data.device}")
if hasattr(weight_param.data, "_data"):
print(f"Parameter.data._data device: {weight_param.data._data.device}")
print(f"Parameter.data._scale device: {weight_param.data._scale.device}")
print(f"\nObject identity changed:")
print(f" WeightQBytesTensor wrapper: {old_data_id} -> {new_data_id} ({'NEW' if old_data_id != new_data_id else 'SAME'})")
if hasattr(weight_param.data, "_data"):
print(
f" Backing _data tensor: {old_data_backing_id} -> {new_data_backing_id} ({'NEW' if old_data_backing_id != new_data_backing_id else 'SAME'})"
)
# Now try to use it
print("\n" + "=" * 70)
print("Attempting forward pass")
print("=" * 70)
input_tensor = torch.randn(2, 64, device=compute_device)
print(f"Input device: {input_tensor.device}")
try:
output = model(input_tensor)
print(f"✓ Success! Output device: {output.device}")
print("\n" + "=" * 70)
print("Analysis")
print("=" * 70)
print(
"If this succeeded, the parameter reassignment properly updated all references."
)
except RuntimeError as e:
print(f"✗ Failed with device mismatch!")
print(f"\nError: {e}")
print("\n" + "=" * 70)
print("Bug Confirmed")
print("=" * 70)
print("The parameter reassignment created a new WeightQBytesTensor")
print(
"with components on the target device, but the module's parameter"
)
print("reference is pointing to old tensor data.")
if __name__ == "__main__":
print("=" * 70)
print("Quanto WeightQBytesTensor Device Mismatch Demonstration")
print("=" * 70)
print("\nThis simulates what happens during Diffusers group offload hooks\n")
demonstrate_device_issue()
print("\n")
print("=" * 70)Logs
======================================================================
Quanto WeightQBytesTensor Device Mismatch Demonstration
======================================================================
This simulates what happens during Diffusers group offload hooks
Using compute device: mps
======================================================================
Initial state (after quantization)
======================================================================
Parameter device: cpu
Parameter.data type: WeightQBytesTensor
Parameter.data device: cpu
Parameter.data._data device: cpu
Parameter.data._scale device: cpu
======================================================================
Simulating group offload: moving to mps
======================================================================
Executing: param.data = param.data.to(device)
After .to(mps):
Parameter device: mps:0
Parameter.data device: cpu
Parameter.data._data device: cpu
Parameter.data._scale device: cpu
Object identity changed:
WeightQBytesTensor wrapper: 13019515792 -> 13019516672 (NEW)
Backing _data tensor: 6162576992 -> 4347766144 (NEW)
======================================================================
Attempting forward pass
======================================================================
Input device: mps:0
✗ Failed with device mismatch!
Error: Tensor for argument #2 'mat2' is on CPU, but expected it to be on GPU (while checking arguments for mm)
======================================================================
Bug Confirmed
======================================================================
The parameter reassignment created a new WeightQBytesTensor
with components on the target device, but the module's parameter
reference is pointing to old tensor data.System Info
- 🤗 Diffusers version: 0.35.2
- Platform: macOS-15.5-arm64-arm-64bit
- Running on Google Colab?: No
- Python version: 3.12.12
- PyTorch version (GPU?): 2.9.0 (False)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.36.0
- Transformers version: 4.57.1
- Accelerate version: 1.11.0
- PEFT version: 0.17.1
- Bitsandbytes version: not installed
- Safetensors version: 0.6.2
- xFormers version: not installed
- Accelerator: Apple M3 Max
- Using GPU in script?:
- Using distributed or parallel set-up in script?:
Also occurs using NVIDIA CUDA or AMD ROCm.
Who can help?
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working