@@ -8,43 +8,51 @@ use text_embeddings_backend_core::{Backend, ModelType, Pool};
88
99#[ test]
1010fn test_gte ( ) -> Result < ( ) > {
11- let model_root = download_artifacts ( "Alibaba-NLP/gte-base-en-v1.5" , None ) ?;
12- let tokenizer = load_tokenizer ( & model_root) ?;
13-
14- let backend = CandleBackend :: new (
15- & model_root,
16- "float32" . to_string ( ) ,
17- ModelType :: Embedding ( Pool :: Cls ) ,
18- ) ?;
19-
20- let input_batch = batch (
21- vec ! [
22- tokenizer. encode( "What is Deep Learning?" , true ) . unwrap( ) ,
23- tokenizer. encode( "Deep Learning is..." , true ) . unwrap( ) ,
24- tokenizer. encode( "What is Deep Learning?" , true ) . unwrap( ) ,
25- ] ,
26- [ 0 , 1 , 2 ] . to_vec ( ) ,
27- vec ! [ ] ,
28- ) ;
29-
30- let matcher = cosine_matcher ( ) ;
31-
32- let ( pooled_embeddings, _) = sort_embeddings ( backend. embed ( input_batch) ?) ;
33- let embeddings_batch = SnapshotEmbeddings :: from ( pooled_embeddings) ;
34- insta:: assert_yaml_snapshot!( "gte_batch" , embeddings_batch, & matcher) ;
35-
36- let input_single = batch (
37- vec ! [ tokenizer. encode( "What is Deep Learning?" , true ) . unwrap( ) ] ,
38- [ 0 ] . to_vec ( ) ,
39- vec ! [ ] ,
40- ) ;
41-
42- let ( pooled_embeddings, _) = sort_embeddings ( backend. embed ( input_single) ?) ;
43- let embeddings_single = SnapshotEmbeddings :: from ( pooled_embeddings) ;
44-
45- insta:: assert_yaml_snapshot!( "gte_single" , embeddings_single, & matcher) ;
46- assert_eq ! ( embeddings_batch[ 0 ] , embeddings_single[ 0 ] ) ;
47- assert_eq ! ( embeddings_batch[ 2 ] , embeddings_single[ 0 ] ) ;
11+ let model_ids = vec ! [
12+ "Alibaba-NLP/gte-base-en-v1.5" ,
13+ "Alibaba-NLP/gte-multilingual-base" , // Included in test due to different safetensors
14+ // format as it comes with the "new." prefix
15+ ] ;
16+
17+ for model_id in model_ids {
18+ let model_root = download_artifacts ( model_id, None ) ?;
19+ let tokenizer = load_tokenizer ( & model_root) ?;
20+
21+ let backend = CandleBackend :: new (
22+ & model_root,
23+ "float32" . to_string ( ) ,
24+ ModelType :: Embedding ( Pool :: Cls ) ,
25+ ) ?;
26+
27+ let input_batch = batch (
28+ vec ! [
29+ tokenizer. encode( "What is Deep Learning?" , true ) . unwrap( ) ,
30+ tokenizer. encode( "Deep Learning is..." , true ) . unwrap( ) ,
31+ tokenizer. encode( "What is Deep Learning?" , true ) . unwrap( ) ,
32+ ] ,
33+ [ 0 , 1 , 2 ] . to_vec ( ) ,
34+ vec ! [ ] ,
35+ ) ;
36+
37+ let matcher = cosine_matcher ( ) ;
38+
39+ let ( pooled_embeddings, _) = sort_embeddings ( backend. embed ( input_batch) ?) ;
40+ let embeddings_batch = SnapshotEmbeddings :: from ( pooled_embeddings) ;
41+ insta:: assert_yaml_snapshot!( "gte_batch" , embeddings_batch, & matcher) ;
42+
43+ let input_single = batch (
44+ vec ! [ tokenizer. encode( "What is Deep Learning?" , true ) . unwrap( ) ] ,
45+ [ 0 ] . to_vec ( ) ,
46+ vec ! [ ] ,
47+ ) ;
48+
49+ let ( pooled_embeddings, _) = sort_embeddings ( backend. embed ( input_single) ?) ;
50+ let embeddings_single = SnapshotEmbeddings :: from ( pooled_embeddings) ;
51+
52+ insta:: assert_yaml_snapshot!( "gte_single" , embeddings_single, & matcher) ;
53+ assert_eq ! ( embeddings_batch[ 0 ] , embeddings_single[ 0 ] ) ;
54+ assert_eq ! ( embeddings_batch[ 2 ] , embeddings_single[ 0 ] ) ;
55+ }
4856
4957 Ok ( ( ) )
5058}
0 commit comments