Skip to content

Commit 7848c7b

Browse files
adds TextQuery class
1 parent 4e37406 commit 7848c7b

File tree

3 files changed

+134
-150
lines changed

3 files changed

+134
-150
lines changed

redisvl/query/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
BaseQuery,
33
CountQuery,
44
FilterQuery,
5+
HybridQuery,
56
RangeQuery,
7+
TextQuery,
68
VectorQuery,
79
VectorRangeQuery,
810
)
@@ -14,4 +16,6 @@
1416
"RangeQuery",
1517
"VectorRangeQuery",
1618
"CountQuery",
19+
"TextQuery",
20+
"HybridQuery",
1721
]

redisvl/query/query.py

Lines changed: 59 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Any, Dict, List, Optional, Union
22

3+
from redis.commands.search.aggregation import AggregateRequest, Desc
34
from redis.commands.search.query import Query as RedisQuery
45

56
from redisvl.query.filter import FilterExpression
@@ -136,7 +137,7 @@ def __init__(
136137
"""A query for a simple count operation provided some filter expression.
137138
138139
Args:
139-
filter_expression (Optional[Union[str, FilterExpression]]): The filter expression to
140+
filter_expression (Optional[Union[str, FilterExpression]]): The filter expression to
140141
query with. Defaults to None.
141142
params (Optional[Dict[str, Any]], optional): The parameters for the query. Defaults to None.
142143
@@ -406,31 +407,32 @@ class RangeQuery(VectorRangeQuery):
406407

407408
class TextQuery(FilterQuery):
408409
def __init__(
409-
self,
410+
self,
410411
text: str,
411412
text_field: str,
412-
text_scorer: str = "TFIDF",
413-
return_fields: Optional[List[str]] = None,
413+
text_scorer: str = "BM25",
414414
filter_expression: Optional[Union[str, FilterExpression]] = None,
415+
return_fields: Optional[List[str]] = None,
415416
num_results: int = 10,
416417
return_score: bool = True,
417418
dialect: int = 2,
418419
sort_by: Optional[str] = None,
419420
in_order: bool = False,
421+
params: Optional[Dict[str, Any]] = None,
420422
):
421423
"""A query for running a full text and vector search, along with an optional
422424
filter expression.
423425
424426
Args:
425-
text (str): The text string to perform the text search with.
427+
text (str): The text string to perform the text search with.
426428
text_field (str): The name of the document field to perform text search on.
427429
text_scorer (str, optional): The text scoring algorithm to use.
428-
Defaults to TFIDF. Options are {TFIDF, BM25, DOCNORM, DISMAX, DOCSCORE}.
430+
Defaults to BM25. Options are {TFIDF, BM25, DOCNORM, DISMAX, DOCSCORE}.
429431
See https://redis.io/docs/latest/develop/interact/search-and-query/advanced-concepts/scoring/
432+
filter_expression (Union[str, FilterExpression], optional): A filter to apply
433+
along with the text search. Defaults to None.
430434
return_fields (List[str]): The declared fields to return with search
431435
results.
432-
filter_expression (Union[str, FilterExpression], optional): A filter to apply
433-
along with the vector search. Defaults to None.
434436
num_results (int, optional): The top k results to return from the
435437
search. Defaults to 10.
436438
return_score (bool, optional): Whether to return the text score.
@@ -442,174 +444,82 @@ def __init__(
442444
in_order (bool): Requires the terms in the field to have
443445
the same order as the terms in the query filter, regardless of
444446
the offsets between them. Defaults to False.
445-
446-
Raises:
447-
TypeError: If filter_expression is not of type redisvl.query.FilterExpression
447+
params (Optional[Dict[str, Any]], optional): The parameters for the query.
448+
Defaults to None.
448449
"""
450+
import nltk
451+
from nltk.corpus import stopwords
452+
453+
nltk.download("stopwords")
454+
self._stopwords = set(stopwords.words("english"))
455+
456+
self._text = text
449457
self._text_field = text_field
450-
self._num_results = num_results
458+
self._text_scorer = text_scorer
459+
451460
self.set_filter(filter_expression)
452-
query_string = self._build_query_string()
453-
from nltk.corpus import stopwords
454-
import nltk
461+
self._num_results = num_results
455462

456-
nltk.download('stopwords')
457-
self._stopwords = set(stopwords.words('english'))
463+
query_string = self._build_query_string()
458464

459-
super().__init__(query_string)
465+
super().__init__(
466+
query_string,
467+
return_fields=return_fields,
468+
num_results=num_results,
469+
dialect=dialect,
470+
sort_by=sort_by,
471+
in_order=in_order,
472+
params=params,
473+
)
460474

461475
# Handle query modifiers
462-
if return_fields:
463-
self.return_fields(*return_fields)
464-
476+
self.scorer(self._text_scorer)
465477
self.paging(0, self._num_results).dialect(dialect)
466478

467479
if return_score:
468-
self.return_fields(self.DISTANCE_ID) #TODO
469-
470-
if sort_by:
471-
self.sort_by(sort_by)
472-
else:
473-
self.sort_by(self.DISTANCE_ID) #TODO
474-
475-
if in_order:
476-
self.in_order()
480+
self.with_scores()
477481

478-
479-
def _tokenize_query(self, user_query: str) -> str:
482+
def tokenize_and_escape_query(self, user_query: str) -> str:
480483
"""Convert a raw user query to a redis full text query joined by ORs"""
484+
from redisvl.utils.token_escaper import TokenEscaper
481485

482-
words = word_tokenize(user_query)
483-
484-
tokens = [token.strip().strip(",").lower() for token in user_query.split()]
485-
return " | ".join([token for token in tokens if token not in self._stopwords])
486+
escaper = TokenEscaper()
486487

488+
tokens = [
489+
escaper.escape(
490+
token.strip().strip(",").replace("“", "").replace("”", "").lower()
491+
)
492+
for token in user_query.split()
493+
]
494+
return " | ".join(
495+
[token for token in tokens if token and token not in self._stopwords]
496+
)
487497

488498
def _build_query_string(self) -> str:
489499
"""Build the full query string for text search with optional filtering."""
490500
filter_expression = self._filter_expression
491-
# TODO include text only
492501
if isinstance(filter_expression, FilterExpression):
493502
filter_expression = str(filter_expression)
503+
else:
504+
filter_expression = ""
494505

495-
text = f"(~{Text(self._text_field) % self._tokenize_query(user_query)})"
496-
497-
text_and_filter = text & self._filter_expression
498-
499-
#TODO is this method even needed? use
500-
return text_and_filter
506+
text = f"(~@{self._text_field}:({self.tokenize_and_escape_query(self._text)}))"
507+
if filter_expression and filter_expression != "*":
508+
text += f"({filter_expression})"
509+
return text
501510

502-
# from redisvl.utils.token_escaper import TokenEscaper
503-
# escaper = TokenEscaper()
504-
# def tokenize_and_escape_query(user_query: str) -> str:
505-
# """Convert a raw user query to a redis full text query joined by ORs"""
506-
# tokens = [escaper.escape(token.strip().strip(",").replace("“", "").replace("”", "").lower()) for token in user_query.split()]
507-
# return " | ".join([token for token in tokens if token and token not in stopwords_en])
508511

509-
class HybridQuery(VectorQuery, TextQuery):
510-
def __init__():
511-
self,
512-
text: str,
513-
text_field: str,
514-
vector: Union[List[float], bytes],
515-
vector_field_name: str,
516-
text_scorer: str = "TFIDF",
517-
alpha: float = 0.7,
518-
return_fields: Optional[List[str]] = None,
519-
filter_expression: Optional[Union[str, FilterExpression]] = None,
520-
dtype: str = "float32",
521-
num_results: int = 10,
522-
return_score: bool = True,
523-
dialect: int = 2,
524-
sort_by: Optional[str] = None,
525-
in_order: bool = False,
512+
class HybridQuery(AggregateRequest):
513+
def __init__(
514+
self, text_query: TextQuery, vector_query: VectorQuery, alpha: float = 0.7
526515
):
527-
"""A query for running a hybrid full text and vector search, along with
528-
an optional filter expression.
516+
"""An aggregate query for running a hybrid full text and vector search.
529517
530518
Args:
531-
text (str): The text string to run text search with.
532-
text_field (str): The name of the text field to search against.
533-
vector (List[float]): The vector to perform the vector search with.
534-
vector_field_name (str): The name of the vector field to search
535-
against in the database.
536-
text_scorer (str, optional): The text scoring algorithm to use.
537-
Defaults to TFIDF.
519+
text_query (TextQuery): The text query to run text search with.
520+
vector_query (VectorQuery): The vector query to run vector search with.
538521
alpha (float, optional): The amount to weight the vector similarity
539522
score relative to the text similarity score. Defaults to 0.7
540-
return_fields (List[str]): The declared fields to return with search
541-
results.
542-
filter_expression (Union[str, FilterExpression], optional): A filter to apply
543-
along with the vector search. Defaults to None.
544-
dtype (str, optional): The dtype of the vector. Defaults to
545-
"float32".
546-
num_results (int, optional): The top k results to return from the
547-
vector search. Defaults to 10.
548-
return_score (bool, optional): Whether to return the vector
549-
distance. Defaults to True.
550-
dialect (int, optional): The RediSearch query dialect.
551-
Defaults to 2.
552-
sort_by (Optional[str]): The field to order the results by. Defaults
553-
to None. Results will be ordered by vector distance.
554-
in_order (bool): Requires the terms in the field to have
555-
the same order as the terms in the query filter, regardless of
556-
the offsets between them. Defaults to False.
557-
558-
Raises:
559-
TypeError: If filter_expression is not of type redisvl.query.FilterExpression
560-
561-
Note:
562-
Learn more about vector queries in Redis: https://redis.io/docs/interact/search-and-query/search/vectors/#knn-search
563-
"""
564-
self._text = text
565-
self._text_field_name = tex_field_name
566-
self._vector = vector
567-
self._vector_field_name = vector_field_name
568-
self._dtype = dtype
569-
self._num_results = num_results
570-
self.set_filter(filter_expression)
571-
query_string = self._build_query_string()
572-
573-
# TODO how to handle multiple parents? call parent.__init__() manually?
574-
super().__init__(query_string)
575-
576-
# Handle query modifiers
577-
if return_fields:
578-
self.return_fields(*return_fields)
579-
580-
self.paging(0, self._num_results).dialect(dialect)
581523
582-
if return_score:
583-
self.return_fields(self.DISTANCE_ID)
584-
585-
if sort_by:
586-
self.sort_by(sort_by)
587-
else:
588-
self.sort_by(self.DISTANCE_ID)
589-
590-
if in_order:
591-
self.in_order()
592-
593-
594-
def _build_query_string(self) -> str:
595-
"""Build the full query string for hybrid search with optional filtering."""
596-
filter_expression = self._filter_expression
597-
# TODO include hybrid
598-
if isinstance(filter_expression, FilterExpression):
599-
filter_expression = str(filter_expression)
600-
return f"{filter_expression}=>[KNN {self._num_results} @{self._vector_field_name} ${self.VECTOR_PARAM} AS {self.DISTANCE_ID}]"
601-
602-
@property
603-
def params(self) -> Dict[str, Any]:
604-
"""Return the parameters for the query.
605-
606-
Returns:
607-
Dict[str, Any]: The parameters for the query.
608524
"""
609-
if isinstance(self._vector, bytes):
610-
vector = self._vector
611-
else:
612-
vector = array_to_buffer(self._vector, dtype=self._dtype)
613-
614-
return {self.VECTOR_PARAM: vector}
615-
525+
pass

tests/unit/test_query_types.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,14 @@
33
from redis.commands.search.result import Result
44

55
from redisvl.index.index import process_results
6-
from redisvl.query import CountQuery, FilterQuery, RangeQuery, VectorQuery
6+
from redisvl.query import (
7+
CountQuery,
8+
FilterQuery,
9+
HybridQuery,
10+
RangeQuery,
11+
TextQuery,
12+
VectorQuery,
13+
)
714
from redisvl.query.filter import Tag
815

916
# Sample data for testing
@@ -187,6 +194,69 @@ def test_range_query():
187194
assert range_query._in_order
188195

189196

197+
def test_text_query():
198+
text_string = "the toon squad play basketball against a gang of aliens"
199+
text_field_name = "description"
200+
return_fields = ["title", "genre", "rating"]
201+
text_query = TextQuery(
202+
text=text_string,
203+
text_field=text_field_name,
204+
return_fields=return_fields,
205+
return_score=False,
206+
)
207+
208+
# Check properties
209+
assert text_query._return_fields == return_fields
210+
assert text_query._num_results == 10
211+
assert (
212+
text_query.filter
213+
== f"(~@{text_field_name}:({text_query.tokenize_and_escape_query(text_string)}))"
214+
)
215+
assert isinstance(text_query, Query)
216+
assert isinstance(text_query.query, Query)
217+
assert isinstance(text_query.params, dict)
218+
assert text_query._text_scorer == "BM25"
219+
assert text_query.params == {}
220+
assert text_query._dialect == 2
221+
assert text_query._in_order == False
222+
223+
# Test paging functionality
224+
text_query.paging(5, 7)
225+
assert text_query._offset == 5
226+
assert text_query._num == 7
227+
assert text_query._num_results == 10
228+
229+
# Test sort_by functionality
230+
filter_expression = Tag("genre") == "comedy"
231+
scorer = "TFIDF"
232+
text_query = TextQuery(
233+
text_string,
234+
text_field_name,
235+
scorer,
236+
filter_expression,
237+
return_fields,
238+
num_results=10,
239+
sort_by="rating",
240+
)
241+
assert text_query._sortby is not None
242+
243+
# Test in_order functionality
244+
text_query = TextQuery(
245+
text_string,
246+
text_field_name,
247+
scorer,
248+
filter_expression,
249+
return_fields,
250+
num_results=10,
251+
in_order=True,
252+
)
253+
assert text_query._in_order
254+
255+
256+
def test_hybrid_query():
257+
pass
258+
259+
190260
@pytest.mark.parametrize(
191261
"query",
192262
[

0 commit comments

Comments
 (0)