Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"files.associations": {
"optional": "cpp",
"format": "cpp"
}
}
6 changes: 6 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ set(CMAKE_CXX_EXTENSIONS OFF)
set(FA2_ENABLED ON)
set(FA3_ENABLED ON)

# Allow disabling FA3 from environment without changing setup.py
if(DEFINED ENV{FLASH_ATTN_DISABLE_FA3} AND NOT "$ENV{FLASH_ATTN_DISABLE_FA3}" STREQUAL "")
message(STATUS "FLASH_ATTN_DISABLE_FA3 is set -> Disabling FA3 build")
set(FA3_ENABLED OFF)
endif()

# CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py)
set(VLLM_TARGET_DEVICE "cuda" CACHE STRING "Target device backend for vLLM")

Expand Down
134 changes: 134 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,140 @@ def flash_attn_with_kvcache(
To see how these functions are used in a multi-head attention layer (which
includes QKV projection, output projection), see the MHA [implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py).

## Auxiliary returns: abs_s (sum of |scores| before softmax) — FA2, experimental

This repository exposes an experimental auxiliary output for the FA2 varlen forward path via the vLLM wrapper, useful for numerical analysis and debugging. When enabled, the kernel returns an additional tensor containing the sum of absolute attention scores (|scores|) computed before softmax and using the standard scale 1/sqrt(D). This auxiliary is not used in training/inference logic and applies no causal/local masks.

- Non-paged varlen (linear KV): returns abs_s with shape (Hq, total_q), where total_q is the sum of per-sequence query lengths.
- Paged-KV (vLLM page cache): returns per-page abs_s with shape (B, Hq, max_blocks), aligned to the KV page layout where page_size = k_paged.size(1) and block_table[b, p] maps the p-th page of sequence b to a physical block index. The last page can be partial as determined by seqused_k[b].

Key constraints
- Available only for FA2 through the vLLM wrapper function flash_attn_varlen_func.
- Set return_aux=True to request auxiliary returns; the per-page abs_s is the first extra tensor (outputs[2]).
- Paged-KV path requires block_table and seqused_k, and cu_seqlens_k must be None (mutually exclusive with seqused_k).
- The auxiliary tensors are returned in float32 for stable comparison/aggregation.

Examples

1) Non-paged varlen: per-token abs_s (Hq, total_q)

```python
import torch
from vllm_flash_attn.flash_attn_interface import flash_attn_varlen_func

# q: (total_q, Hq, D), k/v: (total_k, Hkv, D)
out, lse, abs_s = flash_attn_varlen_func(
q, k, v,
max_seqlen_q=max_seqlen_q,
cu_seqlens_q=cu_seqlens_q, # (B+1,) int32
max_seqlen_k=max_seqlen_k,
cu_seqlens_k=cu_seqlens_k, # (B+1,) int32 — required in non-paged
seqused_k=None,
causal=True,
softcap=0.0,
return_softmax_lse=True,
return_aux=True,
fa_version=2,
)

print(abs_s.shape) # (Hq, total_q), dtype=float32
```

2) Paged-KV varlen: per-page abs_s (B, Hq, max_blocks)

```python
import torch
from vllm_flash_attn.flash_attn_interface import flash_attn_varlen_func

# Paged KV: k/v are (num_blocks, page_size, Hkv, D)
outputs = flash_attn_varlen_func(
q, k_paged, v_paged,
max_seqlen_q=max_seqlen_q,
cu_seqlens_q=cu_seqlens_q, # (B+1,) int32
max_seqlen_k=max_seqlen_k,
cu_seqlens_k=None, # IMPORTANT: must be None with paged-KV
seqused_k=seqused_k, # (B,) int32 — required
block_table=block_table, # (B, max_blocks) int32 — required
causal=True,
softcap=0.0,
return_softmax_lse=True,
return_aux=True,
fa_version=2,
)

out, lse = outputs[0], outputs[1]
per_page_abs_s = outputs[2] # (B, Hq, max_blocks), float32
```

Notes
- The auxiliary is returned only if your build includes the helper; otherwise outputs will contain just the usual (out, lse). Rebuild from source if needed.
- abs_s semantics: sum over absolute scores for each head; for paged-KV it aggregates over (Sq, page_len) within each page. No masking is applied; the 1/sqrt(D) scale is used.


本仓库通过 vLLM 封装在 FA2 变长前向路径提供一个实验性的辅助返回,用于数值分析与调试:返回在 softmax 之前、按 1/sqrt(D) 缩放的注意力分数的绝对值之和(|scores|)。该辅助不参与训练/推理逻辑,且不应用因果/局部掩码。

- 非分页(线性 KV):返回逐 token 的 abs_s,形状为 (Hq, total_q),其中 total_q 为所有序列的 query 长度之和。
- 分页 KV(vLLM Page Cache):返回每页的 abs_s 汇总,形状为 (B, Hq, max_blocks)。页大小 page_size = k_paged.size(1),page 与 block_table[b, p] 指向的物理块一一对应;末页可能不满,由 seqused_k[b] 决定。

关键约束
- 仅通过 vLLM 封装的 flash_attn_varlen_func(FA2)获取。
- 设置 return_aux=True 才会返回辅助张量;每页统计是第一个额外返回(outputs[2])。
- 分页路径必须同时提供 block_table 与 seqused_k,且 cu_seqlens_k 必须为 None(与 seqused_k 互斥)。
- 辅助张量以 float32 返回,便于比较与聚合。

示例

1)非分页:逐 token 的 abs_s(Hq, total_q)

```python
import torch
from vllm_flash_attn.flash_attn_interface import flash_attn_varlen_func

out, lse, abs_s = flash_attn_varlen_func(
q, k, v,
max_seqlen_q=max_seqlen_q,
cu_seqlens_q=cu_seqlens_q,
max_seqlen_k=max_seqlen_k,
cu_seqlens_k=cu_seqlens_k, # 非分页需要提供
seqused_k=None,
causal=True,
softcap=0.0,
return_softmax_lse=True,
return_aux=True,
fa_version=2,
)

print(abs_s.shape) # (Hq, total_q), float32
```

2)分页:每页的 abs_s(B, Hq, max_blocks)

```python
import torch
from vllm_flash_attn.flash_attn_interface import flash_attn_varlen_func

outputs = flash_attn_varlen_func(
q, k_paged, v_paged,
max_seqlen_q=max_seqlen_q,
cu_seqlens_q=cu_seqlens_q,
max_seqlen_k=max_seqlen_k,
cu_seqlens_k=None, # 重要:分页路径下必须为 None
seqused_k=seqused_k, # 必须提供
block_table=block_table, # 必须提供
causal=True,
softcap=0.0,
return_softmax_lse=True,
return_aux=True,
fa_version=2,
)

per_page_abs_s = outputs[2] # (B, Hq, max_blocks), float32
```

说明
- 若构建不包含该辅助返回,outputs 仅有常规 (out, lse)。可按源码重新编译以启用。
- abs_s 的语义:对每个 head 的注意力分数取绝对值并求和;分页场景下对每页在 (Sq, page_len) 范围内聚合。不应用掩码,使用 1/sqrt(D) 的缩放。

## Changelog

### 2.0: Complete rewrite, 2x faster
Expand Down
Loading