You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: site/docs/concepts/optimization-techniques/sparse-attention-prefill.md
+9-2Lines changed: 9 additions & 2 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -42,11 +42,18 @@ The prompt processing occurs as usual until at least two KV cache blocks have be
42
42
After that, for the next prompt chunks only the first and the last/second-last blocks processed will be visible as KV cache contents, effectively introducing sparsity in the attention computation for the rest of the KV cache "body" (`t = 2-4`).
43
43
44
44
Upon reaching the tail of the prompt the KV cache for the entire prompt is used in attention again, effectively switching back from the sparse attention mode to "dense" attention (`t = 5`).
45
-
Apart from improving the generation accuracy, this also makes it possible to effectively combine the tri-shape sparse prefill algorithm with the cache eviction algorithm, which relies on the model having "seen" the entire prompt KV cache when processing the last tokens of the prompt. The "dense attention" portion of the prompt can be configured using the `SparseAttentionConfig.num_retained_recent_tokens_in_cache` field.
45
+
Apart from improving the generation accuracy, this also makes it possible to effectively combine the tri-shape sparse prefill algorithm with the cache eviction algorithm, which relies on the model having "seen" the entire prompt KV cache when processing the last tokens of the prompt. The "dense attention" portion of the prompt can be configured using the `SparseAttentionConfig.num_last_dense_tokens_in_prefill` field.
46
46
47
47
48
48
### XAttention
49
-
TBA
49
+
For the XAttention algorithm, the prefill computation is accelerated by selectively attending only to the most important regions of the attention matrix, determined dynamically through antidiagonal-based importance estimation. During the prefill stage, each query block attends only to the subset of key blocks whose cumulative estimated attention mass exceeds a predefined threshold, while the rest of the KV cache blocks are excluded from the attention computation.
50
50
51
+
The importance estimation procedure consists of two stages. In the first stage, using stride-based reshaping, the query and key tensors are permuted along antidiagonal patterns, with the stride value determined by the `SparseAttentionConfig.xattention_stride` parameter. The reshaped tensors are then used to compute a coarse estimate of the attention mass per block, with the block size defined by `SparseAttentionConfig.xattention_block_size`. The attention values within each block are summed to produce an importance score that represents the approximate total attention mass associated with that block. In the second stage, for each query block, the corresponding key blocks are sorted in descending order of their estimated attention mass. The algorithm then identifies the minimal subset of blocks whose cumulative antidiagonal attention exceeds the predefined threshold `SparseAttentionConfig.xattention_threshold`. The block selection process always retains the diagonal blocks - corresponding to the most recently processed query positions, as well as the least recent KV cache block.
The picture above illustrates the XAttention algorithm in more detail. For simplicity, it is presumed that the prompt occupies 8 full KV cache blocks and is processed within 5 chunks. The `xattention_block_size` corresponds to one HW-dependent block of tokens.
56
+
57
+
The prompt processing occurs as usual until at least two KV cache blocks have been completely filled (`t = 0, 1`). Once the block-level importance scores have been computed (`t = 2-4`), only the subset of KV blocks with cumulative attention mass exceeding the `xattention_threshold` are retained for attention computation, effectively introducing sparsity in the attention computation.
58
+
59
+
Upon reaching the tail of the prompt, the KV cache corresponding to the entire prompt becomes visible again, reverting to dense attention mode (`t = 5`). This transition ensures that the model attends to the complete prompt context before entering the generation stage. Similar to the tri-shape algorithm, the final dense portion of the prefill can be configured using the `SparseAttentionConfig.num_last_dense_tokens_in_prefill` field. Due to the block-wise cache organization and scheduler chunking, the actual number of prompt tokens processed with dense attention may slightly exceed the specified value, potentially extending across a full block or subsequence chunk depending on the hardware configuration.
0 commit comments