From d18878d2707b26cf7f9fec21158e3225f34ecda6 Mon Sep 17 00:00:00 2001 From: akarim23131 Date: Mon, 21 Apr 2025 22:05:24 +1000 Subject: [PATCH 1/3] Added raw_chunks parameter to search methods --- graphrag/cli/main.py | 11 + graphrag/cli/query.py | 254 +++++++++++++++--- graphrag/query/factory.py | 40 +-- .../structured_search/drift_search/search.py | 56 +++- .../structured_search/global_search/search.py | 33 +++ .../structured_search/local_search/search.py | 20 +- 6 files changed, 364 insertions(+), 50 deletions(-) diff --git a/graphrag/cli/main.py b/graphrag/cli/main.py index 610427871f..91031370d4 100644 --- a/graphrag/cli/main.py +++ b/graphrag/cli/main.py @@ -375,6 +375,14 @@ def _prompt_tune_cli( def _query_cli( method: Annotated[SearchMethod, typer.Option(help="The query algorithm to use.")], query: Annotated[str, typer.Option(help="The query to execute.")], + raw_chunks: Annotated[ + bool, + typer.Option( + "--raw-chunks", + help="Show raw chunks retrieved from vector store before final response.", + is_flag=True + ), + ] = False, config: Annotated[ Path | None, typer.Option( @@ -451,6 +459,7 @@ def _query_cli( response_type=response_type, streaming=streaming, query=query, + raw_chunks=raw_chunks, # Added for local search ) case SearchMethod.GLOBAL: run_global_search( @@ -462,6 +471,7 @@ def _query_cli( response_type=response_type, streaming=streaming, query=query, + raw_chunks=raw_chunks, # Added for global search ) case SearchMethod.DRIFT: run_drift_search( @@ -472,6 +482,7 @@ def _query_cli( streaming=streaming, response_type=response_type, query=query, + raw_chunks=raw_chunks, # Added for drift search ) case SearchMethod.BASIC: run_basic_search( diff --git a/graphrag/cli/query.py b/graphrag/cli/query.py index 937ca69bbf..86038f198e 100644 --- a/graphrag/cli/query.py +++ b/graphrag/cli/query.py @@ -21,7 +21,168 @@ logger = PrintProgressLogger("") - +class RawChunksCallback(NoopQueryCallbacks): + def on_context(self, context: Any) -> None: + try: + # For DRIFT search's three-step process + if isinstance(context, dict) and 'initial_context' in context: + print("\n=== DRIFT SEARCH RAW CHUNKS ===") + + # Step 1: Primer Search + print("\nSTEP 1 - PRIMER SEARCH:") + if hasattr(context['initial_context'], 'context_chunks'): + chunks = context['initial_context'].context_chunks + if isinstance(chunks, dict) and 'reports' in chunks: + for i, report in enumerate(chunks['reports'], 1): + print(f"\nReport {i}:") + print(f"Title: {report.get('title', 'N/A')}") + print(f"Text: {report.get('text', 'N/A')}") + else: + print(chunks) + + # Step 2: Follow-up Searches + print("\nSTEP 2 - FOLLOW-UP SEARCHES:") + if 'followup_contexts' in context: + for i, followup in enumerate(context['followup_contexts'], 1): + print(f"\nFollow-up {i}:") + if hasattr(followup, 'query'): + print(f"Question: {followup.query}") + if hasattr(followup, 'context_chunks'): + print("Retrieved Context:") + if isinstance(followup.context_chunks, dict): + for key, value in followup.context_chunks.items(): + print(f"\n{key}: {value}") + else: + print(followup.context_chunks) + + # Step 3: Final Synthesis + print("\nSTEP 3 - FINAL SYNTHESIS:") + if 'final_context' in context and hasattr(context['final_context'], 'context_chunks'): + final_chunks = context['final_context'].context_chunks + if isinstance(final_chunks, dict): + for key, value in final_chunks.items(): + print(f"\n{key}: {value}") + else: + print(final_chunks) + + print("\n=== END DRIFT SEARCH RAW CHUNKS ===\n") + + + # For Global and Local searches + else: + print("\n=== RAW CHUNKS FROM VECTOR STORE ===") + + # First try to access context_records if available + if hasattr(context, 'context_records'): + records = context.context_records + if isinstance(records, dict): + # Handle reports + if 'reports' in records: + print("\nReports:") + for i, report in enumerate(records['reports'], 1): + print(f"\nReport {i}:") + if isinstance(report, dict): + if 'title' in report: + print(f"Title: {report['title']}") + if 'text' in report: + print(f"Text: {report['text']}") + if 'content' in report: + print(f"Content: {report['content']}") + + # Handle text units + if 'text_units' in records: + print("\nText Units:") + for i, unit in enumerate(records['text_units'], 1): + print(f"\nText Unit {i}:") + if isinstance(unit, dict): + if 'text' in unit: + print(f"Text: {unit['text']}") + if 'source' in unit: + print(f"Source: {unit['source']}") + + # Handle relationships + if 'relationships' in records: + print("\nRelationships:") + for i, rel in enumerate(records['relationships'], 1): + print(f"\nRelationship {i}: {rel}") + + # Fallback to direct attributes if context_records not available + else: + # Handle reports + if hasattr(context, 'reports'): + print("\nReports:") + for i, report in enumerate(context.reports, 1): + print(f"\nReport {i}:") + if isinstance(report, dict): + if 'title' in report: + print(f"Title: {report['title']}") + if 'text' in report: + print(f"Text: {report['text']}") + if 'content' in report: + print(f"Content: {report['content']}") + + # Handle text units + if hasattr(context, 'text_units'): + print("\nText Units:") + for i, unit in enumerate(context.text_units, 1): + print(f"\nText Unit {i}:") + if isinstance(unit, dict): + if 'text' in unit: + print(f"Text: {unit['text']}") + if 'source' in unit: + print(f"Source: {unit['source']}") + + # Handle relationships + if hasattr(context, 'relationships'): + print("\nRelationships:") + for i, rel in enumerate(context.relationships, 1): + print(f"\nRelationship {i}: {rel}") + + # Final fallback to context_chunks + if not (hasattr(context, 'context_records') or + hasattr(context, 'reports') or + hasattr(context, 'text_units') or + hasattr(context, 'relationships')): + if hasattr(context, 'context_chunks'): + print("\nContext Chunks:") + chunks = context.context_chunks + if isinstance(chunks, dict): + for key, value in chunks.items(): + print(f"\n{key}:") + if isinstance(value, list): + for i, item in enumerate(value, 1): + if isinstance(item, dict): + print(f"\nItem {i}:") + for k, v in item.items(): + print(f"{k}: {v}") + else: + print(f"\nItem {i}: {item}") + else: + print(value) + elif isinstance(chunks, list): + for i, chunk in enumerate(chunks, 1): + if isinstance(chunk, dict): + print(f"\nChunk {i}:") + for k, v in chunk.items(): + print(f"{k}: {v}") + else: + print(f"\nChunk {i}: {chunk}") + + # If nothing was found, print debug info + if not any([hasattr(context, attr) for attr in ['context_records', 'reports', 'text_units', 'relationships', 'context_chunks']]): + # print("\nDebug Info:") + # print(f"Context type: {type(context)}") + # print(f"Available attributes: {dir(context)}") + print(f"Raw context: {context}") + + print("\n=== END RAW CHUNKS ===\n") + + except Exception as e: + print(f"\nError displaying chunks: {str(e)}") + print(f"Context type: {type(context)}") + print(f"Context attributes: {dir(context)}") + + def run_global_search( config_filepath: Path | None, data_dir: Path | None, @@ -31,16 +192,24 @@ def run_global_search( response_type: str, streaming: bool, query: str, + raw_chunks: bool = False ): """Perform a global search with a given query. Loads index files required for global search and calls the Query API. """ + #print(f"\nDEBUG: run_global_search called with raw_chunks={raw_chunks}") + root = root_dir.resolve() cli_overrides = {} if data_dir: cli_overrides["output.base_dir"] = str(data_dir) config = load_config(root, config_filepath, cli_overrides) + + # Initialize callbacks list + callbacks = [] + if raw_chunks: + callbacks.append(RawChunksCallback()) dataframe_dict = _resolve_output_files( config=config, @@ -75,6 +244,7 @@ def run_global_search( response_type=response_type, streaming=streaming, query=query, + callbacks=callbacks ) ) logger.success(f"Global Search Response:\n{response}") @@ -88,7 +258,6 @@ def run_global_search( final_community_reports: pd.DataFrame = dataframe_dict["community_reports"] if streaming: - async def run_streaming_search(): full_response = "" context_data = {} @@ -97,8 +266,8 @@ def on_context(context: Any) -> None: nonlocal context_data context_data = context - callbacks = NoopQueryCallbacks() - callbacks.on_context = on_context + global_callbacks = callbacks + [NoopQueryCallbacks()] # Combine with existing callbacks + global_callbacks[-1].on_context = on_context async for stream_chunk in api.global_search_streaming( config=config, @@ -109,7 +278,7 @@ def on_context(context: Any) -> None: dynamic_community_selection=dynamic_community_selection, response_type=response_type, query=query, - callbacks=[callbacks], + callbacks=global_callbacks, # Use combined callbacks ): full_response += stream_chunk print(stream_chunk, end="") # noqa: T201 @@ -129,6 +298,7 @@ def on_context(context: Any) -> None: dynamic_community_selection=dynamic_community_selection, response_type=response_type, query=query, + callbacks=callbacks ) ) logger.success(f"Global Search Response:\n{response}") @@ -137,6 +307,8 @@ def on_context(context: Any) -> None: return response, context_data + + def run_local_search( config_filepath: Path | None, data_dir: Path | None, @@ -145,17 +317,26 @@ def run_local_search( response_type: str, streaming: bool, query: str, + raw_chunks: bool = False, ): """Perform a local search with a given query. Loads index files required for local search and calls the Query API. """ + # Add debug print at start of function + print(f"\nDEBUG: run_local_search called with raw_chunks={raw_chunks}") + root = root_dir.resolve() cli_overrides = {} if data_dir: cli_overrides["output.base_dir"] = str(data_dir) config = load_config(root, config_filepath, cli_overrides) - + + # Initialize callbacks list + callbacks = [] + if raw_chunks: + callbacks.append(RawChunksCallback()) + dataframe_dict = _resolve_output_files( config=config, output_list=[ @@ -169,6 +350,7 @@ def run_local_search( "covariates", ], ) + # Call the Multi-Index Local Search API if dataframe_dict["multi-index"]: final_entities_list = dataframe_dict["entities"] @@ -202,6 +384,7 @@ def run_local_search( response_type=response_type, streaming=streaming, query=query, + callbacks=callbacks, ) ) logger.success(f"Local Search Response:\n{response}") @@ -218,7 +401,6 @@ def run_local_search( final_covariates: pd.DataFrame | None = dataframe_dict["covariates"] if streaming: - async def run_streaming_search(): full_response = "" context_data = {} @@ -227,8 +409,8 @@ def on_context(context: Any) -> None: nonlocal context_data context_data = context - callbacks = NoopQueryCallbacks() - callbacks.on_context = on_context + local_callbacks = callbacks + [NoopQueryCallbacks()] + local_callbacks[-1].on_context = on_context async for stream_chunk in api.local_search_streaming( config=config, @@ -241,30 +423,31 @@ def on_context(context: Any) -> None: community_level=community_level, response_type=response_type, query=query, - callbacks=[callbacks], + callbacks=local_callbacks, ): full_response += stream_chunk - print(stream_chunk, end="") # noqa: T201 - sys.stdout.flush() # flush output buffer to display text immediately - print() # noqa: T201 + print(stream_chunk, end="") + sys.stdout.flush() + print() return full_response, context_data return asyncio.run(run_streaming_search()) - # not streaming - response, context_data = asyncio.run( - api.local_search( - config=config, - entities=final_entities, - communities=final_communities, - community_reports=final_community_reports, - text_units=final_text_units, - relationships=final_relationships, - covariates=final_covariates, - community_level=community_level, - response_type=response_type, - query=query, + else: + response, context_data = asyncio.run( + api.local_search( + config=config, + entities=final_entities, + communities=final_communities, + community_reports=final_community_reports, + text_units=final_text_units, + relationships=final_relationships, + covariates=final_covariates, + community_level=community_level, + response_type=response_type, + query=query, + callbacks=callbacks, + ) ) - ) logger.success(f"Local Search Response:\n{response}") # NOTE: we return the response and context data here purely as a complete demonstration of the API. # External users should use the API directly to get the response and context data. @@ -279,16 +462,24 @@ def run_drift_search( response_type: str, streaming: bool, query: str, + raw_chunks: bool = False # Added raw_chunks parameter ): """Perform a local search with a given query. Loads index files required for local search and calls the Query API. """ + print(f"\nDEBUG: run_drift_search called with raw_chunks={raw_chunks}") + root = root_dir.resolve() cli_overrides = {} if data_dir: cli_overrides["output.base_dir"] = str(data_dir) config = load_config(root, config_filepath, cli_overrides) + + # Initialize callbacks list + callbacks = [] + if raw_chunks: + callbacks.append(RawChunksCallback()) dataframe_dict = _resolve_output_files( config=config, @@ -327,6 +518,7 @@ def run_drift_search( response_type=response_type, streaming=streaming, query=query, + callbacks=callbacks # Added callbacks parameter ) ) logger.success(f"DRIFT Search Response:\n{response}") @@ -351,8 +543,9 @@ def on_context(context: Any) -> None: nonlocal context_data context_data = context - callbacks = NoopQueryCallbacks() - callbacks.on_context = on_context + drift_callbacks = callbacks + [NoopQueryCallbacks()] # Combine with existing callbacks + drift_callbacks[-1].on_context = on_context + async for stream_chunk in api.drift_search_streaming( config=config, @@ -364,7 +557,7 @@ def on_context(context: Any) -> None: community_level=community_level, response_type=response_type, query=query, - callbacks=[callbacks], + callbacks=drift_callbacks, # Use combined callbacks ): full_response += stream_chunk print(stream_chunk, end="") # noqa: T201 @@ -386,6 +579,7 @@ def on_context(context: Any) -> None: community_level=community_level, response_type=response_type, query=query, + callbacks=callbacks # Added callbacks parameter ) ) logger.success(f"DRIFT Search Response:\n{response}") diff --git a/graphrag/query/factory.py b/graphrag/query/factory.py index decc3f0c3d..6f68165d88 100644 --- a/graphrag/query/factory.py +++ b/graphrag/query/factory.py @@ -45,6 +45,7 @@ def get_local_search_engine( description_embedding_store: BaseVectorStore, system_prompt: str | None = None, callbacks: list[QueryCallbacks] | None = None, + raw_chunks: bool = True, ) -> LocalSearch: """Create a local search engine based on data + configuration.""" model_settings = config.get_language_model_config(config.local_search.chat_model_id) @@ -77,23 +78,26 @@ def get_local_search_engine( ls_config = config.local_search + # Create context builder without raw_chunks + context_builder = LocalSearchMixedContext( + community_reports=reports, + text_units=text_units, + entities=entities, + relationships=relationships, + covariates=covariates, + entity_text_embeddings=description_embedding_store, + embedding_vectorstore_key=EntityVectorStoreKey.ID, + text_embedder=embedding_model, + token_encoder=token_encoder + ) + return LocalSearch( model=chat_model, system_prompt=system_prompt, - context_builder=LocalSearchMixedContext( - community_reports=reports, - text_units=text_units, - entities=entities, - relationships=relationships, - covariates=covariates, - entity_text_embeddings=description_embedding_store, - embedding_vectorstore_key=EntityVectorStoreKey.ID, # if the vectorstore uses entity title as ids, set this to EntityVectorStoreKey.TITLE - text_embedder=embedding_model, - token_encoder=token_encoder, - ), + context_builder=context_builder, # Use the created context_builder token_encoder=token_encoder, model_params={ - "max_tokens": ls_config.llm_max_tokens, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 1000=1500) + "max_tokens": ls_config.llm_max_tokens, "temperature": ls_config.temperature, "top_p": ls_config.top_p, "n": ls_config.n, @@ -109,14 +113,14 @@ def get_local_search_engine( "include_relationship_weight": True, "include_community_rank": False, "return_candidate_context": False, - "embedding_vectorstore_key": EntityVectorStoreKey.ID, # set this to EntityVectorStoreKey.TITLE if the vectorstore uses entity title as ids - "max_tokens": ls_config.max_tokens, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000) + "embedding_vectorstore_key": EntityVectorStoreKey.ID, + "max_tokens": ls_config.max_tokens, }, response_type=response_type, callbacks=callbacks, + raw_chunks=raw_chunks # Only pass raw_chunks here ) - def get_global_search_engine( config: GraphRagConfig, reports: list[CommunityReport], @@ -128,6 +132,7 @@ def get_global_search_engine( reduce_system_prompt: str | None = None, general_knowledge_inclusion_prompt: str | None = None, callbacks: list[QueryCallbacks] | None = None, + raw_chunks: bool = True, # Added raw_chunks parameter ) -> GlobalSearch: """Create a global search engine based on data + configuration.""" # TODO: Global search should select model based on config?? @@ -207,6 +212,7 @@ def get_global_search_engine( concurrent_coroutines=gs_config.concurrency, response_type=response_type, callbacks=callbacks, + raw_chunks=raw_chunks # Added raw_chunks parameter ) @@ -221,6 +227,7 @@ def get_drift_search_engine( local_system_prompt: str | None = None, reduce_system_prompt: str | None = None, callbacks: list[QueryCallbacks] | None = None, + raw_chunks: bool = True, # Added raw_chunks parameter ) -> DRIFTSearch: """Create a local search engine based on data + configuration.""" chat_model_settings = config.get_language_model_config( @@ -272,6 +279,7 @@ def get_drift_search_engine( ), token_encoder=token_encoder, callbacks=callbacks, + raw_chunks=raw_chunks # Added raw_chunks parameter ) @@ -322,7 +330,7 @@ def get_basic_search_engine( token_encoder=token_encoder, ), token_encoder=token_encoder, - model_params={ + llm_params={ "max_tokens": ls_config.llm_max_tokens, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 1000=1500) "temperature": ls_config.temperature, "top_p": ls_config.top_p, diff --git a/graphrag/query/structured_search/drift_search/search.py b/graphrag/query/structured_search/drift_search/search.py index 14e12120cb..a2bd616665 100644 --- a/graphrag/query/structured_search/drift_search/search.py +++ b/graphrag/query/structured_search/drift_search/search.py @@ -38,6 +38,7 @@ def __init__( token_encoder: tiktoken.Encoding | None = None, query_state: QueryState | None = None, callbacks: list[QueryCallbacks] | None = None, + raw_chunks: bool = True, # Added raw_chunks parameter ): """ Initialize the DRIFTSearch class. @@ -60,6 +61,7 @@ def __init__( token_encoder=token_encoder, ) self.callbacks = callbacks or [] + self.raw_chunks = raw_chunks # Store raw_chunks parameter self.local_search = self.init_local_search() def init_local_search(self) -> LocalSearch: @@ -83,7 +85,7 @@ def init_local_search(self) -> LocalSearch: "max_tokens": self.context_builder.config.local_search_max_data_tokens, } - model_params = { + llm_params = { "max_tokens": self.context_builder.config.local_search_llm_max_gen_tokens, "temperature": self.context_builder.config.local_search_temperature, "response_format": {"type": "json_object"}, @@ -94,10 +96,11 @@ def init_local_search(self) -> LocalSearch: system_prompt=self.context_builder.local_system_prompt, context_builder=self.context_builder.local_mixed_context, token_encoder=self.token_encoder, - model_params=model_params, + model_params=llm_params, context_builder_params=local_context_params, response_type="multiple paragraphs", callbacks=self.callbacks, + raw_chunks=self.raw_chunks, # Pass raw_chunks to LocalSearch ) def _process_primer_results( @@ -202,8 +205,17 @@ async def search( # Check if query state is empty if not self.query_state.graph: - # Prime the search with the primer + if self.raw_chunks: + print("\n=== STEP 1: PRIMER SEARCH ===") + print("Query:", query) + primer_context, token_ct = await self.context_builder.build_context(query) + + if self.raw_chunks: + print("\nPrimer Context:") + print(primer_context) + + llm_calls["build_context"] = token_ct["llm_calls"] prompt_tokens["build_context"] = token_ct["prompt_tokens"] output_tokens["build_context"] = token_ct["prompt_tokens"] @@ -211,6 +223,13 @@ async def search( primer_response = await self.primer.search( query=query, top_k_reports=primer_context ) + + if self.raw_chunks: + print("\nPrimer Response:") + print(primer_response.response) + print("=== END PRIMER SEARCH ===\n") + + llm_calls["primer"] = primer_response.llm_calls prompt_tokens["primer"] = primer_response.prompt_tokens output_tokens["primer"] = primer_response.output_tokens @@ -224,6 +243,10 @@ async def search( epochs = 0 llm_call_offset = 0 while epochs < self.context_builder.config.n_depth: + + if self.raw_chunks: + print(f"\n=== STEP 2: ACTION SEARCH (Epoch {epochs + 1}) ===") + actions = self.query_state.rank_incomplete_actions() if len(actions) == 0: log.info("No more actions to take. Exiting DRIFT loop.") @@ -232,10 +255,25 @@ async def search( llm_call_offset += ( len(actions) - self.context_builder.config.drift_k_followups ) + + if self.raw_chunks: + print(f"\nProcessing {len(actions)} actions:") + for i, action in enumerate(actions, 1): + print(f"\nAction {i}:") + print(f"Query: {action.query}") + print(f"Follow-ups: {action.follow_ups}") + # Process actions results = await self._search_step( global_query=query, search_engine=self.local_search, actions=actions ) + + if self.raw_chunks: + print("\nAction Results:") + for i, result in enumerate(results, 1): + print(f"\nResult {i}:") + print(result.response if hasattr(result, 'response') else result) + print(f"=== END ACTION SEARCH (Epoch {epochs + 1}) ===\n") # Update query state for action in results: @@ -258,6 +296,13 @@ async def search( reduced_response = response_state if reduce: + + if self.raw_chunks: + print("\n=== STEP 3: REDUCTION ===") + print("Response state to be reduced:") + print(response_state) + + # Reduce response_state to a single comprehensive response for callback in self.callbacks: callback.on_reduce_response_start(response_state) @@ -271,6 +316,11 @@ async def search( max_tokens=self.context_builder.config.reduce_max_tokens, temperature=self.context_builder.config.reduce_temperature, ) + + if self.raw_chunks: + print("\nReduced Response:") + print(reduced_response) + print("=== END REDUCTION ===\n") for callback in self.callbacks: callback.on_reduce_response_end(reduced_response) diff --git a/graphrag/query/structured_search/global_search/search.py b/graphrag/query/structured_search/global_search/search.py index f2e82af899..ef7937d50d 100644 --- a/graphrag/query/structured_search/global_search/search.py +++ b/graphrag/query/structured_search/global_search/search.py @@ -75,6 +75,7 @@ def __init__( reduce_llm_params: dict[str, Any] = DEFAULT_REDUCE_LLM_PARAMS, context_builder_params: dict[str, Any] | None = None, concurrent_coroutines: int = 32, + raw_chunks: bool = True, # Added raw_chunks parameter ): super().__init__( model=model, @@ -91,6 +92,7 @@ def __init__( ) self.callbacks = callbacks or [] self.max_data_tokens = max_data_tokens + self.raw_chunks = raw_chunks # Store raw_chunks parameter self.map_llm_params = map_llm_params self.reduce_llm_params = reduce_llm_params @@ -157,6 +159,20 @@ async def search( conversation_history=conversation_history, **self.context_builder_params, ) + + # Print raw chunks if enabled + if self.raw_chunks: + print("\n=== CONTEXT SENT TO LLM (GLOBAL SEARCH) ===") + print("\nInitial Context Chunks:") + print(context_result.context_chunks) + print("\nCommunity Reports:") + if hasattr(context_result, 'community_reports'): + for i, report in enumerate(context_result.community_reports, 1): + print(f"\nReport {i}:") + print(report) + print("=== END INITIAL CONTEXT ===\n") + + llm_calls["build_context"] = context_result.llm_calls prompt_tokens["build_context"] = context_result.prompt_tokens output_tokens["build_context"] = context_result.output_tokens @@ -170,6 +186,14 @@ async def search( ) for data in context_result.context_chunks ]) + + # Print map responses if raw_chunks is enabled + if self.raw_chunks: + print("\n=== MAP RESPONSES ===") + for i, response in enumerate(map_responses, 1): + print(f"\nBatch {i} Response:") + print(response.response) + print("=== END MAP RESPONSES ===\n") for callback in self.callbacks: callback.on_map_response_end(map_responses) @@ -185,6 +209,15 @@ async def search( query=query, **self.reduce_llm_params, ) + + # Print reduce context if raw_chunks is enabled + if self.raw_chunks: + print("\n=== REDUCE CONTEXT ===") + print("\nReduce Input:") + print(reduce_response.context_text) + print("=== END REDUCE CONTEXT ===\n") + + llm_calls["reduce"] = reduce_response.llm_calls prompt_tokens["reduce"] = reduce_response.prompt_tokens output_tokens["reduce"] = reduce_response.output_tokens diff --git a/graphrag/query/structured_search/local_search/search.py b/graphrag/query/structured_search/local_search/search.py index ed55eb2876..4ce303c630 100644 --- a/graphrag/query/structured_search/local_search/search.py +++ b/graphrag/query/structured_search/local_search/search.py @@ -43,6 +43,7 @@ def __init__( callbacks: list[QueryCallbacks] | None = None, model_params: dict[str, Any] = DEFAULT_LLM_PARAMS, context_builder_params: dict | None = None, + raw_chunks: bool = True, ): super().__init__( model=model, @@ -54,6 +55,7 @@ def __init__( self.system_prompt = system_prompt or LOCAL_SEARCH_SYSTEM_PROMPT self.callbacks = callbacks or [] self.response_type = response_type + self.raw_chunks = raw_chunks async def search( self, @@ -71,6 +73,13 @@ async def search( **kwargs, **self.context_builder_params, ) + + if self.raw_chunks: + print("\n=== CONTEXT SENT TO LLM ===") + print(f"Context chunks used for LLM prompt:") + print(context_result.context_chunks) + print("=== END CONTEXT ===\n") + llm_calls["build_context"] = context_result.llm_calls prompt_tokens["build_context"] = context_result.prompt_tokens output_tokens["build_context"] = context_result.output_tokens @@ -95,6 +104,10 @@ async def search( full_response = "" + # Call callbacks with context before formatting prompt + for callback in self.callbacks: + callback.on_context(context_result) + async for response in self.model.achat_stream( prompt=query, history=history_messages, @@ -150,9 +163,14 @@ async def stream_search( **self.context_builder_params, ) log.info("GENERATE ANSWER: %s. QUERY: %s", start_time, query) + + search_prompt = self.system_prompt.format( - context_data=context_result.context_chunks, response_type=self.response_type + context_data=context_result.context_chunks, + response_type=self.response_type ) + + history_messages = [ {"role": "system", "content": search_prompt}, ] From 0e42763abf00b2c05b781471bcd1a0a14cd16b15 Mon Sep 17 00:00:00 2001 From: akarim23131 Date: Wed, 23 Apr 2025 10:41:39 +1000 Subject: [PATCH 2/3] Resolved conflict:kept model_params for consistency --- graphrag/query/structured_search/drift_search/search.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/graphrag/query/structured_search/drift_search/search.py b/graphrag/query/structured_search/drift_search/search.py index a2bd616665..572aa0ea10 100644 --- a/graphrag/query/structured_search/drift_search/search.py +++ b/graphrag/query/structured_search/drift_search/search.py @@ -85,7 +85,7 @@ def init_local_search(self) -> LocalSearch: "max_tokens": self.context_builder.config.local_search_max_data_tokens, } - llm_params = { + model_params = { "max_tokens": self.context_builder.config.local_search_llm_max_gen_tokens, "temperature": self.context_builder.config.local_search_temperature, "response_format": {"type": "json_object"}, @@ -96,7 +96,7 @@ def init_local_search(self) -> LocalSearch: system_prompt=self.context_builder.local_system_prompt, context_builder=self.context_builder.local_mixed_context, token_encoder=self.token_encoder, - model_params=llm_params, + model_params=model_params, context_builder_params=local_context_params, response_type="multiple paragraphs", callbacks=self.callbacks, From 7a778fee05f859c35208ecdccb289fa8a1bee436 Mon Sep 17 00:00:00 2001 From: akarim23131 Date: Wed, 23 Apr 2025 12:15:13 +1000 Subject: [PATCH 3/3] Add get_openai_model_parameters_from_config function --- .../language_model/providers/fnllm/utils.py | 9 ++++ graphrag/query/factory.py | 47 +++++++++---------- 2 files changed, 31 insertions(+), 25 deletions(-) diff --git a/graphrag/language_model/providers/fnllm/utils.py b/graphrag/language_model/providers/fnllm/utils.py index a493089160..6458e6fe0a 100644 --- a/graphrag/language_model/providers/fnllm/utils.py +++ b/graphrag/language_model/providers/fnllm/utils.py @@ -130,3 +130,12 @@ def run_coroutine_sync(coroutine: Coroutine[Any, Any, T]) -> T: _thr.start() future = asyncio.run_coroutine_threadsafe(coroutine, _loop) return future.result() + +def get_openai_model_parameters_from_config(model_settings): + """Get OpenAI model parameters from config.""" + return { + "max_tokens": model_settings.max_tokens, + "temperature": model_settings.temperature, + "top_p": model_settings.top_p, + "n": model_settings.n, + } \ No newline at end of file diff --git a/graphrag/query/factory.py b/graphrag/query/factory.py index 6f68165d88..01a3faa679 100644 --- a/graphrag/query/factory.py +++ b/graphrag/query/factory.py @@ -14,6 +14,9 @@ from graphrag.data_model.relationship import Relationship from graphrag.data_model.text_unit import TextUnit from graphrag.language_model.manager import ModelManager +from graphrag.language_model.providers.fnllm.utils import ( + get_openai_model_parameters_from_config, +) from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey from graphrag.query.structured_search.basic_search.basic_context import ( BasicSearchContext, @@ -36,10 +39,10 @@ def get_local_search_engine( config: GraphRagConfig, - reports: list[CommunityReport], - text_units: list[TextUnit], - entities: list[Entity], - relationships: list[Relationship], + reports: dict[str, list[CommunityReport]], + text_units: dict[str, list[TextUnit]], + entities: dict[str, list[Entity]], + relationships: dict[str, list[Relationship]], covariates: dict[str, list[Covariate]], response_type: str, description_embedding_store: BaseVectorStore, @@ -78,30 +81,24 @@ def get_local_search_engine( ls_config = config.local_search - # Create context builder without raw_chunks - context_builder = LocalSearchMixedContext( - community_reports=reports, - text_units=text_units, - entities=entities, - relationships=relationships, - covariates=covariates, - entity_text_embeddings=description_embedding_store, - embedding_vectorstore_key=EntityVectorStoreKey.ID, - text_embedder=embedding_model, - token_encoder=token_encoder - ) + model_params = get_openai_model_parameters_from_config(model_settings) return LocalSearch( model=chat_model, system_prompt=system_prompt, - context_builder=context_builder, # Use the created context_builder + context_builder=LocalSearchMixedContext( + community_reports=reports, + text_units=text_units, + entities=entities, + relationships=relationships, + covariates=covariates, + entity_text_embeddings=description_embedding_store, + embedding_vectorstore_key=EntityVectorStoreKey.ID, # if the vectorstore uses entity title as ids, set this to EntityVectorStoreKey.TITLE + text_embedder=embedding_model, + token_encoder=token_encoder, + ), token_encoder=token_encoder, - model_params={ - "max_tokens": ls_config.llm_max_tokens, - "temperature": ls_config.temperature, - "top_p": ls_config.top_p, - "n": ls_config.n, - }, + model_params=model_params, context_builder_params={ "text_unit_prop": ls_config.text_unit_prop, "community_prop": ls_config.community_prop, @@ -113,8 +110,8 @@ def get_local_search_engine( "include_relationship_weight": True, "include_community_rank": False, "return_candidate_context": False, - "embedding_vectorstore_key": EntityVectorStoreKey.ID, - "max_tokens": ls_config.max_tokens, + "embedding_vectorstore_key": EntityVectorStoreKey.ID, # set this to EntityVectorStoreKey.TITLE if the vectorstore uses entity title as ids + "max_tokens": ls_config.max_tokens # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000) }, response_type=response_type, callbacks=callbacks,