@@ -557,6 +557,106 @@ def check_handshake(remote_tp_size: int):
557557 )
558558 check_handshake (6 )
559559
560+ @patch (
561+ "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper" ,
562+ FakeNixlWrapper ,
563+ )
564+ @pytest .mark .parametrize ("local_tp_size" , [1 , 2 ])
565+ def test_prefill_tp_size_greater_than_decode_tp_size_mla (
566+ self , local_tp_size : int , dist_init
567+ ):
568+ """
569+ Verify remote TP > local TP handshake succeeds with different
570+ remote configurations for an MLA model.
571+ """
572+ vllm_config = create_vllm_config ()
573+ d_tp_size = 1
574+ p_tp_size = 2
575+
576+ # Build two separate connectors/workers to emulate P TP=2 ranks.
577+ conn_p0 = NixlConnector (vllm_config , KVConnectorRole .WORKER )
578+ conn_p1 = NixlConnector (vllm_config , KVConnectorRole .WORKER )
579+ conn_p0 .connector_worker = FakeNixlConnectorWorker (
580+ vllm_config , conn_p0 .engine_id , hand_shake_latency = 0
581+ )
582+ conn_p1 .connector_worker = FakeNixlConnectorWorker (
583+ vllm_config , conn_p1 .engine_id , hand_shake_latency = 0
584+ )
585+
586+ # Force P world size to 2 for both workers and emulate distinct tp_ranks.
587+ # Also enable MLA path so that expected_finished_count is updated.
588+ for rank , worker in enumerate (
589+ (conn_p0 .connector_worker , conn_p1 .connector_worker )
590+ ):
591+ worker .world_size = p_tp_size
592+ worker .kv_topo .tp_size = p_tp_size
593+ worker .tp_rank = rank
594+ worker .use_mla = True
595+
596+ req_id = "req-ep-dp2-p0"
597+ now = time .perf_counter ()
598+ # Register a request on P that is waiting for consumers to read
599+ # (both workers track it).
600+ conn_p0 .connector_worker ._reqs_to_send [req_id ] = now + 10.0
601+ conn_p0 .connector_worker ._reqs_to_process .add (req_id )
602+ conn_p1 .connector_worker ._reqs_to_send [req_id ] = now + 10.0
603+ conn_p1 .connector_worker ._reqs_to_process .add (req_id )
604+
605+ # Simulate a read notification coming from D with (tp=1, dp=2).
606+ notif = f"{ req_id } :{ d_tp_size } " .encode ()
607+ # D0-0->P0 notif
608+ conn_p0 .connector_worker .nixl_wrapper .get_new_notifs = lambda : {
609+ "agent" : [notif ]
610+ } # type: ignore[method-assign]
611+ conn_p1 .connector_worker .nixl_wrapper .get_new_notifs = lambda : {
612+ "agent" : [notif ]
613+ } # type: ignore[method-assign]
614+
615+ # Trigger notification processing via get_finished().
616+ done_sending0 , _ = conn_p0 .get_finished (finished_req_ids = set ())
617+ done_sending1 , _ = conn_p1 .get_finished (finished_req_ids = set ())
618+ assert req_id in done_sending0 and req_id in done_sending1
619+
620+ # E2E aggregation: ensure the aggregated output marks the request
621+ # as finished using the connector's expected_finished_count.
622+ from vllm .v1 .outputs import KVConnectorOutput , ModelRunnerOutput
623+
624+ aggregator = KVOutputAggregator .from_connector (conn_p0 , world_size = 2 )
625+
626+ out0 = ModelRunnerOutput (
627+ req_ids = [req_id ],
628+ req_id_to_index = {req_id : 0 },
629+ sampled_token_ids = [[0 ]],
630+ logprobs = None ,
631+ prompt_logprobs_dict = {},
632+ pooler_output = [None ],
633+ kv_connector_output = KVConnectorOutput (
634+ finished_sending = done_sending0 ,
635+ finished_recving = None ,
636+ ),
637+ )
638+ out1 = ModelRunnerOutput (
639+ req_ids = [req_id ],
640+ req_id_to_index = {req_id : 0 },
641+ sampled_token_ids = [[0 ]],
642+ logprobs = None ,
643+ prompt_logprobs_dict = {},
644+ pooler_output = [None ],
645+ kv_connector_output = KVConnectorOutput (
646+ finished_sending = done_sending1 ,
647+ finished_recving = None ,
648+ ),
649+ )
650+ aggregated = aggregator .aggregate ([out0 , out1 ], output_rank = 0 )
651+ assert aggregated .kv_connector_output is not None
652+ assert aggregated .kv_connector_output .finished_sending == {req_id }
653+
654+ # Producers cleaned up state for the finished request.
655+ assert req_id not in conn_p0 .connector_worker ._reqs_to_send
656+ assert req_id not in conn_p0 .connector_worker ._reqs_to_process
657+ assert req_id not in conn_p1 .connector_worker ._reqs_to_send
658+ assert req_id not in conn_p1 .connector_worker ._reqs_to_process
659+
560660 @patch (
561661 "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper" ,
562662 FakeNixlWrapper ,
0 commit comments