Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions vllm/model_executor/layers/mamba/mamba_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,6 @@ def forward_cuda(self, hidden_states: torch.Tensor, output: torch.Tensor):
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
has_initial_states_p = attn_metadata.has_initial_states_p
num_padded_decodes = attn_metadata.num_padded_decodes

# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
Expand Down Expand Up @@ -281,7 +280,7 @@ def forward_cuda(self, hidden_states: torch.Tensor, output: torch.Tensor):
state_indices_tensor,
num_prefill_tokens,
num_prefills,
num_padded_decodes,
num_decodes,
)
hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p
hidden_states_BC_d = prefill_decode_split.hidden_states_BC_d
Expand Down Expand Up @@ -470,24 +469,24 @@ def split_batch_to_prefill_and_decode(
state_indices_tensor: torch.Tensor,
num_prefill_tokens: int,
num_prefills: int,
num_padded_decodes: int,
num_decodes: int,
) -> PrefillDecodeSplit:
num_actual_tokens = num_prefill_tokens + num_padded_decodes
num_actual_tokens = num_prefill_tokens + num_decodes

# In v1, decode tokens come first, then prefill tokens.
hidden_states_BC_d, hidden_states_BC_p = torch.split(
hidden_states_BC[..., :num_actual_tokens],
[num_padded_decodes, num_prefill_tokens],
[num_decodes, num_prefill_tokens],
dim=-1,
)
gate_d, gate_p = torch.split(
gate[..., :num_actual_tokens], [num_padded_decodes, num_prefill_tokens], dim=-1
gate[..., :num_actual_tokens], [num_decodes, num_prefill_tokens], dim=-1
)

# num_padded_decodes accounts for CUDA graph padding when applicable
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor[: num_padded_decodes + num_prefills],
[num_padded_decodes, num_prefills],
state_indices_tensor[: num_decodes + num_prefills],
[num_decodes, num_prefills],
dim=0,
)

Expand Down
3 changes: 0 additions & 3 deletions vllm/v1/attention/backends/mamba1_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ class Mamba1AttentionMetadata:
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
num_padded_decodes: int

block_idx_last_scheduled_token: torch.Tensor # shape: [batch,]
block_idx_first_scheduled_token_p: torch.Tensor # shape: [batch,]
Expand Down Expand Up @@ -68,7 +67,6 @@ def build(

has_initial_states_p = None
query_start_loc_p = None
padded_decodes = num_decodes
num_computed_tokens, num_computed_tokens_p = None, None
block_idx_first_scheduled_token = None
block_idx_first_scheduled_token_p = None
Expand Down Expand Up @@ -157,7 +155,6 @@ def build(
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_padded_decodes=padded_decodes,
block_idx_last_scheduled_token=block_idx_last_scheduled_token,
block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p,
block_idx_last_computed_token=block_idx_last_computed_token,
Expand Down