From 91093e43dd04a19957d802339db9f0c606cb7e92 Mon Sep 17 00:00:00 2001 From: Anastasiya Pronina Date: Thu, 31 Jul 2025 03:12:04 +0100 Subject: [PATCH 01/40] Initial version --- src/cpp/src/llm/pipeline.cpp | 19 +- src/cpp/src/llm/pipeline_stateful_npu.cpp | 96 +++ src/cpp/src/llm/pipeline_stateful_npu.hpp | 53 ++ .../speculative_decoding_impl.hpp | 25 +- .../speculative_decoding_npu.cpp | 663 ++++++++++++++++++ .../speculative_decoding_npu.hpp | 113 +++ src/cpp/src/utils.hpp | 26 +- 7 files changed, 955 insertions(+), 40 deletions(-) create mode 100644 src/cpp/src/llm/pipeline_stateful_npu.cpp create mode 100644 src/cpp/src/llm/pipeline_stateful_npu.hpp create mode 100644 src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp create mode 100644 src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp diff --git a/src/cpp/src/llm/pipeline.cpp b/src/cpp/src/llm/pipeline.cpp index 76d1fe24dc..e598ec743e 100644 --- a/src/cpp/src/llm/pipeline.cpp +++ b/src/cpp/src/llm/pipeline.cpp @@ -9,10 +9,10 @@ #include "openvino/genai/llm_pipeline.hpp" #include "openvino/genai/perf_metrics.hpp" -#include "llm/pipeline_static.hpp" #include "llm/pipeline_stateful.hpp" #include "llm/pipeline_continuous_batching_adapter.hpp" #include "speculative_decoding/speculative_decoding_impl.hpp" +#include "llm/pipeline_stateful_npu.hpp" #include "utils.hpp" namespace ov { @@ -85,9 +85,7 @@ ov::genai::LLMPipeline::LLMPipeline( auto [device_properties, scheduler_config] = utils::extract_scheduler_config(properties, utils::get_latency_oriented_scheduler_config()); m_pimpl = std::make_unique(models_path, tokenizer, scheduler_config, device, device_properties); } else if (device == "NPU") { - m_pimpl = properties.count("STATIC_PIPELINE") - ? static_llm::LLMPipelineFactory::create(models_path, tokenizer, properties) - : std::make_unique(models_path, tokenizer, device, properties); + m_pimpl = std::make_unique(models_path, tokenizer, properties); } else if (attention_backend == PA_BACKEND) { // try to call CB adapter one more time, but with safe guard to silent exception try { @@ -122,9 +120,7 @@ ov::genai::LLMPipeline::LLMPipeline( auto [device_properties, scheduler_config] = utils::extract_scheduler_config(properties, utils::get_latency_oriented_scheduler_config()); m_pimpl = std::make_unique(models_path, scheduler_config, device, device_properties); } else if (device == "NPU") { - m_pimpl = properties.count("STATIC_PIPELINE") - ? static_llm::LLMPipelineFactory::create(models_path, properties) - : std::make_unique(models_path, device, properties); + m_pimpl = std::make_unique(models_path, properties); } else if (attention_backend == PA_BACKEND) { // try to call CB adapter one more time, but with safe guard to silent exception try { @@ -163,16 +159,9 @@ ov::genai::LLMPipeline::LLMPipeline( m_pimpl = std::make_unique(model_str, weights_tensor, tokenizer, scheduler_config, device, device_properties, generation_config); } else if (device == "NPU") { - m_pimpl = properties.count("STATIC_PIPELINE") - ? static_llm::LLMPipelineFactory::create( - utils::singleton_core().read_model(model_str, weights_tensor), - tokenizer, - properties, - generation_config) - : std::make_unique( + m_pimpl = std::make_unique( utils::singleton_core().read_model(model_str, weights_tensor), tokenizer, - device, properties, generation_config); } else if (attention_backend == PA_BACKEND) { diff --git a/src/cpp/src/llm/pipeline_stateful_npu.cpp b/src/cpp/src/llm/pipeline_stateful_npu.cpp new file mode 100644 index 0000000000..54e34f213c --- /dev/null +++ b/src/cpp/src/llm/pipeline_stateful_npu.cpp @@ -0,0 +1,96 @@ + +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "pipeline_stateful_npu.hpp" +#include "speculative_decoding/speculative_decoding_npu.hpp" +#include "llm/pipeline_stateful.hpp" +#include "llm/pipeline_static.hpp" +#include "utils.hpp" + +#include + +#include "openvino/runtime/core.hpp" +#include "openvino/core/parallel.hpp" +#include "openvino/genai/text_streamer.hpp" + +namespace { + ov::genai::ModelDesc + extract_draft_model_from_config(ov::AnyMap& config) { + ov::genai::ModelDesc draft_model; + if (config.find(ov::genai::utils::DRAFT_MODEL_ARG_NAME) != config.end()) { + draft_model = config.at(ov::genai::utils::DRAFT_MODEL_ARG_NAME).as(); + config.erase(ov::genai::utils::DRAFT_MODEL_ARG_NAME); + } + return draft_model; +} +} // anonymous namespace + +namespace ov::genai { + +// NB: No constructor for creation of pipeline from infer request, as pipeline from infer request +// for NPU is handled inside of ov::genai::StatefulLLMPipeline class iself. +StatefulLLMPipelineNPU::StatefulLLMPipelineNPU( + const std::filesystem::path& models_path, + const ov::genai::Tokenizer& tokenizer, + const ov::AnyMap& properties) + : StatefulLLMPipelineNPU( + utils::read_model(models_path, properties), + tokenizer, + properties, + utils::from_config_json_if_exists(models_path) + ) {} + +StatefulLLMPipelineNPU::StatefulLLMPipelineNPU( + const std::filesystem::path& models_path, + const ov::AnyMap& plugin_config) + : StatefulLLMPipelineNPU{models_path, Tokenizer(models_path, plugin_config), plugin_config} {} + +StatefulLLMPipelineNPU::StatefulLLMPipelineNPU( + const std::shared_ptr& model, + const ov::genai::Tokenizer& tokenizer, + const ov::AnyMap& properties, + const ov::genai::GenerationConfig& generation_config) + : LLMPipelineImplBase(tokenizer, generation_config) { + auto properties_without_draft_model = properties; + auto draft_model_descr = extract_draft_model_from_config(properties_without_draft_model); + if (draft_model_descr.model != nullptr) { + auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, "NPU", properties_without_draft_model, {}, generation_config); + m_pimpl = std::make_unique(main_model_descr, draft_model_descr); + } else if (properties_without_draft_model.count("STATIC_PIPELINE")) { + m_pimpl = static_llm::LLMPipelineFactory::create(model, tokenizer, + properties_without_draft_model, generation_config); + } else { + m_pimpl = std::make_unique(model, tokenizer, "NPU", + properties_without_draft_model, generation_config); + } +} + +DecodedResults StatefulLLMPipelineNPU::generate( + StringInputs inputs, + OptionalGenerationConfig generation_config, + StreamerVariant streamer) { + return m_pimpl->generate(inputs, generation_config, streamer); +} + +EncodedResults StatefulLLMPipelineNPU::generate( + const EncodedInputs& inputs, + OptionalGenerationConfig generation_config, + StreamerVariant streamer) { + return m_pimpl->generate(inputs, generation_config, streamer); +} + +void StatefulLLMPipelineNPU::start_chat(const std::string& system_message) { + m_pimpl->start_chat(system_message); +} + +// FIXME: Do we need it? +// void StatefulLLMPipelineNPU::reset_kv_state() { +// m_pimpl->reset_kv_state(); +// } + +void StatefulLLMPipelineNPU::finish_chat() { + m_pimpl->finish_chat(); +} + +} // namespace ov::genai diff --git a/src/cpp/src/llm/pipeline_stateful_npu.hpp b/src/cpp/src/llm/pipeline_stateful_npu.hpp new file mode 100644 index 0000000000..e14aa30065 --- /dev/null +++ b/src/cpp/src/llm/pipeline_stateful_npu.hpp @@ -0,0 +1,53 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + + +#include + +#include "llm/pipeline_base.hpp" + +namespace ov::genai { + +class StatefulLLMPipelineNPU final : public LLMPipelineImplBase { +public: + StatefulLLMPipelineNPU( + const std::filesystem::path& models_path, + const ov::genai::Tokenizer& tokenizer, + const ov::AnyMap& plugin_config + ); + + StatefulLLMPipelineNPU( + const std::filesystem::path& models_path, + const ov::AnyMap& plugin_config + ); + + StatefulLLMPipelineNPU( + const std::shared_ptr& model, + const ov::genai::Tokenizer& tokenizer, + const ov::AnyMap& config, + const ov::genai::GenerationConfig& generation_config + ); + + DecodedResults generate( + StringInputs inputs, + OptionalGenerationConfig generation_config, + StreamerVariant streamer + ) override; + + EncodedResults generate( + const EncodedInputs& inputs, + OptionalGenerationConfig generation_config, + StreamerVariant streamer + ) override; + + void start_chat(const std::string& system_message) override; + + void finish_chat() override; + + ~StatefulLLMPipelineNPU() = default; + +private: + std::unique_ptr m_pimpl; +}; + +} // namespace ov::genai diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp index 026d592569..b8ecdb2b76 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp @@ -8,33 +8,10 @@ #include "speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp" #include "speculative_decoding/speculative_decoding_metrics.hpp" #include "openvino/genai/speculative_decoding/perf_metrics.hpp" +#include "utils.hpp" namespace ov::genai { -struct ModelDesc { - std::string device; - ov::genai::SchedulerConfig scheduler_config; - ov::AnyMap properties; - ov::genai::GenerationConfig generation_config; - std::shared_ptr model = nullptr; - ov::genai::Tokenizer tokenizer; - - ModelDesc(const std::shared_ptr& model, - const ov::genai::Tokenizer& tokenizer, - const std::string& device = {}, - const ov::AnyMap& properties = {}, - const ov::genai::SchedulerConfig& scheduler_config = {}, - const ov::genai::GenerationConfig& generation_config = {}) : - model(model), - tokenizer(tokenizer), - device(device), - properties(properties), - scheduler_config(scheduler_config), - generation_config(generation_config) {} - - ModelDesc() = default; -}; - class ContinuousBatchingPipeline::SpeculativeDecodingImpl : public ContinuousBatchingPipeline::IContinuousBatchingPipeline { protected: std::shared_ptr m_main_pipeline, m_draft_pipeline; diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp new file mode 100644 index 0000000000..cd1e7c1a69 --- /dev/null +++ b/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp @@ -0,0 +1,663 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "speculative_decoding_npu.hpp" +#include "openvino/runtime/core.hpp" +#include "openvino/core/parallel.hpp" +#include "openvino/genai/text_streamer.hpp" + +namespace ov::genai { +template struct overloaded : Ts... {using Ts::operator()...;}; +template overloaded(Ts...) -> overloaded; + +bool are_tokenizers_equal(ov::genai::Tokenizer& lhs, ov::genai::Tokenizer& rhs); +} // ov::genai + +namespace { +ov::Tensor make_tensor_slice(ov::Tensor tensor, size_t dim, size_t start_pos, size_t end_pos) { + ov::Shape start_shape(std::vector(tensor.get_shape().size(), 0u)); + start_shape[dim] = start_pos; + ov::Shape end_shape = tensor.get_shape(); + end_shape[dim] = end_pos; + return ov::Tensor(tensor, start_shape, end_shape); +} + +void stream_generated_tokens(std::shared_ptr streamer_ptr, + ov::genai::GenerationHandle& handle) { + if (streamer_ptr && handle->can_read()) { + std::unordered_map token = handle->read(); + auto streaming_status = streamer_ptr->write(token.begin()->second.generated_ids); + if (streaming_status != ov::genai::StreamingStatus::RUNNING) { + streaming_status == ov::genai::StreamingStatus::CANCEL ? handle->cancel() : handle->stop(); + } + } +} +} // anonymous namespace + +namespace ov { +namespace genai { + LLMInferWrapper::LLMInferWrapper( + const ov::genai::ModelDesc& model_desc +) : m_properties(model_desc.properties), + m_generation_config(model_desc.generation_config), + m_tokenizer(model_desc.tokenizer) { + m_kv_pos = ov::genai::utils::get_kv_axes_pos(model_desc.model); + if (model_desc.device == "NPU") { + auto [compiled, kv_desc] = utils::compile_decoder_for_npu(model_desc.model, m_properties, m_kv_pos); + m_max_prompt_len = kv_desc.max_prompt_len; + m_kvcache_total = kv_desc.max_prompt_len + kv_desc.min_response_len; + m_request = compiled.create_infer_request(); + } else { + // TODO: We might need it for manipulations with indices + // utils::apply_gather_before_matmul_transformation(model_desc.model); + m_request = ov::genai::utils::singleton_core().compile_model(model_desc.model, model_desc.device, m_properties).create_infer_request(); + } + + m_sampler.set_tokenizer(m_tokenizer); + m_sampler.set_seed(model_desc.generation_config.rng_seed); +} + +ov::genai::GenerationConfig LLMInferWrapper::get_generation_config() const { + return m_generation_config; +} + +void LLMInferWrapper::set_generation_config(ov::genai::GenerationConfig config) { + m_generation_config = config; +} + +int64_t LLMInferWrapper::infer_first(const ov::Tensor &input_ids, + const ov::Tensor &attention_mask, + const ov::Tensor &position_ids) { + m_request.set_tensor("input_ids", input_ids); + m_request.set_tensor("attention_mask", attention_mask); + m_request.set_tensor("position_ids", position_ids); + // set beam_idx for stateful model: no beam search is used and BATCH_SIZE = 1 + m_request.get_tensor("beam_idx").set_shape({BATCH_SIZE}); + m_request.get_tensor("beam_idx").data()[0] = 0; + + m_request.infer(); + m_num_processed_tokens = input_ids.get_shape()[1]; + + // Need for tokens sampling and streaming: + m_sequence_group = std::make_shared( + 0 /* request_id */, input_ids, m_generation_config, 1 /* block_size */); + m_sequence_group->schedule_tokens(m_sequence_group->get_prompt_len()); + auto logits = get_logits(); + m_sequence_group->set_output_seq_len(logits.get_shape().at(1)); + + // Initialize placeholder data for next inferences on input_ids of size 1 (if any) + // with values of previous iteration for simple increment on next iteration: + m_new_input_token = -1; + m_new_position_id = m_num_processed_tokens - 1; + m_new_atten_mask_data = std::vector(m_num_processed_tokens, 1); + set_already_allocated_input_for_1_token(); + + return std::get(sample_tokens(logits, 1u)); +} + +bool LLMInferWrapper::can_infer() { + OPENVINO_ASSERT(m_sequence_group, "can_infer() can be called only after infer_first()!"); + return (m_sequence_group->is_running() && !m_sequence_group->handle_stopped() && !m_sequence_group->handle_cancelled()); +} + +int64_t LLMInferWrapper::infer_next(int64_t token) { + OPENVINO_ASSERT(m_num_processed_tokens > 0, "infer_next() can be called only after infer_first()!"); + + // FIXME: Uncomment for static model and throw exception instead + // if (m_num_processed_tokens + tokens_size == m_kvcache_total) { + // m_sequence_group->set_out_of_memory(); + // return -1; + // } + + m_sequence_group->schedule_tokens(1u); + m_sequence_group->set_output_seq_len(1u); + + // Just change the variables here, as pointers to them are already set to corresponding tensors + m_new_input_token = token; + ++m_new_position_id; + // However, attention_mask changes its shape on each iteration, it should be re-set explicitly + m_new_atten_mask_data.push_back(1); + m_request.set_tensor("attention_mask", ov::Tensor(ov::element::i64, ov::Shape{1,m_new_atten_mask_data.size()}, (void*)&m_new_atten_mask_data[0])); + + m_request.infer(); + + m_num_processed_tokens += 1u; + + return std::get(sample_tokens(get_logits(), 1u)); +} + +int64_t LLMInferWrapper::infer_next(const std::vector tokens) { + OPENVINO_ASSERT(m_num_processed_tokens > 0, "infer_next() can be called only after infer_first()!"); + + m_sequence_group->schedule_tokens(tokens.size()); + m_sequence_group->set_output_seq_len(1u); + auto logits = infer_next_internal(tokens); + return std::get(sample_tokens(logits, 1u)); +} + +std::vector LLMInferWrapper::infer_next_return_all(const std::vector tokens) { + OPENVINO_ASSERT(m_num_processed_tokens > 0, "infer_next_return_all() can be called only after infer_first()!"); + + auto tokens_size = tokens.size(); + m_sequence_group->schedule_tokens(tokens_size); + m_sequence_group->set_output_seq_len(tokens_size + 1u); + auto logits = infer_next_internal(tokens); + return std::get>(sample_tokens(logits, tokens_size + 1u)); +} + +ov::Tensor LLMInferWrapper::get_logits() { + return m_request.get_tensor("logits"); +} + +std::size_t LLMInferWrapper::get_num_processed_tokens() const { + return m_num_processed_tokens; +} + +// NB: This method should be called only after infer_first()! +GenerationHandle LLMInferWrapper::create_generation_handle() { + OPENVINO_ASSERT(m_num_processed_tokens > 0, "create_generation_handle() can be called only after infer_first()!"); + // NB: Controls what tokens are ready to be pushed into the streamer + m_handle = std::make_shared( + m_sequence_group->get_generation_stream(), m_sequence_group->get_sampling_parameters()); + return m_handle; +} + +// TODO: Debug +void LLMInferWrapper::remove_last_generated_tokens(const size_t tokens_to_remove) { + OPENVINO_ASSERT(m_sequence_group, "remove_last_generated_tokens() can be called only after infer_first()!"); + + // Remove last generated tokens + const auto running_sequences = m_sequence_group->get_running_sequences(); + const auto sequence = running_sequences.front(); + OPENVINO_ASSERT(running_sequences.size() == 1u); + const auto generated_token_ids = sequence->get_generated_ids(); + const auto sequence_generated_len = generated_token_ids.size(); + // auto& logit_processor = m_sampler->get_logit_processor(0); + OPENVINO_ASSERT(sequence_generated_len >= tokens_to_remove); + + // size_t start_pos = sequence_generated_len - tokens_to_remove; + // TODO: We might need it, however, we sample each token, unlike in CB pipeline + // for (size_t i = start_pos; i < sequence_generated_len; ++i) { + // logit_processor.decrease_generated_token_occurance(generated_token_ids[i]); + // } + sequence->remove_last_tokens(tokens_to_remove); + + // FIXME: Should we do it or shouldn't? + // if (is_update_logit_processor) { + // logit_processor.update_generated_len(min_candidate_len); + // } +} + +void LLMInferWrapper::trimm_kv_cache(const size_t tokens_to_remove) { + // Trim kv_cache values on tokens_to_remove + ov::genai::utils::KVCacheState to_trim_state; + to_trim_state.num_tokens_to_trim = tokens_to_remove; + to_trim_state.seq_length_axis = m_kv_pos.seq_len; + to_trim_state.reset_mem_state = false; + ov::genai::utils::trim_kv_cache(m_request, to_trim_state, {}); + m_num_processed_tokens -= tokens_to_remove; +} + +ov::genai::EncodedResults LLMInferWrapper::finalize() { + OPENVINO_ASSERT(m_sequence_group, "finalize() can be called only after infer_first()!"); + + ov::genai::EncodedResults results; + // NB: Only batch=1 is supported now + results.scores.resize(1u); + results.scores[0] = 0u; + results.tokens.resize(1u); + + OPENVINO_ASSERT(m_sequence_group->get_finished_sequences().size() == 1, + "finalize() should be called when inference for current model is finished till EOS or other stop criteria."); + + auto sequence = m_sequence_group->get_finished_sequences().front(); + results.tokens[0] = sequence->get_generated_ids(); + results.scores[0] = sequence->get_cumulative_log_prob(); + + m_sampler.clear_request_info(m_sequence_group->get_request_id()); + + return results; +} + +ov::genai::GenerationStatus LLMInferWrapper::get_generation_status() const { + OPENVINO_ASSERT(m_sequence_group, "get_generation_status() can be called only after infer_first()!"); + + return m_sequence_group->get_generation_stream()->get_status(); +} + +void LLMInferWrapper::reset_state() { + return m_request.reset_state(); +} + +ov::Tensor LLMInferWrapper::infer_next_internal(const std::vector tokens) { + OPENVINO_ASSERT(m_num_processed_tokens > 0, "ov::Tensor infer_next() can be called only after infer_first()!"); + + size_t tokens_size = tokens.size(); + + // FIXME: Uncomment for static model and throw exception instead + // if (m_num_processed_tokens + tokens_size == m_kvcache_total) { + // m_sequence_group->set_out_of_memory(); + // return -1; + // } + + auto input_ids = m_request.get_tensor("input_ids"); + input_ids.set_shape({BATCH_SIZE, tokens_size}); + std::copy_n(tokens.begin(), tokens_size, input_ids.data()); + + // FIXME: For model with static shapes we can just copy after + // the prefilled tokens, no reshape is needed. + auto attention_mask = m_request.get_tensor("attention_mask"); + std::vector attention_mask_copy(attention_mask.data(), + attention_mask.data() + m_num_processed_tokens); + attention_mask.set_shape({BATCH_SIZE, m_num_processed_tokens + tokens_size}); + std::copy_n(attention_mask_copy.begin(), m_num_processed_tokens, attention_mask.data()); + std::fill_n(attention_mask.data() + m_num_processed_tokens, tokens_size, 1); + + auto position_ids = m_request.get_tensor("position_ids"); + position_ids.set_shape({BATCH_SIZE, tokens_size}); + std::iota(position_ids.data(), + position_ids.data() + position_ids.get_size(), + m_num_processed_tokens); + + m_request.get_tensor("beam_idx").set_shape({BATCH_SIZE}); + m_request.get_tensor("beam_idx").data()[0] = 0; + + m_request.infer(); + + m_num_processed_tokens += tokens_size; + + // Update pre-allocated inputs for 1 token and return back to use it + // in case if next infer will be called on input_ids of size 1 + // (most frequent case). + m_new_input_token = -1; + m_new_position_id = m_num_processed_tokens - 1; + for (std::size_t i = 0; i < tokens_size; ++i) { + m_new_atten_mask_data.push_back(1); + } + set_already_allocated_input_for_1_token(); + + return get_logits(); +} + +void LLMInferWrapper::set_already_allocated_input_for_1_token() { + m_request.set_tensor("input_ids", ov::Tensor(ov::element::i64, ov::Shape{1,1}, reinterpret_cast(&m_new_input_token))); + m_request.set_tensor("position_ids", ov::Tensor(ov::element::i64, ov::Shape{1,1}, reinterpret_cast(&m_new_position_id))); +} + +// FIXME: It is wrong way to sample tokens, or right because of set output_seq_len in the sequence? +// get_generated_ids will return all ids? +std::variant> + LLMInferWrapper::sample_tokens(const ov::Tensor& logits, std::size_t num_tokens_to_return) { + OPENVINO_ASSERT(m_sequence_group, "sample_tokens() can be called only after infer_first()!"); + + m_sampler.sample({m_sequence_group}, logits); + const auto running_sequences = m_sequence_group->get_running_sequences(); + OPENVINO_ASSERT(running_sequences.size() == 1u); + auto sampled_tokens = running_sequences.front()->get_generated_ids(); + if (num_tokens_to_return == 1) { + return sampled_tokens.back(); + } else { + // FIXME condition can be switched to boolean? + OPENVINO_ASSERT(num_tokens_to_return == sampled_tokens.size()); + return sampled_tokens; + } +} + +void SpeculativeConfig::update_candidate_strategy(const size_t num_matches) { + // Dynamically adjust number of generated candidates based on number of matches + // we want to balance the benefits of getting candidates tokens correct with the + // cost of forecasting incorrect candidates tokens. + if (num_matches == num_pred_tokens) { + num_pred_tokens = std::min(num_pred_tokens + 2, max_pred_tokens); + } else { + num_pred_tokens = std::max(int64_t(num_pred_tokens) - 1, int64_t(1)); + } +} + +SpeculativeLLMPipelineNPU::SpeculativeLLMPipelineNPU( + const ov::genai::ModelDesc& main_model_desc, + const ov::genai::ModelDesc& draft_model_desc +) : LLMPipelineImplBase(main_model_desc.tokenizer, main_model_desc.generation_config) { + auto draft_model = draft_model_desc.model; + + // FIXME: slicing produces incorrect results for some models on NPU. + // On NPU, applying slice the safe way is done by the underlying plugin + if (draft_model_desc.device != "NPU") { + utils::apply_slice_before_matmul_transformation(draft_model); + // As draft_model_desc contains std::shared_ptr, + // this model update will be reflected in draft_model_desc + } + + // TODO: We might need it for manipulations with indices + // utils::apply_gather_before_matmul_transformation(main_model); + // utils::apply_gather_before_matmul_transformation(draft_model); + + // Main and Draft model can have different tokenizers + // to do: support retokenization: 154103 + ov::genai::Tokenizer main_model_tokenizer = main_model_desc.tokenizer; + ov::genai::Tokenizer draft_model_tokenizer = draft_model_desc.tokenizer; + // todo: remove this condition after support of CVS-154103 + OPENVINO_ASSERT(are_tokenizers_equal(main_model_tokenizer, draft_model_tokenizer), "Tokenizers for draft and main models are different!"); + m_tokenizer = main_model_tokenizer; + OPENVINO_ASSERT(draft_model_desc.generation_config.rng_seed == main_model_desc.generation_config.rng_seed, "Seed for sampling must be equal for draft and main models!"); + + // Draft model (which is smaller, less accurate but faster) + auto draft_model_desc_copy = draft_model_desc; + if (draft_model_desc_copy.device.empty()) { + draft_model_desc_copy.device = main_model_desc.device; + } + if (draft_model_desc_copy.properties.empty()) { + draft_model_desc_copy.properties = main_model_desc.properties; + } + m_draft_request = std::make_unique(draft_model_desc_copy); + + // Main model (which is bigger, more accurate but slower) + // FIXME: Need to support full logits tensor as output for main model on NPU. + m_main_request = std::make_unique(main_model_desc); + + m_perf_metrics = ov::genai::SDPerModelsPerfMetrics(); + + // FIXME: Where to take it when draft model will be on NPU? + size_t max_sequence_length = main_model_desc.generation_config.get_max_new_tokens(); + if (max_sequence_length == SIZE_MAX) { + // FIXME: NPUW_LLM_MAX_PROMPT_LEN + NPUW_LLM_MIN_RESPONSE_LEN + max_sequence_length = 100; + } + // FIXME: ? Use main_model.generation_config.num_assistant_tokens; It should be > 0, if we want draft_model.generation_config.is_speculative_decoding() == true. + const std::size_t candidates_num = 5; + m_speculative_config.max_seq_length = max_sequence_length; + m_speculative_config.num_pred_tokens = candidates_num; +} + +DecodedResults SpeculativeLLMPipelineNPU::generate( + StringInputs inputs, + OptionalGenerationConfig generation_config, + StreamerVariant streamer +) { + auto start_time = std::chrono::steady_clock::now(); + + std::string prompt = std::visit(overloaded{ + [](const std::string& prompt) { + return prompt; + }, + [](std::vector& prompts) { + OPENVINO_ASSERT(prompts.size() == 1u, "Currently only batch size=1 is supported"); + return prompts.front(); + } + }, inputs); + + const GenerationConfig& config = generation_config.has_value() ? *generation_config : m_generation_config; + + ov::genai::TokenizedInputs tokenized_input; + if (m_is_chat_conversation) { + m_history.push_back({{"role", "user"}, {"content", prompt}}); + constexpr bool add_generation_prompt = true; + prompt = m_tokenizer.apply_chat_template(m_history, add_generation_prompt); + // for chat ov::genai::add_special_tokens(false) is aligned with stateful pipeline and HF + tokenized_input = m_tokenizer.encode(prompt, ov::genai::add_special_tokens(false)); + } else { + if (config.apply_chat_template && !m_tokenizer.get_chat_template().empty()) { + ChatHistory history({{{"role", "user"}, {"content", prompt}}}); + constexpr bool add_generation_prompt = true; + auto templated_prompt = m_tokenizer.apply_chat_template(history, add_generation_prompt); + tokenized_input = m_tokenizer.encode(templated_prompt, ov::genai::add_special_tokens(false)); + } else { + // in case when chat_template was not found in tokenizer_config.json or set + tokenized_input = m_tokenizer.encode(prompt, ov::genai::add_special_tokens(true)); + } + } + + auto encode_stop_time = std::chrono::steady_clock::now(); + auto encoded_results = generate(tokenized_input, config, streamer); + + auto decode_start_time = std::chrono::steady_clock::now(); + DecodedResults decoded_results = {m_tokenizer.decode(encoded_results.tokens), encoded_results.scores}; + auto decode_stop_time = std::chrono::steady_clock::now(); + + if (m_is_chat_conversation) { + auto answer = decoded_results.texts[0]; + if (m_chat_generation_finish_status == GenerationStatus::CANCEL) + // If chat generation process was cancelled by user, let's rollback to previous state of history + m_history.pop_back(); + else + m_history.push_back({{"role", "assistant"}, {"content", answer}}); + } + + // generate_durations + // decoded_results.perf_metrics = encoded_results.perf_metrics; + // auto& raw_counters = decoded_results.perf_metrics.raw_metrics; + // auto stop_time = std::chrono::steady_clock::now(); + // raw_counters.generate_durations.clear(); + // raw_counters.generate_durations.emplace_back(PerfMetrics::get_microsec(stop_time - start_time)); + // raw_counters.tokenization_durations.emplace_back(PerfMetrics::get_microsec(encode_stop_time - start_time)); + // raw_counters.detokenization_durations.emplace_back(PerfMetrics::get_microsec(decode_stop_time - decode_start_time)); + // decoded_results.perf_metrics.m_evaluated = false; + // decoded_results.perf_metrics.evaluate_statistics(start_time); + return decoded_results; +} + +EncodedResults SpeculativeLLMPipelineNPU::generate( + const EncodedInputs& inputs, + OptionalGenerationConfig generation_config, + StreamerVariant streamer) { + // from step() + auto& raw_perf_counters = m_perf_metrics.raw_metrics; + auto& main_raw_perf_counters = m_perf_metrics.main_model_metrics.raw_metrics; + // + + auto start_time = std::chrono::steady_clock::now(); + + // from generate() + ManualTimer generate_timer("speculative_decoding: generate()"); + generate_timer.start(); + // + + ov::Tensor input_ids; + ov::Tensor attention_mask; + + if (auto data = std::get_if(&inputs)) { + input_ids = *data; + attention_mask = ov::genai::utils::init_attention_mask(input_ids); + } else if (auto data = std::get_if(&inputs)) { + input_ids = data->input_ids; + attention_mask = data->attention_mask; + } + + ov::Shape prompts_shape = input_ids.get_shape(); + const size_t batch_size = prompts_shape[0]; + OPENVINO_ASSERT(batch_size == 1u, "Currently only batch size=1 is supported"); + + GenerationConfig config = (generation_config.has_value()) ? *generation_config : m_generation_config; + // If stop_token_ids were not provided, take value from default m_generation_config + if (config.stop_token_ids.empty()) + config.stop_token_ids = m_generation_config.stop_token_ids; + // If eos_token_id was not provided, take value from default m_generation_config + if (config.eos_token_id == -1) + config.set_eos_token_id(m_generation_config.eos_token_id); + config.validate(); + // FIXME: Update conditionally: + m_main_request->set_generation_config(config); + + std::shared_ptr streamer_ptr = ov::genai::utils::create_streamer(streamer, m_tokenizer); + + OPENVINO_ASSERT(config.is_greedy_decoding() || config.is_multinomial(), + "Currently only greedy and multinomial decoding are supported"); + + OPENVINO_ASSERT(config.num_return_sequences == 1u, + "Currently only \"num_return_sequences\" equal to 1 is supported!"); + + // FIXME: Return back for the static draft model. + // NB: Check if there is enough space in KV-cache to process input prompt + auto prompt_len = prompts_shape[1]; + // if (prompt_len > m_max_prompt_len) { + // OPENVINO_THROW("Static Stateful LLM pipeline may only process prompts up to " + // + std::to_string(m_max_prompt_len) + " tokens. " + // + "Set the \"MAX_PROMPT_LEN\" config option to increase the limit."); + // } + + ov::Tensor position_ids{ov::element::i64, input_ids.get_shape()}; + utils::initialize_position_ids(position_ids, attention_mask); + + // To collect KV-cache for the prompt and to get the next token, run the very first infer request + // for draft and main models: + m_draft_request->infer_first(input_ids, attention_mask, position_ids); + auto out_token = m_main_request->infer_first(input_ids, attention_mask, position_ids); + + // logits shape is [BATCH_SIZE, seq_len, vocab_size] + auto draft_logits = m_draft_request->get_logits(); + auto main_logits = m_main_request->get_logits(); + size_t draft_vocab_size = draft_logits.get_shape().back(); + size_t main_vocab_size = main_logits.get_shape().back(); + OPENVINO_ASSERT(draft_vocab_size == main_vocab_size, + "Vocab sizes should be the same for the both: main and draft models!"); + + + // FIXME: Apply this logic carefully in LLMInferRequest of prefill model, + // if needed. + // FIXME: Here is workaround to get only useful units of returned logits. + // If SliceOut is applied, there will be only 1 useful logit returned, + // nothing is required here. + // Other way, model will return logits of full context length, + // as internally prefill model is specially reshaped to return them. + // Fix should be done on OpenVINO side, so the model should return only + // useful logits of input prompt length, dropping the implementation-related + // padding ones. + // auto sequence_len = all_logits.get_shape()[1]; + // if (sequence_len > 1) { + // logits = make_tensor_slice(all_logits, 1, sequence_len - prompt_len, sequence_len); + // } + OPENVINO_ASSERT(draft_logits.get_shape().at(1) <= main_logits.get_shape().at(1), + "Num of generated useful logits from draft models should be less" + "or equal than ones from main model."); + + // Config draft model to not stop on EOS and remove stop strings: + ov::genai::GenerationConfig draft_config = m_draft_request->get_generation_config(); + draft_config.ignore_eos = true; + draft_config.stop_strings = {}; + draft_config.validate(); + m_draft_request->set_generation_config(draft_config); + + GenerationHandle m_main_gen_handle = m_main_request->create_generation_handle(); + stream_generated_tokens(streamer_ptr, m_main_gen_handle); + + /* Speculative decoding works the following way. The draft model predicts the next K + tokens one by one in an autoregressive manner, while the main model validates these + predictions and corrects them if necessary. We go through each predicted token, and + if a difference is detected between the draft and main model, we stop and keep the + last token predicted by the main model. Then the draft model gets the latest main + prediction and again tries to predict the next K tokens, repeating the cycle. + + This approach reduces the need for multiple infer requests to the main model, + enhancing performance. For instance, in more predictable parts of text generation, + the draft model can, in best-case scenarios, generate the next K tokens that exactly + match the target. In that case they are validated in a single inference call to + the main model instead of running K subsequent requests. + */ + // Last generated token by draft model needs to be prepended before next run if it is accepted by the main model! + // So it will get into context too. + int64_t draft_prefix_token = -1; + while (m_main_request->can_infer()) { + // Phase 1: Generation of candidates with the draft model: + std::vector candidates; + // Limit candidates size by num_pred_tokens or by max_seq_length: + // FIXME: draft_prefix_token isn't taken into account! + // FIXME: How max_seq_length will limit further generation of main model? + size_t candidates_to_generate = std::min(m_speculative_config.num_pred_tokens, + m_speculative_config.max_seq_length - m_draft_request->get_num_processed_tokens() - 1); + candidates.reserve(candidates_to_generate); + + // If draft_prefix_token is present, prepend it to out_token in order to collect KV cache for it + auto candidate = out_token; + if (draft_prefix_token != -1) { + std::vector tokens_to_infer = {draft_prefix_token, out_token}; + // TODO: Handle OOM exception for static model here. + candidate = m_draft_request->infer_next(tokens_to_infer); + candidates.push_back(candidate); + candidates_to_generate--; + } + for (size_t i = 0; i < candidates_to_generate; i++) { + // TODO: Handle OOM exception for static model here. + candidate = m_draft_request->infer_next(candidate); + candidates.push_back(candidate); + } + + // Phase 2. Main inference. + // For the main network, candidates_size + 1 tokens will be fed at once in a single infer request: + // last token from previous main inference + all candidates from the draft stage + // FIXME: How max_seq_length will be handled? + auto input_for_main = candidates; + input_for_main.insert(candidates.begin(), out_token); + // TODO: Handle OOM exception for static model here. + auto ref_out_tokens = m_main_request->infer_next_return_all(input_for_main); + + // Phase 3. Check if main model produced the same tokens as input candidates: + size_t accepted_tokens_number = 0u; + // Last token is a new token from the main model, skip it: + for (size_t i = 0; i < ref_out_tokens.size() - 1; ++i) { + if (ref_out_tokens[i] != candidates[i]) { + break; + } + accepted_tokens_number++; + } + + auto mismatched_candidates = candidates.size() - accepted_tokens_number; + + // Phase 4: Update inference wrappers based on found matches and mismatches + // This is the case when main model accepted all candidates from draft model + // we need to collect kv cache for draft last generated token by infering it.n + if (mismatched_candidates == 0) { + draft_prefix_token = candidate; + } else { + m_draft_request->remove_last_generated_tokens(mismatched_candidates); + m_draft_request->trimm_kv_cache(mismatched_candidates - 1); + // Check that this works correctly for the model with output seq length != 1 + m_main_request->remove_last_generated_tokens(mismatched_candidates + 1); + m_main_request->trimm_kv_cache(mismatched_candidates); + } + + m_speculative_config.update_candidate_strategy(accepted_tokens_number); + // Should be enough, if all will be streamed from logits? + stream_generated_tokens(streamer_ptr, m_main_gen_handle); + + // raw_perf_counters.m_new_token_times.emplace_back(std::chrono::steady_clock::now()); + // raw_perf_counters.m_batch_sizes.emplace_back(batch_size); + } + + if (streamer_ptr) { // push streamer's cache + streamer_ptr->end(); + } + + m_draft_request->reset_state(); + m_main_request->reset_state(); + + ov::genai::EncodedResults results = m_main_request->finalize(); + m_chat_generation_finish_status = m_main_request->get_generation_status(); + + // auto stop_time = std::chrono::steady_clock::now(); + // If is called without tokenization then that stat will not be reported. + // auto& metrics = results.perf_metrics; + // metrics.num_input_tokens = batch_size * input_ids.get_shape().at(1); + // metrics.load_time = this->m_load_time_ms; + // metrics.raw_metrics.generate_durations.emplace_back(PerfMetrics::get_microsec(stop_time - start_time)); + // metrics.evaluate_statistics(start_time); + return results; +} + +void SpeculativeLLMPipelineNPU::start_chat(const std::string& system_message) { + if (!system_message.empty()) { + m_history.push_back({{"role", "system"}, {"content", system_message}}); + } + m_is_chat_conversation = true; +}; + +void SpeculativeLLMPipelineNPU::finish_chat() { + m_is_chat_conversation = false; + m_history.clear(); +}; + +SpeculativeLLMPipelineNPU::~SpeculativeLLMPipelineNPU() { + // FIXME: Do we need it? + // m_request.get_compiled_model().release_memory(); +} +} // namespace genai +} // namespace ov diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp b/src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp new file mode 100644 index 0000000000..ce1286d8a9 --- /dev/null +++ b/src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp @@ -0,0 +1,113 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include +#include +#include "llm/pipeline_base.hpp" +#include "sampling/sampler.hpp" +#include "utils.hpp" + +namespace ov { +namespace genai { +constexpr size_t BATCH_SIZE = 1; + +class LLMInferWrapper { +public: + LLMInferWrapper::LLMInferWrapper(const ov::genai::ModelDesc& model_desc); + ov::genai::GenerationConfig get_generation_config() const; + void set_generation_config(ov::genai::GenerationConfig config); + int64_t infer_first(const ov::Tensor &input_ids, + const ov::Tensor &attention_mask, + const ov::Tensor &position_ids); + bool can_infer(); + int64_t infer_next(const std::vector tokens); + int64_t infer_next(int64_t out_token); + std::vector infer_next_return_all(const std::vector tokens); + ov::Tensor get_logits(); + std::size_t get_num_processed_tokens() const; + ov::genai::GenerationHandle create_generation_handle(); + void remove_last_generated_tokens(const std::size_t tokens_to_remove); + void trimm_kv_cache(const std::size_t tokens_to_remove); + ov::genai::EncodedResults finalize(); + ov::genai::GenerationStatus get_generation_status() const; + void reset_state(); + +private: + ov::Tensor infer_next_internal(const std::vector tokens); + void set_already_allocated_input_for_1_token(); + std::variant> sample_tokens( + const ov::Tensor& logits, std::size_t num_tokens_to_return); + +private: + ov::AnyMap m_properties; + ov::genai::GenerationConfig m_generation_config; + ov::genai::Tokenizer m_tokenizer; + + std::size_t m_num_processed_tokens = 0u; + uint32_t m_max_prompt_len = 0u; + uint32_t m_kvcache_total = 0u; + ov::genai::utils::KVAxesPosition m_kv_pos; + ov::InferRequest m_request; + ov::genai::Sampler m_sampler; + std::shared_ptr m_sequence_group = nullptr; + GenerationHandle m_handle = nullptr; + // Separate metrics? + + // Data placeholder for 1-token inference: + int64_t m_new_input_token = -1; + int64_t m_new_position_id = -1; + std::vector m_new_atten_mask_data; +}; + +struct SpeculativeConfig { + void update_candidate_strategy(const size_t num_matches); + + std::size_t max_seq_length = SIZE_MAX; + std::size_t num_pred_tokens = 5; + const std::size_t max_pred_tokens = 10; +}; + +class SpeculativeLLMPipelineNPU : public ov::genai::LLMPipelineImplBase { +public: + SpeculativeLLMPipelineNPU( + const ov::genai::ModelDesc& main_model_desc, + const ov::genai::ModelDesc& draft_model_desc + ); + + DecodedResults generate( + StringInputs inputs, + OptionalGenerationConfig generation_config, + StreamerVariant streamer + ) override; + + EncodedResults generate( + const EncodedInputs& inputs, + OptionalGenerationConfig generation_config, + StreamerVariant streamer + ) override; + + void start_chat(const std::string& system_message) override; + void finish_chat() override; + ~SpeculativeLLMPipelineNPU(); + +private: + int64_t generate_next_token(const std::vector tokens); + std::vector generate_candidates(int64_t out_token); + void update_candidate_strategy(const size_t num_matches); + void update_kv_cache(const size_t seq_length); + +private: + uint32_t m_max_prompt_len = 0u; + uint32_t m_kvcache_total = 0u; + std::unique_ptr m_draft_request; + std::unique_ptr m_main_request; + SpeculativeConfig m_speculative_config; + ov::genai::SDPerModelsPerfMetrics m_perf_metrics; + + bool m_is_chat_conversation = false; + ChatHistory m_history; + ov::genai::GenerationStatus m_chat_generation_finish_status = ov::genai::GenerationStatus::RUNNING; +}; + +} // namespace genai +} // namespace ov diff --git a/src/cpp/src/utils.hpp b/src/cpp/src/utils.hpp index 8cd5fc4e6d..9b87482df1 100644 --- a/src/cpp/src/utils.hpp +++ b/src/cpp/src/utils.hpp @@ -12,9 +12,10 @@ #include "openvino/runtime/core.hpp" #include "openvino/genai/generation_handle.hpp" +#include "openvino/genai/scheduler_config.hpp" +#include "openvino/genai/generation_config.hpp" #include "visual_language/processor_config.hpp" -#include "openvino/genai/generation_handle.hpp" #include "openvino/genai/streamer_base.hpp" namespace ov { @@ -23,6 +24,29 @@ namespace genai { extern const std::string PA_BACKEND; extern const std::string SDPA_BACKEND; +struct ModelDesc { + std::string device; + ov::genai::SchedulerConfig scheduler_config; + ov::AnyMap properties; + ov::genai::GenerationConfig generation_config; + std::shared_ptr model = nullptr; + ov::genai::Tokenizer tokenizer; + + ModelDesc(const std::shared_ptr& model, + const ov::genai::Tokenizer& tokenizer, + const std::string& device = {}, + const ov::AnyMap& properties = {}, + const ov::genai::SchedulerConfig& scheduler_config = {}, + const ov::genai::GenerationConfig& generation_config = {}) : + model(model), + tokenizer(tokenizer), + device(device), + properties(properties), + scheduler_config(scheduler_config), + generation_config(generation_config) {} + + ModelDesc() = default; +}; } // namespace genai } // namespace ov From 4aca9c3cf73a68b0eaf448895eccba52b8d791bc Mon Sep 17 00:00:00 2001 From: Anastasiya Pronina Date: Wed, 6 Aug 2025 11:48:45 +0100 Subject: [PATCH 02/40] Fixes to make pipe functional --- .../speculative_decoding_npu.cpp | 31 +++++++++---------- .../speculative_decoding_npu.hpp | 26 ++++++++++++---- 2 files changed, 35 insertions(+), 22 deletions(-) diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp index cd1e7c1a69..349c346e0d 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp @@ -241,23 +241,24 @@ ov::Tensor LLMInferWrapper::infer_next_internal(const std::vector token // } auto input_ids = m_request.get_tensor("input_ids"); - input_ids.set_shape({BATCH_SIZE, tokens_size}); - std::copy_n(tokens.begin(), tokens_size, input_ids.data()); + ov::Tensor new_input_ids(input_ids.get_element_type(), ov::Shape{BATCH_SIZE, tokens_size}); + std::copy_n(tokens.begin(), tokens_size, new_input_ids.data()); + m_request.set_tensor("input_ids", new_input_ids); // FIXME: For model with static shapes we can just copy after // the prefilled tokens, no reshape is needed. auto attention_mask = m_request.get_tensor("attention_mask"); - std::vector attention_mask_copy(attention_mask.data(), - attention_mask.data() + m_num_processed_tokens); - attention_mask.set_shape({BATCH_SIZE, m_num_processed_tokens + tokens_size}); - std::copy_n(attention_mask_copy.begin(), m_num_processed_tokens, attention_mask.data()); - std::fill_n(attention_mask.data() + m_num_processed_tokens, tokens_size, 1); + ov::Tensor new_attention_mask(attention_mask.get_element_type(), ov::Shape{BATCH_SIZE, m_num_processed_tokens + tokens_size}); + std::copy_n(attention_mask.data(), m_num_processed_tokens, new_attention_mask.data()); + std::fill_n(new_attention_mask.data() + m_num_processed_tokens, tokens_size, 1); + m_request.set_tensor("attention_mask", new_attention_mask); auto position_ids = m_request.get_tensor("position_ids"); - position_ids.set_shape({BATCH_SIZE, tokens_size}); - std::iota(position_ids.data(), - position_ids.data() + position_ids.get_size(), + ov::Tensor new_position_ids(position_ids.get_element_type(), ov::Shape{BATCH_SIZE, tokens_size}); + std::iota(new_position_ids.data(), + new_position_ids.data() + new_position_ids.get_size(), m_num_processed_tokens); + m_request.set_tensor("position_ids", new_position_ids); m_request.get_tensor("beam_idx").set_shape({BATCH_SIZE}); m_request.get_tensor("beam_idx").data()[0] = 0; @@ -284,8 +285,7 @@ void LLMInferWrapper::set_already_allocated_input_for_1_token() { m_request.set_tensor("position_ids", ov::Tensor(ov::element::i64, ov::Shape{1,1}, reinterpret_cast(&m_new_position_id))); } -// FIXME: It is wrong way to sample tokens, or right because of set output_seq_len in the sequence? -// get_generated_ids will return all ids? +// FIXME: Need to use Sampler correctly. Sampler does all the validation itself! Just needs to configure it correctly. std::variant> LLMInferWrapper::sample_tokens(const ov::Tensor& logits, std::size_t num_tokens_to_return) { OPENVINO_ASSERT(m_sequence_group, "sample_tokens() can be called only after infer_first()!"); @@ -298,7 +298,6 @@ std::variant> return sampled_tokens.back(); } else { // FIXME condition can be switched to boolean? - OPENVINO_ASSERT(num_tokens_to_return == sampled_tokens.size()); return sampled_tokens; } } @@ -500,8 +499,8 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( // To collect KV-cache for the prompt and to get the next token, run the very first infer request // for draft and main models: - m_draft_request->infer_first(input_ids, attention_mask, position_ids); auto out_token = m_main_request->infer_first(input_ids, attention_mask, position_ids); + m_draft_request->infer_first(input_ids, attention_mask, position_ids); // logits shape is [BATCH_SIZE, seq_len, vocab_size] auto draft_logits = m_draft_request->get_logits(); @@ -585,8 +584,8 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( // For the main network, candidates_size + 1 tokens will be fed at once in a single infer request: // last token from previous main inference + all candidates from the draft stage // FIXME: How max_seq_length will be handled? - auto input_for_main = candidates; - input_for_main.insert(candidates.begin(), out_token); + std::vector input_for_main(candidates.begin(), candidates.end()); + input_for_main.insert(input_for_main.begin(), {out_token}); // TODO: Handle OOM exception for static model here. auto ref_out_tokens = m_main_request->infer_next_return_all(input_for_main); diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp b/src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp index ce1286d8a9..93153b0a1d 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp @@ -14,27 +14,44 @@ constexpr size_t BATCH_SIZE = 1; class LLMInferWrapper { public: LLMInferWrapper::LLMInferWrapper(const ov::genai::ModelDesc& model_desc); + ov::genai::GenerationConfig get_generation_config() const; + void set_generation_config(ov::genai::GenerationConfig config); + int64_t infer_first(const ov::Tensor &input_ids, const ov::Tensor &attention_mask, const ov::Tensor &position_ids); + bool can_infer(); + int64_t infer_next(const std::vector tokens); + int64_t infer_next(int64_t out_token); + std::vector infer_next_return_all(const std::vector tokens); + ov::Tensor get_logits(); + std::size_t get_num_processed_tokens() const; + ov::genai::GenerationHandle create_generation_handle(); + void remove_last_generated_tokens(const std::size_t tokens_to_remove); + void trimm_kv_cache(const std::size_t tokens_to_remove); + ov::genai::EncodedResults finalize(); + ov::genai::GenerationStatus get_generation_status() const; + void reset_state(); private: ov::Tensor infer_next_internal(const std::vector tokens); + void set_already_allocated_input_for_1_token(); + std::variant> sample_tokens( const ov::Tensor& logits, std::size_t num_tokens_to_return); @@ -59,6 +76,7 @@ class LLMInferWrapper { std::vector m_new_atten_mask_data; }; +// FIXME: Do we need this? struct SpeculativeConfig { void update_candidate_strategy(const size_t num_matches); @@ -87,14 +105,10 @@ class SpeculativeLLMPipelineNPU : public ov::genai::LLMPipelineImplBase { ) override; void start_chat(const std::string& system_message) override; + void finish_chat() override; - ~SpeculativeLLMPipelineNPU(); -private: - int64_t generate_next_token(const std::vector tokens); - std::vector generate_candidates(int64_t out_token); - void update_candidate_strategy(const size_t num_matches); - void update_kv_cache(const size_t seq_length); + ~SpeculativeLLMPipelineNPU(); private: uint32_t m_max_prompt_len = 0u; From 50f1fd3a3da9b4cb83ffa267add12a4244d2d872 Mon Sep 17 00:00:00 2001 From: Anastasiya Pronina Date: Wed, 6 Aug 2025 21:25:39 +0100 Subject: [PATCH 03/40] Removed sampler, using greedy decoding for initial implementation --- .../speculative_decoding_npu.cpp | 224 ++++++++---------- .../speculative_decoding_npu.hpp | 9 +- 2 files changed, 98 insertions(+), 135 deletions(-) diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp index 349c346e0d..4e4afbba48 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp @@ -22,14 +22,11 @@ ov::Tensor make_tensor_slice(ov::Tensor tensor, size_t dim, size_t start_pos, si return ov::Tensor(tensor, start_shape, end_shape); } -void stream_generated_tokens(std::shared_ptr streamer_ptr, - ov::genai::GenerationHandle& handle) { - if (streamer_ptr && handle->can_read()) { - std::unordered_map token = handle->read(); - auto streaming_status = streamer_ptr->write(token.begin()->second.generated_ids); - if (streaming_status != ov::genai::StreamingStatus::RUNNING) { - streaming_status == ov::genai::StreamingStatus::CANCEL ? handle->cancel() : handle->stop(); - } + +ov::genai::StreamingStatus stream_generated_tokens(std::shared_ptr streamer_ptr, + const std::vector& tokens) { + if (streamer_ptr) { + return streamer_ptr->write(tokens); } } } // anonymous namespace @@ -52,9 +49,6 @@ namespace genai { // utils::apply_gather_before_matmul_transformation(model_desc.model); m_request = ov::genai::utils::singleton_core().compile_model(model_desc.model, model_desc.device, m_properties).create_infer_request(); } - - m_sampler.set_tokenizer(m_tokenizer); - m_sampler.set_seed(model_desc.generation_config.rng_seed); } ov::genai::GenerationConfig LLMInferWrapper::get_generation_config() const { @@ -77,13 +71,7 @@ int64_t LLMInferWrapper::infer_first(const ov::Tensor &input_ids, m_request.infer(); m_num_processed_tokens = input_ids.get_shape()[1]; - - // Need for tokens sampling and streaming: - m_sequence_group = std::make_shared( - 0 /* request_id */, input_ids, m_generation_config, 1 /* block_size */); - m_sequence_group->schedule_tokens(m_sequence_group->get_prompt_len()); - auto logits = get_logits(); - m_sequence_group->set_output_seq_len(logits.get_shape().at(1)); + m_first_prompt_len = m_num_processed_tokens; // Initialize placeholder data for next inferences on input_ids of size 1 (if any) // with values of previous iteration for simple increment on next iteration: @@ -92,12 +80,26 @@ int64_t LLMInferWrapper::infer_first(const ov::Tensor &input_ids, m_new_atten_mask_data = std::vector(m_num_processed_tokens, 1); set_already_allocated_input_for_1_token(); - return std::get(sample_tokens(logits, 1u)); + last_token = std::get(sample_tokens(get_logits(), 1u)); + return last_token; } bool LLMInferWrapper::can_infer() { - OPENVINO_ASSERT(m_sequence_group, "can_infer() can be called only after infer_first()!"); - return (m_sequence_group->is_running() && !m_sequence_group->handle_stopped() && !m_sequence_group->handle_cancelled()); + OPENVINO_ASSERT(m_num_processed_tokens > 0, "can_infer() can be called only after infer_first()!"); + + // FIXME: Add condition to get out of KV-Cache length for static models. + auto stop_token_ids = m_generation_config.stop_token_ids; + if (!m_generation_config.ignore_eos && (last_token == m_generation_config.eos_token_id)) { + return false; + } + if (std::find(stop_token_ids.begin(), stop_token_ids.end(), last_token) != stop_token_ids.end()) { + return false; + } + if (m_num_processed_tokens - m_first_prompt_len + 1 >= m_generation_config.get_max_new_tokens()) { + return false; + } + + return true; } int64_t LLMInferWrapper::infer_next(int64_t token) { @@ -109,9 +111,6 @@ int64_t LLMInferWrapper::infer_next(int64_t token) { // return -1; // } - m_sequence_group->schedule_tokens(1u); - m_sequence_group->set_output_seq_len(1u); - // Just change the variables here, as pointers to them are already set to corresponding tensors m_new_input_token = token; ++m_new_position_id; @@ -123,26 +122,26 @@ int64_t LLMInferWrapper::infer_next(int64_t token) { m_num_processed_tokens += 1u; - return std::get(sample_tokens(get_logits(), 1u)); + last_token = std::get(sample_tokens(get_logits(), 1u)); + return last_token; } int64_t LLMInferWrapper::infer_next(const std::vector tokens) { OPENVINO_ASSERT(m_num_processed_tokens > 0, "infer_next() can be called only after infer_first()!"); - m_sequence_group->schedule_tokens(tokens.size()); - m_sequence_group->set_output_seq_len(1u); auto logits = infer_next_internal(tokens); - return std::get(sample_tokens(logits, 1u)); + last_token = std::get(sample_tokens(logits, 1u)); + return last_token; } std::vector LLMInferWrapper::infer_next_return_all(const std::vector tokens) { OPENVINO_ASSERT(m_num_processed_tokens > 0, "infer_next_return_all() can be called only after infer_first()!"); - auto tokens_size = tokens.size(); - m_sequence_group->schedule_tokens(tokens_size); - m_sequence_group->set_output_seq_len(tokens_size + 1u); auto logits = infer_next_internal(tokens); - return std::get>(sample_tokens(logits, tokens_size + 1u)); + auto tokens_size = tokens.size(); + auto sampled_tokens = std::get>(sample_tokens(logits, tokens_size + 1u)); + last_token = sampled_tokens[tokens_size]; + return sampled_tokens; } ov::Tensor LLMInferWrapper::get_logits() { @@ -153,41 +152,6 @@ std::size_t LLMInferWrapper::get_num_processed_tokens() const { return m_num_processed_tokens; } -// NB: This method should be called only after infer_first()! -GenerationHandle LLMInferWrapper::create_generation_handle() { - OPENVINO_ASSERT(m_num_processed_tokens > 0, "create_generation_handle() can be called only after infer_first()!"); - // NB: Controls what tokens are ready to be pushed into the streamer - m_handle = std::make_shared( - m_sequence_group->get_generation_stream(), m_sequence_group->get_sampling_parameters()); - return m_handle; -} - -// TODO: Debug -void LLMInferWrapper::remove_last_generated_tokens(const size_t tokens_to_remove) { - OPENVINO_ASSERT(m_sequence_group, "remove_last_generated_tokens() can be called only after infer_first()!"); - - // Remove last generated tokens - const auto running_sequences = m_sequence_group->get_running_sequences(); - const auto sequence = running_sequences.front(); - OPENVINO_ASSERT(running_sequences.size() == 1u); - const auto generated_token_ids = sequence->get_generated_ids(); - const auto sequence_generated_len = generated_token_ids.size(); - // auto& logit_processor = m_sampler->get_logit_processor(0); - OPENVINO_ASSERT(sequence_generated_len >= tokens_to_remove); - - // size_t start_pos = sequence_generated_len - tokens_to_remove; - // TODO: We might need it, however, we sample each token, unlike in CB pipeline - // for (size_t i = start_pos; i < sequence_generated_len; ++i) { - // logit_processor.decrease_generated_token_occurance(generated_token_ids[i]); - // } - sequence->remove_last_tokens(tokens_to_remove); - - // FIXME: Should we do it or shouldn't? - // if (is_update_logit_processor) { - // logit_processor.update_generated_len(min_candidate_len); - // } -} - void LLMInferWrapper::trimm_kv_cache(const size_t tokens_to_remove) { // Trim kv_cache values on tokens_to_remove ov::genai::utils::KVCacheState to_trim_state; @@ -198,39 +162,12 @@ void LLMInferWrapper::trimm_kv_cache(const size_t tokens_to_remove) { m_num_processed_tokens -= tokens_to_remove; } -ov::genai::EncodedResults LLMInferWrapper::finalize() { - OPENVINO_ASSERT(m_sequence_group, "finalize() can be called only after infer_first()!"); - - ov::genai::EncodedResults results; - // NB: Only batch=1 is supported now - results.scores.resize(1u); - results.scores[0] = 0u; - results.tokens.resize(1u); - - OPENVINO_ASSERT(m_sequence_group->get_finished_sequences().size() == 1, - "finalize() should be called when inference for current model is finished till EOS or other stop criteria."); - - auto sequence = m_sequence_group->get_finished_sequences().front(); - results.tokens[0] = sequence->get_generated_ids(); - results.scores[0] = sequence->get_cumulative_log_prob(); - - m_sampler.clear_request_info(m_sequence_group->get_request_id()); - - return results; -} - -ov::genai::GenerationStatus LLMInferWrapper::get_generation_status() const { - OPENVINO_ASSERT(m_sequence_group, "get_generation_status() can be called only after infer_first()!"); - - return m_sequence_group->get_generation_stream()->get_status(); -} - void LLMInferWrapper::reset_state() { return m_request.reset_state(); } ov::Tensor LLMInferWrapper::infer_next_internal(const std::vector tokens) { - OPENVINO_ASSERT(m_num_processed_tokens > 0, "ov::Tensor infer_next() can be called only after infer_first()!"); + OPENVINO_ASSERT(m_num_processed_tokens > 0, "infer_next_internal() can be called only after infer_first()!"); size_t tokens_size = tokens.size(); @@ -285,19 +222,36 @@ void LLMInferWrapper::set_already_allocated_input_for_1_token() { m_request.set_tensor("position_ids", ov::Tensor(ov::element::i64, ov::Shape{1,1}, reinterpret_cast(&m_new_position_id))); } -// FIXME: Need to use Sampler correctly. Sampler does all the validation itself! Just needs to configure it correctly. +// TODO: Use already provided Sampler API, that will support both greedy and +// multinomial decoding. std::variant> - LLMInferWrapper::sample_tokens(const ov::Tensor& logits, std::size_t num_tokens_to_return) { - OPENVINO_ASSERT(m_sequence_group, "sample_tokens() can be called only after infer_first()!"); - - m_sampler.sample({m_sequence_group}, logits); - const auto running_sequences = m_sequence_group->get_running_sequences(); - OPENVINO_ASSERT(running_sequences.size() == 1u); - auto sampled_tokens = running_sequences.front()->get_generated_ids(); - if (num_tokens_to_return == 1) { - return sampled_tokens.back(); + LLMInferWrapper::sample_tokens(const ov::Tensor& logits, std::size_t num_tokens_to_sample) { + OPENVINO_ASSERT(m_num_processed_tokens > 0, "sample_tokens() can be called only after infer_first()!"); + + // logits.shape = [1, seq_len, vocab_size]. + auto logits_shape = logits.get_shape(); + OPENVINO_ASSERT(logits_shape.size() == 3); + std::size_t batch_size = logits_shape[0]; + OPENVINO_ASSERT(batch_size == 1); + std::size_t seq_len = logits_shape[1]; + OPENVINO_ASSERT(num_tokens_to_sample <= seq_len); + std::size_t vocab_size = logits_shape[2]; + + auto sample_token = [&](const ov::Tensor& logits, std::size_t idx) { + size_t sequence_offset = idx * vocab_size; + float* logits_data = logits.data() + sequence_offset; + return std::max_element(logits_data, logits_data + vocab_size) - logits_data; + }; + + if (num_tokens_to_sample == 1) { + // Sample last logit: + return sample_token(logits, seq_len - 1); } else { - // FIXME condition can be switched to boolean? + // Sample last num_tokens_to_sample logits: + std::vector sampled_tokens; + for (std::size_t i = 0; i < num_tokens_to_sample; i++) { + sampled_tokens.push_back(sample_token(logits, seq_len - num_tokens_to_sample + i)); + } return sampled_tokens; } } @@ -338,7 +292,6 @@ SpeculativeLLMPipelineNPU::SpeculativeLLMPipelineNPU( // todo: remove this condition after support of CVS-154103 OPENVINO_ASSERT(are_tokenizers_equal(main_model_tokenizer, draft_model_tokenizer), "Tokenizers for draft and main models are different!"); m_tokenizer = main_model_tokenizer; - OPENVINO_ASSERT(draft_model_desc.generation_config.rng_seed == main_model_desc.generation_config.rng_seed, "Seed for sampling must be equal for draft and main models!"); // Draft model (which is smaller, less accurate but faster) auto draft_model_desc_copy = draft_model_desc; @@ -357,7 +310,7 @@ SpeculativeLLMPipelineNPU::SpeculativeLLMPipelineNPU( m_perf_metrics = ov::genai::SDPerModelsPerfMetrics(); // FIXME: Where to take it when draft model will be on NPU? - size_t max_sequence_length = main_model_desc.generation_config.get_max_new_tokens(); + size_t max_sequence_length = main_model_desc.generation_config.max_length; if (max_sequence_length == SIZE_MAX) { // FIXME: NPUW_LLM_MAX_PROMPT_LEN + NPUW_LLM_MIN_RESPONSE_LEN max_sequence_length = 100; @@ -474,10 +427,6 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( if (config.eos_token_id == -1) config.set_eos_token_id(m_generation_config.eos_token_id); config.validate(); - // FIXME: Update conditionally: - m_main_request->set_generation_config(config); - - std::shared_ptr streamer_ptr = ov::genai::utils::create_streamer(streamer, m_tokenizer); OPENVINO_ASSERT(config.is_greedy_decoding() || config.is_multinomial(), "Currently only greedy and multinomial decoding are supported"); @@ -485,9 +434,21 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( OPENVINO_ASSERT(config.num_return_sequences == 1u, "Currently only \"num_return_sequences\" equal to 1 is supported!"); + // FIXME: Update conditionally: + m_main_request->set_generation_config(config); + + // Config draft model to not stop on EOS and remove stop strings: + ov::genai::GenerationConfig draft_config = m_draft_request->get_generation_config(); + draft_config.ignore_eos = true; + draft_config.stop_strings = {}; + draft_config.validate(); + m_draft_request->set_generation_config(draft_config); + + std::shared_ptr streamer_ptr = ov::genai::utils::create_streamer(streamer, m_tokenizer); + // FIXME: Return back for the static draft model. // NB: Check if there is enough space in KV-cache to process input prompt - auto prompt_len = prompts_shape[1]; + // auto prompt_len = prompts_shape[1]; // if (prompt_len > m_max_prompt_len) { // OPENVINO_THROW("Static Stateful LLM pipeline may only process prompts up to " // + std::to_string(m_max_prompt_len) + " tokens. " @@ -529,15 +490,7 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( "Num of generated useful logits from draft models should be less" "or equal than ones from main model."); - // Config draft model to not stop on EOS and remove stop strings: - ov::genai::GenerationConfig draft_config = m_draft_request->get_generation_config(); - draft_config.ignore_eos = true; - draft_config.stop_strings = {}; - draft_config.validate(); - m_draft_request->set_generation_config(draft_config); - - GenerationHandle m_main_gen_handle = m_main_request->create_generation_handle(); - stream_generated_tokens(streamer_ptr, m_main_gen_handle); + auto streaming_status = stream_generated_tokens(streamer_ptr, std::vector {out_token}); /* Speculative decoding works the following way. The draft model predicts the next K tokens one by one in an autoregressive manner, while the main model validates these @@ -555,7 +508,7 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( // Last generated token by draft model needs to be prepended before next run if it is accepted by the main model! // So it will get into context too. int64_t draft_prefix_token = -1; - while (m_main_request->can_infer()) { + while (m_main_request->can_infer() && (streaming_status == ov::genai::StreamingStatus::RUNNING)) { // Phase 1: Generation of candidates with the draft model: std::vector candidates; // Limit candidates size by num_pred_tokens or by max_seq_length: @@ -586,6 +539,12 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( // FIXME: How max_seq_length will be handled? std::vector input_for_main(candidates.begin(), candidates.end()); input_for_main.insert(input_for_main.begin(), {out_token}); + // Note: If model isn't sliced to return logit only for the last element, + // then it returns logits for all elements of the input prompt. + // In that tensor, for each token `t` of the input prompt it contains + // distribution (over the vocabulary) for the next possible token + // that is generated based on subsequence [first token,...,`t`] + // of the input prompt. // TODO: Handle OOM exception for static model here. auto ref_out_tokens = m_main_request->infer_next_return_all(input_for_main); @@ -600,23 +559,22 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( } auto mismatched_candidates = candidates.size() - accepted_tokens_number; - + std::vector validated_tokens(candidates.begin(), candidates.end() - mismatched_candidates); + validated_tokens.push_back(ref_out_tokens.back()); + // Phase 4: Update inference wrappers based on found matches and mismatches // This is the case when main model accepted all candidates from draft model // we need to collect kv cache for draft last generated token by infering it.n if (mismatched_candidates == 0) { draft_prefix_token = candidate; } else { - m_draft_request->remove_last_generated_tokens(mismatched_candidates); m_draft_request->trimm_kv_cache(mismatched_candidates - 1); - // Check that this works correctly for the model with output seq length != 1 - m_main_request->remove_last_generated_tokens(mismatched_candidates + 1); m_main_request->trimm_kv_cache(mismatched_candidates); } m_speculative_config.update_candidate_strategy(accepted_tokens_number); // Should be enough, if all will be streamed from logits? - stream_generated_tokens(streamer_ptr, m_main_gen_handle); + stream_generated_tokens(streamer_ptr, validated_tokens); // raw_perf_counters.m_new_token_times.emplace_back(std::chrono::steady_clock::now()); // raw_perf_counters.m_batch_sizes.emplace_back(batch_size); @@ -629,8 +587,14 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( m_draft_request->reset_state(); m_main_request->reset_state(); - ov::genai::EncodedResults results = m_main_request->finalize(); - m_chat_generation_finish_status = m_main_request->get_generation_status(); + ov::genai::EncodedResults results; + // NB: Only batch=1 is supported now + results.scores.resize(1u); + results.scores[0] = 0u; + results.tokens.resize(1u); + // results.tokens[0] = sequence->get_generated_ids(); + // results.scores[0] = sequence->get_cumulative_log_prob(); + // m_chat_generation_finish_status = m_streaming_status; // auto stop_time = std::chrono::steady_clock::now(); // If is called without tokenization then that stat will not be reported. diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp b/src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp index 93153b0a1d..68834a8567 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp @@ -60,14 +60,13 @@ class LLMInferWrapper { ov::genai::GenerationConfig m_generation_config; ov::genai::Tokenizer m_tokenizer; + std::size_t m_max_prompt_len = 0u; + std::size_t m_kvcache_total = 0u; + std::size_t m_first_prompt_len = 0u; std::size_t m_num_processed_tokens = 0u; - uint32_t m_max_prompt_len = 0u; - uint32_t m_kvcache_total = 0u; + int64_t last_token = -1; ov::genai::utils::KVAxesPosition m_kv_pos; ov::InferRequest m_request; - ov::genai::Sampler m_sampler; - std::shared_ptr m_sequence_group = nullptr; - GenerationHandle m_handle = nullptr; // Separate metrics? // Data placeholder for 1-token inference: From 26d04a50349bf699d67d2390405f1de78582bf5a Mon Sep 17 00:00:00 2001 From: Anastasiya Pronina Date: Wed, 6 Aug 2025 22:22:29 +0100 Subject: [PATCH 04/40] Fixes after debug --- src/cpp/src/llm/pipeline.cpp | 10 +++++++--- .../speculative_decoding_npu.cpp | 15 +++++++++------ 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/src/cpp/src/llm/pipeline.cpp b/src/cpp/src/llm/pipeline.cpp index e598ec743e..42d2e63421 100644 --- a/src/cpp/src/llm/pipeline.cpp +++ b/src/cpp/src/llm/pipeline.cpp @@ -115,12 +115,16 @@ ov::genai::LLMPipeline::LLMPipeline( auto [properties, attention_backend] = utils::extract_attention_backend(user_properties); + // First -> check draft model. for NPU leave it as is for the main model. + // if NPU + // if draft model is on NPU // If CB is invoked explicitly, create CB adapter as is and re-throw in case if internal issues - if (utils::explicitly_requires_paged_attention(user_properties)) { + if (device == "NPU") { + m_pimpl = std::make_unique(models_path, properties); + } else if (utils::explicitly_requires_paged_attention(user_properties)) { auto [device_properties, scheduler_config] = utils::extract_scheduler_config(properties, utils::get_latency_oriented_scheduler_config()); m_pimpl = std::make_unique(models_path, scheduler_config, device, device_properties); - } else if (device == "NPU") { - m_pimpl = std::make_unique(models_path, properties); + } else if (attention_backend == PA_BACKEND) { // try to call CB adapter one more time, but with safe guard to silent exception try { diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp index 4e4afbba48..33376b13e5 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp @@ -139,8 +139,8 @@ std::vector LLMInferWrapper::infer_next_return_all(const std::vector>(sample_tokens(logits, tokens_size + 1u)); - last_token = sampled_tokens[tokens_size]; + auto sampled_tokens = std::get>(sample_tokens(logits, tokens_size)); + last_token = sampled_tokens[tokens_size - 1]; return sampled_tokens; } @@ -415,8 +415,8 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( attention_mask = data->attention_mask; } - ov::Shape prompts_shape = input_ids.get_shape(); - const size_t batch_size = prompts_shape[0]; + ov::Shape prompt_shape = input_ids.get_shape(); + const size_t batch_size = prompt_shape[0]; OPENVINO_ASSERT(batch_size == 1u, "Currently only batch size=1 is supported"); GenerationConfig config = (generation_config.has_value()) ? *generation_config : m_generation_config; @@ -436,6 +436,8 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( // FIXME: Update conditionally: m_main_request->set_generation_config(config); + auto prompt_len = prompt_shape[1]; + m_speculative_config.max_seq_length = prompt_len + config.get_max_new_tokens(prompt_len); // Config draft model to not stop on EOS and remove stop strings: ov::genai::GenerationConfig draft_config = m_draft_request->get_generation_config(); @@ -560,7 +562,8 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( auto mismatched_candidates = candidates.size() - accepted_tokens_number; std::vector validated_tokens(candidates.begin(), candidates.end() - mismatched_candidates); - validated_tokens.push_back(ref_out_tokens.back()); + out_token = ref_out_tokens.back(); + validated_tokens.push_back(out_token); // Phase 4: Update inference wrappers based on found matches and mismatches // This is the case when main model accepted all candidates from draft model @@ -574,7 +577,7 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( m_speculative_config.update_candidate_strategy(accepted_tokens_number); // Should be enough, if all will be streamed from logits? - stream_generated_tokens(streamer_ptr, validated_tokens); + streaming_status = stream_generated_tokens(streamer_ptr, validated_tokens); // raw_perf_counters.m_new_token_times.emplace_back(std::chrono::steady_clock::now()); // raw_perf_counters.m_batch_sizes.emplace_back(batch_size); From c74b8890c1b079aabf0f43a5893c207abd1badac Mon Sep 17 00:00:00 2001 From: Anastasiya Pronina Date: Thu, 7 Aug 2025 01:01:11 +0100 Subject: [PATCH 05/40] More fixes for accuracy --- src/cpp/src/llm/pipeline.cpp | 30 +++++++++---------- src/cpp/src/llm/pipeline_stateful_npu.cpp | 8 +++-- src/cpp/src/llm/pipeline_stateful_npu.hpp | 3 ++ .../speculative_decoding_npu.cpp | 28 ++++++++++++----- src/cpp/src/utils.cpp | 22 +++++++++++++- src/cpp/src/utils.hpp | 4 +++ 6 files changed, 69 insertions(+), 26 deletions(-) diff --git a/src/cpp/src/llm/pipeline.cpp b/src/cpp/src/llm/pipeline.cpp index 42d2e63421..6cea401798 100644 --- a/src/cpp/src/llm/pipeline.cpp +++ b/src/cpp/src/llm/pipeline.cpp @@ -80,12 +80,12 @@ ov::genai::LLMPipeline::LLMPipeline( auto start_time = std::chrono::steady_clock::now(); auto [properties, attention_backend] = utils::extract_attention_backend(user_properties); - // If CB is invoked explicitly, create CB adapter as is and re-throw in case if internal issues - if (utils::explicitly_requires_paged_attention(user_properties)) { + if (ov::genai::utils::is_npu_requested(device, properties)) { + m_pimpl = std::make_unique(models_path, tokenizer, device, properties); + } else if (utils::explicitly_requires_paged_attention(user_properties)) { + // If CB is invoked explicitly, create CB adapter as is and re-throw in case if internal issues auto [device_properties, scheduler_config] = utils::extract_scheduler_config(properties, utils::get_latency_oriented_scheduler_config()); m_pimpl = std::make_unique(models_path, tokenizer, scheduler_config, device, device_properties); - } else if (device == "NPU") { - m_pimpl = std::make_unique(models_path, tokenizer, properties); } else if (attention_backend == PA_BACKEND) { // try to call CB adapter one more time, but with safe guard to silent exception try { @@ -115,13 +115,10 @@ ov::genai::LLMPipeline::LLMPipeline( auto [properties, attention_backend] = utils::extract_attention_backend(user_properties); - // First -> check draft model. for NPU leave it as is for the main model. - // if NPU - // if draft model is on NPU - // If CB is invoked explicitly, create CB adapter as is and re-throw in case if internal issues - if (device == "NPU") { - m_pimpl = std::make_unique(models_path, properties); + if (ov::genai::utils::is_npu_requested(device, properties)) { + m_pimpl = std::make_unique(models_path, device, properties); } else if (utils::explicitly_requires_paged_attention(user_properties)) { + // If CB is invoked explicitly, create CB adapter as is and re-throw in case if internal issues auto [device_properties, scheduler_config] = utils::extract_scheduler_config(properties, utils::get_latency_oriented_scheduler_config()); m_pimpl = std::make_unique(models_path, scheduler_config, device, device_properties); @@ -157,17 +154,18 @@ ov::genai::LLMPipeline::LLMPipeline( auto [properties, attention_backend] = utils::extract_attention_backend(user_properties); - // If CB is invoked explicitly, create CB adapter as is and re-throw in case if internal issues - if (utils::explicitly_requires_paged_attention(user_properties)) { - auto [device_properties, scheduler_config] = utils::extract_scheduler_config(properties, utils::get_latency_oriented_scheduler_config()); - m_pimpl = std::make_unique(model_str, weights_tensor, - tokenizer, scheduler_config, device, device_properties, generation_config); - } else if (device == "NPU") { + if (ov::genai::utils::is_npu_requested(device, properties)) { m_pimpl = std::make_unique( utils::singleton_core().read_model(model_str, weights_tensor), tokenizer, + device, properties, generation_config); + } else if (utils::explicitly_requires_paged_attention(user_properties)) { + // If CB is invoked explicitly, create CB adapter as is and re-throw in case if internal issues + auto [device_properties, scheduler_config] = utils::extract_scheduler_config(properties, utils::get_latency_oriented_scheduler_config()); + m_pimpl = std::make_unique(model_str, weights_tensor, + tokenizer, scheduler_config, device, device_properties, generation_config); } else if (attention_backend == PA_BACKEND) { // try to call CB adapter one more time, but with safe guard to silent exception try { diff --git a/src/cpp/src/llm/pipeline_stateful_npu.cpp b/src/cpp/src/llm/pipeline_stateful_npu.cpp index 54e34f213c..83407d86fd 100644 --- a/src/cpp/src/llm/pipeline_stateful_npu.cpp +++ b/src/cpp/src/llm/pipeline_stateful_npu.cpp @@ -33,29 +33,33 @@ namespace ov::genai { StatefulLLMPipelineNPU::StatefulLLMPipelineNPU( const std::filesystem::path& models_path, const ov::genai::Tokenizer& tokenizer, + const std::string& device, const ov::AnyMap& properties) : StatefulLLMPipelineNPU( utils::read_model(models_path, properties), tokenizer, + device, properties, utils::from_config_json_if_exists(models_path) ) {} StatefulLLMPipelineNPU::StatefulLLMPipelineNPU( const std::filesystem::path& models_path, + const std::string& device, const ov::AnyMap& plugin_config) - : StatefulLLMPipelineNPU{models_path, Tokenizer(models_path, plugin_config), plugin_config} {} + : StatefulLLMPipelineNPU{models_path, Tokenizer(models_path, plugin_config), device, plugin_config} {} StatefulLLMPipelineNPU::StatefulLLMPipelineNPU( const std::shared_ptr& model, const ov::genai::Tokenizer& tokenizer, + const std::string& device, const ov::AnyMap& properties, const ov::genai::GenerationConfig& generation_config) : LLMPipelineImplBase(tokenizer, generation_config) { auto properties_without_draft_model = properties; auto draft_model_descr = extract_draft_model_from_config(properties_without_draft_model); if (draft_model_descr.model != nullptr) { - auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, "NPU", properties_without_draft_model, {}, generation_config); + auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model, {}, generation_config); m_pimpl = std::make_unique(main_model_descr, draft_model_descr); } else if (properties_without_draft_model.count("STATIC_PIPELINE")) { m_pimpl = static_llm::LLMPipelineFactory::create(model, tokenizer, diff --git a/src/cpp/src/llm/pipeline_stateful_npu.hpp b/src/cpp/src/llm/pipeline_stateful_npu.hpp index e14aa30065..5e88050501 100644 --- a/src/cpp/src/llm/pipeline_stateful_npu.hpp +++ b/src/cpp/src/llm/pipeline_stateful_npu.hpp @@ -13,17 +13,20 @@ class StatefulLLMPipelineNPU final : public LLMPipelineImplBase { StatefulLLMPipelineNPU( const std::filesystem::path& models_path, const ov::genai::Tokenizer& tokenizer, + const std::string& device, const ov::AnyMap& plugin_config ); StatefulLLMPipelineNPU( const std::filesystem::path& models_path, + const std::string& device, const ov::AnyMap& plugin_config ); StatefulLLMPipelineNPU( const std::shared_ptr& model, const ov::genai::Tokenizer& tokenizer, + const std::string& device, const ov::AnyMap& config, const ov::genai::GenerationConfig& generation_config ); diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp index 33376b13e5..3b4d4bd650 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp @@ -443,6 +443,7 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( ov::genai::GenerationConfig draft_config = m_draft_request->get_generation_config(); draft_config.ignore_eos = true; draft_config.stop_strings = {}; + draft_config.max_new_tokens = config.get_max_new_tokens(); draft_config.validate(); m_draft_request->set_generation_config(draft_config); @@ -509,6 +510,9 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( */ // Last generated token by draft model needs to be prepended before next run if it is accepted by the main model! // So it will get into context too. + // Remove debug lines. + // std::cout << std::endl << "Launching spec decode for " << config.get_max_new_tokens(prompt_len) << " max new tokens." << std::endl << std::endl; + // std::vector> accepted_tokens; int64_t draft_prefix_token = -1; while (m_main_request->can_infer() && (streaming_status == ov::genai::StreamingStatus::RUNNING)) { // Phase 1: Generation of candidates with the draft model: @@ -534,7 +538,7 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( candidate = m_draft_request->infer_next(candidate); candidates.push_back(candidate); } - + // Phase 2. Main inference. // For the main network, candidates_size + 1 tokens will be fed at once in a single infer request: // last token from previous main inference + all candidates from the draft stage @@ -548,22 +552,23 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( // that is generated based on subsequence [first token,...,`t`] // of the input prompt. // TODO: Handle OOM exception for static model here. - auto ref_out_tokens = m_main_request->infer_next_return_all(input_for_main); + auto ref_tokens = m_main_request->infer_next_return_all(input_for_main); // Phase 3. Check if main model produced the same tokens as input candidates: size_t accepted_tokens_number = 0u; // Last token is a new token from the main model, skip it: - for (size_t i = 0; i < ref_out_tokens.size() - 1; ++i) { - if (ref_out_tokens[i] != candidates[i]) { + for (size_t i = 0; i < ref_tokens.size() - 1; ++i) { + if (ref_tokens[i] != candidates[i]) { break; } accepted_tokens_number++; } + // FIXME: Remove debug line + // accepted_tokens.push_back({accepted_tokens_number, candidates.size()}); auto mismatched_candidates = candidates.size() - accepted_tokens_number; - std::vector validated_tokens(candidates.begin(), candidates.end() - mismatched_candidates); - out_token = ref_out_tokens.back(); - validated_tokens.push_back(out_token); + std::vector validated_tokens(ref_tokens.begin(), ref_tokens.end() - mismatched_candidates); + out_token = validated_tokens.back(); // Phase 4: Update inference wrappers based on found matches and mismatches // This is the case when main model accepted all candidates from draft model @@ -573,6 +578,7 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( } else { m_draft_request->trimm_kv_cache(mismatched_candidates - 1); m_main_request->trimm_kv_cache(mismatched_candidates); + draft_prefix_token = -1; } m_speculative_config.update_candidate_strategy(accepted_tokens_number); @@ -587,6 +593,14 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( streamer_ptr->end(); } + // Remove debug lines + // std::cout << std::endl << std::endl << "Acceptance ratios for each iteration from total of " << accepted_tokens.size() << "." << std::endl; + // std::cout << "Format: n/m per iteration, `n` accepted tokens from `m` candidates." << std::endl; + // for (int i = 0; i < accepted_tokens.size(); ++i) { + // std::cout << accepted_tokens[i].first << "/" << accepted_tokens[i].second << ", "; + // } + m_speculative_config.num_pred_tokens = 5; + m_draft_request->reset_state(); m_main_request->reset_state(); diff --git a/src/cpp/src/utils.cpp b/src/cpp/src/utils.cpp index 99996979d9..aed102863f 100644 --- a/src/cpp/src/utils.cpp +++ b/src/cpp/src/utils.cpp @@ -104,7 +104,6 @@ inline bool is_paged_attention_available() { return false; #endif } - } // anonymous namespace ov { @@ -205,6 +204,27 @@ ProcessorConfig from_any_map( return extracted_config; } +ov::genai::ModelDesc get_draft_model_from_config(const ov::AnyMap& config) { + ov::genai::ModelDesc draft_model; + if (config.find(utils::DRAFT_MODEL_ARG_NAME) != config.end()) { + draft_model = config.at(utils::DRAFT_MODEL_ARG_NAME).as(); + } + return draft_model; +} + +bool is_npu_requested(const std::string& device, const ov::AnyMap& properties) { + if (device == "NPU") { + return true; + } + + auto draft_model_descr = get_draft_model_from_config(properties); + if (draft_model_descr.model != nullptr) { + return draft_model_descr.device == "NPU"; + } + + return false; +} + ov::genai::TokenizedInputs subtract_chat_tokenized_inputs(const ov::genai::TokenizedInputs& minuend, const ov::genai::TokenizedInputs& subtrahend) { auto minuend_size = minuend.input_ids.get_size(); auto subtrahend_size = subtrahend.input_ids.get_size(); diff --git a/src/cpp/src/utils.hpp b/src/cpp/src/utils.hpp index 9b87482df1..4995b3b754 100644 --- a/src/cpp/src/utils.hpp +++ b/src/cpp/src/utils.hpp @@ -118,6 +118,10 @@ ProcessorConfig from_any_map( const ProcessorConfig& initial ); +ov::genai::ModelDesc get_draft_model_from_config(const ov::AnyMap& config); + +bool is_npu_requested(const std::string& device, const ov::AnyMap& properties); + ov::genai::TokenizedInputs subtract_chat_tokenized_inputs(const ov::genai::TokenizedInputs& minuend, const ov::genai::TokenizedInputs& subtrahend); void apply_slice_before_matmul_transformation(std::shared_ptr model); From 6901dbaf4b97749bd0fecf6c1f7793716cbacda6 Mon Sep 17 00:00:00 2001 From: Anastasiya Pronina Date: Thu, 7 Aug 2025 12:55:07 +0100 Subject: [PATCH 06/40] Fixed issue with infer on 1 token after KV-cache trim --- .../speculative_decoding_npu.cpp | 24 +++++++++++++------ .../speculative_decoding_npu.hpp | 2 +- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp index 3b4d4bd650..efd0fa085a 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp @@ -152,7 +152,7 @@ std::size_t LLMInferWrapper::get_num_processed_tokens() const { return m_num_processed_tokens; } -void LLMInferWrapper::trimm_kv_cache(const size_t tokens_to_remove) { +void LLMInferWrapper::trim_kv_cache(const size_t tokens_to_remove) { // Trim kv_cache values on tokens_to_remove ov::genai::utils::KVCacheState to_trim_state; to_trim_state.num_tokens_to_trim = tokens_to_remove; @@ -160,6 +160,16 @@ void LLMInferWrapper::trimm_kv_cache(const size_t tokens_to_remove) { to_trim_state.reset_mem_state = false; ov::genai::utils::trim_kv_cache(m_request, to_trim_state, {}); m_num_processed_tokens -= tokens_to_remove; + + // Update pre-allocated inputs for 1 token and return back to use it + // in case if next infer will be called on input_ids of size 1 + // (most frequent case). + m_new_input_token = -1; + m_new_position_id = m_num_processed_tokens - 1; + for (std::size_t i = 0; i < tokens_to_remove; ++i) { + m_new_atten_mask_data.pop_back(); + } + set_already_allocated_input_for_1_token(); } void LLMInferWrapper::reset_state() { @@ -211,7 +221,7 @@ ov::Tensor LLMInferWrapper::infer_next_internal(const std::vector token m_new_position_id = m_num_processed_tokens - 1; for (std::size_t i = 0; i < tokens_size; ++i) { m_new_atten_mask_data.push_back(1); - } + } set_already_allocated_input_for_1_token(); return get_logits(); @@ -282,7 +292,7 @@ SpeculativeLLMPipelineNPU::SpeculativeLLMPipelineNPU( } // TODO: We might need it for manipulations with indices - // utils::apply_gather_before_matmul_transformation(main_model); + // utils::apply_gather_before_matmul_transformation(main_model_desc.model); // utils::apply_gather_before_matmul_transformation(draft_model); // Main and Draft model can have different tokenizers @@ -553,7 +563,6 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( // of the input prompt. // TODO: Handle OOM exception for static model here. auto ref_tokens = m_main_request->infer_next_return_all(input_for_main); - // Phase 3. Check if main model produced the same tokens as input candidates: size_t accepted_tokens_number = 0u; // Last token is a new token from the main model, skip it: @@ -574,10 +583,11 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( // This is the case when main model accepted all candidates from draft model // we need to collect kv cache for draft last generated token by infering it.n if (mismatched_candidates == 0) { - draft_prefix_token = candidate; + draft_prefix_token = candidates.back(); } else { - m_draft_request->trimm_kv_cache(mismatched_candidates - 1); - m_main_request->trimm_kv_cache(mismatched_candidates); + // Last draft candidate is out of KV-Cache, as it is output token. + m_draft_request->trim_kv_cache(mismatched_candidates - 1); + m_main_request->trim_kv_cache(mismatched_candidates); draft_prefix_token = -1; } diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp b/src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp index 68834a8567..7aee31c14f 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp @@ -39,7 +39,7 @@ class LLMInferWrapper { void remove_last_generated_tokens(const std::size_t tokens_to_remove); - void trimm_kv_cache(const std::size_t tokens_to_remove); + void trim_kv_cache(const std::size_t tokens_to_remove); ov::genai::EncodedResults finalize(); From 3ac772ea1534fbf150aef03898a659a61747c7bc Mon Sep 17 00:00:00 2001 From: Anastasiya Pronina Date: Thu, 7 Aug 2025 17:03:37 +0100 Subject: [PATCH 07/40] Handled specifics for launch of models on NPU --- src/cpp/src/llm/pipeline_stateful_npu.cpp | 2 +- .../speculative_decoding_npu.cpp | 319 +++++++++--------- .../speculative_decoding_npu.hpp | 15 +- 3 files changed, 176 insertions(+), 160 deletions(-) diff --git a/src/cpp/src/llm/pipeline_stateful_npu.cpp b/src/cpp/src/llm/pipeline_stateful_npu.cpp index 83407d86fd..1cce1f9516 100644 --- a/src/cpp/src/llm/pipeline_stateful_npu.cpp +++ b/src/cpp/src/llm/pipeline_stateful_npu.cpp @@ -58,7 +58,7 @@ StatefulLLMPipelineNPU::StatefulLLMPipelineNPU( : LLMPipelineImplBase(tokenizer, generation_config) { auto properties_without_draft_model = properties; auto draft_model_descr = extract_draft_model_from_config(properties_without_draft_model); - if (draft_model_descr.model != nullptr) { + if (draft_model_descr.model != nullptr) { auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model, {}, generation_config); m_pimpl = std::make_unique(main_model_descr, draft_model_descr); } else if (properties_without_draft_model.count("STATIC_PIPELINE")) { diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp index efd0fa085a..59d7608a5b 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp @@ -6,6 +6,8 @@ #include "openvino/core/parallel.hpp" #include "openvino/genai/text_streamer.hpp" +#include + namespace ov::genai { template struct overloaded : Ts... {using Ts::operator()...;}; template overloaded(Ts...) -> overloaded; @@ -14,20 +16,12 @@ bool are_tokenizers_equal(ov::genai::Tokenizer& lhs, ov::genai::Tokenizer& rhs); } // ov::genai namespace { -ov::Tensor make_tensor_slice(ov::Tensor tensor, size_t dim, size_t start_pos, size_t end_pos) { - ov::Shape start_shape(std::vector(tensor.get_shape().size(), 0u)); - start_shape[dim] = start_pos; - ov::Shape end_shape = tensor.get_shape(); - end_shape[dim] = end_pos; - return ov::Tensor(tensor, start_shape, end_shape); -} - - ov::genai::StreamingStatus stream_generated_tokens(std::shared_ptr streamer_ptr, const std::vector& tokens) { if (streamer_ptr) { return streamer_ptr->write(tokens); } + return ov::genai::StreamingStatus{}; } } // anonymous namespace @@ -35,11 +29,12 @@ namespace ov { namespace genai { LLMInferWrapper::LLMInferWrapper( const ov::genai::ModelDesc& model_desc -) : m_properties(model_desc.properties), +) : m_device(model_desc.device), + m_properties(model_desc.properties), m_generation_config(model_desc.generation_config), m_tokenizer(model_desc.tokenizer) { m_kv_pos = ov::genai::utils::get_kv_axes_pos(model_desc.model); - if (model_desc.device == "NPU") { + if (m_device == "NPU") { auto [compiled, kv_desc] = utils::compile_decoder_for_npu(model_desc.model, m_properties, m_kv_pos); m_max_prompt_len = kv_desc.max_prompt_len; m_kvcache_total = kv_desc.max_prompt_len + kv_desc.min_response_len; @@ -47,10 +42,14 @@ namespace genai { } else { // TODO: We might need it for manipulations with indices // utils::apply_gather_before_matmul_transformation(model_desc.model); - m_request = ov::genai::utils::singleton_core().compile_model(model_desc.model, model_desc.device, m_properties).create_infer_request(); + m_request = ov::genai::utils::singleton_core().compile_model(model_desc.model, m_device, m_properties).create_infer_request(); } } +std::string LLMInferWrapper::device() const { + return m_device; +} + ov::genai::GenerationConfig LLMInferWrapper::get_generation_config() const { return m_generation_config; } @@ -59,9 +58,37 @@ void LLMInferWrapper::set_generation_config(ov::genai::GenerationConfig config) m_generation_config = config; } +int64_t LLMInferWrapper::get_kvcache_capacity() const { + if (m_device == "NPU") { + return m_kvcache_total - m_num_processed_tokens; + } + return std::numeric_limits::max(); +} + +int64_t LLMInferWrapper::get_generation_capacity() const { + int64_t max_new_tokens = static_cast(m_generation_config.get_max_new_tokens()); + if (m_first_prompt_len > 0) { + int64_t generated_new_tokens = static_cast(m_num_processed_tokens - m_first_prompt_len) + 1; + return max_new_tokens - generated_new_tokens; + } else { + return max_new_tokens; + } +} + int64_t LLMInferWrapper::infer_first(const ov::Tensor &input_ids, const ov::Tensor &attention_mask, const ov::Tensor &position_ids) { + if (m_device == "NPU") { + // NB: Check if there is enough space in KV-cache to process input prompt + auto prompt_len = input_ids.get_shape()[1]; + if (prompt_len > m_max_prompt_len) { + OPENVINO_THROW("LLM model on NPU may only process prompts up to " + + std::to_string(m_max_prompt_len) + " tokens. " + + "Set the \"MAX_PROMPT_LEN\" config option to " + + "increase the limit."); + } + } + m_request.set_tensor("input_ids", input_ids); m_request.set_tensor("attention_mask", attention_mask); m_request.set_tensor("position_ids", position_ids); @@ -80,25 +107,31 @@ int64_t LLMInferWrapper::infer_first(const ov::Tensor &input_ids, m_new_atten_mask_data = std::vector(m_num_processed_tokens, 1); set_already_allocated_input_for_1_token(); + // Update last_token variable for `can_infer()` logic: last_token = std::get(sample_tokens(get_logits(), 1u)); return last_token; } -bool LLMInferWrapper::can_infer() { +bool LLMInferWrapper::can_infer(const std::size_t prompt_len) { OPENVINO_ASSERT(m_num_processed_tokens > 0, "can_infer() can be called only after infer_first()!"); - // FIXME: Add condition to get out of KV-Cache length for static models. + if (m_device == "NPU") { + if (prompt_len > get_kvcache_capacity()) { + // Not enough room in KVCache to process prompt_len tokens. + return false; + } + } + auto stop_token_ids = m_generation_config.stop_token_ids; if (!m_generation_config.ignore_eos && (last_token == m_generation_config.eos_token_id)) { return false; } if (std::find(stop_token_ids.begin(), stop_token_ids.end(), last_token) != stop_token_ids.end()) { - return false; + return false; } - if (m_num_processed_tokens - m_first_prompt_len + 1 >= m_generation_config.get_max_new_tokens()) { + if (get_generation_capacity() <= 0) { return false; } - return true; } @@ -122,78 +155,20 @@ int64_t LLMInferWrapper::infer_next(int64_t token) { m_num_processed_tokens += 1u; + // Update last_token variable for `can_infer()` logic: last_token = std::get(sample_tokens(get_logits(), 1u)); return last_token; } -int64_t LLMInferWrapper::infer_next(const std::vector tokens) { - OPENVINO_ASSERT(m_num_processed_tokens > 0, "infer_next() can be called only after infer_first()!"); - - auto logits = infer_next_internal(tokens); - last_token = std::get(sample_tokens(logits, 1u)); - return last_token; -} - std::vector LLMInferWrapper::infer_next_return_all(const std::vector tokens) { OPENVINO_ASSERT(m_num_processed_tokens > 0, "infer_next_return_all() can be called only after infer_first()!"); - auto logits = infer_next_internal(tokens); - auto tokens_size = tokens.size(); - auto sampled_tokens = std::get>(sample_tokens(logits, tokens_size)); - last_token = sampled_tokens[tokens_size - 1]; - return sampled_tokens; -} - -ov::Tensor LLMInferWrapper::get_logits() { - return m_request.get_tensor("logits"); -} - -std::size_t LLMInferWrapper::get_num_processed_tokens() const { - return m_num_processed_tokens; -} - -void LLMInferWrapper::trim_kv_cache(const size_t tokens_to_remove) { - // Trim kv_cache values on tokens_to_remove - ov::genai::utils::KVCacheState to_trim_state; - to_trim_state.num_tokens_to_trim = tokens_to_remove; - to_trim_state.seq_length_axis = m_kv_pos.seq_len; - to_trim_state.reset_mem_state = false; - ov::genai::utils::trim_kv_cache(m_request, to_trim_state, {}); - m_num_processed_tokens -= tokens_to_remove; - - // Update pre-allocated inputs for 1 token and return back to use it - // in case if next infer will be called on input_ids of size 1 - // (most frequent case). - m_new_input_token = -1; - m_new_position_id = m_num_processed_tokens - 1; - for (std::size_t i = 0; i < tokens_to_remove; ++i) { - m_new_atten_mask_data.pop_back(); - } - set_already_allocated_input_for_1_token(); -} - -void LLMInferWrapper::reset_state() { - return m_request.reset_state(); -} - -ov::Tensor LLMInferWrapper::infer_next_internal(const std::vector tokens) { - OPENVINO_ASSERT(m_num_processed_tokens > 0, "infer_next_internal() can be called only after infer_first()!"); - size_t tokens_size = tokens.size(); - - // FIXME: Uncomment for static model and throw exception instead - // if (m_num_processed_tokens + tokens_size == m_kvcache_total) { - // m_sequence_group->set_out_of_memory(); - // return -1; - // } - auto input_ids = m_request.get_tensor("input_ids"); ov::Tensor new_input_ids(input_ids.get_element_type(), ov::Shape{BATCH_SIZE, tokens_size}); std::copy_n(tokens.begin(), tokens_size, new_input_ids.data()); m_request.set_tensor("input_ids", new_input_ids); - // FIXME: For model with static shapes we can just copy after - // the prefilled tokens, no reshape is needed. auto attention_mask = m_request.get_tensor("attention_mask"); ov::Tensor new_attention_mask(attention_mask.get_element_type(), ov::Shape{BATCH_SIZE, m_num_processed_tokens + tokens_size}); std::copy_n(attention_mask.data(), m_num_processed_tokens, new_attention_mask.data()); @@ -224,7 +199,49 @@ ov::Tensor LLMInferWrapper::infer_next_internal(const std::vector token } set_already_allocated_input_for_1_token(); - return get_logits(); + auto logits = get_logits(); + auto sampled_tokens = std::get>(sample_tokens(logits, tokens_size)); + // Update last_token variable for `can_infer()` logic: + last_token = sampled_tokens[tokens_size - 1]; + return sampled_tokens; +} + +ov::Tensor LLMInferWrapper::get_logits() { + return m_request.get_tensor("logits"); +} + +std::size_t LLMInferWrapper::get_num_processed_tokens() const { + return m_num_processed_tokens; +} + +void LLMInferWrapper::trim_kv_cache(const size_t tokens_to_remove) { + OPENVINO_ASSERT(m_num_processed_tokens > 0, "trim_kv_cache() can be called only after infer_first()!"); + + OPENVINO_ASSERT(tokens_to_remove < m_num_processed_tokens); + // For NPU "trim" is done by position ids on NPUW side. + if (m_device != "NPU") { + // Trim kv_cache values on tokens_to_remove + ov::genai::utils::KVCacheState to_trim_state; + to_trim_state.num_tokens_to_trim = tokens_to_remove; + to_trim_state.seq_length_axis = m_kv_pos.seq_len; + to_trim_state.reset_mem_state = false; + ov::genai::utils::trim_kv_cache(m_request, to_trim_state, {}); + } + m_num_processed_tokens -= tokens_to_remove; + + // Update pre-allocated inputs for 1 token and return back to use it + // in case if next infer will be called on input_ids of size 1 + // (most frequent case). + m_new_input_token = -1; + m_new_position_id = m_num_processed_tokens - 1; + for (std::size_t i = 0; i < tokens_to_remove; ++i) { + m_new_atten_mask_data.pop_back(); + } + set_already_allocated_input_for_1_token(); +} + +void LLMInferWrapper::reset_state() { + return m_request.reset_state(); } void LLMInferWrapper::set_already_allocated_input_for_1_token() { @@ -313,12 +330,6 @@ SpeculativeLLMPipelineNPU::SpeculativeLLMPipelineNPU( } m_draft_request = std::make_unique(draft_model_desc_copy); - // Main model (which is bigger, more accurate but slower) - // FIXME: Need to support full logits tensor as output for main model on NPU. - m_main_request = std::make_unique(main_model_desc); - - m_perf_metrics = ov::genai::SDPerModelsPerfMetrics(); - // FIXME: Where to take it when draft model will be on NPU? size_t max_sequence_length = main_model_desc.generation_config.max_length; if (max_sequence_length == SIZE_MAX) { @@ -329,6 +340,15 @@ SpeculativeLLMPipelineNPU::SpeculativeLLMPipelineNPU( const std::size_t candidates_num = 5; m_speculative_config.max_seq_length = max_sequence_length; m_speculative_config.num_pred_tokens = candidates_num; + + // Main model (which is bigger, more accurate but slower) + auto main_model_desc_copy = main_model_desc; + if (main_model_desc_copy.device == "NPU") { + main_model_desc_copy.properties["NPUW_LLM_MAX_GENERATION_TOKEN_LEN"] = m_speculative_config.num_pred_tokens + 1; + } + m_main_request = std::make_unique(main_model_desc_copy); + + m_perf_metrics = ov::genai::SDPerModelsPerfMetrics(); } DecodedResults SpeculativeLLMPipelineNPU::generate( @@ -378,8 +398,8 @@ DecodedResults SpeculativeLLMPipelineNPU::generate( if (m_is_chat_conversation) { auto answer = decoded_results.texts[0]; - if (m_chat_generation_finish_status == GenerationStatus::CANCEL) - // If chat generation process was cancelled by user, let's rollback to previous state of history + if (m_streaming_was_cancelled) + // If generation process was cancelled by user, let's rollback to previous state of history m_history.pop_back(); else m_history.push_back({{"role", "assistant"}, {"content", answer}}); @@ -447,7 +467,9 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( // FIXME: Update conditionally: m_main_request->set_generation_config(config); auto prompt_len = prompt_shape[1]; - m_speculative_config.max_seq_length = prompt_len + config.get_max_new_tokens(prompt_len); + if (config.get_max_new_tokens(prompt_len) != SIZE_MAX) { + m_speculative_config.max_seq_length = prompt_len + config.get_max_new_tokens(prompt_len); + } // Config draft model to not stop on EOS and remove stop strings: ov::genai::GenerationConfig draft_config = m_draft_request->get_generation_config(); @@ -458,15 +480,12 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( m_draft_request->set_generation_config(draft_config); std::shared_ptr streamer_ptr = ov::genai::utils::create_streamer(streamer, m_tokenizer); - - // FIXME: Return back for the static draft model. - // NB: Check if there is enough space in KV-cache to process input prompt - // auto prompt_len = prompts_shape[1]; - // if (prompt_len > m_max_prompt_len) { - // OPENVINO_THROW("Static Stateful LLM pipeline may only process prompts up to " - // + std::to_string(m_max_prompt_len) + " tokens. " - // + "Set the \"MAX_PROMPT_LEN\" config option to increase the limit."); - // } + ov::genai::EncodedResults results; + // NB: Only batch=1 is supported now. + // NB: In the case of greedy decoding scores are filled with zeros. + results.scores.resize(1u); + results.scores[0] = 0u; + results.tokens.resize(1u); ov::Tensor position_ids{ov::element::i64, input_ids.get_shape()}; utils::initialize_position_ids(position_ids, attention_mask); @@ -484,26 +503,12 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( OPENVINO_ASSERT(draft_vocab_size == main_vocab_size, "Vocab sizes should be the same for the both: main and draft models!"); - - // FIXME: Apply this logic carefully in LLMInferRequest of prefill model, - // if needed. - // FIXME: Here is workaround to get only useful units of returned logits. - // If SliceOut is applied, there will be only 1 useful logit returned, - // nothing is required here. - // Other way, model will return logits of full context length, - // as internally prefill model is specially reshaped to return them. - // Fix should be done on OpenVINO side, so the model should return only - // useful logits of input prompt length, dropping the implementation-related - // padding ones. - // auto sequence_len = all_logits.get_shape()[1]; - // if (sequence_len > 1) { - // logits = make_tensor_slice(all_logits, 1, sequence_len - prompt_len, sequence_len); - // } OPENVINO_ASSERT(draft_logits.get_shape().at(1) <= main_logits.get_shape().at(1), "Num of generated useful logits from draft models should be less" "or equal than ones from main model."); auto streaming_status = stream_generated_tokens(streamer_ptr, std::vector {out_token}); + results.tokens[0].push_back(out_token); /* Speculative decoding works the following way. The draft model predicts the next K tokens one by one in an autoregressive manner, while the main model validates these @@ -519,51 +524,69 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( the main model instead of running K subsequent requests. */ // Last generated token by draft model needs to be prepended before next run if it is accepted by the main model! - // So it will get into context too. + // So it will get into the context too. // Remove debug lines. // std::cout << std::endl << "Launching spec decode for " << config.get_max_new_tokens(prompt_len) << " max new tokens." << std::endl << std::endl; // std::vector> accepted_tokens; int64_t draft_prefix_token = -1; while (m_main_request->can_infer() && (streaming_status == ov::genai::StreamingStatus::RUNNING)) { - // Phase 1: Generation of candidates with the draft model: + // Phase 1: Generation of candidates with the draft model. std::vector candidates; - // Limit candidates size by num_pred_tokens or by max_seq_length: - // FIXME: draft_prefix_token isn't taken into account! - // FIXME: How max_seq_length will limit further generation of main model? - size_t candidates_to_generate = std::min(m_speculative_config.num_pred_tokens, - m_speculative_config.max_seq_length - m_draft_request->get_num_processed_tokens() - 1); + int64_t kvcache_room_for_candidates = std::min( + m_draft_request->get_kvcache_capacity() - ((draft_prefix_token == -1) ? 1 : 0), + // Take into the account reference token that is prefixed to candidates: + m_main_request->get_kvcache_capacity() - 1); + int64_t generation_room_for_candidates = std::min( + m_draft_request->get_generation_capacity(), + // Take into the account output token, generated on candidates: + m_main_request->get_generation_capacity() - 1); + int64_t candidates_can_be_generated = std::min( + kvcache_room_for_candidates, generation_room_for_candidates); + if (candidates_can_be_generated <= 0) { + auto remainder = m_main_request->get_generation_capacity(); + // If user asked for more tokens in answer and we have + // KVCache capacity to sequentellly infer them: + if (remainder > 0 && m_main_request->can_infer(remainder)) { + for (std::size_t i = 0; i < remainder; ++i) { + out_token = m_main_request->infer_next(out_token); + streaming_status = stream_generated_tokens(streamer_ptr, {out_token}); + results.tokens[0].push_back(out_token); + } + } + break; + } + auto candidates_to_generate = std::min(static_cast(m_speculative_config.num_pred_tokens), + candidates_can_be_generated);; candidates.reserve(candidates_to_generate); - // If draft_prefix_token is present, prepend it to out_token in order to collect KV cache for it - auto candidate = out_token; + // If draft_prefix_token is present, run an infer on it to collect KV cache for it if (draft_prefix_token != -1) { - std::vector tokens_to_infer = {draft_prefix_token, out_token}; - // TODO: Handle OOM exception for static model here. - candidate = m_draft_request->infer_next(tokens_to_infer); - candidates.push_back(candidate); - candidates_to_generate--; + m_draft_request->infer_next(draft_prefix_token); } + + int64_t candidate = out_token; for (size_t i = 0; i < candidates_to_generate; i++) { - // TODO: Handle OOM exception for static model here. candidate = m_draft_request->infer_next(candidate); candidates.push_back(candidate); } + draft_prefix_token = candidates.back(); // Phase 2. Main inference. - // For the main network, candidates_size + 1 tokens will be fed at once in a single infer request: - // last token from previous main inference + all candidates from the draft stage + // For the main network, candidates_size + 1 tokens will be fed at once in a + // single infer request: last token from previous main inference + all candidates + // from the draft stage. + // + // Note on model's return variable: If model isn't sliced to return logit only + // for the last element, then it returns logits for all elements of the input + // prompt. In that tensor, for each token `t` of the input prompt it contains + // distribution (over the vocabulary) for the next possible token that is + // generated based on subsequence [first token,...,`t`] of the input prompt. // FIXME: How max_seq_length will be handled? std::vector input_for_main(candidates.begin(), candidates.end()); input_for_main.insert(input_for_main.begin(), {out_token}); - // Note: If model isn't sliced to return logit only for the last element, - // then it returns logits for all elements of the input prompt. - // In that tensor, for each token `t` of the input prompt it contains - // distribution (over the vocabulary) for the next possible token - // that is generated based on subsequence [first token,...,`t`] - // of the input prompt. - // TODO: Handle OOM exception for static model here. auto ref_tokens = m_main_request->infer_next_return_all(input_for_main); - // Phase 3. Check if main model produced the same tokens as input candidates: + + // Phase 3. Validation of candidates by output of main model: size_t accepted_tokens_number = 0u; // Last token is a new token from the main model, skip it: for (size_t i = 0; i < ref_tokens.size() - 1; ++i) { @@ -580,25 +603,23 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( out_token = validated_tokens.back(); // Phase 4: Update inference wrappers based on found matches and mismatches - // This is the case when main model accepted all candidates from draft model - // we need to collect kv cache for draft last generated token by infering it.n - if (mismatched_candidates == 0) { - draft_prefix_token = candidates.back(); - } else { - // Last draft candidate is out of KV-Cache, as it is output token. + if (mismatched_candidates > 0) { m_draft_request->trim_kv_cache(mismatched_candidates - 1); m_main_request->trim_kv_cache(mismatched_candidates); + // We don't need last candidate in KVCache of draft model, as + // it fails validation. draft_prefix_token = -1; } - m_speculative_config.update_candidate_strategy(accepted_tokens_number); - // Should be enough, if all will be streamed from logits? + streaming_status = stream_generated_tokens(streamer_ptr, validated_tokens); + results.tokens[0].insert(results.tokens[0].end(), validated_tokens.begin(), validated_tokens.end()); // raw_perf_counters.m_new_token_times.emplace_back(std::chrono::steady_clock::now()); // raw_perf_counters.m_batch_sizes.emplace_back(batch_size); } + m_streaming_was_cancelled = (streaming_status == ov::genai::StreamingStatus::CANCEL); if (streamer_ptr) { // push streamer's cache streamer_ptr->end(); } @@ -609,19 +630,11 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( // for (int i = 0; i < accepted_tokens.size(); ++i) { // std::cout << accepted_tokens[i].first << "/" << accepted_tokens[i].second << ", "; // } - m_speculative_config.num_pred_tokens = 5; + // Reset all states. + m_speculative_config.num_pred_tokens = 5; m_draft_request->reset_state(); m_main_request->reset_state(); - - ov::genai::EncodedResults results; - // NB: Only batch=1 is supported now - results.scores.resize(1u); - results.scores[0] = 0u; - results.tokens.resize(1u); - // results.tokens[0] = sequence->get_generated_ids(); - // results.scores[0] = sequence->get_cumulative_log_prob(); - // m_chat_generation_finish_status = m_streaming_status; // auto stop_time = std::chrono::steady_clock::now(); // If is called without tokenization then that stat will not be reported. diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp b/src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp index 7aee31c14f..a48eece3b2 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp @@ -15,17 +15,21 @@ class LLMInferWrapper { public: LLMInferWrapper::LLMInferWrapper(const ov::genai::ModelDesc& model_desc); + std::string device() const; + ov::genai::GenerationConfig get_generation_config() const; void set_generation_config(ov::genai::GenerationConfig config); + int64_t get_kvcache_capacity() const; + + int64_t get_generation_capacity() const; + int64_t infer_first(const ov::Tensor &input_ids, const ov::Tensor &attention_mask, const ov::Tensor &position_ids); - bool can_infer(); - - int64_t infer_next(const std::vector tokens); + bool can_infer(const std::size_t prompt_len = 0); int64_t infer_next(int64_t out_token); @@ -48,14 +52,13 @@ class LLMInferWrapper { void reset_state(); private: - ov::Tensor infer_next_internal(const std::vector tokens); - void set_already_allocated_input_for_1_token(); std::variant> sample_tokens( const ov::Tensor& logits, std::size_t num_tokens_to_return); private: + std::string m_device; ov::AnyMap m_properties; ov::genai::GenerationConfig m_generation_config; ov::genai::Tokenizer m_tokenizer; @@ -119,7 +122,7 @@ class SpeculativeLLMPipelineNPU : public ov::genai::LLMPipelineImplBase { bool m_is_chat_conversation = false; ChatHistory m_history; - ov::genai::GenerationStatus m_chat_generation_finish_status = ov::genai::GenerationStatus::RUNNING; + bool m_streaming_was_cancelled = false; }; } // namespace genai From 66114cb337d32007d5ca920ba37d3c01c3fef1b9 Mon Sep 17 00:00:00 2001 From: Anastasiya Pronina Date: Tue, 19 Aug 2025 17:58:09 +0100 Subject: [PATCH 08/40] Added perf and sd statistics --- .../speculative_decoding_npu.cpp | 262 +++++++++++------- .../speculative_decoding_npu.hpp | 38 +-- 2 files changed, 182 insertions(+), 118 deletions(-) diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp index 59d7608a5b..ec3d9a8e33 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp @@ -2,6 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 #include "speculative_decoding_npu.hpp" +#include "continuous_batching/timer.hpp" #include "openvino/runtime/core.hpp" #include "openvino/core/parallel.hpp" #include "openvino/genai/text_streamer.hpp" @@ -23,6 +24,14 @@ ov::genai::StreamingStatus stream_generated_tokens(std::shared_ptr()[0] = 0; + if (m_device != "NPU") { + // set beam_idx for stateful model: no beam search is used and BATCH_SIZE = 1 + m_request.get_tensor("beam_idx").set_shape({BATCH_SIZE}); + m_request.get_tensor("beam_idx").data()[0] = 0; + } m_request.infer(); m_num_processed_tokens = input_ids.get_shape()[1]; @@ -109,6 +123,12 @@ int64_t LLMInferWrapper::infer_first(const ov::Tensor &input_ids, // Update last_token variable for `can_infer()` logic: last_token = std::get(sample_tokens(get_logits(), 1u)); + + infer_first_timer.end(); + auto generation_duration = infer_first_timer.get_duration_microsec(); + raw_perf_metrics.m_durations.emplace_back(generation_duration); + raw_perf_metrics.m_batch_sizes.emplace_back(BATCH_SIZE); + return last_token; } @@ -122,10 +142,10 @@ bool LLMInferWrapper::can_infer(const std::size_t prompt_len) { } } - auto stop_token_ids = m_generation_config.stop_token_ids; if (!m_generation_config.ignore_eos && (last_token == m_generation_config.eos_token_id)) { return false; } + auto stop_token_ids = m_generation_config.stop_token_ids; if (std::find(stop_token_ids.begin(), stop_token_ids.end(), last_token) != stop_token_ids.end()) { return false; } @@ -135,14 +155,11 @@ bool LLMInferWrapper::can_infer(const std::size_t prompt_len) { return true; } -int64_t LLMInferWrapper::infer_next(int64_t token) { +int64_t LLMInferWrapper::infer_next(int64_t token, bool skip_perf_stat) { OPENVINO_ASSERT(m_num_processed_tokens > 0, "infer_next() can be called only after infer_first()!"); - // FIXME: Uncomment for static model and throw exception instead - // if (m_num_processed_tokens + tokens_size == m_kvcache_total) { - // m_sequence_group->set_out_of_memory(); - // return -1; - // } + ManualTimer infer_next_timer("infer_next()"); + infer_next_timer.start(); // Just change the variables here, as pointers to them are already set to corresponding tensors m_new_input_token = token; @@ -157,12 +174,23 @@ int64_t LLMInferWrapper::infer_next(int64_t token) { // Update last_token variable for `can_infer()` logic: last_token = std::get(sample_tokens(get_logits(), 1u)); + + infer_next_timer.end(); + if (!skip_perf_stat) { + auto generation_duration = infer_next_timer.get_duration_microsec(); + raw_perf_metrics.m_durations.emplace_back(generation_duration); + raw_perf_metrics.m_batch_sizes.emplace_back(BATCH_SIZE); + } + return last_token; } std::vector LLMInferWrapper::infer_next_return_all(const std::vector tokens) { OPENVINO_ASSERT(m_num_processed_tokens > 0, "infer_next_return_all() can be called only after infer_first()!"); + ManualTimer infer_next_return_all_timer("infer_next_return_all()"); + infer_next_return_all_timer.start(); + size_t tokens_size = tokens.size(); auto input_ids = m_request.get_tensor("input_ids"); ov::Tensor new_input_ids(input_ids.get_element_type(), ov::Shape{BATCH_SIZE, tokens_size}); @@ -182,9 +210,6 @@ std::vector LLMInferWrapper::infer_next_return_all(const std::vector()[0] = 0; - m_request.infer(); m_num_processed_tokens += tokens_size; @@ -202,7 +227,13 @@ std::vector LLMInferWrapper::infer_next_return_all(const std::vector>(sample_tokens(logits, tokens_size)); // Update last_token variable for `can_infer()` logic: - last_token = sampled_tokens[tokens_size - 1]; + last_token = sampled_tokens.back(); + + infer_next_return_all_timer.end(); + auto generation_duration = infer_next_return_all_timer.get_duration_microsec(); + raw_perf_metrics.m_durations.emplace_back(generation_duration); + raw_perf_metrics.m_batch_sizes.emplace_back(tokens_size); + return sampled_tokens; } @@ -244,6 +275,10 @@ void LLMInferWrapper::reset_state() { return m_request.reset_state(); } +void LLMInferWrapper::release_memory() { + m_request.get_compiled_model().release_memory(); +} + void LLMInferWrapper::set_already_allocated_input_for_1_token() { m_request.set_tensor("input_ids", ov::Tensor(ov::element::i64, ov::Shape{1,1}, reinterpret_cast(&m_new_input_token))); m_request.set_tensor("position_ids", ov::Tensor(ov::element::i64, ov::Shape{1,1}, reinterpret_cast(&m_new_position_id))); @@ -283,17 +318,6 @@ std::variant> } } -void SpeculativeConfig::update_candidate_strategy(const size_t num_matches) { - // Dynamically adjust number of generated candidates based on number of matches - // we want to balance the benefits of getting candidates tokens correct with the - // cost of forecasting incorrect candidates tokens. - if (num_matches == num_pred_tokens) { - num_pred_tokens = std::min(num_pred_tokens + 2, max_pred_tokens); - } else { - num_pred_tokens = std::max(int64_t(num_pred_tokens) - 1, int64_t(1)); - } -} - SpeculativeLLMPipelineNPU::SpeculativeLLMPipelineNPU( const ov::genai::ModelDesc& main_model_desc, const ov::genai::ModelDesc& draft_model_desc @@ -303,15 +327,11 @@ SpeculativeLLMPipelineNPU::SpeculativeLLMPipelineNPU( // FIXME: slicing produces incorrect results for some models on NPU. // On NPU, applying slice the safe way is done by the underlying plugin if (draft_model_desc.device != "NPU") { - utils::apply_slice_before_matmul_transformation(draft_model); // As draft_model_desc contains std::shared_ptr, // this model update will be reflected in draft_model_desc + utils::apply_slice_before_matmul_transformation(draft_model); } - - // TODO: We might need it for manipulations with indices - // utils::apply_gather_before_matmul_transformation(main_model_desc.model); - // utils::apply_gather_before_matmul_transformation(draft_model); - + // Main and Draft model can have different tokenizers // to do: support retokenization: 154103 ov::genai::Tokenizer main_model_tokenizer = main_model_desc.tokenizer; @@ -330,25 +350,17 @@ SpeculativeLLMPipelineNPU::SpeculativeLLMPipelineNPU( } m_draft_request = std::make_unique(draft_model_desc_copy); - // FIXME: Where to take it when draft model will be on NPU? - size_t max_sequence_length = main_model_desc.generation_config.max_length; - if (max_sequence_length == SIZE_MAX) { - // FIXME: NPUW_LLM_MAX_PROMPT_LEN + NPUW_LLM_MIN_RESPONSE_LEN - max_sequence_length = 100; - } - // FIXME: ? Use main_model.generation_config.num_assistant_tokens; It should be > 0, if we want draft_model.generation_config.is_speculative_decoding() == true. - const std::size_t candidates_num = 5; - m_speculative_config.max_seq_length = max_sequence_length; - m_speculative_config.num_pred_tokens = candidates_num; - // Main model (which is bigger, more accurate but slower) auto main_model_desc_copy = main_model_desc; if (main_model_desc_copy.device == "NPU") { - main_model_desc_copy.properties["NPUW_LLM_MAX_GENERATION_TOKEN_LEN"] = m_speculative_config.num_pred_tokens + 1; + main_model_desc_copy.properties["NPUW_LLM_MAX_GENERATION_TOKEN_LEN"] = 16; } m_main_request = std::make_unique(main_model_desc_copy); + + auto requested_candidates_num = main_model_desc.generation_config.num_assistant_tokens; + m_candidates_num = (requested_candidates_num != 0) ? requested_candidates_num : 5; - m_perf_metrics = ov::genai::SDPerModelsPerfMetrics(); + m_sd_perf_metrics = ov::genai::SDPerModelsPerfMetrics(); } DecodedResults SpeculativeLLMPipelineNPU::generate( @@ -356,7 +368,10 @@ DecodedResults SpeculativeLLMPipelineNPU::generate( OptionalGenerationConfig generation_config, StreamerVariant streamer ) { - auto start_time = std::chrono::steady_clock::now(); + ManualTimer generate_timer("SpeculativeLLMPipelineNPU::generate()"); + generate_timer.start(); + ManualTimer encode_timer("Encode"); + encode_timer.start(); std::string prompt = std::visit(overloaded{ [](const std::string& prompt) { @@ -389,12 +404,13 @@ DecodedResults SpeculativeLLMPipelineNPU::generate( } } - auto encode_stop_time = std::chrono::steady_clock::now(); + encode_timer.end(); auto encoded_results = generate(tokenized_input, config, streamer); - auto decode_start_time = std::chrono::steady_clock::now(); + ManualTimer decode_timer("Decode"); + decode_timer.start(); DecodedResults decoded_results = {m_tokenizer.decode(encoded_results.tokens), encoded_results.scores}; - auto decode_stop_time = std::chrono::steady_clock::now(); + decode_timer.end(); if (m_is_chat_conversation) { auto answer = decoded_results.texts[0]; @@ -406,15 +422,16 @@ DecodedResults SpeculativeLLMPipelineNPU::generate( } // generate_durations - // decoded_results.perf_metrics = encoded_results.perf_metrics; - // auto& raw_counters = decoded_results.perf_metrics.raw_metrics; - // auto stop_time = std::chrono::steady_clock::now(); - // raw_counters.generate_durations.clear(); - // raw_counters.generate_durations.emplace_back(PerfMetrics::get_microsec(stop_time - start_time)); - // raw_counters.tokenization_durations.emplace_back(PerfMetrics::get_microsec(encode_stop_time - start_time)); - // raw_counters.detokenization_durations.emplace_back(PerfMetrics::get_microsec(decode_stop_time - decode_start_time)); - // decoded_results.perf_metrics.m_evaluated = false; - // decoded_results.perf_metrics.evaluate_statistics(start_time); + decoded_results.perf_metrics = encoded_results.perf_metrics; + decoded_results.extended_perf_metrics = encoded_results.extended_perf_metrics; + auto& raw_counters = decoded_results.perf_metrics.raw_metrics; + generate_timer.end(); + raw_counters.generate_durations.clear(); + raw_counters.generate_durations.emplace_back(generate_timer.get_duration_microsec()); + raw_counters.tokenization_durations.emplace_back(encode_timer.get_duration_microsec()); + raw_counters.detokenization_durations.emplace_back(decode_timer.get_duration_microsec()); + decoded_results.perf_metrics.m_evaluated = false; + decoded_results.perf_metrics.evaluate_statistics(generate_timer.get_start_time()); return decoded_results; } @@ -422,17 +439,8 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( const EncodedInputs& inputs, OptionalGenerationConfig generation_config, StreamerVariant streamer) { - // from step() - auto& raw_perf_counters = m_perf_metrics.raw_metrics; - auto& main_raw_perf_counters = m_perf_metrics.main_model_metrics.raw_metrics; - // - - auto start_time = std::chrono::steady_clock::now(); - - // from generate() - ManualTimer generate_timer("speculative_decoding: generate()"); + ManualTimer generate_timer("SpeculativeLLMPipelineNPU::generate()"); generate_timer.start(); - // ov::Tensor input_ids; ov::Tensor attention_mask; @@ -466,9 +474,9 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( // FIXME: Update conditionally: m_main_request->set_generation_config(config); - auto prompt_len = prompt_shape[1]; - if (config.get_max_new_tokens(prompt_len) != SIZE_MAX) { - m_speculative_config.max_seq_length = prompt_len + config.get_max_new_tokens(prompt_len); + auto requested_candidates_num = config.num_assistant_tokens; + if (requested_candidates_num != 0) { + m_candidates_num = requested_candidates_num; } // Config draft model to not stop on EOS and remove stop strings: @@ -481,6 +489,7 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( std::shared_ptr streamer_ptr = ov::genai::utils::create_streamer(streamer, m_tokenizer); ov::genai::EncodedResults results; + auto& raw_perf_counters = m_sd_perf_metrics.raw_metrics; // NB: Only batch=1 is supported now. // NB: In the case of greedy decoding scores are filled with zeros. results.scores.resize(1u); @@ -505,7 +514,7 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( OPENVINO_ASSERT(draft_logits.get_shape().at(1) <= main_logits.get_shape().at(1), "Num of generated useful logits from draft models should be less" - "or equal than ones from main model."); + " or equal than ones from main model."); auto streaming_status = stream_generated_tokens(streamer_ptr, std::vector {out_token}); results.tokens[0].push_back(out_token); @@ -525,12 +534,15 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( */ // Last generated token by draft model needs to be prepended before next run if it is accepted by the main model! // So it will get into the context too. - // Remove debug lines. - // std::cout << std::endl << "Launching spec decode for " << config.get_max_new_tokens(prompt_len) << " max new tokens." << std::endl << std::endl; - // std::vector> accepted_tokens; int64_t draft_prefix_token = -1; while (m_main_request->can_infer() && (streaming_status == ov::genai::StreamingStatus::RUNNING)) { + ManualTimer iteration_timer("Speculative decode: infer iteration"); + iteration_timer.start(); + // Phase 1: Generation of candidates with the draft model. + ManualTimer candidates_timer("Draft model: candidates generation"); + candidates_timer.start(); + std::vector candidates; int64_t kvcache_room_for_candidates = std::min( m_draft_request->get_kvcache_capacity() - ((draft_prefix_token == -1) ? 1 : 0), @@ -545,23 +557,35 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( if (candidates_can_be_generated <= 0) { auto remainder = m_main_request->get_generation_capacity(); // If user asked for more tokens in answer and we have - // KVCache capacity to sequentellly infer them: + // KVCache capacity to sequentially infer them: if (remainder > 0 && m_main_request->can_infer(remainder)) { for (std::size_t i = 0; i < remainder; ++i) { + ManualTimer main_timer("Main model: inference of the remainder"); + main_timer.start(); out_token = m_main_request->infer_next(out_token); + main_timer.end(); + streaming_status = stream_generated_tokens(streamer_ptr, {out_token}); results.tokens[0].push_back(out_token); + + iteration_timer.end(); + auto iteration_duration = iteration_timer.get_duration_microsec(); + update_perf_metrics(raw_perf_counters, iteration_duration, main_timer.get_end_time(), 1u); + iteration_timer.start(); } } break; } - auto candidates_to_generate = std::min(static_cast(m_speculative_config.num_pred_tokens), - candidates_can_be_generated);; + auto candidates_to_generate = std::min(static_cast(m_candidates_num), + candidates_can_be_generated); candidates.reserve(candidates_to_generate); + ManualTimer draft_prefix_timer("Draft model: prefix token inference"); // If draft_prefix_token is present, run an infer on it to collect KV cache for it if (draft_prefix_token != -1) { - m_draft_request->infer_next(draft_prefix_token); + draft_prefix_timer.start(); + m_draft_request->infer_next(draft_prefix_token, true); + draft_prefix_timer.end(); } int64_t candidate = out_token; @@ -569,8 +593,19 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( candidate = m_draft_request->infer_next(candidate); candidates.push_back(candidate); } + + // Dirty hack to update inference duration of token, that was calculated on + // { draft_prefix_token, out_token } + if (draft_prefix_token != -1) { + auto token_to_hack_it = m_draft_request->raw_perf_metrics.m_durations.end() - candidates_to_generate; + (*token_to_hack_it) += ov::genai::MicroSeconds(draft_prefix_timer.get_duration_microsec()); + } + draft_prefix_token = candidates.back(); + candidates_timer.end(); + m_sd_metrics.draft_duration += candidates_timer.get_duration(); + // Phase 2. Main inference. // For the main network, candidates_size + 1 tokens will be fed at once in a // single infer request: last token from previous main inference + all candidates @@ -581,11 +616,16 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( // prompt. In that tensor, for each token `t` of the input prompt it contains // distribution (over the vocabulary) for the next possible token that is // generated based on subsequence [first token,...,`t`] of the input prompt. - // FIXME: How max_seq_length will be handled? + ManualTimer main_timer("Main model: inference on candidates"); + main_timer.start(); + std::vector input_for_main(candidates.begin(), candidates.end()); input_for_main.insert(input_for_main.begin(), {out_token}); auto ref_tokens = m_main_request->infer_next_return_all(input_for_main); + main_timer.end(); + m_sd_metrics.main_duration += main_timer.get_duration(); + // Phase 3. Validation of candidates by output of main model: size_t accepted_tokens_number = 0u; // Last token is a new token from the main model, skip it: @@ -596,8 +636,6 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( accepted_tokens_number++; } - // FIXME: Remove debug line - // accepted_tokens.push_back({accepted_tokens_number, candidates.size()}); auto mismatched_candidates = candidates.size() - accepted_tokens_number; std::vector validated_tokens(ref_tokens.begin(), ref_tokens.end() - mismatched_candidates); out_token = validated_tokens.back(); @@ -610,13 +648,21 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( // it fails validation. draft_prefix_token = -1; } - m_speculative_config.update_candidate_strategy(accepted_tokens_number); + update_candidate_strategy(accepted_tokens_number); + + m_sd_metrics.update_acceptance_rate(0 /* request_id */, (accepted_tokens_number / candidates_to_generate) * 100); + m_sd_metrics.update_draft_accepted_tokens(0 /* request_id */, accepted_tokens_number); + if (utils::env_setup_for_print_debug_info()) { + m_sd_metrics.print(true); + m_sd_metrics.clean_up(); + } streaming_status = stream_generated_tokens(streamer_ptr, validated_tokens); results.tokens[0].insert(results.tokens[0].end(), validated_tokens.begin(), validated_tokens.end()); - // raw_perf_counters.m_new_token_times.emplace_back(std::chrono::steady_clock::now()); - // raw_perf_counters.m_batch_sizes.emplace_back(batch_size); + iteration_timer.end(); + auto iteration_duration = iteration_timer.get_duration_microsec(); + update_perf_metrics(raw_perf_counters, iteration_duration, main_timer.get_end_time(), validated_tokens.size()); } m_streaming_was_cancelled = (streaming_status == ov::genai::StreamingStatus::CANCEL); @@ -624,25 +670,30 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( streamer_ptr->end(); } - // Remove debug lines - // std::cout << std::endl << std::endl << "Acceptance ratios for each iteration from total of " << accepted_tokens.size() << "." << std::endl; - // std::cout << "Format: n/m per iteration, `n` accepted tokens from `m` candidates." << std::endl; - // for (int i = 0; i < accepted_tokens.size(); ++i) { - // std::cout << accepted_tokens[i].first << "/" << accepted_tokens[i].second << ", "; - // } - // Reset all states. - m_speculative_config.num_pred_tokens = 5; + m_candidates_num = 5; + requested_candidates_num = config.num_assistant_tokens; + if (requested_candidates_num != 0) { + m_candidates_num = requested_candidates_num; + } m_draft_request->reset_state(); m_main_request->reset_state(); - // auto stop_time = std::chrono::steady_clock::now(); + generate_timer.end(); // If is called without tokenization then that stat will not be reported. - // auto& metrics = results.perf_metrics; - // metrics.num_input_tokens = batch_size * input_ids.get_shape().at(1); - // metrics.load_time = this->m_load_time_ms; - // metrics.raw_metrics.generate_durations.emplace_back(PerfMetrics::get_microsec(stop_time - start_time)); - // metrics.evaluate_statistics(start_time); + m_sd_perf_metrics.num_input_tokens = input_ids.get_shape().at(1); + m_sd_perf_metrics.load_time = this->m_load_time_ms; + m_sd_perf_metrics.raw_metrics.generate_durations.clear(); + m_sd_perf_metrics.raw_metrics.generate_durations.emplace_back(generate_timer.get_duration_microsec()); + + m_sd_perf_metrics.draft_model_metrics.raw_metrics = m_draft_request->raw_perf_metrics; + m_sd_perf_metrics.main_model_metrics.raw_metrics = m_main_request->raw_perf_metrics; + + m_sd_perf_metrics.evaluate_statistics(generate_timer.get_start_time()); + + results.perf_metrics = m_sd_perf_metrics; + results.extended_perf_metrics = std::make_shared(m_sd_perf_metrics); + return results; } @@ -659,8 +710,19 @@ void SpeculativeLLMPipelineNPU::finish_chat() { }; SpeculativeLLMPipelineNPU::~SpeculativeLLMPipelineNPU() { - // FIXME: Do we need it? - // m_request.get_compiled_model().release_memory(); + m_main_request->release_memory(); + m_draft_request->release_memory(); +} + +void SpeculativeLLMPipelineNPU::update_candidate_strategy(const std::size_t matches_num) { + // Dynamically adjust number of generated candidates based on number of matches, + // we want to balance the benefits of getting candidates tokens correct with the + // cost of forecasting incorrect candidates tokens. + if (matches_num == m_candidates_num) { + m_candidates_num = std::min(m_candidates_num + 2, m_max_candidates_num); + } else { + m_candidates_num = std::max(int64_t(m_candidates_num) - 1, int64_t(1)); + } } } // namespace genai } // namespace ov diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp b/src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp index a48eece3b2..c5be05ed26 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp @@ -1,15 +1,16 @@ // Copyright (C) 2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 -#include -#include + +#include "speculative_decoding_metrics.hpp" #include "llm/pipeline_base.hpp" #include "sampling/sampler.hpp" #include "utils.hpp" +#include +#include namespace ov { namespace genai { -constexpr size_t BATCH_SIZE = 1; class LLMInferWrapper { public: @@ -31,7 +32,7 @@ class LLMInferWrapper { bool can_infer(const std::size_t prompt_len = 0); - int64_t infer_next(int64_t out_token); + int64_t infer_next(int64_t out_token, bool skip_perf_stat = false); std::vector infer_next_return_all(const std::vector tokens); @@ -51,7 +52,14 @@ class LLMInferWrapper { void reset_state(); + void release_memory(); + +public: + ov::genai::RawPerfMetrics raw_perf_metrics; + private: + static constexpr std::size_t BATCH_SIZE = 1; + void set_already_allocated_input_for_1_token(); std::variant> sample_tokens( @@ -70,7 +78,6 @@ class LLMInferWrapper { int64_t last_token = -1; ov::genai::utils::KVAxesPosition m_kv_pos; ov::InferRequest m_request; - // Separate metrics? // Data placeholder for 1-token inference: int64_t m_new_input_token = -1; @@ -78,15 +85,6 @@ class LLMInferWrapper { std::vector m_new_atten_mask_data; }; -// FIXME: Do we need this? -struct SpeculativeConfig { - void update_candidate_strategy(const size_t num_matches); - - std::size_t max_seq_length = SIZE_MAX; - std::size_t num_pred_tokens = 5; - const std::size_t max_pred_tokens = 10; -}; - class SpeculativeLLMPipelineNPU : public ov::genai::LLMPipelineImplBase { public: SpeculativeLLMPipelineNPU( @@ -113,12 +111,16 @@ class SpeculativeLLMPipelineNPU : public ov::genai::LLMPipelineImplBase { ~SpeculativeLLMPipelineNPU(); private: - uint32_t m_max_prompt_len = 0u; - uint32_t m_kvcache_total = 0u; + void update_candidate_strategy(const std::size_t matches_num); + +private: std::unique_ptr m_draft_request; std::unique_ptr m_main_request; - SpeculativeConfig m_speculative_config; - ov::genai::SDPerModelsPerfMetrics m_perf_metrics; + std::size_t m_candidates_num = 5; + const std::size_t m_max_candidates_num = 10; + + ov::genai::SpeculativeDecodingMetrics m_sd_metrics; + ov::genai::SDPerModelsPerfMetrics m_sd_perf_metrics; bool m_is_chat_conversation = false; ChatHistory m_history; From 27441379a2d829b67d8702ad7ef9582d80449b87 Mon Sep 17 00:00:00 2001 From: Anastasiya Pronina Date: Wed, 20 Aug 2025 02:08:27 +0100 Subject: [PATCH 09/40] Removed unneccessary copy of properties --- src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp index ec3d9a8e33..ddbcd6e9eb 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp @@ -345,9 +345,6 @@ SpeculativeLLMPipelineNPU::SpeculativeLLMPipelineNPU( if (draft_model_desc_copy.device.empty()) { draft_model_desc_copy.device = main_model_desc.device; } - if (draft_model_desc_copy.properties.empty()) { - draft_model_desc_copy.properties = main_model_desc.properties; - } m_draft_request = std::make_unique(draft_model_desc_copy); // Main model (which is bigger, more accurate but slower) From 9b74a99779ddfa27814c7265da505ee621453b1b Mon Sep 17 00:00:00 2001 From: Anastasiya Pronina Date: Wed, 20 Aug 2025 02:59:48 +0100 Subject: [PATCH 10/40] Fixed setting for NPUW target generate model --- .../src/speculative_decoding/speculative_decoding_npu.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp index ddbcd6e9eb..c30a9a164f 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp @@ -347,16 +347,16 @@ SpeculativeLLMPipelineNPU::SpeculativeLLMPipelineNPU( } m_draft_request = std::make_unique(draft_model_desc_copy); + auto requested_candidates_num = main_model_desc.generation_config.num_assistant_tokens; + m_candidates_num = (requested_candidates_num != 0) ? requested_candidates_num : 5; + // Main model (which is bigger, more accurate but slower) auto main_model_desc_copy = main_model_desc; if (main_model_desc_copy.device == "NPU") { - main_model_desc_copy.properties["NPUW_LLM_MAX_GENERATION_TOKEN_LEN"] = 16; + main_model_desc_copy.properties["NPUW_LLM_MAX_GENERATION_TOKEN_LEN"] = m_max_candidates_num + 1; } m_main_request = std::make_unique(main_model_desc_copy); - auto requested_candidates_num = main_model_desc.generation_config.num_assistant_tokens; - m_candidates_num = (requested_candidates_num != 0) ? requested_candidates_num : 5; - m_sd_perf_metrics = ov::genai::SDPerModelsPerfMetrics(); } From 719d49e882e4f7c74bdadb1debda2316fc1d0f1a Mon Sep 17 00:00:00 2001 From: Anastasiya Pronina Date: Mon, 25 Aug 2025 17:26:22 +0100 Subject: [PATCH 11/40] Fixes for perf metrics --- src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp | 4 ++++ src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp index c30a9a164f..a3a73e15da 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp @@ -634,6 +634,10 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( } auto mismatched_candidates = candidates.size() - accepted_tokens_number; + + auto& main_perf_gen_tokens = m_main_request->raw_perf_metrics.m_batch_sizes.back(); + main_perf_gen_tokens -= mismatched_candidates; + std::vector validated_tokens(ref_tokens.begin(), ref_tokens.end() - mismatched_candidates); out_token = validated_tokens.back(); diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp b/src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp index c5be05ed26..cba80bd33c 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp @@ -14,7 +14,7 @@ namespace genai { class LLMInferWrapper { public: - LLMInferWrapper::LLMInferWrapper(const ov::genai::ModelDesc& model_desc); + LLMInferWrapper(const ov::genai::ModelDesc& model_desc); std::string device() const; From 5960b13e72ebf0ca83aff67859b02eadd795b25c Mon Sep 17 00:00:00 2001 From: Anastasiya Pronina Date: Mon, 25 Aug 2025 17:34:05 +0100 Subject: [PATCH 12/40] Polishing --- .../src/speculative_decoding/speculative_decoding_npu.hpp | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp b/src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp index cba80bd33c..7cd249d00b 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp @@ -40,16 +40,8 @@ class LLMInferWrapper { std::size_t get_num_processed_tokens() const; - ov::genai::GenerationHandle create_generation_handle(); - - void remove_last_generated_tokens(const std::size_t tokens_to_remove); - void trim_kv_cache(const std::size_t tokens_to_remove); - ov::genai::EncodedResults finalize(); - - ov::genai::GenerationStatus get_generation_status() const; - void reset_state(); void release_memory(); From df451183383025ee015ccee29b2a6018ab17a7cb Mon Sep 17 00:00:00 2001 From: Anastasiya Pronina Date: Wed, 27 Aug 2025 17:02:14 +0100 Subject: [PATCH 13/40] Polishing --- .../speculative_decoding_npu.cpp | 59 +++++++++++-------- 1 file changed, 34 insertions(+), 25 deletions(-) diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp index a3a73e15da..51258db671 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp @@ -32,7 +32,14 @@ void update_perf_metrics(ov::genai::RawPerfMetrics& raw_perf_counters, const flo raw_perf_counters.m_new_token_times.emplace_back(new_token_time); raw_perf_counters.m_batch_sizes.emplace_back(num_generated_tokens); } -} // anonymous namespace + +// FIXME: Should we update infer duraion? +void update_perf_metrics(ov::genai::RawPerfMetrics& raw_perf_counters, + const float token_duration, const std::size_t num_generated_tokens) { + raw_perf_counters.m_durations.emplace_back(token_duration); + raw_perf_counters.m_batch_sizes.emplace_back(num_generated_tokens); +} +}// anonymous namespace namespace ov { namespace genai { @@ -125,10 +132,7 @@ int64_t LLMInferWrapper::infer_first(const ov::Tensor &input_ids, last_token = std::get(sample_tokens(get_logits(), 1u)); infer_first_timer.end(); - auto generation_duration = infer_first_timer.get_duration_microsec(); - raw_perf_metrics.m_durations.emplace_back(generation_duration); - raw_perf_metrics.m_batch_sizes.emplace_back(BATCH_SIZE); - + update_perf_metrics(raw_perf_metrics, infer_first_timer.get_duration_microsec(), BATCH_SIZE); return last_token; } @@ -177,9 +181,8 @@ int64_t LLMInferWrapper::infer_next(int64_t token, bool skip_perf_stat) { infer_next_timer.end(); if (!skip_perf_stat) { - auto generation_duration = infer_next_timer.get_duration_microsec(); - raw_perf_metrics.m_durations.emplace_back(generation_duration); - raw_perf_metrics.m_batch_sizes.emplace_back(BATCH_SIZE); + update_perf_metrics( + raw_perf_metrics, infer_next_timer.get_duration_microsec(), BATCH_SIZE); } return last_token; @@ -230,10 +233,8 @@ std::vector LLMInferWrapper::infer_next_return_all(const std::vector(draft_model_desc_copy); auto requested_candidates_num = main_model_desc.generation_config.num_assistant_tokens; @@ -591,8 +595,11 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( candidates.push_back(candidate); } - // Dirty hack to update inference duration of token, that was calculated on - // { draft_prefix_token, out_token } + // Dirty hack to update inference duration of token, that was calculated after inference + // on draft_prefix_token. + // As we are inferring on draft_prefix_token just to update KV-Cache and are not interested + // in predicted token output, then we accumulating this inference's duration with the one + // of the next token (which are interested in): if (draft_prefix_token != -1) { auto token_to_hack_it = m_draft_request->raw_perf_metrics.m_durations.end() - candidates_to_generate; (*token_to_hack_it) += ov::genai::MicroSeconds(draft_prefix_timer.get_duration_microsec()); @@ -634,10 +641,6 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( } auto mismatched_candidates = candidates.size() - accepted_tokens_number; - - auto& main_perf_gen_tokens = m_main_request->raw_perf_metrics.m_batch_sizes.back(); - main_perf_gen_tokens -= mismatched_candidates; - std::vector validated_tokens(ref_tokens.begin(), ref_tokens.end() - mismatched_candidates); out_token = validated_tokens.back(); @@ -651,6 +654,8 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( } update_candidate_strategy(accepted_tokens_number); + auto& main_perf_generated_tokens = m_main_request->raw_perf_metrics.m_batch_sizes.back(); + main_perf_generated_tokens -= mismatched_candidates; m_sd_metrics.update_acceptance_rate(0 /* request_id */, (accepted_tokens_number / candidates_to_generate) * 100); m_sd_metrics.update_draft_accepted_tokens(0 /* request_id */, accepted_tokens_number); if (utils::env_setup_for_print_debug_info()) { @@ -671,14 +676,16 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( streamer_ptr->end(); } - // Reset all states. - m_candidates_num = 5; - requested_candidates_num = config.num_assistant_tokens; - if (requested_candidates_num != 0) { - m_candidates_num = requested_candidates_num; + // If not chat conversation, then reset all states. + if (!m_is_chat_conversation) { + m_candidates_num = 5; + requested_candidates_num = config.num_assistant_tokens; + if (requested_candidates_num != 0) { + m_candidates_num = requested_candidates_num; + } + m_draft_request->reset_state(); + m_main_request->reset_state(); } - m_draft_request->reset_state(); - m_main_request->reset_state(); generate_timer.end(); // If is called without tokenization then that stat will not be reported. @@ -708,6 +715,8 @@ void SpeculativeLLMPipelineNPU::start_chat(const std::string& system_message) { void SpeculativeLLMPipelineNPU::finish_chat() { m_is_chat_conversation = false; m_history.clear(); + m_draft_request->reset_state(); + m_main_request->reset_state(); }; SpeculativeLLMPipelineNPU::~SpeculativeLLMPipelineNPU() { From 819e43049a806191c7ba401f9804cd2b0720587c Mon Sep 17 00:00:00 2001 From: Anastasiya Pronina Date: Thu, 28 Aug 2025 12:24:34 +0100 Subject: [PATCH 14/40] Fixed perf metrics --- .../speculative_decoding_npu.cpp | 62 +++++++++++-------- .../speculative_decoding_npu.hpp | 2 +- 2 files changed, 38 insertions(+), 26 deletions(-) diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp index 51258db671..01f89ac2f4 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp @@ -33,10 +33,12 @@ void update_perf_metrics(ov::genai::RawPerfMetrics& raw_perf_counters, const flo raw_perf_counters.m_batch_sizes.emplace_back(num_generated_tokens); } -// FIXME: Should we update infer duraion? void update_perf_metrics(ov::genai::RawPerfMetrics& raw_perf_counters, - const float token_duration, const std::size_t num_generated_tokens) { + const float inference_duration, + const float token_duration, + const std::size_t num_generated_tokens) { raw_perf_counters.m_durations.emplace_back(token_duration); + raw_perf_counters.m_inference_durations[0] += ov::genai::MicroSeconds(inference_duration); raw_perf_counters.m_batch_sizes.emplace_back(num_generated_tokens); } }// anonymous namespace @@ -60,6 +62,7 @@ namespace genai { // utils::apply_gather_before_matmul_transformation(model_desc.model); m_request = ov::genai::utils::singleton_core().compile_model(model_desc.model, m_device, m_properties).create_infer_request(); } + raw_perf_metrics.m_inference_durations = {{ ov::genai::MicroSeconds(0.0f) }}; } std::string LLMInferWrapper::device() const { @@ -117,7 +120,10 @@ int64_t LLMInferWrapper::infer_first(const ov::Tensor &input_ids, m_request.get_tensor("beam_idx").data()[0] = 0; } + const auto infer_start = std::chrono::steady_clock::now(); m_request.infer(); + const auto infer_end = std::chrono::steady_clock::now(); + m_num_processed_tokens = input_ids.get_shape()[1]; m_first_prompt_len = m_num_processed_tokens; @@ -132,7 +138,9 @@ int64_t LLMInferWrapper::infer_first(const ov::Tensor &input_ids, last_token = std::get(sample_tokens(get_logits(), 1u)); infer_first_timer.end(); - update_perf_metrics(raw_perf_metrics, infer_first_timer.get_duration_microsec(), BATCH_SIZE); + update_perf_metrics(raw_perf_metrics, + ov::genai::PerfMetrics::get_microsec(infer_end - infer_start), + infer_first_timer.get_duration_microsec(), BATCH_SIZE); return last_token; } @@ -159,7 +167,7 @@ bool LLMInferWrapper::can_infer(const std::size_t prompt_len) { return true; } -int64_t LLMInferWrapper::infer_next(int64_t token, bool skip_perf_stat) { +int64_t LLMInferWrapper::infer_next(int64_t token, bool append_perf_stat) { OPENVINO_ASSERT(m_num_processed_tokens > 0, "infer_next() can be called only after infer_first()!"); ManualTimer infer_next_timer("infer_next()"); @@ -172,7 +180,9 @@ int64_t LLMInferWrapper::infer_next(int64_t token, bool skip_perf_stat) { m_new_atten_mask_data.push_back(1); m_request.set_tensor("attention_mask", ov::Tensor(ov::element::i64, ov::Shape{1,m_new_atten_mask_data.size()}, (void*)&m_new_atten_mask_data[0])); + const auto infer_start = std::chrono::steady_clock::now(); m_request.infer(); + const auto infer_end = std::chrono::steady_clock::now(); m_num_processed_tokens += 1u; @@ -180,9 +190,16 @@ int64_t LLMInferWrapper::infer_next(int64_t token, bool skip_perf_stat) { last_token = std::get(sample_tokens(get_logits(), 1u)); infer_next_timer.end(); - if (!skip_perf_stat) { + // prepend perf stat + if (!append_perf_stat) { update_perf_metrics( - raw_perf_metrics, infer_next_timer.get_duration_microsec(), BATCH_SIZE); + raw_perf_metrics, + ov::genai::PerfMetrics::get_microsec(infer_end - infer_start), + infer_next_timer.get_duration_microsec(), + BATCH_SIZE); + } else { + raw_perf_metrics.m_durations.back() += infer_next_timer.get_duration_microsec(); + raw_perf_metrics.m_inference_durations[0] += ov::genai::PerfMetrics::get_microsec(infer_end - infer_start); } return last_token; @@ -213,7 +230,9 @@ std::vector LLMInferWrapper::infer_next_return_all(const std::vector LLMInferWrapper::infer_next_return_all(const std::vectorinfer_next(draft_prefix_token, true); - draft_prefix_timer.end(); + const bool draft_prefix_exists = (draft_prefix_token != -1); + if (draft_prefix_exists) { + m_draft_request->infer_next(draft_prefix_token); } + // Note: If `draft_prefix_exists == true`, then we append performance metrics of + // newly generated candidate to the previously generated token on draft prefix prompt, + // as we are only interested in one output from these two inference operations. + int64_t candidate = m_draft_request->infer_next(out_token, draft_prefix_exists); + candidates.push_back(candidate); - int64_t candidate = out_token; - for (size_t i = 0; i < candidates_to_generate; i++) { + for (size_t i = 1; i < candidates_to_generate; i++) { candidate = m_draft_request->infer_next(candidate); candidates.push_back(candidate); } - - // Dirty hack to update inference duration of token, that was calculated after inference - // on draft_prefix_token. - // As we are inferring on draft_prefix_token just to update KV-Cache and are not interested - // in predicted token output, then we accumulating this inference's duration with the one - // of the next token (which are interested in): - if (draft_prefix_token != -1) { - auto token_to_hack_it = m_draft_request->raw_perf_metrics.m_durations.end() - candidates_to_generate; - (*token_to_hack_it) += ov::genai::MicroSeconds(draft_prefix_timer.get_duration_microsec()); - } - draft_prefix_token = candidates.back(); candidates_timer.end(); diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp b/src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp index 7cd249d00b..c96ce4485b 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp @@ -32,7 +32,7 @@ class LLMInferWrapper { bool can_infer(const std::size_t prompt_len = 0); - int64_t infer_next(int64_t out_token, bool skip_perf_stat = false); + int64_t infer_next(int64_t out_token, bool append_perf_stat = false); std::vector infer_next_return_all(const std::vector tokens); From 83dd8d95ba55269c09108a40ab12e6311b78c19a Mon Sep 17 00:00:00 2001 From: Anastasiya Pronina Date: Thu, 28 Aug 2025 12:39:37 +0100 Subject: [PATCH 15/40] Fixed SD metrics --- .../src/speculative_decoding/speculative_decoding_impl.cpp | 5 +++++ .../src/speculative_decoding/speculative_decoding_npu.cpp | 2 ++ 2 files changed, 7 insertions(+) diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp index c98e0d7542..a7432f6447 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp @@ -218,7 +218,12 @@ void ContinuousBatchingPipeline::SpeculativeDecodingImpl::step() { raw_perf_counters.m_batch_sizes.emplace_back(num_generated_tokens); auto m_main_pipeline_metrics = m_main_pipeline->get_metrics(); +<<<<<<< HEAD main_raw_perf_counters.m_durations.push_back(MicroSeconds(main_duration)); +======= + main_raw_perf_counters.m_durations.push_back(MicroSeconds(main_model_gen_duration)); + // TODO: Ask about += +>>>>>>> 4b74b232 (Fixed SD metrics) main_raw_perf_counters.m_inference_durations[0] = MicroSeconds(m_main_pipeline_metrics.inference_duration); main_raw_perf_counters.m_batch_sizes.push_back(num_generated_tokens); // or should be processed + generated m_sd_metrics.update_generated_len(num_generated_tokens); diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp index 01f89ac2f4..03c72cf5e1 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp @@ -668,8 +668,10 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( auto& main_perf_generated_tokens = m_main_request->raw_perf_metrics.m_batch_sizes.back(); main_perf_generated_tokens -= mismatched_candidates; + m_sd_metrics.update_draft_generated_token_len(candidates_to_generate); m_sd_metrics.update_acceptance_rate(0 /* request_id */, (accepted_tokens_number / candidates_to_generate) * 100); m_sd_metrics.update_draft_accepted_tokens(0 /* request_id */, accepted_tokens_number); + m_sd_metrics.update_generated_len(validated_tokens.size()); if (utils::env_setup_for_print_debug_info()) { m_sd_metrics.print(true); m_sd_metrics.clean_up(); From be91e0da723f8d95d2d165fca42ec88940616fee Mon Sep 17 00:00:00 2001 From: Anastasiya Pronina Date: Thu, 28 Aug 2025 13:05:26 +0100 Subject: [PATCH 16/40] 1) 1 Returned SD metrics by pipeline. 2) Removed NPU constraint in llm_bench 3) Fixed pipeline TTFT --- .../speculative_decoding_npu.cpp | 26 +++++++++++++++---- .../speculative_decoding_npu.hpp | 3 +++ .../llm_bench/llm_bench_utils/model_utils.py | 4 +-- 3 files changed, 26 insertions(+), 7 deletions(-) diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp index 03c72cf5e1..c54b2e458a 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp @@ -198,8 +198,10 @@ int64_t LLMInferWrapper::infer_next(int64_t token, bool append_perf_stat) { infer_next_timer.get_duration_microsec(), BATCH_SIZE); } else { - raw_perf_metrics.m_durations.back() += infer_next_timer.get_duration_microsec(); - raw_perf_metrics.m_inference_durations[0] += ov::genai::PerfMetrics::get_microsec(infer_end - infer_start); + raw_perf_metrics.m_durations.back() += + ov::genai::MicroSeconds(infer_next_timer.get_duration_microsec()); + raw_perf_metrics.m_inference_durations[0] += + ov::genai::MicroSeconds(ov::genai::PerfMetrics::get_microsec(infer_end - infer_start)); } return last_token; @@ -353,7 +355,7 @@ SpeculativeLLMPipelineNPU::SpeculativeLLMPipelineNPU( // this model update will be reflected in draft_model_desc utils::apply_slice_before_matmul_transformation(draft_model); } - + // Main and Draft model can have different tokenizers // to do: support retokenization: 154103 ov::genai::Tokenizer main_model_tokenizer = main_model_desc.tokenizer; @@ -522,8 +524,16 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( utils::initialize_position_ids(position_ids, attention_mask); // To collect KV-cache for the prompt and to get the next token, run the very first infer request - // for draft and main models: + // for main and draft models: + ManualTimer first_token_timer("Speculative decode: first token timer"); + first_token_timer.start(); + auto out_token = m_main_request->infer_first(input_ids, attention_mask, position_ids); + + first_token_timer.end(); + update_perf_metrics(raw_perf_counters, first_token_timer.get_duration_microsec(), + first_token_timer.get_end_time(), 1u); + m_draft_request->infer_first(input_ids, attention_mask, position_ids); // logits shape is [BATCH_SIZE, seq_len, vocab_size] @@ -541,6 +551,7 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( auto streaming_status = stream_generated_tokens(streamer_ptr, std::vector {out_token}); results.tokens[0].push_back(out_token); + // first token? /* Speculative decoding works the following way. The draft model predicts the next K tokens one by one in an autoregressive manner, while the main model validates these predictions and corrects them if necessary. We go through each predicted token, and @@ -668,7 +679,7 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( auto& main_perf_generated_tokens = m_main_request->raw_perf_metrics.m_batch_sizes.back(); main_perf_generated_tokens -= mismatched_candidates; - m_sd_metrics.update_draft_generated_token_len(candidates_to_generate); + m_sd_metrics.update_draft_generated_len(0 /* request_id */, candidates_to_generate); m_sd_metrics.update_acceptance_rate(0 /* request_id */, (accepted_tokens_number / candidates_to_generate) * 100); m_sd_metrics.update_draft_accepted_tokens(0 /* request_id */, accepted_tokens_number); m_sd_metrics.update_generated_len(validated_tokens.size()); @@ -719,6 +730,11 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( return results; } +ov::genai::SpeculativeDecodingMetrics +SpeculativeLLMPipelineNPU::get_speculative_decoding_metrics() const { + return m_sd_metrics; +}; + void SpeculativeLLMPipelineNPU::start_chat(const std::string& system_message) { if (!system_message.empty()) { m_history.push_back({{"role", "system"}, {"content", system_message}}); diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp b/src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp index c96ce4485b..cabaf8941c 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp @@ -100,6 +100,9 @@ class SpeculativeLLMPipelineNPU : public ov::genai::LLMPipelineImplBase { void finish_chat() override; + ov::genai::SpeculativeDecodingMetrics + get_speculative_decoding_metrics() const; + ~SpeculativeLLMPipelineNPU(); private: diff --git a/tools/llm_bench/llm_bench_utils/model_utils.py b/tools/llm_bench/llm_bench_utils/model_utils.py index 56800cbea8..a37a528a1b 100644 --- a/tools/llm_bench/llm_bench_utils/model_utils.py +++ b/tools/llm_bench/llm_bench_utils/model_utils.py @@ -209,8 +209,8 @@ def analyze_args(args): if args.cb_config: cb_config = get_config(args.cb_config) model_args["cb_config"] = cb_config - if args.draft_model and (args.device == "NPU" or model_args['config']['ATTENTION_BACKEND'] != PA_ATTENTION_BACKEND): - log.warning("Speculative Decoding is supported only with Page Attention Backend and not supported for NPU device") + if args.draft_model and model_args['config']['ATTENTION_BACKEND'] != PA_ATTENTION_BACKEND: + log.warning("Speculative Decoding is supported only with Page Attention Backend") args.draft_model = None model_args['draft_model'] = args.draft_model model_args['draft_device'] = args.draft_device From b43bcdae59af87aae3b9e42c51b784542e319012 Mon Sep 17 00:00:00 2001 From: Anastasiya Pronina Date: Fri, 29 Aug 2025 11:22:12 +0100 Subject: [PATCH 17/40] Fixed typos in CB Speculative Decode perf metrics --- .../continuous_batching_for_speculative_decoding_impl.cpp | 2 +- .../src/speculative_decoding/speculative_decoding_impl.cpp | 7 +------ 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp index def8f88372..aaede5360f 100644 --- a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp @@ -319,7 +319,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::m auto pipeline_metrics = get_metrics(); if (num_generated_tokens > 0) { raw_perf_metrics.m_durations.emplace_back(generation_duration); - raw_perf_metrics.m_inference_durations[0] = MicroSeconds(pipeline_metrics.inference_duration); + raw_perf_metrics.m_inference_durations[0] += MicroSeconds(pipeline_metrics.inference_duration); raw_perf_metrics.m_batch_sizes.emplace_back(num_generated_tokens); } diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp index a7432f6447..4b2a4a13a2 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp @@ -218,13 +218,8 @@ void ContinuousBatchingPipeline::SpeculativeDecodingImpl::step() { raw_perf_counters.m_batch_sizes.emplace_back(num_generated_tokens); auto m_main_pipeline_metrics = m_main_pipeline->get_metrics(); -<<<<<<< HEAD main_raw_perf_counters.m_durations.push_back(MicroSeconds(main_duration)); -======= - main_raw_perf_counters.m_durations.push_back(MicroSeconds(main_model_gen_duration)); - // TODO: Ask about += ->>>>>>> 4b74b232 (Fixed SD metrics) - main_raw_perf_counters.m_inference_durations[0] = MicroSeconds(m_main_pipeline_metrics.inference_duration); + main_raw_perf_counters.m_inference_durations[0] += MicroSeconds(m_main_pipeline_metrics.inference_duration); main_raw_perf_counters.m_batch_sizes.push_back(num_generated_tokens); // or should be processed + generated m_sd_metrics.update_generated_len(num_generated_tokens); } From 967fbe243d96eacb250b1dd1003a33b3c6432462 Mon Sep 17 00:00:00 2001 From: Anastasiya Pronina Date: Mon, 1 Sep 2025 19:18:03 +0100 Subject: [PATCH 18/40] Extended the ManualTimer to be created only once --- src/cpp/src/continuous_batching/timer.hpp | 10 +++++++++ .../speculative_decoding_npu.cpp | 21 ++++++++++++++----- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/src/cpp/src/continuous_batching/timer.hpp b/src/cpp/src/continuous_batching/timer.hpp index c924a87b96..1c16203c37 100644 --- a/src/cpp/src/continuous_batching/timer.hpp +++ b/src/cpp/src/continuous_batching/timer.hpp @@ -14,6 +14,8 @@ class ManualTimer { public: ManualTimer(const std::string& title) : m_total(0.), + m_start(std::chrono::steady_clock::duration::zero()), + m_end(std::chrono::steady_clock::duration::zero()), m_title(title) { } @@ -42,6 +44,14 @@ class ManualTimer { return m_total; } + void clear() { + m_total = 0.0; + m_start = std::chrono::steady_clock::time_point( + std::chrono::steady_clock::duration::zero()); + m_end = std::chrono::steady_clock::time_point( + std::chrono::steady_clock::duration::zero()); + } + ~ManualTimer() { // std::cout << m_title << ": " << m_total / 1e6 << " secs" << std::endl; } diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp index c54b2e458a..d90b67ea6f 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp @@ -551,7 +551,11 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( auto streaming_status = stream_generated_tokens(streamer_ptr, std::vector {out_token}); results.tokens[0].push_back(out_token); - // first token? + // Creating timers for performance metrics calculation: + ManualTimer iteration_timer("Speculative decode: infer iteration"); + ManualTimer candidates_timer("Draft model: candidates generation"); + ManualTimer main_timer("Main model"); + /* Speculative decoding works the following way. The draft model predicts the next K tokens one by one in an autoregressive manner, while the main model validates these predictions and corrects them if necessary. We go through each predicted token, and @@ -569,11 +573,9 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( // So it will get into the context too. int64_t draft_prefix_token = -1; while (m_main_request->can_infer() && (streaming_status == ov::genai::StreamingStatus::RUNNING)) { - ManualTimer iteration_timer("Speculative decode: infer iteration"); iteration_timer.start(); // Phase 1: Generation of candidates with the draft model. - ManualTimer candidates_timer("Draft model: candidates generation"); candidates_timer.start(); std::vector candidates; @@ -593,7 +595,6 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( // KVCache capacity to sequentially infer them: if (remainder > 0 && m_main_request->can_infer(remainder)) { for (std::size_t i = 0; i < remainder; ++i) { - ManualTimer main_timer("Main model: inference of the remainder"); main_timer.start(); out_token = m_main_request->infer_next(out_token); main_timer.end(); @@ -604,6 +605,9 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( iteration_timer.end(); auto iteration_duration = iteration_timer.get_duration_microsec(); update_perf_metrics(raw_perf_counters, iteration_duration, main_timer.get_end_time(), 1u); + + main_timer.clear(); + iteration_timer.clear(); iteration_timer.start(); } } @@ -632,6 +636,7 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( candidates_timer.end(); m_sd_metrics.draft_duration += candidates_timer.get_duration(); + candidates_timer.clear(); // Phase 2. Main inference. // For the main network, candidates_size + 1 tokens will be fed at once in a @@ -643,7 +648,6 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( // prompt. In that tensor, for each token `t` of the input prompt it contains // distribution (over the vocabulary) for the next possible token that is // generated based on subsequence [first token,...,`t`] of the input prompt. - ManualTimer main_timer("Main model: inference on candidates"); main_timer.start(); std::vector input_for_main(candidates.begin(), candidates.end()); @@ -652,6 +656,7 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( main_timer.end(); m_sd_metrics.main_duration += main_timer.get_duration(); + main_timer.clear(); // Phase 3. Validation of candidates by output of main model: size_t accepted_tokens_number = 0u; @@ -694,6 +699,7 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( iteration_timer.end(); auto iteration_duration = iteration_timer.get_duration_microsec(); update_perf_metrics(raw_perf_counters, iteration_duration, main_timer.get_end_time(), validated_tokens.size()); + iteration_timer.clear(); } m_streaming_was_cancelled = (streaming_status == ov::genai::StreamingStatus::CANCEL); @@ -727,6 +733,11 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( results.perf_metrics = m_sd_perf_metrics; results.extended_perf_metrics = std::make_shared(m_sd_perf_metrics); + // Reset all timers. + generate_timer.clear(); + iteration_timer.clear(); + candidates_timer.clear(); + main_timer.clear(); return results; } From bb3a83fe103e7cee2de2749433d9eee7ea38f19b Mon Sep 17 00:00:00 2001 From: Anastasiya Pronina Date: Mon, 1 Sep 2025 20:04:41 +0100 Subject: [PATCH 19/40] Renaming pipeline --- ....cpp => speculative_decoding_stateful.cpp} | 24 +++++++++---------- ....hpp => speculative_decoding_stateful.hpp} | 6 ++--- 2 files changed, 15 insertions(+), 15 deletions(-) rename src/cpp/src/speculative_decoding/{speculative_decoding_npu.cpp => speculative_decoding_stateful.cpp} (97%) rename src/cpp/src/speculative_decoding/{speculative_decoding_npu.hpp => speculative_decoding_stateful.hpp} (95%) diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_stateful.cpp similarity index 97% rename from src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp rename to src/cpp/src/speculative_decoding/speculative_decoding_stateful.cpp index d90b67ea6f..64896977b9 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_npu.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_stateful.cpp @@ -1,7 +1,7 @@ // Copyright (C) 2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 -#include "speculative_decoding_npu.hpp" +#include "speculative_decoding_stateful.hpp" #include "continuous_batching/timer.hpp" #include "openvino/runtime/core.hpp" #include "openvino/core/parallel.hpp" @@ -342,7 +342,7 @@ std::variant> } } -SpeculativeLLMPipelineNPU::SpeculativeLLMPipelineNPU( +StatefulSpeculativeLLMPipeline::StatefulSpeculativeLLMPipeline( const ov::genai::ModelDesc& main_model_desc, const ov::genai::ModelDesc& draft_model_desc ) : LLMPipelineImplBase(main_model_desc.tokenizer, main_model_desc.generation_config) { @@ -387,12 +387,12 @@ SpeculativeLLMPipelineNPU::SpeculativeLLMPipelineNPU( m_sd_perf_metrics = ov::genai::SDPerModelsPerfMetrics(); } -DecodedResults SpeculativeLLMPipelineNPU::generate( +DecodedResults StatefulSpeculativeLLMPipeline::generate( StringInputs inputs, OptionalGenerationConfig generation_config, StreamerVariant streamer ) { - ManualTimer generate_timer("SpeculativeLLMPipelineNPU::generate()"); + ManualTimer generate_timer("StatefulSpeculativeLLMPipeline::generate()"); generate_timer.start(); ManualTimer encode_timer("Encode"); encode_timer.start(); @@ -459,11 +459,11 @@ DecodedResults SpeculativeLLMPipelineNPU::generate( return decoded_results; } -EncodedResults SpeculativeLLMPipelineNPU::generate( +EncodedResults StatefulSpeculativeLLMPipeline::generate( const EncodedInputs& inputs, OptionalGenerationConfig generation_config, StreamerVariant streamer) { - ManualTimer generate_timer("SpeculativeLLMPipelineNPU::generate()"); + ManualTimer generate_timer("StatefulSpeculativeLLMPipeline::generate()"); generate_timer.start(); ov::Tensor input_ids; @@ -656,7 +656,6 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( main_timer.end(); m_sd_metrics.main_duration += main_timer.get_duration(); - main_timer.clear(); // Phase 3. Validation of candidates by output of main model: size_t accepted_tokens_number = 0u; @@ -700,6 +699,7 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( auto iteration_duration = iteration_timer.get_duration_microsec(); update_perf_metrics(raw_perf_counters, iteration_duration, main_timer.get_end_time(), validated_tokens.size()); iteration_timer.clear(); + main_timer.clear(); } m_streaming_was_cancelled = (streaming_status == ov::genai::StreamingStatus::CANCEL); @@ -742,30 +742,30 @@ EncodedResults SpeculativeLLMPipelineNPU::generate( } ov::genai::SpeculativeDecodingMetrics -SpeculativeLLMPipelineNPU::get_speculative_decoding_metrics() const { +StatefulSpeculativeLLMPipeline::get_speculative_decoding_metrics() const { return m_sd_metrics; }; -void SpeculativeLLMPipelineNPU::start_chat(const std::string& system_message) { +void StatefulSpeculativeLLMPipeline::start_chat(const std::string& system_message) { if (!system_message.empty()) { m_history.push_back({{"role", "system"}, {"content", system_message}}); } m_is_chat_conversation = true; }; -void SpeculativeLLMPipelineNPU::finish_chat() { +void StatefulSpeculativeLLMPipeline::finish_chat() { m_is_chat_conversation = false; m_history.clear(); m_draft_request->reset_state(); m_main_request->reset_state(); }; -SpeculativeLLMPipelineNPU::~SpeculativeLLMPipelineNPU() { +StatefulSpeculativeLLMPipeline::~StatefulSpeculativeLLMPipeline() { m_main_request->release_memory(); m_draft_request->release_memory(); } -void SpeculativeLLMPipelineNPU::update_candidate_strategy(const std::size_t matches_num) { +void StatefulSpeculativeLLMPipeline::update_candidate_strategy(const std::size_t matches_num) { // Dynamically adjust number of generated candidates based on number of matches, // we want to balance the benefits of getting candidates tokens correct with the // cost of forecasting incorrect candidates tokens. diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp b/src/cpp/src/speculative_decoding/speculative_decoding_stateful.hpp similarity index 95% rename from src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp rename to src/cpp/src/speculative_decoding/speculative_decoding_stateful.hpp index cabaf8941c..2740933334 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_npu.hpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_stateful.hpp @@ -77,9 +77,9 @@ class LLMInferWrapper { std::vector m_new_atten_mask_data; }; -class SpeculativeLLMPipelineNPU : public ov::genai::LLMPipelineImplBase { +class StatefulSpeculativeLLMPipeline : public ov::genai::LLMPipelineImplBase { public: - SpeculativeLLMPipelineNPU( + StatefulSpeculativeLLMPipeline( const ov::genai::ModelDesc& main_model_desc, const ov::genai::ModelDesc& draft_model_desc ); @@ -103,7 +103,7 @@ class SpeculativeLLMPipelineNPU : public ov::genai::LLMPipelineImplBase { ov::genai::SpeculativeDecodingMetrics get_speculative_decoding_metrics() const; - ~SpeculativeLLMPipelineNPU(); + ~StatefulSpeculativeLLMPipeline(); private: void update_candidate_strategy(const std::size_t matches_num); From f03d14d03f047419560657f1411bd6ef941bb83d Mon Sep 17 00:00:00 2001 From: Anastasiya Pronina Date: Tue, 2 Sep 2025 15:51:11 +0100 Subject: [PATCH 20/40] Dispatching of LLM pipeline on NPU to Stateful or StatefulSpeculative is done in factory method --- src/cpp/src/continuous_batching/pipeline.cpp | 18 +--- src/cpp/src/llm/pipeline.cpp | 64 ++++++++++-- src/cpp/src/llm/pipeline_stateful_npu.cpp | 100 ------------------- src/cpp/src/llm/pipeline_stateful_npu.hpp | 56 ----------- src/cpp/src/utils.cpp | 10 ++ src/cpp/src/utils.hpp | 2 + 6 files changed, 70 insertions(+), 180 deletions(-) delete mode 100644 src/cpp/src/llm/pipeline_stateful_npu.cpp delete mode 100644 src/cpp/src/llm/pipeline_stateful_npu.hpp diff --git a/src/cpp/src/continuous_batching/pipeline.cpp b/src/cpp/src/continuous_batching/pipeline.cpp index 103c17d10a..36af66979c 100644 --- a/src/cpp/src/continuous_batching/pipeline.cpp +++ b/src/cpp/src/continuous_batching/pipeline.cpp @@ -19,16 +19,6 @@ using namespace ov::genai; namespace { -ov::genai::ModelDesc -extract_draft_model_from_config(ov::AnyMap& config) { - ov::genai::ModelDesc draft_model; - if (config.find(utils::DRAFT_MODEL_ARG_NAME) != config.end()) { - draft_model = config.at(utils::DRAFT_MODEL_ARG_NAME).as(); - config.erase(utils::DRAFT_MODEL_ARG_NAME); - } - return draft_model; -} - bool extract_prompt_lookup_from_config(ov::AnyMap& config) { bool res = false; @@ -53,7 +43,7 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( const std::filesystem::p const ov::AnyMap& vision_encoder_properties) { auto start_time = std::chrono::steady_clock::now(); auto properties_without_draft_model = properties; - auto draft_model_desr = extract_draft_model_from_config(properties_without_draft_model); + auto draft_model_desr = utils::extract_draft_model_from_config(properties_without_draft_model); auto is_prompt_lookup_enabled = extract_prompt_lookup_from_config(properties_without_draft_model); auto model = utils::read_model(models_path, properties); @@ -93,7 +83,7 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( const ov::AnyMap& properties) { auto start_time = std::chrono::steady_clock::now(); auto properties_without_draft_model = properties; - auto draft_model_desr = extract_draft_model_from_config(properties_without_draft_model); + auto draft_model_desr = utils::extract_draft_model_from_config(properties_without_draft_model); auto is_prompt_lookup_enabled = extract_prompt_lookup_from_config(properties_without_draft_model); auto model = utils::read_model(models_path, properties_without_draft_model); @@ -135,7 +125,7 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( auto start_time = std::chrono::steady_clock::now(); auto properties_without_draft_model = properties; - auto draft_model_desr = extract_draft_model_from_config(properties_without_draft_model); + auto draft_model_desr = utils::extract_draft_model_from_config(properties_without_draft_model); auto is_prompt_lookup_enabled = extract_prompt_lookup_from_config(properties_without_draft_model); auto model = utils::singleton_core().read_model(model_str, weights_tensor); @@ -178,7 +168,7 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( auto start_time = std::chrono::steady_clock::now(); auto properties_without_draft_model = properties; - auto draft_model_desr = extract_draft_model_from_config(properties_without_draft_model); + auto draft_model_desr = utils::extract_draft_model_from_config(properties_without_draft_model); auto is_prompt_lookup_enabled = extract_prompt_lookup_from_config(properties_without_draft_model); auto model_pair = utils::get_model_weights_pair(models_map, "language"); auto model = utils::singleton_core().read_model(model_pair.first, model_pair.second); diff --git a/src/cpp/src/llm/pipeline.cpp b/src/cpp/src/llm/pipeline.cpp index 6cea401798..d976471412 100644 --- a/src/cpp/src/llm/pipeline.cpp +++ b/src/cpp/src/llm/pipeline.cpp @@ -12,7 +12,7 @@ #include "llm/pipeline_stateful.hpp" #include "llm/pipeline_continuous_batching_adapter.hpp" #include "speculative_decoding/speculative_decoding_impl.hpp" -#include "llm/pipeline_stateful_npu.hpp" +#include "speculative_decoding/speculative_decoding_stateful.hpp" #include "utils.hpp" namespace ov { @@ -60,6 +60,51 @@ std::pair draft_model( return { utils::DRAFT_MODEL_ARG_NAME, Any::make(model, tokenizer, device, plugin_config, scheduler_config, generation_config) }; } +// NOTE: Should be used only when NPU device is requested +// either for main model or for draft model if last exists. +class StatefulNPUPipelineCreator { +public: +static std::unique_ptr create( + const std::filesystem::path& models_path, + const ov::genai::Tokenizer& tokenizer, + const std::string& device, + const ov::AnyMap& properties) { + return create( + ov::genai::utils::read_model(models_path, properties), + tokenizer, + device, + properties, + utils::from_config_json_if_exists(models_path)); +} + +static std::unique_ptr create( + const std::filesystem::path& models_path, + const std::string& device, + const ov::AnyMap& plugin_config) { + return create(models_path, Tokenizer(models_path, plugin_config), device, plugin_config); +} + +static std::unique_ptr create( + const std::shared_ptr& model, + const ov::genai::Tokenizer& tokenizer, + const std::string& device, + const ov::AnyMap& properties, + const ov::genai::GenerationConfig& generation_config) { + + auto properties_without_draft_model = properties; + auto draft_model_descr = ov::genai::utils::extract_draft_model_from_config(properties_without_draft_model); + if (draft_model_descr.model != nullptr) { + auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model, {}, generation_config); + OPENVINO_ASSERT((draft_model_descr.device == "NPU") || (main_model_descr.device == "NPU")); + return std::make_unique(main_model_descr, draft_model_descr); + } + + OPENVINO_ASSERT(device == "NPU"); + return std::make_unique(model, tokenizer, device, + properties_without_draft_model, generation_config); +} +}; + // Public LLMPipeline ov::genai::LLMPipeline::LLMPipeline( @@ -81,7 +126,7 @@ ov::genai::LLMPipeline::LLMPipeline( auto [properties, attention_backend] = utils::extract_attention_backend(user_properties); if (ov::genai::utils::is_npu_requested(device, properties)) { - m_pimpl = std::make_unique(models_path, tokenizer, device, properties); + m_pimpl = StatefulNPUPipelineCreator::create(models_path, tokenizer, device, properties); } else if (utils::explicitly_requires_paged_attention(user_properties)) { // If CB is invoked explicitly, create CB adapter as is and re-throw in case if internal issues auto [device_properties, scheduler_config] = utils::extract_scheduler_config(properties, utils::get_latency_oriented_scheduler_config()); @@ -116,12 +161,11 @@ ov::genai::LLMPipeline::LLMPipeline( auto [properties, attention_backend] = utils::extract_attention_backend(user_properties); if (ov::genai::utils::is_npu_requested(device, properties)) { - m_pimpl = std::make_unique(models_path, device, properties); + m_pimpl = StatefulNPUPipelineCreator::create(models_path, device, properties); } else if (utils::explicitly_requires_paged_attention(user_properties)) { // If CB is invoked explicitly, create CB adapter as is and re-throw in case if internal issues auto [device_properties, scheduler_config] = utils::extract_scheduler_config(properties, utils::get_latency_oriented_scheduler_config()); m_pimpl = std::make_unique(models_path, scheduler_config, device, device_properties); - } else if (attention_backend == PA_BACKEND) { // try to call CB adapter one more time, but with safe guard to silent exception try { @@ -155,12 +199,12 @@ ov::genai::LLMPipeline::LLMPipeline( auto [properties, attention_backend] = utils::extract_attention_backend(user_properties); if (ov::genai::utils::is_npu_requested(device, properties)) { - m_pimpl = std::make_unique( - utils::singleton_core().read_model(model_str, weights_tensor), - tokenizer, - device, - properties, - generation_config); + m_pimpl = StatefulNPUPipelineCreator::create( + utils::singleton_core().read_model(model_str, weights_tensor), + tokenizer, + device, + properties, + generation_config); } else if (utils::explicitly_requires_paged_attention(user_properties)) { // If CB is invoked explicitly, create CB adapter as is and re-throw in case if internal issues auto [device_properties, scheduler_config] = utils::extract_scheduler_config(properties, utils::get_latency_oriented_scheduler_config()); diff --git a/src/cpp/src/llm/pipeline_stateful_npu.cpp b/src/cpp/src/llm/pipeline_stateful_npu.cpp deleted file mode 100644 index 1cce1f9516..0000000000 --- a/src/cpp/src/llm/pipeline_stateful_npu.cpp +++ /dev/null @@ -1,100 +0,0 @@ - -// Copyright (C) 2025 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 - -#include "pipeline_stateful_npu.hpp" -#include "speculative_decoding/speculative_decoding_npu.hpp" -#include "llm/pipeline_stateful.hpp" -#include "llm/pipeline_static.hpp" -#include "utils.hpp" - -#include - -#include "openvino/runtime/core.hpp" -#include "openvino/core/parallel.hpp" -#include "openvino/genai/text_streamer.hpp" - -namespace { - ov::genai::ModelDesc - extract_draft_model_from_config(ov::AnyMap& config) { - ov::genai::ModelDesc draft_model; - if (config.find(ov::genai::utils::DRAFT_MODEL_ARG_NAME) != config.end()) { - draft_model = config.at(ov::genai::utils::DRAFT_MODEL_ARG_NAME).as(); - config.erase(ov::genai::utils::DRAFT_MODEL_ARG_NAME); - } - return draft_model; -} -} // anonymous namespace - -namespace ov::genai { - -// NB: No constructor for creation of pipeline from infer request, as pipeline from infer request -// for NPU is handled inside of ov::genai::StatefulLLMPipeline class iself. -StatefulLLMPipelineNPU::StatefulLLMPipelineNPU( - const std::filesystem::path& models_path, - const ov::genai::Tokenizer& tokenizer, - const std::string& device, - const ov::AnyMap& properties) - : StatefulLLMPipelineNPU( - utils::read_model(models_path, properties), - tokenizer, - device, - properties, - utils::from_config_json_if_exists(models_path) - ) {} - -StatefulLLMPipelineNPU::StatefulLLMPipelineNPU( - const std::filesystem::path& models_path, - const std::string& device, - const ov::AnyMap& plugin_config) - : StatefulLLMPipelineNPU{models_path, Tokenizer(models_path, plugin_config), device, plugin_config} {} - -StatefulLLMPipelineNPU::StatefulLLMPipelineNPU( - const std::shared_ptr& model, - const ov::genai::Tokenizer& tokenizer, - const std::string& device, - const ov::AnyMap& properties, - const ov::genai::GenerationConfig& generation_config) - : LLMPipelineImplBase(tokenizer, generation_config) { - auto properties_without_draft_model = properties; - auto draft_model_descr = extract_draft_model_from_config(properties_without_draft_model); - if (draft_model_descr.model != nullptr) { - auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model, {}, generation_config); - m_pimpl = std::make_unique(main_model_descr, draft_model_descr); - } else if (properties_without_draft_model.count("STATIC_PIPELINE")) { - m_pimpl = static_llm::LLMPipelineFactory::create(model, tokenizer, - properties_without_draft_model, generation_config); - } else { - m_pimpl = std::make_unique(model, tokenizer, "NPU", - properties_without_draft_model, generation_config); - } -} - -DecodedResults StatefulLLMPipelineNPU::generate( - StringInputs inputs, - OptionalGenerationConfig generation_config, - StreamerVariant streamer) { - return m_pimpl->generate(inputs, generation_config, streamer); -} - -EncodedResults StatefulLLMPipelineNPU::generate( - const EncodedInputs& inputs, - OptionalGenerationConfig generation_config, - StreamerVariant streamer) { - return m_pimpl->generate(inputs, generation_config, streamer); -} - -void StatefulLLMPipelineNPU::start_chat(const std::string& system_message) { - m_pimpl->start_chat(system_message); -} - -// FIXME: Do we need it? -// void StatefulLLMPipelineNPU::reset_kv_state() { -// m_pimpl->reset_kv_state(); -// } - -void StatefulLLMPipelineNPU::finish_chat() { - m_pimpl->finish_chat(); -} - -} // namespace ov::genai diff --git a/src/cpp/src/llm/pipeline_stateful_npu.hpp b/src/cpp/src/llm/pipeline_stateful_npu.hpp deleted file mode 100644 index 5e88050501..0000000000 --- a/src/cpp/src/llm/pipeline_stateful_npu.hpp +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (C) 2025 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 - - -#include - -#include "llm/pipeline_base.hpp" - -namespace ov::genai { - -class StatefulLLMPipelineNPU final : public LLMPipelineImplBase { -public: - StatefulLLMPipelineNPU( - const std::filesystem::path& models_path, - const ov::genai::Tokenizer& tokenizer, - const std::string& device, - const ov::AnyMap& plugin_config - ); - - StatefulLLMPipelineNPU( - const std::filesystem::path& models_path, - const std::string& device, - const ov::AnyMap& plugin_config - ); - - StatefulLLMPipelineNPU( - const std::shared_ptr& model, - const ov::genai::Tokenizer& tokenizer, - const std::string& device, - const ov::AnyMap& config, - const ov::genai::GenerationConfig& generation_config - ); - - DecodedResults generate( - StringInputs inputs, - OptionalGenerationConfig generation_config, - StreamerVariant streamer - ) override; - - EncodedResults generate( - const EncodedInputs& inputs, - OptionalGenerationConfig generation_config, - StreamerVariant streamer - ) override; - - void start_chat(const std::string& system_message) override; - - void finish_chat() override; - - ~StatefulLLMPipelineNPU() = default; - -private: - std::unique_ptr m_pimpl; -}; - -} // namespace ov::genai diff --git a/src/cpp/src/utils.cpp b/src/cpp/src/utils.cpp index aed102863f..5e7b86ec70 100644 --- a/src/cpp/src/utils.cpp +++ b/src/cpp/src/utils.cpp @@ -104,6 +104,7 @@ inline bool is_paged_attention_available() { return false; #endif } + } // anonymous namespace ov { @@ -212,6 +213,15 @@ ov::genai::ModelDesc get_draft_model_from_config(const ov::AnyMap& config) { return draft_model; } +ov::genai::ModelDesc extract_draft_model_from_config(ov::AnyMap& config) { + ov::genai::ModelDesc draft_model; + if (config.find(ov::genai::utils::DRAFT_MODEL_ARG_NAME) != config.end()) { + draft_model = config.at(ov::genai::utils::DRAFT_MODEL_ARG_NAME).as(); + config.erase(ov::genai::utils::DRAFT_MODEL_ARG_NAME); + } + return draft_model; +} + bool is_npu_requested(const std::string& device, const ov::AnyMap& properties) { if (device == "NPU") { return true; diff --git a/src/cpp/src/utils.hpp b/src/cpp/src/utils.hpp index 4995b3b754..471de1192d 100644 --- a/src/cpp/src/utils.hpp +++ b/src/cpp/src/utils.hpp @@ -120,6 +120,8 @@ ProcessorConfig from_any_map( ov::genai::ModelDesc get_draft_model_from_config(const ov::AnyMap& config); +ov::genai::ModelDesc extract_draft_model_from_config(ov::AnyMap& config); + bool is_npu_requested(const std::string& device, const ov::AnyMap& properties); ov::genai::TokenizedInputs subtract_chat_tokenized_inputs(const ov::genai::TokenizedInputs& minuend, const ov::genai::TokenizedInputs& subtrahend); From 00580de2f9b872570c8ba1b8d99f3fbdfd4afc31 Mon Sep 17 00:00:00 2001 From: Anastasiya Pronina Date: Tue, 2 Sep 2025 18:46:14 +0100 Subject: [PATCH 21/40] Refixed llm_bench --- tools/llm_bench/llm_bench_utils/model_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tools/llm_bench/llm_bench_utils/model_utils.py b/tools/llm_bench/llm_bench_utils/model_utils.py index a37a528a1b..597561b136 100644 --- a/tools/llm_bench/llm_bench_utils/model_utils.py +++ b/tools/llm_bench/llm_bench_utils/model_utils.py @@ -209,9 +209,10 @@ def analyze_args(args): if args.cb_config: cb_config = get_config(args.cb_config) model_args["cb_config"] = cb_config - if args.draft_model and model_args['config']['ATTENTION_BACKEND'] != PA_ATTENTION_BACKEND: - log.warning("Speculative Decoding is supported only with Page Attention Backend") - args.draft_model = None + if args.draft_model: + if (args.draft_device != "NPU" and args.device != "NPU" and model_args['config']['ATTENTION_BACKEND'] != PA_ATTENTION_BACKEND): + log.warning("Speculative Decoding is supported only with Page Attention Backend for non-NPU devices") + args.draft_model = None model_args['draft_model'] = args.draft_model model_args['draft_device'] = args.draft_device draft_cb_config = None From 88067e466d2ce5c3bc120c24013a5e9fede72db5 Mon Sep 17 00:00:00 2001 From: Anastasiya Pronina Date: Mon, 8 Sep 2025 19:15:53 +0100 Subject: [PATCH 22/40] Factory method is StatefulPipeline now --- src/cpp/src/llm/pipeline.cpp | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/src/cpp/src/llm/pipeline.cpp b/src/cpp/src/llm/pipeline.cpp index d976471412..954942d5b1 100644 --- a/src/cpp/src/llm/pipeline.cpp +++ b/src/cpp/src/llm/pipeline.cpp @@ -60,9 +60,7 @@ std::pair draft_model( return { utils::DRAFT_MODEL_ARG_NAME, Any::make(model, tokenizer, device, plugin_config, scheduler_config, generation_config) }; } -// NOTE: Should be used only when NPU device is requested -// either for main model or for draft model if last exists. -class StatefulNPUPipelineCreator { +class StatefulPipeline { public: static std::unique_ptr create( const std::filesystem::path& models_path, @@ -95,11 +93,9 @@ static std::unique_ptr create( auto draft_model_descr = ov::genai::utils::extract_draft_model_from_config(properties_without_draft_model); if (draft_model_descr.model != nullptr) { auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model, {}, generation_config); - OPENVINO_ASSERT((draft_model_descr.device == "NPU") || (main_model_descr.device == "NPU")); return std::make_unique(main_model_descr, draft_model_descr); } - OPENVINO_ASSERT(device == "NPU"); return std::make_unique(model, tokenizer, device, properties_without_draft_model, generation_config); } @@ -126,7 +122,7 @@ ov::genai::LLMPipeline::LLMPipeline( auto [properties, attention_backend] = utils::extract_attention_backend(user_properties); if (ov::genai::utils::is_npu_requested(device, properties)) { - m_pimpl = StatefulNPUPipelineCreator::create(models_path, tokenizer, device, properties); + m_pimpl = StatefulPipeline::create(models_path, tokenizer, device, properties); } else if (utils::explicitly_requires_paged_attention(user_properties)) { // If CB is invoked explicitly, create CB adapter as is and re-throw in case if internal issues auto [device_properties, scheduler_config] = utils::extract_scheduler_config(properties, utils::get_latency_oriented_scheduler_config()); @@ -145,7 +141,7 @@ ov::genai::LLMPipeline::LLMPipeline( } if (m_pimpl == nullptr) { - m_pimpl = std::make_unique(models_path, tokenizer, device, properties); + m_pimpl = StatefulPipeline::create(models_path, tokenizer, device, properties); } m_pimpl->save_load_time(start_time); @@ -161,7 +157,7 @@ ov::genai::LLMPipeline::LLMPipeline( auto [properties, attention_backend] = utils::extract_attention_backend(user_properties); if (ov::genai::utils::is_npu_requested(device, properties)) { - m_pimpl = StatefulNPUPipelineCreator::create(models_path, device, properties); + m_pimpl = StatefulPipeline::create(models_path, device, properties); } else if (utils::explicitly_requires_paged_attention(user_properties)) { // If CB is invoked explicitly, create CB adapter as is and re-throw in case if internal issues auto [device_properties, scheduler_config] = utils::extract_scheduler_config(properties, utils::get_latency_oriented_scheduler_config()); @@ -180,7 +176,7 @@ ov::genai::LLMPipeline::LLMPipeline( } if (m_pimpl == nullptr) { - m_pimpl = std::make_unique(models_path, device, properties); + m_pimpl = StatefulPipeline::create(models_path, device, properties); } m_pimpl->save_load_time(start_time); @@ -199,7 +195,7 @@ ov::genai::LLMPipeline::LLMPipeline( auto [properties, attention_backend] = utils::extract_attention_backend(user_properties); if (ov::genai::utils::is_npu_requested(device, properties)) { - m_pimpl = StatefulNPUPipelineCreator::create( + m_pimpl = StatefulPipeline::create( utils::singleton_core().read_model(model_str, weights_tensor), tokenizer, device, @@ -225,7 +221,7 @@ ov::genai::LLMPipeline::LLMPipeline( } if (m_pimpl == nullptr) { - m_pimpl = std::make_unique( + m_pimpl = StatefulPipeline::create( utils::singleton_core().read_model(model_str, weights_tensor), tokenizer, device, From 5cdf96e742fad6381c1f0130b6fd1d2f8e128112 Mon Sep 17 00:00:00 2001 From: Anastasiya Pronina Date: Tue, 9 Sep 2025 02:25:18 +0100 Subject: [PATCH 23/40] Removed PA backend constraint for Speculative Decode and added check before draft properties initialization --- .../speculative_decoding/speculative_decoding_stateful.cpp | 2 +- src/cpp/src/utils.cpp | 7 ------- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_stateful.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_stateful.cpp index 64896977b9..8d28b828dd 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_stateful.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_stateful.cpp @@ -369,7 +369,7 @@ StatefulSpeculativeLLMPipeline::StatefulSpeculativeLLMPipeline( if (draft_model_desc_copy.device.empty()) { draft_model_desc_copy.device = main_model_desc.device; } - if (draft_model_desc_copy.properties.empty()) { + if (draft_model_desc_copy.properties.empty() && (draft_model_desc_copy.device == main_model_desc.device)) { draft_model_desc_copy.properties = main_model_desc.properties; } m_draft_request = std::make_unique(draft_model_desc_copy); diff --git a/src/cpp/src/utils.cpp b/src/cpp/src/utils.cpp index 5e7b86ec70..f17b05a2ce 100644 --- a/src/cpp/src/utils.cpp +++ b/src/cpp/src/utils.cpp @@ -642,13 +642,6 @@ bool explicitly_requires_paged_attention(const ov::AnyMap& properties) { OPENVINO_THROW("Continuous batching backend requires PagedAttention operation support, which is available on x86_64 or ARM64 platforms only"); } } - if (properties.find(utils::DRAFT_MODEL_ARG_NAME) != properties.end()) { - if (is_paged_attention_available()) { - return true; - } else { - OPENVINO_THROW("Speculative decoding requires PagedAttention operation support, which is available on x86_64 or ARM64 platforms only"); - } - } auto prompt_lookup_prop = properties.find("prompt_lookup"); if (prompt_lookup_prop != properties.end() && prompt_lookup_prop->second.as() == true) { From ba3965856edcad9b151d853e0c063a94f0962568 Mon Sep 17 00:00:00 2001 From: Anastasiya Pronina Date: Wed, 10 Sep 2025 15:10:45 +0100 Subject: [PATCH 24/40] Removed PA constraint for llm_bench --- tools/llm_bench/llm_bench_utils/model_utils.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tools/llm_bench/llm_bench_utils/model_utils.py b/tools/llm_bench/llm_bench_utils/model_utils.py index 597561b136..8fdaad05d0 100644 --- a/tools/llm_bench/llm_bench_utils/model_utils.py +++ b/tools/llm_bench/llm_bench_utils/model_utils.py @@ -209,10 +209,6 @@ def analyze_args(args): if args.cb_config: cb_config = get_config(args.cb_config) model_args["cb_config"] = cb_config - if args.draft_model: - if (args.draft_device != "NPU" and args.device != "NPU" and model_args['config']['ATTENTION_BACKEND'] != PA_ATTENTION_BACKEND): - log.warning("Speculative Decoding is supported only with Page Attention Backend for non-NPU devices") - args.draft_model = None model_args['draft_model'] = args.draft_model model_args['draft_device'] = args.draft_device draft_cb_config = None From 0879b6ef633ade80e8a39275988233dcfb1c1fd1 Mon Sep 17 00:00:00 2001 From: Anastasiya Pronina Date: Wed, 10 Sep 2025 15:32:55 +0100 Subject: [PATCH 25/40] Updated sample to reflect enabled feature --- samples/python/text_generation/speculative_decoding_lm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/samples/python/text_generation/speculative_decoding_lm.py b/samples/python/text_generation/speculative_decoding_lm.py index 8484f50c46..ec2795d1da 100755 --- a/samples/python/text_generation/speculative_decoding_lm.py +++ b/samples/python/text_generation/speculative_decoding_lm.py @@ -20,7 +20,7 @@ def main(): # User can run main and draft model on different devices. # Please, set device for main model in `openvino_genai.LLMPipeline` constructor and in openvino_genai.draft_model` for draft. - main_device = 'CPU' # GPU can be used as well + main_device = 'CPU' # GPU or NPU can be used as well draft_device = 'CPU' draft_model = openvino_genai.draft_model(args.draft_model_dir, draft_device) From cf766c0105d533a1a50abd737c501e38dd7b67f9 Mon Sep 17 00:00:00 2001 From: Anastasiya Pronina Date: Thu, 18 Sep 2025 22:45:40 +0100 Subject: [PATCH 26/40] Addressed review comments --- src/cpp/src/continuous_batching/timer.hpp | 10 +++---- .../speculative_decoding_stateful.cpp | 28 +++++++++---------- .../speculative_decoding_stateful.hpp | 4 +-- 3 files changed, 20 insertions(+), 22 deletions(-) diff --git a/src/cpp/src/continuous_batching/timer.hpp b/src/cpp/src/continuous_batching/timer.hpp index 1c16203c37..17d9229fa9 100644 --- a/src/cpp/src/continuous_batching/timer.hpp +++ b/src/cpp/src/continuous_batching/timer.hpp @@ -14,8 +14,8 @@ class ManualTimer { public: ManualTimer(const std::string& title) : m_total(0.), - m_start(std::chrono::steady_clock::duration::zero()), - m_end(std::chrono::steady_clock::duration::zero()), + m_start(), + m_end(), m_title(title) { } @@ -46,10 +46,8 @@ class ManualTimer { void clear() { m_total = 0.0; - m_start = std::chrono::steady_clock::time_point( - std::chrono::steady_clock::duration::zero()); - m_end = std::chrono::steady_clock::time_point( - std::chrono::steady_clock::duration::zero()); + m_start = std::chrono::steady_clock::time_point(); + m_end = std::chrono::steady_clock::time_point(); } ~ManualTimer() { diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_stateful.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_stateful.cpp index 8d28b828dd..37bd722522 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_stateful.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_stateful.cpp @@ -25,18 +25,18 @@ ov::genai::StreamingStatus stream_generated_tokens(std::shared_ptr(sample_tokens(get_logits(), 1u)); infer_first_timer.end(); - update_perf_metrics(raw_perf_metrics, + update_perf_stat_by_infer_duration(raw_perf_metrics, ov::genai::PerfMetrics::get_microsec(infer_end - infer_start), infer_first_timer.get_duration_microsec(), BATCH_SIZE); return last_token; @@ -192,7 +192,7 @@ int64_t LLMInferWrapper::infer_next(int64_t token, bool append_perf_stat) { infer_next_timer.end(); // prepend perf stat if (!append_perf_stat) { - update_perf_metrics( + update_perf_stat_by_infer_duration( raw_perf_metrics, ov::genai::PerfMetrics::get_microsec(infer_end - infer_start), infer_next_timer.get_duration_microsec(), @@ -254,7 +254,7 @@ std::vector LLMInferWrapper::infer_next_return_all(const std::vectorinfer_first(input_ids, attention_mask, position_ids); first_token_timer.end(); - update_perf_metrics(raw_perf_counters, first_token_timer.get_duration_microsec(), - first_token_timer.get_end_time(), 1u); + update_perf_stat_by_token_time(raw_perf_counters, first_token_timer.get_duration_microsec(), + first_token_timer.get_end_time(), 1u); m_draft_request->infer_first(input_ids, attention_mask, position_ids); @@ -604,7 +604,7 @@ EncodedResults StatefulSpeculativeLLMPipeline::generate( iteration_timer.end(); auto iteration_duration = iteration_timer.get_duration_microsec(); - update_perf_metrics(raw_perf_counters, iteration_duration, main_timer.get_end_time(), 1u); + update_perf_stat_by_token_time(raw_perf_counters, iteration_duration, main_timer.get_end_time(), 1u); main_timer.clear(); iteration_timer.clear(); @@ -697,7 +697,7 @@ EncodedResults StatefulSpeculativeLLMPipeline::generate( iteration_timer.end(); auto iteration_duration = iteration_timer.get_duration_microsec(); - update_perf_metrics(raw_perf_counters, iteration_duration, main_timer.get_end_time(), validated_tokens.size()); + update_perf_stat_by_token_time(raw_perf_counters, iteration_duration, main_timer.get_end_time(), validated_tokens.size()); iteration_timer.clear(); main_timer.clear(); } @@ -772,7 +772,7 @@ void StatefulSpeculativeLLMPipeline::update_candidate_strategy(const std::size_t if (matches_num == m_candidates_num) { m_candidates_num = std::min(m_candidates_num + 2, m_max_candidates_num); } else { - m_candidates_num = std::max(int64_t(m_candidates_num) - 1, int64_t(1)); + m_candidates_num = static_cast(std::max(static_cast(m_candidates_num) - 1, int64_t(1))); } } } // namespace genai diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_stateful.hpp b/src/cpp/src/speculative_decoding/speculative_decoding_stateful.hpp index 2740933334..87865b6dad 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_stateful.hpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_stateful.hpp @@ -6,8 +6,8 @@ #include "llm/pipeline_base.hpp" #include "sampling/sampler.hpp" #include "utils.hpp" -#include -#include +#include "openvino/genai/perf_metrics.hpp" +#include "openvino/genai/speculative_decoding/perf_metrics.hpp" namespace ov { namespace genai { From 9f91e49633befed1b385ab18ad99626163ea27f9 Mon Sep 17 00:00:00 2001 From: Anastasiya Pronina Date: Mon, 22 Sep 2025 14:09:17 +0100 Subject: [PATCH 27/40] GenerationConfig.num_assistant_tokens behaviour is specified, added note about usage of GPU in samples --- .../speculative_decoding_lm.cpp | 5 +++ .../speculative_decoding_lm.py | 8 +++-- .../openvino/genai/generation_config.hpp | 2 ++ .../speculative_decoding_stateful.cpp | 34 ++++++++++++++++--- 4 files changed, 43 insertions(+), 6 deletions(-) diff --git a/samples/cpp/text_generation/speculative_decoding_lm.cpp b/samples/cpp/text_generation/speculative_decoding_lm.cpp index c659a75fbd..a3834972ec 100644 --- a/samples/cpp/text_generation/speculative_decoding_lm.cpp +++ b/samples/cpp/text_generation/speculative_decoding_lm.cpp @@ -18,6 +18,8 @@ int main(int argc, char* argv[]) try { config.num_assistant_tokens = 5; // add parameter to enable speculative decoding to generate candidates by draft_model while candidate probability is higher than `assistant_confidence_threshold` // config.assistant_confidence_threshold = 0.4; + // Note: `config.num_assistant_tokens` behaves differently if Stateful and non Continuous Batching pipeline is called for Speculative Decode, number of candidates + // will still be chosen dynamically based on first hint from `config.num_assistant_tokens` std::string main_model_path = argv[1]; std::string draft_model_path = argv[2]; @@ -25,6 +27,9 @@ int main(int argc, char* argv[]) try { // User can run main and draft model on different devices. // Please, set device for main model in `LLMPipeline` constructor and in in `ov::genai::draft_model` for draft. + // CPU, GPU and NPU can be used. Please be aware that GPU is performant only with Continious Batching pipeline, + // so it is not recommented to use it in conjuction with NPU or in configuration when main model doesn't work + // in Paged Attention mode. std::string main_device = "CPU", draft_device = "CPU"; ov::genai::LLMPipeline pipe( diff --git a/samples/python/text_generation/speculative_decoding_lm.py b/samples/python/text_generation/speculative_decoding_lm.py index ec2795d1da..660078fb58 100755 --- a/samples/python/text_generation/speculative_decoding_lm.py +++ b/samples/python/text_generation/speculative_decoding_lm.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright (C) 2024 Intel Corporation +# Copyright (C) 2024-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 import argparse @@ -20,7 +20,9 @@ def main(): # User can run main and draft model on different devices. # Please, set device for main model in `openvino_genai.LLMPipeline` constructor and in openvino_genai.draft_model` for draft. - main_device = 'CPU' # GPU or NPU can be used as well + # CPU, GPU and NPU can be used. Please be aware that GPU is performant only with Continious Batching pipeline, so it is not + # recommented to use it in conjuction with NPU or in configuration when main model doesn't work in Paged Attention mode. + main_device = 'CPU' draft_device = 'CPU' draft_model = openvino_genai.draft_model(args.draft_model_dir, draft_device) @@ -34,6 +36,8 @@ def main(): config.num_assistant_tokens = 5 # add parameter to enable speculative decoding to generate candidates by draft_model while candidate probability is higher than `assistant_confidence_threshold` # config.assistant_confidence_threshold = 0.4 + # Note: `config.num_assistant_tokens` behaves differently if Stateful and non Continuous Batching pipeline is called for Speculative Decode, number of candidates + # will still be chosen dynamically based on first hint from `config.num_assistant_tokens` # Since the streamer is set, the results will be printed # every time a new token is generated and put into the streamer queue. diff --git a/src/cpp/include/openvino/genai/generation_config.hpp b/src/cpp/include/openvino/genai/generation_config.hpp index 3020be34bc..c01d99ac2f 100644 --- a/src/cpp/include/openvino/genai/generation_config.hpp +++ b/src/cpp/include/openvino/genai/generation_config.hpp @@ -274,6 +274,8 @@ operator|(const StructuredOutputConfig::CompoundGrammar& lhs, * Assisting generation parameters: * @param assistant_confidence_threshold the lower token probability of candidate to be validated by main model in case of dynamic strategy candidates number update. * @param num_assistant_tokens the defined candidates number to be generated by draft model/prompt lookup in case of static strategy candidates number update. + * NOTE: Unlike the default purpose of this parameter, it will be used with dynamic strategy in Stateful (non Continuous Batching) Speculative + * Decoding pipeline, as given pipeline supports only `num_assistant_tokens` parameter and dynamic strategy for current moment. * @param max_ngram_size is maximum ngram to use when looking for matches in the prompt. * * @param structured_output_config if set, the output will be a string constrained by the specified json_schema, regex, or EBNF grammar. diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_stateful.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_stateful.cpp index 37bd722522..a4dd0c998c 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_stateful.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_stateful.cpp @@ -41,6 +41,18 @@ void update_perf_stat_by_infer_duration(ov::genai::RawPerfMetrics& raw_perf_coun raw_perf_counters.m_inference_durations[0] += ov::genai::MicroSeconds(inference_duration); raw_perf_counters.m_batch_sizes.emplace_back(num_generated_tokens); } + +void ensure_num_assistant_tokens_is_set(ov::genai::GenerationConfig& generation_config) { + auto assistant_confidence_threshold = generation_config.assistant_confidence_threshold; + OPENVINO_ASSERT(assistant_confidence_threshold == 0.f, + "Stateful (non Continuous Batching) Speculative Decoding pipeline only supports `num_assistant_tokens` " + "as parameter in GenerationConfig and doesn't work with `assistant_confidence_threshold`.\nPlease " + "remove its specification or set it to 0.f."); + + if (generation_config.num_assistant_tokens == 0) { + generation_config.num_assistant_tokens = 5; + } +} }// anonymous namespace namespace ov { @@ -374,8 +386,10 @@ StatefulSpeculativeLLMPipeline::StatefulSpeculativeLLMPipeline( } m_draft_request = std::make_unique(draft_model_desc_copy); - auto requested_candidates_num = main_model_desc.generation_config.num_assistant_tokens; - m_candidates_num = (requested_candidates_num != 0) ? requested_candidates_num : 5; + // Specifying number candidates to generate + ensure_num_assistant_tokens_is_set(m_generation_config); + m_candidates_num = m_generation_config.num_assistant_tokens; + m_max_candidates_num = m_candidates_num * 2; // Main model (which is bigger, more accurate but slower) auto main_model_desc_copy = main_model_desc; @@ -407,7 +421,13 @@ DecodedResults StatefulSpeculativeLLMPipeline::generate( } }, inputs); - const GenerationConfig& config = generation_config.has_value() ? *generation_config : m_generation_config; + GenerationConfig config = m_generation_config; + if (generation_config.has_value()) { + config = *generation_config; + ensure_num_assistant_tokens_is_set(config); + m_candidates_num = config.num_assistant_tokens; + m_max_candidates_num = m_candidates_num * 2; + } ov::genai::TokenizedInputs tokenized_input; if (m_is_chat_conversation) { @@ -481,7 +501,13 @@ EncodedResults StatefulSpeculativeLLMPipeline::generate( const size_t batch_size = prompt_shape[0]; OPENVINO_ASSERT(batch_size == 1u, "Currently only batch size=1 is supported"); - GenerationConfig config = (generation_config.has_value()) ? *generation_config : m_generation_config; + GenerationConfig config = m_generation_config; + if (generation_config.has_value()) { + config = *generation_config; + ensure_num_assistant_tokens_is_set(config); + m_candidates_num = config.num_assistant_tokens; + m_max_candidates_num = m_candidates_num * 2; + } // If stop_token_ids were not provided, take value from default m_generation_config if (config.stop_token_ids.empty()) config.stop_token_ids = m_generation_config.stop_token_ids; From 3d7b8dcd01740a8744b6c42fb58eb8cbd970859a Mon Sep 17 00:00:00 2001 From: Anastasiya Pronina Date: Mon, 22 Sep 2025 15:35:08 +0100 Subject: [PATCH 28/40] Rewritten NOTE-s --- .../text_generation/speculative_decoding_lm.cpp | 16 ++++++++-------- .../text_generation/speculative_decoding_lm.py | 13 +++++++------ .../include/openvino/genai/generation_config.hpp | 4 ++-- .../speculative_decoding_stateful.hpp | 2 +- 4 files changed, 18 insertions(+), 17 deletions(-) diff --git a/samples/cpp/text_generation/speculative_decoding_lm.cpp b/samples/cpp/text_generation/speculative_decoding_lm.cpp index a3834972ec..2799d57152 100644 --- a/samples/cpp/text_generation/speculative_decoding_lm.cpp +++ b/samples/cpp/text_generation/speculative_decoding_lm.cpp @@ -13,13 +13,14 @@ int main(int argc, char* argv[]) try { ov::genai::GenerationConfig config; config.max_new_tokens = 100; - // Speculative decoding generation parameters like `num_assistant_tokens` and `assistant_confidence_threshold` are mutually excluded - // add parameter to enable speculative decoding to generate `num_assistant_tokens` candidates by draft_model per iteration + // Speculative decoding generation parameters like `num_assistant_tokens` and `assistant_confidence_threshold` are mutually excluded. + // Add parameter to enable speculative decoding to generate `num_assistant_tokens` candidates by draft_model per iteration. + // NOTE: ContinuousBatching backend uses `num_assistant_tokens` as is. Stateful backend uses `num_assistant_tokens`'s copy as initial + // value and adjusts it based on recent number of accepted tokens. config.num_assistant_tokens = 5; - // add parameter to enable speculative decoding to generate candidates by draft_model while candidate probability is higher than `assistant_confidence_threshold` + // Add parameter to enable speculative decoding to generate candidates by draft_model while candidate probability is higher than + // `assistant_confidence_threshold`. // config.assistant_confidence_threshold = 0.4; - // Note: `config.num_assistant_tokens` behaves differently if Stateful and non Continuous Batching pipeline is called for Speculative Decode, number of candidates - // will still be chosen dynamically based on first hint from `config.num_assistant_tokens` std::string main_model_path = argv[1]; std::string draft_model_path = argv[2]; @@ -27,9 +28,8 @@ int main(int argc, char* argv[]) try { // User can run main and draft model on different devices. // Please, set device for main model in `LLMPipeline` constructor and in in `ov::genai::draft_model` for draft. - // CPU, GPU and NPU can be used. Please be aware that GPU is performant only with Continious Batching pipeline, - // so it is not recommented to use it in conjuction with NPU or in configuration when main model doesn't work - // in Paged Attention mode. + // CPU, GPU and NPU can be used. Please be aware that GPU is performant only with Continuous Batching pipeline, so it is not recommented + // to use it in conjuction with NPU or in configuration when main model doesn't work in Paged Attention mode. std::string main_device = "CPU", draft_device = "CPU"; ov::genai::LLMPipeline pipe( diff --git a/samples/python/text_generation/speculative_decoding_lm.py b/samples/python/text_generation/speculative_decoding_lm.py index 660078fb58..b6f86e5373 100755 --- a/samples/python/text_generation/speculative_decoding_lm.py +++ b/samples/python/text_generation/speculative_decoding_lm.py @@ -20,7 +20,7 @@ def main(): # User can run main and draft model on different devices. # Please, set device for main model in `openvino_genai.LLMPipeline` constructor and in openvino_genai.draft_model` for draft. - # CPU, GPU and NPU can be used. Please be aware that GPU is performant only with Continious Batching pipeline, so it is not + # CPU, GPU and NPU can be used. Please be aware that GPU is performant only with Continuous Batching pipeline, so it is not # recommented to use it in conjuction with NPU or in configuration when main model doesn't work in Paged Attention mode. main_device = 'CPU' draft_device = 'CPU' @@ -31,13 +31,14 @@ def main(): config = openvino_genai.GenerationConfig() config.max_new_tokens = 100 - # Speculative decoding generation parameters like `num_assistant_tokens` and `assistant_confidence_threshold` are mutually excluded - # add parameter to enable speculative decoding to generate `num_assistant_tokens` candidates by draft_model per iteration + # Speculative decoding generation parameters like `num_assistant_tokens` and `assistant_confidence_threshold` are mutually excluded. + # Add parameter to enable speculative decoding to generate `num_assistant_tokens` candidates by draft_model per iteration. + # NOTE: ContinuousBatching backend uses `num_assistant_tokens` as is. Stateful backend uses `num_assistant_tokens`'s copy as initial + # value and adjusts it based on recent number of accepted tokens. config.num_assistant_tokens = 5 - # add parameter to enable speculative decoding to generate candidates by draft_model while candidate probability is higher than `assistant_confidence_threshold` + # Add parameter to enable speculative decoding to generate candidates by draft_model while candidate probability is higher than + # `assistant_confidence_threshold`. # config.assistant_confidence_threshold = 0.4 - # Note: `config.num_assistant_tokens` behaves differently if Stateful and non Continuous Batching pipeline is called for Speculative Decode, number of candidates - # will still be chosen dynamically based on first hint from `config.num_assistant_tokens` # Since the streamer is set, the results will be printed # every time a new token is generated and put into the streamer queue. diff --git a/src/cpp/include/openvino/genai/generation_config.hpp b/src/cpp/include/openvino/genai/generation_config.hpp index c01d99ac2f..ca19e335e3 100644 --- a/src/cpp/include/openvino/genai/generation_config.hpp +++ b/src/cpp/include/openvino/genai/generation_config.hpp @@ -274,8 +274,8 @@ operator|(const StructuredOutputConfig::CompoundGrammar& lhs, * Assisting generation parameters: * @param assistant_confidence_threshold the lower token probability of candidate to be validated by main model in case of dynamic strategy candidates number update. * @param num_assistant_tokens the defined candidates number to be generated by draft model/prompt lookup in case of static strategy candidates number update. - * NOTE: Unlike the default purpose of this parameter, it will be used with dynamic strategy in Stateful (non Continuous Batching) Speculative - * Decoding pipeline, as given pipeline supports only `num_assistant_tokens` parameter and dynamic strategy for current moment. + * NOTE: ContinuousBatching backend uses `num_assistant_tokens` as is. Stateful backend uses `num_assistant_tokens`'s copy as initial value and adjusts it + * based on recent number of accepted tokens. * @param max_ngram_size is maximum ngram to use when looking for matches in the prompt. * * @param structured_output_config if set, the output will be a string constrained by the specified json_schema, regex, or EBNF grammar. diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_stateful.hpp b/src/cpp/src/speculative_decoding/speculative_decoding_stateful.hpp index 87865b6dad..0de9950eef 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_stateful.hpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_stateful.hpp @@ -112,7 +112,7 @@ class StatefulSpeculativeLLMPipeline : public ov::genai::LLMPipelineImplBase { std::unique_ptr m_draft_request; std::unique_ptr m_main_request; std::size_t m_candidates_num = 5; - const std::size_t m_max_candidates_num = 10; + std::size_t m_max_candidates_num = 10; ov::genai::SpeculativeDecodingMetrics m_sd_metrics; ov::genai::SDPerModelsPerfMetrics m_sd_perf_metrics; From 918336b5b1b983b08c0c868b36e988edff7e4e0f Mon Sep 17 00:00:00 2001 From: Anastasiya Pronina Date: Tue, 23 Sep 2025 15:38:25 +0100 Subject: [PATCH 29/40] Alignment of behavior between Stateful and ContinuousBatching Speculative pipelines --- .../cpp/text_generation/speculative_decoding_lm.cpp | 6 ++++-- .../python/text_generation/speculative_decoding_lm.py | 8 +++++--- src/cpp/include/openvino/genai/generation_config.hpp | 6 ++++-- ...tinuous_batching_for_speculative_decoding_impl.cpp | 5 +++++ .../speculative_decoding_impl.cpp | 10 ++++++++-- .../speculative_decoding_stateful.cpp | 11 +---------- 6 files changed, 27 insertions(+), 19 deletions(-) diff --git a/samples/cpp/text_generation/speculative_decoding_lm.cpp b/samples/cpp/text_generation/speculative_decoding_lm.cpp index 2799d57152..e75c929552 100644 --- a/samples/cpp/text_generation/speculative_decoding_lm.cpp +++ b/samples/cpp/text_generation/speculative_decoding_lm.cpp @@ -16,10 +16,12 @@ int main(int argc, char* argv[]) try { // Speculative decoding generation parameters like `num_assistant_tokens` and `assistant_confidence_threshold` are mutually excluded. // Add parameter to enable speculative decoding to generate `num_assistant_tokens` candidates by draft_model per iteration. // NOTE: ContinuousBatching backend uses `num_assistant_tokens` as is. Stateful backend uses `num_assistant_tokens`'s copy as initial - // value and adjusts it based on recent number of accepted tokens. - config.num_assistant_tokens = 5; + // value and adjusts it based on recent number of accepted tokens. If `num_assistant_tokens` is not set it will be defaulted to `5` + // for both backends. + // config.num_assistant_tokens = 5; // Add parameter to enable speculative decoding to generate candidates by draft_model while candidate probability is higher than // `assistant_confidence_threshold`. + // NOTE: `assistant_confidence_threshold` is supported only by ContinuousBatching backend. // config.assistant_confidence_threshold = 0.4; std::string main_model_path = argv[1]; diff --git a/samples/python/text_generation/speculative_decoding_lm.py b/samples/python/text_generation/speculative_decoding_lm.py index b6f86e5373..a3ff2d82ee 100755 --- a/samples/python/text_generation/speculative_decoding_lm.py +++ b/samples/python/text_generation/speculative_decoding_lm.py @@ -19,7 +19,7 @@ def main(): args = parser.parse_args() # User can run main and draft model on different devices. - # Please, set device for main model in `openvino_genai.LLMPipeline` constructor and in openvino_genai.draft_model` for draft. + # Please, set device for main model in `openvino_genai.LLMPipeline` constructor and in `openvino_genai.draft_model` for draft. # CPU, GPU and NPU can be used. Please be aware that GPU is performant only with Continuous Batching pipeline, so it is not # recommented to use it in conjuction with NPU or in configuration when main model doesn't work in Paged Attention mode. main_device = 'CPU' @@ -34,10 +34,12 @@ def main(): # Speculative decoding generation parameters like `num_assistant_tokens` and `assistant_confidence_threshold` are mutually excluded. # Add parameter to enable speculative decoding to generate `num_assistant_tokens` candidates by draft_model per iteration. # NOTE: ContinuousBatching backend uses `num_assistant_tokens` as is. Stateful backend uses `num_assistant_tokens`'s copy as initial - # value and adjusts it based on recent number of accepted tokens. - config.num_assistant_tokens = 5 + # value and adjusts it based on recent number of accepted tokens. If `num_assistant_tokens` is not set it will be defaulted to `5` + # for both backends. + # config.num_assistant_tokens = 5 # Add parameter to enable speculative decoding to generate candidates by draft_model while candidate probability is higher than # `assistant_confidence_threshold`. + # NOTE: `assistant_confidence_threshold` is supported only by ContinuousBatching backend. # config.assistant_confidence_threshold = 0.4 # Since the streamer is set, the results will be printed diff --git a/src/cpp/include/openvino/genai/generation_config.hpp b/src/cpp/include/openvino/genai/generation_config.hpp index ca19e335e3..61ffce69e3 100644 --- a/src/cpp/include/openvino/genai/generation_config.hpp +++ b/src/cpp/include/openvino/genai/generation_config.hpp @@ -273,9 +273,11 @@ operator|(const StructuredOutputConfig::CompoundGrammar& lhs, * * Assisting generation parameters: * @param assistant_confidence_threshold the lower token probability of candidate to be validated by main model in case of dynamic strategy candidates number update. + NOTE: `assistant_confidence_threshold` is supported only by ContinuousBatching backend for Speculative Decode. * @param num_assistant_tokens the defined candidates number to be generated by draft model/prompt lookup in case of static strategy candidates number update. - * NOTE: ContinuousBatching backend uses `num_assistant_tokens` as is. Stateful backend uses `num_assistant_tokens`'s copy as initial value and adjusts it - * based on recent number of accepted tokens. + * NOTE: ContinuousBatching backend for Speculative Decode uses `num_assistant_tokens` as is. Stateful backend for Speculative Decode uses `num_assistant_tokens`'s + * copy as initial value and adjusts it based on recent number of accepted tokens. If `num_assistant_tokens` is not set it will be defaulted to `5` for both + * backends. * @param max_ngram_size is maximum ngram to use when looking for matches in the prompt. * * @param structured_output_config if set, the output will be a string constrained by the specified json_schema, regex, or EBNF grammar. diff --git a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp index aaede5360f..beff8eac23 100644 --- a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp @@ -14,6 +14,11 @@ ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::Contin bool is_validation_mode_enabled) { m_tokenizer = tokenizer; m_generation_config = generation_config; + if (m_generation_config.assistant_confidence_threshold == 0.f) { + if (m_generation_config.num_assistant_tokens == 0) { + m_generation_config.num_assistant_tokens = 5; + } + } m_is_validation_mode_enabled = is_validation_mode_enabled; initialize_pipeline(model, scheduler_config, device, plugin_config); } diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp index 4b2a4a13a2..6db9a092ad 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp @@ -262,9 +262,15 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector< std::vector main_generations; for (size_t request_id = 0; request_id < input_ids.size(); ++request_id) { OPENVINO_ASSERT(1 == input_ids[request_id].get_shape().at(0), "Use multiple tensors to pass a batch."); - main_generations.push_back(m_main_pipeline->add_request(request_id, input_ids[request_id], sampling_params[request_id])); + auto main_sampling_params = sampling_params[request_id]; + if (main_sampling_params.assistant_confidence_threshold == 0.f) { + if (main_sampling_params.num_assistant_tokens == 0) { + main_sampling_params.num_assistant_tokens = 5; + } + } + main_generations.push_back(m_main_pipeline->add_request(request_id, input_ids[request_id], main_sampling_params)); - auto draft_sampling_params = sampling_params[request_id]; + auto draft_sampling_params = main_sampling_params; // set the parameters do not stop draft generation without stopping of the same request for main pipeline draft_sampling_params.ignore_eos = true; draft_sampling_params.stop_strings = {}; diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_stateful.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_stateful.cpp index a4dd0c998c..427c2fa786 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_stateful.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_stateful.cpp @@ -522,12 +522,7 @@ EncodedResults StatefulSpeculativeLLMPipeline::generate( OPENVINO_ASSERT(config.num_return_sequences == 1u, "Currently only \"num_return_sequences\" equal to 1 is supported!"); - // FIXME: Update conditionally: m_main_request->set_generation_config(config); - auto requested_candidates_num = config.num_assistant_tokens; - if (requested_candidates_num != 0) { - m_candidates_num = requested_candidates_num; - } // Config draft model to not stop on EOS and remove stop strings: ov::genai::GenerationConfig draft_config = m_draft_request->get_generation_config(); @@ -735,11 +730,7 @@ EncodedResults StatefulSpeculativeLLMPipeline::generate( // If not chat conversation, then reset all states. if (!m_is_chat_conversation) { - m_candidates_num = 5; - requested_candidates_num = config.num_assistant_tokens; - if (requested_candidates_num != 0) { - m_candidates_num = requested_candidates_num; - } + m_candidates_num = config.num_assistant_tokens; m_draft_request->reset_state(); m_main_request->reset_state(); } From 3635d1f446502a48b811917ab24301c04d466370 Mon Sep 17 00:00:00 2001 From: Anastasiya Pronina Date: Tue, 23 Sep 2025 15:59:57 +0100 Subject: [PATCH 30/40] Fixed review comments --- samples/cpp/text_generation/speculative_decoding_lm.cpp | 4 ++-- samples/python/text_generation/speculative_decoding_lm.py | 4 ++-- src/cpp/include/openvino/genai/generation_config.hpp | 3 +-- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/samples/cpp/text_generation/speculative_decoding_lm.cpp b/samples/cpp/text_generation/speculative_decoding_lm.cpp index e75c929552..b8a027f792 100644 --- a/samples/cpp/text_generation/speculative_decoding_lm.cpp +++ b/samples/cpp/text_generation/speculative_decoding_lm.cpp @@ -16,8 +16,8 @@ int main(int argc, char* argv[]) try { // Speculative decoding generation parameters like `num_assistant_tokens` and `assistant_confidence_threshold` are mutually excluded. // Add parameter to enable speculative decoding to generate `num_assistant_tokens` candidates by draft_model per iteration. // NOTE: ContinuousBatching backend uses `num_assistant_tokens` as is. Stateful backend uses `num_assistant_tokens`'s copy as initial - // value and adjusts it based on recent number of accepted tokens. If `num_assistant_tokens` is not set it will be defaulted to `5` - // for both backends. + // value and adjusts it based on recent number of accepted tokens. If `num_assistant_tokens` is not set, it defaults to `5` for both + // backends. // config.num_assistant_tokens = 5; // Add parameter to enable speculative decoding to generate candidates by draft_model while candidate probability is higher than // `assistant_confidence_threshold`. diff --git a/samples/python/text_generation/speculative_decoding_lm.py b/samples/python/text_generation/speculative_decoding_lm.py index a3ff2d82ee..7280b2ea48 100755 --- a/samples/python/text_generation/speculative_decoding_lm.py +++ b/samples/python/text_generation/speculative_decoding_lm.py @@ -34,8 +34,8 @@ def main(): # Speculative decoding generation parameters like `num_assistant_tokens` and `assistant_confidence_threshold` are mutually excluded. # Add parameter to enable speculative decoding to generate `num_assistant_tokens` candidates by draft_model per iteration. # NOTE: ContinuousBatching backend uses `num_assistant_tokens` as is. Stateful backend uses `num_assistant_tokens`'s copy as initial - # value and adjusts it based on recent number of accepted tokens. If `num_assistant_tokens` is not set it will be defaulted to `5` - # for both backends. + # value and adjusts it based on recent number of accepted tokens. If `num_assistant_tokens` is not set, it defaults to `5` for both + # backends. # config.num_assistant_tokens = 5 # Add parameter to enable speculative decoding to generate candidates by draft_model while candidate probability is higher than # `assistant_confidence_threshold`. diff --git a/src/cpp/include/openvino/genai/generation_config.hpp b/src/cpp/include/openvino/genai/generation_config.hpp index 61ffce69e3..74a7cdfe90 100644 --- a/src/cpp/include/openvino/genai/generation_config.hpp +++ b/src/cpp/include/openvino/genai/generation_config.hpp @@ -276,8 +276,7 @@ operator|(const StructuredOutputConfig::CompoundGrammar& lhs, NOTE: `assistant_confidence_threshold` is supported only by ContinuousBatching backend for Speculative Decode. * @param num_assistant_tokens the defined candidates number to be generated by draft model/prompt lookup in case of static strategy candidates number update. * NOTE: ContinuousBatching backend for Speculative Decode uses `num_assistant_tokens` as is. Stateful backend for Speculative Decode uses `num_assistant_tokens`'s - * copy as initial value and adjusts it based on recent number of accepted tokens. If `num_assistant_tokens` is not set it will be defaulted to `5` for both - * backends. + * copy as initial value and adjusts it based on recent number of accepted tokens. If `num_assistant_tokens` is not set, it defaults to `5` for both backends. * @param max_ngram_size is maximum ngram to use when looking for matches in the prompt. * * @param structured_output_config if set, the output will be a string constrained by the specified json_schema, regex, or EBNF grammar. From 7af9df6cb7040681ed30db74dd59877a14ff5e3a Mon Sep 17 00:00:00 2001 From: Anastasiya Pronina Date: Mon, 29 Sep 2025 20:23:29 +0100 Subject: [PATCH 31/40] Fixed last comments and added tests --- .../speculative_decoding_lm.cpp | 2 +- .../speculative_decoding_lm.py | 2 +- .../speculative_decoding_stateful.cpp | 2 +- .../test_stateful_speculative_decoding.py | 205 ++++++++++++++++++ .../python_tests/utils/ov_genai_pipelines.py | 15 +- 5 files changed, 220 insertions(+), 6 deletions(-) create mode 100644 tests/python_tests/test_stateful_speculative_decoding.py diff --git a/samples/cpp/text_generation/speculative_decoding_lm.cpp b/samples/cpp/text_generation/speculative_decoding_lm.cpp index b8a027f792..044b9c843e 100644 --- a/samples/cpp/text_generation/speculative_decoding_lm.cpp +++ b/samples/cpp/text_generation/speculative_decoding_lm.cpp @@ -18,7 +18,7 @@ int main(int argc, char* argv[]) try { // NOTE: ContinuousBatching backend uses `num_assistant_tokens` as is. Stateful backend uses `num_assistant_tokens`'s copy as initial // value and adjusts it based on recent number of accepted tokens. If `num_assistant_tokens` is not set, it defaults to `5` for both // backends. - // config.num_assistant_tokens = 5; + config.num_assistant_tokens = 4; // Add parameter to enable speculative decoding to generate candidates by draft_model while candidate probability is higher than // `assistant_confidence_threshold`. // NOTE: `assistant_confidence_threshold` is supported only by ContinuousBatching backend. diff --git a/samples/python/text_generation/speculative_decoding_lm.py b/samples/python/text_generation/speculative_decoding_lm.py index 7280b2ea48..58a1954b6c 100755 --- a/samples/python/text_generation/speculative_decoding_lm.py +++ b/samples/python/text_generation/speculative_decoding_lm.py @@ -36,7 +36,7 @@ def main(): # NOTE: ContinuousBatching backend uses `num_assistant_tokens` as is. Stateful backend uses `num_assistant_tokens`'s copy as initial # value and adjusts it based on recent number of accepted tokens. If `num_assistant_tokens` is not set, it defaults to `5` for both # backends. - # config.num_assistant_tokens = 5 + config.num_assistant_tokens = 4 # Add parameter to enable speculative decoding to generate candidates by draft_model while candidate probability is higher than # `assistant_confidence_threshold`. # NOTE: `assistant_confidence_threshold` is supported only by ContinuousBatching backend. diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_stateful.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_stateful.cpp index 427c2fa786..2ece0cf4a8 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_stateful.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_stateful.cpp @@ -705,7 +705,7 @@ EncodedResults StatefulSpeculativeLLMPipeline::generate( auto& main_perf_generated_tokens = m_main_request->raw_perf_metrics.m_batch_sizes.back(); main_perf_generated_tokens -= mismatched_candidates; m_sd_metrics.update_draft_generated_len(0 /* request_id */, candidates_to_generate); - m_sd_metrics.update_acceptance_rate(0 /* request_id */, (accepted_tokens_number / candidates_to_generate) * 100); + m_sd_metrics.update_acceptance_rate(0 /* request_id */, (accepted_tokens_number * 100.f) / candidates_to_generate); m_sd_metrics.update_draft_accepted_tokens(0 /* request_id */, accepted_tokens_number); m_sd_metrics.update_generated_len(validated_tokens.size()); if (utils::env_setup_for_print_debug_info()) { diff --git a/tests/python_tests/test_stateful_speculative_decoding.py b/tests/python_tests/test_stateful_speculative_decoding.py new file mode 100644 index 0000000000..7c64cfabbb --- /dev/null +++ b/tests/python_tests/test_stateful_speculative_decoding.py @@ -0,0 +1,205 @@ +# Copyright (C) 2023-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import pytest +import numpy as np +import logging + +import openvino as ov +import openvino_genai as ov_genai + +from utils.constants import get_default_llm_properties +from utils.hugging_face import generation_config_to_hf, download_and_convert_model, run_hugging_face +from utils.comparation import compare_generation_results +from utils.ov_genai_pipelines import create_ov_pipeline, generate_and_compare, get_main_pipeline_types, PipelineType, convert_decoded_results_to_generation_result + +test_cases = [ + ('CPU', 'CPU'), + ('CPU', 'NPUW:CPU'), + ('NPUW:CPU', 'CPU'), + ('NPUW:CPU', 'NPUW:CPU') +] +@pytest.mark.parametrize("main_device,draft_device", test_cases) +@pytest.mark.precommit +def test_string_inputs(main_device, draft_device): + # FIXME: For now SmolLM2-135M is used as a main and a draft model in the test. + # However, it is more desirable to use SmolLM2-360M as a main one to simulate the real case + # for speculative decoding. + # It seems like temporary directory from model downloading stage isn't removed after test + # launch for SmolLM2-360M model, that is why it is not used now. + MODEL_UNDER_TEST = { + "name": "HuggingFaceTB/SmolLM2-135M", + "convert_args": ['--trust-remote-code'] + } + prompt = "Alan Turing was a" + + # Download and convert model: + main_opt_model, main_hf_tokenizer, main_model_path = download_and_convert_model(MODEL_UNDER_TEST["name"]) + draft_model_path = main_model_path + + # Create OpenVINO GenAI pipeline: + draft_config = get_default_llm_properties() + if draft_device == "NPUW:CPU": + draft_device = "NPU" + draft_config["NPUW_DEVICES"] = "CPU" + draft_config["GENERATE_HINT"] = "BEST_PERF" + # FIXME: Currently, the same draft and main model fails to work in NPUW_WEIGHTS_BANK: shared mode. + # To workaround this, we name banks differently for draft and main. + draft_config["NPUW_WEIGHTS_BANK"] = "draft" + ov_draft_model = ov_genai.draft_model(draft_model_path, draft_device, **draft_config) + + main_config = get_default_llm_properties() + if main_device == "NPUW:CPU": + main_device = "NPU" + main_config["NPUW_DEVICES"] = "CPU" + # FIXME: SmolLM-135M with GENERATE_HINT: FAST_COMPILE will output garbage on NPUW:CPU if used with configuration + # NPUW_LLM_MAX_GENERATION_TOKEN_LEN > 1. + # Setting GENERATE_HINT: BEST_PERF to workaround an issue currently. + main_config["GENERATE_HINT"] = "BEST_PERF" + # FIXME: Currently, the same draft and main model fails to work in NPUW_WEIGHTS_BANK: shared mode. + # To workaround this, we name banks differently for draft and main. + main_config["NPUW_WEIGHTS_BANK"] = "main" + main_config["ATTENTION_BACKEND"] = "SDPA" + ov_pipe = ov_genai.LLMPipeline(main_model_path, main_device, main_config, draft_model=ov_draft_model) + + # Run reference HF model: + ov_generation_config = ov_genai.GenerationConfig(max_new_tokens=20) + main_hf_tokenizer.add_special_tokens({'pad_token': '[PAD]'}) + ref_gen_results = run_hugging_face(main_opt_model, main_hf_tokenizer, [prompt], ov_generation_config) + + # Run OpenVINO GenAI pipeline: + ov_decoded_results = ov_pipe.generate([prompt], ov_generation_config) + ov_gen_results = convert_decoded_results_to_generation_result(ov_decoded_results, 1, 1, False) + + del ov_pipe + + # Compare results: + compare_generation_results([prompt], ref_gen_results, ov_gen_results, ov_generation_config) + +@pytest.mark.precommit +def test_perf_metrics(): + import time + start_time = time.perf_counter() + model_id = 'katuni4ka/tiny-random-gemma2' + generation_config = ov_genai.GenerationConfig(do_sample=False, max_new_tokens=20, ignore_eos=True, num_assistant_tokens=5) + _, _, model_path = download_and_convert_model(model_id) + ov_pipe = create_ov_pipeline(model_path, pipeline_type=PipelineType.STATEFUL_SPECULATIVE_DECODING) + prompt = 'table is made of' + perf_metrics = ov_pipe.generate([prompt], generation_config).perf_metrics + total_time = (time.perf_counter() - start_time) * 1000 + + # Check that load time is adequate. + load_time = perf_metrics.get_load_time() + assert load_time > 0 and load_time < total_time + + # Check that num input and generated tokens are adequate. + num_generated_tokens = perf_metrics.get_num_generated_tokens() + assert num_generated_tokens > 0 and num_generated_tokens <= generation_config.max_new_tokens + + num_input_tokens = perf_metrics.get_num_input_tokens() + assert num_input_tokens > 0 and num_input_tokens <= len(prompt) + + mean_ttft, std_ttft = perf_metrics.get_ttft() + assert (mean_ttft, std_ttft) == (perf_metrics.get_ttft().mean, perf_metrics.get_ttft().std) + assert mean_ttft > 0 and mean_ttft < 1000.0 + + raw_metrics = perf_metrics.raw_metrics + durations = np.array(raw_metrics.m_durations) / 1000 + # Check that prefill is not included in durations for TPOT calculation. + # For the very long prompt prefill is slow and TTFT is much larger than any other token generation duration. + assert np.all(mean_ttft > durations) + + mean_tpot, std_tpot = perf_metrics.get_tpot() + assert (mean_tpot, std_tpot) == (perf_metrics.get_tpot().mean, perf_metrics.get_tpot().std) + assert mean_tpot > 0 and mean_ttft < 1000.0 + + mean_throughput, std_throughput = perf_metrics.get_throughput() + assert (mean_throughput, std_throughput) == (perf_metrics.get_throughput().mean, perf_metrics.get_throughput().std) + assert mean_throughput > 0 and mean_throughput < 20000.0 + + mean_gen_duration, std_gen_duration = perf_metrics.get_generate_duration() + assert (mean_gen_duration, std_gen_duration) == (perf_metrics.get_generate_duration().mean, perf_metrics.get_generate_duration().std) + assert mean_gen_duration > 0 and load_time + mean_gen_duration < total_time + assert std_gen_duration == 0 + + mean_tok_duration, std_tok_duration = perf_metrics.get_tokenization_duration() + assert (mean_tok_duration, std_tok_duration) == (perf_metrics.get_tokenization_duration().mean, perf_metrics.get_tokenization_duration().std) + assert mean_tok_duration > 0 and mean_tok_duration < mean_gen_duration + assert std_tok_duration == 0 + + mean_detok_duration, std_detok_duration = perf_metrics.get_detokenization_duration() + assert (mean_detok_duration, std_detok_duration) == (perf_metrics.get_detokenization_duration().mean, perf_metrics.get_detokenization_duration().std) + assert mean_detok_duration > 0 and mean_detok_duration < mean_gen_duration + assert std_detok_duration == 0 + + # assert that calculating statistics manually from the raw counters we get the same restults as from PerfMetrics + assert np.allclose(mean_tpot, np.mean(durations)) + assert np.allclose(std_tpot, np.std(durations), atol=0.00002) + + raw_dur = np.array(raw_metrics.generate_durations) / 1000 + assert np.allclose(mean_gen_duration, np.mean(raw_dur)) + assert np.allclose(std_gen_duration, np.std(raw_dur)) + + raw_dur = np.array(raw_metrics.tokenization_durations) / 1000 + assert np.allclose(mean_tok_duration, np.mean(raw_dur)) + assert np.allclose(std_tok_duration, np.std(raw_dur)) + + raw_dur = np.array(raw_metrics.detokenization_durations) / 1000 + assert np.allclose(mean_detok_duration, np.mean(raw_dur)) + assert np.allclose(std_detok_duration, np.std(raw_dur)) + + assert len(raw_metrics.m_times_to_first_token) > 0 + assert len(raw_metrics.m_batch_sizes) > 0 + assert len(raw_metrics.m_durations) > 0 + +@pytest.mark.precommit +def test_extended_perf_metrics(): + import time + start_time = time.perf_counter() + model_id : str = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + generation_config = ov_genai.GenerationConfig(do_sample=False, max_new_tokens=20, ignore_eos=True, num_assistant_tokens=5) + _, _, model_path = download_and_convert_model(model_id) + ov_pipe = create_ov_pipeline(model_path, pipeline_type=PipelineType.STATEFUL_SPECULATIVE_DECODING) + extended_perf_metrics = ov_pipe.generate(["Why is the Sun yellow?"], generation_config).extended_perf_metrics + total_time = (time.perf_counter() - start_time) * 1000 + + assert not extended_perf_metrics is None + assert not extended_perf_metrics.main_model_metrics is None + assert not extended_perf_metrics.draft_model_metrics is None + + assert extended_perf_metrics.get_num_accepted_tokens() > 0 + + num_generated_tokens_main = extended_perf_metrics.main_model_metrics.get_num_generated_tokens() + assert num_generated_tokens_main > 0 and num_generated_tokens_main <= generation_config.max_new_tokens + + num_generated_tokens_draft = extended_perf_metrics.draft_model_metrics.get_num_generated_tokens() + # As Stateful Speculative Decoding pipeline is dynamically adjusting its number of candidates at + # each step, here we check that generated tokens is less than upper candidates limit multiplied by + # maximum number of generated tokens. + assert num_generated_tokens_draft > 0 and \ + num_generated_tokens_draft < ((generation_config.max_new_tokens - 1) * \ + generation_config.num_assistant_tokens * 2 + 1) + + total_iteration_number_main = len(extended_perf_metrics.main_model_metrics.raw_metrics.m_durations) + assert total_iteration_number_main > 0 and total_iteration_number_main <= generation_config.max_new_tokens + + total_iteration_number_draft = len(extended_perf_metrics.draft_model_metrics.raw_metrics.m_durations) + assert total_iteration_number_draft > 0 and \ + total_iteration_number_draft < ((generation_config.max_new_tokens - 1) * \ + generation_config.num_assistant_tokens * 2 + 1) + + for model_metrics in [extended_perf_metrics.main_model_metrics, extended_perf_metrics.draft_model_metrics]: + mean_ttst, std_ttst = model_metrics.get_ttst() + assert (mean_ttst, std_ttst) == (model_metrics.get_ttst().mean, model_metrics.get_ttst().std) + assert mean_ttst > 0 and mean_ttst < model_metrics.get_ttft().mean + assert std_ttst == 0 + + mean_latency, std_latency = model_metrics.get_latency() + assert (mean_latency, std_latency) == (model_metrics.get_latency().mean, model_metrics.get_latency().std) + assert mean_latency > 0 and mean_latency < 1000.0 + + mean_gen_duration, std_gen_duration = model_metrics.get_generate_duration() + assert (mean_gen_duration, std_gen_duration) == (model_metrics.get_generate_duration().mean, model_metrics.get_generate_duration().std) + assert mean_gen_duration > 0 and mean_gen_duration < total_time + assert std_gen_duration == 0 diff --git a/tests/python_tests/utils/ov_genai_pipelines.py b/tests/python_tests/utils/ov_genai_pipelines.py index 167ff94bac..46e27612a0 100644 --- a/tests/python_tests/utils/ov_genai_pipelines.py +++ b/tests/python_tests/utils/ov_genai_pipelines.py @@ -41,13 +41,15 @@ class PipelineType(Enum): PAGED_ATTENTION = 2 CONTINUOUS_BATCHING = 3 SPECULATIVE_DECODING = 4 - PROMPT_LOOKUP_DECODING = 5 - AUTO = 6 + STATEFUL_SPECULATIVE_DECODING = 5 + PROMPT_LOOKUP_DECODING = 6 + AUTO = 7 def get_all_pipeline_types(): - return [PipelineType.STATEFUL, PipelineType.PAGED_ATTENTION, PipelineType.CONTINUOUS_BATCHING, PipelineType.SPECULATIVE_DECODING, PipelineType.PROMPT_LOOKUP_DECODING, PipelineType.AUTO] + return [PipelineType.STATEFUL, PipelineType.PAGED_ATTENTION, PipelineType.CONTINUOUS_BATCHING, PipelineType.SPECULATIVE_DECODING, PipelineType.STATEFUL_SPECULATIVE_DECODING, PipelineType.PROMPT_LOOKUP_DECODING, PipelineType.AUTO] +# TODO: Add PipelineType.STATEFUL_SPECULATIVE_DECODING, make its tests green. def get_main_pipeline_types(): return [PipelineType.STATEFUL, PipelineType.PAGED_ATTENTION, PipelineType.SPECULATIVE_DECODING, PipelineType.PROMPT_LOOKUP_DECODING] @@ -97,6 +99,10 @@ def create_ov_pipeline(models_path: Path, elif pipeline_type == PipelineType.SPECULATIVE_DECODING: ov_draft_model = draft_model(models_path) if draft_model_path is None else draft_model(draft_model_path) return LLMPipeline(models_path, device, ov_config, scheduler_config=scheduler_config, draft_model=ov_draft_model) + elif pipeline_type == PipelineType.STATEFUL_SPECULATIVE_DECODING: + ov_draft_model = draft_model(models_path) if draft_model_path is None else draft_model(draft_model_path) + ov_config["ATTENTION_BACKEND"] = "SDPA" + return LLMPipeline(models_path, device, ov_config, draft_model=ov_draft_model) elif pipeline_type == PipelineType.PROMPT_LOOKUP_DECODING: return LLMPipeline(models_path, device, ov_config, scheduler_config=scheduler_config, prompt_lookup=True) else: @@ -127,6 +133,9 @@ def prepare_generation_config_by_pipe_type(generation_config : GenerationConfig, if pipeline_type == PipelineType.SPECULATIVE_DECODING: assert not generation_config.is_beam_search() generation_config.assistant_confidence_threshold = 0.9 + elif pipeline_type == PipelineType.STATEFUL_SPECULATIVE_DECODING: + assert not generation_config.is_beam_search() + generation_config.num_assistant_tokens = 5 elif pipeline_type == PipelineType.PROMPT_LOOKUP_DECODING: assert not generation_config.is_beam_search() generation_config.num_assistant_tokens = 5 From c8cc3c436a7205fc6b46ddf0ce9722df6c37ebc8 Mon Sep 17 00:00:00 2001 From: Anastasiya Pronina Date: Wed, 1 Oct 2025 14:19:10 +0100 Subject: [PATCH 32/40] Fixed new review comments after team discussion --- ...batching_for_speculative_decoding_impl.cpp | 2 +- ...batching_for_speculative_decoding_impl.hpp | 2 + .../speculative_decoding_impl.cpp | 2 +- .../speculative_decoding_stateful.cpp | 68 ++++++++++++------- .../speculative_decoding_stateful.hpp | 2 +- 5 files changed, 48 insertions(+), 28 deletions(-) diff --git a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp index beff8eac23..4853b8bac6 100644 --- a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp @@ -16,7 +16,7 @@ ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::Contin m_generation_config = generation_config; if (m_generation_config.assistant_confidence_threshold == 0.f) { if (m_generation_config.num_assistant_tokens == 0) { - m_generation_config.num_assistant_tokens = 5; + m_generation_config.num_assistant_tokens = default_num_assistant_tokens; } } m_is_validation_mode_enabled = is_validation_mode_enabled; diff --git a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp index f81c7f2d37..40db6a2ddd 100644 --- a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp +++ b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp @@ -11,6 +11,8 @@ namespace ov::genai { class ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl : public ContinuousBatchingPipeline::ContinuousBatchingImpl { public: + const std::size_t default_num_assistant_tokens = 5; + ContinuousBatchingForSpeculativeDecodingImpl() = default; ContinuousBatchingForSpeculativeDecodingImpl(const std::shared_ptr& model, diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp index 6db9a092ad..17d0bef93a 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp @@ -265,7 +265,7 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector< auto main_sampling_params = sampling_params[request_id]; if (main_sampling_params.assistant_confidence_threshold == 0.f) { if (main_sampling_params.num_assistant_tokens == 0) { - main_sampling_params.num_assistant_tokens = 5; + main_sampling_params.num_assistant_tokens = m_main_pipeline->default_num_assistant_tokens; } } main_generations.push_back(m_main_pipeline->add_request(request_id, input_ids[request_id], main_sampling_params)); diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_stateful.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_stateful.cpp index 2ece0cf4a8..55e0d52ce6 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_stateful.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_stateful.cpp @@ -49,8 +49,9 @@ void ensure_num_assistant_tokens_is_set(ov::genai::GenerationConfig& generation_ "as parameter in GenerationConfig and doesn't work with `assistant_confidence_threshold`.\nPlease " "remove its specification or set it to 0.f."); + constexpr std::size_t default_num_assistant_tokens = 5; if (generation_config.num_assistant_tokens == 0) { - generation_config.num_assistant_tokens = 5; + generation_config.num_assistant_tokens = default_num_assistant_tokens; } } }// anonymous namespace @@ -219,7 +220,7 @@ int64_t LLMInferWrapper::infer_next(int64_t token, bool append_perf_stat) { return last_token; } -std::vector LLMInferWrapper::infer_next_return_all(const std::vector tokens) { +std::vector LLMInferWrapper::infer_next_return_all(const std::vector& tokens) { OPENVINO_ASSERT(m_num_processed_tokens > 0, "infer_next_return_all() can be called only after infer_first()!"); ManualTimer infer_next_return_all_timer("infer_next_return_all()"); @@ -389,6 +390,8 @@ StatefulSpeculativeLLMPipeline::StatefulSpeculativeLLMPipeline( // Specifying number candidates to generate ensure_num_assistant_tokens_is_set(m_generation_config); m_candidates_num = m_generation_config.num_assistant_tokens; + // We set the upper limit for candidates number as two times the number requested + // by user. m_max_candidates_num = m_candidates_num * 2; // Main model (which is bigger, more accurate but slower) @@ -426,6 +429,8 @@ DecodedResults StatefulSpeculativeLLMPipeline::generate( config = *generation_config; ensure_num_assistant_tokens_is_set(config); m_candidates_num = config.num_assistant_tokens; + // We set the upper limit for candidates number as two times the number + // requested by user. m_max_candidates_num = m_candidates_num * 2; } @@ -506,6 +511,8 @@ EncodedResults StatefulSpeculativeLLMPipeline::generate( config = *generation_config; ensure_num_assistant_tokens_is_set(config); m_candidates_num = config.num_assistant_tokens; + // We set the upper limit for candidates number as two times the number + // requested by user. m_max_candidates_num = m_candidates_num * 2; } // If stop_token_ids were not provided, take value from default m_generation_config @@ -516,8 +523,8 @@ EncodedResults StatefulSpeculativeLLMPipeline::generate( config.set_eos_token_id(m_generation_config.eos_token_id); config.validate(); - OPENVINO_ASSERT(config.is_greedy_decoding() || config.is_multinomial(), - "Currently only greedy and multinomial decoding are supported"); + OPENVINO_ASSERT(config.is_greedy_decoding(), + "Currently only greedy decoding are supported"); OPENVINO_ASSERT(config.num_return_sequences == 1u, "Currently only \"num_return_sequences\" equal to 1 is supported!"); @@ -528,7 +535,6 @@ EncodedResults StatefulSpeculativeLLMPipeline::generate( ov::genai::GenerationConfig draft_config = m_draft_request->get_generation_config(); draft_config.ignore_eos = true; draft_config.stop_strings = {}; - draft_config.max_new_tokens = config.get_max_new_tokens(); draft_config.validate(); m_draft_request->set_generation_config(draft_config); @@ -577,21 +583,21 @@ EncodedResults StatefulSpeculativeLLMPipeline::generate( ManualTimer candidates_timer("Draft model: candidates generation"); ManualTimer main_timer("Main model"); - /* Speculative decoding works the following way. The draft model predicts the next K - tokens one by one in an autoregressive manner, while the main model validates these - predictions and corrects them if necessary. We go through each predicted token, and - if a difference is detected between the draft and main model, we stop and keep the - last token predicted by the main model. Then the draft model gets the latest main - prediction and again tries to predict the next K tokens, repeating the cycle. - - This approach reduces the need for multiple infer requests to the main model, - enhancing performance. For instance, in more predictable parts of text generation, - the draft model can, in best-case scenarios, generate the next K tokens that exactly - match the target. In that case they are validated in a single inference call to - the main model instead of running K subsequent requests. - */ - // Last generated token by draft model needs to be prepended before next run if it is accepted by the main model! - // So it will get into the context too. + // Speculative decoding works the following way. The draft model predicts the next K + // tokens one by one in an autoregressive manner, while the main model validates these + // predictions and corrects them if necessary. We go through each predicted token, and + // if a difference is detected between the draft and main model, we stop and keep the + // last token predicted by the main model. Then the draft model gets the latest main + // prediction and again tries to predict the next K tokens, repeating the cycle. + + // This approach reduces the need for multiple infer requests to the main model, + // enhancing performance. For instance, in more predictable parts of text generation, + // the draft model can, in best-case scenarios, generate the next K tokens that exactly + // match the target. In that case they are validated in a single inference call to + // the main model instead of running K subsequent requests. + + // Last generated token by draft model needs to be prepended before next run if it is accepted + // by the main model! So it will get into the kvcache of the draft model. int64_t draft_prefix_token = -1; while (m_main_request->can_infer() && (streaming_status == ov::genai::StreamingStatus::RUNNING)) { iteration_timer.start(); @@ -601,13 +607,25 @@ EncodedResults StatefulSpeculativeLLMPipeline::generate( std::vector candidates; int64_t kvcache_room_for_candidates = std::min( + // Take into the account the draft prefix token, described above + // (before the while loop). If it is needed to be prepended to kvcache, + // then we can generate candidates as number of left kvcache space of + // draft model but minus 1: m_draft_request->get_kvcache_capacity() - ((draft_prefix_token == -1) ? 1 : 0), - // Take into the account reference token that is prefixed to candidates: + // Take into the account reference token that is prefixed to candidates. + // We can generate candidates as number of left kvcache space of main + // model, but as main model will consume candidates + its previous output + // then we need to preserve this one spot in main kvcache for previous + // output. m_main_request->get_kvcache_capacity() - 1); - int64_t generation_room_for_candidates = std::min( - m_draft_request->get_generation_capacity(), - // Take into the account output token, generated on candidates: - m_main_request->get_generation_capacity() - 1); + int64_t generation_room_for_candidates = + // Take into the account output token, generated on candidates. + // If we accept all candidates by the main model, then we will generate + // output of length equal to number of candidates + one output token from + // the main model. + // As output token number is limited we can generate candidates of only + // remained output tokens number - 1 (for output token). + m_main_request->get_generation_capacity() - 1; int64_t candidates_can_be_generated = std::min( kvcache_room_for_candidates, generation_room_for_candidates); if (candidates_can_be_generated <= 0) { diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_stateful.hpp b/src/cpp/src/speculative_decoding/speculative_decoding_stateful.hpp index 0de9950eef..6ff3b5508f 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_stateful.hpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_stateful.hpp @@ -34,7 +34,7 @@ class LLMInferWrapper { int64_t infer_next(int64_t out_token, bool append_perf_stat = false); - std::vector infer_next_return_all(const std::vector tokens); + std::vector infer_next_return_all(const std::vector& tokens); ov::Tensor get_logits(); From 0678708791b51dcff416e60d6019f70d5b8e5040 Mon Sep 17 00:00:00 2001 From: Anastasiya Pronina Date: Wed, 1 Oct 2025 15:56:44 +0100 Subject: [PATCH 33/40] Fixed setting of `max_new_tokens` for draft model --- .../speculative_decoding/speculative_decoding_stateful.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_stateful.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_stateful.cpp index 55e0d52ce6..2135ddbf52 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_stateful.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_stateful.cpp @@ -535,6 +535,10 @@ EncodedResults StatefulSpeculativeLLMPipeline::generate( ov::genai::GenerationConfig draft_config = m_draft_request->get_generation_config(); draft_config.ignore_eos = true; draft_config.stop_strings = {}; + // Need to set `max_new_tokens` as GenerationConfig requires it if `ignore_eos` is true. + // However, this parameter won't be utilized in pipeline, only main's `max_new_tokens` + // will be utilized. + draft_config.max_new_tokens = config.get_max_new_tokens(); draft_config.validate(); m_draft_request->set_generation_config(draft_config); From f7ff986060999cf9a0b4cfed15057eee37e4504d Mon Sep 17 00:00:00 2001 From: Anastasiya Pronina Date: Thu, 2 Oct 2025 18:13:57 +0100 Subject: [PATCH 34/40] Used SmolLM2-360M as main model in tests --- .../test_stateful_speculative_decoding.py | 33 +++++-------------- 1 file changed, 8 insertions(+), 25 deletions(-) diff --git a/tests/python_tests/test_stateful_speculative_decoding.py b/tests/python_tests/test_stateful_speculative_decoding.py index 7c64cfabbb..d2f04966ac 100644 --- a/tests/python_tests/test_stateful_speculative_decoding.py +++ b/tests/python_tests/test_stateful_speculative_decoding.py @@ -14,29 +14,21 @@ from utils.comparation import compare_generation_results from utils.ov_genai_pipelines import create_ov_pipeline, generate_and_compare, get_main_pipeline_types, PipelineType, convert_decoded_results_to_generation_result -test_cases = [ +models_and_input = [ + ("HuggingFaceTB/SmolLM2-360M", "HuggingFaceTB/SmolLM2-135M", "Alan Turing was a")] +devices = [ ('CPU', 'CPU'), ('CPU', 'NPUW:CPU'), ('NPUW:CPU', 'CPU'), ('NPUW:CPU', 'NPUW:CPU') ] -@pytest.mark.parametrize("main_device,draft_device", test_cases) +@pytest.mark.parametrize("main_model,draft_model,prompt", models_and_input) +@pytest.mark.parametrize("main_device,draft_device", devices) @pytest.mark.precommit -def test_string_inputs(main_device, draft_device): - # FIXME: For now SmolLM2-135M is used as a main and a draft model in the test. - # However, it is more desirable to use SmolLM2-360M as a main one to simulate the real case - # for speculative decoding. - # It seems like temporary directory from model downloading stage isn't removed after test - # launch for SmolLM2-360M model, that is why it is not used now. - MODEL_UNDER_TEST = { - "name": "HuggingFaceTB/SmolLM2-135M", - "convert_args": ['--trust-remote-code'] - } - prompt = "Alan Turing was a" - +def test_string_inputs(main_model, main_device, draft_model, draft_device, prompt): # Download and convert model: - main_opt_model, main_hf_tokenizer, main_model_path = download_and_convert_model(MODEL_UNDER_TEST["name"]) - draft_model_path = main_model_path + main_opt_model, main_hf_tokenizer, main_model_path = download_and_convert_model(main_model) + __, __, draft_model_path = download_and_convert_model(draft_model) # Create OpenVINO GenAI pipeline: draft_config = get_default_llm_properties() @@ -44,22 +36,13 @@ def test_string_inputs(main_device, draft_device): draft_device = "NPU" draft_config["NPUW_DEVICES"] = "CPU" draft_config["GENERATE_HINT"] = "BEST_PERF" - # FIXME: Currently, the same draft and main model fails to work in NPUW_WEIGHTS_BANK: shared mode. - # To workaround this, we name banks differently for draft and main. - draft_config["NPUW_WEIGHTS_BANK"] = "draft" ov_draft_model = ov_genai.draft_model(draft_model_path, draft_device, **draft_config) main_config = get_default_llm_properties() if main_device == "NPUW:CPU": main_device = "NPU" main_config["NPUW_DEVICES"] = "CPU" - # FIXME: SmolLM-135M with GENERATE_HINT: FAST_COMPILE will output garbage on NPUW:CPU if used with configuration - # NPUW_LLM_MAX_GENERATION_TOKEN_LEN > 1. - # Setting GENERATE_HINT: BEST_PERF to workaround an issue currently. main_config["GENERATE_HINT"] = "BEST_PERF" - # FIXME: Currently, the same draft and main model fails to work in NPUW_WEIGHTS_BANK: shared mode. - # To workaround this, we name banks differently for draft and main. - main_config["NPUW_WEIGHTS_BANK"] = "main" main_config["ATTENTION_BACKEND"] = "SDPA" ov_pipe = ov_genai.LLMPipeline(main_model_path, main_device, main_config, draft_model=ov_draft_model) From 88aed09f85c8b21d31d0388954f21f425f6cab69 Mon Sep 17 00:00:00 2001 From: Anastasiya Pronina Date: Tue, 7 Oct 2025 11:49:03 +0100 Subject: [PATCH 35/40] Added assert on launch of StatefulSpeculativeLLMPipeline with GPU --- src/cpp/src/llm/pipeline.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/cpp/src/llm/pipeline.cpp b/src/cpp/src/llm/pipeline.cpp index 954942d5b1..8e34471a87 100644 --- a/src/cpp/src/llm/pipeline.cpp +++ b/src/cpp/src/llm/pipeline.cpp @@ -92,6 +92,9 @@ static std::unique_ptr create( auto properties_without_draft_model = properties; auto draft_model_descr = ov::genai::utils::extract_draft_model_from_config(properties_without_draft_model); if (draft_model_descr.model != nullptr) { + OPENVINO_ASSERT(device != "GPU" && draft_model_descr.device != "GPU", + "Speculative Decoding with \"ATTENTION_BACKEND\" : \"SDPA\" or any of the models on NPU " + "doesn't support GPU device either for main or draft models currently!"); auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model, {}, generation_config); return std::make_unique(main_model_descr, draft_model_descr); } From e74f25e174f7d2d85108648fd20cce5a4bb5f548 Mon Sep 17 00:00:00 2001 From: Anastasiya Pronina Date: Tue, 7 Oct 2025 13:56:24 +0100 Subject: [PATCH 36/40] Restrict StatefulSpeculativeLLMPipeline to launch only if NPU specified for one or both the models --- .../speculative_decoding_lm.cpp | 2 +- .../speculative_decoding_lm.py | 2 +- src/cpp/src/llm/pipeline.cpp | 19 +++++--- src/cpp/src/utils.cpp | 8 ++++ .../test_stateful_speculative_decoding.py | 46 +++++++++++-------- .../python_tests/utils/ov_genai_pipelines.py | 15 ++---- 6 files changed, 53 insertions(+), 39 deletions(-) diff --git a/samples/cpp/text_generation/speculative_decoding_lm.cpp b/samples/cpp/text_generation/speculative_decoding_lm.cpp index 044b9c843e..d0822dd221 100644 --- a/samples/cpp/text_generation/speculative_decoding_lm.cpp +++ b/samples/cpp/text_generation/speculative_decoding_lm.cpp @@ -31,7 +31,7 @@ int main(int argc, char* argv[]) try { // User can run main and draft model on different devices. // Please, set device for main model in `LLMPipeline` constructor and in in `ov::genai::draft_model` for draft. // CPU, GPU and NPU can be used. Please be aware that GPU is performant only with Continuous Batching pipeline, so it is not recommented - // to use it in conjuction with NPU or in configuration when main model doesn't work in Paged Attention mode. + // to use it in conjuction with NPU. std::string main_device = "CPU", draft_device = "CPU"; ov::genai::LLMPipeline pipe( diff --git a/samples/python/text_generation/speculative_decoding_lm.py b/samples/python/text_generation/speculative_decoding_lm.py index 58a1954b6c..7aeb099e39 100755 --- a/samples/python/text_generation/speculative_decoding_lm.py +++ b/samples/python/text_generation/speculative_decoding_lm.py @@ -21,7 +21,7 @@ def main(): # User can run main and draft model on different devices. # Please, set device for main model in `openvino_genai.LLMPipeline` constructor and in `openvino_genai.draft_model` for draft. # CPU, GPU and NPU can be used. Please be aware that GPU is performant only with Continuous Batching pipeline, so it is not - # recommented to use it in conjuction with NPU or in configuration when main model doesn't work in Paged Attention mode. + # recommented to use it in conjuction with NPU. main_device = 'CPU' draft_device = 'CPU' diff --git a/src/cpp/src/llm/pipeline.cpp b/src/cpp/src/llm/pipeline.cpp index 8e34471a87..d294d219e2 100644 --- a/src/cpp/src/llm/pipeline.cpp +++ b/src/cpp/src/llm/pipeline.cpp @@ -92,9 +92,10 @@ static std::unique_ptr create( auto properties_without_draft_model = properties; auto draft_model_descr = ov::genai::utils::extract_draft_model_from_config(properties_without_draft_model); if (draft_model_descr.model != nullptr) { - OPENVINO_ASSERT(device != "GPU" && draft_model_descr.device != "GPU", - "Speculative Decoding with \"ATTENTION_BACKEND\" : \"SDPA\" or any of the models on NPU " - "doesn't support GPU device either for main or draft models currently!"); + // FIXME: Add support for StatefulSpeculativeLLMPipeline for non-NPU devices for both models. + OPENVINO_ASSERT(device == "NPU" || draft_model_descr.device == "NPU", + "Stateful Speculative Decoding is expected to be launched when NPU is requsted as " + "execution device for one or both models."); auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model, {}, generation_config); return std::make_unique(main_model_descr, draft_model_descr); } @@ -144,7 +145,9 @@ ov::genai::LLMPipeline::LLMPipeline( } if (m_pimpl == nullptr) { - m_pimpl = StatefulPipeline::create(models_path, tokenizer, device, properties); + // FIXME: Switch to StatefulPipeline::create after resolving issues + // with GPU and CPU for StatefulSpeculativeLLMPipeline + m_pimpl = std::make_unique(models_path, tokenizer, device, properties); } m_pimpl->save_load_time(start_time); @@ -179,7 +182,9 @@ ov::genai::LLMPipeline::LLMPipeline( } if (m_pimpl == nullptr) { - m_pimpl = StatefulPipeline::create(models_path, device, properties); + // FIXME: Switch to StatefulPipeline::create after resolving issues + // with GPU and CPU for StatefulSpeculativeLLMPipeline + m_pimpl = std::make_unique(models_path, device, properties); } m_pimpl->save_load_time(start_time); @@ -224,7 +229,9 @@ ov::genai::LLMPipeline::LLMPipeline( } if (m_pimpl == nullptr) { - m_pimpl = StatefulPipeline::create( + // FIXME: Switch to StatefulPipeline::create after resolving issues + // with GPU and CPU for StatefulSpeculativeLLMPipeline + m_pimpl = std::make_unique( utils::singleton_core().read_model(model_str, weights_tensor), tokenizer, device, diff --git a/src/cpp/src/utils.cpp b/src/cpp/src/utils.cpp index f17b05a2ce..1e55b322a7 100644 --- a/src/cpp/src/utils.cpp +++ b/src/cpp/src/utils.cpp @@ -643,6 +643,14 @@ bool explicitly_requires_paged_attention(const ov::AnyMap& properties) { } } + if (properties.find(utils::DRAFT_MODEL_ARG_NAME) != properties.end()) { + if (is_paged_attention_available()) { + return true; + } else { + OPENVINO_THROW("Speculative decoding on non-NPU devices requires PagedAttention operation support, which is available on x86_64 or ARM64 platforms only"); + } + } + auto prompt_lookup_prop = properties.find("prompt_lookup"); if (prompt_lookup_prop != properties.end() && prompt_lookup_prop->second.as() == true) { if (is_paged_attention_available()) { diff --git a/tests/python_tests/test_stateful_speculative_decoding.py b/tests/python_tests/test_stateful_speculative_decoding.py index d2f04966ac..9c6e92a4f0 100644 --- a/tests/python_tests/test_stateful_speculative_decoding.py +++ b/tests/python_tests/test_stateful_speculative_decoding.py @@ -14,13 +14,19 @@ from utils.comparation import compare_generation_results from utils.ov_genai_pipelines import create_ov_pipeline, generate_and_compare, get_main_pipeline_types, PipelineType, convert_decoded_results_to_generation_result +def get_npu_llm_properties_for_test(): + config = get_default_llm_properties() + config["NPUW_DEVICES"] = "CPU" + config["GENERATE_HINT"] = "BEST_PERF" + return config + models_and_input = [ ("HuggingFaceTB/SmolLM2-360M", "HuggingFaceTB/SmolLM2-135M", "Alan Turing was a")] devices = [ - ('CPU', 'CPU'), - ('CPU', 'NPUW:CPU'), - ('NPUW:CPU', 'CPU'), - ('NPUW:CPU', 'NPUW:CPU') + # FIXME: add 'CPU' and 'GPU' cases in future + ('CPU', 'NPU'), + ('NPU', 'CPU'), + ('NPU', 'NPU') ] @pytest.mark.parametrize("main_model,draft_model,prompt", models_and_input) @pytest.mark.parametrize("main_device,draft_device", devices) @@ -31,19 +37,14 @@ def test_string_inputs(main_model, main_device, draft_model, draft_device, promp __, __, draft_model_path = download_and_convert_model(draft_model) # Create OpenVINO GenAI pipeline: - draft_config = get_default_llm_properties() - if draft_device == "NPUW:CPU": - draft_device = "NPU" - draft_config["NPUW_DEVICES"] = "CPU" - draft_config["GENERATE_HINT"] = "BEST_PERF" + draft_config = get_npu_llm_properties_for_test() \ + if (draft_device == "NPU") else \ + get_default_llm_properties() ov_draft_model = ov_genai.draft_model(draft_model_path, draft_device, **draft_config) - main_config = get_default_llm_properties() - if main_device == "NPUW:CPU": - main_device = "NPU" - main_config["NPUW_DEVICES"] = "CPU" - main_config["GENERATE_HINT"] = "BEST_PERF" - main_config["ATTENTION_BACKEND"] = "SDPA" + main_config = get_npu_llm_properties_for_test() \ + if (main_device == "NPU") else \ + get_default_llm_properties() ov_pipe = ov_genai.LLMPipeline(main_model_path, main_device, main_config, draft_model=ov_draft_model) # Run reference HF model: @@ -65,10 +66,14 @@ def test_perf_metrics(): import time start_time = time.perf_counter() model_id = 'katuni4ka/tiny-random-gemma2' - generation_config = ov_genai.GenerationConfig(do_sample=False, max_new_tokens=20, ignore_eos=True, num_assistant_tokens=5) _, _, model_path = download_and_convert_model(model_id) - ov_pipe = create_ov_pipeline(model_path, pipeline_type=PipelineType.STATEFUL_SPECULATIVE_DECODING) + + # Create OpenVINO GenAI pipeline: + ov_draft_model = ov_genai.draft_model(model_path, "NPU", **get_npu_llm_properties_for_test()) + ov_pipe = ov_genai.LLMPipeline(model_path, "NPU", get_npu_llm_properties_for_test(), draft_model=ov_draft_model) + prompt = 'table is made of' + generation_config = ov_genai.GenerationConfig(do_sample=False, max_new_tokens=20, ignore_eos=True, num_assistant_tokens=5) perf_metrics = ov_pipe.generate([prompt], generation_config).perf_metrics total_time = (time.perf_counter() - start_time) * 1000 @@ -141,9 +146,12 @@ def test_extended_perf_metrics(): import time start_time = time.perf_counter() model_id : str = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" - generation_config = ov_genai.GenerationConfig(do_sample=False, max_new_tokens=20, ignore_eos=True, num_assistant_tokens=5) _, _, model_path = download_and_convert_model(model_id) - ov_pipe = create_ov_pipeline(model_path, pipeline_type=PipelineType.STATEFUL_SPECULATIVE_DECODING) + + ov_draft_model = ov_genai.draft_model(model_path, "NPU", **get_npu_llm_properties_for_test()) + ov_pipe = ov_genai.LLMPipeline(model_path, "NPU", get_npu_llm_properties_for_test(), draft_model=ov_draft_model) + + generation_config = ov_genai.GenerationConfig(do_sample=False, max_new_tokens=20, ignore_eos=True, num_assistant_tokens=5) extended_perf_metrics = ov_pipe.generate(["Why is the Sun yellow?"], generation_config).extended_perf_metrics total_time = (time.perf_counter() - start_time) * 1000 diff --git a/tests/python_tests/utils/ov_genai_pipelines.py b/tests/python_tests/utils/ov_genai_pipelines.py index 46e27612a0..167ff94bac 100644 --- a/tests/python_tests/utils/ov_genai_pipelines.py +++ b/tests/python_tests/utils/ov_genai_pipelines.py @@ -41,15 +41,13 @@ class PipelineType(Enum): PAGED_ATTENTION = 2 CONTINUOUS_BATCHING = 3 SPECULATIVE_DECODING = 4 - STATEFUL_SPECULATIVE_DECODING = 5 - PROMPT_LOOKUP_DECODING = 6 - AUTO = 7 + PROMPT_LOOKUP_DECODING = 5 + AUTO = 6 def get_all_pipeline_types(): - return [PipelineType.STATEFUL, PipelineType.PAGED_ATTENTION, PipelineType.CONTINUOUS_BATCHING, PipelineType.SPECULATIVE_DECODING, PipelineType.STATEFUL_SPECULATIVE_DECODING, PipelineType.PROMPT_LOOKUP_DECODING, PipelineType.AUTO] + return [PipelineType.STATEFUL, PipelineType.PAGED_ATTENTION, PipelineType.CONTINUOUS_BATCHING, PipelineType.SPECULATIVE_DECODING, PipelineType.PROMPT_LOOKUP_DECODING, PipelineType.AUTO] -# TODO: Add PipelineType.STATEFUL_SPECULATIVE_DECODING, make its tests green. def get_main_pipeline_types(): return [PipelineType.STATEFUL, PipelineType.PAGED_ATTENTION, PipelineType.SPECULATIVE_DECODING, PipelineType.PROMPT_LOOKUP_DECODING] @@ -99,10 +97,6 @@ def create_ov_pipeline(models_path: Path, elif pipeline_type == PipelineType.SPECULATIVE_DECODING: ov_draft_model = draft_model(models_path) if draft_model_path is None else draft_model(draft_model_path) return LLMPipeline(models_path, device, ov_config, scheduler_config=scheduler_config, draft_model=ov_draft_model) - elif pipeline_type == PipelineType.STATEFUL_SPECULATIVE_DECODING: - ov_draft_model = draft_model(models_path) if draft_model_path is None else draft_model(draft_model_path) - ov_config["ATTENTION_BACKEND"] = "SDPA" - return LLMPipeline(models_path, device, ov_config, draft_model=ov_draft_model) elif pipeline_type == PipelineType.PROMPT_LOOKUP_DECODING: return LLMPipeline(models_path, device, ov_config, scheduler_config=scheduler_config, prompt_lookup=True) else: @@ -133,9 +127,6 @@ def prepare_generation_config_by_pipe_type(generation_config : GenerationConfig, if pipeline_type == PipelineType.SPECULATIVE_DECODING: assert not generation_config.is_beam_search() generation_config.assistant_confidence_threshold = 0.9 - elif pipeline_type == PipelineType.STATEFUL_SPECULATIVE_DECODING: - assert not generation_config.is_beam_search() - generation_config.num_assistant_tokens = 5 elif pipeline_type == PipelineType.PROMPT_LOOKUP_DECODING: assert not generation_config.is_beam_search() generation_config.num_assistant_tokens = 5 From 8ca0d6b258f9bdc653080ff4ce3e06fdb386573b Mon Sep 17 00:00:00 2001 From: Anastasiya Pronina Date: Fri, 10 Oct 2025 14:53:11 +0100 Subject: [PATCH 37/40] Addressed comments --- samples/cpp/text_generation/speculative_decoding_lm.cpp | 6 +++--- samples/python/text_generation/speculative_decoding_lm.py | 3 +-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/samples/cpp/text_generation/speculative_decoding_lm.cpp b/samples/cpp/text_generation/speculative_decoding_lm.cpp index d0822dd221..093c62973f 100644 --- a/samples/cpp/text_generation/speculative_decoding_lm.cpp +++ b/samples/cpp/text_generation/speculative_decoding_lm.cpp @@ -29,9 +29,9 @@ int main(int argc, char* argv[]) try { std::string prompt = argv[3]; // User can run main and draft model on different devices. - // Please, set device for main model in `LLMPipeline` constructor and in in `ov::genai::draft_model` for draft. - // CPU, GPU and NPU can be used. Please be aware that GPU is performant only with Continuous Batching pipeline, so it is not recommented - // to use it in conjuction with NPU. + // Please, set device for main model in `LLMPipeline` constructor and in `ov::genai::draft_model` for draft. + // CPU, GPU and NPU can be used. For NPU, the preferred configuration is when both the main and draft models + // use NPU. std::string main_device = "CPU", draft_device = "CPU"; ov::genai::LLMPipeline pipe( diff --git a/samples/python/text_generation/speculative_decoding_lm.py b/samples/python/text_generation/speculative_decoding_lm.py index 7aeb099e39..0fd0eb8a9e 100755 --- a/samples/python/text_generation/speculative_decoding_lm.py +++ b/samples/python/text_generation/speculative_decoding_lm.py @@ -20,8 +20,7 @@ def main(): # User can run main and draft model on different devices. # Please, set device for main model in `openvino_genai.LLMPipeline` constructor and in `openvino_genai.draft_model` for draft. - # CPU, GPU and NPU can be used. Please be aware that GPU is performant only with Continuous Batching pipeline, so it is not - # recommented to use it in conjuction with NPU. + # CPU, GPU and NPU can be used. For NPU, the preferred configuration is when both the main and draft models use NPU. main_device = 'CPU' draft_device = 'CPU' From 3862c96dde75f77b53bc91529da0d925cfc5c378 Mon Sep 17 00:00:00 2001 From: Anastasiya Pronina Date: Fri, 10 Oct 2025 16:04:45 +0100 Subject: [PATCH 38/40] Last polishing --- .../speculative_decoding_stateful.cpp | 6 +++--- src/cpp/src/utils.cpp | 2 +- tests/python_tests/test_stateful_speculative_decoding.py | 9 ++++----- tools/llm_bench/llm_bench_utils/model_utils.py | 5 +++++ 4 files changed, 13 insertions(+), 9 deletions(-) diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_stateful.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_stateful.cpp index 2135ddbf52..b11161c9e6 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_stateful.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_stateful.cpp @@ -524,7 +524,7 @@ EncodedResults StatefulSpeculativeLLMPipeline::generate( config.validate(); OPENVINO_ASSERT(config.is_greedy_decoding(), - "Currently only greedy decoding are supported"); + "Currently only greedy decoding is supported"); OPENVINO_ASSERT(config.num_return_sequences == 1u, "Currently only \"num_return_sequences\" equal to 1 is supported!"); @@ -686,8 +686,8 @@ EncodedResults StatefulSpeculativeLLMPipeline::generate( // single infer request: last token from previous main inference + all candidates // from the draft stage. // - // Note on model's return variable: If model isn't sliced to return logit only - // for the last element, then it returns logits for all elements of the input + // Note on model's return variable: If model isn't sliced to return only + // certain logits, then it returns logits for all elements of the input // prompt. In that tensor, for each token `t` of the input prompt it contains // distribution (over the vocabulary) for the next possible token that is // generated based on subsequence [first token,...,`t`] of the input prompt. diff --git a/src/cpp/src/utils.cpp b/src/cpp/src/utils.cpp index 1e55b322a7..681ffc5c96 100644 --- a/src/cpp/src/utils.cpp +++ b/src/cpp/src/utils.cpp @@ -647,7 +647,7 @@ bool explicitly_requires_paged_attention(const ov::AnyMap& properties) { if (is_paged_attention_available()) { return true; } else { - OPENVINO_THROW("Speculative decoding on non-NPU devices requires PagedAttention operation support, which is available on x86_64 or ARM64 platforms only"); + OPENVINO_THROW("Speculative decoding requires PagedAttention operation support on non-NPU devices, which is available on x86_64 or ARM64 platforms only"); } } diff --git a/tests/python_tests/test_stateful_speculative_decoding.py b/tests/python_tests/test_stateful_speculative_decoding.py index 9c6e92a4f0..0d62f8018c 100644 --- a/tests/python_tests/test_stateful_speculative_decoding.py +++ b/tests/python_tests/test_stateful_speculative_decoding.py @@ -1,18 +1,17 @@ -# Copyright (C) 2023-2025 Intel Corporation +# Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 import pytest -import numpy as np -import logging +import numpy as np import openvino as ov import openvino_genai as ov_genai from utils.constants import get_default_llm_properties -from utils.hugging_face import generation_config_to_hf, download_and_convert_model, run_hugging_face +from utils.hugging_face import download_and_convert_model, run_hugging_face from utils.comparation import compare_generation_results -from utils.ov_genai_pipelines import create_ov_pipeline, generate_and_compare, get_main_pipeline_types, PipelineType, convert_decoded_results_to_generation_result +from utils.ov_genai_pipelines import convert_decoded_results_to_generation_result def get_npu_llm_properties_for_test(): config = get_default_llm_properties() diff --git a/tools/llm_bench/llm_bench_utils/model_utils.py b/tools/llm_bench/llm_bench_utils/model_utils.py index 8fdaad05d0..af1b9e344c 100644 --- a/tools/llm_bench/llm_bench_utils/model_utils.py +++ b/tools/llm_bench/llm_bench_utils/model_utils.py @@ -209,6 +209,11 @@ def analyze_args(args): if args.cb_config: cb_config = get_config(args.cb_config) model_args["cb_config"] = cb_config + if args.draft_model: + if args.device != "NPU" and args.draft_model.device != "NPU": + if model_args['config']['ATTENTION_BACKEND'] != PA_ATTENTION_BACKEND: + log.warning("Speculative Decoding is supported only with Paged Attention Backend for non-NPU devices") + args.draft_model = None model_args['draft_model'] = args.draft_model model_args['draft_device'] = args.draft_device draft_cb_config = None From 63d38e8fe840cc790d8feb22f6ac92a78a61b765 Mon Sep 17 00:00:00 2001 From: Anastasiya Pronina Date: Mon, 13 Oct 2025 04:56:01 +0100 Subject: [PATCH 39/40] Fixed llm_bench --- tools/llm_bench/llm_bench_utils/model_utils.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tools/llm_bench/llm_bench_utils/model_utils.py b/tools/llm_bench/llm_bench_utils/model_utils.py index af1b9e344c..715700ffc3 100644 --- a/tools/llm_bench/llm_bench_utils/model_utils.py +++ b/tools/llm_bench/llm_bench_utils/model_utils.py @@ -210,10 +210,9 @@ def analyze_args(args): cb_config = get_config(args.cb_config) model_args["cb_config"] = cb_config if args.draft_model: - if args.device != "NPU" and args.draft_model.device != "NPU": - if model_args['config']['ATTENTION_BACKEND'] != PA_ATTENTION_BACKEND: - log.warning("Speculative Decoding is supported only with Paged Attention Backend for non-NPU devices") - args.draft_model = None + if (args.draft_device != "NPU" and args.device != "NPU" and model_args['config']['ATTENTION_BACKEND'] != PA_ATTENTION_BACKEND): + log.warning("Speculative Decoding is supported only with Paged Attention Backend for non-NPU devices") + args.draft_model = None model_args['draft_model'] = args.draft_model model_args['draft_device'] = args.draft_device draft_cb_config = None From 7f6b318e510026e7b07fce31092d4aac5a720468 Mon Sep 17 00:00:00 2001 From: Anastasiya Pronina Date: Mon, 13 Oct 2025 05:53:16 +0100 Subject: [PATCH 40/40] Fixed 'explicitly_requires_paged_attention()' check --- src/cpp/src/llm/pipeline.cpp | 16 ++++++++++------ src/cpp/src/utils.cpp | 9 +++++---- src/cpp/src/utils.hpp | 4 ++-- 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/src/cpp/src/llm/pipeline.cpp b/src/cpp/src/llm/pipeline.cpp index d294d219e2..fd1eb4763d 100644 --- a/src/cpp/src/llm/pipeline.cpp +++ b/src/cpp/src/llm/pipeline.cpp @@ -123,9 +123,11 @@ ov::genai::LLMPipeline::LLMPipeline( const ov::AnyMap& user_properties) : m_device(device) { auto start_time = std::chrono::steady_clock::now(); - auto [properties, attention_backend] = utils::extract_attention_backend(user_properties); - if (ov::genai::utils::is_npu_requested(device, properties)) { + bool is_npu_requested = ov::genai::utils::is_npu_requested(device, user_properties); + auto [properties, attention_backend] = utils::extract_attention_backend(user_properties, is_npu_requested); + + if (is_npu_requested) { m_pimpl = StatefulPipeline::create(models_path, tokenizer, device, properties); } else if (utils::explicitly_requires_paged_attention(user_properties)) { // If CB is invoked explicitly, create CB adapter as is and re-throw in case if internal issues @@ -160,9 +162,10 @@ ov::genai::LLMPipeline::LLMPipeline( m_device(device) { auto start_time = std::chrono::steady_clock::now(); - auto [properties, attention_backend] = utils::extract_attention_backend(user_properties); + bool is_npu_requested = ov::genai::utils::is_npu_requested(device, user_properties); + auto [properties, attention_backend] = utils::extract_attention_backend(user_properties, is_npu_requested); - if (ov::genai::utils::is_npu_requested(device, properties)) { + if (is_npu_requested) { m_pimpl = StatefulPipeline::create(models_path, device, properties); } else if (utils::explicitly_requires_paged_attention(user_properties)) { // If CB is invoked explicitly, create CB adapter as is and re-throw in case if internal issues @@ -200,9 +203,10 @@ ov::genai::LLMPipeline::LLMPipeline( m_device(device) { auto start_time = std::chrono::steady_clock::now(); - auto [properties, attention_backend] = utils::extract_attention_backend(user_properties); + bool is_npu_requested = ov::genai::utils::is_npu_requested(device, user_properties); + auto [properties, attention_backend] = utils::extract_attention_backend(user_properties, is_npu_requested); - if (ov::genai::utils::is_npu_requested(device, properties)) { + if (is_npu_requested) { m_pimpl = StatefulPipeline::create( utils::singleton_core().read_model(model_str, weights_tensor), tokenizer, diff --git a/src/cpp/src/utils.cpp b/src/cpp/src/utils.cpp index 681ffc5c96..c23131b4f1 100644 --- a/src/cpp/src/utils.cpp +++ b/src/cpp/src/utils.cpp @@ -631,7 +631,7 @@ SchedulerConfig get_latency_oriented_scheduler_config() { return default_config; } -bool explicitly_requires_paged_attention(const ov::AnyMap& properties) { +bool explicitly_requires_paged_attention(const ov::AnyMap& properties, bool is_npu_requested) { auto attention_backend_it = properties.find("ATTENTION_BACKEND"); if (properties.find(ov::genai::scheduler_config.name()) != properties.end() || @@ -643,7 +643,7 @@ bool explicitly_requires_paged_attention(const ov::AnyMap& properties) { } } - if (properties.find(utils::DRAFT_MODEL_ARG_NAME) != properties.end()) { + if (properties.find(utils::DRAFT_MODEL_ARG_NAME) != properties.end() && !is_npu_requested) { if (is_paged_attention_available()) { return true; } else { @@ -662,7 +662,8 @@ bool explicitly_requires_paged_attention(const ov::AnyMap& properties) { return false; } -std::pair extract_attention_backend(const ov::AnyMap& external_properties) { +std::pair extract_attention_backend(const ov::AnyMap& external_properties, + bool is_npu_requested) { std::string attention_backend = PA_BACKEND; ov::AnyMap properties = external_properties; @@ -674,7 +675,7 @@ std::pair extract_attention_backend(const ov::AnyMap& e properties.erase(it); } - if (explicitly_requires_paged_attention(external_properties)) { + if (explicitly_requires_paged_attention(external_properties, is_npu_requested)) { OPENVINO_ASSERT(attention_backend == PA_BACKEND, "User properties are conflicting: some of them requires PagedAttention backend, while 'ATTENTION_BACKEND' is set to 'SDPA'"); } diff --git a/src/cpp/src/utils.hpp b/src/cpp/src/utils.hpp index 471de1192d..9bcfd1361f 100644 --- a/src/cpp/src/utils.hpp +++ b/src/cpp/src/utils.hpp @@ -278,9 +278,9 @@ std::pair extract_scheduler_config(const ov::AnyMap SchedulerConfig get_latency_oriented_scheduler_config(); -bool explicitly_requires_paged_attention(const ov::AnyMap& properties); +bool explicitly_requires_paged_attention(const ov::AnyMap& properties, bool is_npu_requested = false); -std::pair extract_attention_backend(const ov::AnyMap& external_properties); +std::pair extract_attention_backend(const ov::AnyMap& external_properties, bool is_npu_requested = false); void save_openvino_model(const std::shared_ptr& model, const std::string& save_path, bool compress_to_fp16);