Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ var/

# IDE-related
.idea/
.vscode/

# Dev
venv
11 changes: 11 additions & 0 deletions flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -1389,6 +1389,9 @@ def flash_attn_varlen_func(
deterministic=False,
return_attn_probs=False,
block_table=None,
dcp_rank=None,
dcp_world_size=None,
query_base_positions=None,
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
Expand Down Expand Up @@ -1445,6 +1448,14 @@ def flash_attn_varlen_func(
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
# Context parallelism parameters are only supported in Flash Attention 3
# For Flash Attention 2, these parameters are ignored with a warning
if dcp_rank is not None or dcp_world_size is not None or query_base_positions is not None:
import warnings
warnings.warn("Context parallelism parameters (dcp_rank, dcp_world_size, query_base_positions) "
"are only supported in Flash Attention 3. These parameters will be ignored.",
UserWarning)

return FlashAttnVarlenFunc.apply(
q,
k,
Expand Down
5 changes: 5 additions & 0 deletions hopper/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,11 @@ struct Flash_fwd_params : public Qkv_params {

// The S extra matrix, (num_heads)
void *__restrict__ s_aux_ptr;

// Context parallelism parameters for MLA decode
int dcp_rank;
int dcp_world_size;
int *__restrict__ query_base_positions;
};

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
28 changes: 26 additions & 2 deletions hopper/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,10 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
int num_splits,
std::optional<bool> pack_gqa_,
int const sm_margin,
std::optional<const at::Tensor> &s_aux_ // (h)
std::optional<const at::Tensor> &s_aux_, // (h)
int dcp_rank,
int dcp_world_size,
std::optional<const at::Tensor> &query_base_positions_
) {

auto dprops = at::cuda::getCurrentDeviceProperties();
Expand Down Expand Up @@ -1148,6 +1151,15 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
params.s_aux_ptr = nullptr;
}

// Set context parallelism parameters
params.dcp_rank = dcp_rank;
params.dcp_world_size = dcp_world_size;
if (query_base_positions_.has_value()) {
params.query_base_positions = query_base_positions_.value().data_ptr<int>();
} else {
params.query_base_positions = nullptr;
}

#ifdef FLASHATTENTION_DISABLE_LOCAL
TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");
#endif
Expand Down Expand Up @@ -1280,7 +1292,10 @@ std::vector<at::Tensor> mha_bwd(
int window_size_right,
float const softcap,
bool const deterministic,
int const sm_margin) {
int const sm_margin,
int dcp_rank,
int dcp_world_size,
std::optional<const at::Tensor> &query_base_positions_) {

#ifdef FLASHATTENTION_DISABLE_BACKWARD
TORCH_CHECK(false, "This flash attention build does not support backward.");
Expand Down Expand Up @@ -1513,6 +1528,15 @@ std::vector<at::Tensor> mha_bwd(
params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr();
params.dv = head_size; // We don't support hdim_v being different from hdim_qk for now

// Set context parallelism parameters
params.dcp_rank = dcp_rank;
params.dcp_world_size = dcp_world_size;
if (query_base_positions_.has_value()) {
params.query_base_positions = query_base_positions_.value().data_ptr<int>();
} else {
params.query_base_positions = nullptr;
}

// auto tile_count_semaphore = (params.is_causal || params.is_local) ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32));
// params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
// Will be zero'ed out in the backward preprocess kernel
Expand Down
10 changes: 8 additions & 2 deletions hopper/flash_api_torch_lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
int num_splits,
std::optional<bool> pack_gqa_,
int const sm_margin,
std::optional<const at::Tensor> &s_aux_
std::optional<const at::Tensor> &s_aux_,
int dcp_rank,
int dcp_world_size,
std::optional<const at::Tensor> &query_base_positions_
);

// Only applicable to the case where seqused_k (i.e. cache_seqlens) is available
Expand Down Expand Up @@ -120,7 +123,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int num_splits,"
" bool? pack_gqa,"
" int sm_margin,"
" Tensor? s_aux) -> Tensor[]");
" Tensor? s_aux,"
" int dcp_rank,"
" int dcp_world_size,"
" Tensor? query_base_positions) -> Tensor[]");
ops.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd));

ops.def("get_scheduler_metadata("
Expand Down
43 changes: 40 additions & 3 deletions hopper/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ def _flash_attn_forward(
num_splits=1,
pack_gqa=None,
sm_margin=0,
s_aux=None):
s_aux=None,
dcp_rank=0,
dcp_world_size=1,
query_base_positions=None):
q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
Expand All @@ -61,6 +64,9 @@ def _flash_attn_forward(
]
rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
seqlens_rotary = maybe_contiguous(seqlens_rotary)
# Handle context parallelism parameters
query_base_positions = maybe_contiguous(query_base_positions)

out, softmax_lse, *rest = flash_attn_3_cuda.fwd(
q,
k,
Expand Down Expand Up @@ -95,7 +101,10 @@ def _flash_attn_forward(
num_splits,
pack_gqa,
sm_margin,
s_aux
s_aux,
dcp_rank,
dcp_world_size,
query_base_positions
)
return out, softmax_lse, *rest

Expand All @@ -122,9 +131,15 @@ def _flash_attn_backward(
softcap=0.0,
deterministic=False,
sm_margin=0,
dcp_rank=0,
dcp_world_size=1,
query_base_positions=None,
):
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
# Handle context parallelism parameters
query_base_positions = maybe_contiguous(query_base_positions)

dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
dout,
q,
Expand All @@ -148,6 +163,9 @@ def _flash_attn_backward(
softcap,
deterministic,
sm_margin,
dcp_rank,
dcp_world_size,
query_base_positions,
)
return dq, dk, dv, softmax_d

Expand Down Expand Up @@ -351,6 +369,9 @@ def forward(
deterministic=False,
sm_margin=0,
s_aux=None,
dcp_rank=0,
dcp_world_size=1,
query_base_positions=None,
):
if softmax_scale is None:
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
Expand Down Expand Up @@ -380,6 +401,9 @@ def forward(
pack_gqa=pack_gqa,
sm_margin=sm_margin,
s_aux=s_aux,
dcp_rank=dcp_rank,
dcp_world_size=dcp_world_size,
query_base_positions=query_base_positions,
)
# ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
Expand All @@ -391,6 +415,10 @@ def forward(
ctx.softcap = softcap
ctx.deterministic = deterministic
ctx.sm_margin = sm_margin
# Save context parallelism parameters for backward pass
ctx.dcp_rank = dcp_rank
ctx.dcp_world_size = dcp_world_size
ctx.query_base_positions = query_base_positions
return out, softmax_lse

@staticmethod
Expand Down Expand Up @@ -419,11 +447,14 @@ def backward(ctx, dout, *args):
ctx.softcap,
ctx.deterministic,
ctx.sm_margin,
ctx.dcp_rank,
ctx.dcp_world_size,
ctx.query_base_positions,
)
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None


def flash_attn_qkvpacked_func(
Expand Down Expand Up @@ -582,6 +613,9 @@ def flash_attn_varlen_func(
deterministic=False,
sm_margin=0,
s_aux=None,
dcp_rank=0,
dcp_world_size=1,
query_base_positions=None,
):
return FlashAttnVarlenFunc.apply(
q,
Expand All @@ -604,6 +638,9 @@ def flash_attn_varlen_func(
deterministic,
sm_margin,
s_aux,
dcp_rank,
dcp_world_size,
query_base_positions,
)


Expand Down
13 changes: 11 additions & 2 deletions hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,9 @@ struct CollectiveMainloopFwdSm90 {
int const* const leftpad_k = nullptr;
int const* const seqlens_rotary = nullptr;
ElementSAux const* const ptr_S_aux = nullptr;
int const dcp_rank = 0;
int const dcp_world_size = 1;
int const* const query_base_positions = nullptr;
};

// Device side kernel params
Expand Down Expand Up @@ -469,6 +472,9 @@ struct CollectiveMainloopFwdSm90 {
int const* const leftpad_k = nullptr;
int const* const seqlens_rotary = nullptr;
ElementSAux const* const ptr_S_aux = nullptr;
int const dcp_rank = 0;
int const dcp_world_size = 1;
int const* const query_base_positions = nullptr;
};

static Params
Expand Down Expand Up @@ -584,7 +590,8 @@ struct CollectiveMainloopFwdSm90 {
args.kv_batch_idx,
args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new,
args.seqused_q, args.seqused_k, args.leftpad_k, args.seqlens_rotary,
args.ptr_S_aux};
args.ptr_S_aux,
args.dcp_rank, args.dcp_world_size, args.query_base_positions};
}

/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
Expand Down Expand Up @@ -1093,7 +1100,9 @@ struct CollectiveMainloopFwdSm90 {
// But we subtract n_offset for consistency in mask calculations
flash::Mask<kBlockM, kBlockN, PackGQA, TiledMmaQK> mask(
thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 - n_offset /*sink_token_length*/,
params.qhead_per_khead_divmod
params.qhead_per_khead_divmod,
params.dcp_rank, params.dcp_world_size,
params.query_base_positions != nullptr ? params.query_base_positions[bidb] : 0
);

float softcap_val = params.softcap_val;
Expand Down
15 changes: 13 additions & 2 deletions hopper/mask.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,25 @@ struct Mask {
int const seqlen_q, seqlen_k;
int const window_size_left, window_size_right, sink_token_length;
cutlass::FastDivmod const qhead_per_khead_divmod;
// Context parallelism parameters for MLA decode
int const dcp_rank, dcp_world_size;
int const query_base_position;

CUTLASS_DEVICE
Mask(const int thread_idx, const int seqlen_q, const int seqlen_k,
const int window_size_left, const int window_size_right, const int sink_token_length,
cutlass::FastDivmod const &qhead_per_khead_divmod)
cutlass::FastDivmod const &qhead_per_khead_divmod,
const int dcp_rank = 0, const int dcp_world_size = 1, const int query_base_position = 0)
: thread_idx(thread_idx)
, seqlen_q(seqlen_q)
, seqlen_k(seqlen_k)
, window_size_left(window_size_left)
, window_size_right(window_size_right)
, sink_token_length(sink_token_length)
, qhead_per_khead_divmod(qhead_per_khead_divmod)
, dcp_rank(dcp_rank)
, dcp_world_size(dcp_world_size)
, query_base_position(query_base_position)
{
};

Expand Down Expand Up @@ -89,8 +96,12 @@ struct Mask {
int const row_idx = !PackGQA
? get<Row>(tScS_rowcol(m, _0{})) + m_block * kBlockM
: __shfl_sync(0xffffffff, mma_m_idx, m % kMmaThreadsPerRow, kMmaThreadsPerRow);
// For context parallelism, adjust causal mask based on global query position
int const global_row_idx = query_base_position + row_idx;
int const kv_offset = (dcp_world_size > 1) ? (seqlen_k * dcp_rank / dcp_world_size) : 0;
int const adjusted_causal_row_offset = causal_row_offset - kv_offset;
int const col_limit_right = !Seqlenk_mask
? row_idx + causal_row_offset
? ((dcp_world_size > 1) ? global_row_idx + adjusted_causal_row_offset : row_idx + causal_row_offset)
: __viaddmin_s32(row_idx, causal_row_offset, seqlenk_col_limit);
#pragma unroll
for (int n = 0; n < size<1>(tSrS_rowcol); ++n) {
Expand Down
14 changes: 13 additions & 1 deletion vllm_flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ def flash_attn_varlen_func(
# Version selector
fa_version: int = DEFAULT_FA_VERSION,
s_aux=None,
# Context parallelism parameters for MLA decode
dcp_rank=None,
dcp_world_size=None,
query_base_positions=None,
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
Expand Down Expand Up @@ -255,6 +259,11 @@ def flash_attn_varlen_func(
)
elif fa_version == 3:
assert alibi_slopes is None, "Alibi is not supported in FA3"

# Handle context parallelism parameters - convert None to default values
dcp_rank_val = dcp_rank if dcp_rank is not None else 0
dcp_world_size_val = dcp_world_size if dcp_world_size is not None else 1

out, softmax_lse, _, _ = torch.ops._vllm_fa3_C.fwd(
q, k, v,
None, None, # k_new, v_new
Expand All @@ -279,7 +288,10 @@ def flash_attn_varlen_func(
num_splits,
None, # pack_gqa
0, # sm_margin
s_aux # s_aux
s_aux, # s_aux
dcp_rank_val, # dcp_rank
dcp_world_size_val, # dcp_world_size
query_base_positions # query_base_positions
)
else:
raise ValueError(f"Unsupported FA version: {fa_version}")
Expand Down