@@ -82,6 +82,54 @@ def get_num_blocks_to_allocate(
8282 for blk in new_computed_blocks )
8383 return num_new_blocks + num_evictable_computed_blocks
8484
85+ def get_num_blocks_to_allocate_for_connector (
86+ self ,
87+ request_id : str ,
88+ num_connector_prefix_tokens : int ,
89+ num_new_tokens : int ,
90+ new_computed_blocks : list [KVCacheBlock ],
91+ ) -> int :
92+ """
93+ Get the # of blocks to allocate for the request when using connector.
94+
95+ NOTE(Kuntai):
96+ The tricky part is the external prefix from the connector. For
97+ long external prefix, as existing `get_num_blocks_to_allocate`
98+ assumes we allocate for all tokens and for all layers, it will
99+ make vLLM unable to allocate for long external prefix cache,
100+ even if the prefix cache can fit into the GPU memory after
101+ evicting tokens outside the sliding window.
102+
103+ So we need a new `get_num_blocks_to_allocate_for_connector`
104+ and a new `allocate_new_blocks_for_connector` to handle this case.
105+
106+ NOTE(Kuntai): this function returns an *upper bound* of the # of
107+ blocks needed to be allocated.
108+
109+ Args:
110+ request_id: The request ID.
111+ num_connector_prefix_tokens: The number of tokens from the external
112+ prefix.
113+ num_new_tokens: The number of new tokens.
114+ new_computed_blocks: The new computed blocks just hitting the
115+ prefix caching.
116+
117+ Returns:
118+ The upper bound of # of blocks to allocate.
119+ """
120+ total_tokens = num_connector_prefix_tokens + num_new_tokens
121+ num_required_blocks = cdiv (total_tokens , self .block_size )
122+ num_new_blocks = (num_required_blocks - len (new_computed_blocks ) -
123+ len (self .req_to_blocks [request_id ]))
124+ # If a computed block of a request is an eviction candidate (in the
125+ # free queue and ref_cnt == 0), it will be changed from a free block
126+ # to a computed block when the request is allocated, so we also count
127+ # it as needed to be allocated.
128+ num_evictable_computed_blocks = sum (
129+ blk .ref_cnt == 0 and not blk .is_null
130+ for blk in new_computed_blocks )
131+ return num_new_blocks + num_evictable_computed_blocks
132+
85133 def save_new_computed_blocks (
86134 self , request_id : str ,
87135 new_computed_blocks : list [KVCacheBlock ]) -> None :
@@ -127,6 +175,31 @@ def allocate_new_blocks(self, request_id: str,
127175 req_blocks .extend (new_blocks )
128176 return new_blocks
129177
178+ def allocate_new_blocks_for_connector (
179+ self ,
180+ request_id : str ,
181+ num_connector_prefix_tokens : int ,
182+ num_new_tokens : int ,
183+ ) -> list [KVCacheBlock ]:
184+ """
185+ Allocate new blocks for the request when using connector.
186+
187+ NOTE(Kuntai): we need to distinguish between the prefix
188+ tokens and the new tokens. Check the sliding window
189+ layer implementation for more explanation on this.
190+
191+ Args:
192+ request_id: The request ID.
193+ num_connector_prefix_tokens: The number of tokens from the external
194+ prefix.
195+ num_new_tokens: The number of new tokens.
196+
197+ Returns:
198+ The new allocated blocks.
199+ """
200+ return self .allocate_new_blocks (
201+ request_id , num_connector_prefix_tokens + num_new_tokens )
202+
130203 def cache_blocks (self , request : Request , num_tokens : int ) -> None :
131204 """
132205 Cache the blocks for the request.
@@ -390,6 +463,59 @@ def get_num_common_prefix_blocks(self, request_id: str,
390463 """
391464 return 0
392465
466+ def get_num_blocks_to_allocate_for_connector (
467+ self ,
468+ request_id : str ,
469+ num_connector_prefix_tokens : int ,
470+ num_new_tokens : int ,
471+ new_computed_blocks : list [KVCacheBlock ],
472+ ) -> int :
473+ """
474+ Get the # of blocks to allocate for the request with prefix cache
475+ from connector.
476+
477+ The maximum # of blocks we need is the last sliding window in the
478+ prefix cache, plus the blocks for new tokens.
479+ """
480+
481+ # The maximum # of blocks we need for sliding window layer is
482+ # one sliding window of blocks for entire prefix, plus the new tokens.
483+ max_blocks = cdiv (self .sliding_window + num_new_tokens ,
484+ self .block_size )
485+ return min (
486+ super ().get_num_blocks_to_allocate_for_connector (
487+ request_id , num_connector_prefix_tokens , num_new_tokens ,
488+ new_computed_blocks ), max_blocks )
489+
490+ def allocate_new_blocks_for_connector (
491+ self ,
492+ request_id : str ,
493+ num_connector_prefix_tokens : int ,
494+ num_new_tokens : int ,
495+ ) -> list [KVCacheBlock ]:
496+ # Remove the blocks that are no longer be in the sliding window and
497+ # skipped during the attention computation.
498+ last_useful_token = num_connector_prefix_tokens - \
499+ self .sliding_window + 1
500+ last_useful_block = last_useful_token // self .block_size
501+ blocks = self .req_to_blocks [request_id ]
502+ removed_blocks : list [KVCacheBlock ] = []
503+
504+ # Free blocks outside sliding window.
505+ for i in range (len (blocks )):
506+ if i < last_useful_block and not blocks [i ].is_null :
507+ removed_blocks .append (blocks [i ])
508+ blocks [i ] = self ._null_block
509+ self .block_pool .free_blocks (removed_blocks )
510+
511+ # pad blocks with null blocks to the length of last_useful_block
512+ self .req_to_blocks [request_id ].extend (
513+ [self ._null_block ] * (last_useful_block - len (blocks )))
514+
515+ # then fall back to normal allocation.
516+ return super ().allocate_new_blocks_for_connector (
517+ request_id , num_connector_prefix_tokens , num_new_tokens )
518+
393519
394520class ChunkedLocalAttentionManager (SingleTypeKVCacheManager ):
395521
0 commit comments