Skip to content

Commit f7a4b9e

Browse files
adds hybrid aggregation query and tests. modifies search index to accept aggregate queries
1 parent 49b2aba commit f7a4b9e

File tree

8 files changed

+617
-39
lines changed

8 files changed

+617
-39
lines changed

redisvl/index/index.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
Union,
1717
)
1818

19+
from redisvl.redis.utils import convert_bytes, make_dict
1920
from redisvl.utils.utils import deprecated_argument, deprecated_function, sync_wrapper
2021

2122
if TYPE_CHECKING:
@@ -30,7 +31,13 @@
3031

3132
from redisvl.exceptions import RedisModuleVersionError, RedisSearchError
3233
from redisvl.index.storage import BaseStorage, HashStorage, JsonStorage
33-
from redisvl.query import BaseQuery, CountQuery, FilterQuery
34+
from redisvl.query import (
35+
AggregationQuery,
36+
BaseQuery,
37+
CountQuery,
38+
FilterQuery,
39+
HybridAggregationQuery,
40+
)
3441
from redisvl.query.filter import FilterExpression
3542
from redisvl.redis.connection import (
3643
RedisConnectionFactory,
@@ -101,6 +108,34 @@ def _process(doc: "Document") -> Dict[str, Any]:
101108
return [_process(doc) for doc in results.docs]
102109

103110

111+
def process_aggregate_results(
112+
results: "AggregateResult", query: AggregationQuery, storage_type: StorageType
113+
) -> List[Dict[str, Any]]:
114+
"""Convert an aggregate reslt object into a list of document dictionaries.
115+
116+
This function processes results from Redis, handling different storage
117+
types and query types. For JSON storage with empty return fields, it
118+
unpacks the JSON object while retaining the document ID. The 'payload'
119+
field is also removed from all resulting documents for consistency.
120+
121+
Args:
122+
results (AggregarteResult): The aggregart results from Redis.
123+
query (AggregationQuery): The aggregation query object used for the aggregation.
124+
storage_type (StorageType): The storage type of the search
125+
index (json or hash).
126+
127+
Returns:
128+
List[Dict[str, Any]]: A list of processed document dictionaries.
129+
"""
130+
131+
def _process(row):
132+
result = make_dict(convert_bytes(row))
133+
result.pop("__score", None)
134+
return result
135+
136+
return [_process(r) for r in results.rows]
137+
138+
104139
class BaseSearchIndex:
105140
"""Base search engine class"""
106141

@@ -628,6 +663,44 @@ def fetch(self, id: str) -> Optional[Dict[str, Any]]:
628663
return convert_bytes(obj[0])
629664
return None
630665

666+
def aggregate_query(
667+
self, aggregation_query: AggregationQuery
668+
) -> List[Dict[str, Any]]:
669+
"""Execute an aggretation query and processes the results.
670+
671+
This method takes an AggregationHyridQuery object directly, runs the search, and
672+
handles post-processing of the search.
673+
674+
Args:
675+
aggregation_query (AggregationQuery): The aggregation query to run.
676+
677+
Returns:
678+
List[Result]: A list of search results.
679+
680+
.. code-block:: python
681+
682+
from redisvl.query import HybridAggregationQuery
683+
684+
aggregation = HybridAggregationQuery(
685+
text="the text to search for",
686+
text_field="description",
687+
vector=[0.16, -0.34, 0.98, 0.23],
688+
vector_field="embedding",
689+
num_results=3
690+
)
691+
692+
results = index.aggregate_query(aggregation_query)
693+
694+
"""
695+
results = self.aggregate(
696+
aggregation_query, query_params=aggregation_query.params
697+
)
698+
return process_aggregate_results(
699+
results,
700+
query=aggregation_query,
701+
storage_type=self.schema.index.storage_type,
702+
)
703+
631704
def aggregate(self, *args, **kwargs) -> "AggregateResult":
632705
"""Perform an aggregation operation against the index.
633706

redisvl/query/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1+
from redisvl.query.aggregate import AggregationQuery, HybridAggregationQuery
12
from redisvl.query.query import (
23
BaseQuery,
34
CountQuery,
45
FilterQuery,
5-
HybridQuery,
66
RangeQuery,
77
TextQuery,
88
VectorQuery,
@@ -17,5 +17,6 @@
1717
"VectorRangeQuery",
1818
"CountQuery",
1919
"TextQuery",
20-
"HybridQuery",
20+
"AggregationQuery",
21+
"HybridAggregationQuery",
2122
]

redisvl/query/aggregate.py

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
from typing import Any, Dict, List, Optional, Set, Tuple, Union
2+
3+
import nltk
4+
from nltk.corpus import stopwords as nltk_stopwords
5+
from redis.commands.search.aggregation import AggregateRequest, Desc
6+
7+
from redisvl.query.filter import FilterExpression
8+
from redisvl.redis.utils import array_to_buffer
9+
from redisvl.utils.token_escaper import TokenEscaper
10+
11+
12+
# base class
13+
class AggregationQuery(AggregateRequest):
14+
"""
15+
Base class for aggregation queries used to create aggregation queries for Redis.
16+
"""
17+
18+
def __init__(self, query_string):
19+
super().__init__(query_string)
20+
21+
22+
class HybridAggregationQuery(AggregationQuery):
23+
"""
24+
HybridAggregationQuery combines text and vector search in Redis.
25+
It allows you to perform a hybrid search using both text and vector similarity.
26+
It scores documents based on a weighted combination of text and vector similarity.
27+
"""
28+
29+
DISTANCE_ID: str = "vector_distance"
30+
VECTOR_PARAM: str = "vector"
31+
32+
def __init__(
33+
self,
34+
text: str,
35+
text_field: str,
36+
vector: Union[bytes, List[float]],
37+
vector_field: str,
38+
text_scorer: str = "BM25STD",
39+
filter_expression: Optional[Union[str, FilterExpression]] = None,
40+
alpha: float = 0.7,
41+
dtype: str = "float32",
42+
num_results: int = 10,
43+
return_fields: Optional[List[str]] = None,
44+
stopwords: Optional[Union[str, Set[str]]] = "english",
45+
dialect: int = 4,
46+
):
47+
"""
48+
Instantiages a HybridAggregationQuery object.
49+
50+
Args:
51+
text (str): The text to search for.
52+
text_field (str): The text field name to search in.
53+
vector (Union[bytes, List[float]]): The vector to perform vector similarity search.
54+
vector_field (str): The vector field name to search in.
55+
text_scorer (str, optional): The text scorer to use. Options are {TFIDF, TFIDF.DOCNORM,
56+
BM25, DISMAX, DOCSCORE, BM25STD}. Defaults to "BM25STD".
57+
filter_expression (Optional[FilterExpression], optional): The filter expression to use.
58+
Defaults to None.
59+
alpha (float, optional): The weight of the vector similarity. Documents will be scored
60+
as: hybrid_score = (alpha) * vector_score + (1-alpha) * text_score.
61+
Defaults to 0.7.
62+
dtype (str, optional): The data type of the vector. Defaults to "float32".
63+
num_results (int, optional): The number of results to return. Defaults to 10.
64+
return_fields (Optional[List[str]], optional): The fields to return. Defaults to None.
65+
stopwords (Optional[Union[str, Set[str]]], optional): The stopwords to remove from the
66+
provided text prior to searchuse. If a string such as "english" "german" is
67+
provided then a default set of stopwords for that language will be used. if a list,
68+
set, or tuple of strings is provided then those will be used as stopwords.
69+
Defaults to "english". if set to "None" then no stopwords will be removed.
70+
dialect (int, optional): The Redis dialect version. Defaults to 4.
71+
72+
Raises:
73+
ValueError: If the text string is empty, or if the text string becomes empty after
74+
stopwords are removed.
75+
TypeError: If the stopwords are not a set, list, or tuple of strings.
76+
77+
.. code-block:: python
78+
from redisvl.query.aggregate import HybridAggregationQuery
79+
from redisvl.index import SearchIndex
80+
81+
index = SearchIndex("my_index")
82+
83+
query = HybridAggregationQuery(
84+
text="example text",
85+
text_field="text_field",
86+
vector=[0.1, 0.2, 0.3],
87+
vector_field="vector_field",
88+
text_scorer="BM25STD",
89+
filter_expression=None,
90+
alpha=0.7,
91+
dtype="float32",
92+
num_results=10,
93+
return_fields=["field1", "field2"],
94+
stopwords="english",
95+
dialect=4,
96+
)
97+
98+
results = index.aggregate_query(query)
99+
"""
100+
101+
if not text.strip():
102+
raise ValueError("text string cannot be empty")
103+
104+
self._text = text
105+
self._text_field = text_field
106+
self._vector = vector
107+
self._vector_field = vector_field
108+
self._filter_expression = filter_expression
109+
self._alpha = alpha
110+
self._dtype = dtype
111+
self._num_results = num_results
112+
self.set_stopwords(stopwords)
113+
114+
query_string = self._build_query_string()
115+
super().__init__(query_string)
116+
117+
self.scorer(text_scorer)
118+
self.add_scores()
119+
self.apply(
120+
vector_similarity=f"(2 - @{self.DISTANCE_ID})/2", text_score="@__score"
121+
)
122+
self.apply(hybrid_score=f"{1-alpha}*@text_score + {alpha}*@vector_similarity")
123+
self.sort_by(Desc("@hybrid_score"), max=num_results)
124+
self.dialect(dialect)
125+
126+
if return_fields:
127+
self.load(*return_fields)
128+
129+
@property
130+
def params(self) -> Dict[str, Any]:
131+
"""Return the parameters for the aggregation.
132+
133+
Returns:
134+
Dict[str, Any]: The parameters for the aggregation.
135+
"""
136+
if isinstance(self._vector, bytes):
137+
vector = self._vector
138+
else:
139+
vector = array_to_buffer(self._vector, dtype=self._dtype)
140+
141+
params = {self.VECTOR_PARAM: vector}
142+
143+
return params
144+
145+
@property
146+
def stopwords(self) -> Set[str]:
147+
"""Return the stopwords used in the query.
148+
Returns:
149+
Set[str]: The stopwords used in the query.
150+
"""
151+
return self._stopwords.copy() if self._stopwords else set()
152+
153+
def set_stopwords(self, stopwords: Optional[Union[str, Set[str]]] = "english"):
154+
"""Set the stopwords to use in the query.
155+
Args:
156+
stopwords (Optional[Union[str, Set[str]]]): The stopwords to use. If a string
157+
such as "english" "german" is provided then a default set of stopwords for that
158+
language will be used. if a list, set, or tuple of strings is provided then those
159+
will be used as stopwords. Defaults to "english". if set to "None" then no stopwords
160+
will be removed.
161+
Raises:
162+
TypeError: If the stopwords are not a set, list, or tuple of strings.
163+
"""
164+
if not stopwords:
165+
self._stopwords = set()
166+
elif isinstance(stopwords, str):
167+
try:
168+
nltk.download("stopwords")
169+
self._stopwords = set(nltk_stopwords.words(stopwords))
170+
except Exception as e:
171+
raise ValueError(f"Error trying to load {stopwords} from nltk. {e}")
172+
elif isinstance(stopwords, (Set, List, Tuple)) and all( # type: ignore
173+
isinstance(word, str) for word in stopwords
174+
):
175+
self._stopwords = set(stopwords)
176+
else:
177+
raise TypeError("stopwords must be a set, list, or tuple of strings")
178+
179+
def tokenize_and_escape_query(self, user_query: str) -> str:
180+
"""Convert a raw user query to a redis full text query joined by ORs
181+
Args:
182+
user_query (str): The user query to tokenize and escape.
183+
184+
Returns:
185+
str: The tokenized and escaped query string.
186+
Raises:
187+
ValueError: If the text string becomes empty after stopwords are removed.
188+
"""
189+
190+
escaper = TokenEscaper()
191+
192+
tokens = [
193+
escaper.escape(
194+
token.strip().strip(",").replace("“", "").replace("”", "").lower()
195+
)
196+
for token in user_query.split()
197+
]
198+
tokenized = " | ".join(
199+
[token for token in tokens if token and token not in self._stopwords]
200+
)
201+
202+
if not tokenized:
203+
raise ValueError("text string cannot be empty after removing stopwords")
204+
return tokenized
205+
206+
def _build_query_string(self) -> str:
207+
"""Build the full query string for text search with optional filtering."""
208+
if isinstance(self._filter_expression, FilterExpression):
209+
filter_expression = str(self._filter_expression)
210+
else:
211+
filter_expression = ""
212+
213+
# base KNN query
214+
knn_query = f"KNN {self._num_results} @{self._vector_field} ${self.VECTOR_PARAM} AS {self.DISTANCE_ID}"
215+
216+
text = f"(~@{self._text_field}:({self.tokenize_and_escape_query(self._text)}))"
217+
218+
if filter_expression and filter_expression != "*":
219+
text += f"({filter_expression})"
220+
221+
return f"{text}=>[{knn_query}]"

0 commit comments

Comments
 (0)