Vllm_flash_attn_with_attention_weights #88
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This pull request introduces an experimental auxiliary output for the FA2 variable-length forward path, allowing users to obtain the sum of absolute attention scores (|scores|) before softmax for each head and token or per page in paged-KV mode. This feature is exposed via the vLLM wrapper and is primarily intended for numerical analysis and debugging. The implementation includes both Python/C++ and CUDA kernel changes to support this auxiliary return.
Feature: Auxiliary abs_s Output for FA2 Varlen Forward (Numerical Analysis/Debugging)
flash_attn_varlen_func
) by settingreturn_aux=True
. This output provides the sum of absolute pre-softmax attention scores, scaled by 1/sqrt(D), for each head and token (non-paged) or per page (paged-KV).varlen_fwd_with_abs_aux
inflash_api_torch_lib.cpp
, which computes and returns the auxiliary tensor (abs_s
) alongside the usual outputs. Registered this function in the PyTorch extension. [1] [2]Kernel/Parameter Changes for Per-Page Accumulation
Flash_fwd_params
struct inflash.h
to include pointers and stride information for accumulating pre-softmax |S| per page, enabling efficient per-page statistics in the CUDA kernel.accumulate_abslogits_per_page
inflash_fwd_kernel.h
to atomically accumulate the absolute values of pre-softmax scores into the provided buffer for each batch, head, query, and page. This is called in all relevant kernel paths. [1] [2] [3] [4]Developer Experience
.vscode/settings.json
to improve code navigation in VSCode by associating certain file types with C++.Build Configuration
CMakeLists.txt
to allow disabling FA3 via theFLASH_ATTN_DISABLE_FA3
environment variable, improving build flexibility for users who only want FA2.