-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
[Attention] Support MTP with DCP #24997
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Matthew Bonanni <[email protected]>
Signed-off-by: Matthew Bonanni <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request aims to enable Multi-Token Parallelism (MTP) with context parallelism by supporting query_len > 1
in the FlashAttention MLA backend. The changes involve removing previous restrictions and adding metadata for a custom causal mask.
I've found a critical issue where the new logic to compute query_base_positions
in MLACommonMetadataBuilder
is not being used by FlashAttnMLAMetadataBuilder
because it overrides the _build_decode
method. This will prevent the feature from working as intended. Please see my detailed comment.
# Compute DCP query base positions if using DCP | ||
query_base_positions = None | ||
|
||
if self.dcp_world_size > 1: | ||
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] | ||
query_base_positions = (seq_lens_cpu - query_lens).to( | ||
seq_lens_device.device) | ||
|
||
return MLACommonDecodeMetadata( | ||
block_table=block_table_tensor, | ||
seq_lens=seq_lens_device, | ||
query_base_positions=query_base_positions, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change correctly computes query_base_positions
for DCP. However, FlashAttnMLAMetadataBuilder
in vllm/v1/attention/backends/mla/flashattn_mla.py
overrides _build_decode
and does not call this base implementation. As a result, query_base_positions
will be None
for the FlashAttention MLA backend, and the MTP with context parallelism feature will not work correctly.
To fix this, you should move this logic to FlashAttnMLAMetadataBuilder._build_decode
or refactor it so that FlashAttnMLAMetadataBuilder
can reuse this logic. For example, you could add the logic to FlashAttnMLAMetadataBuilder._build_decode
and pass query_base_positions
to the FlashAttnMLADecodeMetadata
constructor.
This pull request has merge conflicts that must be resolved before it can be |
superseded by #25049 |
Purpose
#24453 Added DCP support but did not support
query_len > 1
. This PR, which depends on a corresponding FlashAttention PR (vllm-project/flash-attention#92), implements a custom causal mask to take advantage of the FlashAttention MLA backend's capability forquery_len > 1
, thereby enabling MTP.Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.