Skip to content

Conversation

@NickLucche
Copy link
Collaborator

@NickLucche NickLucche commented Oct 21, 2025

Overview

This PR addresses the following case, P tensor-parallel-size > D tensor-parallel-size.

I think it helps to differentiate two main cases

MLA

For MLA model, the workflow is easier: each D worker reads from some other single P worker (fan-out reads to avoid all reading from same remote), as MLA cache is duplicated. Some P workers will not be read from at all.
Mind that this also holds for the DP/EP deployment, where TP size on D will often be 1!

image

From PR #23917, which also serves as good use-case. Btw as explained in that PR, the number of requests to "expect" is indeed the number of remote instances reading from P.

The main issue to implement that in Nixl is that each P worker will track requests as they come in (_reqs_to_send, _reqs_to_process) and those structures are only cleared properly when a read is detected (o/w timeouts would be raised on P).
To address that, I am allowing MLA D ranks to only execute one transfer, but notifying all affected remote that the read is completed (sending multiple nixl notifs).

cc @njhill @markmc

Dense

For dense models, every D worker will read from n P workers to re-compose its own KV cache, where n is referred to as tp_ratio in code.

image

This is possible because number of heads on P is H/n that of D's, so you can efficiently read into D's cache using HND layout. That is, in memory, you're just laying out flat ND tensors H/n , n times

Side note: current design is flexible and allows for dynamic discovery of remotes with different tp_sizes. However this is not a feature that is currently supported, but it helps to take into account when considering impl choices. It's more of an optional route I'd like to keep open.

Changes

The main change this PR needs to allow is for a D worker to read from multiple P's.
Practical edits this PR introduces to do so:

  • src_xfer_side_chunked_handles: local regions need to be split differently based on how many remotes we want to read from. This is prepared during handshake, once .
  • a few structures go from single remote to [engine_id][rank_no] to accomodate the above
  • get_target_remote->get_target_remotes for the same reason, + a bunch of for loops over its result
  • P has to wait for at most a single read notification (communicated from D)
  • tp_ratio extension to indicate remote P size greater than D
  • multiple xfers/handles per request: this was partly already supported, I just fixed a bug in _pop_done_transfers
  • multiple notifs-single read to optimize for MLA models

How to test

pytest -v -s -x tests/v1/kv_connector/unit/test_nixl_connector.py::TestNixlHandshake::test_prefill_tp_size_greater_than_decode_tp_size/test_prefill_tp_size_greater_than_decode_tp_size_mla

And check out tp_config_sweep_accuracy with config:

PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1 bash tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh

TODO

Coming soon to this PR:

  • Avoid stranding requests on P for MLA
  • [ ] On MLA with DP/EP, avoid having all workers read from same remote deferring
  • DP_EP tests
    It does NOT support replicated KV heads scenario, tp_size>num_heads. This is definitely doable, just I believe on weak demand atm so we can postpone it.

@NickLucche
Copy link
Collaborator Author

cc @GuanLuo let me know if this PR meets the expected set of features you aimed to get with your work. Thank you!

class DummyModelRunnerOutput(ModelRunnerOutput):
def __init__(
self,
finished_sending: set[str] | None = None,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ignore file, to be rebased once #26734 lands

"""
Get the count of requests expected to complete send/receive operations
via this connector.
via this connector. This method is used to initialize the
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ignore, to be rebased once #26734 lands

tp_ratio,
)

### (Optional) Register local agent memory regions. MLA is not split.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gist of the PR

# on notification so that dst worker can wait before freeing blocks.
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(dst_engine_id)
# Cap to 1 when P TP > D TP: only a single rank will read from remote.
tp_ratio = max(1, self.kv_topo.tp_ratio_from_engine_id(dst_engine_id))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is to have P only wait for 1 request instead of -tp_ratio

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
KV cache helper for store.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ignore file, to be rebased once #26734 lands


import asyncio
import time
from abc import ABC, abstractmethod
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ignore file, to be rebased once #26734 lands

include_finished_set=vllm_config.parallel_config.data_parallel_size > 1,
log_stats=self.log_stats,
block_size=scheduler_block_size,
)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ignore, to be rebased once #26734 lands

class KVConnectorOutput:
# [req_ids]
finished_sending: set[str] | None = None
finished_recving: set[str] | None = None
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ignore, to be rebased once #26734 lands

@mergify
Copy link

mergify bot commented Oct 22, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @NickLucche.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify
Copy link

mergify bot commented Oct 23, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @NickLucche.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 23, 2025
@NickLucche NickLucche marked this pull request as ready for review October 23, 2025 16:35
@NickLucche NickLucche requested a review from ApostaC as a code owner October 23, 2025 16:35
@mergify mergify bot removed the needs-rebase label Oct 23, 2025
@NickLucche
Copy link
Collaborator Author

PR's now ready for review!

@NickLucche NickLucche requested a review from andylolu2 October 24, 2025 08:51
@NickLucche
Copy link
Collaborator Author

cc @xuechendi for xpu

@mergify
Copy link

mergify bot commented Oct 27, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @NickLucche.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 27, 2025
@xuechendi
Copy link
Contributor

@zhenwei-intel , please help to review, thx

@mergify mergify bot removed the needs-rebase label Nov 7, 2025
@NickLucche NickLucche added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 9, 2025
@mergify
Copy link

mergify bot commented Nov 14, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @NickLucche.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 14, 2025
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
self.num_layers = 0

# nixl_prepped_dlist_handle.
self.src_xfer_side_handle: int = 0
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dropped default self.src_xfer_side_handle in favor of
self.src_xfer_handles_by_block_size[self.block_size]

Comment on lines +2053 to +2055
if self.use_mla and tp_ratio < 0:
# ..but we still need to notify the other remote ranks that we
# have the blocks we need so they can update the request state.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

important mla logic

@NickLucche NickLucche requested a review from xuechendi November 21, 2025 14:22
self.kv_caches_base_addr: dict[EngineId, list[int]] = {}
self.device_id: int = 0
# Current rank may pull from multiple remote TP workers.
self.kv_caches_base_addr: defaultdict[EngineId, dict[int, list[int]]] = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be helpful with a comment to explain the leveled-dict, ex:

# EngineId, dict[int, list[int]] -> engine_id, tp_rank, base_addr_for_layer

@xuechendi
Copy link
Contributor

PR is verified with heter_block_size test, and it looks good.
Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kv-connector ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants