Skip to content

Commit 78ef532

Browse files
committed
more MLA tests
Signed-off-by: NickLucche <[email protected]>
1 parent dff9e35 commit 78ef532

File tree

2 files changed

+101
-0
lines changed

2 files changed

+101
-0
lines changed

tests/out_prefill

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
error: No justfile found

tests/v1/kv_connector/unit/test_nixl_connector.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,106 @@ def check_handshake(remote_tp_size: int):
557557
)
558558
check_handshake(6)
559559

560+
@patch(
561+
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
562+
FakeNixlWrapper,
563+
)
564+
@pytest.mark.parametrize("local_tp_size", [1, 2])
565+
def test_prefill_tp_size_greater_than_decode_tp_size_mla(
566+
self, local_tp_size: int, dist_init
567+
):
568+
"""
569+
Verify remote TP > local TP handshake succeeds with different
570+
remote configurations for an MLA model.
571+
"""
572+
vllm_config = create_vllm_config()
573+
d_tp_size = 1
574+
p_tp_size = 2
575+
576+
# Build two separate connectors/workers to emulate P TP=2 ranks.
577+
conn_p0 = NixlConnector(vllm_config, KVConnectorRole.WORKER)
578+
conn_p1 = NixlConnector(vllm_config, KVConnectorRole.WORKER)
579+
conn_p0.connector_worker = FakeNixlConnectorWorker(
580+
vllm_config, conn_p0.engine_id, hand_shake_latency=0
581+
)
582+
conn_p1.connector_worker = FakeNixlConnectorWorker(
583+
vllm_config, conn_p1.engine_id, hand_shake_latency=0
584+
)
585+
586+
# Force P world size to 2 for both workers and emulate distinct tp_ranks.
587+
# Also enable MLA path so that expected_finished_count is updated.
588+
for rank, worker in enumerate(
589+
(conn_p0.connector_worker, conn_p1.connector_worker)
590+
):
591+
worker.world_size = p_tp_size
592+
worker.kv_topo.tp_size = p_tp_size
593+
worker.tp_rank = rank
594+
worker.use_mla = True
595+
596+
req_id = "req-ep-dp2-p0"
597+
now = time.perf_counter()
598+
# Register a request on P that is waiting for consumers to read
599+
# (both workers track it).
600+
conn_p0.connector_worker._reqs_to_send[req_id] = now + 10.0
601+
conn_p0.connector_worker._reqs_to_process.add(req_id)
602+
conn_p1.connector_worker._reqs_to_send[req_id] = now + 10.0
603+
conn_p1.connector_worker._reqs_to_process.add(req_id)
604+
605+
# Simulate a read notification coming from D with (tp=1, dp=2).
606+
notif = f"{req_id}:{d_tp_size}".encode()
607+
# D0-0->P0 notif
608+
conn_p0.connector_worker.nixl_wrapper.get_new_notifs = lambda: {
609+
"agent": [notif]
610+
} # type: ignore[method-assign]
611+
conn_p1.connector_worker.nixl_wrapper.get_new_notifs = lambda: {
612+
"agent": [notif]
613+
} # type: ignore[method-assign]
614+
615+
# Trigger notification processing via get_finished().
616+
done_sending0, _ = conn_p0.get_finished(finished_req_ids=set())
617+
done_sending1, _ = conn_p1.get_finished(finished_req_ids=set())
618+
assert req_id in done_sending0 and req_id in done_sending1
619+
620+
# E2E aggregation: ensure the aggregated output marks the request
621+
# as finished using the connector's expected_finished_count.
622+
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
623+
624+
aggregator = KVOutputAggregator.from_connector(conn_p0, world_size=2)
625+
626+
out0 = ModelRunnerOutput(
627+
req_ids=[req_id],
628+
req_id_to_index={req_id: 0},
629+
sampled_token_ids=[[0]],
630+
logprobs=None,
631+
prompt_logprobs_dict={},
632+
pooler_output=[None],
633+
kv_connector_output=KVConnectorOutput(
634+
finished_sending=done_sending0,
635+
finished_recving=None,
636+
),
637+
)
638+
out1 = ModelRunnerOutput(
639+
req_ids=[req_id],
640+
req_id_to_index={req_id: 0},
641+
sampled_token_ids=[[0]],
642+
logprobs=None,
643+
prompt_logprobs_dict={},
644+
pooler_output=[None],
645+
kv_connector_output=KVConnectorOutput(
646+
finished_sending=done_sending1,
647+
finished_recving=None,
648+
),
649+
)
650+
aggregated = aggregator.aggregate([out0, out1], output_rank=0)
651+
assert aggregated.kv_connector_output is not None
652+
assert aggregated.kv_connector_output.finished_sending == {req_id}
653+
654+
# Producers cleaned up state for the finished request.
655+
assert req_id not in conn_p0.connector_worker._reqs_to_send
656+
assert req_id not in conn_p0.connector_worker._reqs_to_process
657+
assert req_id not in conn_p1.connector_worker._reqs_to_send
658+
assert req_id not in conn_p1.connector_worker._reqs_to_process
659+
560660
@patch(
561661
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
562662
FakeNixlWrapper,

0 commit comments

Comments
 (0)