diff --git a/src/plugins/intel_npu/src/al/include/intel_npu/config/npuw.hpp b/src/plugins/intel_npu/src/al/include/intel_npu/config/npuw.hpp index 3d48f0dc77a3bc..6db5ed1f4b0b0c 100644 --- a/src/plugins/intel_npu/src/al/include/intel_npu/config/npuw.hpp +++ b/src/plugins/intel_npu/src/al/include/intel_npu/config/npuw.hpp @@ -107,6 +107,7 @@ DEFINE_OPT(NPUW_LLM, bool, false, npuw::llm::enabled, RunTime); DEFINE_OPT(NPUW_LLM_BATCH_DIM, uint32_t, 0, npuw::llm::batch_dim, RunTime); DEFINE_OPT(NPUW_LLM_SEQ_LEN_DIM, uint32_t, 2, npuw::llm::seq_len_dim, RunTime); DEFINE_OPT(NPUW_LLM_MAX_PROMPT_LEN, uint32_t, 1024, npuw::llm::max_prompt_len, RunTime); +DEFINE_OPT(NPUW_LLM_MAX_GENERATION_TOKEN_LEN, uint32_t, 1, npuw::llm::max_generation_token_len, RunTime); DEFINE_OPT(NPUW_LLM_MIN_RESPONSE_LEN, uint32_t, 128, npuw::llm::min_response_len, RunTime); DEFINE_OPT(NPUW_LLM_OPTIMIZE_V_TENSORS, bool, true, npuw::llm::optimize_v_tensors, RunTime); DEFINE_OPT(NPUW_LLM_CACHE_ROPE, bool, true, npuw::llm::cache_rope, CompileTime); diff --git a/src/plugins/intel_npu/src/al/include/intel_npu/npuw_private_properties.hpp b/src/plugins/intel_npu/src/al/include/intel_npu/npuw_private_properties.hpp index 245305c86468a6..2a487f50b5ae64 100644 --- a/src/plugins/intel_npu/src/al/include/intel_npu/npuw_private_properties.hpp +++ b/src/plugins/intel_npu/src/al/include/intel_npu/npuw_private_properties.hpp @@ -423,6 +423,14 @@ static constexpr ov::Property seq_len_dim{"NPUW_LLM_SEQ_LEN_DIM"}; */ static constexpr ov::Property max_prompt_len{"NPUW_LLM_MAX_PROMPT_LEN"}; +/** ++ * @brief ++ * Type: uint32_t. ++ * Desirable max input token length for generation. ++ * Default value: 1. ++ */ +static constexpr ov::Property max_generation_token_len{"NPUW_LLM_MAX_GENERATION_TOKEN_LEN"}; + /** * @brief * Type: uint32_t. diff --git a/src/plugins/intel_npu/src/al/src/config/npuw.cpp b/src/plugins/intel_npu/src/al/src/config/npuw.cpp index 359aff75225a9f..66f38c903f225f 100644 --- a/src/plugins/intel_npu/src/al/src/config/npuw.cpp +++ b/src/plugins/intel_npu/src/al/src/config/npuw.cpp @@ -66,6 +66,7 @@ void intel_npu::registerNPUWLLMOptions(OptionsDesc& desc) { desc.add(); desc.add(); desc.add(); + desc.add(); desc.add(); desc.add(); desc.add(); diff --git a/src/plugins/intel_npu/src/plugin/include/properties.hpp b/src/plugins/intel_npu/src/plugin/include/properties.hpp index d6f0b5f04fa9c3..4f27e8552e3df9 100644 --- a/src/plugins/intel_npu/src/plugin/include/properties.hpp +++ b/src/plugins/intel_npu/src/plugin/include/properties.hpp @@ -111,6 +111,7 @@ class Properties final { ov::intel_npu::npuw::llm::batch_dim.name(), ov::intel_npu::npuw::llm::seq_len_dim.name(), ov::intel_npu::npuw::llm::max_prompt_len.name(), + ov::intel_npu::npuw::llm::max_generation_token_len.name(), ov::intel_npu::npuw::llm::min_response_len.name(), ov::intel_npu::npuw::llm::optimize_v_tensors.name(), ov::intel_npu::npuw::llm::prefill_hint.name(), diff --git a/src/plugins/intel_npu/src/plugin/npuw/llm_compiled_model.cpp b/src/plugins/intel_npu/src/plugin/npuw/llm_compiled_model.cpp index 28f3302505f263..cdf886efbbced0 100644 --- a/src/plugins/intel_npu/src/plugin/npuw/llm_compiled_model.cpp +++ b/src/plugins/intel_npu/src/plugin/npuw/llm_compiled_model.cpp @@ -499,10 +499,13 @@ void reshape_to_static(std::shared_ptr model, model->reshape(new_shapes); } -void reshape_sliced_head_to_static(std::shared_ptr lm_head_model, const uint32_t& batch_dim) { - // We have only one input with dynamic shapes: output of Slice operation, and this output - // should have "1" for dimension representing number of embeddings to send to the matmul. - // Batch size should be also equal "1" for NPU. +void reshape_sliced_head_to_static(std::shared_ptr lm_head_model, + const uint32_t& batch_dim, + std::size_t max_generation_token_len) { + // We have only one input with dynamic shapes: output embeds. + // Output embeds should have "max_generation_token_len" for dimension representing + // number of embeddings to send to the matmul. Batch size should be equal to "1" + // for NPU. const auto& input = lm_head_model->input(0); const auto& partial_shape = input.get_partial_shape(); NPUW_ASSERT(partial_shape.size() == 3); @@ -512,7 +515,7 @@ void reshape_sliced_head_to_static(std::shared_ptr lm_head_model, con // Left dynamic axis will be for number of embeddings for (auto i = 0; i < new_shape.rank().get_length(); i++) { if (new_shape[i].is_dynamic()) { - new_shape[i] = 1; + new_shape[i] = max_generation_token_len; // Sanity check that only one left dimension is dynamic, as // another one should contain embedding space rank break; @@ -522,7 +525,9 @@ void reshape_sliced_head_to_static(std::shared_ptr lm_head_model, con lm_head_model->reshape(new_shape); } -void slice_out_embeds(std::shared_ptr model, const uint32_t& batch_dim) { +void slice_out_embeds(std::shared_ptr model, + const uint32_t& batch_dim, + std::size_t max_generation_token_len) { std::shared_ptr embed_result; for (auto&& output : model->outputs()) { if (output.get_any_name() == ov::npuw::LLMCompiledModel::output_embeds) { @@ -533,15 +538,16 @@ void slice_out_embeds(std::shared_ptr model, const uint32_t& batch_di if (embed_result) { auto shape = embed_result->input(0).get_shape(); // If shape.size() is 3, then last axis should be the Vocab size. - // But 1st and 2nd axis can mean different things. + // But 1st and 2nd axes can mean different things. // 1st axis can represent the batch size, while 2nd - the number of embeddings, // or vice-versa (in chatglm) if (shape.size() == 3) { uint32_t num_embeds_dim = 1 - batch_dim; - if (shape[num_embeds_dim] > 1) { - std::vector start_pos{static_cast(batch_dim * (shape[num_embeds_dim] - 1)), - static_cast(num_embeds_dim * (shape[num_embeds_dim] - 1)), - 0}; + if (shape[num_embeds_dim] > max_generation_token_len) { + std::vector start_pos{ + static_cast(batch_dim * (shape[num_embeds_dim] - max_generation_token_len)), + static_cast(num_embeds_dim * (shape[num_embeds_dim] - max_generation_token_len)), + 0}; std::vector stop_pos{static_cast(batch_dim * (shape[num_embeds_dim] - 1)) + 1, static_cast(num_embeds_dim * (shape[num_embeds_dim] - 1)) + 1, static_cast(shape[2])}; @@ -673,6 +679,9 @@ ov::AnyMap get_default_generate_config(const std::optional& npudesc, if (hint == ::intel_npu::npuw::llm::GenerateHint::FAST_COMPILE) { config.emplace("NPUW_UNFOLD_IREQS", "YES"); } + // We don't need slice out for kv cache model, especially for speculative decoding which need + // to generate more than 1 token for each inference + config.erase("NPUW_SLICE_OUT"); return config; } @@ -849,6 +858,10 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr& m KVAxesPosition axes{batch_dim, seq_len_dim}; uint32_t max_prompt_len = align_to(m_cfg.get<::intel_npu::NPUW_LLM_MAX_PROMPT_LEN>(), 64u); const uint32_t min_response_len = align_to(m_cfg.get<::intel_npu::NPUW_LLM_MIN_RESPONSE_LEN>(), 64u); + uint32_t max_generation_token_len = m_cfg.get<::intel_npu::NPUW_LLM_MAX_GENERATION_TOKEN_LEN>(); + if (max_generation_token_len != 1) { + max_generation_token_len = align_to(max_generation_token_len, 8u); + } // If chunk size covers the entire prompt, just follow the static behavior. // Otherwise, use chunking and align the prompt size to the chunk size. @@ -872,7 +885,9 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr& m LOG_VERB("Prefill chunk size: " << m_prefill_chunk_size); LOG_VERB("Maximum prompt length: " << max_prompt_len); - m_kvcache_desc = KVCacheDesc{max_prompt_len, max_prompt_len + min_response_len, 0u, seq_len_dim}; + m_kvcache_desc = + KVCacheDesc{max_prompt_len, max_prompt_len + min_response_len, 0u, seq_len_dim, max_generation_token_len}; + LOG_DEBUG("Make prefill model with static shapes"); m_max_lora_rank = m_cfg.get<::intel_npu::NPUW_LLM_MAX_LORA_RANK>(); if (m_use_chunk_prefill) { @@ -889,14 +904,18 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr& m m_max_lora_rank); } LOG_DEBUG("Make kvcache model with static shapes"); - reshape_to_static(kvcache_model, 1u, m_kvcache_desc.total_size, axes, m_max_lora_rank); + reshape_to_static(kvcache_model, + m_kvcache_desc.max_generation_token_len, + m_kvcache_desc.total_size, + axes, + m_max_lora_rank); if (lm_head_model) { LOG_DEBUG("Shared LM head: slice the prefill output"); - // KVCache model is already reshaped to [1, 1, embed size], so only apply slice to - // the Prefill model: - slice_out_embeds(prefill_model, axes.batch); + // KVCache model is already reshaped to [1, max_generation_token_len, embed size], + // so only apply slice to the Prefill model: + slice_out_embeds(prefill_model, axes.batch, m_kvcache_desc.max_generation_token_len); LOG_DEBUG("Make LM head model with static shapes"); - reshape_sliced_head_to_static(lm_head_model, axes.batch); + reshape_sliced_head_to_static(lm_head_model, axes.batch, m_kvcache_desc.max_generation_token_len); } LOG_DEBUG("5.1, decompose GroupQueryAttention OP"); @@ -1089,6 +1108,7 @@ void ov::npuw::LLMCompiledModel::serialize(std::ostream& stream, const ov::npuw: write(model_stream, m_kvcache_desc.total_size); write(model_stream, m_kvcache_desc.num_stored_tokens); write(model_stream, m_kvcache_desc.dim); + write(model_stream, m_kvcache_desc.max_generation_token_len); write(model_stream, m_kvcache_desc.v_tensors_transposed); write(model_stream, m_prefill_chunk_size); write(model_stream, m_use_chunk_prefill); @@ -1297,6 +1317,7 @@ std::shared_ptr ov::npuw::LLMCompiledModel::deserial read(model_stream, compiled->m_kvcache_desc.total_size); read(model_stream, compiled->m_kvcache_desc.num_stored_tokens); read(model_stream, compiled->m_kvcache_desc.dim); + read(model_stream, compiled->m_kvcache_desc.max_generation_token_len); read(model_stream, compiled->m_kvcache_desc.v_tensors_transposed); read(model_stream, compiled->m_prefill_chunk_size); read(model_stream, compiled->m_use_chunk_prefill); diff --git a/src/plugins/intel_npu/src/plugin/npuw/llm_compiled_model.hpp b/src/plugins/intel_npu/src/plugin/npuw/llm_compiled_model.hpp index fbeedcbe809969..36445f2858c246 100644 --- a/src/plugins/intel_npu/src/plugin/npuw/llm_compiled_model.hpp +++ b/src/plugins/intel_npu/src/plugin/npuw/llm_compiled_model.hpp @@ -24,6 +24,7 @@ class LLMCompiledModel : public ov::npuw::ICompiledModel { uint32_t total_size = 0u; uint32_t num_stored_tokens = 0u; uint32_t dim = 0u; + uint32_t max_generation_token_len = 0u; bool v_tensors_transposed = false; }; diff --git a/src/plugins/intel_npu/src/plugin/npuw/llm_infer_request.cpp b/src/plugins/intel_npu/src/plugin/npuw/llm_infer_request.cpp index abc5ed7f248579..8861c9d419a918 100644 --- a/src/plugins/intel_npu/src/plugin/npuw/llm_infer_request.cpp +++ b/src/plugins/intel_npu/src/plugin/npuw/llm_infer_request.cpp @@ -178,12 +178,9 @@ void pad_position_ids(const ov::SoPtr& padded_position_ids, const o OPENVINO_ASSERT(position_shape.size() <= 3); - size_t diff_dim = 0; - for (size_t i = 0; i < padded_shape.size(); ++i) { - if (padded_shape[i] != position_shape[i]) { - diff_dim = i; - break; - } + size_t diff_dim = position_shape.size() - 1; + for (size_t i = 0; i < diff_dim; ++i) { + OPENVINO_ASSERT(padded_shape[i] == position_shape[i]); } size_t keep_elements = padded_shape[diff_dim] - position_shape[diff_dim]; @@ -584,11 +581,34 @@ void ov::npuw::LLMInferRequest::update_kvcache_for( kvcache_desc.num_stored_tokens - num_tokens, kvcache_desc.num_stored_tokens); auto src_tensor = request->get_tensor(out_ports.at(output_name)); - copy_tensor_by_dim(src_tensor, dst_slice, kv_dim); + + // NOTE: Sometimes present kv layer can contain greater seq_len + // than was sent to be processed + uint32_t src_seq_len = static_cast(src_tensor->get_shape()[kv_dim]); + OPENVINO_ASSERT(num_tokens <= src_seq_len); + if (src_seq_len > num_tokens) { + auto src_slice = make_tensor_slice(src_tensor, kv_dim, src_seq_len - num_tokens, src_seq_len); + copy_tensor_by_dim(src_slice, dst_slice, kv_dim); + } else { + copy_tensor_by_dim(src_tensor, dst_slice, kv_dim); + } } LOG_DEBUG("Done."); } +void ov::npuw::LLMInferRequest::trim_kvcache_for_speculative_decoding(ov::SoPtr position_ids) { + auto& kvcache_desc = m_npuw_llm_compiled_model->m_kvcache_desc; + // FIXME: It can not work with OmniThinker for now. + OPENVINO_ASSERT((position_ids->get_shape().size() >= 2) && (position_ids->get_shape().back() >= 1)); + auto position_id = position_ids->data()[0]; + auto dirty_num = kvcache_desc.num_stored_tokens - static_cast(position_id); + if (dirty_num > 0) { + LOG_DEBUG("Trim kv cache from " << kvcache_desc.num_stored_tokens << " length" + << " to " << position_id << " length"); + } + kvcache_desc.num_stored_tokens -= dirty_num; +} + void ov::npuw::LLMInferRequest::clear_chunk_prefill_kv_cache() { const auto& prefill_compiled = m_prefill_request->get_compiled_model(); @@ -767,50 +787,68 @@ void ov::npuw::LLMInferRequest::infer_generate(ov::SoPtr input_ids, ov::SoPtr position_ids) { LOG_DEBUG("Calling inference for generate model..."); LOG_BLOCK(); + auto& kvcache_desc = m_npuw_llm_compiled_model->m_kvcache_desc; + uint32_t input_tokens_len = static_cast(input_ids->get_shape()[INPUT_IDS_SEQ_LEN_DIM]); + if (input_tokens_len > kvcache_desc.max_generation_token_len) { + OPENVINO_THROW("Input prompt length is greater than output \"NPUW_LLM_MAX_GENERATION_TOKEN_LEN\": ", + kvcache_desc.max_generation_token_len, + ".\nPlease adjust it."); + } if (!m_generate_initialized) { LOG_DEBUG("Copy kv-cache from prefill to generate model."); copy_kvcache(); - LOG_DEBUG("Prepare attention mask pattern."); - auto kv_attn_mask = m_kvcache_request->get_tensor(m_kvcache_in_ports.at(layer_names::attention_mask)); - fill_tensor(kv_attn_mask, 0); - // NOTE: Attention mask pattern for generate model requires last "1" to be in the end of the mask. - // We can safely set this "1" once and then copy on one "1" less in the infer_generate(). - kv_attn_mask->data()[m_npuw_llm_compiled_model->m_kvcache_desc.total_size - 1] = 1; - + LOG_DEBUG("Prepare inputs."); + fill_tensor_bytes(m_kvcache_request->get_tensor(m_kvcache_in_ports.at(m_input_ids_name)), 0u); + fill_tensor(m_kvcache_request->get_tensor(m_kvcache_in_ports.at(layer_names::attention_mask)), 0); + fill_tensor(m_kvcache_request->get_tensor(m_kvcache_in_ports.at(layer_names::position_ids)), 0); m_generate_initialized = true; } - auto& kvcache_desc = m_npuw_llm_compiled_model->m_kvcache_desc; // NB: KV-cache is full, further generation is impossible - if (kvcache_desc.num_stored_tokens == kvcache_desc.total_size) { + if (kvcache_desc.num_stored_tokens + input_tokens_len > kvcache_desc.total_size) { OPENVINO_THROW("KV-Cache is full."); } // FIXME: these tensors should be shared between the parent & child models - auto kv_input_ids = m_kvcache_request->get_tensor(m_kvcache_in_ports.at(m_input_ids_name)); // NB: input_ids can be either fp32(VLM) or i64(LLM) - std::copy_n(reinterpret_cast(input_ids->data()), - input_ids->get_byte_size(), - reinterpret_cast(kv_input_ids->data())); + auto kv_input_ids = m_kvcache_request->get_tensor(m_kvcache_in_ports.at(m_input_ids_name)); + // NOTE: As `input_tokens_len` can be less than the value of `max_generation_token_len`, which + // input layers of generation model are resized to, then we need to put + // `input_tokens_len` prompt to the right of `max_generation_token_len`-sized tensors. + // Attention mask should rule out all left unusable space. + std::copy_n( + reinterpret_cast(input_ids->data()), + input_ids->get_byte_size(), + reinterpret_cast(kv_input_ids->data()) + kv_input_ids->get_byte_size() - input_ids->get_byte_size()); - // NOTE: Attention mask pattern for generate model requires last "1" to be in the end of the mask. - // As it is already set above, here we copy on one "1" unit less. + // NOTE: Attention mask pattern for generate model requires the set of "1" + // units of length of the current prompt on the right (for present + // kv layers) and the set of "1" units of number of previously calculated + // tokens on the left (for past kv layers). auto kv_attn_mask = m_kvcache_request->get_tensor(m_kvcache_in_ports.at(layer_names::attention_mask)); - std::copy_n(attention_mask->data(), attention_mask->get_size() - 1, kv_attn_mask->data()); + std::copy_n(attention_mask->data(), + attention_mask->get_size() - input_tokens_len, + kv_attn_mask->data()); + if (input_tokens_len < kvcache_desc.max_generation_token_len) { + std::fill_n(kv_attn_mask->data() + kv_attn_mask->get_size() - kvcache_desc.max_generation_token_len, + kvcache_desc.max_generation_token_len - input_tokens_len, + 0); + } + std::fill_n(kv_attn_mask->data() + kv_attn_mask->get_size() - input_tokens_len, input_tokens_len, 1); auto kv_pos_ids = m_kvcache_request->get_tensor(m_kvcache_in_ports.at(layer_names::position_ids)); - std::copy_n(position_ids->data(), position_ids->get_size(), kv_pos_ids->data()); + pad_position_ids(kv_pos_ids, position_ids); m_kvcache_request->infer(); - kvcache_desc.num_stored_tokens += 1; + kvcache_desc.num_stored_tokens += input_tokens_len; if (m_lm_head_request) { LOG_DEBUG("Calling inference for LM head model asynchronously"); m_lm_head_request->start_async(); if (kvcache_desc.num_stored_tokens < kvcache_desc.total_size) { - update_kvcache_for(m_kvcache_request, m_kvcache_in_ports, m_kvcache_out_ports, 1); + update_kvcache_for(m_kvcache_request, m_kvcache_in_ports, m_kvcache_out_ports, input_tokens_len); } m_lm_head_request->wait(); LOG_DEBUG("Calling inference for LM head model -- done."); @@ -818,7 +856,7 @@ void ov::npuw::LLMInferRequest::infer_generate(ov::SoPtr input_ids, m_logits = m_lm_head_request->get_tensor(m_lm_head_logits_port); } else { if (kvcache_desc.num_stored_tokens < kvcache_desc.total_size) { - update_kvcache_for(m_kvcache_request, m_kvcache_in_ports, m_kvcache_out_ports, 1); + update_kvcache_for(m_kvcache_request, m_kvcache_in_ports, m_kvcache_out_ports, input_tokens_len); } m_logits = m_kvcache_request->get_tensor(m_kvcache_out_ports.at(layer_names::logits)); @@ -842,10 +880,31 @@ void ov::npuw::LLMInferRequest::infer() { OPENVINO_ASSERT(ov::element::i64 == position_ids->get_element_type()); // NB: Check the sequence length provided for input_ids - // in order to distinguish prefill / generate stages - if (input_ids->get_shape()[INPUT_IDS_SEQ_LEN_DIM] != 1) { + // and start position idx in order to distinguish prefill + // and generate stages. + // Notes for Speculative Decoding: + // 1. If model is a draft one in speculative decoding setting, + // we expect it to be launched for more than 1 token only once, + // while all other candidates to be generated consequentively + // on previous token output. + // 2. If model is a main one in speculative decoding setting, + // then it can be launched on multiple tokens at every iteration. + // The first iteration will take the input prompt of variable + // length in range [0, NPUW_LLM_MAX_PROMPT_LEN], while others + // will be launched on variable number of candidates in range + // [0, NPUW_LLM_MAX_GENERATION_TOKEN_LEN]. + // NPUW_LLM_MAX_GENERATION_TOKEN_LEN is much lesser than + // NPUW_LLM_MAX_PROMPT_LEN. So, for second and next iterations + // generate model will be utilized, that is reshaped to take + // NPUW_LLM_MAX_GENERATION_TOKEN_LEN tokens and output the same + // number of logits. + // The outcome of two items is that prefill and generate stages + // can be safely differentiated by start position id for + // both main and draft models. + if (input_ids->get_shape()[INPUT_IDS_SEQ_LEN_DIM] > 1 && position_ids->data()[0] == 0) { infer_prefill(input_ids, attention_mask, position_ids); } else { + trim_kvcache_for_speculative_decoding(position_ids); infer_generate(input_ids, attention_mask, position_ids); } } diff --git a/src/plugins/intel_npu/src/plugin/npuw/llm_infer_request.hpp b/src/plugins/intel_npu/src/plugin/npuw/llm_infer_request.hpp index 7197ef4f19fe54..86ccfef6e41700 100644 --- a/src/plugins/intel_npu/src/plugin/npuw/llm_infer_request.hpp +++ b/src/plugins/intel_npu/src/plugin/npuw/llm_infer_request.hpp @@ -52,6 +52,7 @@ class LLMInferRequest final : public ov::ISyncInferRequest { std::unordered_map> in_ports, std::unordered_map> out_ports, uint32_t tokens); + void trim_kvcache_for_speculative_decoding(ov::SoPtr position_ids); void infer_chunked_prefill(ov::SoPtr input_ids, ov::SoPtr attention_mask, diff --git a/src/plugins/intel_npu/src/plugin/npuw/serialization.hpp b/src/plugins/intel_npu/src/plugin/npuw/serialization.hpp index 613ad079d807bc..77e90c13f06c25 100644 --- a/src/plugins/intel_npu/src/plugin/npuw/serialization.hpp +++ b/src/plugins/intel_npu/src/plugin/npuw/serialization.hpp @@ -34,7 +34,7 @@ const constexpr ov::npuw::s11n::IndicatorType NPUW_COMPILED_MODEL_INDICATOR = const constexpr ov::npuw::s11n::IndicatorType NPUW_LLM_COMPILED_MODEL_INDICATOR = {char{0x4c}, char{0x4c}, char{0x4d}, char{0x43}, char{0x4d}, char{0x4f}}; -const constexpr char* NPUW_SERIALIZATION_VERSION = "0.8"; +const constexpr char* NPUW_SERIALIZATION_VERSION = "0.9"; // Forward declaration namespace intel_npu { diff --git a/src/plugins/intel_npu/src/plugin/src/plugin.cpp b/src/plugins/intel_npu/src/plugin/src/plugin.cpp index 4a6206c3731b4e..5f4d82cda9436d 100644 --- a/src/plugins/intel_npu/src/plugin/src/plugin.cpp +++ b/src/plugins/intel_npu/src/plugin/src/plugin.cpp @@ -329,6 +329,7 @@ void Plugin::init_options() { REGISTER_OPTION(NPUW_LLM_BATCH_DIM); REGISTER_OPTION(NPUW_LLM_SEQ_LEN_DIM); REGISTER_OPTION(NPUW_LLM_MAX_PROMPT_LEN); + REGISTER_OPTION(NPUW_LLM_MAX_GENERATION_TOKEN_LEN); REGISTER_OPTION(NPUW_LLM_MIN_RESPONSE_LEN); REGISTER_OPTION(NPUW_LLM_OPTIMIZE_V_TENSORS); REGISTER_OPTION(NPUW_LLM_CACHE_ROPE); diff --git a/src/plugins/intel_npu/src/plugin/src/properties.cpp b/src/plugins/intel_npu/src/plugin/src/properties.cpp index 507048ac090dcf..73be75a56caec2 100644 --- a/src/plugins/intel_npu/src/plugin/src/properties.cpp +++ b/src/plugins/intel_npu/src/plugin/src/properties.cpp @@ -446,6 +446,7 @@ void Properties::registerPluginProperties() { TRY_REGISTER_SIMPLE_PROPERTY(ov::intel_npu::npuw::llm::batch_dim, NPUW_LLM_BATCH_DIM); TRY_REGISTER_SIMPLE_PROPERTY(ov::intel_npu::npuw::llm::seq_len_dim, NPUW_LLM_SEQ_LEN_DIM); TRY_REGISTER_SIMPLE_PROPERTY(ov::intel_npu::npuw::llm::max_prompt_len, NPUW_LLM_MAX_PROMPT_LEN); + TRY_REGISTER_SIMPLE_PROPERTY(ov::intel_npu::npuw::llm::max_generation_token_len, NPUW_LLM_MAX_GENERATION_TOKEN_LEN); TRY_REGISTER_SIMPLE_PROPERTY(ov::intel_npu::npuw::llm::min_response_len, NPUW_LLM_MIN_RESPONSE_LEN); TRY_REGISTER_SIMPLE_PROPERTY(ov::intel_npu::npuw::llm::optimize_v_tensors, NPUW_LLM_OPTIMIZE_V_TENSORS); TRY_REGISTER_SIMPLE_PROPERTY(ov::intel_npu::npuw::llm::prefill_hint, NPUW_LLM_PREFILL_HINT);