Skip to content

Commit db5b244

Browse files
committed
Fix accuracy issue
1 parent 9d22e73 commit db5b244

File tree

2 files changed

+64
-35
lines changed

2 files changed

+64
-35
lines changed

tests/python_tests/test_kv_cache_eviction.py

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -270,15 +270,14 @@ def test_optimized_generation_longbench(device, test_struct):
270270
assert avg_optimization_ratio >= test_struct.avg_cache_usage_optimization_ratio
271271

272272

273-
MILEBENCH_CACHE_EVICTION_CONFIG = CacheEvictionConfig(start_size=32, recent_size=128, max_cache_size=672, aggregation_mode=AggregationMode.SUM)
273+
MILEBENCH_CACHE_EVICTION_CONFIG = CacheEvictionConfig(start_size=32, recent_size=64, max_cache_size=352, aggregation_mode=AggregationMode.SUM)
274274

275275
@pytest.mark.nightly
276276
@pytest.mark.parametrize("device", ["CPU", "GPU"])
277277
@pytest.mark.parametrize("test_struct", [
278-
BenchmarkTestData("ALFRED", 3.2, 2.0, 3.3),
279-
BenchmarkTestData("MMCoQA", 4, 1.6, 3.3),
280-
BenchmarkTestData("TextNeedleInAHaystack", 3.2, 2.0, 3.3),
281-
BenchmarkTestData("WikiVQA", 5.8, 1.29, 2.621),
278+
BenchmarkTestData("ALFRED", 0.011, 1.440, 1.574),
279+
BenchmarkTestData("MMCoQA", 0.032, 1.843, 1.620),
280+
BenchmarkTestData("WikiVQA", 0.032, 1.412, 1.527),
282281
])
283282
def test_optimized_generation_milebench(device, test_struct):
284283
seqs_per_request = 32
@@ -292,19 +291,13 @@ def test_optimized_generation_milebench(device, test_struct):
292291
if scheduler_config_opt.use_cache_eviction:
293292
scheduler_config_opt.cache_eviction_config = MILEBENCH_CACHE_EVICTION_CONFIG
294293

295-
model_cb_noopt = ContinuousBatchingPipeline(models_path, scheduler_config, device, {}, get_default_llm_properties())
296-
model_cb_opt = ContinuousBatchingPipeline(models_path, scheduler_config_opt, device, {}, get_default_llm_properties())
294+
model_cb_noopt = ContinuousBatchingPipeline(models_path, scheduler_config, device, properties=get_default_llm_properties())
295+
model_cb_opt = ContinuousBatchingPipeline(models_path, scheduler_config_opt, device, properties=get_default_llm_properties())
297296

298297
generation_config = GenerationConfig() # expecting default greedy sampling
299298
generation_config.num_return_sequences = 1
300-
generation_config.max_new_tokens = 64
301-
302-
processor = retry_request(
303-
lambda: transformers.AutoProcessor.from_pretrained(
304-
model_id,
305-
trust_remote_code=True,
306-
)
307-
)
299+
generation_config.max_new_tokens = 512
300+
generation_config.do_sample = False
308301

309302
data_dir = "milebench_data" # HF_HOME / "milebench_data"
310303
subset = test_struct.subset
@@ -313,15 +306,13 @@ def test_optimized_generation_milebench(device, test_struct):
313306
subset=subset,
314307
subset_size=seqs_per_request,
315308
)
309+
316310
with tqdm(total=len(data)) as progress_bar:
317311
prompts, images = [], []
318312
answers = []
319313
ref_answers = []
320314
for p_idx, data_sample in enumerate(data):
321-
conversation = data_sample["conversation"]
322-
prompt = processor.apply_chat_template(
323-
conversation, tokenize=False, add_generation_prompt=True
324-
)
315+
prompt = data_sample["prompt"]
325316
image = data_sample["images"]
326317

327318
progress_bar.update(1)
@@ -332,24 +323,27 @@ def test_optimized_generation_milebench(device, test_struct):
332323

333324
if len(prompts) == seqs_per_request or p_idx == len(data) - 1:
334325
ans_batch = model_cb_opt.generate(
335-
prompts, images=images, generation_config=[generation_config] * len(prompts)
326+
prompts, images=images, generation_config=[generation_config] * len(prompts),
336327
)
337328
ref_ans_batch = model_cb_noopt.generate(
338-
prompts, images=images, generation_config=[generation_config] * len(prompts)
329+
prompts, images=images, generation_config=[generation_config] * len(prompts),
339330
)
331+
340332
for i, (opt_output, ref_output) in enumerate(zip(ans_batch, ref_ans_batch), start=p_idx-len(prompts)+1):
341-
answers[i]["pred"] = opt_output.m_generation_ids[0]
342-
ref_answers[i]["pred"] = ref_output.m_generation_ids[0]
333+
answers[i]["pred"] = opt_output.texts[0]
334+
ref_answers[i]["pred"] = ref_output.texts[0]
343335
prompts.clear()
344336
images.clear()
345337

346338
question_type = data.annotation['meta_data']['question_type']
347339
scorer = Eval()
340+
348341
score = scorer.evaluate(answers, subset, question_type)
349342
print(f"Score: {score}")
350343

351344
ref_score = scorer.evaluate(ref_answers, subset, question_type)
352345
print(f"Reference score: {ref_score}")
346+
353347
pipeline_opt_metrics = model_cb_opt.get_metrics()
354348
pipeline_noopt_metrics = model_cb_noopt.get_metrics()
355349

tests/python_tests/utils/milebench.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,49 @@
66
#
77
# Licensed under the Apache License
88

9+
# To download the required subsets from the MileBench dataset, please run the following script:
10+
#
11+
#!/bin/bash
12+
# OUT_DIR="milebench_data"
13+
# KEEP_DIRS=("ALFRED" "MMCoQA" "WikiVQA")
14+
# BASE_URL="https://huggingface.co/datasets/FreedomIntelligence/MileBench/resolve/main"
15+
# # List of tar.gz parts to download
16+
# PARTS=(part0 part2 part5)
17+
# # Create output directory
18+
# mkdir -p "$OUT_DIR"
19+
# cd "$OUT_DIR" || exit 1
20+
# # Download and extract
21+
# for part in "${PARTS[@]}"; do
22+
# FILENAME="MileBench_${part}.tar.gz"
23+
# URL="${BASE_URL}/${FILENAME}"
24+
# echo "Downloading $FILENAME..."
25+
# curl -L -o "$FILENAME" "$URL"
26+
# echo "Extracting $FILENAME..."
27+
# tar -xzf "$FILENAME" || { echo "Failed to extract $FILENAME"; exit 1; }
28+
# rm "$FILENAME"
29+
# done
30+
# # Remove unwanted folders
31+
# echo "Cleaning up..."
32+
# for dir in */ ; do
33+
# dir=${dir%/}
34+
# if [[ ! " ${KEEP_DIRS[@]} " =~ " ${dir} " ]]; then
35+
# echo "Removing $dir"
36+
# rm -rf "$dir"
37+
# fi
38+
# done
39+
# echo "Removing combined_1_images folders and *-adv.json inside kept directories..."
40+
# for dir in "${KEEP_DIRS[@]}"; do
41+
# TARGET="$dir/combined_1_images"
42+
# if [ -d "$TARGET" ]; then
43+
# rm -rf "$TARGET"
44+
# fi
45+
# ADV_FILE="$dir/${dir}-adv.json"
46+
# if [ -f "$ADV_FILE" ]; then
47+
# rm "$ADV_FILE"
48+
# fi
49+
# done
50+
51+
952
import os
1053
import json
1154
import re
@@ -67,25 +110,17 @@ def __getitem__(self, idx):
67110
context += choice_str
68111

69112
img_num = len(ann["task_instance"]["images_path"])
113+
qwen2_vl_image_placeholder = "<|vision_start|><|image_pad|><|vision_end|>"
70114
for i in range(img_num):
71115
rmv_txt = "{image#%d}"% (i+1)
72116
rmv_tbl = "{table#%d}"% (i+1)
73-
context = context.replace(rmv_txt, "")
74-
context = context.replace(rmv_tbl, "")
117+
context = context.replace(rmv_txt, qwen2_vl_image_placeholder)
118+
context = context.replace(rmv_tbl, qwen2_vl_image_placeholder)
75119

76120
task_instruction_id = ann["task_instruction_id"]
77121
context_str = task_instructions[task_instruction_id] + "\n" + context
78122
prompt = MileBenchDataset._transform_string(context_str)
79123

80-
conversation = [
81-
{
82-
"role": "user",
83-
"content": [{"type": "text", "text": prompt}],
84-
},
85-
]
86-
for i in range(img_num):
87-
conversation[0]["content"].append({"type": "image"})
88-
89124
images = []
90125
for p in ann["task_instance"]["images_path"]:
91126
img_path = os.path.join(self.image_dir, p)
@@ -95,7 +130,7 @@ def __getitem__(self, idx):
95130
images.append(image_tensor)
96131

97132
return {
98-
"conversation": conversation,
133+
"prompt": prompt,
99134
"images": images,
100135
"gt_answer": ann["response"],
101136
"choice_list": ann["task_instance"].get("choice_list", None),

0 commit comments

Comments
 (0)