Skip to content
This repository was archived by the owner on May 27, 2025. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions backend/graphrag_app/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,15 @@
response_model=StorageNameList,
responses={status.HTTP_200_OK: {"model": StorageNameList}},
)
async def get_all_data_containers():
async def get_all_data_containers(
container_store_client=Depends(get_cosmos_container_store_client),
):
"""
Retrieve a list of all data containers.
"""
items = []
try:
container_store_client = get_cosmos_container_store_client()
# container_store_client = get_cosmos_container_store_client()
for item in container_store_client.read_all_items():
if item["type"] == "data":
items.append(item["human_readable_name"])
Expand Down Expand Up @@ -184,6 +186,8 @@ async def upload_files(
)
return BaseResponse(status="Success.")
except Exception as e:
# import traceback
# traceback.print_exc()
logger = load_pipeline_logger()
logger.error(
message="Error uploading files.",
Expand Down
44 changes: 42 additions & 2 deletions backend/graphrag_app/api/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

import os
import traceback
from io import BytesIO

import networkx as nx
from fastapi import (
APIRouter,
Depends,
Expand All @@ -13,6 +15,7 @@
from fastapi.responses import StreamingResponse

from graphrag_app.logger.load_logger import load_pipeline_logger
from graphrag_app.typing.models import GraphDataResponse
from graphrag_app.utils.azure_clients import AzureClientManager
from graphrag_app.utils.common import (
sanitize_name,
Expand All @@ -37,6 +40,8 @@
async def get_graphml_file(
container_name, sanitized_container_name: str = Depends(sanitize_name)
):
logger = load_pipeline_logger()

# validate graphml file existence
azure_client_manager = AzureClientManager()
graphml_filename = "graph.graphml"
Expand All @@ -50,10 +55,12 @@ async def get_graphml_file(
return StreamingResponse(
blob_stream,
media_type="application/octet-stream",
headers={"Content-Disposition": f"attachment; filename={graphml_filename}"},
headers={
"Content-Disposition": f"attachment; filename={graphml_filename}",
"filename": graphml_filename,
},
)
except Exception as e:
logger = load_pipeline_logger()
logger.error(
message="Could not fetch graphml file",
cause=e,
Expand All @@ -63,3 +70,36 @@ async def get_graphml_file(
status_code=500,
detail=f"Could not fetch graphml file for '{container_name}'.",
)


@graph_route.get(
"/stats/{index_name}",
summary="Retrieve basic graph statistics, number of nodes and edges",
response_model=GraphDataResponse,
responses={200: {"model": GraphDataResponse}},
response_description="Retrieve the number of nodes and edges from the index graph",
)
async def retrieve_graph_stats(index_name: str):
logger = load_pipeline_logger()

# validate index_name and graphml file existence
sanitized_index_name = sanitize_name(index_name)
graphml_filename = "graph.graphml"
graphml_filepath = f"output/{graphml_filename}" # expected file location of the graph based on the workflow
validate_index_file_exist(sanitized_index_name, graphml_filepath)

try:
azure_client_manager = AzureClientManager()
storage_client = azure_client_manager.get_blob_service_client().get_blob_client(
container=sanitized_index_name, blob=graphml_filepath
)
blob_data = storage_client.download_blob().readall()
bytes_io = BytesIO(blob_data)
g = nx.read_graphml(bytes_io)
return GraphDataResponse(nodes=len(g.nodes), edges=len(g.edges))
except Exception:
logger.error("Could not retrieve graph data file")
raise HTTPException(
status_code=500,
detail=f"Could not retrieve graph statistics for index '{index_name}'.",
)
49 changes: 38 additions & 11 deletions backend/graphrag_app/api/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
UploadFile,
status,
)
from graphrag.config.enums import IndexingMethod
from kubernetes import (
client as kubernetes_client,
)
Expand Down Expand Up @@ -57,8 +58,12 @@ async def schedule_index_job(
index_container_name: str,
entity_extraction_prompt: UploadFile | None = None,
entity_summarization_prompt: UploadFile | None = None,
community_summarization_prompt: UploadFile | None = None,
community_summarization_graph_prompt: UploadFile | None = None,
community_summarization_text_prompt: UploadFile | None = None,
indexing_method: IndexingMethod = IndexingMethod.Standard.value,
):
indexing_method = IndexingMethod(indexing_method).value

azure_client_manager = AzureClientManager()
blob_service_client = azure_client_manager.get_blob_service_client()
pipelinejob = PipelineJob()
Expand Down Expand Up @@ -87,9 +92,14 @@ async def schedule_index_job(
if entity_summarization_prompt
else None
)
community_summarization_prompt_content = (
community_summarization_prompt.file.read().decode("utf-8")
if community_summarization_prompt
community_summarization_graph_content = (
community_summarization_graph_prompt.file.read().decode("utf-8")
if community_summarization_graph_prompt
else None
)
community_summarization_text_content = (
community_summarization_text_prompt.file.read().decode("utf-8")
if community_summarization_text_prompt
else None
)

Expand Down Expand Up @@ -120,9 +130,14 @@ async def schedule_index_job(
) = []
existing_job._entity_extraction_prompt = entity_extraction_prompt_content
existing_job._entity_summarization_prompt = entity_summarization_prompt_content
existing_job._community_summarization_prompt = (
community_summarization_prompt_content
existing_job.community_summarization_graph_prompt = (
community_summarization_graph_content
)
existing_job.community_summarization_text_prompt = (
community_summarization_text_content
)
existing_job._indexing_method = indexing_method

existing_job._epoch_request_time = int(time())
existing_job.update_db()
else:
Expand All @@ -132,7 +147,9 @@ async def schedule_index_job(
human_readable_storage_name=storage_container_name,
entity_extraction_prompt=entity_extraction_prompt_content,
entity_summarization_prompt=entity_summarization_prompt_content,
community_summarization_prompt=community_summarization_prompt_content,
community_summarization_graph_prompt=community_summarization_graph_content,
community_summarization_text_prompt=community_summarization_text_content,
indexing_method=indexing_method,
status=PipelineJobState.SCHEDULED,
)

Expand All @@ -155,7 +172,7 @@ async def get_all_index_names(
try:
for item in container_store_client.read_all_items():
if item["type"] == "index":
items.append(item["human_readable_name"])
items.append(item["human_readable_index_name"])
except Exception as e:
logger = load_pipeline_logger()
logger.error(
Expand Down Expand Up @@ -245,9 +262,19 @@ async def delete_index(
credential=DefaultAzureCredential(),
audience=os.environ["AI_SEARCH_AUDIENCE"],
)
ai_search_index_name = f"{sanitized_container_name}_description_embedding"
if ai_search_index_name in index_client.list_index_names():
index_client.delete_index(ai_search_index_name)

index_names = index_client.list_index_names()
ai_search_index_report_name = f"{sanitized_container_name}-community-full_content"
if ai_search_index_report_name in index_names:
index_client.delete_index(ai_search_index_report_name)

ai_search_index_description_name = f"{sanitized_container_name}-entity-description"
if ai_search_index_description_name in index_names:
index_client.delete_index(ai_search_index_description_name)

ai_search_index_text_name = f"{sanitized_container_name}-text_unit-text"
if ai_search_index_text_name in index_names:
index_client.delete_index(ai_search_index_text_name)

except Exception as e:
logger = load_pipeline_logger()
Expand Down
41 changes: 23 additions & 18 deletions backend/graphrag_app/api/prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
from pathlib import Path

import graphrag.api as api
import yaml
from fastapi import (
APIRouter,
Depends,
HTTPException,
status,
)
from graphrag.config.create_graphrag_config import create_graphrag_config
from graphrag.config.load_config import load_config
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.logger.rich_progress import RichProgressLogger
from graphrag.prompt_tune.types import DocSelectionType

from graphrag_app.logger.load_logger import load_pipeline_logger
from graphrag_app.utils.azure_clients import AzureClientManager
Expand All @@ -32,7 +34,7 @@
)
async def generate_prompts(
container_name: str,
limit: int = 5,
limit: int = 15,
sanitized_container_name: str = Depends(sanitize_name),
):
"""
Expand All @@ -48,21 +50,31 @@ async def generate_prompts(
detail=f"Storage container '{container_name}' does not exist.",
)

# load pipeline configuration file (settings.yaml) for input data and other settings
ROOT_DIR = Path(__file__).resolve().parent.parent.parent
with (ROOT_DIR / "scripts/settings.yaml").open("r") as f:
data = yaml.safe_load(f)
data["input"]["container_name"] = sanitized_container_name
graphrag_config = create_graphrag_config(values=data, root_dir=".")
# load custom pipeline settings
ROOT_DIR = Path(__file__).resolve().parent.parent.parent / "scripts/settings.yaml"

# layer the custom settings on top of the default configuration settings of graphrag
graphrag_config: GraphRagConfig = load_config(
root_dir=ROOT_DIR.parent,
config_filepath=ROOT_DIR
)
graphrag_config.input.container_name = sanitized_container_name

# generate prompts
try:
prompts: tuple[str, str, str] = await api.generate_indexing_prompts(
config=graphrag_config,
logger=RichProgressLogger(prefix=sanitized_container_name),
root=".",
limit=limit,
selection_method="random",
selection_method=DocSelectionType.AUTO,
)
prompt_content = {
"entity_extraction_prompt": prompts[0],
"entity_summarization_prompt": prompts[1],
"community_summarization_prompt": prompts[2],
}
return prompt_content # returns a fastapi.responses.JSONResponse object
except Exception as e:
logger = load_pipeline_logger()
error_details = {
Expand All @@ -77,11 +89,4 @@ async def generate_prompts(
raise HTTPException(
status_code=500,
detail=f"Error generating prompts for data in '{container_name}'. Please try a lower limit.",
)

prompt_content = {
"entity_extraction_prompt": prompts[0],
"entity_summarization_prompt": prompts[1],
"community_summarization_prompt": prompts[2],
}
return prompt_content # returns a fastapi.responses.JSONResponse object
)
Loading
Loading