Skip to content

Commit 958b797

Browse files
committed
cache code change
Signed-off-by: KuntaiDu <[email protected]>
1 parent f2085d9 commit 958b797

File tree

4 files changed

+369
-360
lines changed

4 files changed

+369
-360
lines changed

vllm/v1/core/kv_cache_coordinator.py

Lines changed: 48 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
1111
KVCacheSpec)
1212
from vllm.v1.request import Request
13+
from vllm.logger import init_logger
14+
15+
logger = init_logger(__name__)
1316

1417

1518
class KVCacheCoordinator(ABC):
@@ -74,18 +77,20 @@ def get_num_blocks_to_allocate(self, request_id: str, num_tokens: int,
7477
return num_blocks_to_allocate
7578

7679
def get_num_blocks_to_allocate_for_connector(
77-
self, request_id: str, num_tokens: int,
78-
num_connector_prefix_tokens: int,
79-
new_computed_blocks: tuple[list[KVCacheBlock],
80-
...], num_encoder_tokens: int) -> int:
80+
self,
81+
request_id: str,
82+
num_tokens: int,
83+
num_connector_cached_tokens: int,
84+
new_computed_blocks: tuple[list[KVCacheBlock],...],
85+
num_encoder_tokens: int) -> int:
8186
"""
8287
Get the # of blocks to allocate for request when using connector.
8388
8489
Args:
8590
request_id: The request ID.
8691
num_tokens: The total number of tokens that need a slot (including
8792
tokens that are already allocated).
88-
num_connector_prefix_tokens: The number of tokens that hits
93+
num_connector_cached_tokens: The number of tokens that hits
8994
the prefix cache inside connector.
9095
new_computed_blocks: The new computed blocks just hitting the
9196
prefix caching.
@@ -97,19 +102,12 @@ def get_num_blocks_to_allocate_for_connector(
97102
"""
98103
num_blocks_to_allocate = 0
99104
for i, manager in enumerate(self.single_type_managers):
100-
if isinstance(manager, CrossAttentionManager):
101-
# Cross-attention does not support prefix cache
102-
# from connector yet.
103-
num_blocks_to_allocate += \
104-
manager.get_num_blocks_to_allocate_for_connector(
105-
request_id, num_encoder_tokens, 0, [])
106-
else:
107-
num_blocks_to_allocate += \
108-
manager.get_num_blocks_to_allocate_for_connector(
109-
request_id,
110-
num_tokens,
111-
num_connector_prefix_tokens,
112-
new_computed_blocks[i])
105+
num_blocks_to_allocate += \
106+
manager.get_num_blocks_to_allocate_for_connector(
107+
request_id,
108+
num_tokens,
109+
num_connector_cached_tokens,
110+
new_computed_blocks[i])
113111
return num_blocks_to_allocate
114112

115113
def save_new_computed_blocks(
@@ -152,6 +150,22 @@ def allocate_new_blocks(
152150
manager, CrossAttentionManager) else num_tokens)
153151
for manager in self.single_type_managers)
154152

153+
154+
def allocate_new_blocks_for_connector(
155+
self,
156+
request_id: str,
157+
num_connector_prefix_tokens: int,
158+
num_new_tokens: int,
159+
num_encoder_tokens: int = 0,
160+
) -> tuple[list[KVCacheBlock], ...]:
161+
return tuple(
162+
manager.allocate_new_blocks_for_connector(
163+
request_id,
164+
num_connector_prefix_tokens,
165+
num_encoder_tokens if isinstance(
166+
manager, CrossAttentionManager) else num_new_tokens)
167+
for manager in self.single_type_managers)
168+
155169
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
156170
"""
157171
Cache the blocks for the request.
@@ -196,8 +210,13 @@ def get_num_common_prefix_blocks(self, request_id: str,
196210
]
197211
return num_blocks_per_group
198212

199-
def remove_skipped_blocks(self, request_id: str,
200-
num_computed_tokens: int) -> None:
213+
def remove_skipped_and_allocate_necessary(
214+
self,
215+
request_id: str,
216+
num_computed_tokens: int,
217+
num_extra_tokens_from_connector: int,
218+
num_encoder_tokens: int,
219+
) -> tuple[list[KVCacheBlock], ...]:
201220
"""
202221
Remove the blocks that are no longer needed from `blocks` and replace
203222
the removed blocks with null_block.
@@ -206,8 +225,15 @@ def remove_skipped_blocks(self, request_id: str,
206225
request_id: The request ID.
207226
num_computed_tokens: The number of tokens that have been computed.
208227
"""
209-
for manager in self.single_type_managers:
210-
manager.remove_skipped_blocks(request_id, num_computed_tokens)
228+
return tuple(
229+
manager.remove_skipped_and_allocate_necessary(
230+
request_id,
231+
num_encoder_tokens if isinstance(
232+
manager, CrossAttentionManager) else num_computed_tokens,
233+
num_extra_tokens_from_connector,
234+
)
235+
for manager in self.single_type_managers
236+
)
211237

212238
def get_blocks(self, request_id: str) -> tuple[list[KVCacheBlock], ...]:
213239
"""

0 commit comments

Comments
 (0)