Skip to content

Commit a893712

Browse files
Fix local attention (#102)
* fix local attention Signed-off-by: Lucas Wilkinson <[email protected]> * use tot_seqlen_k Signed-off-by: Lucas Wilkinson <[email protected]> * more stable fix Signed-off-by: Lucas Wilkinson <[email protected]> * better error message Signed-off-by: Lucas Wilkinson <[email protected]> * make consistent with og dcp pr Signed-off-by: Lucas Wilkinson <[email protected]> --------- Signed-off-by: Lucas Wilkinson <[email protected]>
1 parent d9e577e commit a893712

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

hopper/block.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,12 @@ struct BlockMN {
4040
// If local, blocking (m_idx_max - m_idx_min + window_size_right + window_size_left)
4141
// when cp is not enabled, tot_seqlen_k is equal to seqlen_k, and cp_world_size is 1.
4242
// cp_world_size is guaranteed to be greater than 0
43-
n_block_max = std::min(n_block_max,
44-
cute::ceil_div(
45-
cute::ceil_div(m_idx_max + seqlen_info.tot_seqlen_k - seqlen_q + window_size_right - seqlen_info.cp_rank,
46-
seqlen_info.cp_world_size),
47-
kBlockN));
43+
int tot_seqlen_k = (Is_local) ? seqlen_k : seqlen_info.tot_seqlen_k;
44+
int n_token_max = m_idx_max + tot_seqlen_k - seqlen_q + window_size_right;
45+
if (seqlen_info.cp_world_size > 1 && !Is_local) {
46+
n_token_max = cute::ceil_div(n_token_max - seqlen_info.cp_rank, seqlen_info.cp_world_size);
47+
}
48+
n_block_max = std::min(n_block_max, cute::ceil_div(n_token_max, kBlockN));
4849
}
4950
// Now, only adjust n_block_min if split
5051
int n_block_min = 0;

hopper/flash_api.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1170,6 +1170,9 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
11701170
TORCH_CHECK(cp_world_size > 0, "cp_world_size must be positive, required by downstream unified code path. Use 1 if CP is not enabled.");
11711171
TORCH_CHECK(cp_world_size != 1 || cp_rank == 0, "When context parallelism is disabled, cp_rank must be zero");
11721172
TORCH_CHECK(cp_world_size == 1 || cp_tot_seqused_k_.has_value(), "cp_tot_seqused_k_ must be provided when context parallelism is enabled.");
1173+
TORCH_CHECK(!(params.is_local && cp_world_size > 1),
1174+
"Local attention (sliding window) is not currently supported with context parallelism (cp_world_size > 1)."
1175+
"Requires proper n_offset handling in block boundary calculations in mainloop and block.h");
11731176

11741177
#ifdef FLASHATTENTION_DISABLE_LOCAL
11751178
TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");

0 commit comments

Comments
 (0)