Skip to content

Commit 1672ea3

Browse files
fixes query string parsing. adds more tests
1 parent ea5d087 commit 1672ea3

File tree

6 files changed

+212
-62
lines changed

6 files changed

+212
-62
lines changed

redisvl/query/aggregate.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,6 @@ def __init__(
112112
self.set_stopwords(stopwords)
113113

114114
query_string = self._build_query_string()
115-
####
116-
print('query_string: ', query_string)
117-
####
118115
super().__init__(query_string)
119116

120117
self.scorer(text_scorer) # type: ignore[attr-defined]
@@ -125,7 +122,6 @@ def __init__(
125122
self.apply(hybrid_score=f"{1-alpha}*@text_score + {alpha}*@vector_similarity")
126123
self.sort_by(Desc("@hybrid_score"), max=num_results)
127124
self.dialect(dialect) # type: ignore[attr-defined]
128-
129125
if return_fields:
130126
self.load(*return_fields)
131127

@@ -216,9 +212,9 @@ def _build_query_string(self) -> str:
216212
# base KNN query
217213
knn_query = f"KNN {self._num_results} @{self._vector_field} ${self.VECTOR_PARAM} AS {self.DISTANCE_ID}"
218214

219-
text = f"(~@{self._text_field}:({self.tokenize_and_escape_query(self._text)}))"
215+
text = f"(~@{self._text_field}:({self.tokenize_and_escape_query(self._text)})"
220216

221217
if filter_expression and filter_expression != "*":
222-
text += f"({filter_expression})"
218+
text += f" AND {filter_expression}"
223219

224-
return f"{text}=>[{knn_query}]"
220+
return f"{text})=>[{knn_query}]"

redisvl/query/query.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,7 @@ class RangeQuery(VectorRangeQuery):
687687
pass
688688

689689

690-
class TextQuery(FilterQuery):
690+
class TextQuery(BaseQuery):
691691
def __init__(
692692
self,
693693
text: str,
@@ -747,22 +747,26 @@ def __init__(
747747
self.set_stopwords(stopwords)
748748
self.set_filter(filter_expression)
749749

750-
query_string = self._build_query_string()
750+
if params:
751+
self._params = params
751752

752-
super().__init__(
753-
query_string,
754-
return_fields=return_fields,
755-
num_results=num_results,
756-
dialect=dialect,
757-
sort_by=sort_by,
758-
in_order=in_order,
759-
params=params,
760-
)
753+
self._num_results = num_results
761754

762-
# Handle query modifiers
763-
self.scorer(self._text_scorer)
755+
# initialize the base query with the full query string and filter expression
756+
query_string = self._build_query_string()
757+
super().__init__(query_string)
758+
759+
# Handle query settings
760+
if return_fields:
761+
self.return_fields(*return_fields)
764762
self.paging(0, self._num_results).dialect(dialect)
765763

764+
if sort_by:
765+
self.sort_by(sort_by)
766+
767+
if in_order:
768+
self.in_order()
769+
766770
if return_score:
767771
self.with_scores()
768772

@@ -812,7 +816,7 @@ def _build_query_string(self) -> str:
812816
else:
813817
filter_expression = ""
814818

815-
text = f"(@{self._text_field}:({self.tokenize_and_escape_query(self._text)}))"
819+
text = f"@{self._text_field}:({self.tokenize_and_escape_query(self._text)})"
816820
if filter_expression and filter_expression != "*":
817-
text += f"({filter_expression})"
821+
text += f" AND {filter_expression}"
818822
return text

tests/integration/test_aggregation.py

Lines changed: 85 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,10 @@
22
from redis.commands.search.aggregation import AggregateResult
33
from redis.commands.search.result import Result
44

5-
from redisvl.redis.connection import compare_versions
6-
75
from redisvl.index import SearchIndex
86
from redisvl.query import HybridAggregationQuery
9-
from redisvl.query.filter import (
10-
FilterExpression,
11-
Geo,
12-
GeoRadius,
13-
Num,
14-
Tag,
15-
Text,
16-
)
7+
from redisvl.query.filter import FilterExpression, Geo, GeoRadius, Num, Tag, Text
8+
from redisvl.redis.connection import compare_versions
179
from redisvl.redis.utils import array_to_buffer
1810

1911

@@ -131,17 +123,24 @@ def test_empty_query_string():
131123
# test if text is empty
132124
with pytest.raises(ValueError):
133125
hybrid_query = HybridAggregationQuery(
134-
text=text, text_field_name=text_field, vector=vector, vector_field_name=vector_field
126+
text=text,
127+
text_field_name=text_field,
128+
vector=vector,
129+
vector_field_name=vector_field,
135130
)
136131

137132
# test if text becomes empty after stopwords are removed
138133
text = "with a for but and" # will all be removed as default stopwords
139134
with pytest.raises(ValueError):
140135
hybrid_query = HybridAggregationQuery(
141-
text=text, text_field_name=text_field, vector=vector, vector_field_name=vector_field
136+
text=text,
137+
text_field_name=text_field,
138+
vector=vector,
139+
vector_field_name=vector_field,
142140
)
143141

144-
def test_aggregation_query_filter(index):
142+
143+
def test_aggregation_query_with_filter(index):
145144
redis_version = index.client.info()["redis_version"]
146145
if not compare_versions(redis_version, "7.2.0"):
147146
pytest.skip("Not using a late enough version of Redis")
@@ -163,7 +162,7 @@ def test_aggregation_query_filter(index):
163162
)
164163

165164
results = index.aggregate_query(hybrid_query)
166-
assert len(results) == 3
165+
assert len(results) == 2
167166
for result in results:
168167
assert result["credit_score"] == "high"
169168
assert int(result["age"]) > 30
@@ -179,7 +178,7 @@ def test_aggregation_query_with_geo_filter(index):
179178
vector = [0.1, 0.1, 0.5]
180179
vector_field = "user_embedding"
181180
return_fields = ["user", "credit_score", "age", "job", "location", "description"]
182-
filter_expression = Geo("location") == GeoRadius(37.7749, -122.4194, 1000)
181+
filter_expression = Geo("location") == GeoRadius(-122.4194, 37.7749, 1000, "m")
183182

184183
hybrid_query = HybridAggregationQuery(
185184
text=text,
@@ -196,6 +195,36 @@ def test_aggregation_query_with_geo_filter(index):
196195
assert result["location"] is not None
197196

198197

198+
@pytest.mark.parametrize("alpha", [0.1, 0.5, 0.9])
199+
def test_aggregate_query_alpha(index, alpha):
200+
redis_version = index.client.info()["redis_version"]
201+
if not compare_versions(redis_version, "7.2.0"):
202+
pytest.skip("Not using a late enough version of Redis")
203+
204+
text = "a medical professional with expertise in lung cancer"
205+
text_field = "description"
206+
vector = [0.1, 0.1, 0.5]
207+
vector_field = "user_embedding"
208+
209+
hybrid_query = HybridAggregationQuery(
210+
text=text,
211+
text_field_name=text_field,
212+
vector=vector,
213+
vector_field_name=vector_field,
214+
alpha=alpha,
215+
)
216+
217+
results = index.aggregate_query(hybrid_query)
218+
assert len(results) == 7
219+
for result in results:
220+
score = alpha * float(result["vector_similarity"]) + (1 - alpha) * float(
221+
result["text_score"]
222+
)
223+
assert (
224+
float(result["hybrid_score"]) - score <= 0.0001
225+
) # allow for small floating point error
226+
227+
199228
def test_aggregate_query_stopwords(index):
200229
redis_version = index.client.info()["redis_version"]
201230
if not compare_versions(redis_version, "7.2.0"):
@@ -205,24 +234,34 @@ def test_aggregate_query_stopwords(index):
205234
text_field = "description"
206235
vector = [0.1, 0.1, 0.5]
207236
vector_field = "user_embedding"
237+
alpha = 0.5
208238

209239
hybrid_query = HybridAggregationQuery(
210240
text=text,
211241
text_field_name=text_field,
212242
vector=vector,
213243
vector_field_name=vector_field,
214-
alpha=0.5,
244+
alpha=alpha,
215245
stopwords=["medical", "expertise"],
216246
)
217247

248+
query_string = hybrid_query._build_query_string()
249+
250+
assert "medical" not in query_string
251+
assert "expertize" not in query_string
252+
218253
results = index.aggregate_query(hybrid_query)
219254
assert len(results) == 7
220-
for r in results:
221-
assert r["text_score"] == 0
222-
assert r["hybrid_score"] == 0.5 * r["vector_similarity"]
255+
for result in results:
256+
score = alpha * float(result["vector_similarity"]) + (1 - alpha) * float(
257+
result["text_score"]
258+
)
259+
assert (
260+
float(result["hybrid_score"]) - score <= 0.0001
261+
) # allow for small floating point error
223262

224263

225-
def test_aggregate_query_text_filter(index):
264+
def test_aggregate_query_with_text_filter(index):
226265
redis_version = index.client.info()["redis_version"]
227266
if not compare_versions(redis_version, "7.2.0"):
228267
pytest.skip("Not using a late enough version of Redis")
@@ -231,19 +270,39 @@ def test_aggregate_query_text_filter(index):
231270
text_field = "description"
232271
vector = [0.1, 0.1, 0.5]
233272
vector_field = "user_embedding"
234-
filter_expression = (Text("description") == ("medical")) | (Text("job") % ("doct*"))
273+
filter_expression = Text(text_field) == ("medical")
235274

275+
# make sure we can still apply filters to the same text field we are querying
236276
hybrid_query = HybridAggregationQuery(
237277
text=text,
238278
text_field_name=text_field,
239279
vector=vector,
240280
vector_field_name=vector_field,
241281
alpha=0.5,
242-
filter_expression=filter_expression
243-
)
282+
filter_expression=filter_expression,
283+
return_fields=["job", "description"],
284+
)
244285

245286
results = index.aggregate_query(hybrid_query)
246-
assert len(results) == 7
287+
assert len(results) == 2
288+
for result in results:
289+
assert "medical" in result[text_field].lower()
290+
291+
filter_expression = (Text(text_field) == ("medical")) & (
292+
(Text(text_field) != ("research"))
293+
)
294+
hybrid_query = HybridAggregationQuery(
295+
text=text,
296+
text_field_name=text_field,
297+
vector=vector,
298+
vector_field_name=vector_field,
299+
alpha=0.5,
300+
filter_expression=filter_expression,
301+
return_fields=["description"],
302+
)
303+
304+
results = index.aggregate_query(hybrid_query)
305+
assert len(results) == 2
247306
for result in results:
248-
assert result["text_score"] == 0
249-
assert result["hybrid_score"] == 0.5 * result["vector_similarity"]
307+
assert "medical" in result[text_field].lower()
308+
assert "research" not in result[text_field].lower()

tests/integration/test_query.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,13 @@
44
from redis.commands.search.result import Result
55

66
from redisvl.index import SearchIndex
7-
from redisvl.query import CountQuery, FilterQuery, VectorQuery, VectorRangeQuery
7+
from redisvl.query import (
8+
CountQuery,
9+
FilterQuery,
10+
TextQuery,
11+
VectorQuery,
12+
VectorRangeQuery,
13+
)
814
from redisvl.query.filter import (
915
FilterExpression,
1016
Geo,
@@ -147,6 +153,7 @@ def index(sample_data, redis_url):
147153
"storage_type": "hash",
148154
},
149155
"fields": [
156+
{"name": "description", "type": "text"},
150157
{"name": "credit_score", "type": "tag"},
151158
{"name": "job", "type": "text"},
152159
{"name": "age", "type": "numeric"},
@@ -790,3 +797,78 @@ def test_range_query_normalize_bad_input(index):
790797
return_fields=["user", "credit_score", "age", "job", "location"],
791798
distance_threshold=1.2,
792799
)
800+
801+
802+
def test_text_query(index):
803+
text = "a medical professional with expertise in lung cancer"
804+
text_field = "description"
805+
return_fields = ["user", "credit_score", "age", "job", "location", "description"]
806+
807+
text_query = TextQuery(
808+
text=text,
809+
text_field_name=text_field,
810+
return_fields=return_fields,
811+
)
812+
results = index.query(text_query)
813+
814+
assert len(results) == 4
815+
816+
# make sure at least one word from the query is in the description
817+
for result in results:
818+
assert any(word in result[text_field] for word in text.split())
819+
820+
821+
# test that text queryies work with filter expressions
822+
def test_text_query_with_filter(index):
823+
text = "a medical professional with expertise in lung cancer"
824+
text_field = "description"
825+
return_fields = ["user", "credit_score", "age", "job", "location", "description"]
826+
filter_expression = (Tag("credit_score") == ("high")) & (Num("age") > 30)
827+
828+
text_query = TextQuery(
829+
text=text,
830+
text_field_name=text_field,
831+
filter_expression=filter_expression,
832+
return_fields=return_fields,
833+
)
834+
results = index.query(text_query)
835+
assert len(results) == 2
836+
for result in results:
837+
assert any(word in result[text_field] for word in text.split())
838+
assert result["credit_score"] == "high"
839+
assert int(result["age"]) > 30
840+
841+
842+
# test that text queryies workt with text filter expressions on the same text field
843+
def test_text_query_with_text_filter(index):
844+
text = "a medical professional with expertise in lung cancer"
845+
text_field = "description"
846+
return_fields = ["age", "job", "description"]
847+
filter_expression = Text(text_field) == ("medical")
848+
849+
text_query = TextQuery(
850+
text=text,
851+
text_field_name=text_field,
852+
filter_expression=filter_expression,
853+
return_fields=return_fields,
854+
)
855+
results = index.query(text_query)
856+
assert len(results) == 2
857+
for result in results:
858+
assert any(word in result[text_field] for word in text.split())
859+
assert "medical" in result[text_field]
860+
861+
filter_expression = Text(text_field) != ("research")
862+
863+
text_query = TextQuery(
864+
text=text,
865+
text_field_name=text_field,
866+
filter_expression=filter_expression,
867+
return_fields=return_fields,
868+
)
869+
870+
results = index.query(text_query)
871+
assert len(results) == 3
872+
for result in results:
873+
assert any(word in result[text_field] for word in text.split())
874+
assert "research" not in result[text_field]

0 commit comments

Comments
 (0)