Skip to content

Commit 3899d23

Browse files
committed
fix tests
Signed-off-by: NickLucche <[email protected]>
1 parent 3e39b57 commit 3899d23

File tree

1 file changed

+21
-8
lines changed

1 file changed

+21
-8
lines changed

tests/v1/kv_connector/unit/test_nixl_connector.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,8 @@ def __init__(
380380
super().__init__(*args, **kwargs)
381381
self._hand_shake_latency = hand_shake_latency
382382
self.kv_cache_layout = kv_cache_layout
383+
# Mock register_kv_caches attribute needed for tests that do not call it.
384+
self.src_xfer_handles_by_block_size = {self.block_size: 1}
383385

384386
def _nixl_handshake(
385387
self, host: str, port: int, remote_tp_size: int, expected_engine_id: str
@@ -470,6 +472,7 @@ def test_multi_xfer_one_engine(
470472
worker.dst_xfer_side_handles = {
471473
FakeNixlConnectorWorker.REMOTE_ENGINE_ID: {0: 1}
472474
}
475+
# worker.src_xfer_handles_by_block_size = {worker.block_size: 1}
473476
worker.kv_cache_layout = "HND"
474477
num_xfers = 4
475478
while True:
@@ -547,6 +550,9 @@ def test_async_load_kv(
547550
connector.connector_worker = FakeNixlConnectorWorker(
548551
vllm_config, connector.engine_id
549552
)
553+
# worker = connector.connector_worker
554+
# worker.src_xfer_handles_by_block_size = {worker.block_size: 1}
555+
550556
metadata = NixlConnectorMetadata()
551557
metadata.add_new_req(
552558
request_id="id",
@@ -619,9 +625,9 @@ def check_handshake(remote_tp_size: int):
619625
remote_engine_id = worker.REMOTE_ENGINE_ID
620626
assert worker._tp_size[remote_engine_id] == remote_tp_size
621627
assert -tp_ratio == worker.kv_topo.tp_ratio_from_engine_id(remote_engine_id)
622-
# ensure src_xfer_side_chunked_handles is populated with tpratio chunks
623-
assert -tp_ratio in worker.src_xfer_side_chunked_handles
624-
assert len(worker.src_xfer_side_chunked_handles[-tp_ratio]) == tp_ratio
628+
# ensure src_xfer_handles_by_tp_ratio is populated with tpratio chunks
629+
assert -tp_ratio in worker.src_xfer_handles_by_tp_ratio
630+
assert len(worker.src_xfer_handles_by_tp_ratio[-tp_ratio]) == tp_ratio
625631
assert remote_engine_id in worker.dst_xfer_side_handles
626632
assert set(worker.dst_xfer_side_handles[remote_engine_id].keys()) == set(
627633
range(tp_ratio)
@@ -679,7 +685,7 @@ def test_prefill_tp_size_greater_than_decode_tp_size_mla(
679685
(conn_p0.connector_worker, conn_p1.connector_worker)
680686
):
681687
worker.world_size = p_tp_size
682-
worker.kv_topo.tp_size = p_tp_size
688+
worker.kv_topo.remote_tp_size = {worker.engine_id: p_tp_size}
683689
worker.tp_rank = rank
684690
worker.use_mla = True
685691

@@ -765,6 +771,9 @@ def test_concurrent_load_kv(
765771
connector.connector_worker = FakeNixlConnectorWorker(
766772
vllm_config, connector.engine_id
767773
)
774+
# Register (mocked) local xfer handler
775+
# worker = connector.connector_worker
776+
# worker.src_xfer_handles_by_block_size = {worker.block_size: 1}
768777
metadata = NixlConnectorMetadata()
769778
total_reqs = 5
770779
for i in range(total_reqs):
@@ -1489,8 +1498,10 @@ def test_shutdown_cleans_up_resources(dist_init):
14891498
patch.object(nixl_wrapper, "deregister_memory") as mock_dereg,
14901499
):
14911500
worker._recving_transfers = {"req1": [(123, time.perf_counter())]}
1492-
worker.src_xfer_side_handle = 456
1493-
worker.src_xfer_side_chunked_handles = {-2: [456]}
1501+
# Mock register_kv_cache which registers local handle
1502+
worker.src_xfer_handles_by_block_size = {worker.block_size: 455}
1503+
# P TP = 2 * D TP case, we should register 2 local handles
1504+
worker.src_xfer_handles_by_tp_ratio = {-2: [456, 457]}
14941505
worker.dst_xfer_side_handles = {"engine1": {0: 789}}
14951506
worker._remote_agents = {"engine1": {0: "agent1"}}
14961507
worker._registered_descs = ["desc1", "desc2"]
@@ -1512,8 +1523,10 @@ def test_shutdown_cleans_up_resources(dist_init):
15121523
mock_listener.join.assert_called_once()
15131524

15141525
mock_rel_xfer.assert_called_once_with(123)
1515-
assert mock_rel_dlist.call_count == 3
1516-
mock_rel_dlist.assert_any_call(456) # src handle
1526+
assert mock_rel_dlist.call_count == 4
1527+
mock_rel_dlist.assert_any_call(455) # src handle (whole region)
1528+
mock_rel_dlist.assert_any_call(456) # src handle (1st chunk)
1529+
mock_rel_dlist.assert_any_call(457) # src handle (2nd chunk)
15171530
mock_rel_dlist.assert_any_call(789) # dst handle
15181531
mock_rem_agent.assert_called_once_with("agent1")
15191532
assert mock_dereg.call_count == 2

0 commit comments

Comments
 (0)