diff --git a/src/cpp/include/openvino/genai/cache_eviction.hpp b/src/cpp/include/openvino/genai/cache_eviction.hpp index e5f54b24ec..81db238693 100644 --- a/src/cpp/include/openvino/genai/cache_eviction.hpp +++ b/src/cpp/include/openvino/genai/cache_eviction.hpp @@ -4,6 +4,7 @@ #pragma once #include +#include #include "openvino/core/except.hpp" @@ -15,8 +16,11 @@ namespace ov::genai { */ enum class AggregationMode { SUM, /**< In this mode the importance scores of each token will be summed after each step of generation */ - NORM_SUM /**< Same as SUM, but the importance scores are additionally divided by the lifetime (in tokens generated) + NORM_SUM, /**< Same as SUM, but the importance scores are additionally divided by the lifetime (in tokens generated) * of a given token in cache */ + ADAPTIVE_RKV /** Switches the cache eviction algorithm to use Adaptive R-KV algorithm. The scores are aggregated within + a configurable window size of the latest generated tokens. May not be used together with the KVCrush + algorithm. */ }; /** @@ -68,9 +72,15 @@ class KVCrushConfig { } }; -/** -* @brief Configuration struct for the cache eviction algorithm. -*/ +struct AdaptiveRKVConfig { + AdaptiveRKVConfig() = default; + AdaptiveRKVConfig(double attention_mass_, size_t window_size_) : attention_mass(attention_mass_), window_size(window_size_) {}; + + double attention_mass = 0.9; + size_t window_size = 8; +}; + + class CacheEvictionConfig { public: CacheEvictionConfig() = default; @@ -81,14 +91,16 @@ class CacheEvictionConfig { AggregationMode aggregation_mode_, bool apply_rotation_ = false, size_t snapkv_window_size_ = 8, - const KVCrushConfig& kvcrush_config_ = KVCrushConfig(0, KVCrushAnchorPointMode::RANDOM)) + const KVCrushConfig& kvcrush_config_ = KVCrushConfig(0, KVCrushAnchorPointMode::RANDOM), + const AdaptiveRKVConfig& adaptive_rkv_config_ = AdaptiveRKVConfig()) : aggregation_mode(aggregation_mode_), apply_rotation(apply_rotation_), snapkv_window_size(snapkv_window_size_), + kvcrush_config(kvcrush_config_), + adaptive_rkv_config(adaptive_rkv_config_), m_start_size(start_size), m_recent_size(recent_size), - m_max_cache_size(max_cache_size), - kvcrush_config(kvcrush_config_) { + m_max_cache_size(max_cache_size) { OPENVINO_ASSERT(start_size, "CacheEvictionConfig.start_size must be non-zero"); OPENVINO_ASSERT(recent_size, "CacheEvictionConfig.recent_size must be non-zero"); OPENVINO_ASSERT(max_cache_size, "CacheEvictionConfig.max_cache_size must be non-zero"); @@ -142,6 +154,8 @@ class CacheEvictionConfig { * even if they are not among the most important ones.*/ KVCrushConfig kvcrush_config; + AdaptiveRKVConfig adaptive_rkv_config; + private: /** Number of tokens in the *beginning* of KV cache that should be retained * in the KV cache for this sequence during generation. Must be non-zero and a multiple of the KV cache block size for diff --git a/src/cpp/src/continuous_batching/attention_output.hpp b/src/cpp/src/continuous_batching/attention_output.hpp index 602fcda1a0..4463bb45b5 100644 --- a/src/cpp/src/continuous_batching/attention_output.hpp +++ b/src/cpp/src/continuous_batching/attention_output.hpp @@ -6,3 +6,8 @@ using AttentionScoresForCacheOfSubsequence = ov::Tensor; using AttentionScoresForEachDecoderLayer = std::vector; using AttentionScoresForEachSubsequence = std::map; + + +using TokenSimilarityForSubsequence = ov::Tensor; +using TokenSimilarityForEachDecoderLayer = std::vector; +using TokenSimilarityForEachSubsequence = std::map; diff --git a/src/cpp/src/continuous_batching/cache_eviction.cpp b/src/cpp/src/continuous_batching/cache_eviction.cpp index 7196b29e56..c2bf23401e 100644 --- a/src/cpp/src/continuous_batching/cache_eviction.cpp +++ b/src/cpp/src/continuous_batching/cache_eviction.cpp @@ -2,6 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 #include "continuous_batching/cache_eviction.hpp" +#include namespace ov::genai { @@ -48,6 +49,24 @@ namespace ov::genai { m_cache_counter[decoder_layer_idx] = new_counter; } + template + void _max_pool(std::vector& dst, const T* src_data, size_t size, size_t max_pool_window_size) { + OPENVINO_ASSERT(size == dst.size()); + for (size_t idx = 0; idx < size; idx++) { + size_t effective_window_size = max_pool_window_size; + size_t elements_left = size - idx; + if (elements_left < effective_window_size) { + effective_window_size = elements_left; + } + auto max_val = src_data[idx]; + for (size_t window_idx = 1; window_idx < effective_window_size; window_idx++) { + auto val = src_data[idx + window_idx]; + max_val = std::max(val, max_val); + } + dst[idx] = max_val; + } + } + void EvictionScoreManager::register_new_token_scores( const AttentionScoresForEachDecoderLayer &attention_scores_for_all_decoder_layers, const std::set& skipped_logical_block_ids, size_t num_snapkv_scores) { @@ -72,7 +91,6 @@ namespace ov::genai { return; } - std::set skip_set_adjusted; size_t num_skipped_blocks_in_ignore_area = 0; for (size_t i = 0; i < m_ignore_first_n_blocks; i++) { if (skipped_logical_block_ids.find(i) != skipped_logical_block_ids.end()) { @@ -81,13 +99,8 @@ namespace ov::genai { } OPENVINO_ASSERT(num_skipped_blocks_in_ignore_area <= m_ignore_first_n_blocks); - size_t start_token_offset_in_scores = (m_ignore_first_n_blocks - num_skipped_blocks_in_ignore_area) * m_block_size; - for (size_t skipped_block_id : skipped_logical_block_ids) { - if (skipped_block_id >= m_ignore_first_n_blocks) { - skip_set_adjusted.insert(skipped_block_id - m_ignore_first_n_blocks); - } // else do not include this block in the adjusted skip set since it is in the start area already - } + size_t start_token_offset_in_scores = (m_ignore_first_n_blocks - num_skipped_blocks_in_ignore_area) * m_block_size; auto hh_score = ov::Tensor( attention_scores, @@ -95,98 +108,150 @@ namespace ov::genai { ov::Coordinate{scores_size_in_tokens} ); - std::vector max_pooled_hh_scores(hh_score.get_size()); - auto hh_score_data = hh_score.data(); - size_t num_hh_scores = hh_score.get_size(); + std::vector processed_hh_scores(hh_score.get_size()); - for (size_t idx = 0; idx < num_hh_scores; idx++) { - size_t effective_window_size = m_max_pool_window_size; - size_t elements_left = num_hh_scores - idx; - if (elements_left < effective_window_size) { - effective_window_size = elements_left; - } - auto max_val = hh_score_data[idx]; - for (size_t window_idx = 1; window_idx < effective_window_size; window_idx++) { - auto val = hh_score_data[idx + window_idx]; - max_val = std::max(val, max_val); - } - max_pooled_hh_scores[idx] = max_val; + if (m_aggregation_mode == AggregationMode::ADAPTIVE_RKV) { + _max_pool(processed_hh_scores, hh_score.data(), hh_score.get_size(), 1); + } else { + _max_pool(processed_hh_scores, hh_score.data(), hh_score.get_size(), m_max_pool_window_size); } auto& accumulated_scores_for_current_decoder_layer = m_scores[decoder_layer_idx]; if (accumulated_scores_for_current_decoder_layer.empty()) { - if (m_snapkv_window_size != 0 && num_snapkv_scores == 0) { - // SnapKV window not yet reached, no meaningful scores to accumulate - continue; - } - // New sequence to track - if (skipped_logical_block_ids.empty()) { - accumulated_scores_for_current_decoder_layer = max_pooled_hh_scores; - } - else { - accumulated_scores_for_current_decoder_layer.resize(max_pooled_hh_scores.size() + m_block_size * skipped_logical_block_ids.size(), 0.0); - size_t src_idx = 0; - for (size_t dst_idx = 0; dst_idx < accumulated_scores_for_current_decoder_layer.size(); dst_idx++) { - size_t curr_logical_block_idx = dst_idx / m_block_size; - if (skipped_logical_block_ids.find(curr_logical_block_idx) != skipped_logical_block_ids.end()) { - dst_idx += m_block_size; - continue; - } - accumulated_scores_for_current_decoder_layer[dst_idx] = accumulated_scores_for_current_decoder_layer[src_idx]; - src_idx++; - } - OPENVINO_ASSERT(src_idx == max_pooled_hh_scores.size()); - } - - if (m_aggregation_mode == AggregationMode::NORM_SUM) { - std::size_t new_scores_size = num_hh_scores; - std::vector counter(new_scores_size); - if (m_snapkv_window_size == 0) { - // Will simulate that the tokens comprising the sequence were added one-by-one - // from the standpoint of the occurrence tracker - std::generate(counter.begin(), counter.begin() + new_scores_size, - [&new_scores_size] { return new_scores_size--; }); - } - else { - OPENVINO_ASSERT(num_snapkv_scores > 0); - OPENVINO_ASSERT(new_scores_size >= num_snapkv_scores); - std::fill(counter.begin(), counter.end() - num_snapkv_scores, num_snapkv_scores); - std::iota(counter.rbegin(), counter.rbegin() + num_snapkv_scores, 1); - } - m_cache_counter[decoder_layer_idx] = counter; - } + _accumulate_initial_scores(processed_hh_scores, decoder_layer_idx, num_snapkv_scores, skipped_logical_block_ids); } else { - size_t old_size_in_tokens = accumulated_scores_for_current_decoder_layer.size(); - size_t new_size_in_tokens = max_pooled_hh_scores.size() + m_block_size * skipped_logical_block_ids.size(); - - OPENVINO_ASSERT(new_size_in_tokens >= old_size_in_tokens); - size_t num_new_tokens = new_size_in_tokens - old_size_in_tokens; - if (m_aggregation_mode == AggregationMode::NORM_SUM) { - auto &counter_for_current_decoder_layer = m_cache_counter[decoder_layer_idx]; - counter_for_current_decoder_layer.resize(new_size_in_tokens); - if (m_snapkv_window_size == 0 || m_num_registered_snapkv_aggregated_scores == m_snapkv_window_size) { - // Increment occurrence counts of all currently tracked cache blocks - for (auto it = counter_for_current_decoder_layer.begin(); - it != counter_for_current_decoder_layer.end(); it++) { - *it += num_new_tokens; - } - // Add occurrence counts for new tokens like above - for (size_t i = 0; i < num_new_tokens; i++) { - auto idx = old_size_in_tokens + i; - counter_for_current_decoder_layer[idx] = num_new_tokens - i; - } - } - else { - OPENVINO_ASSERT(new_size_in_tokens >= m_num_registered_snapkv_aggregated_scores); - std::fill(counter_for_current_decoder_layer.begin(), counter_for_current_decoder_layer.end() - m_num_registered_snapkv_aggregated_scores, m_num_registered_snapkv_aggregated_scores); - std::iota(counter_for_current_decoder_layer.rbegin(), counter_for_current_decoder_layer.rbegin() + m_num_registered_snapkv_aggregated_scores, 1); - } + _accumulate_with_existing_scores(processed_hh_scores, decoder_layer_idx, num_snapkv_scores, skipped_logical_block_ids); + } + } + } + + + void EvictionScoreManager::_accumulate_initial_scores(const std::vector& max_pooled_hh_scores, size_t decoder_layer_idx, size_t num_snapkv_scores, const std::set& skipped_logical_block_ids) { + if (m_snapkv_window_size != 0 && num_snapkv_scores == 0) { + // SnapKV window not yet reached, no meaningful scores to accumulate + return; + } + + + OPENVINO_ASSERT(m_previous_scores_queues[decoder_layer_idx].empty()); + m_previous_scores_queues[decoder_layer_idx].emplace_back(max_pooled_hh_scores, skipped_logical_block_ids); + auto& accumulated_scores_for_current_decoder_layer = m_scores[decoder_layer_idx]; + _initialize_score_with_skips(accumulated_scores_for_current_decoder_layer, max_pooled_hh_scores, skipped_logical_block_ids); + + if (m_aggregation_mode == AggregationMode::NORM_SUM) { + std::size_t new_scores_size = max_pooled_hh_scores.size(); + std::vector counter(new_scores_size); + if (m_snapkv_window_size == 0) { + // Will simulate that the tokens comprising the sequence were added one-by-one + // from the standpoint of the occurrence tracker + std::generate(counter.begin(), counter.begin() + new_scores_size, + [&new_scores_size] { return new_scores_size--; }); + } + else { + OPENVINO_ASSERT(num_snapkv_scores > 0); + OPENVINO_ASSERT(new_scores_size >= num_snapkv_scores); + std::fill(counter.begin(), counter.end() - num_snapkv_scores, num_snapkv_scores); + std::iota(counter.rbegin(), counter.rbegin() + num_snapkv_scores, 1); + } + m_cache_counter[decoder_layer_idx] = counter; + } + } + + void EvictionScoreManager::_accumulate_with_existing_scores(const std::vector& max_pooled_hh_scores, size_t decoder_layer_idx, size_t num_snapkv_scores, const std::set& skipped_logical_block_ids) { + if (m_aggregation_mode == AggregationMode::ADAPTIVE_RKV) { + if (m_previous_scores_queues[decoder_layer_idx].size() >= m_adaptive_rkv_window_size) { + m_previous_scores_queues[decoder_layer_idx].pop_front(); + } + m_previous_scores_queues[decoder_layer_idx].emplace_back(max_pooled_hh_scores, skipped_logical_block_ids); + } + if (m_aggregation_mode == AggregationMode::ADAPTIVE_RKV) { + OPENVINO_ASSERT(!m_previous_scores_queues[decoder_layer_idx].empty()); + auto start_it = m_previous_scores_queues[decoder_layer_idx].begin(); + auto& dst = m_scores[decoder_layer_idx]; + _initialize_score_with_skips(dst, start_it->score, start_it->skips); + for (auto it = start_it + 1; it != m_previous_scores_queues[decoder_layer_idx].end(); it++) { + _accumulate_layer_scores_to(decoder_layer_idx, it->score, it->skips, dst); + } + // mean + size_t queue_size = m_previous_scores_queues[decoder_layer_idx].size(); + for (size_t i = 0; i < dst.size(); i++) { + dst[i] /= queue_size; + } + auto score_copy = dst; + _max_pool(dst, score_copy.data(), score_copy.size(), m_max_pool_window_size); + } else { + auto& accumulated_scores_for_current_decoder_layer = m_scores[decoder_layer_idx]; + size_t old_size_in_tokens = accumulated_scores_for_current_decoder_layer.size(); + _accumulate_layer_scores_to(decoder_layer_idx, max_pooled_hh_scores, skipped_logical_block_ids, accumulated_scores_for_current_decoder_layer); + size_t new_size_in_tokens = accumulated_scores_for_current_decoder_layer.size(); + + if (m_aggregation_mode == AggregationMode::NORM_SUM) { + _adjust_norm_sum_counters(decoder_layer_idx, old_size_in_tokens, new_size_in_tokens); + } + } + } + + void EvictionScoreManager::_accumulate_layer_scores_to(size_t decoder_layer_idx, const std::vector& src, const std::set& skipped_logical_block_ids, std::vector& dst) { + std::set skip_set_adjusted; + + for (size_t skipped_block_id : skipped_logical_block_ids) { + if (skipped_block_id >= m_ignore_first_n_blocks) { + skip_set_adjusted.insert(skipped_block_id - m_ignore_first_n_blocks); + } // else do not include this block in the adjusted skip set since it is in the start area already + } + size_t old_size_in_tokens = dst.size(); + size_t new_size_in_tokens = src.size() + m_block_size * skipped_logical_block_ids.size(); + + OPENVINO_ASSERT(new_size_in_tokens >= old_size_in_tokens); + dst.resize(new_size_in_tokens); + add_with_skips(dst, src, skip_set_adjusted); + } + + void EvictionScoreManager::_adjust_norm_sum_counters(size_t decoder_layer_idx, size_t old_size_in_tokens, size_t new_size_in_tokens) { + OPENVINO_ASSERT(new_size_in_tokens >= old_size_in_tokens); + size_t num_new_tokens = new_size_in_tokens - old_size_in_tokens; + auto &counter_for_current_decoder_layer = m_cache_counter[decoder_layer_idx]; + counter_for_current_decoder_layer.resize(new_size_in_tokens); + if (m_snapkv_window_size == 0 || m_num_registered_snapkv_aggregated_scores == m_snapkv_window_size) { + // Increment occurrence counts of all currently tracked cache blocks + for (auto it = counter_for_current_decoder_layer.begin(); + it != counter_for_current_decoder_layer.end(); it++) { + *it += num_new_tokens; + } + // Add occurrence counts for new tokens like above + for (size_t i = 0; i < num_new_tokens; i++) { + auto idx = old_size_in_tokens + i; + counter_for_current_decoder_layer[idx] = num_new_tokens - i; + } + } + else { + OPENVINO_ASSERT(new_size_in_tokens >= m_num_registered_snapkv_aggregated_scores); + std::fill(counter_for_current_decoder_layer.begin(), counter_for_current_decoder_layer.end() - m_num_registered_snapkv_aggregated_scores, m_num_registered_snapkv_aggregated_scores); + std::iota(counter_for_current_decoder_layer.rbegin(), counter_for_current_decoder_layer.rbegin() + m_num_registered_snapkv_aggregated_scores, 1); + } + } + + void EvictionScoreManager::_initialize_score_with_skips(std::vector& dst, const std::vector& src, const std::set skipped_logical_block_ids) { + // New sequence to track + if (skipped_logical_block_ids.empty()) { + dst = src; + } + else { + dst.clear(); + dst.resize(src.size() + m_block_size * skipped_logical_block_ids.size(), 0.0); + size_t src_idx = 0; + for (size_t dst_idx = 0; dst_idx < dst.size(); dst_idx++) { + size_t curr_logical_block_idx = dst_idx / m_block_size; + if (skipped_logical_block_ids.find(curr_logical_block_idx) != skipped_logical_block_ids.end()) { + dst_idx = curr_logical_block_idx * m_block_size + m_block_size - 1; + continue; } - accumulated_scores_for_current_decoder_layer.resize(new_size_in_tokens); - add_with_skips(accumulated_scores_for_current_decoder_layer, max_pooled_hh_scores, skip_set_adjusted); + dst[dst_idx] = src[src_idx]; + src_idx++; } + OPENVINO_ASSERT(src_idx == src.size()); } } @@ -220,7 +285,7 @@ namespace ov::genai { CacheEvictionAlgorithm::CacheEvictionAlgorithm(const CacheEvictionConfig &eviction_config, size_t block_size, size_t num_decoder_layers, size_t max_pool_window_size) : m_eviction_config(eviction_config), m_block_size(block_size), m_num_decoder_layers(num_decoder_layers), - m_score_manager(block_size, num_decoder_layers, max_pool_window_size, eviction_config.aggregation_mode, eviction_config.get_start_size() / block_size, eviction_config.snapkv_window_size), m_kvcrush_algo(eviction_config.kvcrush_config, block_size) + m_score_manager(block_size, num_decoder_layers, max_pool_window_size, eviction_config.aggregation_mode, eviction_config.get_start_size() / block_size, eviction_config.snapkv_window_size, eviction_config.adaptive_rkv_config.window_size), m_kvcrush_algo(eviction_config.kvcrush_config, block_size) { OPENVINO_ASSERT(!(m_eviction_config.get_start_size() % m_block_size), "CacheEvictionConfig.start_size in tokens must be a multiple of block size ", m_block_size); @@ -260,42 +325,66 @@ namespace ov::genai { continue; } - // Only the blocks in the "intermediate" part of the logical KV cache will be considered for eviction + auto scores_for_all_evictable_blocks = get_scores_for_all_evictable_blocks(decoder_layer_idx); - size_t num_blocks_to_evict = get_num_blocks_to_evict(decoder_layer_idx); - auto evicted_block_indices = get_indices_of_blocks_to_evict(scores_for_all_evictable_blocks, num_blocks_to_evict); - - // KVCrush: start - bool should_apply_kvcrush = (m_eviction_config.kvcrush_config.budget > 0) && - (evicted_block_indices.size() >= m_eviction_config.kvcrush_config.budget); - if (should_apply_kvcrush) { - size_t num_tokens_in_evictable_blocks = scores_for_all_evictable_blocks.size() * m_block_size; - - auto kvcrush_retained_block_indices = m_kvcrush_algo.get_indices_of_blocks_to_retain_using_kvcrush( - num_tokens_in_evictable_blocks, - evicted_block_indices, - m_score_manager.get_scores()[decoder_layer_idx]); - - // Remove the indices in kvcrush_retained_block_indices from evicted_block_indices - if (!kvcrush_retained_block_indices.empty()) { - // Convert both vectors to sets for efficient operations - std::unordered_set retained_set(kvcrush_retained_block_indices.begin(), - kvcrush_retained_block_indices.end()); - - // Create a new vector containing only elements not in retained_set - std::vector filtered_evicted_indices; - filtered_evicted_indices.reserve(evicted_block_indices.size()); - - for (const auto& idx : evicted_block_indices) { - if (retained_set.find(idx) == retained_set.end()) { - filtered_evicted_indices.push_back(idx); + std::vector evicted_block_indices; + if (m_eviction_config.aggregation_mode == AggregationMode::ADAPTIVE_RKV) { + OPENVINO_ASSERT(!m_last_token_similarity.empty(), "Token similarity must be registered before each eviction in the Adaptive R-KV scenario"); + size_t num_evictable_blocks = get_num_evictable_blocks(decoder_layer_idx); + size_t num_similarity_tokens_registered = m_last_token_similarity[0].size(); + OPENVINO_ASSERT(num_similarity_tokens_registered / m_block_size == num_evictable_blocks, "Similarity score size mismatch - registered ", num_similarity_tokens_registered / m_block_size, " blocks worth of similarity scores, but have ", num_evictable_blocks, " evictable blocks"); + auto similarity_set_and_num_blocks_kept = get_adaptive_rkv_similarity_set(num_evictable_blocks, scores_for_all_evictable_blocks); + + const auto& similarity_set = similarity_set_and_num_blocks_kept.first; + size_t num_blocks_kept = similarity_set_and_num_blocks_kept.second; + + size_t num_evictable_blocks_to_keep_after_eviction = m_eviction_config.get_evictable_size() / m_block_size; + OPENVINO_ASSERT(num_blocks_kept <= num_evictable_blocks_to_keep_after_eviction); + size_t num_blocks_left_to_fill = num_evictable_blocks_to_keep_after_eviction - num_blocks_kept; + auto diverse_set = get_adaptive_rkv_diverse_blocks(num_blocks_left_to_fill, similarity_set, m_last_token_similarity[decoder_layer_idx]); + + for (size_t potentially_evicted_idx : similarity_set) { + if (diverse_set.find(potentially_evicted_idx) == diverse_set.end()) { + evicted_block_indices.push_back(potentially_evicted_idx); + } + } + + } else { + // Only the blocks in the "intermediate" part of the logical KV cache will be considered for eviction + size_t num_blocks_to_evict = get_num_blocks_to_evict(decoder_layer_idx); + evicted_block_indices = get_indices_of_blocks_to_evict(scores_for_all_evictable_blocks, num_blocks_to_evict); + // KVCrush: start + bool should_apply_kvcrush = (m_eviction_config.kvcrush_config.budget > 0) && + (evicted_block_indices.size() >= m_eviction_config.kvcrush_config.budget); + if (should_apply_kvcrush) { + size_t num_tokens_in_evictable_blocks = scores_for_all_evictable_blocks.size() * m_block_size; + + auto kvcrush_retained_block_indices = m_kvcrush_algo.get_indices_of_blocks_to_retain_using_kvcrush( + num_tokens_in_evictable_blocks, + evicted_block_indices, + m_score_manager.get_scores()[decoder_layer_idx]); + + // Remove the indices in kvcrush_retained_block_indices from evicted_block_indices + if (!kvcrush_retained_block_indices.empty()) { + // Convert both vectors to sets for efficient operations + std::unordered_set retained_set(kvcrush_retained_block_indices.begin(), + kvcrush_retained_block_indices.end()); + + // Create a new vector containing only elements not in retained_set + std::vector filtered_evicted_indices; + filtered_evicted_indices.reserve(evicted_block_indices.size()); + + for (const auto& idx : evicted_block_indices) { + if (retained_set.find(idx) == retained_set.end()) { + filtered_evicted_indices.push_back(idx); + } } + // Replace the original vector with the filtered one + evicted_block_indices = std::move(filtered_evicted_indices); } - // Replace the original vector with the filtered one - evicted_block_indices = std::move(filtered_evicted_indices); } + // KVCrush: end } - // KVCrush: end m_num_evicted_tokens += evicted_block_indices.size() * m_block_size; @@ -306,6 +395,72 @@ namespace ov::genai { for (auto &idx: evicted_block_indices) idx += get_num_blocks(m_eviction_config.get_start_size()); for (auto &idx: evicted_block_indices) retval[decoder_layer_idx].insert(idx); } + + m_last_token_similarity.clear(); + return retval; + } + + std::pair, size_t> CacheEvictionAlgorithm::get_adaptive_rkv_similarity_set(size_t max_num_blocks_kept, const std::vector& evictable_area_block_scores) { + struct ScoreAndBlockIdx { + double score; + size_t block_idx; + bool operator<(const ScoreAndBlockIdx& rhs) const { return score < rhs.score; } + }; + std::priority_queue score_block_queue; + double total_sum = 0.0; + for (size_t i = 0; i < evictable_area_block_scores.size(); i++) { + total_sum += evictable_area_block_scores[i]; + score_block_queue.push({evictable_area_block_scores[i], i}); + } + + double expected_sum = total_sum * m_eviction_config.adaptive_rkv_config.attention_mass; + std::set retval; + + double sum = 0.0; + size_t num_blocks_kept = 0; + while (sum < expected_sum && !score_block_queue.empty() && num_blocks_kept <= max_num_blocks_kept) { + // Blocks with most attention mass are kept + auto score_and_idx = score_block_queue.top(); + sum += score_and_idx.score; + score_block_queue.pop(); + num_blocks_kept += 1; + } + + // The rest will be further filtered according to their cosine similarity separately + while (!score_block_queue.empty()) { + auto score_and_idx = score_block_queue.top(); + retval.insert(score_and_idx.block_idx); + score_block_queue.pop(); + } + return {retval, num_blocks_kept}; + } + + std::set CacheEvictionAlgorithm::get_adaptive_rkv_diverse_blocks(size_t num_blocks_left_to_fill, const std::set& similarity_set, const std::vector& token_similarity) { + OPENVINO_ASSERT(num_blocks_left_to_fill <= similarity_set.size()); + struct ScoreAndBlockIdx { + double score; + size_t block_idx; + bool operator<(const ScoreAndBlockIdx& rhs) const { return score < rhs.score; } // sic! + }; + std::priority_queue score_block_queue; + OPENVINO_ASSERT(token_similarity.size() % m_block_size == 0); + for (size_t block_idx : similarity_set) { + OPENVINO_ASSERT(block_idx * m_block_size <= token_similarity.size()); + double block_diversity_score = 0.0; + for (size_t tok_idx = 0; tok_idx < m_block_size; tok_idx++) { + block_diversity_score -= token_similarity[block_idx * m_block_size + tok_idx]; + } + score_block_queue.push({block_diversity_score, block_idx}); + block_diversity_score = 0.0; + } + + std::set retval; + + while (retval.size() < num_blocks_left_to_fill && !score_block_queue.empty()) { + auto score_and_idx = score_block_queue.top(); + retval.insert(score_and_idx.block_idx); + score_block_queue.pop(); + } return retval; } @@ -423,6 +578,22 @@ namespace ov::genai { m_score_manager.remove_scores(evicted_block_indices, decoder_layer_idx); } + void CacheEvictionAlgorithm::register_token_similarity(const TokenSimilarityForEachDecoderLayer& token_similarity_for_all_decoder_layers) { + OPENVINO_ASSERT(m_last_token_similarity.empty(), "CacheEvictionAlgorithm already has token similarity, must evict before new similarity is registered"); + OPENVINO_ASSERT(token_similarity_for_all_decoder_layers.size() == m_num_decoder_layers); + m_last_token_similarity.resize(m_num_decoder_layers); + for (size_t layer_idx = 0; layer_idx < m_num_decoder_layers; layer_idx++) + { + const auto& layer_similarity = token_similarity_for_all_decoder_layers[layer_idx]; + const float* data = layer_similarity.data(); + size_t num_similarity_tokens = layer_similarity.get_size(); + OPENVINO_ASSERT(num_similarity_tokens % m_block_size == 0); + m_last_token_similarity[layer_idx].resize(num_similarity_tokens); + for (size_t tok_idx = 0; tok_idx < layer_similarity.get_size(); tok_idx++) { + m_last_token_similarity[layer_idx][tok_idx] = data[tok_idx]; + } + } + } CacheRotationCalculator::CacheRotationCalculator(size_t block_size, size_t max_context_length_in_blocks, diff --git a/src/cpp/src/continuous_batching/cache_eviction.hpp b/src/cpp/src/continuous_batching/cache_eviction.hpp index 1069ea6241..9c7c736540 100644 --- a/src/cpp/src/continuous_batching/cache_eviction.hpp +++ b/src/cpp/src/continuous_batching/cache_eviction.hpp @@ -7,6 +7,7 @@ #include #include #include +#include #include "openvino/openvino.hpp" #include "continuous_batching/attention_output.hpp" @@ -15,6 +16,7 @@ namespace ov::genai { + /** * @brief Keeps track of the accumulated token scores across model inferences and their lifetime. */ @@ -39,8 +41,10 @@ class EvictionScoreManager { * where `S` is equal to `snapkv_window_size`. In contrast, if this is set to 0, then the initial counter state would be * `| L | L - 1 | ... | 2 | 1 |`, * where L is the prompt size of the sequence in tokens. + * @param adaptive_rkv_window_size AggregationMode::ADAPTIVE_RKV only - Number of last token scores that will be aggregated (using mean) + * for purposes of determining blocks in the evictable area that comprise the most attention mass. */ - explicit EvictionScoreManager(size_t block_size, size_t num_decoder_layers, size_t max_pool_window_size, AggregationMode aggregation_mode, size_t ignore_first_n_blocks = 0, size_t snapkv_window_size = 0) : m_block_size(block_size), m_num_decoder_layers(num_decoder_layers), m_scores(num_decoder_layers), m_cache_counter(num_decoder_layers), m_max_pool_window_size(max_pool_window_size), m_aggregation_mode(aggregation_mode), m_ignore_first_n_blocks(ignore_first_n_blocks), m_snapkv_window_size(snapkv_window_size), m_num_registered_snapkv_aggregated_scores(0) {} + explicit EvictionScoreManager(size_t block_size, size_t num_decoder_layers, size_t max_pool_window_size, AggregationMode aggregation_mode, size_t ignore_first_n_blocks = 0, size_t snapkv_window_size = 0, size_t adaptive_rkv_window_size = 8) : m_block_size(block_size), m_num_decoder_layers(num_decoder_layers), m_scores(num_decoder_layers), m_cache_counter(num_decoder_layers), m_max_pool_window_size(max_pool_window_size), m_aggregation_mode(aggregation_mode), m_ignore_first_n_blocks(ignore_first_n_blocks), m_snapkv_window_size(snapkv_window_size), m_num_registered_snapkv_aggregated_scores(0), m_adaptive_rkv_window_size(adaptive_rkv_window_size), m_previous_scores_queues(num_decoder_layers) {} /** * Registers new token scores and aggregates them internally as necessary. The token scores provided may be corresponding not to all @@ -100,6 +104,22 @@ class EvictionScoreManager { std::size_t m_ignore_first_n_blocks; std::size_t m_snapkv_window_size; std::size_t m_num_registered_snapkv_aggregated_scores; + size_t m_adaptive_rkv_window_size = 8; + + struct EvictionScoreRecord { + EvictionScoreRecord(const std::vector& score_, const std::set& skips_) : score(score_), skips(skips_) {}; + std::vector score; + std::set skips; + }; + + std::vector> m_previous_scores_queues; + + void _initialize_score_with_skips(std::vector& dst, const std::vector& src, const std::set skipped_logical_block_ids); + void _accumulate_initial_scores(const std::vector& max_pooled_hh_scores, size_t decoder_layer_idx, size_t num_snapkv_scores, const std::set& skipped_logical_block_ids); + + void _accumulate_layer_scores_to(size_t decoder_layer_idx, const std::vector& src, const std::set& skipped_logical_block_ids, std::vector& dst); + void _accumulate_with_existing_scores(const std::vector& max_pooled_hh_scores, size_t decoder_layer_idx, size_t num_snapkv_scores, const std::set& skipped_logical_block_ids); + void _adjust_norm_sum_counters(size_t decoder_layer_idx, size_t old_size_in_tokens, size_t new_size_in_tokens); }; class SnapKVScoreAggregationCalculator { @@ -201,6 +221,8 @@ class CacheEvictionAlgorithm { */ std::vector> evict_logical_blocks(); + void register_token_similarity(const TokenSimilarityForEachDecoderLayer& token_similarity_for_all_decoder_layers); + private: std::size_t get_num_blocks(std::size_t num_tokens) const; @@ -221,6 +243,10 @@ class CacheEvictionAlgorithm { std::size_t m_num_evicted_tokens = 0; std::size_t m_num_decoder_layers; EvictionScoreManager m_score_manager; + + std::vector> m_last_token_similarity; + std::pair, size_t> get_adaptive_rkv_similarity_set(size_t max_num_blocks_kept, const std::vector& evictable_area_token_scores); + std::set get_adaptive_rkv_diverse_blocks(size_t num_blocks_left_to_fill, const std::set& similarity_set, const std::vector& token_similarity); }; diff --git a/src/cpp/src/continuous_batching/model_runner.hpp b/src/cpp/src/continuous_batching/model_runner.hpp index bfe5a4ec2d..137a44ed62 100644 --- a/src/cpp/src/continuous_batching/model_runner.hpp +++ b/src/cpp/src/continuous_batching/model_runner.hpp @@ -15,6 +15,7 @@ #include "continuous_batching/timer.hpp" #include "continuous_batching/attention_output.hpp" +#include "continuous_batching/cache_eviction.hpp" namespace ov::genai { @@ -24,6 +25,12 @@ inline std::string get_paged_attention_score_output_for_decoder_layer(size_t dec return ss.str(); } +inline std::string get_adaptive_rkv_similarity_score_output_for_decoder_layer(size_t decoder_layer_id) { + std::stringstream ss; + ss << "adaptive_rkv_similarity." << decoder_layer_id; + return ss.str(); +} + /** * @brief Runs the LLM infer request, parsing the continuous batching scheduler output into proper inputs in terms of OV API (e.g. token input IDs, * KV cache block indices etc.) and returning the logit scores for the next token to be generated for each of the currently scheduled sequences. @@ -31,6 +38,7 @@ inline std::string get_paged_attention_score_output_for_decoder_layer(size_t dec class ModelRunner { ov::InferRequest m_request; AttentionScoresForEachSubsequence m_last_attention_scores; + TokenSimilarityForEachSubsequence m_last_token_similarities; size_t m_block_size; size_t m_num_decoder_layers; bool m_collect_attention_scores; @@ -44,6 +52,9 @@ class ModelRunner { bool m_is_aggregate_attention_scores; bool m_is_use_xattention_inputs; + + bool m_is_use_adaptive_rkv; + // A model to compute token embeddings. // Input shape: [N, conversation length]. // Output shape: [1, conversation length, hidden_size]. @@ -84,7 +95,8 @@ class ModelRunner { bool is_use_per_layer_cache_control = false, bool is_use_rotation_inputs = false, bool is_aggregate_attention_scores = false, - bool is_use_xattention_inputs = false) + bool is_use_xattention_inputs = false, + bool m_is_use_adaptive_rkv_inputs = false) : m_request(std::move(request)), m_block_size(block_size), m_num_decoder_layers(num_decoder_layers), @@ -93,7 +105,8 @@ class ModelRunner { m_is_use_rotation_inputs(is_use_rotation_inputs), m_rotated_block_logical_indices_per_sequence_for_each_layer(num_decoder_layers), m_is_aggregate_attention_scores(is_aggregate_attention_scores), - m_is_use_xattention_inputs(is_use_xattention_inputs) { + m_is_use_xattention_inputs(is_use_xattention_inputs), + m_is_use_adaptive_rkv(m_is_use_adaptive_rkv_inputs) { OPENVINO_ASSERT(m_num_decoder_layers != 0, "num_decoder_layers must be non-zero"); _reset_cache_rotation_coefficients(); } @@ -118,6 +131,9 @@ class ModelRunner { return m_last_attention_scores; } + const TokenSimilarityForEachSubsequence& get_last_token_similarities() const { + return m_last_token_similarities; + } void set_cache_rotation_trig_lut(ov::Tensor&& rotation_trig_lut) { m_cache_rotation_trig_lut = std::move(rotation_trig_lut); @@ -382,6 +398,10 @@ class ModelRunner { _set_xattention_tensors(sequence_groups, scheduler_output, batch_size_in_sequences); } + if (m_is_use_adaptive_rkv) { + _set_adaptive_rkv_tensors(sequence_groups, scheduler_output, batch_size_in_sequences); + } + if (matmul_gathering_is_available) { // use pre-allocated tensor for gather_indices as well ov::Tensor gather_indices = m_request.get_tensor("sampled_tokens_indices"); @@ -404,6 +424,10 @@ class ModelRunner { _collect_attention_scores(sequence_groups, scheduler_output); } + if (m_is_use_adaptive_rkv) { + _collect_token_similarities(sequence_groups, scheduler_output); + } + _reset_cache_rotation_coefficients(); // return logits @@ -749,6 +773,51 @@ class ModelRunner { } } + void _collect_token_similarities(const std::vector & sequence_groups, const Scheduler::Output& scheduler_output) { + m_last_token_similarities.clear(); + size_t num_sequence_groups = scheduler_output.m_scheduled_sequence_groups_ids.size(); + using IndexSpan = std::pair; + std::list> running_seq_ids_and_kvcache_spans; + size_t offset = 0; + for (size_t i = 0; i < num_sequence_groups; ++i) { + size_t seq_group_id = scheduler_output.m_scheduled_sequence_groups_ids[i]; + SequenceGroup::CPtr sequence_group = sequence_groups[seq_group_id]; + std::vector running_sequences = sequence_group->get_running_sequences(); + + for (size_t seq_idx = 0; seq_idx < running_sequences.size(); ++seq_idx) { + Sequence::CPtr sequence = running_sequences[seq_idx]; + size_t global_sequence_id = sequence->get_id(); + auto it = scheduler_output.m_adaptive_rkv_evictable_sizes.find(global_sequence_id); + if (it == scheduler_output.m_adaptive_rkv_evictable_sizes.end()) { + // Adaptive R-KV similarity calculation was not scheduled for this sequence + continue; + } + size_t num_similarity_tokens_calculated = it->second * m_block_size; + // As we only evict during generation phase, so will the similarity calculation will only be + // scheduled after prefill is finished + OPENVINO_ASSERT(sequence_group->can_generate_tokens()); + + IndexSpan span = {offset, offset + num_similarity_tokens_calculated}; + offset += num_similarity_tokens_calculated; + running_seq_ids_and_kvcache_spans.emplace_back(global_sequence_id, span); + } + } + + for (const auto& seq_id_and_score_span : running_seq_ids_and_kvcache_spans) { + auto token_similarities_across_decoder_layers_for_current_sequence = TokenSimilarityForEachDecoderLayer(m_num_decoder_layers); + size_t global_sequence_id = seq_id_and_score_span.first; + IndexSpan span = seq_id_and_score_span.second; + for (size_t decoder_layer_id = 0; decoder_layer_id < m_num_decoder_layers; decoder_layer_id++) { + auto attention_score = m_request.get_tensor(get_adaptive_rkv_similarity_score_output_for_decoder_layer(decoder_layer_id)); + auto scores_for_cache_of_current_sequence_group = ov::Tensor(attention_score, ov::Coordinate{span.first}, ov::Coordinate{span.second}); + auto copied_tensor = ov::Tensor(scores_for_cache_of_current_sequence_group.get_element_type(), ov::Shape{span.second - span.first}); + scores_for_cache_of_current_sequence_group.copy_to(copied_tensor); + token_similarities_across_decoder_layers_for_current_sequence[decoder_layer_id] = scores_for_cache_of_current_sequence_group; + } + m_last_token_similarities[global_sequence_id] = token_similarities_across_decoder_layers_for_current_sequence; + } + } + void _set_xattention_tensors(const std::vector& sequence_groups, const Scheduler::Output& scheduler_output, size_t batch_size_in_sequences) { @@ -786,5 +855,40 @@ class ModelRunner { } } + + void _set_adaptive_rkv_tensors(const std::vector& sequence_groups, + const Scheduler::Output& scheduler_output, + size_t batch_size_in_sequences) { + ov::Tensor adaptive_rkv_start_size(ov::element::i32, {}); + adaptive_rkv_start_size.data()[0] = scheduler_output.m_adaptive_rkv_start_size; + m_request.set_tensor("adaptive_rkv_start_size", adaptive_rkv_start_size); + + ov::Tensor adaptive_rkv_evictable_sizes(ov::element::i32, {batch_size_in_sequences}); + float* adaptive_rkv_evictable_sizes_data = adaptive_rkv_evictable_sizes.data(); + for (size_t i = 0; i < scheduler_output.m_scheduled_sequence_groups_ids.size(); i++) { + size_t seq_group_id = scheduler_output.m_scheduled_sequence_groups_ids[i]; + SequenceGroup::CPtr sequence_group = sequence_groups[seq_group_id]; + std::vector running_sequences = sequence_group->get_running_sequences(); + size_t num_running_sequences = running_sequences.size(); + for (size_t k = 0; k < num_running_sequences; ++k) { + Sequence::CPtr sequence = running_sequences[k]; + size_t seq_id = sequence->get_id(); + size_t evictable_size = 0; + + if (scheduler_output.m_adaptive_rkv_evictable_sizes.find(seq_id) != scheduler_output.m_adaptive_rkv_evictable_sizes.end()) { + evictable_size = scheduler_output.m_adaptive_rkv_evictable_sizes.at(seq_id); + } + *adaptive_rkv_evictable_sizes_data = evictable_size; + adaptive_rkv_evictable_sizes_data += 1; + } + } + + m_request.set_tensor("adaptive_rkv_evictable_sizes", adaptive_rkv_evictable_sizes); + + // Reserved for future use + ov::Tensor adaptive_rkv_diversity_block_set(ov::element::i32, {batch_size_in_sequences}); + m_request.set_tensor("adaptive_rkv_diversity_block_set", adaptive_rkv_diversity_block_set); + + } }; } diff --git a/src/cpp/src/continuous_batching/paged_attention_transformations.cpp b/src/cpp/src/continuous_batching/paged_attention_transformations.cpp index f175ab2cde..9f529355b0 100644 --- a/src/cpp/src/continuous_batching/paged_attention_transformations.cpp +++ b/src/cpp/src/continuous_batching/paged_attention_transformations.cpp @@ -10,13 +10,13 @@ namespace ov { namespace genai { namespace utils { -void apply_paged_attention_transformations(std::shared_ptr model, bool per_layer_cache_control, bool allow_cache_rotation, bool allow_xattention) { +void apply_paged_attention_transformations(std::shared_ptr model, bool per_layer_cache_control, bool allow_cache_rotation, bool allow_xattention, bool allow_adaptive_rkv) { const ov::op::util::VariableVector& variables = model->get_variables(); OPENVINO_ASSERT(!variables.empty(), "Model is supposed to be stateful"); bool use_block_indices_inputs = per_layer_cache_control; bool use_score_outputs = per_layer_cache_control; - ov::pass::SDPAToPagedAttention(use_block_indices_inputs, use_score_outputs, /* allow_score_aggregation = */ true, allow_cache_rotation, allow_xattention).run_on_model(model); + ov::pass::SDPAToPagedAttention(use_block_indices_inputs, use_score_outputs, /* allow_score_aggregation = */ true, allow_cache_rotation, allow_xattention/* FIXME (VSHAMPOR): allow_adaptive_rkv */).run_on_model(model); std::map> key_cache_params, value_cache_params; for (const auto& param_ptr : model->get_parameters()) { diff --git a/src/cpp/src/continuous_batching/paged_attention_transformations.hpp b/src/cpp/src/continuous_batching/paged_attention_transformations.hpp index d2b6445997..db02c57637 100644 --- a/src/cpp/src/continuous_batching/paged_attention_transformations.hpp +++ b/src/cpp/src/continuous_batching/paged_attention_transformations.hpp @@ -24,7 +24,7 @@ namespace utils { * @param allow_xattention If true, then the transformations will enable additional per-layer inputs to control the XAttention block-sparse * attention optimization. */ -void apply_paged_attention_transformations(std::shared_ptr model, bool per_layer_cache_control = false, bool allow_cache_rotation = false, bool allow_xattention = false); +void apply_paged_attention_transformations(std::shared_ptr model, bool per_layer_cache_control = false, bool allow_cache_rotation = false, bool allow_xattention = false, bool allow_adaptive_rkv = false); void apply_gather_before_matmul_transformation(std::shared_ptr model); diff --git a/src/cpp/src/continuous_batching/pipeline_impl.cpp b/src/cpp/src/continuous_batching/pipeline_impl.cpp index f29c3799ed..1439d1f9da 100644 --- a/src/cpp/src/continuous_batching/pipeline_impl.cpp +++ b/src/cpp/src/continuous_batching/pipeline_impl.cpp @@ -4,6 +4,7 @@ #include #include #include +#include "openvino/genai/cache_eviction.hpp" #ifdef __APPLE__ #include @@ -72,7 +73,8 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::ContinuousBatchingImpl( bool is_need_per_layer_cache_control = scheduler_config.use_cache_eviction; bool allow_cache_rotation = scheduler_config.cache_eviction_config.apply_rotation; bool allow_xattention = scheduler_config.use_sparse_attention && scheduler_config.sparse_attention_config.mode == SparseAttentionMode::XATTENTION; - utils::apply_paged_attention_transformations(model, is_need_per_layer_cache_control, allow_cache_rotation, allow_xattention); + bool allow_adaptive_rkv = scheduler_config.use_cache_eviction && scheduler_config.cache_eviction_config.aggregation_mode == AggregationMode::ADAPTIVE_RKV; + utils::apply_paged_attention_transformations(model, is_need_per_layer_cache_control, allow_cache_rotation, allow_xattention, allow_adaptive_rkv); utils::apply_gather_before_matmul_transformation(model); initialize_pipeline(model, scheduler_config, device, properties); @@ -182,6 +184,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::initialize_pipeline( m_scheduler = std::make_shared(m_block_size, cache_manager, normalized_config, m_num_decoder_layers, can_use_partial_preemption, eviction_config.snapkv_window_size); bool is_apply_rotation = eviction_config.apply_rotation; + bool is_use_adaptive_rkv = (eviction_config.aggregation_mode == AggregationMode::ADAPTIVE_RKV); m_model_runner = std::make_shared(infer_request, m_block_size, m_num_decoder_layers, @@ -189,7 +192,8 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::initialize_pipeline( /* is_use_per_layer_cache_control = */ true, /* is_use_rotation_inputs = */ is_apply_rotation, /* is_aggregate_attention_scores = */ true, - is_use_xattention); + is_use_xattention, + is_use_adaptive_rkv); if (eviction_config.apply_rotation) { _prepare_rotation_data_storage(normalized_config, cache_manager->get_v_head_size(0)); } @@ -201,7 +205,8 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::initialize_pipeline( /* is_use_per_layer_cache_control = */ false, /* is_use_rotation_inputs = */ false, /* is_aggregate_attention_scores = */ false, - is_use_xattention); + is_use_xattention, + /* is_use_adaptive_rkv = */ false); } m_sampler = std::make_shared(m_tokenizer, sampler_num_threads); @@ -327,10 +332,12 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::step() { m_pipeline_metrics.avg_cache_usage = _get_current_running_average_cache_usage(); const auto& sched_config = m_scheduler->get_config(); - if (sched_config.use_cache_eviction && sched_config.cache_eviction_config.apply_rotation) { - _compute_cache_rotation_data(m_requests, scheduler_output); - m_model_runner->set_cache_rotation_data(std::move(m_current_step_rotated_block_indices_per_sequence), - std::move(m_current_step_rotation_deltas)); + if (sched_config.use_cache_eviction) { + if (sched_config.cache_eviction_config.apply_rotation) { + _compute_cache_rotation_data(m_requests, scheduler_output); + m_model_runner->set_cache_rotation_data(std::move(m_current_step_rotated_block_indices_per_sequence), + std::move(m_current_step_rotation_deltas)); + } } } @@ -696,7 +703,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_compute_cache_rotation void ContinuousBatchingPipeline::ContinuousBatchingImpl::_maybe_evict_cache_blocks(const SchedulerConfig& sched_config, const Scheduler::Output& scheduler_output) { std::unordered_map seq_group_to_num_blocks_evicted_map; - auto sequence_attention_scores = m_model_runner->get_last_attention_scores(); + const auto& sequence_attention_scores = m_model_runner->get_last_attention_scores(); OPENVINO_ASSERT(!sequence_attention_scores.empty()); size_t num_decoder_layers = sequence_attention_scores.begin()->second.size(); @@ -735,6 +742,14 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_maybe_evict_cache_bloc continue; } + if (sched_config.cache_eviction_config.aggregation_mode == AggregationMode::ADAPTIVE_RKV) { + const auto& token_similarities = m_model_runner->get_last_attention_scores(); + auto it = token_similarities.find(seq_id); + if (it != token_similarities.end()) { + cache_eviction_algo.register_token_similarity(it->second); + } + } + m_previous_num_blocks_before_eviction_per_sequence[seq_id] = seq_group_ptr->get_num_logical_blocks(); auto logical_blocks_to_evict = cache_eviction_algo.evict_logical_blocks(); @@ -760,6 +775,12 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_maybe_evict_cache_bloc } } +void ContinuousBatchingPipeline::ContinuousBatchingImpl::_set_adaptive_rkv_diversity_blocks(const SchedulerConfig& sched_config, const Scheduler::Output& scheduler_output) { + // TODO(vshampor): implement +} + + + void ContinuousBatchingPipeline::ContinuousBatchingImpl::_fill_prompt_log_probs(std::vector& sequence_groups, ov::Tensor& logits) { const float * logits_data = logits.data(); ov::Shape logits_shape = logits.get_shape(); diff --git a/src/cpp/src/continuous_batching/pipeline_impl.hpp b/src/cpp/src/continuous_batching/pipeline_impl.hpp index 9e6cebdd99..4d5ffa604f 100644 --- a/src/cpp/src/continuous_batching/pipeline_impl.hpp +++ b/src/cpp/src/continuous_batching/pipeline_impl.hpp @@ -90,11 +90,13 @@ class ContinuousBatchingPipeline::ContinuousBatchingImpl : public ContinuousBatc */ void _maybe_evict_cache_blocks(const SchedulerConfig& sched_config, const Scheduler::Output& scheduler_output); + void _register_step_cache_usage(float step_cache_usage); void _reset_cache_usage_statistics(); float _get_current_running_average_cache_usage() const; void _compute_cache_rotation_data(const std::vector& sequence_groups, const Scheduler::Output& scheduler_output); void _prepare_rotation_data_storage(const SchedulerConfig& normalized_config, size_t embedding_size); + void _set_adaptive_rkv_diversity_blocks(const SchedulerConfig& sched_config, const Scheduler::Output& scheduler_output); virtual void drop_requests(); @@ -115,7 +117,7 @@ class ContinuousBatchingPipeline::ContinuousBatchingImpl : public ContinuousBatc const ov::AnyMap& properties, const ov::genai::GenerationConfig& generation_config, bool is_validation_mode_enabled = false); - + virtual ~ContinuousBatchingImpl(); GenerationHandle add_request(uint64_t request_id, diff --git a/src/cpp/src/continuous_batching/scheduler.hpp b/src/cpp/src/continuous_batching/scheduler.hpp index 3e565a9d5d..156bcdedfa 100644 --- a/src/cpp/src/continuous_batching/scheduler.hpp +++ b/src/cpp/src/continuous_batching/scheduler.hpp @@ -52,6 +52,9 @@ class Scheduler { size_t m_xattention_block_size = 0; size_t m_xattention_stride = 0; + size_t m_adaptive_rkv_start_size = 0; + // A value of 0 means that Adaptive R-KV similarity computation is not to be applied + std::map m_adaptive_rkv_evictable_sizes; // total number of scheduled tokens size_t m_total_num_scheduled_tokens = 0; @@ -334,6 +337,9 @@ class Scheduler { scheduler_output.m_xattention_thresholds[seq_id] = _schedule_xattention_threshold(sequence_group); scheduler_output.m_xattention_block_size = m_config.sparse_attention_config.xattention_block_size; scheduler_output.m_xattention_stride = m_config.sparse_attention_config.xattention_stride; + + scheduler_output.m_adaptive_rkv_start_size = m_config.cache_eviction_config.get_start_size(); + scheduler_output.m_adaptive_rkv_evictable_sizes[seq_id] = _schedule_adaptive_rkv_evictable_size(sequence_group); } } @@ -402,8 +408,13 @@ class Scheduler { scheduler_output.m_block_tables[seq_id] = m_block_manager->get_block_tables(seq_id); scheduler_output.m_score_aggregation_windows[seq_id] = _schedule_scores_to_aggregate(sequence_group); + + scheduler_output.m_xattention_thresholds[seq_id] = _schedule_xattention_threshold(sequence_group); scheduler_output.m_xattention_block_size = m_config.sparse_attention_config.xattention_block_size; scheduler_output.m_xattention_stride = m_config.sparse_attention_config.xattention_stride; + + scheduler_output.m_adaptive_rkv_start_size = m_config.cache_eviction_config.get_start_size(); + scheduler_output.m_adaptive_rkv_evictable_sizes[seq_id] = _schedule_adaptive_rkv_evictable_size(sequence_group); } @@ -476,8 +487,6 @@ class Scheduler { // add scheduling information { - Sequence::Ptr sequence = (*sequence_group)[0]; - uint64_t seq_id = sequence->get_id(); // and schedule tokens sequence_group->schedule_tokens(sequence_len); @@ -490,10 +499,15 @@ class Scheduler { uint64_t seq_id = sequence_group->get_running_sequences()[0]->get_id(); scheduler_output.m_block_tables[seq_id] = m_block_manager->get_block_tables(seq_id); scheduler_output.m_total_num_scheduled_tokens += sequence_len; + scheduler_output.m_score_aggregation_windows[seq_id] = _schedule_scores_to_aggregate(sequence_group); + scheduler_output.m_xattention_thresholds[seq_id] = _schedule_xattention_threshold(sequence_group); scheduler_output.m_xattention_block_size = m_config.sparse_attention_config.xattention_block_size; scheduler_output.m_xattention_stride = m_config.sparse_attention_config.xattention_stride; + + scheduler_output.m_adaptive_rkv_start_size = m_config.cache_eviction_config.get_start_size(); + scheduler_output.m_adaptive_rkv_evictable_sizes[seq_id] = _schedule_adaptive_rkv_evictable_size(sequence_group); } // update "is_prompt" flag @@ -587,6 +601,30 @@ class Scheduler { return m_config.sparse_attention_config.xattention_threshold; } + size_t _schedule_adaptive_rkv_evictable_size(SequenceGroup::Ptr sequence_group) { + if (!(m_config.use_cache_eviction && m_config.cache_eviction_config.aggregation_mode == AggregationMode::ADAPTIVE_RKV)) { + return 0; + } + if (!sequence_group->can_generate_tokens()) { + // Won't evict during prefill + return 0; + } + + // First similarity/diversity calculation will be scheduled when at least `max_cache_size` tokens are filled + if (sequence_group->get_num_processed_tokens() < m_config.cache_eviction_config.get_max_cache_size()) { + return 0; + } + + if (sequence_group->get_num_cached_tokens() % get_block_size() != 0) { + // Only request similarity computation once every block since eviction can only occur with a block granularity + return 0; + } + + size_t non_evictable_size = m_config.cache_eviction_config.get_max_cache_size() - m_config.cache_eviction_config.get_evictable_size(); + OPENVINO_ASSERT(sequence_group->get_num_logical_blocks() * get_block_size() >= non_evictable_size); + + return sequence_group->get_num_logical_blocks() * get_block_size() - non_evictable_size; + } }; } diff --git a/src/python/py_continuous_batching_pipeline.cpp b/src/python/py_continuous_batching_pipeline.cpp index 3a25667936..c4415adbb4 100644 --- a/src/python/py_continuous_batching_pipeline.cpp +++ b/src/python/py_continuous_batching_pipeline.cpp @@ -8,6 +8,7 @@ #include #include +#include "openvino/genai/cache_eviction.hpp" #include "openvino/genai/continuous_batching_pipeline.hpp" #include "openvino/genai/sparse_attention.hpp" #include "tokenizer/tokenizers_path.hpp" @@ -32,6 +33,7 @@ using ov::genai::SchedulerConfig; using ov::genai::PipelineMetrics; using ov::genai::KVCrushAnchorPointMode; using ov::genai::KVCrushConfig; +using ov::genai::AdaptiveRKVConfig; namespace { @@ -276,9 +278,12 @@ void init_continuous_batching_pipeline(py::module_& m) { py::enum_(m, "AggregationMode", R"(Represents the mode of per-token score aggregation when determining least important tokens for eviction from cache :param AggregationMode.SUM: In this mode the importance scores of each token will be summed after each step of generation - :param AggregationMode.NORM_SUM: Same as SUM, but the importance scores are additionally divided by the lifetime (in tokens generated) of a given token in cache)") + :param AggregationMode.NORM_SUM: Same as SUM, but the importance scores are additionally divided by the lifetime (in tokens generated) of a given token in cache + :param AggregationMode.ADAPTIVE_RKV Switches the cache eviction algorithm to use Adaptive R-KV algorithm. The scores are aggregated within a configurable window + size of the latest generated tokens. May not be used together with the KVCrush algorithm (which is disabled automatically in this mode).)") .value("SUM", AggregationMode::SUM) - .value("NORM_SUM", AggregationMode::NORM_SUM); + .value("NORM_SUM", AggregationMode::NORM_SUM) + .value("ADAPTIVE_RKV", AggregationMode::ADAPTIVE_RKV); py::enum_(m, "KVCrushAnchorPointMode", R"(Represents the anchor point types for KVCrush cache eviction @@ -304,6 +309,13 @@ void init_continuous_batching_pipeline(py::module_& m) { .def_readwrite("anchor_point_mode", &KVCrushConfig::anchor_point_mode) .def_readwrite("rng_seed", &KVCrushConfig::rng_seed); + py::class_(m, "AdaptiveRKVConfig", "Configuration struct for the Adaptive R-KV cache eviction algorithm") + .def(py::init<>([](double attention_mass, size_t window_size) { return AdaptiveRKVConfig(attention_mass, window_size); }), + py::arg("attention_mass") = 0.9, + py::arg("window_size") = 8) + .def_readwrite("attention_mass", &AdaptiveRKVConfig::attention_mass) + .def_readwrite("window_size", &AdaptiveRKVConfig::window_size); + py::class_(m, "CacheEvictionConfig", cache_eviction_config_docstring) .def(py::init<>([](const size_t start_size, size_t recent_size, size_t max_cache_size, AggregationMode aggregation_mode, bool apply_rotation, size_t snapkv_window_size, py::object kvcrush_config) { @@ -328,7 +340,8 @@ void init_continuous_batching_pipeline(py::module_& m) { .def("get_start_size", &CacheEvictionConfig::get_start_size) .def("get_recent_size", &CacheEvictionConfig::get_recent_size) .def("get_max_cache_size", &CacheEvictionConfig::get_max_cache_size) - .def("get_evictable_size", &CacheEvictionConfig::get_evictable_size); + .def("get_evictable_size", &CacheEvictionConfig::get_evictable_size) + .def_readwrite("adaptive_rkv_config", &CacheEvictionConfig::adaptive_rkv_config); py::enum_(m, "SparseAttentionMode", R"(Represents the mode of sparse attention applied during generation. diff --git a/tests/cpp/cache_eviction.cpp b/tests/cpp/cache_eviction.cpp index 7c7f6209f5..56177118fd 100644 --- a/tests/cpp/cache_eviction.cpp +++ b/tests/cpp/cache_eviction.cpp @@ -376,11 +376,79 @@ TEST_P(EvictionScoreManagerRegisterScoresParameterizedTest, ScoresAndCountersAft } } + + + INSTANTIATE_TEST_SUITE_P(VariousInputs, EvictionScoreManagerRegisterScoresParameterizedTest, ::testing::ValuesIn(REGISTER_SCORES_TEST_CASES), [](const testing::TestParamInfo& info) { return info.param.test_id; }); + +struct EvictionScoreManagerAdaptiveRKVRegisterScoresTestStruct { + std::string test_id; + size_t block_size; + size_t max_pool_window_size; + size_t adaptive_rkv_window_size; + + std::vector>, std::set>> scores_and_skips; + std::vector> ref_scores; +}; + +using EvictionScoreManagerAdaptiveRKVRegisterScoresParameterizedTest = ::testing::TestWithParam; + +const std::vector ADAPTIVE_RKV_REGISTER_SCORES_TEST_CASES = { + // TODO(vshampor): fix + { "within_adaptive_rkv_window", + /* block_size =*/ 2, /* max_pool_window_size = */ 3, /* adaptive_rkv_window_size = */ 8, + { + { {{1.5, -0.8, 4.1, 7.7, 3.6, -7.4}, + {-0.9, 1.4, 6.4, -9.0, 8.1, 2.6}}, {} }, + { {{-7.4, 2.6, 8.9}, + {-3.1, -8.2, 5.9}}, {1, 2} } + }, + + { {2.05, 3.85, 3.85, 3.85, 4.45, 4.45, 4.45}, + {3.2, 3.2, 4.05, 4.05, 4.05, 2.95, 2.95} }, + }, + { "exceeding_adaptive_rkv_window", + /* block_size =*/ 2, /* max_pool_window_size = */ 3, /* adaptive_rkv_window_size = */ 3, + { + { {{ 1.5, -0.8, 4.1, 7.7, 3.6, -7.4}, + {-0.9, 1.4, 6.4, -9.0, 8.1, 2.6}}, {} }, + { {{-7.4, 2.6, 8.9}, + {-3.1, -8.2, 5.9}}, {1, 2} }, + { {{ 4.3, -4.1, -2.7, 8.3, -3.8, 4.9, 7.2, -6.2}, + {-2.2, 5.8, 7.0, 7.6, -9.8, -3.7, 1.4, -1.0 }}, {} }, + { {{ 9.8, -3.8, 1.0, -1.9, 6.2, 3.0, 0.7, -4.5, 6.7, -4.7}, + { 4.9, -7.6, 6.5, -6.7, 0.5, 6.7, 8.8, -7.5, 8.9, -0.5}}, {} } + }, + + { {2.233333, 2.133333, 2.1333333, 2.633333, 5.6, 5.6, 5.6, 2.233333, 2.233333, -1.566666}, + {4.5, 4.5, 4.5, 1.0, 5.366666, 5.3666666, 5.366666, 2.966666, 2.966666, -0.1666666 } }, + }, +}; + +TEST_P(EvictionScoreManagerAdaptiveRKVRegisterScoresParameterizedTest, ScoresAfterRegistrationAreCorrect) { + const auto& test_struct = GetParam(); + ov::genai::EvictionScoreManager mgr(test_struct.block_size, DEFAULT_NUM_DECODER_LAYERS, test_struct.max_pool_window_size, ov::genai::AggregationMode::ADAPTIVE_RKV,/* ignore_first_n_blocks = */ 0, /* snapkv_window_size = */ 0, test_struct.adaptive_rkv_window_size); + for (const auto& score_and_skip : test_struct.scores_and_skips) { + mgr.register_new_token_scores(get_layer_scores_from_2d_vector(score_and_skip.first), score_and_skip.second); + } + const auto& test_scores = mgr.get_scores(); + ASSERT_EQ(test_scores.size(), DEFAULT_NUM_DECODER_LAYERS); + + float abs_tol = 1e-6; + for (size_t layer_idx = 0; layer_idx < DEFAULT_NUM_DECODER_LAYERS; layer_idx++) { + EXPECT_THAT(test_scores[layer_idx], ::testing::Pointwise(::testing::DoubleNear(abs_tol), test_struct.ref_scores[layer_idx])); + } +} + +INSTANTIATE_TEST_SUITE_P(VariousInputs, EvictionScoreManagerAdaptiveRKVRegisterScoresParameterizedTest, ::testing::ValuesIn(ADAPTIVE_RKV_REGISTER_SCORES_TEST_CASES), + [](const testing::TestParamInfo& info) { + return info.param.test_id; + }); + struct EvictionScoreManagerSnapKVCounterTestStruct { std::string test_id; size_t snapkv_window_size; @@ -910,6 +978,69 @@ INSTANTIATE_TEST_SUITE_P(VariousSetsOfLowScoreBlocks, CacheEvictionLowScoreBlock return info.param.test_id; }); +struct CacheEvictionAdaptiveRKVLowScoreAndSimilarityTestStruct { + std::string test_id; + size_t tokens_over_max_cache_size; + ov::genai::AdaptiveRKVConfig adaptive_rkv_config; + std::vector evictable_area_token_scores; + std::vector evictable_area_token_similarity; + std::set ref_evicted_blocks; +}; + +using CacheEvictionAdaptiveRKVLowScoreAndSimilarityParameterizedTest = ::testing::TestWithParam; + +// clang-format off +const std::vector ADAPTIVE_RKV_LOW_SCORE_AND_SIMILARITY_EVICTION_TEST_CASES = { + // Expecting `max_cache_size - start_area - recent_area equal` to 3 blocks, block size of 2 + // same, but with multiple blocks in evictable area + { + "three_blocks_overflow_one_hiscore_two_diverse_to_keep", + 2 * 2 + 1, // 2 blocks worth of overflow + 1 tokens, amounting to 3 blocks to be evicted + ov::genai::AdaptiveRKVConfig(0.9, 1), + {999.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, + {10.0, 11.0, 0.5, 0.1, 18.0, 19.0, 0.2, 0.4, 23.1, 24.2, 19.8, 18.7}, + {3, 5, 6} + }, + { + "two_blocks_overflow_two_hiscore_one_diverse_to_keep", + 2 * 2, // 2 blocks worth of overflow + ov::genai::AdaptiveRKVConfig(0.9, 1), + {0.1, 0.2, 0.3, 0.4, 0.1, 8.3, 11.0, 0.0, 0.0, 0.0}, + {2.1, 2.2, -3.5, 8.1, -19.0, -21.4, 8.2, 0.4, 1.2, 1.1}, + {1, 2} + } +}; +// clang-format on + +TEST_P(CacheEvictionAdaptiveRKVLowScoreAndSimilarityParameterizedTest, EvictsLowestScoredBlocksAndKeepsDiverse) { + auto test_struct = GetParam(); + size_t num_decoder_layers = DEFAULT_NUM_DECODER_LAYERS; + auto algo = ov::genai::CacheEvictionAlgorithm(ov::genai::CacheEvictionConfig(2, 2, 10, ov::genai::AggregationMode::ADAPTIVE_RKV, /* apply_rotation = */ false, /* snapkv_window_size = */ 0), /* block_size = */2, num_decoder_layers, /* max_pool_window_size = */ 1); + + auto scores = get_mock_scores(num_decoder_layers, algo.get_max_cache_size_after_eviction() + test_struct.tokens_over_max_cache_size); + for (size_t layer_idx = 0; layer_idx < num_decoder_layers; layer_idx++) { + auto& scores_per_layer = scores[layer_idx]; + fill_scores(scores_per_layer, 0, scores_per_layer.get_size(), 1.0); + for (size_t evictable_area_tok_idx = 0; evictable_area_tok_idx < test_struct.evictable_area_token_scores.size(); evictable_area_tok_idx++) { + scores_per_layer.data()[2 + evictable_area_tok_idx] = test_struct.evictable_area_token_scores[evictable_area_tok_idx]; + } + } + algo.register_new_token_scores(scores); + auto similarity = std::vector>(DEFAULT_NUM_DECODER_LAYERS, test_struct.evictable_area_token_similarity); + algo.register_token_similarity(get_layer_scores_from_2d_vector(similarity)); + + auto test_evicted_blocks = algo.evict_logical_blocks(); + auto ref_evicted_blocks = test_struct.ref_evicted_blocks; + for (size_t layer_idx = 0; layer_idx < num_decoder_layers; layer_idx++) { + EXPECT_EQ(test_evicted_blocks[layer_idx], ref_evicted_blocks); + } +} + +INSTANTIATE_TEST_SUITE_P(VariousSetsOfLowScoreAndDiverseBlocks, CacheEvictionAdaptiveRKVLowScoreAndSimilarityParameterizedTest, + ::testing::ValuesIn(ADAPTIVE_RKV_LOW_SCORE_AND_SIMILARITY_EVICTION_TEST_CASES), + [](const testing::TestParamInfo& info) { + return info.param.test_id; + }); static constexpr size_t BLOCKS_TO_EVICT = 3; // 3 blocks to evict struct NormalizationSettingTestStruct { @@ -1105,6 +1236,16 @@ TEST_P(CacheEvictionAlgoInitializationTest, ThrowsForInvalidConfigs) { INSTANTIATE_TEST_SUITE_P(VariousInvalidInitParams, CacheEvictionAlgoInitializationTest, ::testing::ValuesIn(INVALID_ALGO_INIT_PARAMS_CASES)); +TEST(CacheEvictionAlgoAdaptiveRKVTest, ThrowsIfEvictingWithoutSimilarityData) { + auto algo = ov::genai::CacheEvictionAlgorithm(ov::genai::CacheEvictionConfig(4, 4, 12, ov::genai::AggregationMode::ADAPTIVE_RKV, /* apply_rotation = */ false, /* snapkv_window_size = */ 0), DEFAULT_BLOCK_SIZE, DEFAULT_NUM_DECODER_LAYERS, DEFAULT_MAX_POOL_WINDOW_SIZE); + std::vector> mock_scores(2, std::vector(16, 0.0)); + std::vector> mock_similarity(2, std::vector(8, 0.0)); + algo.register_new_token_scores(get_layer_scores_from_2d_vector(mock_scores)); + EXPECT_THROW(algo.evict_logical_blocks(), ov::Exception); + algo.register_token_similarity(get_layer_scores_from_2d_vector(mock_similarity)); + EXPECT_NO_THROW(algo.evict_logical_blocks()); +} + TEST(CacheRotationCalculatorTest, CanInitializeWithBasicParams) { EXPECT_NO_THROW(ov::genai::CacheRotationCalculator(32, 128, 64)); }