-
Notifications
You must be signed in to change notification settings - Fork 81
feat: implement rerank step in collection search API #1162
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
29597c2
fd48539
e78c569
23949b1
6237035
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -258,6 +258,73 @@ async def create_search( | |
|
||
# Process search results | ||
docs = result.get(end_node_id, {}).docs | ||
|
||
# Apply reranking if we have documents | ||
if docs: | ||
# Find a rerank model with the "default_for_rerank" tag | ||
rerank_model = await self._get_default_rerank_model(user) | ||
|
||
if rerank_model: | ||
# Verify if the model has an API key configured | ||
api_key = await async_db_ops.query_provider_api_key(rerank_model["provider"], user) | ||
|
||
if api_key: | ||
# Add rerank node to the existing flow instead of creating a separate FlowInstance | ||
try: | ||
# Add rerank node to the existing flow | ||
rerank_node_id = "rerank" | ||
|
||
# Add the rerank node to the existing nodes dictionary | ||
flow.nodes[rerank_node_id] = NodeInstance( | ||
id=rerank_node_id, | ||
type="rerank", | ||
input_values={ | ||
"model": rerank_model["model"], | ||
"model_service_provider": rerank_model["provider"], | ||
"custom_llm_provider": rerank_model.get("custom_llm_provider", rerank_model["provider"]), | ||
"docs": docs | ||
} | ||
) | ||
|
||
# Add an edge from the merge node to the rerank node | ||
flow.edges.append(Edge(source=end_node_id, target=rerank_node_id)) | ||
|
||
# Re-execute the flow with the added rerank node | ||
result, _ = await engine.execute_flow(flow, {"query": query, "user": user}) | ||
|
||
|
||
if result and rerank_node_id in result: | ||
docs = result[rerank_node_id].docs | ||
|
||
except Exception as e: | ||
# Log error but continue with original results | ||
import logging | ||
logger = logging.getLogger(__name__) | ||
logger.error(f"Reranking failed: {str(e)}") | ||
else: | ||
# Log warning that the model doesn't have an API key configured | ||
import logging | ||
logger = logging.getLogger(__name__) | ||
logger.warning(f"Rerank model {rerank_model['model']} with provider {rerank_model['provider']} has no API key configured. Skipping reranking.") | ||
|
||
else: | ||
# Apply fallback strategy if no rerank model is available | ||
# 1. Place graph search results first (they typically have better quality) | ||
# 2. Sort remaining vector and fulltext results by score in descending order | ||
graph_results = [] | ||
other_results = [] | ||
|
||
for doc in docs: | ||
recall_type = doc.metadata.get("recall_type", "") | ||
if recall_type == "graph_search": | ||
graph_results.append(doc) | ||
else: | ||
other_results.append(doc) | ||
|
||
# Sort other results by score in descending order | ||
other_results.sort(key=lambda x: x.score if x.score is not None else 0, reverse=True) | ||
|
||
# Combine results with graph results first | ||
docs = graph_results + other_results | ||
|
||
items = [] | ||
for idx, doc in enumerate(docs): | ||
items.append( | ||
|
@@ -290,6 +357,35 @@ async def create_search( | |
created=record.gmt_created.isoformat(), | ||
) | ||
|
||
async def _get_default_rerank_model(self, user: str) -> dict: | ||
""" | ||
Find a rerank model with the "default_for_rerank" tag. | ||
Returns the first model found or None if no model is available. | ||
""" | ||
from aperag.schema.view_models import TagFilterCondition | ||
from aperag.service.llm_available_model_service import llm_available_model_service | ||
|
||
# Create a tag filter to find models with "default_for_rerank" tag | ||
tag_filter = [TagFilterCondition(tags=["default_for_rerank"], operation="AND")] | ||
|
||
# Get available models with the filter | ||
model_configs = await llm_available_model_service.get_available_models( | ||
user, view_models.TagFilterRequest(tag_filters=tag_filter) | ||
) | ||
|
||
# Look for rerank models in the result | ||
for provider in model_configs.items: | ||
if provider.rerank and len(provider.rerank) > 0: | ||
# Return the first rerank model found | ||
return { | ||
"provider": provider.name, | ||
"model": provider.rerank[0]["model"], | ||
"custom_llm_provider": provider.rerank[0].get("custom_llm_provider"), | ||
} | ||
|
||
# No rerank model found | ||
return None | ||
|
||
async def list_searches(self, user: str, collection_id: str) -> view_models.SearchResultList: | ||
from aperag.exceptions import CollectionNotFoundException | ||
|
||
|
@@ -346,4 +442,4 @@ async def test_mineru_token(self, token: str) -> dict: | |
|
||
# Create a global service instance for easy access | ||
# This uses the global db_ops instance and doesn't require session management in views | ||
collection_service = CollectionService() | ||
collection_service = CollectionService() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The rerank node shouldn't just be added; it should form a DAG after the merge node. Our flow process has actually implemented the combination of search + rerank; you can refer to rag_flow2.yaml.