Skip to content

Commit 16a3d94

Browse files
committed
Add test for gte-multilingual-reranker-base in test_gte.rs
1 parent b9969f4 commit 16a3d94

File tree

1 file changed

+45
-37
lines changed

1 file changed

+45
-37
lines changed

backends/candle/tests/test_gte.rs

Lines changed: 45 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8,43 +8,51 @@ use text_embeddings_backend_core::{Backend, ModelType, Pool};
88

99
#[test]
1010
fn test_gte() -> Result<()> {
11-
let model_root = download_artifacts("Alibaba-NLP/gte-base-en-v1.5", None)?;
12-
let tokenizer = load_tokenizer(&model_root)?;
13-
14-
let backend = CandleBackend::new(
15-
&model_root,
16-
"float32".to_string(),
17-
ModelType::Embedding(Pool::Cls),
18-
)?;
19-
20-
let input_batch = batch(
21-
vec![
22-
tokenizer.encode("What is Deep Learning?", true).unwrap(),
23-
tokenizer.encode("Deep Learning is...", true).unwrap(),
24-
tokenizer.encode("What is Deep Learning?", true).unwrap(),
25-
],
26-
[0, 1, 2].to_vec(),
27-
vec![],
28-
);
29-
30-
let matcher = cosine_matcher();
31-
32-
let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_batch)?);
33-
let embeddings_batch = SnapshotEmbeddings::from(pooled_embeddings);
34-
insta::assert_yaml_snapshot!("gte_batch", embeddings_batch, &matcher);
35-
36-
let input_single = batch(
37-
vec![tokenizer.encode("What is Deep Learning?", true).unwrap()],
38-
[0].to_vec(),
39-
vec![],
40-
);
41-
42-
let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_single)?);
43-
let embeddings_single = SnapshotEmbeddings::from(pooled_embeddings);
44-
45-
insta::assert_yaml_snapshot!("gte_single", embeddings_single, &matcher);
46-
assert_eq!(embeddings_batch[0], embeddings_single[0]);
47-
assert_eq!(embeddings_batch[2], embeddings_single[0]);
11+
let model_ids = vec![
12+
"Alibaba-NLP/gte-base-en-v1.5",
13+
"Alibaba-NLP/gte-multilingual-base", // Included in test due to different safetensors
14+
// format as it comes with the "new." prefix
15+
];
16+
17+
for model_id in model_ids {
18+
let model_root = download_artifacts(model_id, None)?;
19+
let tokenizer = load_tokenizer(&model_root)?;
20+
21+
let backend = CandleBackend::new(
22+
&model_root,
23+
"float32".to_string(),
24+
ModelType::Embedding(Pool::Cls),
25+
)?;
26+
27+
let input_batch = batch(
28+
vec![
29+
tokenizer.encode("What is Deep Learning?", true).unwrap(),
30+
tokenizer.encode("Deep Learning is...", true).unwrap(),
31+
tokenizer.encode("What is Deep Learning?", true).unwrap(),
32+
],
33+
[0, 1, 2].to_vec(),
34+
vec![],
35+
);
36+
37+
let matcher = cosine_matcher();
38+
39+
let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_batch)?);
40+
let embeddings_batch = SnapshotEmbeddings::from(pooled_embeddings);
41+
insta::assert_yaml_snapshot!("gte_batch", embeddings_batch, &matcher);
42+
43+
let input_single = batch(
44+
vec![tokenizer.encode("What is Deep Learning?", true).unwrap()],
45+
[0].to_vec(),
46+
vec![],
47+
);
48+
49+
let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_single)?);
50+
let embeddings_single = SnapshotEmbeddings::from(pooled_embeddings);
51+
52+
insta::assert_yaml_snapshot!("gte_single", embeddings_single, &matcher);
53+
assert_eq!(embeddings_batch[0], embeddings_single[0]);
54+
assert_eq!(embeddings_batch[2], embeddings_single[0]);
55+
}
4856

4957
Ok(())
5058
}

0 commit comments

Comments
 (0)