Skip to content

Commit 3dc2b62

Browse files
wip: adding Text and Hybrid queries
1 parent 29cb397 commit 3dc2b62

File tree

1 file changed

+210
-2
lines changed

1 file changed

+210
-2
lines changed

redisvl/query/query.py

Lines changed: 210 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ def __init__(
9090
num_results (Optional[int], optional): The number of results to return. Defaults to 10.
9191
dialect (int, optional): The query dialect. Defaults to 2.
9292
sort_by (Optional[str], optional): The field to order the results by. Defaults to None.
93-
in_order (bool, optional): Requires the terms in the field to have the same order as the terms in the query filter. Defaults to False.
93+
in_order (bool, optional): Requires the terms in the field to have the same order as the
94+
terms in the query filter. Defaults to False.
9495
params (Optional[Dict[str, Any]], optional): The parameters for the query. Defaults to None.
9596
9697
Raises:
@@ -135,7 +136,8 @@ def __init__(
135136
"""A query for a simple count operation provided some filter expression.
136137
137138
Args:
138-
filter_expression (Optional[Union[str, FilterExpression]]): The filter expression to query with. Defaults to None.
139+
filter_expression (Optional[Union[str, FilterExpression]]): The filter expression to
140+
query with. Defaults to None.
139141
params (Optional[Dict[str, Any]], optional): The parameters for the query. Defaults to None.
140142
141143
Raises:
@@ -204,6 +206,7 @@ def __init__(
204206
"float32".
205207
num_results (int, optional): The top k results to return from the
206208
vector search. Defaults to 10.
209+
207210
return_score (bool, optional): Whether to return the vector
208211
distance. Defaults to True.
209212
dialect (int, optional): The RediSearch query dialect.
@@ -399,3 +402,208 @@ def params(self) -> Dict[str, Any]:
399402
class RangeQuery(VectorRangeQuery):
400403
# keep for backwards compatibility
401404
pass
405+
406+
407+
class TextQuery(FilterQuery):
408+
def __init__(
409+
self,
410+
text: str,
411+
text_field: str,
412+
text_scorer: str = "TFIDF",
413+
return_fields: Optional[List[str]] = None,
414+
filter_expression: Optional[Union[str, FilterExpression]] = None,
415+
num_results: int = 10,
416+
return_score: bool = True,
417+
dialect: int = 2,
418+
sort_by: Optional[str] = None,
419+
in_order: bool = False,
420+
):
421+
"""A query for running a full text and vector search, along with an optional
422+
filter expression.
423+
424+
Args:
425+
text (str): The text string to perform the text search with.
426+
text_field (str): The name of the document field to perform text search on.
427+
text_scorer (str, optional): The text scoring algorithm to use.
428+
Defaults to TFIDF. Options are {TFIDF, BM25, DOCNORM, DISMAX, DOCSCORE}.
429+
See https://redis.io/docs/latest/develop/interact/search-and-query/advanced-concepts/scoring/
430+
return_fields (List[str]): The declared fields to return with search
431+
results.
432+
filter_expression (Union[str, FilterExpression], optional): A filter to apply
433+
along with the vector search. Defaults to None.
434+
num_results (int, optional): The top k results to return from the
435+
search. Defaults to 10.
436+
return_score (bool, optional): Whether to return the text score.
437+
Defaults to True.
438+
dialect (int, optional): The RediSearch query dialect.
439+
Defaults to 2.
440+
sort_by (Optional[str]): The field to order the results by. Defaults
441+
to None. Results will be ordered by text score.
442+
in_order (bool): Requires the terms in the field to have
443+
the same order as the terms in the query filter, regardless of
444+
the offsets between them. Defaults to False.
445+
446+
Raises:
447+
TypeError: If filter_expression is not of type redisvl.query.FilterExpression
448+
"""
449+
self._text_field = text_field
450+
self._num_results = num_results
451+
self.set_filter(filter_expression)
452+
query_string = self._build_query_string()
453+
from nltk.corpus import stopwords
454+
import nltk
455+
456+
nltk.download('stopwords')
457+
self._stopwords = set(stopwords.words('english'))
458+
459+
460+
super().__init__(query_string)
461+
462+
# Handle query modifiers
463+
if return_fields:
464+
self.return_fields(*return_fields)
465+
466+
self.paging(0, self._num_results).dialect(dialect)
467+
468+
if return_score:
469+
self.return_fields(self.DISTANCE_ID) #TODO
470+
471+
if sort_by:
472+
self.sort_by(sort_by)
473+
else:
474+
self.sort_by(self.DISTANCE_ID) #TODO
475+
476+
if in_order:
477+
self.in_order()
478+
479+
480+
def _tokenize_query(self, user_query: str) -> str:
481+
"""Convert a raw user query to a redis full text query joined by ORs"""
482+
words = word_tokenize(user_query)
483+
484+
tokens = [token.strip().strip(",").lower() for token in user_query.split()]
485+
return " | ".join([token for token in tokens if token not in self._stopwords])
486+
487+
488+
def _build_query_string(self) -> str:
489+
"""Build the full query string for text search with optional filtering."""
490+
filter_expression = self._filter_expression
491+
# TODO include text only
492+
if isinstance(filter_expression, FilterExpression):
493+
filter_expression = str(filter_expression)
494+
495+
text = f"(~{Text(self._text_field) % self._tokenize_query(user_query)})"
496+
497+
text_and_filter = text & self._filter_expression
498+
499+
#TODO is this method even needed? use
500+
return text_and_filter
501+
502+
503+
class HybridQuery(VectorQuery, TextQuery):
504+
def __init__():
505+
self,
506+
text: str,
507+
text_field: str,
508+
vector: Union[List[float], bytes],
509+
vector_field_name: str,
510+
text_scorer: str = "TFIDF",
511+
alpha: float = 0.7,
512+
return_fields: Optional[List[str]] = None,
513+
filter_expression: Optional[Union[str, FilterExpression]] = None,
514+
dtype: str = "float32",
515+
num_results: int = 10,
516+
return_score: bool = True,
517+
dialect: int = 2,
518+
sort_by: Optional[str] = None,
519+
in_order: bool = False,
520+
):
521+
"""A query for running a hybrid full text and vector search, along with
522+
an optional filter expression.
523+
524+
Args:
525+
text (str): The text string to run text search with.
526+
text_field (str): The name of the text field to search against.
527+
vector (List[float]): The vector to perform the vector search with.
528+
vector_field_name (str): The name of the vector field to search
529+
against in the database.
530+
text_scorer (str, optional): The text scoring algorithm to use.
531+
Defaults to TFIDF.
532+
alpha (float, optional): The amount to weight the vector similarity
533+
score relative to the text similarity score. Defaults to 0.7
534+
return_fields (List[str]): The declared fields to return with search
535+
results.
536+
filter_expression (Union[str, FilterExpression], optional): A filter to apply
537+
along with the vector search. Defaults to None.
538+
dtype (str, optional): The dtype of the vector. Defaults to
539+
"float32".
540+
num_results (int, optional): The top k results to return from the
541+
vector search. Defaults to 10.
542+
return_score (bool, optional): Whether to return the vector
543+
distance. Defaults to True.
544+
dialect (int, optional): The RediSearch query dialect.
545+
Defaults to 2.
546+
sort_by (Optional[str]): The field to order the results by. Defaults
547+
to None. Results will be ordered by vector distance.
548+
in_order (bool): Requires the terms in the field to have
549+
the same order as the terms in the query filter, regardless of
550+
the offsets between them. Defaults to False.
551+
552+
Raises:
553+
TypeError: If filter_expression is not of type redisvl.query.FilterExpression
554+
555+
Note:
556+
Learn more about vector queries in Redis: https://redis.io/docs/interact/search-and-query/search/vectors/#knn-search
557+
"""
558+
self._text = text
559+
self._text_field_name = tex_field_name
560+
self._vector = vector
561+
self._vector_field_name = vector_field_name
562+
self._dtype = dtype
563+
self._num_results = num_results
564+
self.set_filter(filter_expression)
565+
query_string = self._build_query_string()
566+
567+
# TODO how to handle multiple parents? call parent.__init__() manually?
568+
super().__init__(query_string)
569+
570+
# Handle query modifiers
571+
if return_fields:
572+
self.return_fields(*return_fields)
573+
574+
self.paging(0, self._num_results).dialect(dialect)
575+
576+
if return_score:
577+
self.return_fields(self.DISTANCE_ID)
578+
579+
if sort_by:
580+
self.sort_by(sort_by)
581+
else:
582+
self.sort_by(self.DISTANCE_ID)
583+
584+
if in_order:
585+
self.in_order()
586+
587+
588+
def _build_query_string(self) -> str:
589+
"""Build the full query string for hybrid search with optional filtering."""
590+
filter_expression = self._filter_expression
591+
# TODO include hybrid
592+
if isinstance(filter_expression, FilterExpression):
593+
filter_expression = str(filter_expression)
594+
return f"{filter_expression}=>[KNN {self._num_results} @{self._vector_field_name} ${self.VECTOR_PARAM} AS {self.DISTANCE_ID}]"
595+
596+
@property
597+
def params(self) -> Dict[str, Any]:
598+
"""Return the parameters for the query.
599+
600+
Returns:
601+
Dict[str, Any]: The parameters for the query.
602+
"""
603+
if isinstance(self._vector, bytes):
604+
vector = self._vector
605+
else:
606+
vector = array_to_buffer(self._vector, dtype=self._dtype)
607+
608+
return {self.VECTOR_PARAM: vector}
609+

0 commit comments

Comments
 (0)