Skip to content

Commit a424ab9

Browse files
authored
[Bug fix] Fix prefix cache in v1 (#3710)
* [Bug fix] Fix prefix cache in V1 * add comment
1 parent 10a95f8 commit a424ab9

File tree

2 files changed

+76
-9
lines changed

2 files changed

+76
-9
lines changed

fastdeploy/cache_manager/prefix_cache_manager.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,7 @@ def request_match_blocks(self, task, block_size, *args):
509509
self.metrics.req_count += 1
510510
input_ids = task.prompt_token_ids
511511
req_id = task.request_id
512-
logger.info(f"request_block_ids: start to allocate blocks for req_id {req_id}")
512+
logger.info(f"request_match_blocks: start to allocate blocks for req_id {req_id}")
513513
input_token_num = len(input_ids)
514514
common_block_ids = []
515515
# 1. match block
@@ -541,7 +541,7 @@ def request_match_blocks(self, task, block_size, *args):
541541
cpu_recv_block_ids=[],
542542
)
543543
else:
544-
raise Exception("Not enough GPU memory to allocate cache for matched CPU Cache")
544+
raise Exception("request_match_blocks: Not enough GPU memory to allocate cache for matched CPU Cache")
545545

546546
# record request cache info
547547
self.cache_info[req_id] = (match_block_node, input_ids)
@@ -563,11 +563,14 @@ def request_match_blocks(self, task, block_size, *args):
563563
if self.metrics.req_count % 10000 == 0:
564564
self.metrics.reset_metrics()
565565
logger.info(
566-
f"request_block_ids: request block for req_id {req_id}: common_block_ids {common_block_ids}"
566+
f"request_match_blocks: request block for req_id {req_id}: common_block_ids {common_block_ids}"
567567
)
568+
# set leaf node temporarily, then update it in update_cache_blocks
569+
self.req_leaf_map[req_id] = match_block_node
570+
self.leaf_req_map[match_block_node].add(req_id)
568571
return common_block_ids, matched_token_num, hit_info
569572
except Exception as e:
570-
logger.error(f"request_block_ids: error: {type(e)} {e}")
573+
logger.error(f"request_match_blocks: request_block_ids: error: {type(e)} {e}")
571574
raise e
572575

573576
def request_block_ids(self, task, block_size, dec_token_num, *args):
@@ -723,6 +726,43 @@ def release_block_ids(self, task):
723726
except Exception as e:
724727
logger.error(f"release_block_ids: error: {type(e)} {e}")
725728
raise e
729+
def free_nodes_directly(self, node):
730+
"""
731+
Recycle nodes by a query directly.
732+
"""
733+
with self.request_release_lock:
734+
try:
735+
total_gpu_free_count = 0
736+
while True:
737+
if node in self.gpu_lru_leaf_heap:
738+
self.gpu_lru_leaf_heap.remove(node)
739+
self.gpu_lru_leaf_set.remove(node)
740+
if node.shared_count == 0 and node.is_gpu_leaf_node: # 直接回收
741+
self._handle_free_gpu_node_without_cpu(node)
742+
logger.info(f"free_nodes_directly: node {node}")
743+
total_gpu_free_count += 1
744+
cur_node = node
745+
node = node.parent
746+
if cur_node.hash_value in node.children:
747+
del node.children[cur_node.hash_value]
748+
if not node.children:
749+
if node in self.gpu_lru_leaf_set:
750+
continue
751+
if (
752+
node != self.radix_tree_root
753+
and node.shared_count == 0
754+
and node.is_gpu_leaf_node
755+
and node.is_persistent is False
756+
):
757+
heapq.heappush(self.gpu_lru_leaf_heap, node)
758+
self.gpu_lru_leaf_set.add(node)
759+
else:
760+
break
761+
else:
762+
break
763+
except Exception as e:
764+
logger.error(f"free_nodes_directly: error: {type(e)} {e}")
765+
raise e
726766

727767
def _handle_free_gpu_node_without_cpu(self, node):
728768
"""
@@ -1066,6 +1106,15 @@ def _update_matched_node_info(self, req_id, last_node, current_time):
10661106
node.last_used_time = current_time
10671107
node.req_id_set.add(req_id)
10681108
node = node.parent
1109+
1110+
def decrease_request_share_count(self, req_id):
1111+
"""
1112+
Decrease node shared count
1113+
"""
1114+
node, input_ids = self.cache_info[req_id]
1115+
while node != self.radix_tree_root:
1116+
node.decrement_shared_count()
1117+
node = node.parent
10691118

10701119
def build_path(
10711120
self,

fastdeploy/engine/sched/resource_manager_v1.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@ def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_re
117117
preempted_req = self.running.pop()
118118
preempted_req.status = RequestStatus.PREEMPTED
119119
preempted_req.num_computed_tokens = 0
120-
preempted_req.prefill_block_num = 0
121120
self._free_blocks(preempted_req)
121+
preempted_req.prefill_block_num = None
122122
self.to_be_rescheduled_request_id_set.add(preempted_req.request_id)
123123
preempted_reqs.append(preempted_req)
124124
scheduled_reqs.append(self._prepare_preempt_task(preempted_req))
@@ -305,6 +305,7 @@ def schedule(self):
305305
if self.config.cache_config.enable_prefix_caching:
306306
success = self.get_prefix_cached_blocks(request)
307307
if not success:
308+
self._free_blocks(request)
308309
break
309310

310311
num_new_tokens = self._get_num_new_tokens(request, token_budget)
@@ -327,23 +328,33 @@ def schedule(self):
327328
self.stop_flags[allocated_position] = False
328329
self.req_dict[request.request_id] = allocated_position
329330
else:
331+
if self.config.cache_config.enable_prefix_caching:
332+
self._free_blocks(request)
330333
break
331334
elif request.status == RequestStatus.PREEMPTED:
332335
request.need_prefill_tokens = (
333336
request.num_total_tokens
334337
) # Before preempted task rescheduled, preempted task has been sent to engine, no more tokens are output, here num_total_tokens should be static and correct
338+
if self.config.cache_config.enable_prefix_caching:
339+
success = self.get_prefix_cached_blocks(request)
340+
if not success:
341+
self._free_blocks(request)
342+
break
335343
num_new_tokens = self._get_num_new_tokens(request, token_budget)
336344
num_new_block = self.get_new_block_nums(request, num_new_tokens)
337345
# Allocate blocks to prefill
338346
if self.cache_manager.can_allocate_gpu_blocks(num_new_block):
339-
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(num_new_block))
347+
if not request.get("skip_allocate", False):
348+
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(num_new_block))
340349
self.waiting.popleft()
341350
self.running.append(request)
342351
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
343352
token_budget -= num_new_tokens
344353
request.num_computed_tokens += num_new_tokens
345354
request.status = RequestStatus.RUNNING
346355
else:
356+
if self.config.cache_config.enable_prefix_caching:
357+
self._free_blocks(request)
347358
break
348359
else:
349360
llm_logger.error("Unknown request status type")
@@ -399,7 +410,7 @@ def get_prefix_cached_blocks(self, request: Request):
399410
main_process_metrics.prefix_cpu_cache_token_num.inc(request.cpu_cache_token_num)
400411

401412
if matched_token_num == request.prompt_token_ids_len:
402-
request.num_computed_tokens = matched_token_num - 1
413+
request.num_computed_tokens = matched_token_num - self.config.cache_config.block_size
403414
request.skip_allocate = True
404415
else:
405416
request.num_computed_tokens = matched_token_num
@@ -417,8 +428,15 @@ def add_request(self, request: Request) -> None:
417428
def _free_blocks(self, request: Request):
418429
if self.config.cache_config.enable_prefix_caching:
419430
# TODO(chengyanfu): support cache ouput blocks for prefix caching
420-
self.cache_manager.release_block_ids_async(request)
421-
self.cache_manager.recycle_gpu_blocks(request.block_tables[request.prefill_block_num :])
431+
if request.get("prefill_block_num", None) is None:
432+
leaf_node = self.cache_manager.req_leaf_map[request.request_id]
433+
self.cache_manager.decrease_request_share_count(request.request_id)
434+
self.cache_manager.free_nodes_directly(leaf_node)
435+
self.cache_manager.recycle_gpu_blocks(request.block_tables[request.cache_info[0]:])
436+
437+
else:
438+
self.cache_manager.release_block_ids_async(request)
439+
self.cache_manager.recycle_gpu_blocks(request.block_tables[request.prefill_block_num :])
422440
else:
423441
self.cache_manager.recycle_gpu_blocks(request.block_tables)
424442
request.block_tables = []

0 commit comments

Comments
 (0)