Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,8 @@
class MLACommonDecodeMetadata:
block_table: torch.Tensor
seq_lens: torch.Tensor
query_base_positions: Optional[torch.Tensor] = field(default=None,

Check failure on line 362 in vllm/v1/attention/backends/mla/common.py

View workflow job for this annotation

GitHub Actions / pre-commit

No overload variant of "field" matches argument types "None", "bool" [call-overload]

Check failure on line 362 in vllm/v1/attention/backends/mla/common.py

View workflow job for this annotation

GitHub Actions / pre-commit

No overload variant of "field" matches argument types "None", "bool" [call-overload]

Check failure on line 362 in vllm/v1/attention/backends/mla/common.py

View workflow job for this annotation

GitHub Actions / pre-commit

No overload variant of "field" matches argument types "None", "bool" [call-overload]

Check failure on line 362 in vllm/v1/attention/backends/mla/common.py

View workflow job for this annotation

GitHub Actions / pre-commit

No overload variant of "field" matches argument types "None", "bool" [call-overload]
kw_only=True)


D = TypeVar("D", bound=MLACommonDecodeMetadata)
Expand Down Expand Up @@ -615,9 +617,19 @@
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int) -> MLACommonDecodeMetadata:

# Compute DCP query base positions if using DCP
query_base_positions = None

if self.dcp_world_size > 1:
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
query_base_positions = (seq_lens_cpu - query_lens).to(
seq_lens_device.device)

return MLACommonDecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens_device,
query_base_positions=query_base_positions,
Comment on lines +621 to +632
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change correctly computes query_base_positions for DCP. However, FlashAttnMLAMetadataBuilder in vllm/v1/attention/backends/mla/flashattn_mla.py overrides _build_decode and does not call this base implementation. As a result, query_base_positions will be None for the FlashAttention MLA backend, and the MTP with context parallelism feature will not work correctly.

To fix this, you should move this logic to FlashAttnMLAMetadataBuilder._build_decode or refactor it so that FlashAttnMLAMetadataBuilder can reuse this logic. For example, you could add the logic to FlashAttnMLAMetadataBuilder._build_decode and pass query_base_positions to the FlashAttnMLADecodeMetadata constructor.

)

def build_for_cudagraph_capture(
Expand Down
9 changes: 3 additions & 6 deletions vllm/v1/attention/backends/mla/flashattn_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from vllm.attention.utils.fa_utils import (flash_attn_supports_mla,
get_flash_attn_version)
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_dcp_group
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonDecodeMetadata,
Expand Down Expand Up @@ -99,11 +98,6 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
# pre-allocated during capture.
self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH

# TODO(lucas): Until we add support for the DCP custom masking we need
# to restrict decodes to q_len == 1 when DCP is enabled.
self.__class__.reorder_batch_threshold = 1 \
if get_dcp_group().world_size > 1 else self.reorder_batch_threshold

def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens,
max_seq_len, causal):
if self.fa_aot_schedule:
Expand Down Expand Up @@ -262,6 +256,9 @@ def _forward_decode(
fa_version=3, # only version 3 is supported
scheduler_metadata=attn_metadata.decode.scheduler_metadata,
num_splits=attn_metadata.decode.max_num_splits,
dcp_rank=self.dcp_rank,
dcp_world_size=self.dcp_world_size,
query_base_positions=attn_metadata.decode.query_base_positions,
)

if self.need_to_return_lse_for_decode:
Expand Down
6 changes: 0 additions & 6 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,12 +474,6 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
return

if self.reorder_batch_threshold is not None:
# NOTE(lucas): currently no backend supports the custom masking
# required for DCP with q_len > 1, so we assert here. Remove this
# assert once the custom mask is support is added to FA3.
if self.dcp_world_size > 1:
assert self.reorder_batch_threshold == 1, \
"DCP not support reorder_batch_threshold > 1 now."
reorder_batch_to_split_decodes_and_prefills(
self.input_batch,
scheduler_output,
Expand Down
Loading