Skip to content

Commit e72008c

Browse files
committed
Restrict StatefulSpeculativeLLMPipeline to launch only if NPU specified for one or both the models
1 parent ec10cb7 commit e72008c

File tree

6 files changed

+64
-41
lines changed

6 files changed

+64
-41
lines changed

samples/cpp/text_generation/speculative_decoding_lm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ int main(int argc, char* argv[]) try {
3131
// User can run main and draft model on different devices.
3232
// Please, set device for main model in `LLMPipeline` constructor and in in `ov::genai::draft_model` for draft.
3333
// CPU, GPU and NPU can be used. Please be aware that GPU is performant only with Continuous Batching pipeline, so it is not recommented
34-
// to use it in conjuction with NPU or in configuration when main model doesn't work in Paged Attention mode.
34+
// to use it in conjuction with NPU.
3535
std::string main_device = "CPU", draft_device = "CPU";
3636

3737
ov::genai::LLMPipeline pipe(

samples/python/text_generation/speculative_decoding_lm.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@ def main():
2121
# User can run main and draft model on different devices.
2222
# Please, set device for main model in `openvino_genai.LLMPipeline` constructor and in `openvino_genai.draft_model` for draft.
2323
# CPU, GPU and NPU can be used. Please be aware that GPU is performant only with Continuous Batching pipeline, so it is not
24-
# recommented to use it in conjuction with NPU or in configuration when main model doesn't work in Paged Attention mode.
24+
# recommented to use it in conjuction with NPU.
2525
main_device = 'CPU'
2626
draft_device = 'CPU'
2727

2828
draft_model = openvino_genai.draft_model(args.draft_model_dir, draft_device)
2929

30-
pipe = openvino_genai.LLMPipeline(args.model_dir, main_device, draft_model=draft_model)
30+
pipe = openvino_genai.LLMPipeline(args.model_dir, "CPU", draft_model=draft_model)
3131

3232
config = openvino_genai.GenerationConfig()
3333
config.max_new_tokens = 100
@@ -69,5 +69,15 @@ def main():
6969
print(f" Total iteration number: {len(draft_model_metrics.raw_metrics.m_durations)}")
7070
print()
7171

72+
print(f"DRAFT MODEL" )
73+
print(f" Generate time: {draft_model_metrics.get_generate_duration().mean:.2f} ms" )
74+
print(f" TTFT: {draft_model_metrics.get_ttft().mean:.2f} ms")
75+
print(f" TTST: {draft_model_metrics.get_ttst().mean:.2f} ms/token")
76+
print(f" TPOT: {draft_model_metrics.get_tpot().mean:.2f} ± {draft_model_metrics.get_tpot().std:.2f} ms/token")
77+
print(f" AVG Latency: {draft_model_metrics.get_latency().mean:.2f} ± {draft_model_metrics.get_latency().std:.2f} ms/iteration")
78+
print(f" Num generated token: {draft_model_metrics.get_num_generated_tokens()} tokens")
79+
print(f" Total iteration number: {len(draft_model_metrics.raw_metrics.m_durations)}")
80+
print()
81+
7282
if '__main__' == __name__:
7383
main()

src/cpp/src/llm/pipeline.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,10 @@ static std::unique_ptr<LLMPipelineImplBase> create(
9292
auto properties_without_draft_model = properties;
9393
auto draft_model_descr = ov::genai::utils::extract_draft_model_from_config(properties_without_draft_model);
9494
if (draft_model_descr.model != nullptr) {
95-
OPENVINO_ASSERT(device != "GPU" && draft_model_descr.device != "GPU",
96-
"Speculative Decoding with \"ATTENTION_BACKEND\" : \"SDPA\" or any of the models on NPU "
97-
"doesn't support GPU device either for main or draft models currently!");
95+
// FIXME: Add support for StatefulSpeculativeLLMPipeline for non-NPU devices for both models.
96+
OPENVINO_ASSERT(device == "NPU" || draft_model_descr.device == "NPU",
97+
"Stateful Speculative Decoding is expected to be launched when NPU is requsted as "
98+
"execution device for one or both models.");
9899
auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model, {}, generation_config);
99100
return std::make_unique<StatefulSpeculativeLLMPipeline>(main_model_descr, draft_model_descr);
100101
}
@@ -144,7 +145,9 @@ ov::genai::LLMPipeline::LLMPipeline(
144145
}
145146

146147
if (m_pimpl == nullptr) {
147-
m_pimpl = StatefulPipeline::create(models_path, tokenizer, device, properties);
148+
// FIXME: Switch to StatefulPipeline::create after resolving issues
149+
// with GPU and CPU for StatefulSpeculativeLLMPipeline
150+
m_pimpl = std::make_unique<StatefulLLMPipeline>(models_path, tokenizer, device, properties);
148151
}
149152

150153
m_pimpl->save_load_time(start_time);
@@ -158,7 +161,6 @@ ov::genai::LLMPipeline::LLMPipeline(
158161
auto start_time = std::chrono::steady_clock::now();
159162

160163
auto [properties, attention_backend] = utils::extract_attention_backend(user_properties);
161-
162164
if (ov::genai::utils::is_npu_requested(device, properties)) {
163165
m_pimpl = StatefulPipeline::create(models_path, device, properties);
164166
} else if (utils::explicitly_requires_paged_attention(user_properties)) {
@@ -179,7 +181,9 @@ ov::genai::LLMPipeline::LLMPipeline(
179181
}
180182

181183
if (m_pimpl == nullptr) {
182-
m_pimpl = StatefulPipeline::create(models_path, device, properties);
184+
// FIXME: Switch to StatefulPipeline::create after resolving issues
185+
// with GPU and CPU for StatefulSpeculativeLLMPipeline
186+
m_pimpl = std::make_unique<StatefulLLMPipeline>(models_path, device, properties);
183187
}
184188

185189
m_pimpl->save_load_time(start_time);
@@ -224,7 +228,9 @@ ov::genai::LLMPipeline::LLMPipeline(
224228
}
225229

226230
if (m_pimpl == nullptr) {
227-
m_pimpl = StatefulPipeline::create(
231+
// FIXME: Switch to StatefulPipeline::create after resolving issues
232+
// with GPU and CPU for StatefulSpeculativeLLMPipeline
233+
m_pimpl = std::make_unique<StatefulLLMPipeline>(
228234
utils::singleton_core().read_model(model_str, weights_tensor),
229235
tokenizer,
230236
device,

src/cpp/src/utils.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,14 @@ bool explicitly_requires_paged_attention(const ov::AnyMap& properties) {
643643
}
644644
}
645645

646+
if (properties.find(utils::DRAFT_MODEL_ARG_NAME) != properties.end()) {
647+
if (is_paged_attention_available()) {
648+
return true;
649+
} else {
650+
OPENVINO_THROW("Speculative decoding on non-NPU devices requires PagedAttention operation support, which is available on x86_64 or ARM64 platforms only");
651+
}
652+
}
653+
646654
auto prompt_lookup_prop = properties.find("prompt_lookup");
647655
if (prompt_lookup_prop != properties.end() && prompt_lookup_prop->second.as<bool>() == true) {
648656
if (is_paged_attention_available()) {

tests/python_tests/test_stateful_speculative_decoding.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,19 @@
1414
from utils.comparation import compare_generation_results
1515
from utils.ov_genai_pipelines import create_ov_pipeline, generate_and_compare, get_main_pipeline_types, PipelineType, convert_decoded_results_to_generation_result
1616

17+
def get_npu_llm_properties_for_test():
18+
config = get_default_llm_properties()
19+
config["NPUW_DEVICES"] = "CPU"
20+
config["GENERATE_HINT"] = "BEST_PERF"
21+
return config
22+
1723
models_and_input = [
1824
("HuggingFaceTB/SmolLM2-360M", "HuggingFaceTB/SmolLM2-135M", "Alan Turing was a")]
1925
devices = [
20-
('CPU', 'CPU'),
21-
('CPU', 'NPUW:CPU'),
22-
('NPUW:CPU', 'CPU'),
23-
('NPUW:CPU', 'NPUW:CPU')
26+
# FIXME: add 'CPU' and 'GPU' cases in future
27+
('CPU', 'NPU'),
28+
('NPU', 'CPU'),
29+
('NPU', 'NPU')
2430
]
2531
@pytest.mark.parametrize("main_model,draft_model,prompt", models_and_input)
2632
@pytest.mark.parametrize("main_device,draft_device", devices)
@@ -31,19 +37,14 @@ def test_string_inputs(main_model, main_device, draft_model, draft_device, promp
3137
__, __, draft_model_path = download_and_convert_model(draft_model)
3238

3339
# Create OpenVINO GenAI pipeline:
34-
draft_config = get_default_llm_properties()
35-
if draft_device == "NPUW:CPU":
36-
draft_device = "NPU"
37-
draft_config["NPUW_DEVICES"] = "CPU"
38-
draft_config["GENERATE_HINT"] = "BEST_PERF"
40+
draft_config = get_npu_llm_properties_for_test() \
41+
if (draft_device == "NPU") else \
42+
get_default_llm_properties()
3943
ov_draft_model = ov_genai.draft_model(draft_model_path, draft_device, **draft_config)
4044

41-
main_config = get_default_llm_properties()
42-
if main_device == "NPUW:CPU":
43-
main_device = "NPU"
44-
main_config["NPUW_DEVICES"] = "CPU"
45-
main_config["GENERATE_HINT"] = "BEST_PERF"
46-
main_config["ATTENTION_BACKEND"] = "SDPA"
45+
main_config = get_npu_llm_properties_for_test() \
46+
if (main_device == "NPU") else \
47+
get_default_llm_properties()
4748
ov_pipe = ov_genai.LLMPipeline(main_model_path, main_device, main_config, draft_model=ov_draft_model)
4849

4950
# Run reference HF model:
@@ -65,10 +66,14 @@ def test_perf_metrics():
6566
import time
6667
start_time = time.perf_counter()
6768
model_id = 'katuni4ka/tiny-random-gemma2'
68-
generation_config = ov_genai.GenerationConfig(do_sample=False, max_new_tokens=20, ignore_eos=True, num_assistant_tokens=5)
6969
_, _, model_path = download_and_convert_model(model_id)
70-
ov_pipe = create_ov_pipeline(model_path, pipeline_type=PipelineType.STATEFUL_SPECULATIVE_DECODING)
70+
71+
# Create OpenVINO GenAI pipeline:
72+
ov_draft_model = ov_genai.draft_model(model_path, "NPU", **get_npu_llm_properties_for_test())
73+
ov_pipe = ov_genai.LLMPipeline(model_path, "NPU", get_npu_llm_properties_for_test(), draft_model=ov_draft_model)
74+
7175
prompt = 'table is made of'
76+
generation_config = ov_genai.GenerationConfig(do_sample=False, max_new_tokens=20, ignore_eos=True, num_assistant_tokens=5)
7277
perf_metrics = ov_pipe.generate([prompt], generation_config).perf_metrics
7378
total_time = (time.perf_counter() - start_time) * 1000
7479

@@ -141,9 +146,12 @@ def test_extended_perf_metrics():
141146
import time
142147
start_time = time.perf_counter()
143148
model_id : str = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
144-
generation_config = ov_genai.GenerationConfig(do_sample=False, max_new_tokens=20, ignore_eos=True, num_assistant_tokens=5)
145149
_, _, model_path = download_and_convert_model(model_id)
146-
ov_pipe = create_ov_pipeline(model_path, pipeline_type=PipelineType.STATEFUL_SPECULATIVE_DECODING)
150+
151+
ov_draft_model = ov_genai.draft_model(model_path, "NPU", **get_npu_llm_properties_for_test())
152+
ov_pipe = ov_genai.LLMPipeline(model_path, "NPU", get_npu_llm_properties_for_test(), draft_model=ov_draft_model)
153+
154+
generation_config = ov_genai.GenerationConfig(do_sample=False, max_new_tokens=20, ignore_eos=True, num_assistant_tokens=5)
147155
extended_perf_metrics = ov_pipe.generate(["Why is the Sun yellow?"], generation_config).extended_perf_metrics
148156
total_time = (time.perf_counter() - start_time) * 1000
149157

tests/python_tests/utils/ov_genai_pipelines.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,13 @@ class PipelineType(Enum):
4141
PAGED_ATTENTION = 2
4242
CONTINUOUS_BATCHING = 3
4343
SPECULATIVE_DECODING = 4
44-
STATEFUL_SPECULATIVE_DECODING = 5
45-
PROMPT_LOOKUP_DECODING = 6
46-
AUTO = 7
44+
PROMPT_LOOKUP_DECODING = 5
45+
AUTO = 6
4746

4847

4948
def get_all_pipeline_types():
50-
return [PipelineType.STATEFUL, PipelineType.PAGED_ATTENTION, PipelineType.CONTINUOUS_BATCHING, PipelineType.SPECULATIVE_DECODING, PipelineType.STATEFUL_SPECULATIVE_DECODING, PipelineType.PROMPT_LOOKUP_DECODING, PipelineType.AUTO]
49+
return [PipelineType.STATEFUL, PipelineType.PAGED_ATTENTION, PipelineType.CONTINUOUS_BATCHING, PipelineType.SPECULATIVE_DECODING, PipelineType.PROMPT_LOOKUP_DECODING, PipelineType.AUTO]
5150

52-
# TODO: Add PipelineType.STATEFUL_SPECULATIVE_DECODING, make its tests green.
5351
def get_main_pipeline_types():
5452
return [PipelineType.STATEFUL, PipelineType.PAGED_ATTENTION, PipelineType.SPECULATIVE_DECODING, PipelineType.PROMPT_LOOKUP_DECODING]
5553

@@ -99,10 +97,6 @@ def create_ov_pipeline(models_path: Path,
9997
elif pipeline_type == PipelineType.SPECULATIVE_DECODING:
10098
ov_draft_model = draft_model(models_path) if draft_model_path is None else draft_model(draft_model_path)
10199
return LLMPipeline(models_path, device, ov_config, scheduler_config=scheduler_config, draft_model=ov_draft_model)
102-
elif pipeline_type == PipelineType.STATEFUL_SPECULATIVE_DECODING:
103-
ov_draft_model = draft_model(models_path) if draft_model_path is None else draft_model(draft_model_path)
104-
ov_config["ATTENTION_BACKEND"] = "SDPA"
105-
return LLMPipeline(models_path, device, ov_config, draft_model=ov_draft_model)
106100
elif pipeline_type == PipelineType.PROMPT_LOOKUP_DECODING:
107101
return LLMPipeline(models_path, device, ov_config, scheduler_config=scheduler_config, prompt_lookup=True)
108102
else:
@@ -133,9 +127,6 @@ def prepare_generation_config_by_pipe_type(generation_config : GenerationConfig,
133127
if pipeline_type == PipelineType.SPECULATIVE_DECODING:
134128
assert not generation_config.is_beam_search()
135129
generation_config.assistant_confidence_threshold = 0.9
136-
elif pipeline_type == PipelineType.STATEFUL_SPECULATIVE_DECODING:
137-
assert not generation_config.is_beam_search()
138-
generation_config.num_assistant_tokens = 5
139130
elif pipeline_type == PipelineType.PROMPT_LOOKUP_DECODING:
140131
assert not generation_config.is_beam_search()
141132
generation_config.num_assistant_tokens = 5

0 commit comments

Comments
 (0)