Skip to content

Commit df4b761

Browse files
authored
Improve turbomind's prefix cache (#3835)
* compile successfully * trail whitespaces * update * fix linting * fix linting * fix linting * put kInconsistency to DisableInvalidRequests * update according to reviewer comments * fix according to reviewer's comments * interactive chat cannot be used when prefix caching is enabled * fix * fix according to reviewer comments * fix linting
1 parent 8be72a5 commit df4b761

File tree

11 files changed

+291
-143
lines changed

11 files changed

+291
-143
lines changed

lmdeploy/cli/chat.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ def input_prompt():
1414

1515
def build_pipe(model_path, backend, **kwargs):
1616
engine_config = None
17+
if kwargs.get('enable_prefix_caching', False):
18+
print('interactive chat cannot be used when prefix caching is enabled')
19+
exit(-1)
1720
if backend == 'turbomind':
1821
engine_config = TurbomindEngineConfig()
1922
for key, value in kwargs.items():

lmdeploy/messages.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,7 @@ class ResponseType(enum.Enum):
410410
INPUT_LENGTH_ERROR = enum.auto()
411411
INTERNAL_ENGINE_ERROR = enum.auto()
412412
CANCEL = enum.auto()
413+
PREFIX_CACHE_CONFLICT_INTERACTIVE_MODE = enum.auto()
413414

414415

415416
@dataclass
@@ -444,6 +445,15 @@ class Response:
444445
last_hidden_state: torch.Tensor = None
445446
index: int = 0
446447

448+
def __repr__(self):
449+
logits = 'logits=None' if self.logits is None else f'logits.shape={self.logits.shape}\nlogits={self.logits}'
450+
hidden_state = (
451+
'last_hidden_state=None' if self.last_hidden_state is None else
452+
f'last_hidden_state.shape={self.last_hidden_state.shape}\nlast_hidden_state={self.last_hidden_state}')
453+
s = (f'text={self.text}\ngenerate_token_len={self.generate_token_len}\nfinish_reason="{self.finish_reason}"\n'
454+
f'token_ids={self.token_ids}\nlog_probs={self.logprobs}\n{logits}\n{hidden_state}')
455+
return s
456+
447457

448458
# modified from https://github.com/vllm-project/vllm/blob/main/vllm/v1/engine/__init__.py
449459
class EventType(enum.IntEnum):

lmdeploy/serve/async_engine.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -690,11 +690,6 @@ async def generate(
690690
gen_config.stop_token_ids = self.stop_words
691691
gen_config.update_from_hf_gen_cfg(self.hf_gen_cfg, self.tokenizer.eos_token_id)
692692
if not gen_config.do_sample:
693-
logger.warning(f'GenerationConfig: {gen_config}')
694-
logger.warning('Since v0.6.0, lmdeploy add `do_sample` in '
695-
'GenerationConfig. It defaults to False, meaning greedy '
696-
'decoding. Please set `do_sample=True` if sampling '
697-
' decoding is needed')
698693
# greedy decode
699694
gen_config.top_k = 1
700695
# avoid unnecessary process
@@ -704,8 +699,7 @@ async def generate(
704699
elif gen_config.random_seed is None and sequence_start:
705700
gen_config.random_seed = random.getrandbits(64)
706701
if gen_config.n > 1:
707-
logger.ERROR(f"n({gen_config.n}) > 1 hasn't been supported yet. "
708-
f'Fallback to 1')
702+
logger.warning(f'n({gen_config.n}) > 1 hasn\'t been supported yet. Fallback to 1')
709703
gen_config.n = 1
710704
if messages:
711705
prompt = messages
@@ -742,6 +736,17 @@ async def generate(
742736
if sequence_end is True and sequence_start is False:
743737
await self.end_session(session_id)
744738
return
739+
if self.backend_config.enable_prefix_caching and (gen_config.output_last_hidden_state == 'all'
740+
or gen_config.output_logits == 'all'):
741+
errmsg = ('lmdeploy does not support outputting all token\'s logits or last_hidden_state '
742+
'when prefix caching is ON')
743+
yield GenOut(response=errmsg,
744+
history_token_len=self.id2step[session_id],
745+
input_token_len=len(input_ids),
746+
generate_token_len=0,
747+
finish_reason='error',
748+
token_ids=[])
749+
return
745750

746751
def is_error(status):
747752
return status not in [ResponseType.SUCCESS, ResponseType.FINISH, ResponseType.CANCEL]

lmdeploy/turbomind/turbomind.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,7 @@ def __init__(self, tm_model: TurboMind, config: TurbomindModelConfig, cuda_strea
538538
6: ResponseType.INPUT_LENGTH_ERROR,
539539
7: ResponseType.FINISH,
540540
8: ResponseType.CANCEL,
541+
9: ResponseType.PREFIX_CACHE_CONFLICT_INTERACTIVE_MODE,
541542
-1: ResponseType.INTERNAL_ENGINE_ERROR,
542543
}
543544

src/turbomind/engine/request.h

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -140,15 +140,16 @@ struct Request {
140140

141141
enum
142142
{
143-
kOk = 0,
144-
kInvalid = 1, // Sequence not exist or both `start` & `stop` (instead of `end`) is set
145-
kConflict = 2, // Concurrent requests to the same sequence
146-
kBusy = 3, // Sequence is already running
147-
kInactive = 4, // Sequence to `stop` is not active
148-
kFail = 5, // Can't find sequence for `stop` request or internal error during inference
149-
kTooLong = 6, // history + prompt > session_len,
150-
kFinish = 7,
151-
kCancel = 8,
143+
kOk = 0,
144+
kInvalid = 1, // Sequence not exist or both `start` & `stop` (instead of `end`) is set
145+
kConflict = 2, // Concurrent requests to the same sequence
146+
kBusy = 3, // Sequence is already running
147+
kInactive = 4, // Sequence to `stop` is not active
148+
kFail = 5, // Can't find sequence for `stop` request or internal error during inference
149+
kTooLong = 6, // history + prompt > session_len,
150+
kFinish = 7,
151+
kCancel = 8,
152+
kInconsistency = 9, // Inconsistent request parameters, e.g. prefix caching is not allowed in interactive mode
152153
};
153154
};
154155

src/turbomind/models/llama/BlockTrie.cc

Lines changed: 37 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,23 @@ size_t hash(const std::vector<int>& vec)
1414
return seed;
1515
}
1616

17-
BlockTrie::BlockTrie(size_t block_seq_len, std::shared_ptr<BlockManager> block_manager, bool enable_prefix_caching):
18-
block_seq_len_(block_seq_len), block_manager_(block_manager), enable_prefix_caching_(enable_prefix_caching)
17+
BlockTrie::BlockTrie(size_t block_len, std::shared_ptr<BlockManager> block_manager):
18+
block_seq_len_(block_len), block_manager_(block_manager)
1919
{
2020
root_ = std::make_shared<TrieNode>();
2121
}
2222

23-
void BlockTrie::match(Sequence& seq)
23+
std::tuple<BlockIds, UniqueIds> BlockTrie::Match(const Sequence& seq)
2424
{
2525
BlockIds matched_blocks;
2626
UniqueIds matched_unique_ids;
2727

2828
std::shared_ptr<TrieNode> curr_node = root_;
2929
int num_matched = 0;
3030

31+
// Warning: Do not use "<=" operator even when seq.prompt length is evenly
32+
// divisible by block_seq_len_. This may produce an input_length of zero for
33+
// the sequence, violating the precondition checked in LlamaBatch::Forward.
3134
while (num_matched + block_seq_len_ < seq.prompt.size()) {
3235
std::vector<int> curr_tokens(seq.prompt.begin() + num_matched,
3336
seq.prompt.begin() + num_matched + block_seq_len_);
@@ -40,44 +43,47 @@ void BlockTrie::match(Sequence& seq)
4043
}
4144

4245
if (curr_tokens != it->second->tokens) {
46+
TM_LOG_WARNING("hash key cache hit, but tokens are not the same");
4347
break;
4448
}
4549

46-
matched_blocks.push_back(it->second->block_id);
47-
matched_unique_ids.push_back(it->second->block_unique_id);
50+
matched_blocks.emplace_back(it->second->block_id);
51+
matched_unique_ids.emplace_back(it->second->block_unique_id);
4852
curr_node = it->second;
4953
num_matched += block_seq_len_;
5054
}
51-
52-
if (matched_blocks.size() > 0) {
53-
// add use count
54-
block_manager_->Lock(matched_blocks);
55-
block_manager_->Touch(matched_blocks);
56-
// only consider no history blocks
57-
seq.blocks.insert(seq.blocks.end(), matched_blocks.begin(), matched_blocks.end());
58-
seq.block_unique_ids.insert(seq.block_unique_ids.end(), matched_unique_ids.begin(), matched_unique_ids.end());
59-
}
55+
return std::make_tuple(matched_blocks, matched_unique_ids);
6056
}
6157

62-
void BlockTrie::cache(const Sequence& seq)
58+
std::tuple<BlockIds, UniqueIds> BlockTrie::Cache(const Sequence& seq, const std::vector<int>& tokens)
6359
{
64-
std::shared_ptr<TrieNode> curr_node = root_;
65-
int num_matched = 0;
66-
int idx = 0;
67-
BlockIds cached_blocks;
60+
FT_CHECK(seq.status != Sequence::kCached);
61+
FT_CHECK(tokens.size() <= seq.blocks.size() * block_seq_len_);
6862

69-
while (num_matched + block_seq_len_ <= seq.prompt.size()) {
70-
std::vector<int> curr_tokens(seq.prompt.begin() + num_matched,
71-
seq.prompt.begin() + num_matched + block_seq_len_);
72-
size_t hash_key = hash(curr_tokens);
63+
std::shared_ptr<TrieNode> curr_node = root_;
64+
int idx = 0;
7365

74-
auto it = curr_node->children.find(hash_key);
66+
BlockIds cache_block_ids;
67+
UniqueIds cache_block_unique_ids;
68+
69+
// We don't cache the last block of the sequence, since it might not be full
70+
// TODO(lvhan): determine wether the last block is full or not. It is not trivial
71+
// considering chunk prefill
72+
for (int idx = 0; idx < (int)seq.blocks.size() - 1; ++idx) {
73+
auto start = tokens.begin() + idx * block_seq_len_;
74+
auto end = start + block_seq_len_;
75+
76+
std::vector<int> curr_tokens(start, end);
77+
// TODO(lvhan): add salt to ensure the hash security
78+
size_t hash_key = hash(curr_tokens);
7579

7680
int block_id = seq.blocks[idx];
7781
uint64_t block_unique_id = seq.block_unique_ids[idx];
7882

83+
auto it = curr_node->children.find(hash_key);
7984
if (it != curr_node->children.end()) {
8085
if (curr_tokens != it->second->tokens) {
86+
TM_LOG_WARNING("[BlockTrie][cache] hash key cache hit, but tokens are not the same");
8187
break;
8288
}
8389
curr_node = it->second;
@@ -91,38 +97,33 @@ void BlockTrie::cache(const Sequence& seq)
9197
node->tokens = curr_tokens;
9298
node->block_id = block_id;
9399
node->block_unique_id = block_unique_id;
94-
node->num_matched = num_matched + block_seq_len_;
95100
curr_node->children[hash_key] = node;
96101
curr_node = node;
97102
}
98-
99-
cached_blocks.push_back(curr_node->block_id);
100-
num_matched += block_seq_len_;
101-
idx++;
103+
cache_block_ids.emplace_back(block_id);
104+
cache_block_unique_ids.emplace_back(block_unique_id);
102105
}
103106

104-
block_manager_->Touch(cached_blocks);
107+
return std::make_tuple(cache_block_ids, cache_block_unique_ids);
105108
}
106109

107-
int BlockTrie::verify()
110+
void BlockTrie::Verify()
108111
{
109-
return verify_traverse(root_);
112+
DFS(root_);
110113
}
111114

112-
int BlockTrie::verify_traverse(std::shared_ptr<TrieNode>& node)
115+
void BlockTrie::DFS(std::shared_ptr<TrieNode>& node)
113116
{
114-
int valid_count = 1;
115117
for (auto it = node->children.begin(); it != node->children.end();) {
116118
if (block_manager_->unique_id(it->second->block_id) != it->second->block_unique_id) {
117119
// child invalid
118120
it = node->children.erase(it);
119121
}
120122
else {
121-
valid_count += verify_traverse(it->second);
123+
DFS(it->second);
122124
it++;
123125
}
124126
}
125-
return valid_count;
126127
}
127128

128129
} // namespace turbomind

src/turbomind/models/llama/BlockTrie.h

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,27 +22,48 @@ struct TrieNode {
2222

2323
class BlockTrie {
2424
public:
25-
explicit BlockTrie(size_t block_len_, std::shared_ptr<BlockManager> block_manager, bool enable_prefix_caching);
25+
explicit BlockTrie(size_t block_len, std::shared_ptr<BlockManager> block_manager);
2626

27-
bool enabled()
28-
{
29-
return enable_prefix_caching_;
30-
}
27+
/**
28+
* @brief Attempt to match cached key-value (KV) blocks for a given sequence.
29+
*
30+
* This function iterates the tokens of the sequence and attempts
31+
* to match them with the cached KV blocks. If the max prefix match is found,
32+
* it returns the IDs, unique IDs of the matched blocks.
33+
*
34+
* @param seq The sequence whose tokens are to be matched against the cached KV blocks.
35+
* @return A tuple containing the following:
36+
* - BlockIds: A list of IDs of the matched blocks.
37+
* - UniqueIds: A list of unique IDs of the matched blocks.
38+
*
39+
* @note If no blocks are matched, all containers in the returned tuple will be empty.
40+
*/
41+
std::tuple<BlockIds, UniqueIds> Match(const Sequence& seq);
3142

32-
// get cached blocks for sequence
33-
void match(Sequence& seq);
43+
/**
44+
* @brief Cache the key-value (KV) blocks of a given sequence.
45+
*
46+
* This function caches the KV blocks of the specified sequence. Only valid blocks
47+
* of a sequence whose status is NOT `Sequence::kCached` are considered
48+
* to be cached
49+
*
50+
* @param seq The sequence whose KV blocks are to be cached.
51+
* @param tokens The token list corresponding to the KV blocks
52+
* @return A tuple containing the following:
53+
* - BlockIds: A list of IDs of the cached blocks.
54+
* - UniqueIds: A list of unique IDs of the cached blocks.
55+
*/
56+
std::tuple<BlockIds, UniqueIds> Cache(const Sequence& seq, const std::vector<int>& tokens);
3457

35-
// cache computed blocks for sequence
36-
void cache(const Sequence& seq);
37-
38-
// remove invalid nodes, return valid count
39-
int verify();
58+
/**
59+
* @brief remove invalid nodes
60+
*/
61+
void Verify();
4062

4163
private:
42-
int verify_traverse(std::shared_ptr<TrieNode>& node);
64+
void DFS(std::shared_ptr<TrieNode>& node);
4365

4466
private:
45-
bool enable_prefix_caching_;
4667
size_t block_seq_len_;
4768

4869
std::shared_ptr<BlockManager> block_manager_;

0 commit comments

Comments
 (0)