Skip to content

Commit d3a0717

Browse files
committed
more tests
1 parent 705cd2e commit d3a0717

File tree

2 files changed

+151
-3
lines changed

2 files changed

+151
-3
lines changed

redisvl/extensions/threshold_optimizer/base.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ class EvalMetric(Enum):
1111
F1 = "f1"
1212
PRECISION = "precision"
1313
RECALL = "recall"
14-
ACCURACY = "accuracy"
1514

1615
def __str__(self) -> str:
1716
return self.value

tests/integration/test_threshold_optimizer.py

Lines changed: 151 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from redisvl.extensions.llmcache import SemanticCache
44
from redisvl.extensions.router import Route, SemanticRouter
55
from redisvl.extensions.router.schema import RoutingConfig
6+
from redisvl.extensions.threshold_optimizer.base import EvalMetric
67
from redisvl.extensions.threshold_optimizer.cache import CacheThresholdOptimizer
78
from redisvl.extensions.threshold_optimizer.router import RouterThresholdOptimizer
89
from redisvl.redis.connection import compare_versions
@@ -77,7 +78,7 @@ def test_data_optimization():
7778
]
7879

7980

80-
def test_routes_different_distance_thresholds_optimizer(
81+
def test_routes_different_distance_thresholds_optimizer_default(
8182
semantic_router, routes, redis_url, test_data_optimization
8283
):
8384
redis_version = semantic_router._index.client.info()["redis_version"]
@@ -110,7 +111,77 @@ def test_routes_different_distance_thresholds_optimizer(
110111
assert route.distance_threshold > zero_threshold
111112

112113

113-
def test_optimize_threshold_cache(redis_url):
114+
def test_routes_different_distance_thresholds_optimizer_precision(
115+
semantic_router, routes, redis_url, test_data_optimization
116+
):
117+
redis_version = semantic_router._index.client.info()["redis_version"]
118+
if not compare_versions(redis_version, "7.0.0"):
119+
pytest.skip("Not using a late enough version of Redis")
120+
121+
zero_threshold = 0.0
122+
123+
# Test that it updates the thresholds
124+
routes[0].distance_threshold = zero_threshold
125+
routes[1].distance_threshold = zero_threshold
126+
127+
router = SemanticRouter(
128+
name="test_routes_different_distance_optimizer",
129+
routes=routes,
130+
redis_url=redis_url,
131+
overwrite=True,
132+
)
133+
134+
# szia is hello in hungarian and not in our test data
135+
matches = router.route_many("Szia", max_k=2)
136+
assert len(matches) == 0
137+
138+
# now run optimizer
139+
router_optimizer = RouterThresholdOptimizer(
140+
router, test_data_optimization, eval_metric="precision"
141+
)
142+
router_optimizer.optimize(max_iterations=10)
143+
144+
# test that it updated thresholds beyond the null case
145+
for route in routes:
146+
assert route.distance_threshold > zero_threshold
147+
148+
149+
def test_routes_different_distance_thresholds_optimizer_recall(
150+
semantic_router, routes, redis_url, test_data_optimization
151+
):
152+
redis_version = semantic_router._index.client.info()["redis_version"]
153+
if not compare_versions(redis_version, "7.0.0"):
154+
pytest.skip("Not using a late enough version of Redis")
155+
156+
zero_threshold = 0.0
157+
158+
# Test that it updates the thresholds
159+
routes[0].distance_threshold = zero_threshold
160+
routes[1].distance_threshold = zero_threshold
161+
162+
router = SemanticRouter(
163+
name="test_routes_different_distance_optimizer",
164+
routes=routes,
165+
redis_url=redis_url,
166+
overwrite=True,
167+
)
168+
169+
# szia is hello in hungarian and not in our test data
170+
matches = router.route_many("Szia", max_k=2)
171+
assert len(matches) == 0
172+
173+
# now run optimizer
174+
router_optimizer = RouterThresholdOptimizer(
175+
router, test_data_optimization, eval_metric="recall"
176+
)
177+
router_optimizer.optimize(max_iterations=10)
178+
179+
# test that it updated thresholds beyond the null case
180+
for route in routes:
181+
assert route.distance_threshold > zero_threshold
182+
183+
184+
def test_optimize_threshold_cache_default(redis_url):
114185
null_threshold = 0.0
115186
cache = SemanticCache(
116187
name="test_optimize_threshold_cache",
@@ -132,3 +203,81 @@ def test_optimize_threshold_cache(redis_url):
132203
cache_optimizer.optimize()
133204

134205
assert cache.distance_threshold > null_threshold
206+
207+
208+
def test_optimize_threshold_cache_precision(redis_url):
209+
null_threshold = 0.0
210+
cache = SemanticCache(
211+
name="test_optimize_threshold_cache",
212+
redis_url=redis_url,
213+
distance_threshold=null_threshold,
214+
)
215+
216+
paris_key = cache.store(prompt="what is the capital of france?", response="paris")
217+
rabat_key = cache.store(prompt="what is the capital of morocco?", response="rabat")
218+
219+
test_dict = [
220+
{"query": "what actually is the capital of france?", "query_match": paris_key},
221+
{"query": "what actually is the capital of morocco?", "query_match": rabat_key},
222+
{"query": "What is the state bird of virginia?", "query_match": ""},
223+
]
224+
225+
cache_optimizer = CacheThresholdOptimizer(cache, test_dict, eval_metric="precision")
226+
227+
cache_optimizer.optimize()
228+
229+
assert cache.distance_threshold > null_threshold
230+
231+
232+
def test_optimize_threshold_cache_recall(redis_url):
233+
null_threshold = 0.0
234+
cache = SemanticCache(
235+
name="test_optimize_threshold_cache",
236+
redis_url=redis_url,
237+
distance_threshold=null_threshold,
238+
)
239+
240+
paris_key = cache.store(prompt="what is the capital of france?", response="paris")
241+
rabat_key = cache.store(prompt="what is the capital of morocco?", response="rabat")
242+
243+
test_dict = [
244+
{"query": "what actually is the capital of france?", "query_match": paris_key},
245+
{"query": "what actually is the capital of morocco?", "query_match": rabat_key},
246+
{"query": "What is the state bird of virginia?", "query_match": ""},
247+
]
248+
249+
cache_optimizer = CacheThresholdOptimizer(cache, test_dict, eval_metric="recall")
250+
251+
cache_optimizer.optimize()
252+
253+
assert cache.distance_threshold > null_threshold
254+
255+
256+
def test_eval_metric_from_string():
257+
"""Test that EvalMetric.from_string works for valid metrics."""
258+
assert EvalMetric.from_string("f1") == EvalMetric.F1
259+
assert EvalMetric.from_string("precision") == EvalMetric.PRECISION
260+
assert EvalMetric.from_string("recall") == EvalMetric.RECALL
261+
262+
# Test case insensitivity
263+
assert EvalMetric.from_string("F1") == EvalMetric.F1
264+
assert EvalMetric.from_string("PRECISION") == EvalMetric.PRECISION
265+
266+
267+
def test_eval_metric_invalid():
268+
"""Test that EvalMetric.from_string raises ValueError for invalid metrics."""
269+
with pytest.raises(ValueError):
270+
EvalMetric.from_string("invalid_metric")
271+
272+
273+
def test_optimizer_with_invalid_metric(redis_url):
274+
"""Test that optimizers raise ValueError when initialized with invalid metric."""
275+
cache = SemanticCache(
276+
name="test_invalid_metric",
277+
redis_url=redis_url,
278+
)
279+
280+
test_dict = [{"query": "test", "query_match": ""}]
281+
282+
with pytest.raises(ValueError):
283+
CacheThresholdOptimizer(cache, test_dict, eval_metric="invalid_metric")

0 commit comments

Comments
 (0)