Skip to content

Commit df88a2e

Browse files
[P/D]mooncake_connector adapted to 0.10.1 (#2664)
### What this PR does / why we need it? In vllm version 0.10.1, a new KVOutputAggregator was added to the executor, moving aggregation to the executor(vllm-project/vllm#19555). This caused mooncake_connector to break. This change aims to fix this bug and also adds a policy to forcibly release the KV cache when the prefill node times out. This PR is currently linked to a PR in vllm (vllm-project/vllm#23917). The vllm PR aims to modify the finish and send count confirmation in heterogeneous TP situations. The reason for deleting many UTs is that a lot of communication codes have been deleted, so the UT as a whole will appear more concise. - vLLM version: v0.10.1.1 - vLLM main: vllm-project/vllm@fa4311d --------- Signed-off-by: baxingpiaochong <[email protected]>
1 parent 07d44ad commit df88a2e

File tree

3 files changed

+133
-322
lines changed

3 files changed

+133
-322
lines changed

tests/ut/kv_connector/test_mooncake_connector.py

Lines changed: 43 additions & 248 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import msgspec
1313
import zmq
1414
from vllm.utils import make_zmq_path
15-
from zmq import Context # type: ignore
1615

1716
fake_engine = types.ModuleType("mooncake.engine")
1817
fake_engine.TransferEngine = MagicMock() # type: ignore[attr-defined]
@@ -32,193 +31,17 @@
3231
class TestKVCacheTaskTrackerInit(unittest.TestCase):
3332

3433
def test_init_basic_properties(self):
35-
tracker = KVCacheTaskTracker(tp_rank=1,
36-
local_engine_id="engine1",
37-
target_count=10)
38-
self.assertEqual(tracker.tp_rank, 1)
39-
self.assertEqual(tracker.local_engine_id, "engine1")
40-
self.assertEqual(tracker.target_count, 10)
34+
tracker = KVCacheTaskTracker()
4135
self.assertIsInstance(tracker.done_task_lock, type(threading.Lock()))
42-
self.assertIsInstance(tracker.done_task_counts, defaultdict)
4336
self.assertIsInstance(tracker.finished_requests, set)
44-
45-
def test_socket_path_generation(self):
46-
tracker = KVCacheTaskTracker(tp_rank=1,
47-
local_engine_id="engine42",
48-
target_count=1)
49-
self.assertEqual(tracker.socket_path,
50-
"ipc:///tmp/vllm_mooncake_connector_engine42.ipc")
51-
52-
@patch("vllm_ascend.distributed.mooncake_connector.threading.Thread")
53-
def test_tp_rank_zero_initialization(self, mock_thread):
54-
tracker = KVCacheTaskTracker(tp_rank=0,
55-
local_engine_id="test",
56-
target_count=1)
57-
mock_thread.assert_called_once_with(
58-
target=tracker._listen_for_completion_signals,
59-
daemon=True,
60-
name="KVCacheTaskTrackerListenerThread")
61-
mock_thread.return_value.start.assert_called_once()
62-
self.assertIsNone(tracker.socket)
63-
self.assertTrue(tracker.listener.daemon)
64-
65-
@patch("vllm_ascend.distributed.mooncake_connector.make_zmq_socket")
66-
@patch("vllm_ascend.distributed.mooncake_connector.logger")
67-
def test_tp_rank_non_zero_initialization(self, mock_logger,
68-
mock_make_zmq_socket):
69-
mock_socket = MagicMock()
70-
mock_make_zmq_socket.return_value = mock_socket
71-
tracker = KVCacheTaskTracker(tp_rank=1,
72-
local_engine_id="test",
73-
target_count=1)
74-
mock_make_zmq_socket.assert_called_once_with(
75-
ctx=unittest.mock.ANY,
76-
path="ipc:///tmp/vllm_mooncake_connector_test.ipc",
77-
socket_type=zmq.PUSH, # type: ignore
78-
bind=False)
79-
mock_logger.info.assert_called_once_with(
80-
"Connecting to transfer socket at %s",
81-
"ipc:///tmp/vllm_mooncake_connector_test.ipc")
82-
self.assertIsNone(tracker.listener)
83-
self.assertEqual(tracker.socket, mock_socket)
84-
85-
86-
class TestKVCacheTaskTrackerListenMethod(unittest.TestCase):
87-
88-
def setUp(self):
89-
self.tp_rank = 0
90-
self.local_engine_id = "test_engine_ut"
91-
self.target_count = 3
92-
self.tracker = KVCacheTaskTracker(self.tp_rank, self.local_engine_id,
93-
self.target_count)
94-
self.original_listen = self.tracker._listen_for_completion_signals
95-
96-
def tearDown(self):
97-
self.tracker._listen_for_completion_signals = self.original_listen
98-
Context.instance().term()
99-
time.sleep(0.1)
100-
101-
def test_normal_message_processing(self):
102-
listener_thread = threading.Thread(
103-
target=self.tracker._listen_for_completion_signals, daemon=True)
104-
listener_thread.start()
105-
time.sleep(0.2)
106-
test_messages = [("request_001", 1), ("request_001", 2),
107-
("request_002", 0), ("request_003", 1)]
108-
ctx = Context()
109-
sender_socket = ctx.socket(zmq.PUSH) # type: ignore
110-
sender_socket.connect(self.tracker.socket_path)
111-
for msg in test_messages:
112-
sender_socket.send_pyobj(msg)
113-
time.sleep(0.05)
114-
sender_socket.close()
115-
time.sleep(0.2)
116-
117-
with self.tracker.done_task_lock:
118-
self.assertEqual(len(self.tracker.done_task_counts["request_001"]),
119-
2)
120-
self.assertIn(1, self.tracker.done_task_counts["request_001"])
121-
self.assertIn(2, self.tracker.done_task_counts["request_001"])
122-
self.assertEqual(len(self.tracker.done_task_counts["request_002"]),
123-
1)
124-
self.assertIn(0, self.tracker.done_task_counts["request_002"])
125-
self.assertEqual(len(self.tracker.done_task_counts["request_003"]),
126-
1)
127-
self.assertIn(1, self.tracker.done_task_counts["request_003"])
128-
129-
@patch("vllm_ascend.distributed.mooncake_connector.make_zmq_socket",
130-
autospec=True)
131-
def test_listen_with_timeout(self, mock_make_socket):
132-
mock_socket = MagicMock()
133-
134-
def mock_recv():
135-
start = time.time()
136-
while time.time() - start < 0.5:
137-
time.sleep(0.01)
138-
return ("req1", 0)
139-
140-
mock_socket.recv_pyobj = mock_recv
141-
mock_make_socket.return_value = mock_socket
142-
143-
test_thread = threading.Thread(
144-
target=self.tracker._listen_for_completion_signals, daemon=True)
145-
test_thread.start()
146-
test_thread.join(timeout=1.0)
147-
mock_make_socket.assert_called_once()
148-
149-
150-
class TestKVCacheTaskTrackerTP(unittest.TestCase):
151-
152-
def setUp(self):
153-
self.local_engine_id = "test_engine"
154-
self.target_count = 3
155-
156-
def test_update_done_task_count_tp_rank_0(self):
157-
tracker = KVCacheTaskTracker(tp_rank=0,
158-
local_engine_id=self.local_engine_id,
159-
target_count=self.target_count)
160-
test_request_id = "test_req_001"
161-
test_tp_rank = 1
162-
tracker.update_done_task_count(test_request_id, test_tp_rank)
163-
with tracker.done_task_lock:
164-
self.assertEqual(len(tracker.done_task_counts[test_request_id]), 1)
165-
self.assertIn(test_tp_rank,
166-
tracker.done_task_counts[test_request_id])
167-
168-
@patch("vllm_ascend.distributed.mooncake_connector.make_zmq_socket",
169-
autospec=True)
170-
def test_update_done_task_count_non_zero_tp(self, mock_make_socket):
171-
mock_socket = MagicMock()
172-
mock_make_socket.return_value = mock_socket
173-
tracker = KVCacheTaskTracker(tp_rank=1,
174-
local_engine_id=self.local_engine_id,
175-
target_count=self.target_count)
176-
test_request_id = "test_req_002"
177-
test_tp_rank = 1
178-
tracker.update_done_task_count(test_request_id, test_tp_rank)
179-
mock_socket.send_pyobj.assert_called_once_with(
180-
(test_request_id, test_tp_rank))
181-
with tracker.done_task_lock:
182-
self.assertNotIn(test_request_id, tracker.done_task_counts)
183-
184-
@patch("vllm_ascend.distributed.mooncake_connector.logger", autospec=True)
185-
@patch("vllm_ascend.distributed.mooncake_connector.make_zmq_socket",
186-
autospec=True)
187-
def test_update_done_task_count_logging(self, mock_make_socket,
188-
mock_logger):
189-
mock_socket = MagicMock()
190-
mock_make_socket.return_value = mock_socket
191-
tracker = KVCacheTaskTracker(tp_rank=2,
192-
local_engine_id=self.local_engine_id,
193-
target_count=self.target_count)
194-
test_request_id = "test_req_003"
195-
tracker.update_done_task_count(test_request_id, 2)
196-
mock_logger.debug.assert_called_once_with(
197-
"Sent done signal for request %s to tp 0", test_request_id)
198-
199-
@patch("vllm_ascend.distributed.mooncake_connector.make_zmq_socket",
200-
autospec=True)
201-
def test_update_multiple_calls(self, mock_make_socket):
202-
mock_socket = MagicMock()
203-
mock_make_socket.return_value = mock_socket
204-
tracker = KVCacheTaskTracker(tp_rank=1,
205-
local_engine_id=self.local_engine_id,
206-
target_count=self.target_count)
207-
test_data = [("req1", 1), ("req1", 1), ("req2", 1)]
208-
for req_id, rank in test_data:
209-
tracker.update_done_task_count(req_id, rank)
210-
self.assertEqual(mock_socket.send_pyobj.call_count, 3)
211-
mock_socket.send_pyobj.assert_called_with(("req2", 1))
37+
self.assertIsInstance(tracker.delayed_free_requests, deque)
21238

21339

21440
class TestGetAndClearFinishedSingleRequests(unittest.TestCase):
21541

21642
def setUp(self):
217-
self.tracker = KVCacheTaskTracker(tp_rank=0,
218-
local_engine_id="test",
219-
target_count=3)
43+
self.tracker = KVCacheTaskTracker()
22044
self.tracker.finished_requests = set()
221-
self.tracker.done_task_counts = defaultdict(set)
22245
self.tracker.done_task_lock = threading.Lock()
22346

22447
def test_empty_requests(self):
@@ -251,14 +74,6 @@ def test_concurrent_access(self, mock_logger):
25174
self.assertEqual(sum(1 for r in results if r), 1)
25275
self.assertEqual(len(self.tracker.finished_requests), 0)
25376

254-
def test_after_increment(self):
255-
self.tracker._increment_task_count("req_123", 0)
256-
self.tracker._increment_task_count("req_123", 1)
257-
self.tracker._increment_task_count("req_123", 2)
258-
result = self.tracker.get_and_clear_finished_requests()
259-
self.assertEqual(result, {"req_123"})
260-
self.assertEqual(self.tracker.get_and_clear_finished_requests(), set())
261-
26277

26378
class TestKVCacheSendingThreadInit(unittest.TestCase):
26479

@@ -282,47 +97,6 @@ def tearDown(self):
28297
if hasattr(thread, 'is_alive') and thread.is_alive():
28398
thread.join(timeout=0.1)
28499

285-
@patch('vllm_ascend.distributed.mooncake_connector.KVCacheTaskTracker')
286-
def test_initialization_basic(self, mock_tracker):
287-
thread = KVCacheSendingThread(**self.common_args)
288-
self.threads.append(thread)
289-
self.assertEqual(thread.tp_rank, 1)
290-
self.assertEqual(thread.decode_tp_size, 4)
291-
self.assertEqual(thread.local_engine_id, 'engine_1')
292-
mock_tracker.assert_called_once()
293-
args = mock_tracker.call_args[0]
294-
kwargs = mock_tracker.call_args[1]
295-
if args:
296-
self.assertEqual(args[0], 1)
297-
self.assertEqual(args[1], 'engine_1')
298-
self.assertEqual(args[2], 4)
299-
else:
300-
self.assertEqual(kwargs['tp_rank'], 1)
301-
self.assertEqual(kwargs['local_engine_id'], 'engine_1')
302-
self.assertEqual(kwargs['target_count'], 4)
303-
304-
@patch('vllm_ascend.distributed.mooncake_connector.KVCacheTaskTracker')
305-
def test_task_tracker_initialization(self, mock_tracker):
306-
args = self.common_args.copy()
307-
args.update({
308-
'tp_rank': 2,
309-
'decode_tp_size': 8,
310-
'local_engine_id': 'engine_2'
311-
})
312-
thread = KVCacheSendingThread(**args)
313-
self.threads.append(thread)
314-
mock_tracker.assert_called_once()
315-
call_args = mock_tracker.call_args[0]
316-
call_kwargs = mock_tracker.call_args[1]
317-
if call_args:
318-
self.assertEqual(call_args[0], 2)
319-
self.assertEqual(call_args[1], 'engine_2')
320-
self.assertEqual(call_args[2], 8)
321-
else:
322-
self.assertEqual(call_kwargs['tp_rank'], 2)
323-
self.assertEqual(call_kwargs['local_engine_id'], 'engine_2')
324-
self.assertEqual(call_kwargs['target_count'], 8)
325-
326100
def test_thread_daemon_property(self):
327101
thread = KVCacheSendingThread(**self.common_args)
328102
self.threads.append(thread)
@@ -542,7 +316,7 @@ def test_handle_request(self, mock_send, mock_transfer):
542316
mock_transfer.assert_called_once_with(self.test_req)
543317
mock_send.assert_called_once_with("req1", "localhost", 6666)
544318
self.thread.task_tracker.update_done_task_count.assert_called_once_with(
545-
"req1", self.thread.tp_rank)
319+
"req1")
546320
self.mock_queue.task_done.assert_called_once()
547321

548322
@patch.object(KVCacheRecvingThread, '_get_remote_metadata')
@@ -675,9 +449,11 @@ def test_run_loop_normal(self, mock_handle):
675449
class MockVllmConfig:
676450

677451
def __init__(self):
452+
self.model_config = MagicMock()
678453
self.parallel_config = MagicMock()
679454
self.cache_config = MagicMock()
680455
self.kv_transfer_config = MagicMock()
456+
self.model_config.use_mla = True
681457
self.parallel_config.tensor_parallel_size = 2
682458
self.parallel_config.data_parallel_rank_local = 0
683459
self.parallel_config.data_parallel_size_local = 1
@@ -714,28 +490,40 @@ def __init__(self,
714490
class TestKVCacheTaskTracker(unittest.TestCase):
715491

716492
def setUp(self):
717-
self.tracker = KVCacheTaskTracker(tp_rank=0,
718-
local_engine_id="test_engine",
719-
target_count=2)
493+
self.tracker = KVCacheTaskTracker()
720494

721-
def test_update_task_count(self):
722-
self.assertEqual(len(self.tracker.done_task_counts), 0)
723-
self.assertEqual(len(self.tracker.finished_requests), 0)
724-
725-
self.tracker.update_done_task_count("req1", 0)
726-
self.tracker.update_done_task_count("req1", 1)
727-
728-
self.assertEqual(len(self.tracker.finished_requests), 1)
729-
self.assertTrue("req1" in self.tracker.finished_requests)
730-
731-
finished = self.tracker.get_and_clear_finished_requests()
732-
self.assertEqual(finished, {"req1"})
495+
def test_update_done_task_count(self):
733496
self.assertEqual(len(self.tracker.finished_requests), 0)
497+
self.assertEqual(len(self.tracker.delayed_free_requests), 0)
498+
499+
current_time = time.time()
500+
self.tracker.add_delayed_request("req_1", current_time)
501+
result = self.tracker.delayed_free_requests
502+
self.assertEqual(len(result), 1)
503+
self.assertEqual(result[0], ("req_1", current_time))
504+
505+
self.tracker.update_done_task_count("req_1")
506+
result_finished = self.tracker.finished_requests
507+
result_delayed = self.tracker.delayed_free_requests
508+
self.assertEqual(result_finished, {"req_1"})
509+
self.assertEqual(len(result_delayed), 0)
510+
511+
def test_retrieve_expired_requests(self):
512+
current_time = time.time()
513+
self.tracker.add_delayed_request("req_1", current_time - 600)
514+
self.tracker.add_delayed_request("req_2", current_time)
515+
result = self.tracker._retrieve_expired_requests()
516+
self.assertEqual(result, {
517+
"req_1",
518+
})
519+
result_delay = self.tracker.delayed_free_requests
520+
self.assertEqual(len(result_delay), 1)
521+
self.assertEqual(result_delay[0], ("req_2", current_time))
734522

735523
def test_duplicate_task_update(self):
736-
self.tracker.update_done_task_count("req1", 0)
737-
self.tracker.update_done_task_count("req1", 0)
738-
self.tracker.update_done_task_count("req1", 1)
524+
self.tracker.update_done_task_count("req1")
525+
self.tracker.update_done_task_count("req1")
526+
self.tracker.update_done_task_count("req1")
739527

740528
finished = self.tracker.get_and_clear_finished_requests()
741529
self.assertEqual(finished, {"req1"})
@@ -745,6 +533,9 @@ class TestMooncakeConnectorMetadata(unittest.TestCase):
745533

746534
def test_add_new_req(self):
747535
meta = MooncakeConnectorMetadata()
536+
self.assertEqual(len(meta.requests), 0)
537+
self.assertEqual(len(meta.requests_to_send), 0)
538+
748539
meta.add_new_req(request_id="req1",
749540
local_block_ids=[1, 2, 3],
750541
kv_transfer_params={
@@ -802,6 +593,10 @@ def test_build_connector_meta(self):
802593
self.assertEqual(meta.requests["req1"].remote_block_ids, [1, 2, 3])
803594
self.assertEqual(len(self.scheduler._reqs_need_recv), 0)
804595

596+
def test_get_finished_count(self):
597+
count = self.scheduler.get_finished_count()
598+
self.assertEqual(count, 2)
599+
805600

806601
class TestHelperFunctions(unittest.TestCase):
807602

0 commit comments

Comments
 (0)