1212import msgspec
1313import zmq
1414from vllm .utils import make_zmq_path
15- from zmq import Context # type: ignore
1615
1716fake_engine = types .ModuleType ("mooncake.engine" )
1817fake_engine .TransferEngine = MagicMock () # type: ignore[attr-defined]
3231class TestKVCacheTaskTrackerInit (unittest .TestCase ):
3332
3433 def test_init_basic_properties (self ):
35- tracker = KVCacheTaskTracker (tp_rank = 1 ,
36- local_engine_id = "engine1" ,
37- target_count = 10 )
38- self .assertEqual (tracker .tp_rank , 1 )
39- self .assertEqual (tracker .local_engine_id , "engine1" )
40- self .assertEqual (tracker .target_count , 10 )
34+ tracker = KVCacheTaskTracker ()
4135 self .assertIsInstance (tracker .done_task_lock , type (threading .Lock ()))
42- self .assertIsInstance (tracker .done_task_counts , defaultdict )
4336 self .assertIsInstance (tracker .finished_requests , set )
44-
45- def test_socket_path_generation (self ):
46- tracker = KVCacheTaskTracker (tp_rank = 1 ,
47- local_engine_id = "engine42" ,
48- target_count = 1 )
49- self .assertEqual (tracker .socket_path ,
50- "ipc:///tmp/vllm_mooncake_connector_engine42.ipc" )
51-
52- @patch ("vllm_ascend.distributed.mooncake_connector.threading.Thread" )
53- def test_tp_rank_zero_initialization (self , mock_thread ):
54- tracker = KVCacheTaskTracker (tp_rank = 0 ,
55- local_engine_id = "test" ,
56- target_count = 1 )
57- mock_thread .assert_called_once_with (
58- target = tracker ._listen_for_completion_signals ,
59- daemon = True ,
60- name = "KVCacheTaskTrackerListenerThread" )
61- mock_thread .return_value .start .assert_called_once ()
62- self .assertIsNone (tracker .socket )
63- self .assertTrue (tracker .listener .daemon )
64-
65- @patch ("vllm_ascend.distributed.mooncake_connector.make_zmq_socket" )
66- @patch ("vllm_ascend.distributed.mooncake_connector.logger" )
67- def test_tp_rank_non_zero_initialization (self , mock_logger ,
68- mock_make_zmq_socket ):
69- mock_socket = MagicMock ()
70- mock_make_zmq_socket .return_value = mock_socket
71- tracker = KVCacheTaskTracker (tp_rank = 1 ,
72- local_engine_id = "test" ,
73- target_count = 1 )
74- mock_make_zmq_socket .assert_called_once_with (
75- ctx = unittest .mock .ANY ,
76- path = "ipc:///tmp/vllm_mooncake_connector_test.ipc" ,
77- socket_type = zmq .PUSH , # type: ignore
78- bind = False )
79- mock_logger .info .assert_called_once_with (
80- "Connecting to transfer socket at %s" ,
81- "ipc:///tmp/vllm_mooncake_connector_test.ipc" )
82- self .assertIsNone (tracker .listener )
83- self .assertEqual (tracker .socket , mock_socket )
84-
85-
86- class TestKVCacheTaskTrackerListenMethod (unittest .TestCase ):
87-
88- def setUp (self ):
89- self .tp_rank = 0
90- self .local_engine_id = "test_engine_ut"
91- self .target_count = 3
92- self .tracker = KVCacheTaskTracker (self .tp_rank , self .local_engine_id ,
93- self .target_count )
94- self .original_listen = self .tracker ._listen_for_completion_signals
95-
96- def tearDown (self ):
97- self .tracker ._listen_for_completion_signals = self .original_listen
98- Context .instance ().term ()
99- time .sleep (0.1 )
100-
101- def test_normal_message_processing (self ):
102- listener_thread = threading .Thread (
103- target = self .tracker ._listen_for_completion_signals , daemon = True )
104- listener_thread .start ()
105- time .sleep (0.2 )
106- test_messages = [("request_001" , 1 ), ("request_001" , 2 ),
107- ("request_002" , 0 ), ("request_003" , 1 )]
108- ctx = Context ()
109- sender_socket = ctx .socket (zmq .PUSH ) # type: ignore
110- sender_socket .connect (self .tracker .socket_path )
111- for msg in test_messages :
112- sender_socket .send_pyobj (msg )
113- time .sleep (0.05 )
114- sender_socket .close ()
115- time .sleep (0.2 )
116-
117- with self .tracker .done_task_lock :
118- self .assertEqual (len (self .tracker .done_task_counts ["request_001" ]),
119- 2 )
120- self .assertIn (1 , self .tracker .done_task_counts ["request_001" ])
121- self .assertIn (2 , self .tracker .done_task_counts ["request_001" ])
122- self .assertEqual (len (self .tracker .done_task_counts ["request_002" ]),
123- 1 )
124- self .assertIn (0 , self .tracker .done_task_counts ["request_002" ])
125- self .assertEqual (len (self .tracker .done_task_counts ["request_003" ]),
126- 1 )
127- self .assertIn (1 , self .tracker .done_task_counts ["request_003" ])
128-
129- @patch ("vllm_ascend.distributed.mooncake_connector.make_zmq_socket" ,
130- autospec = True )
131- def test_listen_with_timeout (self , mock_make_socket ):
132- mock_socket = MagicMock ()
133-
134- def mock_recv ():
135- start = time .time ()
136- while time .time () - start < 0.5 :
137- time .sleep (0.01 )
138- return ("req1" , 0 )
139-
140- mock_socket .recv_pyobj = mock_recv
141- mock_make_socket .return_value = mock_socket
142-
143- test_thread = threading .Thread (
144- target = self .tracker ._listen_for_completion_signals , daemon = True )
145- test_thread .start ()
146- test_thread .join (timeout = 1.0 )
147- mock_make_socket .assert_called_once ()
148-
149-
150- class TestKVCacheTaskTrackerTP (unittest .TestCase ):
151-
152- def setUp (self ):
153- self .local_engine_id = "test_engine"
154- self .target_count = 3
155-
156- def test_update_done_task_count_tp_rank_0 (self ):
157- tracker = KVCacheTaskTracker (tp_rank = 0 ,
158- local_engine_id = self .local_engine_id ,
159- target_count = self .target_count )
160- test_request_id = "test_req_001"
161- test_tp_rank = 1
162- tracker .update_done_task_count (test_request_id , test_tp_rank )
163- with tracker .done_task_lock :
164- self .assertEqual (len (tracker .done_task_counts [test_request_id ]), 1 )
165- self .assertIn (test_tp_rank ,
166- tracker .done_task_counts [test_request_id ])
167-
168- @patch ("vllm_ascend.distributed.mooncake_connector.make_zmq_socket" ,
169- autospec = True )
170- def test_update_done_task_count_non_zero_tp (self , mock_make_socket ):
171- mock_socket = MagicMock ()
172- mock_make_socket .return_value = mock_socket
173- tracker = KVCacheTaskTracker (tp_rank = 1 ,
174- local_engine_id = self .local_engine_id ,
175- target_count = self .target_count )
176- test_request_id = "test_req_002"
177- test_tp_rank = 1
178- tracker .update_done_task_count (test_request_id , test_tp_rank )
179- mock_socket .send_pyobj .assert_called_once_with (
180- (test_request_id , test_tp_rank ))
181- with tracker .done_task_lock :
182- self .assertNotIn (test_request_id , tracker .done_task_counts )
183-
184- @patch ("vllm_ascend.distributed.mooncake_connector.logger" , autospec = True )
185- @patch ("vllm_ascend.distributed.mooncake_connector.make_zmq_socket" ,
186- autospec = True )
187- def test_update_done_task_count_logging (self , mock_make_socket ,
188- mock_logger ):
189- mock_socket = MagicMock ()
190- mock_make_socket .return_value = mock_socket
191- tracker = KVCacheTaskTracker (tp_rank = 2 ,
192- local_engine_id = self .local_engine_id ,
193- target_count = self .target_count )
194- test_request_id = "test_req_003"
195- tracker .update_done_task_count (test_request_id , 2 )
196- mock_logger .debug .assert_called_once_with (
197- "Sent done signal for request %s to tp 0" , test_request_id )
198-
199- @patch ("vllm_ascend.distributed.mooncake_connector.make_zmq_socket" ,
200- autospec = True )
201- def test_update_multiple_calls (self , mock_make_socket ):
202- mock_socket = MagicMock ()
203- mock_make_socket .return_value = mock_socket
204- tracker = KVCacheTaskTracker (tp_rank = 1 ,
205- local_engine_id = self .local_engine_id ,
206- target_count = self .target_count )
207- test_data = [("req1" , 1 ), ("req1" , 1 ), ("req2" , 1 )]
208- for req_id , rank in test_data :
209- tracker .update_done_task_count (req_id , rank )
210- self .assertEqual (mock_socket .send_pyobj .call_count , 3 )
211- mock_socket .send_pyobj .assert_called_with (("req2" , 1 ))
37+ self .assertIsInstance (tracker .delayed_free_requests , deque )
21238
21339
21440class TestGetAndClearFinishedSingleRequests (unittest .TestCase ):
21541
21642 def setUp (self ):
217- self .tracker = KVCacheTaskTracker (tp_rank = 0 ,
218- local_engine_id = "test" ,
219- target_count = 3 )
43+ self .tracker = KVCacheTaskTracker ()
22044 self .tracker .finished_requests = set ()
221- self .tracker .done_task_counts = defaultdict (set )
22245 self .tracker .done_task_lock = threading .Lock ()
22346
22447 def test_empty_requests (self ):
@@ -251,14 +74,6 @@ def test_concurrent_access(self, mock_logger):
25174 self .assertEqual (sum (1 for r in results if r ), 1 )
25275 self .assertEqual (len (self .tracker .finished_requests ), 0 )
25376
254- def test_after_increment (self ):
255- self .tracker ._increment_task_count ("req_123" , 0 )
256- self .tracker ._increment_task_count ("req_123" , 1 )
257- self .tracker ._increment_task_count ("req_123" , 2 )
258- result = self .tracker .get_and_clear_finished_requests ()
259- self .assertEqual (result , {"req_123" })
260- self .assertEqual (self .tracker .get_and_clear_finished_requests (), set ())
261-
26277
26378class TestKVCacheSendingThreadInit (unittest .TestCase ):
26479
@@ -282,47 +97,6 @@ def tearDown(self):
28297 if hasattr (thread , 'is_alive' ) and thread .is_alive ():
28398 thread .join (timeout = 0.1 )
28499
285- @patch ('vllm_ascend.distributed.mooncake_connector.KVCacheTaskTracker' )
286- def test_initialization_basic (self , mock_tracker ):
287- thread = KVCacheSendingThread (** self .common_args )
288- self .threads .append (thread )
289- self .assertEqual (thread .tp_rank , 1 )
290- self .assertEqual (thread .decode_tp_size , 4 )
291- self .assertEqual (thread .local_engine_id , 'engine_1' )
292- mock_tracker .assert_called_once ()
293- args = mock_tracker .call_args [0 ]
294- kwargs = mock_tracker .call_args [1 ]
295- if args :
296- self .assertEqual (args [0 ], 1 )
297- self .assertEqual (args [1 ], 'engine_1' )
298- self .assertEqual (args [2 ], 4 )
299- else :
300- self .assertEqual (kwargs ['tp_rank' ], 1 )
301- self .assertEqual (kwargs ['local_engine_id' ], 'engine_1' )
302- self .assertEqual (kwargs ['target_count' ], 4 )
303-
304- @patch ('vllm_ascend.distributed.mooncake_connector.KVCacheTaskTracker' )
305- def test_task_tracker_initialization (self , mock_tracker ):
306- args = self .common_args .copy ()
307- args .update ({
308- 'tp_rank' : 2 ,
309- 'decode_tp_size' : 8 ,
310- 'local_engine_id' : 'engine_2'
311- })
312- thread = KVCacheSendingThread (** args )
313- self .threads .append (thread )
314- mock_tracker .assert_called_once ()
315- call_args = mock_tracker .call_args [0 ]
316- call_kwargs = mock_tracker .call_args [1 ]
317- if call_args :
318- self .assertEqual (call_args [0 ], 2 )
319- self .assertEqual (call_args [1 ], 'engine_2' )
320- self .assertEqual (call_args [2 ], 8 )
321- else :
322- self .assertEqual (call_kwargs ['tp_rank' ], 2 )
323- self .assertEqual (call_kwargs ['local_engine_id' ], 'engine_2' )
324- self .assertEqual (call_kwargs ['target_count' ], 8 )
325-
326100 def test_thread_daemon_property (self ):
327101 thread = KVCacheSendingThread (** self .common_args )
328102 self .threads .append (thread )
@@ -542,7 +316,7 @@ def test_handle_request(self, mock_send, mock_transfer):
542316 mock_transfer .assert_called_once_with (self .test_req )
543317 mock_send .assert_called_once_with ("req1" , "localhost" , 6666 )
544318 self .thread .task_tracker .update_done_task_count .assert_called_once_with (
545- "req1" , self . thread . tp_rank )
319+ "req1" )
546320 self .mock_queue .task_done .assert_called_once ()
547321
548322 @patch .object (KVCacheRecvingThread , '_get_remote_metadata' )
@@ -675,9 +449,11 @@ def test_run_loop_normal(self, mock_handle):
675449class MockVllmConfig :
676450
677451 def __init__ (self ):
452+ self .model_config = MagicMock ()
678453 self .parallel_config = MagicMock ()
679454 self .cache_config = MagicMock ()
680455 self .kv_transfer_config = MagicMock ()
456+ self .model_config .use_mla = True
681457 self .parallel_config .tensor_parallel_size = 2
682458 self .parallel_config .data_parallel_rank_local = 0
683459 self .parallel_config .data_parallel_size_local = 1
@@ -714,28 +490,40 @@ def __init__(self,
714490class TestKVCacheTaskTracker (unittest .TestCase ):
715491
716492 def setUp (self ):
717- self .tracker = KVCacheTaskTracker (tp_rank = 0 ,
718- local_engine_id = "test_engine" ,
719- target_count = 2 )
493+ self .tracker = KVCacheTaskTracker ()
720494
721- def test_update_task_count (self ):
722- self .assertEqual (len (self .tracker .done_task_counts ), 0 )
723- self .assertEqual (len (self .tracker .finished_requests ), 0 )
724-
725- self .tracker .update_done_task_count ("req1" , 0 )
726- self .tracker .update_done_task_count ("req1" , 1 )
727-
728- self .assertEqual (len (self .tracker .finished_requests ), 1 )
729- self .assertTrue ("req1" in self .tracker .finished_requests )
730-
731- finished = self .tracker .get_and_clear_finished_requests ()
732- self .assertEqual (finished , {"req1" })
495+ def test_update_done_task_count (self ):
733496 self .assertEqual (len (self .tracker .finished_requests ), 0 )
497+ self .assertEqual (len (self .tracker .delayed_free_requests ), 0 )
498+
499+ current_time = time .time ()
500+ self .tracker .add_delayed_request ("req_1" , current_time )
501+ result = self .tracker .delayed_free_requests
502+ self .assertEqual (len (result ), 1 )
503+ self .assertEqual (result [0 ], ("req_1" , current_time ))
504+
505+ self .tracker .update_done_task_count ("req_1" )
506+ result_finished = self .tracker .finished_requests
507+ result_delayed = self .tracker .delayed_free_requests
508+ self .assertEqual (result_finished , {"req_1" })
509+ self .assertEqual (len (result_delayed ), 0 )
510+
511+ def test_retrieve_expired_requests (self ):
512+ current_time = time .time ()
513+ self .tracker .add_delayed_request ("req_1" , current_time - 600 )
514+ self .tracker .add_delayed_request ("req_2" , current_time )
515+ result = self .tracker ._retrieve_expired_requests ()
516+ self .assertEqual (result , {
517+ "req_1" ,
518+ })
519+ result_delay = self .tracker .delayed_free_requests
520+ self .assertEqual (len (result_delay ), 1 )
521+ self .assertEqual (result_delay [0 ], ("req_2" , current_time ))
734522
735523 def test_duplicate_task_update (self ):
736- self .tracker .update_done_task_count ("req1" , 0 )
737- self .tracker .update_done_task_count ("req1" , 0 )
738- self .tracker .update_done_task_count ("req1" , 1 )
524+ self .tracker .update_done_task_count ("req1" )
525+ self .tracker .update_done_task_count ("req1" )
526+ self .tracker .update_done_task_count ("req1" )
739527
740528 finished = self .tracker .get_and_clear_finished_requests ()
741529 self .assertEqual (finished , {"req1" })
@@ -745,6 +533,9 @@ class TestMooncakeConnectorMetadata(unittest.TestCase):
745533
746534 def test_add_new_req (self ):
747535 meta = MooncakeConnectorMetadata ()
536+ self .assertEqual (len (meta .requests ), 0 )
537+ self .assertEqual (len (meta .requests_to_send ), 0 )
538+
748539 meta .add_new_req (request_id = "req1" ,
749540 local_block_ids = [1 , 2 , 3 ],
750541 kv_transfer_params = {
@@ -802,6 +593,10 @@ def test_build_connector_meta(self):
802593 self .assertEqual (meta .requests ["req1" ].remote_block_ids , [1 , 2 , 3 ])
803594 self .assertEqual (len (self .scheduler ._reqs_need_recv ), 0 )
804595
596+ def test_get_finished_count (self ):
597+ count = self .scheduler .get_finished_count ()
598+ self .assertEqual (count , 2 )
599+
805600
806601class TestHelperFunctions (unittest .TestCase ):
807602
0 commit comments