Skip to content

Commit ea5d087

Browse files
wip: debugging aggregations and filters
1 parent c5ad696 commit ea5d087

File tree

3 files changed

+118
-50
lines changed

3 files changed

+118
-50
lines changed

redisvl/query/aggregate.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,26 +32,26 @@ class HybridAggregationQuery(AggregationQuery):
3232
def __init__(
3333
self,
3434
text: str,
35-
text_field: str,
35+
text_field_name: str,
3636
vector: Union[bytes, List[float]],
37-
vector_field: str,
37+
vector_field_name: str,
3838
text_scorer: str = "BM25STD",
3939
filter_expression: Optional[Union[str, FilterExpression]] = None,
4040
alpha: float = 0.7,
4141
dtype: str = "float32",
4242
num_results: int = 10,
4343
return_fields: Optional[List[str]] = None,
4444
stopwords: Optional[Union[str, Set[str]]] = "english",
45-
dialect: int = 4,
45+
dialect: int = 2,
4646
):
4747
"""
4848
Instantiages a HybridAggregationQuery object.
4949
5050
Args:
5151
text (str): The text to search for.
52-
text_field (str): The text field name to search in.
52+
text_field_name (str): The text field name to search in.
5353
vector (Union[bytes, List[float]]): The vector to perform vector similarity search.
54-
vector_field (str): The vector field name to search in.
54+
vector_field_name (str): The vector field name to search in.
5555
text_scorer (str, optional): The text scorer to use. Options are {TFIDF, TFIDF.DOCNORM,
5656
BM25, DISMAX, DOCSCORE, BM25STD}. Defaults to "BM25STD".
5757
filter_expression (Optional[FilterExpression], optional): The filter expression to use.
@@ -67,7 +67,7 @@ def __init__(
6767
provided then a default set of stopwords for that language will be used. if a list,
6868
set, or tuple of strings is provided then those will be used as stopwords.
6969
Defaults to "english". if set to "None" then no stopwords will be removed.
70-
dialect (int, optional): The Redis dialect version. Defaults to 4.
70+
dialect (int, optional): The Redis dialect version. Defaults to 2.
7171
7272
Raises:
7373
ValueError: If the text string is empty, or if the text string becomes empty after
@@ -82,9 +82,9 @@ def __init__(
8282
8383
query = HybridAggregationQuery(
8484
text="example text",
85-
text_field="text_field",
85+
text_field_name="text_field",
8686
vector=[0.1, 0.2, 0.3],
87-
vector_field="vector_field",
87+
vector_field_name="vector_field",
8888
text_scorer="BM25STD",
8989
filter_expression=None,
9090
alpha=0.7,
@@ -102,16 +102,19 @@ def __init__(
102102
raise ValueError("text string cannot be empty")
103103

104104
self._text = text
105-
self._text_field = text_field
105+
self._text_field = text_field_name
106106
self._vector = vector
107-
self._vector_field = vector_field
107+
self._vector_field = vector_field_name
108108
self._filter_expression = filter_expression
109109
self._alpha = alpha
110110
self._dtype = dtype
111111
self._num_results = num_results
112112
self.set_stopwords(stopwords)
113113

114114
query_string = self._build_query_string()
115+
####
116+
print('query_string: ', query_string)
117+
####
115118
super().__init__(query_string)
116119

117120
self.scorer(text_scorer) # type: ignore[attr-defined]

redisvl/query/query.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -691,7 +691,7 @@ class TextQuery(FilterQuery):
691691
def __init__(
692692
self,
693693
text: str,
694-
text_field: str,
694+
text_field_name: str,
695695
text_scorer: str = "BM25STD",
696696
filter_expression: Optional[Union[str, FilterExpression]] = None,
697697
return_fields: Optional[List[str]] = None,
@@ -708,7 +708,7 @@ def __init__(
708708
709709
Args:
710710
text (str): The text string to perform the text search with.
711-
text_field (str): The name of the document field to perform text search on.
711+
text_field_name (str): The name of the document field to perform text search on.
712712
text_scorer (str, optional): The text scoring algorithm to use.
713713
Defaults to BM25STD. Options are {TFIDF, BM25STD, BM25, DOCNORM, DISMAX, DOCSCORE}.
714714
See https://redis.io/docs/latest/develop/interact/search-and-query/advanced-concepts/scoring/
@@ -740,7 +740,7 @@ def __init__(
740740
TypeError: If stopwords is not a valid iterable set of strings.
741741
"""
742742
self._text = text
743-
self._text_field = text_field
743+
self._text_field = text_field_name
744744
self._text_scorer = text_scorer
745745
self._num_results = num_results
746746

tests/integration/test_aggregation.py

Lines changed: 102 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from datetime import timedelta
2-
31
import pytest
42
from redis.commands.search.aggregation import AggregateResult
53
from redis.commands.search.result import Result
64

5+
from redisvl.redis.connection import compare_versions
6+
77
from redisvl.index import SearchIndex
88
from redisvl.query import HybridAggregationQuery
99
from redisvl.query.filter import (
@@ -13,31 +13,12 @@
1313
Num,
1414
Tag,
1515
Text,
16-
Timestamp,
1716
)
1817
from redisvl.redis.utils import array_to_buffer
1918

20-
# TODO expand to multiple schema types and sync + async
21-
22-
vector = ([0.1, 0.1, 0.5],)
23-
vector_field_name = ("user_embedding",)
24-
return_fields = (
25-
[
26-
"user",
27-
"credit_score",
28-
"age",
29-
"job",
30-
"location",
31-
"last_updated",
32-
],
33-
)
34-
filter_expression = (Tag("credit_score") == "high",)
35-
distance_threshold = (0.2,)
36-
3719

3820
@pytest.fixture
3921
def index(sample_data, redis_url):
40-
# construct a search index from the schema
4122
index = SearchIndex.from_dict(
4223
{
4324
"index": {
@@ -70,7 +51,7 @@ def index(sample_data, redis_url):
7051
# create the index (no data yet)
7152
index.create(overwrite=True)
7253

73-
# Prepare and load the data
54+
# prepare and load the data
7455
def hash_preprocess(item: dict) -> dict:
7556
return {
7657
**item,
@@ -87,7 +68,10 @@ def hash_preprocess(item: dict) -> dict:
8768

8869

8970
def test_aggregation_query(index):
90-
# *=>[KNN 7 @user_embedding $vector AS vector_distance]
71+
redis_version = index.client.info()["redis_version"]
72+
if not compare_versions(redis_version, "7.2.0"):
73+
pytest.skip("Not using a late enough version of Redis")
74+
9175
text = "a medical professional with expertise in lung cancer"
9276
text_field = "description"
9377
vector = [0.1, 0.1, 0.5]
@@ -96,17 +80,16 @@ def test_aggregation_query(index):
9680

9781
hybrid_query = HybridAggregationQuery(
9882
text=text,
99-
text_field=text_field,
83+
text_field_name=text_field,
10084
vector=vector,
101-
vector_field=vector_field,
85+
vector_field_name=vector_field,
10286
return_fields=return_fields,
10387
)
10488

10589
results = index.aggregate_query(hybrid_query)
10690
assert isinstance(results, list)
10791
assert len(results) == 7
10892
for doc in results:
109-
# ensure all return fields present
11093
assert doc["user"] in [
11194
"john",
11295
"derrick",
@@ -121,12 +104,11 @@ def test_aggregation_query(index):
121104
assert doc["job"] in ["engineer", "doctor", "dermatologist", "CEO", "dentist"]
122105
assert doc["credit_score"] in ["high", "low", "medium"]
123106

124-
# test num_results
125107
hybrid_query = HybridAggregationQuery(
126108
text=text,
127-
text_field=text_field,
109+
text_field_name=text_field,
128110
vector=vector,
129-
vector_field=vector_field,
111+
vector_field_name=vector_field,
130112
num_results=3,
131113
)
132114

@@ -139,7 +121,7 @@ def test_aggregation_query(index):
139121
)
140122

141123

142-
def test_empty_query_string(index):
124+
def test_empty_query_string():
143125
text = ""
144126
text_field = "description"
145127
vector = [0.1, 0.1, 0.5]
@@ -149,30 +131,86 @@ def test_empty_query_string(index):
149131
# test if text is empty
150132
with pytest.raises(ValueError):
151133
hybrid_query = HybridAggregationQuery(
152-
text=text, text_field=text_field, vector=vector, vector_field=vector_field
134+
text=text, text_field_name=text_field, vector=vector, vector_field_name=vector_field
153135
)
154136

155137
# test if text becomes empty after stopwords are removed
156138
text = "with a for but and" # will all be removed as default stopwords
157139
with pytest.raises(ValueError):
158140
hybrid_query = HybridAggregationQuery(
159-
text=text, text_field=text_field, vector=vector, vector_field=vector_field
141+
text=text, text_field_name=text_field, vector=vector, vector_field_name=vector_field
160142
)
161143

144+
def test_aggregation_query_filter(index):
145+
redis_version = index.client.info()["redis_version"]
146+
if not compare_versions(redis_version, "7.2.0"):
147+
pytest.skip("Not using a late enough version of Redis")
162148

163-
def test_aggregate_query_stopwords(index):
164149
text = "a medical professional with expertise in lung cancer"
165150
text_field = "description"
166151
vector = [0.1, 0.1, 0.5]
167152
vector_field = "user_embedding"
168153
return_fields = ["user", "credit_score", "age", "job", "location", "description"]
169-
return
170-
# test num_results
154+
filter_expression = (Tag("credit_score") == ("high")) & (Num("age") > 30)
155+
171156
hybrid_query = HybridAggregationQuery(
172157
text=text,
173-
text_field=text_field,
158+
text_field_name=text_field,
174159
vector=vector,
175-
vector_field=vector_field,
160+
vector_field_name=vector_field,
161+
filter_expression=filter_expression,
162+
return_fields=return_fields,
163+
)
164+
165+
results = index.aggregate_query(hybrid_query)
166+
assert len(results) == 3
167+
for result in results:
168+
assert result["credit_score"] == "high"
169+
assert int(result["age"]) > 30
170+
171+
172+
def test_aggregation_query_with_geo_filter(index):
173+
redis_version = index.client.info()["redis_version"]
174+
if not compare_versions(redis_version, "7.2.0"):
175+
pytest.skip("Not using a late enough version of Redis")
176+
177+
text = "a medical professional with expertise in lung cancer"
178+
text_field = "description"
179+
vector = [0.1, 0.1, 0.5]
180+
vector_field = "user_embedding"
181+
return_fields = ["user", "credit_score", "age", "job", "location", "description"]
182+
filter_expression = Geo("location") == GeoRadius(37.7749, -122.4194, 1000)
183+
184+
hybrid_query = HybridAggregationQuery(
185+
text=text,
186+
text_field_name=text_field,
187+
vector=vector,
188+
vector_field_name=vector_field,
189+
filter_expression=filter_expression,
190+
return_fields=return_fields,
191+
)
192+
193+
results = index.aggregate_query(hybrid_query)
194+
assert len(results) == 3
195+
for result in results:
196+
assert result["location"] is not None
197+
198+
199+
def test_aggregate_query_stopwords(index):
200+
redis_version = index.client.info()["redis_version"]
201+
if not compare_versions(redis_version, "7.2.0"):
202+
pytest.skip("Not using a late enough version of Redis")
203+
204+
text = "a medical professional with expertise in lung cancer"
205+
text_field = "description"
206+
vector = [0.1, 0.1, 0.5]
207+
vector_field = "user_embedding"
208+
209+
hybrid_query = HybridAggregationQuery(
210+
text=text,
211+
text_field_name=text_field,
212+
vector=vector,
213+
vector_field_name=vector_field,
176214
alpha=0.5,
177215
stopwords=["medical", "expertise"],
178216
)
@@ -182,3 +220,30 @@ def test_aggregate_query_stopwords(index):
182220
for r in results:
183221
assert r["text_score"] == 0
184222
assert r["hybrid_score"] == 0.5 * r["vector_similarity"]
223+
224+
225+
def test_aggregate_query_text_filter(index):
226+
redis_version = index.client.info()["redis_version"]
227+
if not compare_versions(redis_version, "7.2.0"):
228+
pytest.skip("Not using a late enough version of Redis")
229+
230+
text = "a medical professional with expertise in lung cancer"
231+
text_field = "description"
232+
vector = [0.1, 0.1, 0.5]
233+
vector_field = "user_embedding"
234+
filter_expression = (Text("description") == ("medical")) | (Text("job") % ("doct*"))
235+
236+
hybrid_query = HybridAggregationQuery(
237+
text=text,
238+
text_field_name=text_field,
239+
vector=vector,
240+
vector_field_name=vector_field,
241+
alpha=0.5,
242+
filter_expression=filter_expression
243+
)
244+
245+
results = index.aggregate_query(hybrid_query)
246+
assert len(results) == 7
247+
for result in results:
248+
assert result["text_score"] == 0
249+
assert result["hybrid_score"] == 0.5 * result["vector_similarity"]

0 commit comments

Comments
 (0)