33
33
from pydantic_core import ValidationError
34
34
35
35
from nucliadb .common .datamanagers .exceptions import KnowledgeBoxNotFound
36
+ from nucliadb .common .external_index_providers .base import ScoredTextBlock
37
+ from nucliadb .common .ids import ParagraphId
36
38
from nucliadb .models .responses import HTTPClientError
37
39
from nucliadb .search import logger , predict
38
40
from nucliadb .search .predict import (
63
65
from nucliadb .search .search .metrics import RAGMetrics
64
66
from nucliadb .search .search .query_parser .fetcher import Fetcher
65
67
from nucliadb .search .search .query_parser .parsers .ask import fetcher_for_ask , parse_ask
68
+ from nucliadb .search .search .rank_fusion import WeightedCombSum
66
69
from nucliadb .search .search .rerankers import (
67
70
get_reranker ,
68
71
)
@@ -861,6 +864,10 @@ async def retrieval_in_resource(
861
864
)
862
865
863
866
867
+ class _FindParagraph (ScoredTextBlock ):
868
+ original : FindParagraph
869
+
870
+
864
871
def compute_best_matches (
865
872
main_results : KnowledgeboxFindResults ,
866
873
prequeries_results : Optional [list [PreQueryResult ]] = None ,
@@ -878,42 +885,49 @@ def compute_best_matches(
878
885
`main_query_weight` is the weight given to the paragraphs matching the main query when calculating the final score.
879
886
"""
880
887
881
- def iter_paragraphs (results : KnowledgeboxFindResults ):
888
+ def extract_paragraphs (results : KnowledgeboxFindResults ) -> list [_FindParagraph ]:
889
+ paragraphs = []
882
890
for resource in results .resources .values ():
883
891
for field in resource .fields .values ():
884
892
for paragraph in field .paragraphs .values ():
885
- yield paragraph
886
-
887
- total_weights = main_query_weight + sum (prequery .weight for prequery , _ in prequeries_results or [])
888
- paragraph_id_to_match : dict [str , RetrievalMatch ] = {}
889
- for paragraph in iter_paragraphs (main_results ):
890
- normalized_weight = main_query_weight / total_weights
891
- rmatch = RetrievalMatch (
892
- paragraph = paragraph ,
893
- weighted_score = paragraph .score * normalized_weight ,
894
- )
895
- paragraph_id_to_match [paragraph .id ] = rmatch
896
-
897
- for prequery , prequery_results in prequeries_results or []:
898
- for paragraph in iter_paragraphs (prequery_results ):
899
- normalized_weight = prequery .weight / total_weights
900
- weighted_score = paragraph .score * normalized_weight
901
- if paragraph .id in paragraph_id_to_match :
902
- rmatch = paragraph_id_to_match [paragraph .id ]
903
- # If a paragraph is matched in various prequeries, the final score is the
904
- # sum of the weighted scores
905
- rmatch .weighted_score += weighted_score
906
- else :
907
- paragraph_id_to_match [paragraph .id ] = RetrievalMatch (
908
- paragraph = paragraph ,
909
- weighted_score = weighted_score ,
910
- )
893
+ paragraphs .append (
894
+ _FindParagraph (
895
+ paragraph_id = ParagraphId .from_string (paragraph .id ),
896
+ score = paragraph .score ,
897
+ score_type = paragraph .score_type ,
898
+ original = paragraph ,
899
+ )
900
+ )
901
+ return paragraphs
911
902
912
- return sorted (
913
- paragraph_id_to_match .values (),
914
- key = lambda match : match .weighted_score ,
915
- reverse = True ,
916
- )
903
+ weights = {
904
+ "main" : main_query_weight ,
905
+ }
906
+ total_weight = main_query_weight
907
+ find_results = {
908
+ "main" : extract_paragraphs (main_results ),
909
+ }
910
+ total_elements = len (find_results ["main" ])
911
+ for i , (prequery , prequery_results ) in enumerate (prequeries_results or []):
912
+ weights [f"prequery-{ i } " ] = prequery .weight
913
+ total_weight += prequery .weight
914
+ prequery_paragraphs = extract_paragraphs (prequery_results )
915
+ find_results [f"prequery-{ i } " ] = prequery_paragraphs
916
+ total_elements += len (prequery_paragraphs )
917
+
918
+ normalized_weights = {key : value / total_weight for key , value in weights .items ()}
919
+
920
+ # window does nothing here
921
+ rank_fusion = WeightedCombSum (window = 0 , weights = normalized_weights )
922
+
923
+ merged = []
924
+ for item in rank_fusion .fuse (find_results ):
925
+ match = RetrievalMatch (
926
+ paragraph = item .original ,
927
+ weighted_score = item .score ,
928
+ )
929
+ merged .append (match )
930
+ return merged
917
931
918
932
919
933
def calculate_prequeries_for_json_schema (
0 commit comments