22from redis .commands .search .aggregation import AggregateResult
33from redis .commands .search .result import Result
44
5- from redisvl .redis .connection import compare_versions
6-
75from redisvl .index import SearchIndex
86from redisvl .query import HybridAggregationQuery
9- from redisvl .query .filter import (
10- FilterExpression ,
11- Geo ,
12- GeoRadius ,
13- Num ,
14- Tag ,
15- Text ,
16- )
7+ from redisvl .query .filter import FilterExpression , Geo , GeoRadius , Num , Tag , Text
8+ from redisvl .redis .connection import compare_versions
179from redisvl .redis .utils import array_to_buffer
1810
1911
@@ -131,17 +123,24 @@ def test_empty_query_string():
131123 # test if text is empty
132124 with pytest .raises (ValueError ):
133125 hybrid_query = HybridAggregationQuery (
134- text = text , text_field_name = text_field , vector = vector , vector_field_name = vector_field
126+ text = text ,
127+ text_field_name = text_field ,
128+ vector = vector ,
129+ vector_field_name = vector_field ,
135130 )
136131
137132 # test if text becomes empty after stopwords are removed
138133 text = "with a for but and" # will all be removed as default stopwords
139134 with pytest .raises (ValueError ):
140135 hybrid_query = HybridAggregationQuery (
141- text = text , text_field_name = text_field , vector = vector , vector_field_name = vector_field
136+ text = text ,
137+ text_field_name = text_field ,
138+ vector = vector ,
139+ vector_field_name = vector_field ,
142140 )
143141
144- def test_aggregation_query_filter (index ):
142+
143+ def test_aggregation_query_with_filter (index ):
145144 redis_version = index .client .info ()["redis_version" ]
146145 if not compare_versions (redis_version , "7.2.0" ):
147146 pytest .skip ("Not using a late enough version of Redis" )
@@ -163,7 +162,7 @@ def test_aggregation_query_filter(index):
163162 )
164163
165164 results = index .aggregate_query (hybrid_query )
166- assert len (results ) == 3
165+ assert len (results ) == 2
167166 for result in results :
168167 assert result ["credit_score" ] == "high"
169168 assert int (result ["age" ]) > 30
@@ -179,7 +178,7 @@ def test_aggregation_query_with_geo_filter(index):
179178 vector = [0.1 , 0.1 , 0.5 ]
180179 vector_field = "user_embedding"
181180 return_fields = ["user" , "credit_score" , "age" , "job" , "location" , "description" ]
182- filter_expression = Geo ("location" ) == GeoRadius (37.7749 , - 122.4194 , 1000 )
181+ filter_expression = Geo ("location" ) == GeoRadius (- 122.4194 , 37.7749 , 1000 , "m" )
183182
184183 hybrid_query = HybridAggregationQuery (
185184 text = text ,
@@ -196,6 +195,36 @@ def test_aggregation_query_with_geo_filter(index):
196195 assert result ["location" ] is not None
197196
198197
198+ @pytest .mark .parametrize ("alpha" , [0.1 , 0.5 , 0.9 ])
199+ def test_aggregate_query_alpha (index , alpha ):
200+ redis_version = index .client .info ()["redis_version" ]
201+ if not compare_versions (redis_version , "7.2.0" ):
202+ pytest .skip ("Not using a late enough version of Redis" )
203+
204+ text = "a medical professional with expertise in lung cancer"
205+ text_field = "description"
206+ vector = [0.1 , 0.1 , 0.5 ]
207+ vector_field = "user_embedding"
208+
209+ hybrid_query = HybridAggregationQuery (
210+ text = text ,
211+ text_field_name = text_field ,
212+ vector = vector ,
213+ vector_field_name = vector_field ,
214+ alpha = alpha ,
215+ )
216+
217+ results = index .aggregate_query (hybrid_query )
218+ assert len (results ) == 7
219+ for result in results :
220+ score = alpha * float (result ["vector_similarity" ]) + (1 - alpha ) * float (
221+ result ["text_score" ]
222+ )
223+ assert (
224+ float (result ["hybrid_score" ]) - score <= 0.0001
225+ ) # allow for small floating point error
226+
227+
199228def test_aggregate_query_stopwords (index ):
200229 redis_version = index .client .info ()["redis_version" ]
201230 if not compare_versions (redis_version , "7.2.0" ):
@@ -205,24 +234,34 @@ def test_aggregate_query_stopwords(index):
205234 text_field = "description"
206235 vector = [0.1 , 0.1 , 0.5 ]
207236 vector_field = "user_embedding"
237+ alpha = 0.5
208238
209239 hybrid_query = HybridAggregationQuery (
210240 text = text ,
211241 text_field_name = text_field ,
212242 vector = vector ,
213243 vector_field_name = vector_field ,
214- alpha = 0.5 ,
244+ alpha = alpha ,
215245 stopwords = ["medical" , "expertise" ],
216246 )
217247
248+ query_string = hybrid_query ._build_query_string ()
249+
250+ assert "medical" not in query_string
251+ assert "expertize" not in query_string
252+
218253 results = index .aggregate_query (hybrid_query )
219254 assert len (results ) == 7
220- for r in results :
221- assert r ["text_score" ] == 0
222- assert r ["hybrid_score" ] == 0.5 * r ["vector_similarity" ]
255+ for result in results :
256+ score = alpha * float (result ["vector_similarity" ]) + (1 - alpha ) * float (
257+ result ["text_score" ]
258+ )
259+ assert (
260+ float (result ["hybrid_score" ]) - score <= 0.0001
261+ ) # allow for small floating point error
223262
224263
225- def test_aggregate_query_text_filter (index ):
264+ def test_aggregate_query_with_text_filter (index ):
226265 redis_version = index .client .info ()["redis_version" ]
227266 if not compare_versions (redis_version , "7.2.0" ):
228267 pytest .skip ("Not using a late enough version of Redis" )
@@ -231,19 +270,39 @@ def test_aggregate_query_text_filter(index):
231270 text_field = "description"
232271 vector = [0.1 , 0.1 , 0.5 ]
233272 vector_field = "user_embedding"
234- filter_expression = ( Text ("description" ) == ("medical" )) | ( Text ( "job" ) % ( "doct*" ) )
273+ filter_expression = Text (text_field ) == ("medical" )
235274
275+ # make sure we can still apply filters to the same text field we are querying
236276 hybrid_query = HybridAggregationQuery (
237277 text = text ,
238278 text_field_name = text_field ,
239279 vector = vector ,
240280 vector_field_name = vector_field ,
241281 alpha = 0.5 ,
242- filter_expression = filter_expression
243- )
282+ filter_expression = filter_expression ,
283+ return_fields = ["job" , "description" ],
284+ )
244285
245286 results = index .aggregate_query (hybrid_query )
246- assert len (results ) == 7
287+ assert len (results ) == 2
288+ for result in results :
289+ assert "medical" in result [text_field ].lower ()
290+
291+ filter_expression = (Text (text_field ) == ("medical" )) & (
292+ (Text (text_field ) != ("research" ))
293+ )
294+ hybrid_query = HybridAggregationQuery (
295+ text = text ,
296+ text_field_name = text_field ,
297+ vector = vector ,
298+ vector_field_name = vector_field ,
299+ alpha = 0.5 ,
300+ filter_expression = filter_expression ,
301+ return_fields = ["description" ],
302+ )
303+
304+ results = index .aggregate_query (hybrid_query )
305+ assert len (results ) == 2
247306 for result in results :
248- assert result [ "text_score" ] == 0
249- assert result [ "hybrid_score" ] == 0.5 * result ["vector_similarity" ]
307+ assert "medical" in result [ text_field ]. lower ()
308+ assert "research" not in result [text_field ]. lower ()
0 commit comments