Skip to content

Commit 8267424

Browse files
committed
Fixed last comments and added tests
1 parent 90a4124 commit 8267424

File tree

6 files changed

+227
-9
lines changed

6 files changed

+227
-9
lines changed

samples/cpp/text_generation/speculative_decoding_lm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ int main(int argc, char* argv[]) try {
1818
// NOTE: ContinuousBatching backend uses `num_assistant_tokens` as is. Stateful backend uses `num_assistant_tokens`'s copy as initial
1919
// value and adjusts it based on recent number of accepted tokens. If `num_assistant_tokens` is not set, it defaults to `5` for both
2020
// backends.
21-
// config.num_assistant_tokens = 5;
21+
config.num_assistant_tokens = 4;
2222
// Add parameter to enable speculative decoding to generate candidates by draft_model while candidate probability is higher than
2323
// `assistant_confidence_threshold`.
2424
// NOTE: `assistant_confidence_threshold` is supported only by ContinuousBatching backend.

samples/python/text_generation/speculative_decoding_lm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def main():
3636
# NOTE: ContinuousBatching backend uses `num_assistant_tokens` as is. Stateful backend uses `num_assistant_tokens`'s copy as initial
3737
# value and adjusts it based on recent number of accepted tokens. If `num_assistant_tokens` is not set, it defaults to `5` for both
3838
# backends.
39-
# config.num_assistant_tokens = 5
39+
config.num_assistant_tokens = 4
4040
# Add parameter to enable speculative decoding to generate candidates by draft_model while candidate probability is higher than
4141
# `assistant_confidence_threshold`.
4242
# NOTE: `assistant_confidence_threshold` is supported only by ContinuousBatching backend.

src/cpp/src/speculative_decoding/speculative_decoding_stateful.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -705,7 +705,7 @@ EncodedResults StatefulSpeculativeLLMPipeline::generate(
705705
auto& main_perf_generated_tokens = m_main_request->raw_perf_metrics.m_batch_sizes.back();
706706
main_perf_generated_tokens -= mismatched_candidates;
707707
m_sd_metrics.update_draft_generated_len(0 /* request_id */, candidates_to_generate);
708-
m_sd_metrics.update_acceptance_rate(0 /* request_id */, (accepted_tokens_number / candidates_to_generate) * 100);
708+
m_sd_metrics.update_acceptance_rate(0 /* request_id */, (accepted_tokens_number * 100.f) / candidates_to_generate);
709709
m_sd_metrics.update_draft_accepted_tokens(0 /* request_id */, accepted_tokens_number);
710710
m_sd_metrics.update_generated_len(validated_tokens.size());
711711
if (utils::env_setup_for_print_debug_info()) {

tests/python_tests/test_sampling.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from openvino_genai import GenerationConfig, StopCriteria
1111

12-
from utils.ov_genai_pipelines import generate_and_compare, run_ov_pipeline, get_main_pipeline_types
12+
from utils.ov_genai_pipelines import generate_and_compare, run_ov_pipeline, get_main_pipeline_types, PipelineType
1313
from utils.hugging_face import download_and_convert_model
1414

1515
@pytest.mark.precommit
@@ -51,7 +51,12 @@ def test_basic_stop_criteria(generation_config, prompt):
5151
"multiple_stop_strings_exclude_from_output",
5252
"multiple_stop_strings_include_to_output",
5353
"multiple_stop_strings_one_no_match_and_long_exclude_from_output"])
54-
@pytest.mark.parametrize("pipeline_type", get_main_pipeline_types())
54+
# FIXME: PipelineType.STATEFUL_SPECULATIVE_DECODING currently fails these tests
55+
def main_pipelines_wo_stateful_speculative():
56+
main_pipe_types = get_main_pipeline_types()
57+
main_pipe_types.remove(PipelineType.STATEFUL_SPECULATIVE_DECODING)
58+
return main_pipe_types
59+
@pytest.mark.parametrize("pipeline_type", main_pipelines_wo_stateful_speculative())
5560
def test_stop_strings(generation_config, model_id, pipeline_type):
5661
prompts = [ "What is OpenVINO?" ]
5762
generate_and_compare(model_id, prompts, generation_config, pipeline_type=pipeline_type)
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
# Copyright (C) 2023-2025 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
5+
import pytest
6+
import numpy as np
7+
import logging
8+
9+
import openvino as ov
10+
import openvino_genai as ov_genai
11+
12+
from utils.constants import get_default_llm_properties
13+
from utils.hugging_face import generation_config_to_hf, download_and_convert_model, run_hugging_face
14+
from utils.comparation import compare_generation_results
15+
from utils.ov_genai_pipelines import create_ov_pipeline, generate_and_compare, get_main_pipeline_types, PipelineType, convert_decoded_results_to_generation_result
16+
17+
test_cases = [
18+
('CPU', 'CPU'),
19+
('CPU', 'NPUW:CPU'),
20+
('NPUW:CPU', 'CPU'),
21+
('NPUW:CPU', 'NPUW:CPU')
22+
]
23+
@pytest.mark.parametrize("main_device,draft_device", test_cases)
24+
@pytest.mark.precommit
25+
def test_string_inputs(main_device, draft_device):
26+
# FIXME: For now SmolLM2-135M is used as a main and a draft model in the test.
27+
# However, it is more desirable to use SmolLM2-360M as a main one to simulate the real case
28+
# for speculative decoding.
29+
# It seems like temporary directory from model downloading stage isn't removed after test
30+
# launch for SmolLM2-360M model, that is why it is not used now.
31+
MODEL_UNDER_TEST = {
32+
"name": "HuggingFaceTB/SmolLM2-135M",
33+
"convert_args": ['--trust-remote-code']
34+
}
35+
prompt = "Alan Turing was a"
36+
37+
# Download and convert model:
38+
main_opt_model, main_hf_tokenizer, main_model_path = download_and_convert_model(MODEL_UNDER_TEST["name"])
39+
draft_model_path = main_model_path
40+
41+
# Create OpenVINO GenAI pipeline:
42+
draft_config = get_default_llm_properties()
43+
if draft_device == "NPUW:CPU":
44+
draft_device = "NPU"
45+
draft_config["NPUW_DEVICES"] = "CPU"
46+
draft_config["GENERATE_HINT"] = "BEST_PERF"
47+
# FIXME: Currently, the same draft and main model fails to work in NPUW_WEIGHTS_BANK: shared mode.
48+
# To workaround this, we name banks differently for draft and main.
49+
draft_config["NPUW_WEIGHTS_BANK"] = "draft"
50+
ov_draft_model = ov_genai.draft_model(draft_model_path, draft_device, **draft_config)
51+
52+
main_config = get_default_llm_properties()
53+
if main_device == "NPUW:CPU":
54+
main_device = "NPU"
55+
main_config["NPUW_DEVICES"] = "CPU"
56+
# FIXME: SmolLM-135M with GENERATE_HINT: FAST_COMPILE will output garbage on NPUW:CPU if used with configuration
57+
# NPUW_LLM_MAX_GENERATION_TOKEN_LEN > 1.
58+
# Setting GENERATE_HINT: BEST_PERF to workaround an issue currently.
59+
main_config["GENERATE_HINT"] = "BEST_PERF"
60+
# FIXME: Currently, the same draft and main model fails to work in NPUW_WEIGHTS_BANK: shared mode.
61+
# To workaround this, we name banks differently for draft and main.
62+
main_config["NPUW_WEIGHTS_BANK"] = "main"
63+
main_config["ATTENTION_BACKEND"] = "SDPA"
64+
ov_pipe = ov_genai.LLMPipeline(main_model_path, main_device, main_config, draft_model=ov_draft_model)
65+
66+
# Run reference HF model:
67+
ov_generation_config = ov_genai.GenerationConfig(max_new_tokens=20)
68+
main_hf_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
69+
ref_gen_results = run_hugging_face(main_opt_model, main_hf_tokenizer, [prompt], ov_generation_config)
70+
71+
# Run OpenVINO GenAI pipeline:
72+
ov_decoded_results = ov_pipe.generate([prompt], ov_generation_config)
73+
ov_gen_results = convert_decoded_results_to_generation_result(ov_decoded_results, 1, 1, False)
74+
75+
del ov_pipe
76+
77+
# Compare results:
78+
compare_generation_results([prompt], ref_gen_results, ov_gen_results, ov_generation_config)
79+
80+
@pytest.mark.precommit
81+
def test_perf_metrics():
82+
import time
83+
start_time = time.perf_counter()
84+
model_id = 'katuni4ka/tiny-random-gemma2'
85+
generation_config = ov_genai.GenerationConfig(do_sample=False, max_new_tokens=20, ignore_eos=True, num_assistant_tokens=5)
86+
_, _, model_path = download_and_convert_model(model_id)
87+
ov_pipe = create_ov_pipeline(model_path, pipeline_type=PipelineType.STATEFUL_SPECULATIVE_DECODING)
88+
prompt = 'table is made of'
89+
perf_metrics = ov_pipe.generate([prompt], generation_config).perf_metrics
90+
total_time = (time.perf_counter() - start_time) * 1000
91+
92+
# Check that load time is adequate.
93+
load_time = perf_metrics.get_load_time()
94+
assert load_time > 0 and load_time < total_time
95+
96+
# Check that num input and generated tokens are adequate.
97+
num_generated_tokens = perf_metrics.get_num_generated_tokens()
98+
assert num_generated_tokens > 0 and num_generated_tokens <= generation_config.max_new_tokens
99+
100+
num_input_tokens = perf_metrics.get_num_input_tokens()
101+
assert num_input_tokens > 0 and num_input_tokens <= len(prompt)
102+
103+
mean_ttft, std_ttft = perf_metrics.get_ttft()
104+
assert (mean_ttft, std_ttft) == (perf_metrics.get_ttft().mean, perf_metrics.get_ttft().std)
105+
assert mean_ttft > 0 and mean_ttft < 1000.0
106+
107+
raw_metrics = perf_metrics.raw_metrics
108+
durations = np.array(raw_metrics.m_durations) / 1000
109+
# Check that prefill is not included in durations for TPOT calculation.
110+
# For the very long prompt prefill is slow and TTFT is much larger than any other token generation duration.
111+
assert np.all(mean_ttft > durations)
112+
113+
mean_tpot, std_tpot = perf_metrics.get_tpot()
114+
assert (mean_tpot, std_tpot) == (perf_metrics.get_tpot().mean, perf_metrics.get_tpot().std)
115+
assert mean_tpot > 0 and mean_ttft < 1000.0
116+
117+
mean_throughput, std_throughput = perf_metrics.get_throughput()
118+
assert (mean_throughput, std_throughput) == (perf_metrics.get_throughput().mean, perf_metrics.get_throughput().std)
119+
assert mean_throughput > 0 and mean_throughput < 20000.0
120+
121+
mean_gen_duration, std_gen_duration = perf_metrics.get_generate_duration()
122+
assert (mean_gen_duration, std_gen_duration) == (perf_metrics.get_generate_duration().mean, perf_metrics.get_generate_duration().std)
123+
assert mean_gen_duration > 0 and load_time + mean_gen_duration < total_time
124+
assert std_gen_duration == 0
125+
126+
mean_tok_duration, std_tok_duration = perf_metrics.get_tokenization_duration()
127+
assert (mean_tok_duration, std_tok_duration) == (perf_metrics.get_tokenization_duration().mean, perf_metrics.get_tokenization_duration().std)
128+
assert mean_tok_duration > 0 and mean_tok_duration < mean_gen_duration
129+
assert std_tok_duration == 0
130+
131+
mean_detok_duration, std_detok_duration = perf_metrics.get_detokenization_duration()
132+
assert (mean_detok_duration, std_detok_duration) == (perf_metrics.get_detokenization_duration().mean, perf_metrics.get_detokenization_duration().std)
133+
assert mean_detok_duration > 0 and mean_detok_duration < mean_gen_duration
134+
assert std_detok_duration == 0
135+
136+
# assert that calculating statistics manually from the raw counters we get the same restults as from PerfMetrics
137+
assert np.allclose(mean_tpot, np.mean(durations))
138+
assert np.allclose(std_tpot, np.std(durations), atol=0.00002)
139+
140+
raw_dur = np.array(raw_metrics.generate_durations) / 1000
141+
assert np.allclose(mean_gen_duration, np.mean(raw_dur))
142+
assert np.allclose(std_gen_duration, np.std(raw_dur))
143+
144+
raw_dur = np.array(raw_metrics.tokenization_durations) / 1000
145+
assert np.allclose(mean_tok_duration, np.mean(raw_dur))
146+
assert np.allclose(std_tok_duration, np.std(raw_dur))
147+
148+
raw_dur = np.array(raw_metrics.detokenization_durations) / 1000
149+
assert np.allclose(mean_detok_duration, np.mean(raw_dur))
150+
assert np.allclose(std_detok_duration, np.std(raw_dur))
151+
152+
assert len(raw_metrics.m_times_to_first_token) > 0
153+
assert len(raw_metrics.m_batch_sizes) > 0
154+
assert len(raw_metrics.m_durations) > 0
155+
156+
@pytest.mark.precommit
157+
def test_extended_perf_metrics():
158+
import time
159+
start_time = time.perf_counter()
160+
model_id : str = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
161+
generation_config = ov_genai.GenerationConfig(do_sample=False, max_new_tokens=20, ignore_eos=True, num_assistant_tokens=5)
162+
_, _, model_path = download_and_convert_model(model_id)
163+
ov_pipe = create_ov_pipeline(model_path, pipeline_type=PipelineType.STATEFUL_SPECULATIVE_DECODING)
164+
extended_perf_metrics = ov_pipe.generate(["Why is the Sun yellow?"], generation_config).extended_perf_metrics
165+
total_time = (time.perf_counter() - start_time) * 1000
166+
167+
assert not extended_perf_metrics is None
168+
assert not extended_perf_metrics.main_model_metrics is None
169+
assert not extended_perf_metrics.draft_model_metrics is None
170+
171+
assert extended_perf_metrics.get_num_accepted_tokens() > 0
172+
173+
num_generated_tokens_main = extended_perf_metrics.main_model_metrics.get_num_generated_tokens()
174+
assert num_generated_tokens_main > 0 and num_generated_tokens_main <= generation_config.max_new_tokens
175+
176+
num_generated_tokens_draft = extended_perf_metrics.draft_model_metrics.get_num_generated_tokens()
177+
# As Stateful Speculative Decoding pipeline is dynamically adjusting its number of candidates at
178+
# each step, here we check that generated tokens is less than upper candidates limit multiplied by
179+
# maximum number of generated tokens.
180+
assert num_generated_tokens_draft > 0 and \
181+
num_generated_tokens_draft < ((generation_config.max_new_tokens - 1) * \
182+
generation_config.num_assistant_tokens * 2 + 1)
183+
184+
total_iteration_number_main = len(extended_perf_metrics.main_model_metrics.raw_metrics.m_durations)
185+
assert total_iteration_number_main > 0 and total_iteration_number_main <= generation_config.max_new_tokens
186+
187+
total_iteration_number_draft = len(extended_perf_metrics.draft_model_metrics.raw_metrics.m_durations)
188+
assert total_iteration_number_draft > 0 and \
189+
total_iteration_number_draft < ((generation_config.max_new_tokens - 1) * \
190+
generation_config.num_assistant_tokens * 2 + 1)
191+
192+
for model_metrics in [extended_perf_metrics.main_model_metrics, extended_perf_metrics.draft_model_metrics]:
193+
mean_ttst, std_ttst = model_metrics.get_ttst()
194+
assert (mean_ttst, std_ttst) == (model_metrics.get_ttst().mean, model_metrics.get_ttst().std)
195+
assert mean_ttst > 0 and mean_ttst < model_metrics.get_ttft().mean
196+
assert std_ttst == 0
197+
198+
mean_latency, std_latency = model_metrics.get_latency()
199+
assert (mean_latency, std_latency) == (model_metrics.get_latency().mean, model_metrics.get_latency().std)
200+
assert mean_latency > 0 and mean_latency < 1000.0
201+
202+
mean_gen_duration, std_gen_duration = model_metrics.get_generate_duration()
203+
assert (mean_gen_duration, std_gen_duration) == (model_metrics.get_generate_duration().mean, model_metrics.get_generate_duration().std)
204+
assert mean_gen_duration > 0 and mean_gen_duration < total_time
205+
assert std_gen_duration == 0

tests/python_tests/utils/ov_genai_pipelines.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,16 @@ class PipelineType(Enum):
4141
PAGED_ATTENTION = 2
4242
CONTINUOUS_BATCHING = 3
4343
SPECULATIVE_DECODING = 4
44-
PROMPT_LOOKUP_DECODING = 5
45-
AUTO = 6
44+
STATEFUL_SPECULATIVE_DECODING = 5
45+
PROMPT_LOOKUP_DECODING = 6
46+
AUTO = 7
4647

4748

4849
def get_all_pipeline_types():
49-
return [PipelineType.STATEFUL, PipelineType.PAGED_ATTENTION, PipelineType.CONTINUOUS_BATCHING, PipelineType.SPECULATIVE_DECODING, PipelineType.PROMPT_LOOKUP_DECODING, PipelineType.AUTO]
50+
return [PipelineType.STATEFUL, PipelineType.PAGED_ATTENTION, PipelineType.CONTINUOUS_BATCHING, PipelineType.SPECULATIVE_DECODING, PipelineType.STATEFUL_SPECULATIVE_DECODING, PipelineType.PROMPT_LOOKUP_DECODING, PipelineType.AUTO]
5051

5152
def get_main_pipeline_types():
52-
return [PipelineType.STATEFUL, PipelineType.PAGED_ATTENTION, PipelineType.SPECULATIVE_DECODING, PipelineType.PROMPT_LOOKUP_DECODING]
53+
return [PipelineType.STATEFUL, PipelineType.PAGED_ATTENTION, PipelineType.SPECULATIVE_DECODING, PipelineType.STATEFUL_SPECULATIVE_DECODING, PipelineType.PROMPT_LOOKUP_DECODING]
5354

5455
def get_gguf_pipeline_types():
5556
return [PipelineType.STATEFUL, PipelineType.PAGED_ATTENTION]
@@ -97,6 +98,10 @@ def create_ov_pipeline(models_path: Path,
9798
elif pipeline_type == PipelineType.SPECULATIVE_DECODING:
9899
ov_draft_model = draft_model(models_path) if draft_model_path is None else draft_model(draft_model_path)
99100
return LLMPipeline(models_path, device, ov_config, scheduler_config=scheduler_config, draft_model=ov_draft_model)
101+
elif pipeline_type == PipelineType.STATEFUL_SPECULATIVE_DECODING:
102+
ov_draft_model = draft_model(models_path) if draft_model_path is None else draft_model(draft_model_path)
103+
ov_config["ATTENTION_BACKEND"] = "SDPA"
104+
return LLMPipeline(models_path, device, ov_config, draft_model=ov_draft_model)
100105
elif pipeline_type == PipelineType.PROMPT_LOOKUP_DECODING:
101106
return LLMPipeline(models_path, device, ov_config, scheduler_config=scheduler_config, prompt_lookup=True)
102107
else:
@@ -127,6 +132,9 @@ def prepare_generation_config_by_pipe_type(generation_config : GenerationConfig,
127132
if pipeline_type == PipelineType.SPECULATIVE_DECODING:
128133
assert not generation_config.is_beam_search()
129134
generation_config.assistant_confidence_threshold = 0.9
135+
elif pipeline_type == PipelineType.STATEFUL_SPECULATIVE_DECODING:
136+
assert not generation_config.is_beam_search()
137+
generation_config.num_assistant_tokens = 2
130138
elif pipeline_type == PipelineType.PROMPT_LOOKUP_DECODING:
131139
assert not generation_config.is_beam_search()
132140
generation_config.num_assistant_tokens = 5

0 commit comments

Comments
 (0)