Skip to content

Continuous Batching sliding window attention mask is wrong #41184

@NixGD

Description

@NixGD

System Info

  • transformers version: 4.57.0.dev0
  • Platform: Linux-6.8.0-52-generic-x86_64-with-glibc2.35
  • Python version: 3.11.11
  • Huggingface_hub version: 1.0.0.rc1
  • Safetensors version: 0.6.2
  • Accelerate version: 1.10.1
  • Accelerate config: not found
  • DeepSpeed version: not installed
  • PyTorch version (accelerator?): 2.8.0+cu128 (CUDA)
  • Using distributed or parallel set-up in script?: no
  • Using GPU in script?: no
  • GPU type: NVIDIA A100-SXM4-80GB

Who can help?

@remi-or

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Here's a minimal example showing that the sliding attention leaves more key-value pairs able to be attended to.

import torch
from transformers.generation.continuous_batching.continuous_api import build_attention_mask

seq_len = 16
sliding_window = 4
cumulative_seqlens_q = torch.tensor([0, seq_len])
cumulative_seqlens_k = torch.tensor([0, seq_len])

# initialize masks to all ones -- attention is allowed everywhere
window_mask = torch.ones((1, 1, seq_len, seq_len), dtype=torch.float32)
full_mask = torch.ones((1, 1, seq_len, seq_len), dtype=torch.float32)

# build_attention_mask converts this to 0 & -inf
build_attention_mask(window_mask, cumulative_seqlens_q, cumulative_seqlens_k, sliding_window=sliding_window)
build_attention_mask(full_mask, cumulative_seqlens_q, cumulative_seqlens_k, sliding_window=1)

# entries that are still 0 allow the model to attend to that query-key pair
print("Key/Query pairs the model can attend to (full mask):", (full_mask == 0).sum().item())
print("Key/Query pairs the model can attend to (window mask):", (window_mask == 0).sum().item())

This outputs

Key/Query pairs the model can attend to (full mask): 136
Key/Query pairs the model can attend to (window mask): 202

Expected behavior

Using Continuous Batching with gpt-oss gives meaningless results. I believe this is the reason.

My understanding is the true cause is this line sets elements of the mask from values of -inf to be 0, which means they are allowed to be attended to. Instead the window-attention should be more restrictive.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions