Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
115 changes: 114 additions & 1 deletion aperag/service/collection_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,90 @@ async def create_search(

# Process search results
docs = result.get(end_node_id, {}).docs

# Apply reranking if we have documents
if docs:
# Find a rerank model with the "default_for_rerank" tag
rerank_model = await self._get_default_rerank_model(user)

if rerank_model:
# Verify if the model has an API key configured
api_key = await async_db_ops.query_provider_api_key(rerank_model["provider"], user)

if api_key:
# Add rerank node to the flow and execute it
try:
rerank_node_id = "rerank"

# Add the rerank node to the flow
flow.nodes[rerank_node_id] = NodeInstance(
id=rerank_node_id,
type="rerank",
input_values={
"model": rerank_model["model"],
"model_service_provider": rerank_model["provider"],
"custom_llm_provider": rerank_model.get("custom_llm_provider", rerank_model["provider"]),
"docs": "{{ nodes." + end_node_id + ".output.docs }}" # Use reference to merge node output
}
)

# Add an edge from the merge node to the rerank node
flow.edges.append(Edge(source=end_node_id, target=rerank_node_id))

# Execute the flow with the rerank node
result, _ = await engine.execute_flow(flow, {"query": query, "user": user})
if result and rerank_node_id in result:
docs = result[rerank_node_id].docs

except Exception as e:
# Log error but continue with original results
import logging
logger = logging.getLogger(__name__)
logger.error(f"Reranking failed: {str(e)}")
else:
# Log warning that the model doesn't have an API key configured
import logging
logger = logging.getLogger(__name__)
logger.warning(f"Rerank model {rerank_model['model']} with provider {rerank_model['provider']} has no API key configured. Using fallback rerank strategy.")

# Apply fallback rerank strategy
# 1. Place graph search results first (they typically have better quality)
# 2. Sort remaining vector and fulltext results by score in descending order
graph_results = []
other_results = []

for doc in docs:
recall_type = doc.metadata.get("recall_type", "")
if recall_type == "graph_search":
graph_results.append(doc)
else:
other_results.append(doc)

# Sort other results by score in descending order
other_results.sort(key=lambda x: x.score if x.score is not None else 0, reverse=True)

# Combine results with graph results first
docs = graph_results + other_results
else:
# Apply fallback strategy if no rerank model is available
# 1. Place graph search results first (they typically have better quality)
# 2. Sort remaining vector and fulltext results by score in descending order
graph_results = []
other_results = []

for doc in docs:
recall_type = doc.metadata.get("recall_type", "")
if recall_type == "graph_search":
graph_results.append(doc)
else:
other_results.append(doc)

# Sort other results by score in descending order
other_results.sort(key=lambda x: x.score if x.score is not None else 0, reverse=True)

# Combine results with graph results first
docs = graph_results + other_results

items = []
for idx, doc in enumerate(docs):
items.append(
Expand Down Expand Up @@ -290,6 +374,35 @@ async def create_search(
created=record.gmt_created.isoformat(),
)

async def _get_default_rerank_model(self, user: str) -> dict:
"""
Find a rerank model with the "default_for_rerank" tag.
Returns the first model found or None if no model is available.
"""
from aperag.schema.view_models import TagFilterCondition
from aperag.service.llm_available_model_service import llm_available_model_service

# Create a tag filter to find models with "default_for_rerank" tag
tag_filter = [TagFilterCondition(tags=["default_for_rerank"], operation="AND")]

# Get available models with the filter
model_configs = await llm_available_model_service.get_available_models(
user, view_models.TagFilterRequest(tag_filters=tag_filter)
)

# Look for rerank models in the result
for provider in model_configs.items:
if provider.rerank and len(provider.rerank) > 0:
# Return the first rerank model found
return {
"provider": provider.name,
"model": provider.rerank[0]["model"],
"custom_llm_provider": provider.rerank[0].get("custom_llm_provider"),
}

# No rerank model found
return None

async def list_searches(self, user: str, collection_id: str) -> view_models.SearchResultList:
from aperag.exceptions import CollectionNotFoundException

Expand Down Expand Up @@ -346,4 +459,4 @@ async def test_mineru_token(self, token: str) -> dict:

# Create a global service instance for easy access
# This uses the global db_ops instance and doesn't require session management in views
collection_service = CollectionService()
collection_service = CollectionService()