-
Notifications
You must be signed in to change notification settings - Fork 97
FA3 variable length attention sort/swizzle #82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Jay Shah <[email protected]>
…or virtual batch metadata Signed-off-by: Jay Shah <[email protected]>
Signed-off-by: Jay Shah <[email protected]>
Signed-off-by: Jay Shah <[email protected]>
Signed-off-by: Jay Shah <[email protected]>
Signed-off-by: Jay Shah <[email protected]>
Signed-off-by: Jay Shah <[email protected]>
Signed-off-by: Jay Shah <[email protected]>
Signed-off-by: Jay Shah <[email protected]>
Signed-off-by: Jay Shah <[email protected]>
Signed-off-by: Jay Shah <[email protected]>
Signed-off-by: Jay Shah <[email protected]>
Signed-off-by: Jay Shah <[email protected]>
Signed-off-by: Jay Shah <[email protected]>
seqlen = params.seqlen; | ||
if constexpr (Prepared) { | ||
return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 | ||
? cute::ceil_div(params.prepare_seqlen_q_ptr[batch_idx], kBlockM) : 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right now vLLM is a bit annoying and we actually compute the attention metadata (and as a result the mha_fwd_get_scheduler_metadata
) before knowing how many requests we sill pad too; this means that scheduler metadata will be for a batch size than what params.b is at runtime. This is normally fine since cu_seqlens
is padded to make sure all requests up to max batch size are seqlen_q == 0 so FA returns before touching any bad memory; however if this reads garbage from prepare_seqlen_q_ptr
this might break? We can probably zero the metadata here: https://github.com/neuralmagic/vllm/blob/a75c6e034abf00603fba527625e44baab7b42f80/vllm/v1/attention/backends/flash_attn.py#L333-L338
(this is a historical artifact of thinking that piecewise cudagraphs would be enough in V1 and we wouldn't need attention to be in a cudagraph; so this may be re-architected in the near future)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually we might have to do a more aggressive refactor on the vLLM side since I think an even bigger problem is that all of the offsets will be wrong:
int sort_offset = b_rounded * (use_dynamic_split ? 2 : 1);
int head_swizzle_offset = b_rounded * (num_prepare_batch_vectors - 1);
int tile_count_semaphore_offset = b_rounded * num_prepare_batch_vectors;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the other option would be to make the scheduler metadata an "Array of Structs" instead of a "Struct of Arrays", then I the offsets wouldn't be dependent on the batch size the scheduler used (and we could more easily just 0 out the rest of the metadata)
how hard do you think this would be / how badly do you think this would hurt perf
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the other option would be to make the scheduler metadata an "Array of Structs" instead of a "Struct of Arrays", then I the offsets wouldn't be dependent on the batch size the scheduler used (and we could more easily just 0 out the rest of the metadata)
how hard do you think this would be / how badly do you think this would hurt perf
I could write out as int4 array instead, but wouldn't have coalesced accesses when reading back in, so would like to avoid if at all possible.
Can we pass in a max batch size to set the offsets correctly?
vllm side mirror of Dao-AILab#1823