-
Notifications
You must be signed in to change notification settings - Fork 300
[NPU] Add batch_size support for embedding model
#2986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 8 commits
f03d41c
961530c
cc8c457
7ee0bff
30a052f
9e9110e
af6a54d
96e95b0
58e4c4d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -66,7 +66,8 @@ def __init__( | |
| gen_embeds_fn=None, | ||
| pooling_type=None, | ||
| normalize=None, | ||
| padding_side=None | ||
| padding_side=None, | ||
| batch_size=None | ||
| ) -> None: | ||
| assert ( | ||
| base_model is not None or gt_data is not None | ||
|
|
@@ -80,6 +81,7 @@ def __init__( | |
| self.normalize = normalize or False | ||
| self.padding_side = padding_side or 'right' | ||
| self.gt_dir = os.path.dirname(gt_data) | ||
| self.batch_size = batch_size | ||
|
|
||
| if base_model: | ||
| self.gt_data = self._generate_data( | ||
|
|
@@ -178,8 +180,14 @@ def default_gen_answer(model, tokenizer, passages, **kwargs): | |
| kwargs = {'padding_side': self.padding_side, | ||
| 'pooling_type': self.pooling_type, | ||
| 'normalize': self.normalize} | ||
| result = gen_answer_fn(model, self.tokenizer, data[0], **kwargs) | ||
| passages.append(data[0]) | ||
|
|
||
| batch_size = self.batch_size or len(data[0]) | ||
| assert batch_size <= len(data[0]), \ | ||
| f"batch_size ({batch_size}) cannot be greater than data length ({len(data[0])})" | ||
| data_input = data[0][:batch_size] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What will be the behavior if the chunk of input data is less than the batch size?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added the check.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I meant inside the plugin, there's no point in line
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A exception will throw if batch and data-size are not match in I don't know if it is redundant.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's discuss this before making changes. If I understand correctly, the plugin will crash if we say the batch is 10, but provide 7 as input. @mengweiguo @as-suvorov What do you think about this ?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point. Currently if we fix
mengweiguo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| result = gen_answer_fn(model, self.tokenizer, data_input, **kwargs) | ||
|
|
||
| passages.append(data_input) | ||
| result_path = os.path.join(result_dir, f"embeds_{i}.npy") | ||
| with open(result_path, 'wb') as f: | ||
| np.save(f, result) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -262,6 +262,12 @@ def parse_args(): | |
| help="Config option assistant_confidence_threshold for Speculative decoding.", | ||
| ) | ||
|
|
||
| parser.add_argument( | ||
| '-bs', '--batch_size', | ||
| type=int, | ||
| default=None, | ||
| help='Batch size value') | ||
|
||
|
|
||
| return parser.parse_args() | ||
|
|
||
|
|
||
|
|
@@ -635,6 +641,7 @@ def create_evaluator(base_model, args): | |
| pooling_type=args.embeds_pooling_type, | ||
| normalize=args.embeds_normalize, | ||
| padding_side=args.embeds_padding_side, | ||
| batch_size=args.batch_size, | ||
| ) | ||
| elif task == "text-reranking": | ||
| return EvaluatorCLS( | ||
|
|
@@ -754,6 +761,8 @@ def main(): | |
| logger.info(version_str) | ||
|
|
||
| kwargs = {} | ||
| kwargs["batch_size"] = args.batch_size | ||
|
|
||
| if args.cb_config: | ||
| kwargs["cb_config"] = read_cb_config(args.cb_config) | ||
| if args.from_onnx: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall it be documented and added to help. Tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually this option
batch_sizeis already there in benchmark.