diff --git a/nucliadb/src/nucliadb/search/search/find.py b/nucliadb/src/nucliadb/search/search/find.py index 3fff362e48..96102433c0 100644 --- a/nucliadb/src/nucliadb/search/search/find.py +++ b/nucliadb/src/nucliadb/search/search/find.py @@ -17,41 +17,51 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . # +import asyncio import logging from time import time +from typing import Optional -from nucliadb.common.external_index_providers.base import ExternalIndexManager +from nidx_protos.nodereader_pb2 import SearchRequest + +from nucliadb.common.external_index_providers.base import TextBlockMatch from nucliadb.common.external_index_providers.manager import get_external_index_manager +from nucliadb.common.external_index_providers.pinecone import PineconeIndexManager from nucliadb.common.models_utils import to_proto -from nucliadb.search.requesters.utils import Method, node_query from nucliadb.search.search.find_merge import ( - build_find_response, + _round, compose_find_resources, - hydrate_and_rerank, -) -from nucliadb.search.search.hydrator import ( - ResourceHydrationOptions, - TextBlockHydrationOptions, ) +from nucliadb.search.search.hydrator import ResourceHydrationOptions, TextBlockHydrationOptions from nucliadb.search.search.metrics import ( RAGMetrics, ) -from nucliadb.search.search.query_parser.models import ParsedQuery +from nucliadb.search.search.plan.cut import Cut +from nucliadb.search.search.plan.hydrate import HydrateResources, HydrateTextBlocks +from nucliadb.search.search.plan.models import IndexResult, Plan, PlanStep +from nucliadb.search.search.plan.query import ( + IndexPostFilter, + NidxQuery, + PineconeQuery, + ProtoIntoTextBlockMatches, +) +from nucliadb.search.search.plan.rank_fusion import RankFusion +from nucliadb.search.search.plan.rerank import Rerank +from nucliadb.search.search.plan.serialize import SerializeRelations +from nucliadb.search.search.plan.utils import CachedStep, Flatten, Group, OptionalStep, Parallel +from nucliadb.search.search.query_parser.models import ParsedQuery, UnitRetrieval from nucliadb.search.search.query_parser.parsers import parse_find from nucliadb.search.search.query_parser.parsers.unit_retrieval import legacy_convert_retrieval_to_proto -from nucliadb.search.search.rank_fusion import ( - get_rank_fusion, -) -from nucliadb.search.search.rerankers import ( - RerankingOptions, - get_reranker, -) +from nucliadb.search.search.rank_fusion import get_rank_fusion +from nucliadb.search.search.rerankers import Reranker, RerankingOptions, get_reranker from nucliadb.search.settings import settings +from nucliadb_models.resource import Resource from nucliadb_models.search import ( FindRequest, KnowledgeboxFindResults, MinScore, NucliaDBClientType, + Relations, ) from nucliadb_utils.utilities import get_audit @@ -65,38 +75,12 @@ async def find( x_nucliadb_user: str, x_forwarded_for: str, metrics: RAGMetrics = RAGMetrics(), -) -> tuple[KnowledgeboxFindResults, bool, ParsedQuery]: - external_index_manager = await get_external_index_manager(kbid=kbid) - if external_index_manager is not None: - return await _external_index_retrieval( - kbid, - item, - external_index_manager, - ) - else: - return await _index_node_retrieval( - kbid, item, x_ndb_client, x_nucliadb_user, x_forwarded_for, metrics - ) - - -async def _index_node_retrieval( - kbid: str, - item: FindRequest, - x_ndb_client: NucliaDBClientType, - x_nucliadb_user: str, - x_forwarded_for: str, - metrics: RAGMetrics = RAGMetrics(), ) -> tuple[KnowledgeboxFindResults, bool, ParsedQuery]: audit = get_audit() start_time = time() with metrics.time("query_parse"): parsed = await parse_find(kbid, item) - assert parsed.retrieval.rank_fusion is not None and parsed.retrieval.reranker is not None, ( - "find parser must provide rank fusion and reranker algorithms" - ) - rank_fusion = get_rank_fusion(parsed.retrieval.rank_fusion) - reranker = get_reranker(parsed.retrieval.reranker) ( pb_query, incomplete_results, @@ -104,26 +88,38 @@ async def _index_node_retrieval( rephrased_query, ) = await legacy_convert_retrieval_to_proto(parsed) - with metrics.time("node_query"): - results, query_incomplete_results, queried_shards = await node_query( - kbid, Method.SEARCH, pb_query - ) - incomplete_results = incomplete_results or query_incomplete_results + with metrics.time("query_plan"): + plan = await plan_find_request(kbid, item, parsed, pb_query) + + with metrics.time("node_query"), metrics.time("query_execute"): + text_blocks, resources, relations = await plan.execute() - # Rank fusion merge, cut, hydrate and rerank with metrics.time("results_merge"): - search_results = await build_find_response( - results, - retrieval=parsed.retrieval, - kbid=kbid, + # compose response + find_resources = compose_find_resources(text_blocks, resources) + + best_matches = [text_block.paragraph_id.full() for text_block in text_blocks] + + # XXX: we shouldn't need a min score that we haven't used. Previous + # implementations got this value from the proto request (i.e., default to 0) + min_score_bm25 = 0.0 + if parsed.retrieval.query.keyword is not None: + min_score_bm25 = parsed.retrieval.query.keyword.min_score + min_score_semantic = 0.0 + if parsed.retrieval.query.semantic is not None: + min_score_semantic = parsed.retrieval.query.semantic.min_score + + find_results = KnowledgeboxFindResults( query=pb_query.body, rephrased_query=rephrased_query, - show=item.show, - extracted=item.extracted, - field_type_filter=item.field_type_filter, - highlight=item.highlight, - rank_fusion_algorithm=rank_fusion, - reranker=reranker, + resources=find_resources, + best_matches=best_matches, + relations=relations, + total=plan.context.keyword_result_count, + page_number=0, # Bw/c with pagination + page_size=parsed.retrieval.top_k, + next_page=plan.context.next_page, + min_score=MinScore(bm25=_round(min_score_bm25), semantic=_round(min_score_semantic)), ) search_time = time() - start_time @@ -135,12 +131,12 @@ async def _index_node_retrieval( x_forwarded_for, pb_query, search_time, - len(search_results.resources), + len(find_results.resources), retrieval_rephrased_question=rephrased_query, ) - search_results.shards = queried_shards - search_results.autofilters = autofilters + find_results.shards = plan.context.nidx_queried_shards + find_results.autofilters = autofilters ndb_time = metrics.elapsed("node_query") + metrics.elapsed("results_merge") if metrics.elapsed("node_query") > settings.slow_node_query_log_threshold: @@ -168,67 +164,199 @@ async def _index_node_retrieval( }, ) - return search_results, incomplete_results, parsed + incomplete = incomplete_results or plan.context.nidx_incomplete + return find_results, incomplete, parsed + + +async def plan_find_request( + kbid: str, request: FindRequest, parsed: ParsedQuery, pb_query: SearchRequest +) -> Plan[tuple[list[TextBlockMatch], list[Resource], Optional[Relations]]]: + query_step, serialize_relations_step = await _plan_index_query(kbid, parsed, pb_query) + rank_fusion_step = _plan_rank_fusion(parsed.retrieval, uses=query_step) + hydrate_and_rerank_step = _plan_hydrate_and_rerank( + kbid, request, parsed, pb_query, uses=rank_fusion_step + ) + + plan = Plan( + Flatten( + Parallel( + hydrate_and_rerank_step, + serialize_relations_step, + ) + ) + ) + return plan -async def _external_index_retrieval( +async def _plan_index_query( kbid: str, - item: FindRequest, - external_index_manager: ExternalIndexManager, -) -> tuple[KnowledgeboxFindResults, bool, ParsedQuery]: + parsed: ParsedQuery, + pb_query: SearchRequest, +) -> tuple[PlanStep[IndexResult], OptionalStep[Relations]]: + """An index query is the first step in the plan. Depending on the KB + configuration, this can be done using our internal index (nidx) or an + external index (currently, only pinecone is supported). + """ - Parse the query, query the external index, and hydrate the results. + query_index: PlanStep[IndexResult] + serialize_relations: OptionalStep[Relations] + + external_index_manager = await get_external_index_manager(kbid=kbid) + if external_index_manager is not None: + assert isinstance(external_index_manager, PineconeIndexManager), ( + "every query index has it's own quirks and we only support Pinecone" + ) + query_index = PineconeQuery(kbid, pb_query, external_index_manager) + serialize_relations = OptionalStep(None) + + else: + nidx_query = CachedStep(NidxQuery(kbid, pb_query)) + query_index = ProtoIntoTextBlockMatches(uses=nidx_query) + serialize_relations = OptionalStep(SerializeRelations(parsed.retrieval, uses=nidx_query)) + + # XXX: Previous implementation got this value from the proto request (i.e., default to 0) + semantic_min_score = 0.0 + if parsed.retrieval.query.semantic: + semantic_min_score = parsed.retrieval.query.semantic.min_score + + # REVIEW: This shouldn't be needed if the index filters correctly + query_index = IndexPostFilter(query_index, semantic_min_score) + + return query_index, serialize_relations + + +def _plan_rank_fusion( + retrieval: UnitRetrieval, *, uses: PlanStep[IndexResult] +) -> PlanStep[list[TextBlockMatch]]: + """Rank fusion is applied to an index result. It merges results from + different indexes into a single list of matches. + + """ + assert retrieval.rank_fusion is not None, "find parser must provide a rank fusion algorithm" + rank_fusion_algorithm = get_rank_fusion(retrieval.rank_fusion) + return RankFusion(algorithm=rank_fusion_algorithm, source=uses) + + +def _plan_hydrate_and_rerank( + kbid: str, + request: FindRequest, + parsed: ParsedQuery, + pb_query: SearchRequest, + *, + uses: PlanStep[list[TextBlockMatch]], +) -> PlanStep[tuple[list[TextBlockMatch], list[Resource]]]: + """Given a list of matches, a reranker is used to improve the score accuracy + of the results. + + Reranking depends on matches to be hydrated, so a text block hydration step + is done before reranking. + + For performance, depending on the reranking parameters, resource hydration + is done in parallel. + """ - # Parse query - parsed = await parse_find(kbid, item) + + retrieval_step = uses + assert parsed.retrieval.reranker is not None, "find parser must provide a reranking algorithm" reranker = get_reranker(parsed.retrieval.reranker) - search_request, incomplete_results, _, rephrased_query = await legacy_convert_retrieval_to_proto( - parsed - ) - # Query index - query_results = await external_index_manager.query(search_request) # noqa - - # Hydrate and rerank results - text_blocks, resources, best_matches = await hydrate_and_rerank( - query_results.iter_matching_text_blocks(), - kbid, - resource_hydration_options=ResourceHydrationOptions( - show=item.show, - extracted=item.extracted, - field_type_filter=item.field_type_filter, - ), - text_block_hydration_options=TextBlockHydrationOptions(), - reranker=reranker, - reranking_options=RerankingOptions( - kbid=kbid, - query=search_request.body, - ), - top_k=parsed.retrieval.top_k, + pre_rerank_cut, rerank, post_rerank_cut = _plan_rerank(kbid, parsed, reranker, pb_query) + hydrate_text_blocks, hydrate_resources = _plan_hydration(kbid, request) + + # Reranking + + hydrate_and_rerank_step: PlanStep[tuple[list[TextBlockMatch], list[Resource]]] + + if post_rerank_cut is not None: + # To avoid hydrating unneeded resources, we first rerank, cut and then + # hydrate them + + reranked = CachedStep( + post_rerank_cut.plan( + rerank.plan( + hydrate_text_blocks.plan( + pre_rerank_cut.plan( + retrieval_step, + ), + ), + ), + ) + ) + hydrate_and_rerank_step = Group( + reranked, + hydrate_resources.plan(uses=reranked), + ) + else: + # As we don't need extra results, we can rerank and hydrate resources in + # parallel + + pre_rerank = CachedStep( + pre_rerank_cut.plan( + retrieval_step, + ), + ) + + hydrate_and_rerank_step = Parallel( + rerank.plan( + hydrate_text_blocks.plan( + pre_rerank, + ), + ), + hydrate_resources.plan( + pre_rerank, + ), + ) + + return hydrate_and_rerank_step + + +def _plan_hydration(kbid: str, request: FindRequest) -> tuple[HydrateTextBlocks, HydrateResources]: + max_hydration_ops = asyncio.Semaphore(50) + + text_block_hydration_options = TextBlockHydrationOptions( + highlight=request.highlight, + # TODO: ematches for text block hydration options + # ematches=search_response.paragraph.ematches, # type: ignore + ematches=[], ) - find_resources = compose_find_resources(text_blocks, resources) + hydrate_text_blocks = HydrateTextBlocks(kbid, text_block_hydration_options, max_hydration_ops) - results_min_score = MinScore( - bm25=0, - semantic=parsed.retrieval.query.semantic.min_score - if parsed.retrieval.query.semantic is not None - else 0.0, + resource_hydration_options = ResourceHydrationOptions( + show=request.show, extracted=request.extracted, field_type_filter=request.field_type_filter ) - retrieval_results = KnowledgeboxFindResults( - resources=find_resources, - query=item.query, - rephrased_query=rephrased_query, - total=0, - page_number=0, - page_size=item.top_k, - relations=None, # Not implemented for external indexes yet - autofilters=[], # Not implemented for external indexes yet - min_score=results_min_score, - best_matches=best_matches, - # These are not used for external indexes - shards=None, - nodes=None, + hydrate_resources = HydrateResources(kbid, resource_hydration_options, max_hydration_ops) + + return hydrate_text_blocks, hydrate_resources + + +def _plan_rerank( + kbid: str, parsed: ParsedQuery, reranker: Reranker, pb_query: SearchRequest +) -> tuple[Cut, Rerank, Optional[Cut]]: + """Reranking is done within two windows. We want to rerank N elements and + obtain the best K. This function returns the steps for reranking and cutting + before and after. + + """ + # cut before reranking + + if reranker.needs_extra_results: + # we assume pagination + predict reranker is forbidden and has been already + # enforced/validated by the query parsing. + assert reranker.window is not None, "Reranker definition must enforce this condition" + pre_rerank_cut = Cut(page=reranker.window) + post_rerank_cut = Cut(page=parsed.retrieval.top_k) + else: + pre_rerank_cut = Cut(page=parsed.retrieval.top_k) + post_rerank_cut = None + + # reranking + + reranking_options = RerankingOptions( + kbid=kbid, + # XXX: we are using keyword query for reranking, it doesn't work with semantic or graph only! + query=pb_query.body, ) + rerank = Rerank(kbid, reranker, reranking_options, parsed.retrieval.top_k) - return retrieval_results, incomplete_results, parsed + return pre_rerank_cut, rerank, post_rerank_cut diff --git a/nucliadb/src/nucliadb/search/search/find_merge.py b/nucliadb/src/nucliadb/search/search/find_merge.py index 3e2eccc98b..eccd536450 100644 --- a/nucliadb/src/nucliadb/search/search/find_merge.py +++ b/nucliadb/src/nucliadb/search/search/find_merge.py @@ -32,7 +32,6 @@ from nucliadb.common.external_index_providers.base import TextBlockMatch from nucliadb.common.ids import ParagraphId, VectorId from nucliadb.search import SERVICE_NAME, logger -from nucliadb.search.search.cut import cut_page from nucliadb.search.search.hydrator import ( ResourceHydrationOptions, TextBlockHydrationOptions, @@ -41,6 +40,7 @@ text_block_to_find_paragraph, ) from nucliadb.search.search.merge import merge_relations_results +from nucliadb.search.search.plan.cut import cut_page from nucliadb.search.search.query_parser.models import UnitRetrieval from nucliadb.search.search.rank_fusion import IndexSource, RankFusionAlgorithm from nucliadb.search.search.rerankers import ( diff --git a/nucliadb/src/nucliadb/search/search/merge.py b/nucliadb/src/nucliadb/search/search/merge.py index 29e85b9cb7..9975f90d18 100644 --- a/nucliadb/src/nucliadb/search/search/merge.py +++ b/nucliadb/src/nucliadb/search/search/merge.py @@ -38,13 +38,13 @@ from nucliadb.common.models_utils import from_proto from nucliadb.common.models_utils.from_proto import RelationTypePbMap from nucliadb.search.search import cache -from nucliadb.search.search.cut import cut_page from nucliadb.search.search.fetch import ( fetch_resources, get_labels_paragraph, get_labels_resource, get_seconds_paragraph, ) +from nucliadb.search.search.plan.cut import cut_page from nucliadb.search.search.query_parser.models import FulltextQuery, UnitRetrieval from nucliadb_models.common import FieldTypeName from nucliadb_models.labels import translate_system_to_alias_label diff --git a/nucliadb/src/nucliadb/search/search/cut.py b/nucliadb/src/nucliadb/search/search/plan/__init__.py similarity index 72% rename from nucliadb/src/nucliadb/search/search/cut.py rename to nucliadb/src/nucliadb/search/search/plan/__init__.py index 4db692e1c3..3b734776ac 100644 --- a/nucliadb/src/nucliadb/search/search/cut.py +++ b/nucliadb/src/nucliadb/search/search/plan/__init__.py @@ -17,14 +17,3 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . # - -from typing import TypeVar - -T = TypeVar("T") - - -def cut_page(items: list[T], top_k: int) -> tuple[list[T], bool]: - """Return a slice of `items` representing the specified page and a boolean - indicating whether there is a next page or not""" - next_page = len(items) > top_k - return items[:top_k], next_page diff --git a/nucliadb/src/nucliadb/search/search/plan/cut.py b/nucliadb/src/nucliadb/search/search/plan/cut.py new file mode 100644 index 0000000000..7b71bd72b5 --- /dev/null +++ b/nucliadb/src/nucliadb/search/search/plan/cut.py @@ -0,0 +1,56 @@ +# Copyright (C) 2021 Bosutech XXI S.L. +# +# nucliadb is offered under the AGPL v3.0 and as commercial software. +# For commercial licensing, contact us at info@nuclia.com. +# +# AGPL: +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +# + +from typing import TypeVar, Union + +from typing_extensions import Self + +from nucliadb.common.external_index_providers.base import TextBlockMatch + +from .models import ExecutionContext, PlanStep + +T = TypeVar("T") + + +class Cut(PlanStep): + def __init__(self, page: int): + self.page = page + + def plan(self, uses: PlanStep[list[TextBlockMatch]]) -> Self: + self.source = uses + return self + + async def execute(self, context: ExecutionContext) -> list[TextBlockMatch]: + text_blocks = await self.source.execute(context) + page, next_page = cut_page(text_blocks, self.page) + context.next_page = next_page + return page + + def explain(self) -> Union[str, dict[str, Union[str, dict]]]: + return { + "Cut": self.source.explain(), + } + + +def cut_page(items: list[T], top_k: int) -> tuple[list[T], bool]: + """Return a slice of `items` representing the specified page and a boolean + indicating whether there is a next page or not""" + next_page = len(items) > top_k + return items[:top_k], next_page diff --git a/nucliadb/src/nucliadb/search/search/plan/hydrate.py b/nucliadb/src/nucliadb/search/search/plan/hydrate.py new file mode 100644 index 0000000000..abfe64999c --- /dev/null +++ b/nucliadb/src/nucliadb/search/search/plan/hydrate.py @@ -0,0 +1,116 @@ +# Copyright (C) 2021 Bosutech XXI S.L. +# +# nucliadb is offered under the AGPL v3.0 and as commercial software. +# For commercial licensing, contact us at info@nuclia.com. +# +# AGPL: +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +# +import asyncio +from typing import Union + +from typing_extensions import Self + +from nucliadb.common.external_index_providers.base import TextBlockMatch +from nucliadb.search import SERVICE_NAME +from nucliadb.search.search.hydrator import ( + ResourceHydrationOptions, + TextBlockHydrationOptions, + hydrate_resource_metadata, + hydrate_text_block, +) +from nucliadb_models.resource import Resource + +from .models import ExecutionContext, PlanStep + + +class HydrateTextBlocks(PlanStep): + def __init__( + self, + kbid: str, + hydration_options: TextBlockHydrationOptions, + max_ops: asyncio.Semaphore, + ): + self.kbid = kbid + self.hydration_options = hydration_options + self.max_ops = max_ops + + def plan(self, uses: PlanStep[list[TextBlockMatch]]) -> Self: + self.source = uses + return self + + async def execute(self, context: ExecutionContext) -> list[TextBlockMatch]: + text_blocks = await self.source.execute(context) + + ops = [] + for text_block in text_blocks: + ops.append( + asyncio.create_task( + hydrate_text_block( + self.kbid, + text_block, + self.hydration_options, + concurrency_control=self.max_ops, + ) + ) + ) + + # TODO: metrics + hydrated = await asyncio.gather(*ops) + return hydrated + + def explain(self) -> Union[str, dict[str, Union[str, dict]]]: + return { + "HydrateTextBlocks": self.source.explain(), + } + + +class HydrateResources(PlanStep): + def __init__( + self, kbid: str, hydration_options: ResourceHydrationOptions, max_ops: asyncio.Semaphore + ): + self.kbid = kbid + self.hydration_options = hydration_options + self.max_ops = max_ops + + def plan(self, uses: PlanStep[list[TextBlockMatch]]) -> Self: + self.source = uses + return self + + async def execute(self, context: ExecutionContext) -> list[Resource]: + text_blocks = await self.source.execute(context) + + ops = {} + for text_block in text_blocks: + rid = text_block.paragraph_id.rid + + if rid not in ops: + ops[rid] = asyncio.create_task( + hydrate_resource_metadata( + self.kbid, + rid, + options=self.hydration_options, + concurrency_control=self.max_ops, + service_name=SERVICE_NAME, + ) + ) + # TODO: metrics + hydrated = await asyncio.gather(*ops.values()) + resources = [resource for resource in hydrated if resource is not None] + return resources + + def explain(self) -> Union[str, dict[str, Union[str, dict]]]: + return { + "HydrateResources": self.source.explain(), + } diff --git a/nucliadb/src/nucliadb/search/search/plan/models.py b/nucliadb/src/nucliadb/search/search/plan/models.py new file mode 100644 index 0000000000..ffe7d51b6c --- /dev/null +++ b/nucliadb/src/nucliadb/search/search/plan/models.py @@ -0,0 +1,103 @@ +# Copyright (C) 2021 Bosutech XXI S.L. +# +# nucliadb is offered under the AGPL v3.0 and as commercial software. +# For commercial licensing, contact us at info@nuclia.com. +# +# AGPL: +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +# +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Generic, Optional, TypeVar, Union + +from nucliadb.common.external_index_providers.base import TextBlockMatch +from nucliadb_models.resource import Resource + +# execution context (filled while running the plan) + + +@dataclass +class ExecutionContext: + # nidx search failed on some shards and results aren't from the whole corpus + nidx_incomplete: bool = False + + # nidx shards queried during retrieval + nidx_queried_shards: Optional[list[str]] = None + + # whether a query with a greater top_k would have returned more results + next_page: bool = False + + # number of keyword results obtained from the index + keyword_result_count: int = 0 + + # list of exact matches during keyword search + keyword_exact_matches: Optional[list[str]] = None + + +# query plan step outputs + + +@dataclass +class IndexResult: + keyword: list[TextBlockMatch] + semantic: list[TextBlockMatch] + graph: list[TextBlockMatch] + + +@dataclass +class BestMatchesHydrated: + best_text_blocks: list[TextBlockMatch] + resources: list[Resource] + best_matches: list[str] + + +# query plan step logic + + +T = TypeVar("T") + + +class PlanStep(ABC, Generic[T]): + @abstractmethod + async def execute(self, context: ExecutionContext) -> T: + pass + + @abstractmethod + def explain(self) -> Union[str, dict[str, Union[str, dict]]]: + return {self.__class__.__name__: "(explain not implemented)"} + + +class Plan(Generic[T]): + def __init__(self, step: PlanStep[T]): + self._context = ExecutionContext() + self.step = step + + async def execute(self) -> T: + return await self.step.execute(self._context) + + def explain(self) -> None: + def _print_plan(plan: Union[str, Union[str, dict]], offset: int = 0): + if isinstance(plan, dict): + for key, value in plan.items(): + print(" " * offset, "-", key) + _print_plan(value, offset + 2) + else: + print(" " * offset, "-", plan) + + plan = self.step.explain() + _print_plan(plan) + + @property + def context(self) -> ExecutionContext: + return self._context diff --git a/nucliadb/src/nucliadb/search/search/plan/query.py b/nucliadb/src/nucliadb/search/search/plan/query.py new file mode 100644 index 0000000000..adf963dbd0 --- /dev/null +++ b/nucliadb/src/nucliadb/search/search/plan/query.py @@ -0,0 +1,113 @@ +# Copyright (C) 2021 Bosutech XXI S.L. +# +# nucliadb is offered under the AGPL v3.0 and as commercial software. +# For commercial licensing, contact us at info@nuclia.com. +# +# AGPL: +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +# +from typing import Union + +from nidx_protos.nodereader_pb2 import SearchRequest, SearchResponse + +from nucliadb.common.external_index_providers.pinecone import PineconeIndexManager +from nucliadb.search.requesters.utils import Method, node_query +from nucliadb.search.search.find_merge import ( + graph_results_to_text_block_matches, + keyword_results_to_text_block_matches, + merge_shard_responses, + semantic_results_to_text_block_matches, +) + +from .models import ExecutionContext, IndexResult, PlanStep + + +class NidxQuery(PlanStep): + """Perform a nidx search and return results as text block matches""" + + def __init__(self, kbid: str, pb_query: SearchRequest): + self.kbid = kbid + self.pb_query = pb_query + + async def execute(self, context: ExecutionContext) -> SearchResponse: + index_results, incomplete, queried_shards = await node_query( + self.kbid, Method.SEARCH, self.pb_query + ) + context.nidx_incomplete = incomplete + context.nidx_queried_shards = queried_shards + + search_response = merge_shard_responses(index_results) + + context.next_page = context.next_page or search_response.paragraph.next_page + context.keyword_result_count = context.next_page or search_response.paragraph.next_page + context.keyword_exact_matches = list(search_response.paragraph.ematches) + + return search_response + + def explain(self) -> Union[str, dict[str, Union[str, dict]]]: + return {"NidxQuery": "..."} + + +class ProtoIntoTextBlockMatches(PlanStep): + def __init__(self, uses: PlanStep[SearchResponse]): + self.proto_query = uses + + async def execute(self, context: ExecutionContext) -> IndexResult: + search_response = await self.proto_query.execute(context) + + keyword = keyword_results_to_text_block_matches(search_response.paragraph.results) + semantic = semantic_results_to_text_block_matches(search_response.vector.documents) + graph = graph_results_to_text_block_matches(search_response.graph) + + return IndexResult(keyword=keyword, semantic=semantic, graph=graph) + + def explain(self) -> Union[str, dict[str, Union[str, dict]]]: + return { + "ProtoIntoTextBlockMatches": self.proto_query.explain(), + } + + +class PineconeQuery(PlanStep): + def __init__(self, kbid: str, pb_query: SearchRequest, pinecone_index: PineconeIndexManager): + self.kbid = kbid + self.pb_query = pb_query + self.pinecone_index = pinecone_index + + async def execute(self, _context: ExecutionContext) -> IndexResult: + query_results = await self.pinecone_index.query(self.pb_query) + return IndexResult( + semantic=list(query_results.iter_matching_text_blocks()), + # keyword and graph are not supported in Pinecone. + keyword=[], + graph=[], + ) + + def explain(self) -> Union[str, dict[str, Union[str, dict]]]: + return "PineconeQuery" + + +class IndexPostFilter(PlanStep): + def __init__(self, index_query: PlanStep[IndexResult], semantic_min_score: float): + self.index_query = index_query + self.semantic_min_score = semantic_min_score + + async def execute(self, context: ExecutionContext) -> IndexResult: + results = await self.index_query.execute(context) + results.semantic = list(filter(lambda x: x.score >= self.semantic_min_score, results.semantic)) + return results + + def explain(self) -> Union[str, dict[str, Union[str, dict]]]: + return { + "IndexPostFilter": self.index_query.explain(), + } diff --git a/nucliadb/src/nucliadb/search/search/plan/rank_fusion.py b/nucliadb/src/nucliadb/search/search/plan/rank_fusion.py new file mode 100644 index 0000000000..03d0811c52 --- /dev/null +++ b/nucliadb/src/nucliadb/search/search/plan/rank_fusion.py @@ -0,0 +1,47 @@ +# Copyright (C) 2021 Bosutech XXI S.L. +# +# nucliadb is offered under the AGPL v3.0 and as commercial software. +# For commercial licensing, contact us at info@nuclia.com. +# +# AGPL: +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +# +from typing import Union + +from nucliadb.common.external_index_providers.base import TextBlockMatch +from nucliadb.search.search.rank_fusion import IndexSource, RankFusionAlgorithm + +from .models import ExecutionContext, IndexResult, PlanStep + + +class RankFusion(PlanStep): + def __init__(self, algorithm: RankFusionAlgorithm, source: PlanStep[IndexResult]): + self.algorithm = algorithm + self.source = source + + async def execute(self, context: ExecutionContext) -> list[TextBlockMatch]: + index_results = await self.source.execute(context) + merged = self.algorithm.fuse( + { + IndexSource.KEYWORD: index_results.keyword, + IndexSource.SEMANTIC: index_results.semantic, + IndexSource.GRAPH: index_results.graph, + } + ) + return merged + + def explain(self) -> Union[str, dict[str, Union[str, dict]]]: + return { + "RankFusion": self.source.explain(), + } diff --git a/nucliadb/src/nucliadb/search/search/plan/rerank.py b/nucliadb/src/nucliadb/search/search/plan/rerank.py new file mode 100644 index 0000000000..351f664a43 --- /dev/null +++ b/nucliadb/src/nucliadb/search/search/plan/rerank.py @@ -0,0 +1,105 @@ +# Copyright (C) 2021 Bosutech XXI S.L. +# +# nucliadb is offered under the AGPL v3.0 and as commercial software. +# For commercial licensing, contact us at info@nuclia.com. +# +# AGPL: +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +# +from typing import Union + +from typing_extensions import Self + +from nucliadb.common.external_index_providers.base import TextBlockMatch +from nucliadb.search.search.rerankers import ( + RerankableItem, + Reranker, + RerankingOptions, +) + +from .models import ExecutionContext, PlanStep + + +class Rerank(PlanStep): + def __init__(self, kbid: str, reranker: Reranker, reranking_options: RerankingOptions, top_k: int): + self.kbid = kbid + self.reranker = reranker + self.reranking_options = reranking_options + self.top_k = top_k + + def plan(self, uses: PlanStep[list[TextBlockMatch]]) -> Self: + self.source = uses + return self + + async def execute(self, context: ExecutionContext) -> list[TextBlockMatch]: + # we assume text blocks have been hydrated. It'd be nice to enforce this + # with types + text_blocks = await self.source.execute(context) + + text_blocks_by_id: dict[ + str, TextBlockMatch + ] = {} # useful for faster access to text blocks later + to_rerank = [] + for text_block in text_blocks: + paragraph_id = text_block.paragraph_id.full() + + # If we find multiple results (from different indexes) with different + # metadata, this statement will only get the metadata from the first on + # the list. We assume metadata is the same on all indexes, otherwise + # this would be a BUG + text_blocks_by_id.setdefault(paragraph_id, text_block) + + to_rerank.append( + RerankableItem( + id=paragraph_id, + score=text_block.score, + score_type=text_block.score_type, + content=text_block.text or "", # TODO: add a warning, this shouldn't usually happen + ) + ) + + reranked = await self.reranker.rerank(to_rerank, self.reranking_options) + + # after reranking, we can cut to the number of results the user wants, so we + # don't work with unnecessary stuff + reranked = reranked[: self.top_k] + + # now get back the text block matches + matches = [] + for item in reranked: + paragraph_id = item.id + score = item.score + score_type = item.score_type + + text_block = text_blocks_by_id[paragraph_id] + text_block.score = score + text_block.score_type = score_type + + matches.append((paragraph_id, score)) + + # TODO: we shouldn't need to sort here again + matches.sort(key=lambda x: x[1], reverse=True) + + best_text_blocks = [] + for order, (paragraph_id, _) in enumerate(matches): + text_block = text_blocks_by_id[paragraph_id] + text_block.order = order + best_text_blocks.append(text_block) + + return best_text_blocks + + def explain(self) -> Union[str, dict[str, Union[str, dict]]]: + return { + "Rerank": self.source.explain(), + } diff --git a/nucliadb/src/nucliadb/search/search/plan/serialize.py b/nucliadb/src/nucliadb/search/search/plan/serialize.py new file mode 100644 index 0000000000..a2481e8784 --- /dev/null +++ b/nucliadb/src/nucliadb/search/search/plan/serialize.py @@ -0,0 +1,51 @@ +# Copyright (C) 2021 Bosutech XXI S.L. +# +# nucliadb is offered under the AGPL v3.0 and as commercial software. +# For commercial licensing, contact us at info@nuclia.com. +# +# AGPL: +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +# +from typing import Union + +from nidx_protos.nodereader_pb2 import SearchResponse + +from nucliadb.search.search.merge import merge_relations_results +from nucliadb.search.search.query_parser.models import UnitRetrieval +from nucliadb_models.search import ( + Relations, +) + +from .models import ExecutionContext, PlanStep + + +class SerializeRelations(PlanStep): + def __init__(self, retrieval: UnitRetrieval, uses: PlanStep[SearchResponse]): + self.retrieval = retrieval + self.proto_query = uses + + async def execute(self, context: ExecutionContext) -> Relations: + if self.retrieval.query.relation is not None: + search_response = await self.proto_query.execute(context) + entry_points = self.retrieval.query.relation.entry_points + relations = await merge_relations_results([search_response.graph], entry_points) + else: + relations = Relations(entities={}) + + return relations + + def explain(self) -> Union[str, dict[str, Union[str, dict]]]: + return { + "Cached": self.proto_query.explain(), + } diff --git a/nucliadb/src/nucliadb/search/search/plan/utils.py b/nucliadb/src/nucliadb/search/search/plan/utils.py new file mode 100644 index 0000000000..f5f042ef11 --- /dev/null +++ b/nucliadb/src/nucliadb/search/search/plan/utils.py @@ -0,0 +1,170 @@ +# Copyright (C) 2021 Bosutech XXI S.L. +# +# nucliadb is offered under the AGPL v3.0 and as commercial software. +# For commercial licensing, contact us at info@nuclia.com. +# +# AGPL: +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +# +import asyncio +from typing import Optional, TypeVar, Union + +from typing_extensions import TypeIs + +from .models import ExecutionContext, PlanStep + +T = TypeVar("T") +A = TypeVar("A") +B = TypeVar("B") +C = TypeVar("C") + + +# We use a class as cache miss marker to allow None values in the cache and to +# make mypy happy with typing +class NotCached: + pass + + +def is_cached(field: Union[T, NotCached]) -> TypeIs[T]: + return not isinstance(field, NotCached) + + +# class DiamondStep(PlanStep[tuple[A, B]]): +# def __init__(self, ) + + +class CachedStep(PlanStep[T]): + """Executes a plan step and caches the result. Following calls to `execute` + will return the cached result. + + This is useful to build diamonds in a tree. I.e., we want this: + A + // \\ + B C + \\ // + D + + and implement it like this: + A + // \\ + B C + | | + D' D' + + where D' is cached D step + + """ + + def __init__(self, step: PlanStep[T]): + self.step = step + self.lock = asyncio.Lock() + self.cache: Union[T, NotCached] = NotCached() + + async def execute(self, context: ExecutionContext) -> T: + async with self.lock: + if not is_cached(self.cache): + print("not cached") + self.cache = await self.step.execute(context) + print("caching") + print("self is", self) + else: + print("is cached") + return self.cache + + def explain(self) -> Union[str, dict[str, Union[str, dict]]]: + return { + "Cached": self.step.explain(), + } + + +class Parallel(PlanStep[tuple[A, B]]): + """A parallel step executes multiple steps in different asyncio tasks and + gathers and returns all results. This is useful when multiple steps can be + done concurrently. + + """ + + def __init__(self, a: PlanStep[A], b: PlanStep[B]): + self.a = a + self.b = b + + async def execute(self, context: ExecutionContext) -> tuple[A, B]: + a, b = await asyncio.gather( + self.a.execute(context), + self.b.execute(context), + ) + return (a, b) + + def explain(self) -> Union[str, dict[str, Union[str, dict]]]: + return { + "Parallel": { + "A": self.a.explain(), + "B": self.b.explain(), + } + } + + +class Group(PlanStep[tuple[A, B]]): + """Executes a series of steps in order and return all its values. This is + useful when we need to group results from multiple steps that depend on each + other. + + """ + + def __init__(self, a: PlanStep[A], b: PlanStep[B]): + self.a = a + self.b = b + + async def execute(self, context: ExecutionContext) -> tuple[A, B]: + a = await self.a.execute(context) + b = await self.b.execute(context) + return (a, b) + + def explain(self) -> Union[str, dict[str, Union[str, dict]]]: + return { + "Group": { + "A": self.a.explain(), + "B": self.b.explain(), + } + } + + +# TODO: implement a more generic Flatten +class Flatten(PlanStep[tuple[A, B, C]]): + def __init__(self, step: PlanStep[tuple[tuple[A, B], C]]): + self.step = step + + async def execute(self, context: ExecutionContext) -> tuple[A, B, C]: + (a, b), c = await self.step.execute(context) + return (a, b, c) + + def explain(self) -> Union[str, dict[str, Union[str, dict]]]: + return {"Flatten": self.step.explain()} + + +class OptionalStep(PlanStep[Optional[T]]): + def __init__(self, step: Optional[PlanStep[T]]): + self.step = step + + async def execute(self, context: ExecutionContext) -> Optional[T]: + if self.step is None: + return None + else: + return await self.step.execute(context) + + def explain(self) -> Union[str, dict[str, Union[str, dict]]]: + if self.step is None: + return {"Optional": "None"} + else: + return {"Optional": self.step.explain()} diff --git a/nucliadb/tests/nucliadb/integration/search/graph/test_find_graph.py b/nucliadb/tests/nucliadb/integration/search/graph/test_find_graph.py index b9b3cb4298..3f25171c70 100644 --- a/nucliadb/tests/nucliadb/integration/search/graph/test_find_graph.py +++ b/nucliadb/tests/nucliadb/integration/search/graph/test_find_graph.py @@ -23,7 +23,7 @@ from nidx_protos.nodereader_pb2 import SearchRequest from pytest_mock import MockerFixture -from nucliadb.search.search import find +from nucliadb.search.search.plan import query @pytest.mark.deploy_modes("standalone") @@ -34,7 +34,7 @@ async def test_find_graph_request( ): """Validate how /find prepares a graph search""" kbid = standalone_knowledgebox - spy = mocker.spy(find, "node_query") + spy = mocker.spy(query, "node_query") # graph_query but missing features=graph resp = await nucliadb_reader.post( @@ -110,7 +110,7 @@ async def test_find_graph_feature( """ kbid = standalone_knowledgebox - spy = mocker.spy(find, "build_find_response") + spy = mocker.spy(query, "merge_shard_responses") resp = await nucliadb_reader.post( f"/kb/{kbid}/find", diff --git a/nucliadb/tests/nucliadb/integration/search/test_search.py b/nucliadb/tests/nucliadb/integration/search/test_search.py index daef1631af..d872ceeea7 100644 --- a/nucliadb/tests/nucliadb/integration/search/test_search.py +++ b/nucliadb/tests/nucliadb/integration/search/test_search.py @@ -838,9 +838,9 @@ async def test_search_user_relations( ): kbid = standalone_knowledgebox - from nucliadb.search.search import find + from nucliadb.search.search.plan import query - spy = mocker.spy(find, "node_query") + spy = mocker.spy(query, "node_query") with patch.object(predict_mock, "detect_entities", AsyncMock(return_value=[])): resp = await nucliadb_reader.post( f"/kb/{kbid}/find", diff --git a/nucliadb/tests/nucliadb/integration/test_vectorsets.py b/nucliadb/tests/nucliadb/integration/test_vectorsets.py index 30cd379216..069f24fb07 100644 --- a/nucliadb/tests/nucliadb/integration/test_vectorsets.py +++ b/nucliadb/tests/nucliadb/integration/test_vectorsets.py @@ -144,7 +144,7 @@ def set_predict_default_vectorset(query_info: QueryInfo) -> QueryInfo: new=AsyncMock(side_effect=mock_node_query), ), patch( - "nucliadb.search.search.find.node_query", + "nucliadb.search.search.plan.query.node_query", new=AsyncMock(side_effect=mock_node_query), ), patch( @@ -198,7 +198,7 @@ async def mock_node_query(kbid: str, method, pb_query: nodereader_pb2.SearchRequ new=AsyncMock(side_effect=mock_node_query), ), patch( - "nucliadb.search.search.find.node_query", + "nucliadb.search.search.plan.query.node_query", new=AsyncMock(side_effect=mock_node_query), ), patch(