Skip to content

Commit 1b58a22

Browse files
committed
Add batch_size support for embedding model
1 parent 91dc71e commit 1b58a22

File tree

5 files changed

+18
-4
lines changed

5 files changed

+18
-4
lines changed

tools/llm_bench/llm_bench_utils/ov_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -683,6 +683,7 @@ def create_genai_text_embed_model(model_path, device, memory_data_collector, **k
683683
if padding_side:
684684
config.padding_side = padding_side
685685

686+
config.batch_size = kwargs.get("batch_size")
686687
ov_config = kwargs['config']
687688

688689
if kwargs.get("mem_consumption"):

tools/who_what_benchmark/whowhatbench/embeddings_evaluator.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ def __init__(
6666
gen_embeds_fn=None,
6767
pooling_type=None,
6868
normalize=None,
69-
padding_side=None
69+
padding_side=None,
70+
batch_size=None
7071
) -> None:
7172
assert (
7273
base_model is not None or gt_data is not None
@@ -80,6 +81,7 @@ def __init__(
8081
self.normalize = normalize or False
8182
self.padding_side = padding_side or 'right'
8283
self.gt_dir = os.path.dirname(gt_data)
84+
self.batch_size = batch_size
8385

8486
if base_model:
8587
self.gt_data = self._generate_data(
@@ -178,7 +180,10 @@ def default_gen_answer(model, tokenizer, passages, **kwargs):
178180
kwargs = {'padding_side': self.padding_side,
179181
'pooling_type': self.pooling_type,
180182
'normalize': self.normalize}
181-
result = gen_answer_fn(model, self.tokenizer, data[0], **kwargs)
183+
batch_size = self.batch_size or len(data[0])
184+
data_input = data[0][:batch_size]
185+
result = gen_answer_fn(model, self.tokenizer, data_input, **kwargs)
186+
182187
passages.append(data[0])
183188
result_path = os.path.join(result_dir, f"embeds_{i}.npy")
184189
with open(result_path, 'wb') as f:

tools/who_what_benchmark/whowhatbench/model_loaders.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,7 @@ def load_embedding_genai_pipeline(model_dir, device="CPU", ov_config=None, **kwa
504504
config.max_length = EMBED_DEFAULT_MAX_LENGTH
505505
config.normalize = kwargs.get("embeds_normalize", False)
506506
config.pad_to_max_length = True
507+
config.batch_size = kwargs.get("batch_size", config.batch_size)
507508

508509
logger.info("Using OpenVINO GenAI TextEmbeddingPipeline API")
509510
pipeline = openvino_genai.TextEmbeddingPipeline(model_dir, device.upper(), config, **ov_config)

tools/who_what_benchmark/whowhatbench/whowhat_metrics.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,11 @@ def evaluate(self, data_gold, data_prediction):
189189
with open(prediction, 'rb') as f:
190190
prediction_data = np.load(f)
191191

192-
cos_sim = F.cosine_similarity(torch.from_numpy(gold_data), torch.from_numpy(prediction_data))
192+
min_len = min(gold_data.shape[0], prediction_data.shape[0])
193+
gold_trimmed = gold_data[:min_len]
194+
pred_trimmed = prediction_data[:min_len]
195+
196+
cos_sim = F.cosine_similarity(torch.from_numpy(gold_trimmed), torch.from_numpy(pred_trimmed))
193197
metric_per_passages.append(cos_sim.detach().numpy())
194198
metric_per_gen.append(torch.mean(cos_sim).item())
195199

tools/who_what_benchmark/whowhatbench/wwb.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def parse_args():
232232
"If the base/target model is a local path, gguf-file should be just the filename (e.g., 'model.gguf'). "
233233
"If the base/target model is a HuggingFace model ID, gguf-file should be a relative path.",
234234
)
235-
235+
parser.add_argument('-bs', '--batch_size', type=int, default=None, help='Batch size value')
236236
return parser.parse_args()
237237

238238

@@ -478,6 +478,7 @@ def genai_gen_visual_text(model, prompt, image, processor, tokenizer, max_new_to
478478

479479
def genai_gen_embedding(model, tokenizer, passages, **kwargs):
480480
embeddings = model.embed_documents(passages)
481+
481482
return embeddings
482483

483484

@@ -588,6 +589,7 @@ def create_evaluator(base_model, args):
588589
pooling_type=args.embeds_pooling_type,
589590
normalize=args.embeds_normalize,
590591
padding_side=args.embeds_padding_side,
592+
batch_size=args.batch_size
591593
)
592594
elif task == "text-reranking":
593595
return EvaluatorCLS(
@@ -724,6 +726,7 @@ def main():
724726
kwargs["embeds_pooling"] = args.embeds_pooling_type
725727
kwargs["embeds_normalize"] = args.embeds_normalize
726728
kwargs["embeds_padding_side"] = args.embeds_padding_side
729+
kwargs["batch_size"] = args.batch_size
727730

728731
if args.gt_data and os.path.exists(args.gt_data):
729732
evaluator = create_evaluator(None, args)

0 commit comments

Comments
 (0)