@@ -377,29 +377,60 @@ impl SnippetRetriever {
377377 Ok ( ( ) )
378378 }
379379
380- pub ( crate ) async fn search (
380+ pub ( crate ) async fn build_query (
381381 & self ,
382382 snippet : String ,
383+ strategy : BuildFrom ,
384+ ) -> Result < Vec < f32 > > {
385+ match strategy {
386+ BuildFrom :: Start => {
387+ let mut encoding = self . tokenizer . encode ( snippet. clone ( ) , true ) ?;
388+ encoding. truncate (
389+ self . model_config . max_input_size ,
390+ 1 ,
391+ TruncationDirection :: Right ,
392+ ) ;
393+ self . generate_embedding ( encoding, self . model . clone ( ) ) . await
394+ }
395+ BuildFrom :: Cursor { cursor_position } => {
396+ let ( before, after) = snippet. split_at ( cursor_position) ;
397+ let mut before_encoding = self . tokenizer . encode ( before, true ) ?;
398+ let mut after_encoding = self . tokenizer . encode ( after, true ) ?;
399+ let share = self . model_config . max_input_size / 2 ;
400+ before_encoding. truncate ( share, 1 , TruncationDirection :: Left ) ;
401+ after_encoding. truncate ( share, 1 , TruncationDirection :: Right ) ;
402+ before_encoding. take_overflowing ( ) ;
403+ after_encoding. take_overflowing ( ) ;
404+ before_encoding. merge_with ( after_encoding, false ) ;
405+ self . generate_embedding ( before_encoding, self . model . clone ( ) )
406+ . await
407+ }
408+ BuildFrom :: End => {
409+ let mut encoding = self . tokenizer . encode ( snippet. clone ( ) , true ) ?;
410+ encoding. truncate (
411+ self . model_config . max_input_size ,
412+ 1 ,
413+ TruncationDirection :: Left ,
414+ ) ;
415+ self . generate_embedding ( encoding, self . model . clone ( ) ) . await
416+ }
417+ }
418+ }
419+
420+ pub ( crate ) async fn search (
421+ & self ,
422+ query : & [ f32 ] ,
383423 filter : Option < FilterBuilder > ,
384424 ) -> Result < Vec < Snippet > > {
385425 let db = match self . db . as_ref ( ) {
386426 Some ( db) => db. clone ( ) ,
387427 None => return Err ( Error :: UninitialisedDatabase ) ,
388428 } ;
389429 let col = db. get_collection ( & self . collection_name ) . await ?;
390- let mut encoding = self . tokenizer . encode ( snippet. clone ( ) , true ) ?;
391- encoding. truncate (
392- self . model_config . max_input_size ,
393- 1 ,
394- TruncationDirection :: Right ,
395- ) ;
396- let query = self
397- . generate_embedding ( encoding, self . model . clone ( ) )
398- . await ?;
399430 let result = col
400431 . read ( )
401432 . await
402- . get ( & query, 5 , filter)
433+ . get ( query, 5 , filter)
403434 . await ?
404435 . iter ( )
405436 . map ( TryInto :: try_into)
@@ -537,3 +568,12 @@ impl SnippetRetriever {
537568 Ok ( ( ) )
538569 }
539570}
571+
572+ pub ( crate ) enum BuildFrom {
573+ Cursor {
574+ cursor_position : usize ,
575+ } ,
576+ End ,
577+ #[ allow( dead_code) ]
578+ Start ,
579+ }
0 commit comments