Skip to content

Conversation

@LucasWilkinson
Copy link
Collaborator

When doing prefill up-convert the kv-cache from fp8 to bf16 and call the bf16 prefill kernel instead of the decode kernel. This PR introduce global workspace management to have the bf16 workspace overlap with the MoE workspace buffers.

Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
@mergify mergify bot added deepseek Related to DeepSeek models v1 labels Oct 26, 2025
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +125 to +136
def get(self, spec: "WorkspaceSpec") -> torch.Tensor:
"""Get a workspace tensor for the given spec.
Args:
spec: The workspace specification.
Returns:
A tensor view into the workspace buffer with the requested shape and dtype.
"""
num_bytes = spec.num_bytes()
current_workspace = self._ensure_workspace_size(num_bytes, spec.name)
return current_workspace[:num_bytes].view(spec.dtype).reshape(spec.shape)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Allocating workspaces fails due to invalid view call

WorkspaceManager.get reinterprets the byte buffer with current_workspace[:num_bytes].view(spec.dtype) but Tensor.view only accepts a shape, not a dtype. Passing a torch.dtype raises TypeError: 'torch.dtype' object cannot be interpreted as an integer, so every call to reserve/get will crash before returning a workspace. The manager needs to reshape using a size tuple and cast with view(dtype) via reinterpret_cast semantics (e.g. view(-1).view(spec.dtype) or view(dtype).reshape).

Useful? React with 👍 / 👎.

Comment on lines +794 to +805
# Process decode tokens
if num_decode_tokens > 0:
attn_out = self._forward_fp8_kv(
q[:num_decode_tokens],
kv_cache,
topk_indices_global[:num_decode_tokens],
attn_metadata,
)

if num_prefill_tokens > 0:
decode_attn_out = attn_out
attn_out = q.new_empty(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 Badge Prefill-only batches reference attn_out before initialization

In the fp8 path of FlashMLASparseImpl.forward, attn_out is only assigned inside the if num_decode_tokens > 0 branch. The subsequent if num_prefill_tokens > 0 branch unconditionally reads decode_attn_out = attn_out, which raises UnboundLocalError whenever a batch contains only prefill tokens. Prefill batches are common during initial context ingestion, so this path will always fail until attn_out is initialized for the prefill case.

Useful? React with 👍 / 👎.

Comment on lines +803 to +811
if num_prefill_tokens > 0:
decode_attn_out = attn_out
attn_out = q.new_empty(
(num_actual_toks, self.num_heads, self.kv_lora_rank),
dtype=q.dtype,
device=q.device,
)
attn_out[:num_prefill_tokens] = decode_attn_out[:num_prefill_tokens]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Decode outputs stored into prefill slots

When both decode and prefill tokens exist, the fp8 path copies decode attention results with attn_out[:num_prefill_tokens] = decode_attn_out[:num_prefill_tokens]. Decode tokens occupy the first num_decode_tokens entries, so this writes them into the wrong slice and fails whenever num_prefill_tokens > num_decode_tokens because the right-hand side is shorter than the target. The assignment should use num_decode_tokens to preserve decode outputs and avoid size mismatches.

Useful? React with 👍 / 👎.

Signed-off-by: Lucas Wilkinson <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant