1010from vllm .v1 .kv_cache_interface import (FullAttentionSpec , KVCacheConfig ,
1111 KVCacheSpec )
1212from vllm .v1 .request import Request
13+ from vllm .logger import init_logger
14+
15+ logger = init_logger (__name__ )
1316
1417
1518class KVCacheCoordinator (ABC ):
@@ -74,18 +77,20 @@ def get_num_blocks_to_allocate(self, request_id: str, num_tokens: int,
7477 return num_blocks_to_allocate
7578
7679 def get_num_blocks_to_allocate_for_connector (
77- self , request_id : str , num_tokens : int ,
78- num_connector_prefix_tokens : int ,
79- new_computed_blocks : tuple [list [KVCacheBlock ],
80- ...], num_encoder_tokens : int ) -> int :
80+ self ,
81+ request_id : str ,
82+ num_tokens : int ,
83+ num_connector_cached_tokens : int ,
84+ new_computed_blocks : tuple [list [KVCacheBlock ],...],
85+ num_encoder_tokens : int ) -> int :
8186 """
8287 Get the # of blocks to allocate for request when using connector.
8388
8489 Args:
8590 request_id: The request ID.
8691 num_tokens: The total number of tokens that need a slot (including
8792 tokens that are already allocated).
88- num_connector_prefix_tokens : The number of tokens that hits
93+ num_connector_cached_tokens : The number of tokens that hits
8994 the prefix cache inside connector.
9095 new_computed_blocks: The new computed blocks just hitting the
9196 prefix caching.
@@ -97,19 +102,12 @@ def get_num_blocks_to_allocate_for_connector(
97102 """
98103 num_blocks_to_allocate = 0
99104 for i , manager in enumerate (self .single_type_managers ):
100- if isinstance (manager , CrossAttentionManager ):
101- # Cross-attention does not support prefix cache
102- # from connector yet.
103- num_blocks_to_allocate += \
104- manager .get_num_blocks_to_allocate_for_connector (
105- request_id , num_encoder_tokens , 0 , [])
106- else :
107- num_blocks_to_allocate += \
108- manager .get_num_blocks_to_allocate_for_connector (
109- request_id ,
110- num_tokens ,
111- num_connector_prefix_tokens ,
112- new_computed_blocks [i ])
105+ num_blocks_to_allocate += \
106+ manager .get_num_blocks_to_allocate_for_connector (
107+ request_id ,
108+ num_tokens ,
109+ num_connector_cached_tokens ,
110+ new_computed_blocks [i ])
113111 return num_blocks_to_allocate
114112
115113 def save_new_computed_blocks (
@@ -152,6 +150,22 @@ def allocate_new_blocks(
152150 manager , CrossAttentionManager ) else num_tokens )
153151 for manager in self .single_type_managers )
154152
153+
154+ def allocate_new_blocks_for_connector (
155+ self ,
156+ request_id : str ,
157+ num_connector_prefix_tokens : int ,
158+ num_new_tokens : int ,
159+ num_encoder_tokens : int = 0 ,
160+ ) -> tuple [list [KVCacheBlock ], ...]:
161+ return tuple (
162+ manager .allocate_new_blocks_for_connector (
163+ request_id ,
164+ num_connector_prefix_tokens ,
165+ num_encoder_tokens if isinstance (
166+ manager , CrossAttentionManager ) else num_new_tokens )
167+ for manager in self .single_type_managers )
168+
155169 def cache_blocks (self , request : Request , num_computed_tokens : int ) -> None :
156170 """
157171 Cache the blocks for the request.
@@ -196,8 +210,13 @@ def get_num_common_prefix_blocks(self, request_id: str,
196210 ]
197211 return num_blocks_per_group
198212
199- def remove_skipped_blocks (self , request_id : str ,
200- num_computed_tokens : int ) -> None :
213+ def remove_skipped_and_allocate_necessary (
214+ self ,
215+ request_id : str ,
216+ num_computed_tokens : int ,
217+ num_extra_tokens_from_connector : int ,
218+ num_encoder_tokens : int ,
219+ ) -> tuple [list [KVCacheBlock ], ...]:
201220 """
202221 Remove the blocks that are no longer needed from `blocks` and replace
203222 the removed blocks with null_block.
@@ -206,8 +225,15 @@ def remove_skipped_blocks(self, request_id: str,
206225 request_id: The request ID.
207226 num_computed_tokens: The number of tokens that have been computed.
208227 """
209- for manager in self .single_type_managers :
210- manager .remove_skipped_blocks (request_id , num_computed_tokens )
228+ return tuple (
229+ manager .remove_skipped_and_allocate_necessary (
230+ request_id ,
231+ num_encoder_tokens if isinstance (
232+ manager , CrossAttentionManager ) else num_computed_tokens ,
233+ num_extra_tokens_from_connector ,
234+ )
235+ for manager in self .single_type_managers
236+ )
211237
212238 def get_blocks (self , request_id : str ) -> tuple [list [KVCacheBlock ], ...]:
213239 """
0 commit comments