-
-
Couldn't load subscription status.
- Fork 10.8k
[Attention] Use sparse prefill kernel for fp8 kv-cache in DeepSeek-v3.2 #27532
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Attention] Use sparse prefill kernel for fp8 kv-cache in DeepSeek-v3.2 #27532
Conversation
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| def get(self, spec: "WorkspaceSpec") -> torch.Tensor: | ||
| """Get a workspace tensor for the given spec. | ||
| Args: | ||
| spec: The workspace specification. | ||
| Returns: | ||
| A tensor view into the workspace buffer with the requested shape and dtype. | ||
| """ | ||
| num_bytes = spec.num_bytes() | ||
| current_workspace = self._ensure_workspace_size(num_bytes, spec.name) | ||
| return current_workspace[:num_bytes].view(spec.dtype).reshape(spec.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Allocating workspaces fails due to invalid
view call
WorkspaceManager.get reinterprets the byte buffer with current_workspace[:num_bytes].view(spec.dtype) but Tensor.view only accepts a shape, not a dtype. Passing a torch.dtype raises TypeError: 'torch.dtype' object cannot be interpreted as an integer, so every call to reserve/get will crash before returning a workspace. The manager needs to reshape using a size tuple and cast with view(dtype) via reinterpret_cast semantics (e.g. view(-1).view(spec.dtype) or view(dtype).reshape).
Useful? React with 👍 / 👎.
| # Process decode tokens | ||
| if num_decode_tokens > 0: | ||
| attn_out = self._forward_fp8_kv( | ||
| q[:num_decode_tokens], | ||
| kv_cache, | ||
| topk_indices_global[:num_decode_tokens], | ||
| attn_metadata, | ||
| ) | ||
|
|
||
| if num_prefill_tokens > 0: | ||
| decode_attn_out = attn_out | ||
| attn_out = q.new_empty( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prefill-only batches reference
attn_out before initialization
In the fp8 path of FlashMLASparseImpl.forward, attn_out is only assigned inside the if num_decode_tokens > 0 branch. The subsequent if num_prefill_tokens > 0 branch unconditionally reads decode_attn_out = attn_out, which raises UnboundLocalError whenever a batch contains only prefill tokens. Prefill batches are common during initial context ingestion, so this path will always fail until attn_out is initialized for the prefill case.
Useful? React with 👍 / 👎.
| if num_prefill_tokens > 0: | ||
| decode_attn_out = attn_out | ||
| attn_out = q.new_empty( | ||
| (num_actual_toks, self.num_heads, self.kv_lora_rank), | ||
| dtype=q.dtype, | ||
| device=q.device, | ||
| ) | ||
| attn_out[:num_prefill_tokens] = decode_attn_out[:num_prefill_tokens] | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Decode outputs stored into prefill slots
When both decode and prefill tokens exist, the fp8 path copies decode attention results with attn_out[:num_prefill_tokens] = decode_attn_out[:num_prefill_tokens]. Decode tokens occupy the first num_decode_tokens entries, so this writes them into the wrong slice and fails whenever num_prefill_tokens > num_decode_tokens because the right-hand side is shorter than the target. The assignment should use num_decode_tokens to preserve decode outputs and avoid size mismatches.
Useful? React with 👍 / 👎.
Signed-off-by: Lucas Wilkinson <[email protected]>
When doing prefill up-convert the kv-cache from fp8 to bf16 and call the bf16 prefill kernel instead of the decode kernel. This PR introduce global workspace management to have the bf16 workspace overlap with the MoE workspace buffers.