@@ -380,6 +380,8 @@ def __init__(
380380 super ().__init__ (* args , ** kwargs )
381381 self ._hand_shake_latency = hand_shake_latency
382382 self .kv_cache_layout = kv_cache_layout
383+ # Mock register_kv_caches attribute needed for tests that do not call it.
384+ self .src_xfer_handles_by_block_size = {self .block_size : 1 }
383385
384386 def _nixl_handshake (
385387 self , host : str , port : int , remote_tp_size : int , expected_engine_id : str
@@ -470,6 +472,7 @@ def test_multi_xfer_one_engine(
470472 worker .dst_xfer_side_handles = {
471473 FakeNixlConnectorWorker .REMOTE_ENGINE_ID : {0 : 1 }
472474 }
475+ # worker.src_xfer_handles_by_block_size = {worker.block_size: 1}
473476 worker .kv_cache_layout = "HND"
474477 num_xfers = 4
475478 while True :
@@ -547,6 +550,9 @@ def test_async_load_kv(
547550 connector .connector_worker = FakeNixlConnectorWorker (
548551 vllm_config , connector .engine_id
549552 )
553+ # worker = connector.connector_worker
554+ # worker.src_xfer_handles_by_block_size = {worker.block_size: 1}
555+
550556 metadata = NixlConnectorMetadata ()
551557 metadata .add_new_req (
552558 request_id = "id" ,
@@ -619,9 +625,9 @@ def check_handshake(remote_tp_size: int):
619625 remote_engine_id = worker .REMOTE_ENGINE_ID
620626 assert worker ._tp_size [remote_engine_id ] == remote_tp_size
621627 assert - tp_ratio == worker .kv_topo .tp_ratio_from_engine_id (remote_engine_id )
622- # ensure src_xfer_side_chunked_handles is populated with tpratio chunks
623- assert - tp_ratio in worker .src_xfer_side_chunked_handles
624- assert len (worker .src_xfer_side_chunked_handles [- tp_ratio ]) == tp_ratio
628+ # ensure src_xfer_handles_by_tp_ratio is populated with tpratio chunks
629+ assert - tp_ratio in worker .src_xfer_handles_by_tp_ratio
630+ assert len (worker .src_xfer_handles_by_tp_ratio [- tp_ratio ]) == tp_ratio
625631 assert remote_engine_id in worker .dst_xfer_side_handles
626632 assert set (worker .dst_xfer_side_handles [remote_engine_id ].keys ()) == set (
627633 range (tp_ratio )
@@ -679,7 +685,7 @@ def test_prefill_tp_size_greater_than_decode_tp_size_mla(
679685 (conn_p0 .connector_worker , conn_p1 .connector_worker )
680686 ):
681687 worker .world_size = p_tp_size
682- worker .kv_topo .tp_size = p_tp_size
688+ worker .kv_topo .remote_tp_size = { worker . engine_id : p_tp_size }
683689 worker .tp_rank = rank
684690 worker .use_mla = True
685691
@@ -765,6 +771,9 @@ def test_concurrent_load_kv(
765771 connector .connector_worker = FakeNixlConnectorWorker (
766772 vllm_config , connector .engine_id
767773 )
774+ # Register (mocked) local xfer handler
775+ # worker = connector.connector_worker
776+ # worker.src_xfer_handles_by_block_size = {worker.block_size: 1}
768777 metadata = NixlConnectorMetadata ()
769778 total_reqs = 5
770779 for i in range (total_reqs ):
@@ -1489,8 +1498,10 @@ def test_shutdown_cleans_up_resources(dist_init):
14891498 patch .object (nixl_wrapper , "deregister_memory" ) as mock_dereg ,
14901499 ):
14911500 worker ._recving_transfers = {"req1" : [(123 , time .perf_counter ())]}
1492- worker .src_xfer_side_handle = 456
1493- worker .src_xfer_side_chunked_handles = {- 2 : [456 ]}
1501+ # Mock register_kv_cache which registers local handle
1502+ worker .src_xfer_handles_by_block_size = {worker .block_size : 455 }
1503+ # P TP = 2 * D TP case, we should register 2 local handles
1504+ worker .src_xfer_handles_by_tp_ratio = {- 2 : [456 , 457 ]}
14941505 worker .dst_xfer_side_handles = {"engine1" : {0 : 789 }}
14951506 worker ._remote_agents = {"engine1" : {0 : "agent1" }}
14961507 worker ._registered_descs = ["desc1" , "desc2" ]
@@ -1512,8 +1523,10 @@ def test_shutdown_cleans_up_resources(dist_init):
15121523 mock_listener .join .assert_called_once ()
15131524
15141525 mock_rel_xfer .assert_called_once_with (123 )
1515- assert mock_rel_dlist .call_count == 3
1516- mock_rel_dlist .assert_any_call (456 ) # src handle
1526+ assert mock_rel_dlist .call_count == 4
1527+ mock_rel_dlist .assert_any_call (455 ) # src handle (whole region)
1528+ mock_rel_dlist .assert_any_call (456 ) # src handle (1st chunk)
1529+ mock_rel_dlist .assert_any_call (457 ) # src handle (2nd chunk)
15171530 mock_rel_dlist .assert_any_call (789 ) # dst handle
15181531 mock_rem_agent .assert_called_once_with ("agent1" )
15191532 assert mock_dereg .call_count == 2
0 commit comments