Skip to content

Commit 497a620

Browse files
committed
Implement compute_best_matches with WeightedCombSum
1 parent 3daf2be commit 497a620

File tree

1 file changed

+46
-32
lines changed
  • nucliadb/src/nucliadb/search/search/chat

1 file changed

+46
-32
lines changed

nucliadb/src/nucliadb/search/search/chat/ask.py

Lines changed: 46 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
from pydantic_core import ValidationError
3434

3535
from nucliadb.common.datamanagers.exceptions import KnowledgeBoxNotFound
36+
from nucliadb.common.external_index_providers.base import ScoredTextBlock
37+
from nucliadb.common.ids import ParagraphId
3638
from nucliadb.models.responses import HTTPClientError
3739
from nucliadb.search import logger, predict
3840
from nucliadb.search.predict import (
@@ -63,6 +65,7 @@
6365
from nucliadb.search.search.metrics import RAGMetrics
6466
from nucliadb.search.search.query_parser.fetcher import Fetcher
6567
from nucliadb.search.search.query_parser.parsers.ask import fetcher_for_ask, parse_ask
68+
from nucliadb.search.search.rank_fusion import WeightedCombSum
6669
from nucliadb.search.search.rerankers import (
6770
get_reranker,
6871
)
@@ -861,6 +864,10 @@ async def retrieval_in_resource(
861864
)
862865

863866

867+
class _FindParagraph(ScoredTextBlock):
868+
original: FindParagraph
869+
870+
864871
def compute_best_matches(
865872
main_results: KnowledgeboxFindResults,
866873
prequeries_results: Optional[list[PreQueryResult]] = None,
@@ -878,42 +885,49 @@ def compute_best_matches(
878885
`main_query_weight` is the weight given to the paragraphs matching the main query when calculating the final score.
879886
"""
880887

881-
def iter_paragraphs(results: KnowledgeboxFindResults):
888+
def extract_paragraphs(results: KnowledgeboxFindResults) -> list[_FindParagraph]:
889+
paragraphs = []
882890
for resource in results.resources.values():
883891
for field in resource.fields.values():
884892
for paragraph in field.paragraphs.values():
885-
yield paragraph
886-
887-
total_weights = main_query_weight + sum(prequery.weight for prequery, _ in prequeries_results or [])
888-
paragraph_id_to_match: dict[str, RetrievalMatch] = {}
889-
for paragraph in iter_paragraphs(main_results):
890-
normalized_weight = main_query_weight / total_weights
891-
rmatch = RetrievalMatch(
892-
paragraph=paragraph,
893-
weighted_score=paragraph.score * normalized_weight,
894-
)
895-
paragraph_id_to_match[paragraph.id] = rmatch
896-
897-
for prequery, prequery_results in prequeries_results or []:
898-
for paragraph in iter_paragraphs(prequery_results):
899-
normalized_weight = prequery.weight / total_weights
900-
weighted_score = paragraph.score * normalized_weight
901-
if paragraph.id in paragraph_id_to_match:
902-
rmatch = paragraph_id_to_match[paragraph.id]
903-
# If a paragraph is matched in various prequeries, the final score is the
904-
# sum of the weighted scores
905-
rmatch.weighted_score += weighted_score
906-
else:
907-
paragraph_id_to_match[paragraph.id] = RetrievalMatch(
908-
paragraph=paragraph,
909-
weighted_score=weighted_score,
910-
)
893+
paragraphs.append(
894+
_FindParagraph(
895+
paragraph_id=ParagraphId.from_string(paragraph.id),
896+
score=paragraph.score,
897+
score_type=paragraph.score_type,
898+
original=paragraph,
899+
)
900+
)
901+
return paragraphs
911902

912-
return sorted(
913-
paragraph_id_to_match.values(),
914-
key=lambda match: match.weighted_score,
915-
reverse=True,
916-
)
903+
weights = {
904+
"main": main_query_weight,
905+
}
906+
total_weight = main_query_weight
907+
find_results = {
908+
"main": extract_paragraphs(main_results),
909+
}
910+
total_elements = len(find_results["main"])
911+
for i, (prequery, prequery_results) in enumerate(prequeries_results or []):
912+
weights[f"prequery-{i}"] = prequery.weight
913+
total_weight += prequery.weight
914+
prequery_paragraphs = extract_paragraphs(prequery_results)
915+
find_results[f"prequery-{i}"] = prequery_paragraphs
916+
total_elements += len(prequery_paragraphs)
917+
918+
normalized_weights = {key: value / total_weight for key, value in weights.items()}
919+
920+
# window does nothing here
921+
rank_fusion = WeightedCombSum(window=0, weights=normalized_weights)
922+
923+
merged = []
924+
for item in rank_fusion.fuse(find_results):
925+
match = RetrievalMatch(
926+
paragraph=item.original,
927+
weighted_score=item.score,
928+
)
929+
merged.append(match)
930+
return merged
917931

918932

919933
def calculate_prequeries_for_json_schema(

0 commit comments

Comments
 (0)