@@ -13,7 +13,7 @@ import { operationWithFallback } from "../../../helpers/operationWithFallback.js
1313import { AGG_COUNT_MAX_TIME_MS_CAP , ONE_MB , CURSOR_LIMITS_TO_LLM_TEXT } from "../../../helpers/constants.js" ;
1414import { zEJSON } from "../../args.js" ;
1515import { LogId } from "../../../common/logger.js" ;
16- import { SupportedEmbeddingModels , zSupportedEmbeddingModels } from "../../../common/search/embeddingsProvider.js" ;
16+ import { zSupportedEmbeddingModels } from "../../../common/search/embeddingsProvider.js" ;
1717
1818const AnyStage = zEJSON ( ) ;
1919const VectorSearchStage = z . object ( {
@@ -47,9 +47,11 @@ const VectorSearchStage = z.object({
4747 filter : zEJSON ( )
4848 . optional ( )
4949 . describe ( "MQL filter that can only use pre-filter fields from the index definition." ) ,
50- embeddingModel : zSupportedEmbeddingModels . describe (
51- "The embedding model to use to generate embeddings before search. Note to LLM: If unsure, ask the user before providing one."
52- ) ,
50+ embeddingModel : zSupportedEmbeddingModels
51+ . optional ( )
52+ . describe (
53+ "The embedding model to use to generate embeddings before search. Note to LLM: If unsure, ask the user before providing one."
54+ ) ,
5355 } )
5456 . passthrough ( ) ,
5557} ) ;
@@ -224,32 +226,36 @@ export class AggregateTool extends MongoDBToolBase {
224226 pipeline : Document [ ] ;
225227 } ) : Promise < Document [ ] > {
226228 for ( const stage of pipeline ) {
227- if ( stage . $vectorSearch ) {
228- if ( "queryVector" in stage . $vectorSearch && Array . isArray ( stage . $vectorSearch . queryVector ) ) {
229- // if it's already embeddings, don't do anything
229+ if ( "$vectorSearch" in stage ) {
230+ const { $vectorSearch : vectorSearchStage } = stage as z . infer < typeof VectorSearchStage > ;
231+
232+ if ( Array . isArray ( vectorSearchStage . queryVector ) ) {
230233 continue ;
231234 }
232235
233- if ( ! ( "embeddingModel" in stage . $vectorSearch ) ) {
236+ if ( ! vectorSearchStage . embeddingModel ) {
234237 throw new MongoDBError (
235238 ErrorCodes . AtlasVectorSearchInvalidQuery ,
236239 "embeddingModel is mandatory if queryVector is a raw string."
237240 ) ;
238241 }
239242
240- const model = stage . $vectorSearch . embeddingModel as SupportedEmbeddingModels ;
241- delete stage . $vectorSearch . embeddingModel ;
243+ const model = vectorSearchStage . embeddingModel ;
244+ delete vectorSearchStage . embeddingModel ;
242245
243246 const [ embeddings ] = await this . session . vectorSearchEmbeddingsManager . generateEmbeddings ( {
244247 database,
245248 collection,
246- path : stage . $vectorSearch . path ,
249+ path : vectorSearchStage . path ,
247250 model,
248- rawValues : stage . $vectorSearch . queryVector ,
251+ rawValues : [ vectorSearchStage . queryVector ] ,
249252 inputType : "query" ,
250253 } ) ;
251254
252- stage . $vectorSearch . queryVector = embeddings ;
255+ // $vectorSearch.queryVector can be a BSON.Binary: that it's not either number or an array.
256+ // It's not exactly valid from the LLM perspective. That's why we overwrite the
257+ // stage in an untyped way, as what we expose and what we can use are different.
258+ vectorSearchStage . queryVector = embeddings as number [ ] ;
253259 }
254260 }
255261
0 commit comments