1- from datetime import timedelta
2-
31import pytest
42from redis .commands .search .aggregation import AggregateResult
53from redis .commands .search .result import Result
64
5+ from redisvl .redis .connection import compare_versions
6+
77from redisvl .index import SearchIndex
88from redisvl .query import HybridAggregationQuery
99from redisvl .query .filter import (
1313 Num ,
1414 Tag ,
1515 Text ,
16- Timestamp ,
1716)
1817from redisvl .redis .utils import array_to_buffer
1918
20- # TODO expand to multiple schema types and sync + async
21-
22- vector = ([0.1 , 0.1 , 0.5 ],)
23- vector_field_name = ("user_embedding" ,)
24- return_fields = (
25- [
26- "user" ,
27- "credit_score" ,
28- "age" ,
29- "job" ,
30- "location" ,
31- "last_updated" ,
32- ],
33- )
34- filter_expression = (Tag ("credit_score" ) == "high" ,)
35- distance_threshold = (0.2 ,)
36-
3719
3820@pytest .fixture
3921def index (sample_data , redis_url ):
40- # construct a search index from the schema
4122 index = SearchIndex .from_dict (
4223 {
4324 "index" : {
@@ -70,7 +51,7 @@ def index(sample_data, redis_url):
7051 # create the index (no data yet)
7152 index .create (overwrite = True )
7253
73- # Prepare and load the data
54+ # prepare and load the data
7455 def hash_preprocess (item : dict ) -> dict :
7556 return {
7657 ** item ,
@@ -87,7 +68,10 @@ def hash_preprocess(item: dict) -> dict:
8768
8869
8970def test_aggregation_query (index ):
90- # *=>[KNN 7 @user_embedding $vector AS vector_distance]
71+ redis_version = index .client .info ()["redis_version" ]
72+ if not compare_versions (redis_version , "7.2.0" ):
73+ pytest .skip ("Not using a late enough version of Redis" )
74+
9175 text = "a medical professional with expertise in lung cancer"
9276 text_field = "description"
9377 vector = [0.1 , 0.1 , 0.5 ]
@@ -96,17 +80,16 @@ def test_aggregation_query(index):
9680
9781 hybrid_query = HybridAggregationQuery (
9882 text = text ,
99- text_field = text_field ,
83+ text_field_name = text_field ,
10084 vector = vector ,
101- vector_field = vector_field ,
85+ vector_field_name = vector_field ,
10286 return_fields = return_fields ,
10387 )
10488
10589 results = index .aggregate_query (hybrid_query )
10690 assert isinstance (results , list )
10791 assert len (results ) == 7
10892 for doc in results :
109- # ensure all return fields present
11093 assert doc ["user" ] in [
11194 "john" ,
11295 "derrick" ,
@@ -121,12 +104,11 @@ def test_aggregation_query(index):
121104 assert doc ["job" ] in ["engineer" , "doctor" , "dermatologist" , "CEO" , "dentist" ]
122105 assert doc ["credit_score" ] in ["high" , "low" , "medium" ]
123106
124- # test num_results
125107 hybrid_query = HybridAggregationQuery (
126108 text = text ,
127- text_field = text_field ,
109+ text_field_name = text_field ,
128110 vector = vector ,
129- vector_field = vector_field ,
111+ vector_field_name = vector_field ,
130112 num_results = 3 ,
131113 )
132114
@@ -139,7 +121,7 @@ def test_aggregation_query(index):
139121 )
140122
141123
142- def test_empty_query_string (index ):
124+ def test_empty_query_string ():
143125 text = ""
144126 text_field = "description"
145127 vector = [0.1 , 0.1 , 0.5 ]
@@ -149,30 +131,86 @@ def test_empty_query_string(index):
149131 # test if text is empty
150132 with pytest .raises (ValueError ):
151133 hybrid_query = HybridAggregationQuery (
152- text = text , text_field = text_field , vector = vector , vector_field = vector_field
134+ text = text , text_field_name = text_field , vector = vector , vector_field_name = vector_field
153135 )
154136
155137 # test if text becomes empty after stopwords are removed
156138 text = "with a for but and" # will all be removed as default stopwords
157139 with pytest .raises (ValueError ):
158140 hybrid_query = HybridAggregationQuery (
159- text = text , text_field = text_field , vector = vector , vector_field = vector_field
141+ text = text , text_field_name = text_field , vector = vector , vector_field_name = vector_field
160142 )
161143
144+ def test_aggregation_query_filter (index ):
145+ redis_version = index .client .info ()["redis_version" ]
146+ if not compare_versions (redis_version , "7.2.0" ):
147+ pytest .skip ("Not using a late enough version of Redis" )
162148
163- def test_aggregate_query_stopwords (index ):
164149 text = "a medical professional with expertise in lung cancer"
165150 text_field = "description"
166151 vector = [0.1 , 0.1 , 0.5 ]
167152 vector_field = "user_embedding"
168153 return_fields = ["user" , "credit_score" , "age" , "job" , "location" , "description" ]
169- return
170- # test num_results
154+ filter_expression = ( Tag ( "credit_score" ) == ( "high" )) & ( Num ( "age" ) > 30 )
155+
171156 hybrid_query = HybridAggregationQuery (
172157 text = text ,
173- text_field = text_field ,
158+ text_field_name = text_field ,
174159 vector = vector ,
175- vector_field = vector_field ,
160+ vector_field_name = vector_field ,
161+ filter_expression = filter_expression ,
162+ return_fields = return_fields ,
163+ )
164+
165+ results = index .aggregate_query (hybrid_query )
166+ assert len (results ) == 3
167+ for result in results :
168+ assert result ["credit_score" ] == "high"
169+ assert int (result ["age" ]) > 30
170+
171+
172+ def test_aggregation_query_with_geo_filter (index ):
173+ redis_version = index .client .info ()["redis_version" ]
174+ if not compare_versions (redis_version , "7.2.0" ):
175+ pytest .skip ("Not using a late enough version of Redis" )
176+
177+ text = "a medical professional with expertise in lung cancer"
178+ text_field = "description"
179+ vector = [0.1 , 0.1 , 0.5 ]
180+ vector_field = "user_embedding"
181+ return_fields = ["user" , "credit_score" , "age" , "job" , "location" , "description" ]
182+ filter_expression = Geo ("location" ) == GeoRadius (37.7749 , - 122.4194 , 1000 )
183+
184+ hybrid_query = HybridAggregationQuery (
185+ text = text ,
186+ text_field_name = text_field ,
187+ vector = vector ,
188+ vector_field_name = vector_field ,
189+ filter_expression = filter_expression ,
190+ return_fields = return_fields ,
191+ )
192+
193+ results = index .aggregate_query (hybrid_query )
194+ assert len (results ) == 3
195+ for result in results :
196+ assert result ["location" ] is not None
197+
198+
199+ def test_aggregate_query_stopwords (index ):
200+ redis_version = index .client .info ()["redis_version" ]
201+ if not compare_versions (redis_version , "7.2.0" ):
202+ pytest .skip ("Not using a late enough version of Redis" )
203+
204+ text = "a medical professional with expertise in lung cancer"
205+ text_field = "description"
206+ vector = [0.1 , 0.1 , 0.5 ]
207+ vector_field = "user_embedding"
208+
209+ hybrid_query = HybridAggregationQuery (
210+ text = text ,
211+ text_field_name = text_field ,
212+ vector = vector ,
213+ vector_field_name = vector_field ,
176214 alpha = 0.5 ,
177215 stopwords = ["medical" , "expertise" ],
178216 )
@@ -182,3 +220,30 @@ def test_aggregate_query_stopwords(index):
182220 for r in results :
183221 assert r ["text_score" ] == 0
184222 assert r ["hybrid_score" ] == 0.5 * r ["vector_similarity" ]
223+
224+
225+ def test_aggregate_query_text_filter (index ):
226+ redis_version = index .client .info ()["redis_version" ]
227+ if not compare_versions (redis_version , "7.2.0" ):
228+ pytest .skip ("Not using a late enough version of Redis" )
229+
230+ text = "a medical professional with expertise in lung cancer"
231+ text_field = "description"
232+ vector = [0.1 , 0.1 , 0.5 ]
233+ vector_field = "user_embedding"
234+ filter_expression = (Text ("description" ) == ("medical" )) | (Text ("job" ) % ("doct*" ))
235+
236+ hybrid_query = HybridAggregationQuery (
237+ text = text ,
238+ text_field_name = text_field ,
239+ vector = vector ,
240+ vector_field_name = vector_field ,
241+ alpha = 0.5 ,
242+ filter_expression = filter_expression
243+ )
244+
245+ results = index .aggregate_query (hybrid_query )
246+ assert len (results ) == 7
247+ for result in results :
248+ assert result ["text_score" ] == 0
249+ assert result ["hybrid_score" ] == 0.5 * result ["vector_similarity" ]
0 commit comments