Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
98 changes: 97 additions & 1 deletion aperag/service/collection_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,73 @@ async def create_search(

# Process search results
docs = result.get(end_node_id, {}).docs

# Apply reranking if we have documents
if docs:
# Find a rerank model with the "default_for_rerank" tag
rerank_model = await self._get_default_rerank_model(user)

if rerank_model:
# Verify if the model has an API key configured
api_key = await async_db_ops.query_provider_api_key(rerank_model["provider"], user)

if api_key:
# Add rerank node to the existing flow instead of creating a separate FlowInstance
try:
# Add rerank node to the existing flow
rerank_node_id = "rerank"

# Add the rerank node to the existing nodes dictionary
flow.nodes[rerank_node_id] = NodeInstance(
id=rerank_node_id,
type="rerank",
input_values={
"model": rerank_model["model"],
"model_service_provider": rerank_model["provider"],
"custom_llm_provider": rerank_model.get("custom_llm_provider", rerank_model["provider"]),
"docs": docs
}
)

# Add an edge from the merge node to the rerank node
flow.edges.append(Edge(source=end_node_id, target=rerank_node_id))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rerank node shouldn't just be added; it should form a DAG after the merge node. Our flow process has actually implemented the combination of search + rerank; you can refer to rag_flow2.yaml.


# Re-execute the flow with the added rerank node
result, _ = await engine.execute_flow(flow, {"query": query, "user": user})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You shouldn't re-execute the flow. Instead, you should check if there's a rerank node; if so, include it in the execution. If there's no rerank node, then use the default rerank strategy.


if result and rerank_node_id in result:
docs = result[rerank_node_id].docs

except Exception as e:
# Log error but continue with original results
import logging
logger = logging.getLogger(__name__)
logger.error(f"Reranking failed: {str(e)}")
else:
# Log warning that the model doesn't have an API key configured
import logging
logger = logging.getLogger(__name__)
logger.warning(f"Rerank model {rerank_model['model']} with provider {rerank_model['provider']} has no API key configured. Skipping reranking.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there's no API key, the fallback rerank strategy should be used.

else:
# Apply fallback strategy if no rerank model is available
# 1. Place graph search results first (they typically have better quality)
# 2. Sort remaining vector and fulltext results by score in descending order
graph_results = []
other_results = []

for doc in docs:
recall_type = doc.metadata.get("recall_type", "")
if recall_type == "graph_search":
graph_results.append(doc)
else:
other_results.append(doc)

# Sort other results by score in descending order
other_results.sort(key=lambda x: x.score if x.score is not None else 0, reverse=True)

# Combine results with graph results first
docs = graph_results + other_results

items = []
for idx, doc in enumerate(docs):
items.append(
Expand Down Expand Up @@ -290,6 +357,35 @@ async def create_search(
created=record.gmt_created.isoformat(),
)

async def _get_default_rerank_model(self, user: str) -> dict:
"""
Find a rerank model with the "default_for_rerank" tag.
Returns the first model found or None if no model is available.
"""
from aperag.schema.view_models import TagFilterCondition
from aperag.service.llm_available_model_service import llm_available_model_service

# Create a tag filter to find models with "default_for_rerank" tag
tag_filter = [TagFilterCondition(tags=["default_for_rerank"], operation="AND")]

# Get available models with the filter
model_configs = await llm_available_model_service.get_available_models(
user, view_models.TagFilterRequest(tag_filters=tag_filter)
)

# Look for rerank models in the result
for provider in model_configs.items:
if provider.rerank and len(provider.rerank) > 0:
# Return the first rerank model found
return {
"provider": provider.name,
"model": provider.rerank[0]["model"],
"custom_llm_provider": provider.rerank[0].get("custom_llm_provider"),
}

# No rerank model found
return None

async def list_searches(self, user: str, collection_id: str) -> view_models.SearchResultList:
from aperag.exceptions import CollectionNotFoundException

Expand Down Expand Up @@ -346,4 +442,4 @@ async def test_mineru_token(self, token: str) -> dict:

# Create a global service instance for easy access
# This uses the global db_ops instance and doesn't require session management in views
collection_service = CollectionService()
collection_service = CollectionService()