Skip to content

Commit 4e1df4a

Browse files
authored
Fix reading non-standard config for past_key_values in ONNX (#751)
1 parent 78502d8 commit 4e1df4a

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

backends/ort/src/lib.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,16 @@ use text_embeddings_backend_core::{
1515
pub struct Config {
1616
pub pad_token_id: Option<usize>,
1717
pub eos_token_id: Option<usize>,
18-
// NOTE: the fields below are only required when the ONNX model expects the `past_key_values`
19-
// as input i.e., whenever the ONNX model has been exported with optimized MHA nodes
18+
19+
// NOTE: The fields below are only required when the ONNX model expects the `past_key_values`
20+
// as input i.e., whenever the ONNX model has been exported with optimized MHA/MQA nodes
21+
// NOTE: The renames from `n_embd`, `n_layer`, and `n_head` have been included for some edge
22+
// cases as e.g. `nomic-ai/nomic-embed-text-v1`, given that those ONNX exports use MQA
23+
#[serde(rename = "n_embd")]
2024
pub hidden_size: usize,
25+
#[serde(rename = "n_layer")]
2126
pub num_hidden_layers: usize,
27+
#[serde(rename = "n_head")]
2228
pub num_key_value_heads: Option<usize>,
2329
}
2430

0 commit comments

Comments
 (0)