Skip to content

Conversation

@MatthewBonanni
Copy link
Contributor

@MatthewBonanni MatthewBonanni commented Dec 5, 2025

Purpose

Starting up DeepSeek R1 DP8/EP on 8xH200 currently OOMs at the default --gpu-memory-utilization (0.9). This PR prevents an unnecessary 4 GiB allocation during the post-CG-capture dummy run, allowing it to start up without having to reduce --gpu-memory-utilization. It still OOMs when prompted, though, so while this improves the situation and reflects the original intent of the code, it doesn't solve the problem.

Test Plan

vllm serve deepseek-ai/DeepSeek-R1 -dp 8 --enable-expert-parallel

Test Result

main: OOM during server startup

PR branch: no longer OOMs during startup


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Matthew Bonanni <[email protected]>
@mergify mergify bot added deepseek Related to DeepSeek models v1 labels Dec 5, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses an out-of-memory (OOM) error during the startup of DeepSeek R1 models. The fix is achieved by making a large, unnecessary memory allocation conditional. This allocation, intended for worst-case memory profiling or CUDA graph capture, is now skipped during other dummy runs, such as the warmup phase before graph capture.

The changes are implemented by introducing an is_memory_profile flag in the ForwardContext and using it, along with the cudagraph_runtime_mode, to control the allocation in vllm/v1/attention/backends/mla/common.py.

My review of the changes indicates that the logic is sound and correctly targets the source of the OOM issue. The modifications are clean, well-contained, and effectively resolve the problem without introducing any apparent side effects. The code quality is good, and I have no further suggestions for improvement.

@MatthewBonanni MatthewBonanni marked this pull request as ready for review December 6, 2025 00:09
@MatthewBonanni
Copy link
Contributor Author

cc @LucasWilkinson

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Signed-off-by: Matthew Bonanni <[email protected]>
Comment on lines +208 to +211
# set dynamically for each forward pass
# True during memory profiling, False otherwise
is_memory_profile: bool = False

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we avoid adding too many things to the forward_context? It is becoming increasingly complicated and I am increasingly worried about this class getting more and more bloated. cc @WoosukKwon @youkaichao

Copy link
Member

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also seems like this is probably no longer needed with model runner v2?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models v1

Projects

Status: Backlog

Development

Successfully merging this pull request may close these issues.

2 participants