11from typing import Any , Dict , List , Optional , Union
22
3+ from redis .commands .search .aggregation import AggregateRequest , Desc
34from redis .commands .search .query import Query as RedisQuery
45
56from redisvl .query .filter import FilterExpression
@@ -136,7 +137,7 @@ def __init__(
136137 """A query for a simple count operation provided some filter expression.
137138
138139 Args:
139- filter_expression (Optional[Union[str, FilterExpression]]): The filter expression to
140+ filter_expression (Optional[Union[str, FilterExpression]]): The filter expression to
140141 query with. Defaults to None.
141142 params (Optional[Dict[str, Any]], optional): The parameters for the query. Defaults to None.
142143
@@ -406,31 +407,32 @@ class RangeQuery(VectorRangeQuery):
406407
407408class TextQuery (FilterQuery ):
408409 def __init__ (
409- self ,
410+ self ,
410411 text : str ,
411412 text_field : str ,
412- text_scorer : str = "TFIDF" ,
413- return_fields : Optional [List [str ]] = None ,
413+ text_scorer : str = "BM25" ,
414414 filter_expression : Optional [Union [str , FilterExpression ]] = None ,
415+ return_fields : Optional [List [str ]] = None ,
415416 num_results : int = 10 ,
416417 return_score : bool = True ,
417418 dialect : int = 2 ,
418419 sort_by : Optional [str ] = None ,
419420 in_order : bool = False ,
421+ params : Optional [Dict [str , Any ]] = None ,
420422 ):
421423 """A query for running a full text and vector search, along with an optional
422424 filter expression.
423425
424426 Args:
425- text (str): The text string to perform the text search with.
427+ text (str): The text string to perform the text search with.
426428 text_field (str): The name of the document field to perform text search on.
427429 text_scorer (str, optional): The text scoring algorithm to use.
428- Defaults to TFIDF . Options are {TFIDF, BM25, DOCNORM, DISMAX, DOCSCORE}.
430+ Defaults to BM25 . Options are {TFIDF, BM25, DOCNORM, DISMAX, DOCSCORE}.
429431 See https://redis.io/docs/latest/develop/interact/search-and-query/advanced-concepts/scoring/
432+ filter_expression (Union[str, FilterExpression], optional): A filter to apply
433+ along with the text search. Defaults to None.
430434 return_fields (List[str]): The declared fields to return with search
431435 results.
432- filter_expression (Union[str, FilterExpression], optional): A filter to apply
433- along with the vector search. Defaults to None.
434436 num_results (int, optional): The top k results to return from the
435437 search. Defaults to 10.
436438 return_score (bool, optional): Whether to return the text score.
@@ -442,174 +444,82 @@ def __init__(
442444 in_order (bool): Requires the terms in the field to have
443445 the same order as the terms in the query filter, regardless of
444446 the offsets between them. Defaults to False.
445-
446- Raises:
447- TypeError: If filter_expression is not of type redisvl.query.FilterExpression
447+ params (Optional[Dict[str, Any]], optional): The parameters for the query.
448+ Defaults to None.
448449 """
450+ import nltk
451+ from nltk .corpus import stopwords
452+
453+ nltk .download ("stopwords" )
454+ self ._stopwords = set (stopwords .words ("english" ))
455+
456+ self ._text = text
449457 self ._text_field = text_field
450- self ._num_results = num_results
458+ self ._text_scorer = text_scorer
459+
451460 self .set_filter (filter_expression )
452- query_string = self ._build_query_string ()
453- from nltk .corpus import stopwords
454- import nltk
461+ self ._num_results = num_results
455462
456- nltk .download ('stopwords' )
457- self ._stopwords = set (stopwords .words ('english' ))
463+ query_string = self ._build_query_string ()
458464
459- super ().__init__ (query_string )
465+ super ().__init__ (
466+ query_string ,
467+ return_fields = return_fields ,
468+ num_results = num_results ,
469+ dialect = dialect ,
470+ sort_by = sort_by ,
471+ in_order = in_order ,
472+ params = params ,
473+ )
460474
461475 # Handle query modifiers
462- if return_fields :
463- self .return_fields (* return_fields )
464-
476+ self .scorer (self ._text_scorer )
465477 self .paging (0 , self ._num_results ).dialect (dialect )
466478
467479 if return_score :
468- self .return_fields (self .DISTANCE_ID ) #TODO
469-
470- if sort_by :
471- self .sort_by (sort_by )
472- else :
473- self .sort_by (self .DISTANCE_ID ) #TODO
474-
475- if in_order :
476- self .in_order ()
480+ self .with_scores ()
477481
478-
479- def _tokenize_query (self , user_query : str ) -> str :
482+ def tokenize_and_escape_query (self , user_query : str ) -> str :
480483 """Convert a raw user query to a redis full text query joined by ORs"""
484+ from redisvl .utils .token_escaper import TokenEscaper
481485
482- words = word_tokenize (user_query )
483-
484- tokens = [token .strip ().strip ("," ).lower () for token in user_query .split ()]
485- return " | " .join ([token for token in tokens if token not in self ._stopwords ])
486+ escaper = TokenEscaper ()
486487
488+ tokens = [
489+ escaper .escape (
490+ token .strip ().strip ("," ).replace ("“" , "" ).replace ("”" , "" ).lower ()
491+ )
492+ for token in user_query .split ()
493+ ]
494+ return " | " .join (
495+ [token for token in tokens if token and token not in self ._stopwords ]
496+ )
487497
488498 def _build_query_string (self ) -> str :
489499 """Build the full query string for text search with optional filtering."""
490500 filter_expression = self ._filter_expression
491- # TODO include text only
492501 if isinstance (filter_expression , FilterExpression ):
493502 filter_expression = str (filter_expression )
503+ else :
504+ filter_expression = ""
494505
495- text = f"(~{ Text (self ._text_field ) % self ._tokenize_query (user_query )} )"
496-
497- text_and_filter = text & self ._filter_expression
498-
499- #TODO is this method even needed? use
500- return text_and_filter
506+ text = f"(~@{ self ._text_field } :({ self .tokenize_and_escape_query (self ._text )} ))"
507+ if filter_expression and filter_expression != "*" :
508+ text += f"({ filter_expression } )"
509+ return text
501510
502- # from redisvl.utils.token_escaper import TokenEscaper
503- # escaper = TokenEscaper()
504- # def tokenize_and_escape_query(user_query: str) -> str:
505- # """Convert a raw user query to a redis full text query joined by ORs"""
506- # tokens = [escaper.escape(token.strip().strip(",").replace("“", "").replace("”", "").lower()) for token in user_query.split()]
507- # return " | ".join([token for token in tokens if token and token not in stopwords_en])
508511
509- class HybridQuery (VectorQuery , TextQuery ):
510- def __init__ ():
511- self ,
512- text : str ,
513- text_field : str ,
514- vector : Union [List [float ], bytes ],
515- vector_field_name : str ,
516- text_scorer : str = "TFIDF" ,
517- alpha : float = 0.7 ,
518- return_fields : Optional [List [str ]] = None ,
519- filter_expression : Optional [Union [str , FilterExpression ]] = None ,
520- dtype : str = "float32" ,
521- num_results : int = 10 ,
522- return_score : bool = True ,
523- dialect : int = 2 ,
524- sort_by : Optional [str ] = None ,
525- in_order : bool = False ,
512+ class HybridQuery (AggregateRequest ):
513+ def __init__ (
514+ self , text_query : TextQuery , vector_query : VectorQuery , alpha : float = 0.7
526515 ):
527- """A query for running a hybrid full text and vector search, along with
528- an optional filter expression.
516+ """An aggregate query for running a hybrid full text and vector search.
529517
530518 Args:
531- text (str): The text string to run text search with.
532- text_field (str): The name of the text field to search against.
533- vector (List[float]): The vector to perform the vector search with.
534- vector_field_name (str): The name of the vector field to search
535- against in the database.
536- text_scorer (str, optional): The text scoring algorithm to use.
537- Defaults to TFIDF.
519+ text_query (TextQuery): The text query to run text search with.
520+ vector_query (VectorQuery): The vector query to run vector search with.
538521 alpha (float, optional): The amount to weight the vector similarity
539522 score relative to the text similarity score. Defaults to 0.7
540- return_fields (List[str]): The declared fields to return with search
541- results.
542- filter_expression (Union[str, FilterExpression], optional): A filter to apply
543- along with the vector search. Defaults to None.
544- dtype (str, optional): The dtype of the vector. Defaults to
545- "float32".
546- num_results (int, optional): The top k results to return from the
547- vector search. Defaults to 10.
548- return_score (bool, optional): Whether to return the vector
549- distance. Defaults to True.
550- dialect (int, optional): The RediSearch query dialect.
551- Defaults to 2.
552- sort_by (Optional[str]): The field to order the results by. Defaults
553- to None. Results will be ordered by vector distance.
554- in_order (bool): Requires the terms in the field to have
555- the same order as the terms in the query filter, regardless of
556- the offsets between them. Defaults to False.
557-
558- Raises:
559- TypeError: If filter_expression is not of type redisvl.query.FilterExpression
560-
561- Note:
562- Learn more about vector queries in Redis: https://redis.io/docs/interact/search-and-query/search/vectors/#knn-search
563- """
564- self ._text = text
565- self ._text_field_name = tex_field_name
566- self ._vector = vector
567- self ._vector_field_name = vector_field_name
568- self ._dtype = dtype
569- self ._num_results = num_results
570- self .set_filter (filter_expression )
571- query_string = self ._build_query_string ()
572-
573- # TODO how to handle multiple parents? call parent.__init__() manually?
574- super ().__init__ (query_string )
575-
576- # Handle query modifiers
577- if return_fields :
578- self .return_fields (* return_fields )
579-
580- self .paging (0 , self ._num_results ).dialect (dialect )
581523
582- if return_score :
583- self .return_fields (self .DISTANCE_ID )
584-
585- if sort_by :
586- self .sort_by (sort_by )
587- else :
588- self .sort_by (self .DISTANCE_ID )
589-
590- if in_order :
591- self .in_order ()
592-
593-
594- def _build_query_string (self ) -> str :
595- """Build the full query string for hybrid search with optional filtering."""
596- filter_expression = self ._filter_expression
597- # TODO include hybrid
598- if isinstance (filter_expression , FilterExpression ):
599- filter_expression = str (filter_expression )
600- return f"{ filter_expression } =>[KNN { self ._num_results } @{ self ._vector_field_name } ${ self .VECTOR_PARAM } AS { self .DISTANCE_ID } ]"
601-
602- @property
603- def params (self ) -> Dict [str , Any ]:
604- """Return the parameters for the query.
605-
606- Returns:
607- Dict[str, Any]: The parameters for the query.
608524 """
609- if isinstance (self ._vector , bytes ):
610- vector = self ._vector
611- else :
612- vector = array_to_buffer (self ._vector , dtype = self ._dtype )
613-
614- return {self .VECTOR_PARAM : vector }
615-
525+ pass
0 commit comments