Skip to content

Conversation

@LCAIZJ
Copy link
Contributor

@LCAIZJ LCAIZJ commented Aug 29, 2025

Purpose

  1. D is memory bound so greater TP size will yield better performance. However, in the MLA scenario, each TP worker maintains the full kvcache, which prevents actual GPU memory savings. As a result, when MLA is enabled, the TP size for P often exceeds that of D.
  2. When the TP size for P exceeds that of D, the current mechanism of kv_output_aggregator fails to function correctly. As illustrated in the figure below, both finished_sending and finished_recving should be 2 in this scenario. However, the kv_output_aggregator incorrectly initializes them to world_size, causing finished_sending to remain at 4 and preventing the release of P's kvcache.
image
  1. To resolve this issue, we introduce a new interface get_finished_count() in the connector's base class. This method is implemented by specific connector subclasses to return the correct values for finished_sending and finished_recving. When different executors invoke get_finished_count() to initialize the kv_output_aggregator, it ensures the counters are configured correctly.

Test Plan

Model:DeepSeek-R1
Distributed Strategy:1P(TP16)1D(DP4、TP4)

Test Result

Successfully processed 10,000 benchmark prompts with 4K input tokens and 1.5K output tokens per request.
image
Since the primary goal was to verify overall functionality, we did not enable many optimization features during benchmark execution. As a result, the TTFT and TPOT metrics are suboptimal, though the verification of the core functionality was successfully completed.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the initialization of KVOutputAggregator to support heterogeneous configurations by querying the number of participants from the KV connector. The changes are well-structured, introducing a new method in the connector's base class and updating executors to use it. However, there is a critical issue in both MultiprocExecutor and RayDistributedExecutor where a None return from get_finished_count() is not handled, which would lead to a TypeError at runtime. I've provided suggestions to fix this by adding a fallback mechanism.

Comment on lines 328 to 335
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The get_finished_count() method on KVConnectorBase_V1 is defined to return Optional[int], and its base implementation returns None. The KVOutputAggregator constructor expects an int, so passing None to it will cause a TypeError. This is a critical issue that can lead to a runtime crash if a connector that does not override get_finished_count() is used. You should handle the None case, for example by falling back to self.parallel_config.world_size.

Suggested change
def init_kv_output_aggregator(self) -> None:
if has_kv_transfer_group():
kv_connector = get_kv_transfer_group()
self.kv_output_aggregator = KVOutputAggregator(
kv_connector.get_finished_count())
else:
self.kv_output_aggregator = KVOutputAggregator(
self.parallel_config.world_size)
def init_kv_output_aggregator(self) -> None:
world_size = self.parallel_config.world_size
if has_kv_transfer_group():
kv_connector = get_kv_transfer_group()
finished_count = kv_connector.get_finished_count()
if finished_count is not None:
world_size = finished_count
self.kv_output_aggregator = KVOutputAggregator(world_size)

Comment on lines 111 to 118
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The get_finished_count() method on KVConnectorBase_V1 is defined to return Optional[int], and its base implementation returns None. The KVOutputAggregator constructor expects an int, so passing None to it will cause a TypeError. This is a critical issue that can lead to a runtime crash if a connector that does not override get_finished_count() is used. You should handle the None case, for example by falling back to self.parallel_config.world_size.

Suggested change
def init_kv_output_aggregator(self) -> None:
if has_kv_transfer_group():
kv_connector = get_kv_transfer_group()
self.kv_output_aggregator = KVOutputAggregator(
kv_connector.get_finished_count())
else:
self.kv_output_aggregator = KVOutputAggregator(
self.parallel_config.world_size)
def init_kv_output_aggregator(self) -> None:
world_size = self.parallel_config.world_size
if has_kv_transfer_group():
kv_connector = get_kv_transfer_group()
finished_count = kv_connector.get_finished_count()
if finished_count is not None:
world_size = finished_count
self.kv_output_aggregator = KVOutputAggregator(world_size)

@LCAIZJ LCAIZJ force-pushed the feat/finish_count branch from fa99f75 to da4ed3d Compare August 29, 2025 09:05
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@LCAIZJ LCAIZJ changed the title [WIP]Fix kv_output_aggregator support heterogeneous [WIP] kv_output_aggregator support heterogeneous Aug 31, 2025
@LCAIZJ LCAIZJ force-pushed the feat/finish_count branch 2 times, most recently from 7a0d2b4 to 1b59338 Compare September 1, 2025 03:55
@LCAIZJ LCAIZJ changed the title [WIP] kv_output_aggregator support heterogeneous kv_output_aggregator support heterogeneous Sep 2, 2025
@LCAIZJ LCAIZJ force-pushed the feat/finish_count branch 2 times, most recently from 70f10c8 to 7073da9 Compare September 2, 2025 16:07
wangxiyuan pushed a commit to vllm-project/vllm-ascend that referenced this pull request Sep 4, 2025
### What this PR does / why we need it?
In vllm version 0.10.1, a new KVOutputAggregator was added to the
executor, moving aggregation to the
executor(vllm-project/vllm#19555). This caused
mooncake_connector to break. This change aims to fix this bug and also
adds a policy to forcibly release the KV cache when the prefill node
times out.

This PR is currently linked to a PR in vllm
(vllm-project/vllm#23917). The vllm PR aims to
modify the finish and send count confirmation in heterogeneous TP
situations.

The reason for deleting many UTs is that a lot of communication codes
have been deleted, so the UT as a whole will appear more concise.

- vLLM version: v0.10.1.1
- vLLM main:
vllm-project/vllm@fa4311d

---------

Signed-off-by: baxingpiaochong <[email protected]>
Copy link
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for contributing to this @LCAIZJ !

One big issue with this interface is that heterogeneous TP is, for some connectors like Nixl but I would argue more broadly for discovery-based ones, a "runtime property", as the connector only finds out about the heterogeneous state after the first handshake (assuming "static" P configuration on the other side).
Hence here we wouldn't be able to return a sensible get_finished_count at init time.

Successfully processed 10,000 benchmark prompts with 4K input tokens and 1.5K output tokens per request.

What connector did you use for this, as I don't see any concrete get_finished_count impl on this PR?

Also, it would be nice if we could add some minimal unit tests, perhaps testing the example you reported.


PS

When the TP size for P exceeds that of D, the current mechanism of kv_output_aggregator fails to function correctly

This is also partly the reason why in disagg PD settings with NixlConnector we do not allow P TP size to exceed D's https://github.com/vllm-project/vllm/blob/main/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py#L909.
But I would like to understand, how popular/important is this MLA+P TP>D PT use-case, rather than assuming MoE and going wide EP?
I feel maintaining both P-TP</>D-TP in NixlConnector might complicate code quite a bit.
Nothing against supporting that for other connectors though ofc.

@zzy-ContiLearn
Copy link

Thanks a lot for contributing to this @LCAIZJ !

One big issue with this interface is that heterogeneous TP is, for some connectors like Nixl but I would argue more broadly for discovery-based ones, a "runtime property", as the connector only finds out about the heterogeneous state after the first handshake (assuming "static" P configuration on the other side). Hence here we wouldn't be able to return a sensible get_finished_count at init time.

Successfully processed 10,000 benchmark prompts with 4K input tokens and 1.5K output tokens per request.

What connector did you use for this, as I don't see any concrete get_finished_count impl on this PR?

Also, it would be nice if we could add some minimal unit tests, perhaps testing the example you reported.

PS

When the TP size for P exceeds that of D, the current mechanism of kv_output_aggregator fails to function correctly

This is also partly the reason why in disagg PD settings with NixlConnector we do not allow P TP size to exceed D's https://github.com/vllm-project/vllm/blob/main/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py#L909. But I would like to understand, how popular/important is this MLA+P TP>D PT use-case, rather than assuming MoE and going wide EP? I feel maintaining both P-TP</>D-TP in NixlConnector might complicate code quite a bit. Nothing against supporting that for other connectors though ofc.

Thanks a lot for contributing to this @LCAIZJ !

One big issue with this interface is that heterogeneous TP is, for some connectors like Nixl but I would argue more broadly for discovery-based ones, a "runtime property", as the connector only finds out about the heterogeneous state after the first handshake (assuming "static" P configuration on the other side). Hence here we wouldn't be able to return a sensible get_finished_count at init time.

Successfully processed 10,000 benchmark prompts with 4K input tokens and 1.5K output tokens per request.

What connector did you use for this, as I don't see any concrete get_finished_count impl on this PR?

Also, it would be nice if we could add some minimal unit tests, perhaps testing the example you reported.

PS

When the TP size for P exceeds that of D, the current mechanism of kv_output_aggregator fails to function correctly

This is also partly the reason why in disagg PD settings with NixlConnector we do not allow P TP size to exceed D's https://github.com/vllm-project/vllm/blob/main/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py#L909. But I would like to understand, how popular/important is this MLA+P TP>D PT use-case, rather than assuming MoE and going wide EP? I feel maintaining both P-TP</>D-TP in NixlConnector might complicate code quite a bit. Nothing against supporting that for other connectors though ofc.

Regarding the design of which connector might involve P's TP size exceeding D's TP size as you mentioned, please refer to the implementation in this PR and file: https://github.com/vllm-project/vllm-ascend/pull/2664/files#diff-033f9a4af4a59c8ca0ed782d5821675ffcc15f64625bc37afa95a886edb373d1.

From my personal perspective, it would be ideal for the connector design to be compatible with various TP_SIZE ratios (relying on long-term evolution). Currently, the connector mentioned above is primarily used in the vllm-ascend community.

As a side note, if we could write the TP information of P and D into kv_connector_extra_config within kv_transfer_config, then read it in the executor, calculate the actual target_count (the true completion count), and pass it during aggregator initialization—this might also be a viable implementation approach.

@LCAIZJ
Copy link
Contributor Author

LCAIZJ commented Sep 5, 2025

Thanks a lot for contributing to this @LCAIZJ !

One big issue with this interface is that heterogeneous TP is, for some connectors like Nixl but I would argue more broadly for discovery-based ones, a "runtime property", as the connector only finds out about the heterogeneous state after the first handshake (assuming "static" P configuration on the other side). Hence here we wouldn't be able to return a sensible get_finished_count at init time.

Successfully processed 10,000 benchmark prompts with 4K input tokens and 1.5K output tokens per request.

What connector did you use for this, as I don't see any concrete get_finished_count impl on this PR?

Also, it would be nice if we could add some minimal unit tests, perhaps testing the example you reported.

PS

When the TP size for P exceeds that of D, the current mechanism of kv_output_aggregator fails to function correctly

This is also partly the reason why in disagg PD settings with NixlConnector we do not allow P TP size to exceed D's https://github.com/vllm-project/vllm/blob/main/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py#L909. But I would like to understand, how popular/important is this MLA+P TP>D PT use-case, rather than assuming MoE and going wide EP? I feel maintaining both P-TP</>D-TP in NixlConnector might complicate code quite a bit. Nothing against supporting that for other connectors though ofc.

Hi @NickLucche ,

Thank you for your review comments. My responses to your questions are as follows:

  1. Our test is based on the MooncakeConnector in vllm-ascend https://github.com/vllm-project/vllm-ascend/blob/main/vllm_ascend/distributed/mooncake_connector.py, which implements the get_finished_count() method [P/D]mooncake_connector adapted to 0.10.1 vllm-ascend#2664.

  2. Since the TP size for P and D can be determined at engine initialization, the connector retrieves these values from the configuration file, enabling it to return a valid get_finished_count during initialization.

image
  1. In cases like MLA and GQA, in our production environments the TP size for P is larger than that for D. Additionally, different hardware characteristics may also influence the optimal TP size selection for P and D. Therefore, I believe the connector should ideally support all scenarios where P and D TP sizes are heterogeneous.

  2. Regarding the unit tests concerns, this PR primarily modifies interfaces. We believe the changes are safe as the CI pipeline already validates the implementation classes of the interface through existing unit tests. Additionally, we've verified the changes through end-to-end testing with 10,000 test cases. Should you identify specific components requiring additional unit test coverage, we're happy to add them in follow-up commits.

@youkaichao
Copy link
Member

how popular/important is this MLA+P TP>D PT use-case

@NickLucche For MLA, people usually do Decode TP=1, and Prefill TP > 1.

@LCAIZJ LCAIZJ requested a review from NickLucche September 5, 2025 23:32
@LCAIZJ
Copy link
Contributor Author

LCAIZJ commented Sep 8, 2025

how popular/important is this MLA+P TP>D PT use-case

@NickLucche For MLA, people usually do Decode TP=1, and Prefill TP > 1.

Indeed

@NickLucche
Copy link
Collaborator

Apologies for the delay, let me get back to this PR tomorrow.

@robertgshaw2-redhat robertgshaw2-redhat changed the title kv_output_aggregator support heterogeneous [P/D][NIXL] kv_output_aggregator support heterogeneous Sep 8, 2025
@mergify
Copy link

mergify bot commented Sep 9, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @LCAIZJ.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@LCAIZJ
Copy link
Contributor Author

LCAIZJ commented Sep 13, 2025

@NickLucche All CI pipelines have passed. Would you mind reviewing the PR again when you have time?

@LCAIZJ LCAIZJ requested a review from NickLucche September 13, 2025 12:33
@NickLucche NickLucche enabled auto-merge (squash) September 15, 2025 06:55
@NickLucche
Copy link
Collaborator

This is likely the most minimal enabling set of changes now, thank you!

@LCAIZJ
Copy link
Contributor Author

LCAIZJ commented Sep 15, 2025

This is likely the most minimal enabling set of changes now, thank you!

@NickLucche Thank you for your feedback. If everything looks good, could you please approve this PR? It's currently configured to require review approval before it can be merged automatically.

@NickLucche NickLucche merged commit 8de261b into vllm-project:main Sep 15, 2025
48 checks passed
@NickLucche
Copy link
Collaborator

NickLucche commented Sep 15, 2025

Apologies, I had missed the force push

bbartels pushed a commit to bbartels/vllm that referenced this pull request Sep 15, 2025
Signed-off-by: LCAIZJ <[email protected]>
Co-authored-by: leichao.lc <[email protected]>
Signed-off-by: bbartels <[email protected]>
offline893 pushed a commit to offline893/vllm-ascend that referenced this pull request Sep 16, 2025
### What this PR does / why we need it?
In vllm version 0.10.1, a new KVOutputAggregator was added to the
executor, moving aggregation to the
executor(vllm-project/vllm#19555). This caused
mooncake_connector to break. This change aims to fix this bug and also
adds a policy to forcibly release the KV cache when the prefill node
times out.

This PR is currently linked to a PR in vllm
(vllm-project/vllm#23917). The vllm PR aims to
modify the finish and send count confirmation in heterogeneous TP
situations.

The reason for deleting many UTs is that a lot of communication codes
have been deleted, so the UT as a whole will appear more concise.

- vLLM version: v0.10.1.1
- vLLM main:
vllm-project/vllm@fa4311d

---------

Signed-off-by: baxingpiaochong <[email protected]>
Signed-off-by: offline0806 <[email protected]>
wangxiaoteng888 pushed a commit to LCAIZJ/vllm-ascend that referenced this pull request Sep 25, 2025
### What this PR does / why we need it?
In vllm version 0.10.1, a new KVOutputAggregator was added to the
executor, moving aggregation to the
executor(vllm-project/vllm#19555). This caused
mooncake_connector to break. This change aims to fix this bug and also
adds a policy to forcibly release the KV cache when the prefill node
times out.

This PR is currently linked to a PR in vllm
(vllm-project/vllm#23917). The vllm PR aims to
modify the finish and send count confirmation in heterogeneous TP
situations.

The reason for deleting many UTs is that a lot of communication codes
have been deleted, so the UT as a whole will appear more concise.

- vLLM version: v0.10.1.1
- vLLM main:
vllm-project/vllm@fa4311d

---------

Signed-off-by: baxingpiaochong <[email protected]>
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
chopper0126 pushed a commit to chopper0126/vllm-ascend that referenced this pull request Sep 26, 2025
### What this PR does / why we need it?
In vllm version 0.10.1, a new KVOutputAggregator was added to the
executor, moving aggregation to the
executor(vllm-project/vllm#19555). This caused
mooncake_connector to break. This change aims to fix this bug and also
adds a policy to forcibly release the KV cache when the prefill node
times out.

This PR is currently linked to a PR in vllm
(vllm-project/vllm#23917). The vllm PR aims to
modify the finish and send count confirmation in heterogeneous TP
situations.

The reason for deleting many UTs is that a lot of communication codes
have been deleted, so the UT as a whole will appear more concise.

- vLLM version: v0.10.1.1
- vLLM main:
vllm-project/vllm@fa4311d

---------

Signed-off-by: baxingpiaochong <[email protected]>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
Signed-off-by: LCAIZJ <[email protected]>
Co-authored-by: leichao.lc <[email protected]>
Signed-off-by: xuebwang-amd <[email protected]>
choprahetarth pushed a commit to Tandemn-Labs/vllm that referenced this pull request Oct 11, 2025
Angazenn pushed a commit to Angazenn/vllm-ascend that referenced this pull request Oct 21, 2025
### What this PR does / why we need it?
In vllm version 0.10.1, a new KVOutputAggregator was added to the
executor, moving aggregation to the
executor(vllm-project/vllm#19555). This caused
mooncake_connector to break. This change aims to fix this bug and also
adds a policy to forcibly release the KV cache when the prefill node
times out.

This PR is currently linked to a PR in vllm
(vllm-project/vllm#23917). The vllm PR aims to
modify the finish and send count confirmation in heterogeneous TP
situations.

The reason for deleting many UTs is that a lot of communication codes
have been deleted, so the UT as a whole will appear more concise.

- vLLM version: v0.10.1.1
- vLLM main:
vllm-project/vllm@fa4311d

---------

Signed-off-by: baxingpiaochong <[email protected]>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
Signed-off-by: LCAIZJ <[email protected]>
Co-authored-by: leichao.lc <[email protected]>
Signed-off-by: xuebwang-amd <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants