Skip to content

Commit f2085d9

Browse files
committed
add a new API to KV cache coordinator
Signed-off-by: KuntaiDu <[email protected]>
1 parent eb838a0 commit f2085d9

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

vllm/v1/core/kv_cache_coordinator.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,45 @@ def get_num_blocks_to_allocate(self, request_id: str, num_tokens: int,
7373
request_id, num_tokens, new_computed_blocks[i])
7474
return num_blocks_to_allocate
7575

76+
def get_num_blocks_to_allocate_for_connector(
77+
self, request_id: str, num_tokens: int,
78+
num_connector_prefix_tokens: int,
79+
new_computed_blocks: tuple[list[KVCacheBlock],
80+
...], num_encoder_tokens: int) -> int:
81+
"""
82+
Get the # of blocks to allocate for request when using connector.
83+
84+
Args:
85+
request_id: The request ID.
86+
num_tokens: The total number of tokens that need a slot (including
87+
tokens that are already allocated).
88+
num_connector_prefix_tokens: The number of tokens that hits
89+
the prefix cache inside connector.
90+
new_computed_blocks: The new computed blocks just hitting the
91+
prefix caching.
92+
num_encoder_tokens: The number of encoder tokens for allocating
93+
blocks for cross-attention.
94+
95+
Returns:
96+
The number of blocks.
97+
"""
98+
num_blocks_to_allocate = 0
99+
for i, manager in enumerate(self.single_type_managers):
100+
if isinstance(manager, CrossAttentionManager):
101+
# Cross-attention does not support prefix cache
102+
# from connector yet.
103+
num_blocks_to_allocate += \
104+
manager.get_num_blocks_to_allocate_for_connector(
105+
request_id, num_encoder_tokens, 0, [])
106+
else:
107+
num_blocks_to_allocate += \
108+
manager.get_num_blocks_to_allocate_for_connector(
109+
request_id,
110+
num_tokens,
111+
num_connector_prefix_tokens,
112+
new_computed_blocks[i])
113+
return num_blocks_to_allocate
114+
76115
def save_new_computed_blocks(
77116
self, request_id: str,
78117
new_computed_blocks: tuple[list[KVCacheBlock], ...]) -> None:

0 commit comments

Comments
 (0)