Skip to content

[Chroma] Enabling attention backend checks causes error #12575

@zzlol63

Description

@zzlol63

Describe the bug

I am using OneTrainer which uses diffusers to train a LoRA for Chroma. The latest diffusers introduces the concept of a selectable attention backend which can be enabled via the method documented below:
https://huggingface.co/docs/diffusers/main/en/optimization/attention_backends

I've tried to only enable the DIFFUSERS_ATTN_CHECKS flag via environment variable to force diffusers to do shape and dtype validation on the parameters as part of the attention backend routing. However, even with the standard default native backend (SDPA), an error occurs (see below).

OneTrainer passes a 2-dimensional attention mask which is then made into a 4-dimensional one in the following line in diffusers:

attention_mask = attention_mask[:, None, None, :] * attention_mask[:, None, :, None]

This 4-dimensional attention mask does not meet the validations enforced as part of _check_shape:

if attn_mask is not None and attn_mask.shape[-1] != key.shape[-2]:

I can't speak to what the correct shape of the mask is as I'm not acquainted with the model's architecture so unsure if it's an issue in OneTrainer or diffusers itself.

Reproduction

Set DIFFUSERS_ATTN_CHECKS=1 and use OneTrainer with one of the Chroma LoRA presets (depending on VRAM) and observe the error produced by diffusers.

Logs

File "C:\OneTrainer\modules\modelSetup\BaseChromaSetup.py", line 233, in predict
    packed_predicted_flow = model.transformer(
  File "C:\OneTrainer\venv\lib\site-packages\torch\nn\modules\module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\OneTrainer\venv\lib\site-packages\torch\nn\modules\module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\OneTrainer\venv\src\diffusers\src\diffusers\models\transformers\transformer_chroma.py", line 577, in forward
    encoder_hidden_states, hidden_states = block(
  File "C:\OneTrainer\venv\lib\site-packages\torch\nn\modules\module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\OneTrainer\venv\lib\site-packages\torch\nn\modules\module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\OneTrainer\modules\util\checkpointing_util.py", line 169, in forward
    return custom_forward(None, *args, **kwargs)
  File "C:\OneTrainer\modules\util\checkpointing_util.py", line 148, in custom_forward
    return orig_forward(
  File "C:\OneTrainer\venv\src\diffusers\src\diffusers\models\transformers\transformer_chroma.py", line 327, in forward
    attention_outputs = self.attn(
  File "C:\OneTrainer\venv\lib\site-packages\torch\nn\modules\module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\OneTrainer\venv\lib\site-packages\torch\nn\modules\module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\OneTrainer\venv\src\diffusers\src\diffusers\models\transformers\transformer_flux.py", line 342, in forward
    return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
  File "C:\OneTrainer\venv\src\diffusers\src\diffusers\models\transformers\transformer_flux.py", line 116, in __call__
    hidden_states = dispatch_attention_fn(
  File "C:\OneTrainer\venv\src\diffusers\src\diffusers\models\attention_dispatch.py", line 271, in dispatch_attention_fn
    check(**kwargs)
  File "C:\OneTrainer\venv\src\diffusers\src\diffusers\models\attention_dispatch.py", line 335, in _check_shape
    raise ValueError("Attention mask must match the key's second to last dimension.")
ValueError: Attention mask must match the key's second to last dimension.

System Info

  • 🤗 Diffusers version: 0.36.0.dev0
  • Platform: Windows-10-10.0.26100-SP0
  • Running on Google Colab?: No
  • Python version: 3.10.16
  • PyTorch version (GPU?): 2.7.1+cu128 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.34.4
  • Transformers version: 4.52.4
  • Accelerate version: 1.7.0
  • PEFT version: not installed
  • Bitsandbytes version: 0.46.0
  • Safetensors version: 0.5.3
  • xFormers version: not installed
  • Accelerator: NVIDIA GeForce RTX 5090, 32607 MiB
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions