Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions src/cpp/include/openvino/genai/cache_eviction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#pragma once

#include <cstddef>
#include <set>

#include "openvino/core/except.hpp"

Expand All @@ -15,8 +16,11 @@ namespace ov::genai {
*/
enum class AggregationMode {
SUM, /**< In this mode the importance scores of each token will be summed after each step of generation */
NORM_SUM /**< Same as SUM, but the importance scores are additionally divided by the lifetime (in tokens generated)
NORM_SUM, /**< Same as SUM, but the importance scores are additionally divided by the lifetime (in tokens generated)
* of a given token in cache */
ADAPTIVE_RKV /** Switches the cache eviction algorithm to use Adaptive R-KV algorithm. The scores are aggregated within
a configurable window size of the latest generated tokens. May not be used together with the KVCrush
algorithm. */
};

/**
Expand Down Expand Up @@ -68,9 +72,15 @@ class KVCrushConfig {
}
};

/**
* @brief Configuration struct for the cache eviction algorithm.
*/
struct AdaptiveRKVConfig {
AdaptiveRKVConfig() = default;
AdaptiveRKVConfig(double attention_mass_, size_t window_size_) : attention_mass(attention_mass_), window_size(window_size_) {};

double attention_mass = 0.9;
size_t window_size = 8;
};


class CacheEvictionConfig {
public:
CacheEvictionConfig() = default;
Expand All @@ -81,14 +91,16 @@ class CacheEvictionConfig {
AggregationMode aggregation_mode_,
bool apply_rotation_ = false,
size_t snapkv_window_size_ = 8,
const KVCrushConfig& kvcrush_config_ = KVCrushConfig(0, KVCrushAnchorPointMode::RANDOM))
const KVCrushConfig& kvcrush_config_ = KVCrushConfig(0, KVCrushAnchorPointMode::RANDOM),
const AdaptiveRKVConfig& adaptive_rkv_config_ = AdaptiveRKVConfig())
: aggregation_mode(aggregation_mode_),
apply_rotation(apply_rotation_),
snapkv_window_size(snapkv_window_size_),
kvcrush_config(kvcrush_config_),
adaptive_rkv_config(adaptive_rkv_config_),
m_start_size(start_size),
m_recent_size(recent_size),
m_max_cache_size(max_cache_size),
kvcrush_config(kvcrush_config_) {
m_max_cache_size(max_cache_size) {
OPENVINO_ASSERT(start_size, "CacheEvictionConfig.start_size must be non-zero");
OPENVINO_ASSERT(recent_size, "CacheEvictionConfig.recent_size must be non-zero");
OPENVINO_ASSERT(max_cache_size, "CacheEvictionConfig.max_cache_size must be non-zero");
Expand Down Expand Up @@ -142,6 +154,8 @@ class CacheEvictionConfig {
* even if they are not among the most important ones.*/
KVCrushConfig kvcrush_config;

AdaptiveRKVConfig adaptive_rkv_config;

private:
/** Number of tokens in the *beginning* of KV cache that should be retained
* in the KV cache for this sequence during generation. Must be non-zero and a multiple of the KV cache block size for
Expand Down
5 changes: 5 additions & 0 deletions src/cpp/src/continuous_batching/attention_output.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,8 @@
using AttentionScoresForCacheOfSubsequence = ov::Tensor;
using AttentionScoresForEachDecoderLayer = std::vector<AttentionScoresForCacheOfSubsequence>;
using AttentionScoresForEachSubsequence = std::map<size_t, AttentionScoresForEachDecoderLayer>;


using TokenSimilarityForSubsequence = ov::Tensor;
using TokenSimilarityForEachDecoderLayer = std::vector<TokenSimilarityForSubsequence>;
using TokenSimilarityForEachSubsequence = std::map<size_t, TokenSimilarityForEachDecoderLayer>;
413 changes: 292 additions & 121 deletions src/cpp/src/continuous_batching/cache_eviction.cpp

Large diffs are not rendered by default.

28 changes: 27 additions & 1 deletion src/cpp/src/continuous_batching/cache_eviction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <vector>
#include <cstdlib>
#include <cmath>
#include <deque>

#include "openvino/openvino.hpp"
#include "continuous_batching/attention_output.hpp"
Expand All @@ -15,6 +16,7 @@

namespace ov::genai {


/**
* @brief Keeps track of the accumulated token scores across model inferences and their lifetime.
*/
Expand All @@ -39,8 +41,10 @@ class EvictionScoreManager {
* where `S` is equal to `snapkv_window_size`. In contrast, if this is set to 0, then the initial counter state would be
* `| L | L - 1 | ... | 2 | 1 |`,
* where L is the prompt size of the sequence in tokens.
* @param adaptive_rkv_window_size AggregationMode::ADAPTIVE_RKV only - Number of last token scores that will be aggregated (using mean)
* for purposes of determining blocks in the evictable area that comprise the most attention mass.
*/
explicit EvictionScoreManager(size_t block_size, size_t num_decoder_layers, size_t max_pool_window_size, AggregationMode aggregation_mode, size_t ignore_first_n_blocks = 0, size_t snapkv_window_size = 0) : m_block_size(block_size), m_num_decoder_layers(num_decoder_layers), m_scores(num_decoder_layers), m_cache_counter(num_decoder_layers), m_max_pool_window_size(max_pool_window_size), m_aggregation_mode(aggregation_mode), m_ignore_first_n_blocks(ignore_first_n_blocks), m_snapkv_window_size(snapkv_window_size), m_num_registered_snapkv_aggregated_scores(0) {}
explicit EvictionScoreManager(size_t block_size, size_t num_decoder_layers, size_t max_pool_window_size, AggregationMode aggregation_mode, size_t ignore_first_n_blocks = 0, size_t snapkv_window_size = 0, size_t adaptive_rkv_window_size = 8) : m_block_size(block_size), m_num_decoder_layers(num_decoder_layers), m_scores(num_decoder_layers), m_cache_counter(num_decoder_layers), m_max_pool_window_size(max_pool_window_size), m_aggregation_mode(aggregation_mode), m_ignore_first_n_blocks(ignore_first_n_blocks), m_snapkv_window_size(snapkv_window_size), m_num_registered_snapkv_aggregated_scores(0), m_adaptive_rkv_window_size(adaptive_rkv_window_size), m_previous_scores_queues(num_decoder_layers) {}

/**
* Registers new token scores and aggregates them internally as necessary. The token scores provided may be corresponding not to all
Expand Down Expand Up @@ -100,6 +104,22 @@ class EvictionScoreManager {
std::size_t m_ignore_first_n_blocks;
std::size_t m_snapkv_window_size;
std::size_t m_num_registered_snapkv_aggregated_scores;
size_t m_adaptive_rkv_window_size = 8;

struct EvictionScoreRecord {
EvictionScoreRecord(const std::vector<double>& score_, const std::set<size_t>& skips_) : score(score_), skips(skips_) {};
std::vector<double> score;
std::set<size_t> skips;
};

std::vector<std::deque<EvictionScoreRecord>> m_previous_scores_queues;

void _initialize_score_with_skips(std::vector<double>& dst, const std::vector<double>& src, const std::set<size_t> skipped_logical_block_ids);
void _accumulate_initial_scores(const std::vector<double>& max_pooled_hh_scores, size_t decoder_layer_idx, size_t num_snapkv_scores, const std::set<size_t>& skipped_logical_block_ids);

void _accumulate_layer_scores_to(size_t decoder_layer_idx, const std::vector<double>& src, const std::set<size_t>& skipped_logical_block_ids, std::vector<double>& dst);
void _accumulate_with_existing_scores(const std::vector<double>& max_pooled_hh_scores, size_t decoder_layer_idx, size_t num_snapkv_scores, const std::set<size_t>& skipped_logical_block_ids);
void _adjust_norm_sum_counters(size_t decoder_layer_idx, size_t old_size_in_tokens, size_t new_size_in_tokens);
};

class SnapKVScoreAggregationCalculator {
Expand Down Expand Up @@ -201,6 +221,8 @@ class CacheEvictionAlgorithm {
*/
std::vector<std::set<std::size_t>> evict_logical_blocks();

void register_token_similarity(const TokenSimilarityForEachDecoderLayer& token_similarity_for_all_decoder_layers);


private:
std::size_t get_num_blocks(std::size_t num_tokens) const;
Expand All @@ -221,6 +243,10 @@ class CacheEvictionAlgorithm {
std::size_t m_num_evicted_tokens = 0;
std::size_t m_num_decoder_layers;
EvictionScoreManager m_score_manager;

std::vector<std::vector<double>> m_last_token_similarity;
std::pair<std::set<size_t>, size_t> get_adaptive_rkv_similarity_set(size_t max_num_blocks_kept, const std::vector<double>& evictable_area_token_scores);
std::set<size_t> get_adaptive_rkv_diverse_blocks(size_t num_blocks_left_to_fill, const std::set<size_t>& similarity_set, const std::vector<double>& token_similarity);
};


Expand Down
108 changes: 106 additions & 2 deletions src/cpp/src/continuous_batching/model_runner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "continuous_batching/timer.hpp"

#include "continuous_batching/attention_output.hpp"
#include "continuous_batching/cache_eviction.hpp"

namespace ov::genai {

Expand All @@ -24,13 +25,20 @@ inline std::string get_paged_attention_score_output_for_decoder_layer(size_t dec
return ss.str();
}

inline std::string get_adaptive_rkv_similarity_score_output_for_decoder_layer(size_t decoder_layer_id) {
std::stringstream ss;
ss << "adaptive_rkv_similarity." << decoder_layer_id;
return ss.str();
}

/**
* @brief Runs the LLM infer request, parsing the continuous batching scheduler output into proper inputs in terms of OV API (e.g. token input IDs,
* KV cache block indices etc.) and returning the logit scores for the next token to be generated for each of the currently scheduled sequences.
*/
class ModelRunner {
ov::InferRequest m_request;
AttentionScoresForEachSubsequence m_last_attention_scores;
TokenSimilarityForEachSubsequence m_last_token_similarities;
size_t m_block_size;
size_t m_num_decoder_layers;
bool m_collect_attention_scores;
Expand All @@ -44,6 +52,9 @@ class ModelRunner {
bool m_is_aggregate_attention_scores;

bool m_is_use_xattention_inputs;

bool m_is_use_adaptive_rkv;

// A model to compute token embeddings.
// Input shape: [N, conversation length].
// Output shape: [1, conversation length, hidden_size].
Expand Down Expand Up @@ -84,7 +95,8 @@ class ModelRunner {
bool is_use_per_layer_cache_control = false,
bool is_use_rotation_inputs = false,
bool is_aggregate_attention_scores = false,
bool is_use_xattention_inputs = false)
bool is_use_xattention_inputs = false,
bool m_is_use_adaptive_rkv_inputs = false)
: m_request(std::move(request)),
m_block_size(block_size),
m_num_decoder_layers(num_decoder_layers),
Expand All @@ -93,7 +105,8 @@ class ModelRunner {
m_is_use_rotation_inputs(is_use_rotation_inputs),
m_rotated_block_logical_indices_per_sequence_for_each_layer(num_decoder_layers),
m_is_aggregate_attention_scores(is_aggregate_attention_scores),
m_is_use_xattention_inputs(is_use_xattention_inputs) {
m_is_use_xattention_inputs(is_use_xattention_inputs),
m_is_use_adaptive_rkv(m_is_use_adaptive_rkv_inputs) {
OPENVINO_ASSERT(m_num_decoder_layers != 0, "num_decoder_layers must be non-zero");
_reset_cache_rotation_coefficients();
}
Expand All @@ -118,6 +131,9 @@ class ModelRunner {
return m_last_attention_scores;
}

const TokenSimilarityForEachSubsequence& get_last_token_similarities() const {
return m_last_token_similarities;
}

void set_cache_rotation_trig_lut(ov::Tensor&& rotation_trig_lut) {
m_cache_rotation_trig_lut = std::move(rotation_trig_lut);
Expand Down Expand Up @@ -382,6 +398,10 @@ class ModelRunner {
_set_xattention_tensors(sequence_groups, scheduler_output, batch_size_in_sequences);
}

if (m_is_use_adaptive_rkv) {
_set_adaptive_rkv_tensors(sequence_groups, scheduler_output, batch_size_in_sequences);
}

if (matmul_gathering_is_available) {
// use pre-allocated tensor for gather_indices as well
ov::Tensor gather_indices = m_request.get_tensor("sampled_tokens_indices");
Expand All @@ -404,6 +424,10 @@ class ModelRunner {
_collect_attention_scores(sequence_groups, scheduler_output);
}

if (m_is_use_adaptive_rkv) {
_collect_token_similarities(sequence_groups, scheduler_output);
}

_reset_cache_rotation_coefficients();

// return logits
Expand Down Expand Up @@ -749,6 +773,51 @@ class ModelRunner {
}
}

void _collect_token_similarities(const std::vector<SequenceGroup::Ptr> & sequence_groups, const Scheduler::Output& scheduler_output) {
m_last_token_similarities.clear();
size_t num_sequence_groups = scheduler_output.m_scheduled_sequence_groups_ids.size();
using IndexSpan = std::pair<size_t, size_t>;
std::list<std::pair<size_t, IndexSpan>> running_seq_ids_and_kvcache_spans;
size_t offset = 0;
for (size_t i = 0; i < num_sequence_groups; ++i) {
size_t seq_group_id = scheduler_output.m_scheduled_sequence_groups_ids[i];
SequenceGroup::CPtr sequence_group = sequence_groups[seq_group_id];
std::vector<Sequence::CPtr> running_sequences = sequence_group->get_running_sequences();

for (size_t seq_idx = 0; seq_idx < running_sequences.size(); ++seq_idx) {
Sequence::CPtr sequence = running_sequences[seq_idx];
size_t global_sequence_id = sequence->get_id();
auto it = scheduler_output.m_adaptive_rkv_evictable_sizes.find(global_sequence_id);
if (it == scheduler_output.m_adaptive_rkv_evictable_sizes.end()) {
// Adaptive R-KV similarity calculation was not scheduled for this sequence
continue;
}
size_t num_similarity_tokens_calculated = it->second * m_block_size;
// As we only evict during generation phase, so will the similarity calculation will only be
// scheduled after prefill is finished
OPENVINO_ASSERT(sequence_group->can_generate_tokens());

IndexSpan span = {offset, offset + num_similarity_tokens_calculated};
offset += num_similarity_tokens_calculated;
running_seq_ids_and_kvcache_spans.emplace_back(global_sequence_id, span);
}
}

for (const auto& seq_id_and_score_span : running_seq_ids_and_kvcache_spans) {
auto token_similarities_across_decoder_layers_for_current_sequence = TokenSimilarityForEachDecoderLayer(m_num_decoder_layers);
size_t global_sequence_id = seq_id_and_score_span.first;
IndexSpan span = seq_id_and_score_span.second;
for (size_t decoder_layer_id = 0; decoder_layer_id < m_num_decoder_layers; decoder_layer_id++) {
auto attention_score = m_request.get_tensor(get_adaptive_rkv_similarity_score_output_for_decoder_layer(decoder_layer_id));
auto scores_for_cache_of_current_sequence_group = ov::Tensor(attention_score, ov::Coordinate{span.first}, ov::Coordinate{span.second});
auto copied_tensor = ov::Tensor(scores_for_cache_of_current_sequence_group.get_element_type(), ov::Shape{span.second - span.first});
scores_for_cache_of_current_sequence_group.copy_to(copied_tensor);
token_similarities_across_decoder_layers_for_current_sequence[decoder_layer_id] = scores_for_cache_of_current_sequence_group;
}
m_last_token_similarities[global_sequence_id] = token_similarities_across_decoder_layers_for_current_sequence;
}
}

void _set_xattention_tensors(const std::vector<SequenceGroup::Ptr>& sequence_groups,
const Scheduler::Output& scheduler_output,
size_t batch_size_in_sequences) {
Expand Down Expand Up @@ -786,5 +855,40 @@ class ModelRunner {
}

}

void _set_adaptive_rkv_tensors(const std::vector<SequenceGroup::Ptr>& sequence_groups,
const Scheduler::Output& scheduler_output,
size_t batch_size_in_sequences) {
ov::Tensor adaptive_rkv_start_size(ov::element::i32, {});
adaptive_rkv_start_size.data<int32_t>()[0] = scheduler_output.m_adaptive_rkv_start_size;
m_request.set_tensor("adaptive_rkv_start_size", adaptive_rkv_start_size);

ov::Tensor adaptive_rkv_evictable_sizes(ov::element::i32, {batch_size_in_sequences});
float* adaptive_rkv_evictable_sizes_data = adaptive_rkv_evictable_sizes.data<float>();
for (size_t i = 0; i < scheduler_output.m_scheduled_sequence_groups_ids.size(); i++) {
size_t seq_group_id = scheduler_output.m_scheduled_sequence_groups_ids[i];
SequenceGroup::CPtr sequence_group = sequence_groups[seq_group_id];
std::vector<Sequence::CPtr> running_sequences = sequence_group->get_running_sequences();
size_t num_running_sequences = running_sequences.size();
for (size_t k = 0; k < num_running_sequences; ++k) {
Sequence::CPtr sequence = running_sequences[k];
size_t seq_id = sequence->get_id();
size_t evictable_size = 0;

if (scheduler_output.m_adaptive_rkv_evictable_sizes.find(seq_id) != scheduler_output.m_adaptive_rkv_evictable_sizes.end()) {
evictable_size = scheduler_output.m_adaptive_rkv_evictable_sizes.at(seq_id);
}
*adaptive_rkv_evictable_sizes_data = evictable_size;
adaptive_rkv_evictable_sizes_data += 1;
}
}

m_request.set_tensor("adaptive_rkv_evictable_sizes", adaptive_rkv_evictable_sizes);

// Reserved for future use
ov::Tensor adaptive_rkv_diversity_block_set(ov::element::i32, {batch_size_in_sequences});
m_request.set_tensor("adaptive_rkv_diversity_block_set", adaptive_rkv_diversity_block_set);

}
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ namespace ov {
namespace genai {
namespace utils {

void apply_paged_attention_transformations(std::shared_ptr<ov::Model> model, bool per_layer_cache_control, bool allow_cache_rotation, bool allow_xattention) {
void apply_paged_attention_transformations(std::shared_ptr<ov::Model> model, bool per_layer_cache_control, bool allow_cache_rotation, bool allow_xattention, bool allow_adaptive_rkv) {
const ov::op::util::VariableVector& variables = model->get_variables();
OPENVINO_ASSERT(!variables.empty(), "Model is supposed to be stateful");

bool use_block_indices_inputs = per_layer_cache_control;
bool use_score_outputs = per_layer_cache_control;
ov::pass::SDPAToPagedAttention(use_block_indices_inputs, use_score_outputs, /* allow_score_aggregation = */ true, allow_cache_rotation, allow_xattention).run_on_model(model);
ov::pass::SDPAToPagedAttention(use_block_indices_inputs, use_score_outputs, /* allow_score_aggregation = */ true, allow_cache_rotation, allow_xattention/* FIXME (VSHAMPOR): allow_adaptive_rkv */).run_on_model(model);

std::map<std::string, std::shared_ptr<ov::op::v0::Parameter>> key_cache_params, value_cache_params;
for (const auto& param_ptr : model->get_parameters()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace utils {
* @param allow_xattention If true, then the transformations will enable additional per-layer inputs to control the XAttention block-sparse
* attention optimization.
*/
void apply_paged_attention_transformations(std::shared_ptr<ov::Model> model, bool per_layer_cache_control = false, bool allow_cache_rotation = false, bool allow_xattention = false);
void apply_paged_attention_transformations(std::shared_ptr<ov::Model> model, bool per_layer_cache_control = false, bool allow_cache_rotation = false, bool allow_xattention = false, bool allow_adaptive_rkv = false);

void apply_gather_before_matmul_transformation(std::shared_ptr<ov::Model> model);

Expand Down
Loading
Loading