66
77from redisvl .extensions .router .semantic import SemanticRouter
88from redisvl .utils .optimize .base import BaseThresholdOptimizer , EvalMetric
9- from redisvl .utils .optimize .schema import TestData
9+ from redisvl .utils .optimize .schema import LabeledData
1010from redisvl .utils .optimize .utils import NULL_RESPONSE_KEY , _format_qrels
1111
1212
13- def _generate_run_router (test_data : List [TestData ], router : SemanticRouter ) -> Run :
13+ def _generate_run_router (test_data : List [LabeledData ], router : SemanticRouter ) -> Run :
1414 """Format router results into format for ranx Run"""
1515 run_dict : Dict [Any , Any ] = {}
1616
@@ -26,7 +26,7 @@ def _generate_run_router(test_data: List[TestData], router: SemanticRouter) -> R
2626
2727
2828def _eval_router (
29- router : SemanticRouter , test_data : List [TestData ], qrels : Qrels , eval_metric : str
29+ router : SemanticRouter , test_data : List [LabeledData ], qrels : Qrels , eval_metric : str
3030) -> float :
3131 """Evaluate acceptable metric given run and qrels data"""
3232 run = _generate_run_router (test_data , router )
@@ -55,7 +55,7 @@ def _router_random_search(
5555
5656def _random_search_opt_router (
5757 router : SemanticRouter ,
58- test_data : List [TestData ],
58+ test_data : List [LabeledData ],
5959 qrels : Qrels ,
6060 eval_metric : EvalMetric ,
6161 ** kwargs : Any ,
@@ -67,12 +67,15 @@ def _random_search_opt_router(
6767 best_thresholds = router .route_thresholds
6868
6969 max_iterations = kwargs .get ("max_iterations" , 20 )
70+ search_step = kwargs .get ("search_step" , 0.10 )
7071
7172 for _ in range (max_iterations ):
7273 route_names = router .route_names
7374 route_thresholds = router .route_thresholds
7475 thresholds = _router_random_search (
75- route_names = route_names , route_thresholds = route_thresholds
76+ route_names = route_names ,
77+ route_thresholds = route_thresholds ,
78+ search_step = search_step ,
7679 )
7780 router .update_route_thresholds (thresholds )
7881 score = _eval_router (router , test_data , qrels , eval_metric .value )
0 commit comments