33from redisvl .extensions .llmcache import SemanticCache
44from redisvl .extensions .router import Route , SemanticRouter
55from redisvl .extensions .router .schema import RoutingConfig
6+ from redisvl .extensions .threshold_optimizer .base import EvalMetric
67from redisvl .extensions .threshold_optimizer .cache import CacheThresholdOptimizer
78from redisvl .extensions .threshold_optimizer .router import RouterThresholdOptimizer
89from redisvl .redis .connection import compare_versions
@@ -77,7 +78,7 @@ def test_data_optimization():
7778 ]
7879
7980
80- def test_routes_different_distance_thresholds_optimizer (
81+ def test_routes_different_distance_thresholds_optimizer_default (
8182 semantic_router , routes , redis_url , test_data_optimization
8283):
8384 redis_version = semantic_router ._index .client .info ()["redis_version" ]
@@ -110,7 +111,77 @@ def test_routes_different_distance_thresholds_optimizer(
110111 assert route .distance_threshold > zero_threshold
111112
112113
113- def test_optimize_threshold_cache (redis_url ):
114+ def test_routes_different_distance_thresholds_optimizer_precision (
115+ semantic_router , routes , redis_url , test_data_optimization
116+ ):
117+ redis_version = semantic_router ._index .client .info ()["redis_version" ]
118+ if not compare_versions (redis_version , "7.0.0" ):
119+ pytest .skip ("Not using a late enough version of Redis" )
120+
121+ zero_threshold = 0.0
122+
123+ # Test that it updates the thresholds
124+ routes [0 ].distance_threshold = zero_threshold
125+ routes [1 ].distance_threshold = zero_threshold
126+
127+ router = SemanticRouter (
128+ name = "test_routes_different_distance_optimizer" ,
129+ routes = routes ,
130+ redis_url = redis_url ,
131+ overwrite = True ,
132+ )
133+
134+ # szia is hello in hungarian and not in our test data
135+ matches = router .route_many ("Szia" , max_k = 2 )
136+ assert len (matches ) == 0
137+
138+ # now run optimizer
139+ router_optimizer = RouterThresholdOptimizer (
140+ router , test_data_optimization , eval_metric = "precision"
141+ )
142+ router_optimizer .optimize (max_iterations = 10 )
143+
144+ # test that it updated thresholds beyond the null case
145+ for route in routes :
146+ assert route .distance_threshold > zero_threshold
147+
148+
149+ def test_routes_different_distance_thresholds_optimizer_recall (
150+ semantic_router , routes , redis_url , test_data_optimization
151+ ):
152+ redis_version = semantic_router ._index .client .info ()["redis_version" ]
153+ if not compare_versions (redis_version , "7.0.0" ):
154+ pytest .skip ("Not using a late enough version of Redis" )
155+
156+ zero_threshold = 0.0
157+
158+ # Test that it updates the thresholds
159+ routes [0 ].distance_threshold = zero_threshold
160+ routes [1 ].distance_threshold = zero_threshold
161+
162+ router = SemanticRouter (
163+ name = "test_routes_different_distance_optimizer" ,
164+ routes = routes ,
165+ redis_url = redis_url ,
166+ overwrite = True ,
167+ )
168+
169+ # szia is hello in hungarian and not in our test data
170+ matches = router .route_many ("Szia" , max_k = 2 )
171+ assert len (matches ) == 0
172+
173+ # now run optimizer
174+ router_optimizer = RouterThresholdOptimizer (
175+ router , test_data_optimization , eval_metric = "recall"
176+ )
177+ router_optimizer .optimize (max_iterations = 10 )
178+
179+ # test that it updated thresholds beyond the null case
180+ for route in routes :
181+ assert route .distance_threshold > zero_threshold
182+
183+
184+ def test_optimize_threshold_cache_default (redis_url ):
114185 null_threshold = 0.0
115186 cache = SemanticCache (
116187 name = "test_optimize_threshold_cache" ,
@@ -132,3 +203,81 @@ def test_optimize_threshold_cache(redis_url):
132203 cache_optimizer .optimize ()
133204
134205 assert cache .distance_threshold > null_threshold
206+
207+
208+ def test_optimize_threshold_cache_precision (redis_url ):
209+ null_threshold = 0.0
210+ cache = SemanticCache (
211+ name = "test_optimize_threshold_cache" ,
212+ redis_url = redis_url ,
213+ distance_threshold = null_threshold ,
214+ )
215+
216+ paris_key = cache .store (prompt = "what is the capital of france?" , response = "paris" )
217+ rabat_key = cache .store (prompt = "what is the capital of morocco?" , response = "rabat" )
218+
219+ test_dict = [
220+ {"query" : "what actually is the capital of france?" , "query_match" : paris_key },
221+ {"query" : "what actually is the capital of morocco?" , "query_match" : rabat_key },
222+ {"query" : "What is the state bird of virginia?" , "query_match" : "" },
223+ ]
224+
225+ cache_optimizer = CacheThresholdOptimizer (cache , test_dict , eval_metric = "precision" )
226+
227+ cache_optimizer .optimize ()
228+
229+ assert cache .distance_threshold > null_threshold
230+
231+
232+ def test_optimize_threshold_cache_recall (redis_url ):
233+ null_threshold = 0.0
234+ cache = SemanticCache (
235+ name = "test_optimize_threshold_cache" ,
236+ redis_url = redis_url ,
237+ distance_threshold = null_threshold ,
238+ )
239+
240+ paris_key = cache .store (prompt = "what is the capital of france?" , response = "paris" )
241+ rabat_key = cache .store (prompt = "what is the capital of morocco?" , response = "rabat" )
242+
243+ test_dict = [
244+ {"query" : "what actually is the capital of france?" , "query_match" : paris_key },
245+ {"query" : "what actually is the capital of morocco?" , "query_match" : rabat_key },
246+ {"query" : "What is the state bird of virginia?" , "query_match" : "" },
247+ ]
248+
249+ cache_optimizer = CacheThresholdOptimizer (cache , test_dict , eval_metric = "recall" )
250+
251+ cache_optimizer .optimize ()
252+
253+ assert cache .distance_threshold > null_threshold
254+
255+
256+ def test_eval_metric_from_string ():
257+ """Test that EvalMetric.from_string works for valid metrics."""
258+ assert EvalMetric .from_string ("f1" ) == EvalMetric .F1
259+ assert EvalMetric .from_string ("precision" ) == EvalMetric .PRECISION
260+ assert EvalMetric .from_string ("recall" ) == EvalMetric .RECALL
261+
262+ # Test case insensitivity
263+ assert EvalMetric .from_string ("F1" ) == EvalMetric .F1
264+ assert EvalMetric .from_string ("PRECISION" ) == EvalMetric .PRECISION
265+
266+
267+ def test_eval_metric_invalid ():
268+ """Test that EvalMetric.from_string raises ValueError for invalid metrics."""
269+ with pytest .raises (ValueError ):
270+ EvalMetric .from_string ("invalid_metric" )
271+
272+
273+ def test_optimizer_with_invalid_metric (redis_url ):
274+ """Test that optimizers raise ValueError when initialized with invalid metric."""
275+ cache = SemanticCache (
276+ name = "test_invalid_metric" ,
277+ redis_url = redis_url ,
278+ )
279+
280+ test_dict = [{"query" : "test" , "query_match" : "" }]
281+
282+ with pytest .raises (ValueError ):
283+ CacheThresholdOptimizer (cache , test_dict , eval_metric = "invalid_metric" )
0 commit comments