@@ -14,20 +14,23 @@ size_t hash(const std::vector<int>& vec)
14
14
return seed;
15
15
}
16
16
17
- BlockTrie::BlockTrie (size_t block_seq_len , std::shared_ptr<BlockManager> block_manager, bool enable_prefix_caching ):
18
- block_seq_len_ (block_seq_len ), block_manager_(block_manager), enable_prefix_caching_(enable_prefix_caching )
17
+ BlockTrie::BlockTrie (size_t block_len , std::shared_ptr<BlockManager> block_manager):
18
+ block_seq_len_ (block_len ), block_manager_(block_manager)
19
19
{
20
20
root_ = std::make_shared<TrieNode>();
21
21
}
22
22
23
- void BlockTrie::match ( Sequence& seq)
23
+ std::tuple<BlockIds, UniqueIds> BlockTrie::Match ( const Sequence& seq)
24
24
{
25
25
BlockIds matched_blocks;
26
26
UniqueIds matched_unique_ids;
27
27
28
28
std::shared_ptr<TrieNode> curr_node = root_;
29
29
int num_matched = 0 ;
30
30
31
+ // Warning: Do not use "<=" operator even when seq.prompt length is evenly
32
+ // divisible by block_seq_len_. This may produce an input_length of zero for
33
+ // the sequence, violating the precondition checked in LlamaBatch::Forward.
31
34
while (num_matched + block_seq_len_ < seq.prompt .size ()) {
32
35
std::vector<int > curr_tokens (seq.prompt .begin () + num_matched,
33
36
seq.prompt .begin () + num_matched + block_seq_len_);
@@ -40,44 +43,47 @@ void BlockTrie::match(Sequence& seq)
40
43
}
41
44
42
45
if (curr_tokens != it->second ->tokens ) {
46
+ TM_LOG_WARNING (" hash key cache hit, but tokens are not the same" );
43
47
break ;
44
48
}
45
49
46
- matched_blocks.push_back (it->second ->block_id );
47
- matched_unique_ids.push_back (it->second ->block_unique_id );
50
+ matched_blocks.emplace_back (it->second ->block_id );
51
+ matched_unique_ids.emplace_back (it->second ->block_unique_id );
48
52
curr_node = it->second ;
49
53
num_matched += block_seq_len_;
50
54
}
51
-
52
- if (matched_blocks.size () > 0 ) {
53
- // add use count
54
- block_manager_->Lock (matched_blocks);
55
- block_manager_->Touch (matched_blocks);
56
- // only consider no history blocks
57
- seq.blocks .insert (seq.blocks .end (), matched_blocks.begin (), matched_blocks.end ());
58
- seq.block_unique_ids .insert (seq.block_unique_ids .end (), matched_unique_ids.begin (), matched_unique_ids.end ());
59
- }
55
+ return std::make_tuple (matched_blocks, matched_unique_ids);
60
56
}
61
57
62
- void BlockTrie::cache (const Sequence& seq)
58
+ std::tuple<BlockIds, UniqueIds> BlockTrie::Cache (const Sequence& seq, const std::vector< int >& tokens )
63
59
{
64
- std::shared_ptr<TrieNode> curr_node = root_;
65
- int num_matched = 0 ;
66
- int idx = 0 ;
67
- BlockIds cached_blocks;
60
+ FT_CHECK (seq.status != Sequence::kCached );
61
+ FT_CHECK (tokens.size () <= seq.blocks .size () * block_seq_len_);
68
62
69
- while (num_matched + block_seq_len_ <= seq.prompt .size ()) {
70
- std::vector<int > curr_tokens (seq.prompt .begin () + num_matched,
71
- seq.prompt .begin () + num_matched + block_seq_len_);
72
- size_t hash_key = hash (curr_tokens);
63
+ std::shared_ptr<TrieNode> curr_node = root_;
64
+ int idx = 0 ;
73
65
74
- auto it = curr_node->children .find (hash_key);
66
+ BlockIds cache_block_ids;
67
+ UniqueIds cache_block_unique_ids;
68
+
69
+ // We don't cache the last block of the sequence, since it might not be full
70
+ // TODO(lvhan): determine wether the last block is full or not. It is not trivial
71
+ // considering chunk prefill
72
+ for (int idx = 0 ; idx < (int )seq.blocks .size () - 1 ; ++idx) {
73
+ auto start = tokens.begin () + idx * block_seq_len_;
74
+ auto end = start + block_seq_len_;
75
+
76
+ std::vector<int > curr_tokens (start, end);
77
+ // TODO(lvhan): add salt to ensure the hash security
78
+ size_t hash_key = hash (curr_tokens);
75
79
76
80
int block_id = seq.blocks [idx];
77
81
uint64_t block_unique_id = seq.block_unique_ids [idx];
78
82
83
+ auto it = curr_node->children .find (hash_key);
79
84
if (it != curr_node->children .end ()) {
80
85
if (curr_tokens != it->second ->tokens ) {
86
+ TM_LOG_WARNING (" [BlockTrie][cache] hash key cache hit, but tokens are not the same" );
81
87
break ;
82
88
}
83
89
curr_node = it->second ;
@@ -91,38 +97,33 @@ void BlockTrie::cache(const Sequence& seq)
91
97
node->tokens = curr_tokens;
92
98
node->block_id = block_id;
93
99
node->block_unique_id = block_unique_id;
94
- node->num_matched = num_matched + block_seq_len_;
95
100
curr_node->children [hash_key] = node;
96
101
curr_node = node;
97
102
}
98
-
99
- cached_blocks.push_back (curr_node->block_id );
100
- num_matched += block_seq_len_;
101
- idx++;
103
+ cache_block_ids.emplace_back (block_id);
104
+ cache_block_unique_ids.emplace_back (block_unique_id);
102
105
}
103
106
104
- block_manager_-> Touch (cached_blocks );
107
+ return std::make_tuple (cache_block_ids, cache_block_unique_ids );
105
108
}
106
109
107
- int BlockTrie::verify ()
110
+ void BlockTrie::Verify ()
108
111
{
109
- return verify_traverse (root_);
112
+ DFS (root_);
110
113
}
111
114
112
- int BlockTrie::verify_traverse (std::shared_ptr<TrieNode>& node)
115
+ void BlockTrie::DFS (std::shared_ptr<TrieNode>& node)
113
116
{
114
- int valid_count = 1 ;
115
117
for (auto it = node->children .begin (); it != node->children .end ();) {
116
118
if (block_manager_->unique_id (it->second ->block_id ) != it->second ->block_unique_id ) {
117
119
// child invalid
118
120
it = node->children .erase (it);
119
121
}
120
122
else {
121
- valid_count += verify_traverse (it->second );
123
+ DFS (it->second );
122
124
it++;
123
125
}
124
126
}
125
- return valid_count;
126
127
}
127
128
128
129
} // namespace turbomind
0 commit comments