Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@ SCRIPT="v1/kv_connector/nixl_integration/run_accuracy_test.sh"
configs=(
"GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2"
"GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2"
"GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1"
"GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA case
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP1, D-DPEP=2 (TP=1)
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP2, D-DPEP=2 (TP=1)
)

run_tests() {
Expand Down
251 changes: 228 additions & 23 deletions tests/v1/kv_connector/unit/test_nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,8 @@ def __init__(
super().__init__(*args, **kwargs)
self._hand_shake_latency = hand_shake_latency
self.kv_cache_layout = kv_cache_layout
# Mock register_kv_caches attribute needed for tests that do not call it.
self.src_xfer_handles_by_block_size = {self.block_size: 1}

def _nixl_handshake(
self, host: str, port: int, remote_tp_size: int, expected_engine_id: str
Expand All @@ -396,23 +398,44 @@ def _nixl_handshake(

assert expected_engine_id == self.REMOTE_ENGINE_ID

remote_agent_name = self.add_remote_agent(
NixlAgentMetadata(
engine_id=self.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
device_id=0,
num_blocks=1,
block_lens=self.block_len_per_layer,
attn_backend_name=self.backend_name,
# `self.kv_cache_layout` is only forced to HND when vllm engine
# is started. We mock HND here.
kv_cache_layout="HND",
block_size=self.block_size,
),
remote_tp_size=remote_tp_size,
)
return {0: remote_agent_name}
# Adjust remote block length metadata to satisfy heterogeneous TP
# invariants enforced during handshake validation.
remote_block_lens = list(self.block_len_per_layer)
tp_ratio = self.kv_topo.tp_ratio(remote_tp_size)
if remote_tp_size > self.world_size:
# P TP > D TP case, block_len of remote is smaller
remote_block_lens = [
block_len // (-tp_ratio) for block_len in remote_block_lens
]
elif remote_tp_size < self.world_size:
remote_block_lens = [
block_len * tp_ratio for block_len in remote_block_lens
]

# When remote tp_size > local tp_size, handshake with multiple
# remote ranks.
num_hanshakes = 1 if tp_ratio > 0 else -tp_ratio
remote_agents: dict[int, str] = {}
for remote_tp_rank in range(num_hanshakes):
remote_agent_name = self.add_remote_agent(
NixlAgentMetadata(
engine_id=self.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
device_id=remote_tp_rank,
num_blocks=1,
block_lens=remote_block_lens,
attn_backend_name=self.backend_name,
# `self.kv_cache_layout` is only forced to HND when vllm engine
# is started. We mock HND here.
kv_cache_layout="HND",
block_size=self.block_size,
),
remote_tp_rank=remote_tp_rank,
remote_tp_size=remote_tp_size,
)
remote_agents[remote_tp_rank] = remote_agent_name
return remote_agents


class TestNixlHandshake:
Expand Down Expand Up @@ -443,7 +466,14 @@ def test_multi_xfer_one_engine(
vllm_config, connector.engine_id, hand_shake_latency=0
)
assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper)
connector.connector_worker.nixl_wrapper.set_cycles_before_xfer_done(3)
worker = connector.connector_worker
worker.nixl_wrapper.set_cycles_before_xfer_done(3)
# simulate handshake
worker.dst_xfer_side_handles = {
FakeNixlConnectorWorker.REMOTE_ENGINE_ID: {0: 1}
}
# worker.src_xfer_handles_by_block_size = {worker.block_size: 1}
worker.kv_cache_layout = "HND"
num_xfers = 4
while True:
# For the same request_id, initiate multiple xfers across different
Expand Down Expand Up @@ -520,6 +550,9 @@ def test_async_load_kv(
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id
)
# worker = connector.connector_worker
# worker.src_xfer_handles_by_block_size = {worker.block_size: 1}

metadata = NixlConnectorMetadata()
metadata.add_new_req(
request_id="id",
Expand Down Expand Up @@ -555,6 +588,171 @@ def test_async_load_kv(
return
raise TimeoutError("Took too long to complete async handshake.")

@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
@pytest.mark.parametrize("local_tp_size", [1, 2])
def test_prefill_tp_size_greater_than_decode_tp_size(
self, local_tp_size: int, dist_init
):
"""
Verify remote TP > local TP handshake succeeds with different
remote configurations.
"""

vllm_config = create_vllm_config()
local_tp_size = 1
vllm_config.parallel_config.tensor_parallel_size = local_tp_size

connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0
)
worker = connector.connector_worker

# Minimal local registration params used by add_remote_agent
worker.slot_size_per_layer = [4096]
worker.block_len_per_layer = [4096 * worker.block_size]
worker.num_blocks = 1
worker.dst_num_blocks[worker.engine_id] = worker.num_blocks
worker.src_blocks_data = [(0, worker.block_len_per_layer[0], worker.tp_rank)]

def check_handshake(remote_tp_size: int):
tp_ratio = remote_tp_size // local_tp_size
assert set(remote_agents.keys()) == set(range(tp_ratio))

remote_engine_id = worker.REMOTE_ENGINE_ID
assert worker._tp_size[remote_engine_id] == remote_tp_size
assert -tp_ratio == worker.kv_topo.tp_ratio_from_engine_id(remote_engine_id)
# ensure src_xfer_handles_by_tp_ratio is populated with tpratio chunks
assert -tp_ratio in worker.src_xfer_handles_by_tp_ratio
assert len(worker.src_xfer_handles_by_tp_ratio[-tp_ratio]) == tp_ratio
assert remote_engine_id in worker.dst_xfer_side_handles
assert set(worker.dst_xfer_side_handles[remote_engine_id].keys()) == set(
range(tp_ratio)
)

remote_agents = worker._nixl_handshake(
host="localhost",
port=1234,
remote_tp_size=2,
expected_engine_id=worker.REMOTE_ENGINE_ID,
)
check_handshake(2)

# NOTE flexiblity: a second remote with higher number of ranks is
# discovered. This is not a scenario we actively support right now, but
# the connector allows it.
worker.REMOTE_ENGINE_ID = "remote_engine_2"
remote_agents = worker._nixl_handshake(
host="localhost",
port=1234,
remote_tp_size=6,
expected_engine_id=worker.REMOTE_ENGINE_ID,
)
check_handshake(6)

@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
@pytest.mark.parametrize("local_tp_size", [1, 2])
def test_prefill_tp_size_greater_than_decode_tp_size_mla(
self, local_tp_size: int, dist_init
):
"""
Verify remote TP > local TP handshake succeeds with different
remote configurations for an MLA model.
"""
vllm_config = create_vllm_config()
d_tp_size = 1
p_tp_size = 2

# Build two separate connectors/workers to emulate P TP=2 ranks.
conn_p0 = NixlConnector(vllm_config, KVConnectorRole.WORKER)
conn_p1 = NixlConnector(vllm_config, KVConnectorRole.WORKER)
conn_p0.connector_worker = FakeNixlConnectorWorker(
vllm_config, conn_p0.engine_id, hand_shake_latency=0
)
conn_p1.connector_worker = FakeNixlConnectorWorker(
vllm_config, conn_p1.engine_id, hand_shake_latency=0
)

# Force P world size to 2 for both workers and emulate distinct tp_ranks.
# Also enable MLA path so that expected_finished_count is updated.
for rank, worker in enumerate(
(conn_p0.connector_worker, conn_p1.connector_worker)
):
worker.world_size = p_tp_size
worker.kv_topo.remote_tp_size = {worker.engine_id: p_tp_size}
worker.tp_rank = rank
worker.use_mla = True

req_id = "req-ep-dp2-p0"
now = time.perf_counter()
# Register a request on P that is waiting for consumers to read
# (both workers track it).
conn_p0.connector_worker._reqs_to_send[req_id] = now + 10.0
conn_p0.connector_worker._reqs_to_process.add(req_id)
conn_p1.connector_worker._reqs_to_send[req_id] = now + 10.0
conn_p1.connector_worker._reqs_to_process.add(req_id)

# Simulate a read notification coming from D with (tp=1, dp=2).
notif = f"{req_id}:{d_tp_size}".encode()
# D0-0->P0 notif
conn_p0.connector_worker.nixl_wrapper.get_new_notifs = lambda: {
"agent": [notif]
} # type: ignore[method-assign]
conn_p1.connector_worker.nixl_wrapper.get_new_notifs = lambda: {
"agent": [notif]
} # type: ignore[method-assign]

# Trigger notification processing via get_finished().
done_sending0, _ = conn_p0.get_finished(finished_req_ids=set())
done_sending1, _ = conn_p1.get_finished(finished_req_ids=set())
assert req_id in done_sending0 and req_id in done_sending1

# E2E aggregation: ensure the aggregated output marks the request
# as finished using the connector's expected_finished_count.
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput

aggregator = KVOutputAggregator.from_connector(conn_p0, world_size=2)

out0 = ModelRunnerOutput(
req_ids=[req_id],
req_id_to_index={req_id: 0},
sampled_token_ids=[[0]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[None],
kv_connector_output=KVConnectorOutput(
finished_sending=done_sending0,
finished_recving=None,
),
)
out1 = ModelRunnerOutput(
req_ids=[req_id],
req_id_to_index={req_id: 0},
sampled_token_ids=[[0]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[None],
kv_connector_output=KVConnectorOutput(
finished_sending=done_sending1,
finished_recving=None,
),
)
aggregated = aggregator.aggregate([out0, out1], output_rank=0)
assert aggregated.kv_connector_output is not None
assert aggregated.kv_connector_output.finished_sending == {req_id}

# Producers cleaned up state for the finished request.
assert req_id not in conn_p0.connector_worker._reqs_to_send
assert req_id not in conn_p0.connector_worker._reqs_to_process
assert req_id not in conn_p1.connector_worker._reqs_to_send
assert req_id not in conn_p1.connector_worker._reqs_to_process

@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
Expand All @@ -573,6 +771,9 @@ def test_concurrent_load_kv(
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id
)
# Register (mocked) local xfer handler
# worker = connector.connector_worker
# worker.src_xfer_handles_by_block_size = {worker.block_size: 1}
metadata = NixlConnectorMetadata()
total_reqs = 5
for i in range(total_reqs):
Expand Down Expand Up @@ -660,7 +861,6 @@ def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init):
with pytest.raises(RuntimeError):
# mismatched layout is expected to fail
worker.add_remote_agent(meta, remote_tp_size=2)
with pytest.raises(AssertionError):
worker.add_remote_agent(meta, remote_tp_size=1)

@patch(
Expand Down Expand Up @@ -1298,8 +1498,11 @@ def test_shutdown_cleans_up_resources(dist_init):
patch.object(nixl_wrapper, "deregister_memory") as mock_dereg,
):
worker._recving_transfers = {"req1": [(123, time.perf_counter())]}
worker.src_xfer_side_handle = 456
worker.dst_xfer_side_handles = {"engine1": 789}
# Mock register_kv_cache which registers local handle
worker.src_xfer_handles_by_block_size = {worker.block_size: 455}
# P TP = 2 * D TP case, we should register 2 local handles
worker.src_xfer_handles_by_tp_ratio = {-2: [456, 457]}
worker.dst_xfer_side_handles = {"engine1": {0: 789}}
worker._remote_agents = {"engine1": {0: "agent1"}}
worker._registered_descs = ["desc1", "desc2"]

Expand All @@ -1320,8 +1523,10 @@ def test_shutdown_cleans_up_resources(dist_init):
mock_listener.join.assert_called_once()

mock_rel_xfer.assert_called_once_with(123)
assert mock_rel_dlist.call_count == 2
mock_rel_dlist.assert_any_call(456) # src handle
assert mock_rel_dlist.call_count == 4
mock_rel_dlist.assert_any_call(455) # src handle (whole region)
mock_rel_dlist.assert_any_call(456) # src handle (1st chunk)
mock_rel_dlist.assert_any_call(457) # src handle (2nd chunk)
mock_rel_dlist.assert_any_call(789) # dst handle
mock_rem_agent.assert_called_once_with("agent1")
assert mock_dereg.call_count == 2
Expand Down
Loading