Skip to content

Commit 8436848

Browse files
authored
build trie in prefill and add hit rate (#4184)
1 parent 359c5a0 commit 8436848

File tree

5 files changed

+35
-1
lines changed

5 files changed

+35
-1
lines changed

lmdeploy/messages.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,7 @@ class ScheduleMetrics:
546546
active_blocks: int = 0
547547
cached_blocks: int = 0
548548
free_blocks: int = 0
549+
prefix_cache_hit_rate: float = 0
549550

550551

551552
@dataclass

lmdeploy/metrics/loggers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,9 @@ def log(self):
118118
f'Unfinished: {scheduler_stats.num_total_reqs-scheduler_stats.num_finished_reqs} reqs, '
119119
f'Running: {scheduler_stats.num_running_reqs} reqs, '
120120
f'Waiting: {scheduler_stats.num_waiting_reqs} reqs, '
121-
f'GPU KV cache usage: {scheduler_stats.gpu_cache_usage * 100 :.1f}%')
121+
f'GPU KV cache usage: {scheduler_stats.gpu_cache_usage * 100 :.1f}%, '
122+
f'Prefix cache hit rate: {scheduler_stats.prefix_cache_hit_rate * 100 :.1f}%')
123+
122124
print(log_msg, flush=True)
123125
self.log_spec_msg()
124126

lmdeploy/metrics/stats.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,15 @@ class SchedulerStats:
2020
num_running_reqs: Currently executing requests.
2121
num_waiting_reqs: Requests queued waiting for execution.
2222
gpu_cache_usage: Fraction of GPU KV blocks utilized (0.0 to 1.0).
23+
prefix_cache_hit_rate: Prefix caching hit rate.
2324
"""
2425

2526
num_total_reqs: int = 0
2627
num_finished_reqs: int = 0
2728
num_running_reqs: int = 0
2829
num_waiting_reqs: int = 0
2930
gpu_cache_usage: float = 0.0
31+
prefix_cache_hit_rate: float = 0.0
3032

3133
def __repr__(self):
3234
return ('SchedulerStats(\n'
@@ -35,12 +37,14 @@ def __repr__(self):
3537
f' num_running_reqs={self.num_running_reqs},\n'
3638
f' num_waiting_reqs={self.num_waiting_reqs},\n'
3739
f' gpu_cache_usage={self.gpu_cache_usage:.6f},\n'
40+
f' prefix_cache_hit_rate={self.prefix_cache_hit_rate:.6f},\n'
3841
')')
3942

4043
def update_from_schedule_metrics(self, scheduled_metrics: ScheduleMetrics):
4144
self.num_running_reqs = scheduled_metrics.active_seqs
4245
self.num_waiting_reqs = scheduled_metrics.waiting_seqs
4346
self.gpu_cache_usage = 1.0 - (scheduled_metrics.free_blocks / scheduled_metrics.total_blocks)
47+
self.prefix_cache_hit_rate = scheduled_metrics.prefix_cache_hit_rate
4448

4549

4650
class RequestStats:

lmdeploy/pytorch/paging/block_trie.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import heapq
3+
from dataclasses import dataclass
34
from typing import Dict, Set
45

56
import numpy as np
@@ -10,6 +11,20 @@
1011
from .block_manager import BaseBlockManager
1112

1213

14+
@dataclass
15+
class PrefixCacheStats:
16+
"""Prefix caching stats."""
17+
num_query_tokens: int = 0
18+
num_hit_tokens: int = 0
19+
20+
def reset(self):
21+
self.num_query_tokens = 0
22+
self.num_hit_tokens = 0
23+
24+
def hit_rate(self):
25+
return 0.0 if self.num_query_tokens <= 0 else float(self.num_hit_tokens) / self.num_query_tokens
26+
27+
1328
class Node:
1429
"""Node of block trie."""
1530

@@ -54,6 +69,11 @@ def __init__(self, cache_config: CacheConfig, block_manager: BaseBlockManager):
5469
# caches with different adapter should not be shared.
5570
self._roots: Dict[str, Node] = dict()
5671
self.leaves: Set[Node] = set()
72+
self.stats = PrefixCacheStats()
73+
74+
def hit_rate(self):
75+
"""Get hit rate."""
76+
return self.stats.hit_rate()
5777

5878
def get_root(self, adapter_name: str):
5979
"""Get root by adapter name."""
@@ -73,6 +93,7 @@ def match(self, seq: SchedulerSequence):
7393
curr: Node = getattr(logical_blocks, 'last_shared_node', None)
7494
if curr is None:
7595
curr = self.get_root(seq.adapter_name)
96+
init_num_matched = curr.num_matched
7697
num_matched = curr.num_matched
7798

7899
def __match_success(node: Node):
@@ -101,6 +122,10 @@ def __match_success(node: Node):
101122
seq.logical_blocks.append(matched_blocks)
102123
seq.set_step(num_matched)
103124

125+
# record prefix hit
126+
self.stats.num_query_tokens += seq.num_all_ids - init_num_matched
127+
self.stats.num_hit_tokens += num_matched - init_num_matched
128+
104129
seq.logical_blocks.last_shared_node = curr
105130

106131
def allocate(self, seq: SchedulerSequence):

lmdeploy/pytorch/paging/scheduler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ def _reorder_waiting():
235235

236236
# allocate session memory
237237
self.block_manager.allocate(seq, prealloc_size)
238+
self.block_trie.allocate(seq)
238239
if self.is_ssm:
239240
self.state_manager.allocate(seq)
240241
_to_running(seq)
@@ -451,4 +452,5 @@ def schedule_metrics(self):
451452
waiting_seqs=self.num_waiting() + self.num_running(),
452453
total_blocks=self.block_manager.num_gpu_blocks,
453454
free_blocks=self.block_manager.get_num_free_gpu_blocks(),
455+
prefix_cache_hit_rate=self.block_trie.hit_rate(),
454456
)

0 commit comments

Comments
 (0)