Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions vllm/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,10 @@ class ForwardContext:

ubatch_slices: UBatchSlices | None = None

# set dynamically for each forward pass
# True during memory profiling, False otherwise
is_memory_profile: bool = False

Comment on lines +208 to +211
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we avoid adding too many things to the forward_context? It is becoming increasingly complicated and I am increasingly worried about this class getting more and more bloated. cc @WoosukKwon @youkaichao

def __post_init__(self):
assert self.cudagraph_runtime_mode.valid_runtime_modes(), (
f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}"
Expand Down Expand Up @@ -235,6 +239,7 @@ def create_forward_context(
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: BatchDescriptor | None = None,
ubatch_slices: UBatchSlices | None = None,
is_memory_profile: bool = False,
):
return ForwardContext(
no_compile_layers=vllm_config.compilation_config.static_forward_context,
Expand All @@ -244,6 +249,7 @@ def create_forward_context(
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor,
ubatch_slices=ubatch_slices,
is_memory_profile=is_memory_profile,
)


Expand Down Expand Up @@ -272,6 +278,7 @@ def set_forward_context(
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: BatchDescriptor | None = None,
ubatch_slices: UBatchSlices | None = None,
is_memory_profile: bool = False,
):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
Expand Down Expand Up @@ -317,6 +324,7 @@ def set_forward_context(
cudagraph_runtime_mode,
batch_descriptor,
ubatch_slices,
is_memory_profile,
)

try:
Expand Down
31 changes: 19 additions & 12 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,9 @@
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.config.compilation import CUDAGraphMode
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
Expand Down Expand Up @@ -1917,18 +1919,23 @@ def forward(
)

if attn_metadata is None:
# During the profile run try to simulate to worse case output size
# for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context`
# since this can be large
_ = torch.empty(
(
self.chunked_prefill_workspace_size,
self.num_heads,
self.qk_nope_head_dim + self.v_head_dim,
),
device=k_c_normed.device,
dtype=k_c_normed.dtype,
)
# During the profile run or cudagraph capture try to simulate worst case
# output size for `self.kv_b_proj(kv_c_normed)` in
# `_compute_prefill_context` since this can be large
forward_ctx = get_forward_context()
if (
forward_ctx.is_memory_profile
or forward_ctx.cudagraph_runtime_mode != CUDAGraphMode.NONE
):
_ = torch.empty(
(
self.chunked_prefill_workspace_size,
self.num_heads,
self.qk_nope_head_dim + self.v_head_dim,
),
device=k_c_normed.device,
dtype=k_c_normed.dtype,
)

# The zero fill is required when used with DP + EP
# to ensure all ranks within a DP group compute the
Expand Down
Loading