-
Notifications
You must be signed in to change notification settings - Fork 0
Fixes to llm output parsing when using LLM based ranking #16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -104,6 +104,44 @@ async def send_prompt( | |
| return content, usage | ||
|
|
||
|
|
||
| def _preprocess_json_string(text: str) -> str: | ||
| """ | ||
| Pre-process JSON string to fix common LLM formatting errors. | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function actually modifies the LLM output, so I think a more appropriate name for it would be |
||
| Parameters | ||
| ---------- | ||
| text : str | ||
| Raw JSON text that may have formatting issues | ||
|
|
||
| Returns | ||
| ------- | ||
| str | ||
| Cleaned JSON text | ||
| """ | ||
| # Strip whitespace | ||
| text = text.strip() | ||
|
|
||
| # Fix common array termination issues like: ["item1", "item2".] | ||
| # Replace ".] with "] | ||
| text = re.sub(r'"\s*\.\s*\]', '"]', text) | ||
|
|
||
| # Fix missing closing quotes before array end: ["item1", "item2] | ||
| # Find patterns like: "something] where ] should be "] | ||
| text = re.sub(r'([^"])\]', r'\1"]', text) | ||
| # But undo if we just added "" which would be wrong | ||
| text = text.replace('"""]', '"]') | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To fix the probable errors that might come from the regex defined in line 130, we should use something like: |
||
|
|
||
| # Fix patterns like: "COLUMN:." or "TABLE:" (empty after colon) | ||
| # Remove items that are just "TYPE:." or "TYPE:" | ||
| text = re.sub(r'"\s*[A-Z]+\s*:\s*\.?\s*"', '""', text) | ||
|
|
||
| # Remove empty strings from arrays | ||
| text = re.sub(r',\s*""\s*,', ",", text) # middle | ||
| text = re.sub(r'\[\s*""\s*,', "[", text) # start | ||
| text = re.sub(r',\s*""\s*\]', "]", text) # end | ||
| return re.sub(r'\[\s*""\s*\]', "[]", text) # only item | ||
|
|
||
|
|
||
| def extract_json(text: str) -> dict[str, Any] | None: | ||
| """ | ||
| Extract JSON object from text with code blocks. | ||
|
|
@@ -121,13 +159,17 @@ def extract_json(text: str) -> dict[str, Any] | None: | |
| try: | ||
| if "```json" in text: | ||
| res = re.findall(r"```json([\s\S]*?)```", text) | ||
| json_res = json.loads(res[0]) | ||
| json_text = res[0] | ||
| elif "```" in text: | ||
| res = re.findall(r"```([\s\S]*?)```", text) | ||
| json_res = json.loads(res[0]) | ||
| json_text = res[0] | ||
| else: | ||
| json_res = json.loads(text) | ||
| return json_res | ||
| json_text = text | ||
|
|
||
| # Pre-process to fix common formatting errors | ||
| json_text = _preprocess_json_string(json_text) | ||
|
|
||
| return json.loads(json_text) | ||
| except Exception as e: | ||
| logger.warning(f"Failed to extract json from: {text}, error={e}") | ||
| return None | ||
|
|
@@ -172,6 +214,7 @@ def extract_object(text: str) -> Any | None: | |
| if obj is None: | ||
| obj = eval_literal(text) | ||
| if obj is None: | ||
| logger.error(f"Failed to extract object: {text}") | ||
| # Only log at debug level since callers typically handle None gracefully | ||
| logger.debug(f"Failed to extract object: {text}") | ||
| obj = None | ||
| return obj | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,6 +7,7 @@ | |
| from src.pipe.llm_util import extract_object | ||
| from src.pipe.rank_schema_prompts.v1 import RANK_SCHEMA_ITEMS_V1 | ||
| from src.pipe.schema_repo import DatabaseSchemaRepo | ||
| from src.utils.logging import logger | ||
|
|
||
|
|
||
| class RankSchemaItems(PromptProcessor): | ||
|
|
@@ -31,8 +32,72 @@ def __init__( | |
| super().__init__(prop_name, openai_config=openai_config, model=model) | ||
| self.schema_repo = DatabaseSchemaRepo(tables_path) | ||
|
|
||
| def _process_output(self, row: dict[str, Any], output: str) -> Any: | ||
| return extract_object(output) | ||
| def _sanitize_schema_item(self, item: str) -> str | None: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here, |
||
| """ | ||
| Sanitize a schema item reference to ensure proper formatting. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| item : str | ||
| Schema item reference (e.g., "TABLE:[name]" or "COLUMN:[table].[col]") | ||
|
|
||
| Returns | ||
| ------- | ||
| str or None | ||
| Sanitized schema item or None if invalid | ||
| """ | ||
| if not isinstance(item, str) or ":" not in item: | ||
| return None | ||
|
|
||
| parts = item.split(":", 1) | ||
| if len(parts) != 2: | ||
| return None | ||
|
|
||
| item_type, item_ref = parts | ||
|
|
||
| # Skip empty references | ||
| if not item_ref or item_ref.strip() in ["", ".", "[.]"]: | ||
| return None | ||
|
|
||
| # Ensure all opening brackets have closing brackets | ||
| bracket_count = item_ref.count("[") - item_ref.count("]") | ||
| if bracket_count > 0: | ||
| # Add missing closing brackets | ||
| item_ref = item_ref + ("]" * bracket_count) | ||
| elif bracket_count < 0: | ||
| # More closing than opening - invalid | ||
| return None | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here we don’t just need to have the same number of opening and closing brackets, it should exactly match our formats of |
||
|
|
||
| return f"{item_type}:{item_ref}" | ||
|
|
||
| def _process_output(self, row: dict[str, Any], output: str) -> list[str]: | ||
| result = extract_object(output) | ||
|
|
||
| # Handle None or invalid output | ||
| if result is None or not isinstance(result, list): | ||
| logger.warning( | ||
| f"LLM returned invalid schema items for question_id={row.get('question_id')}, " | ||
| f"falling back to all schema items" | ||
| ) | ||
| # Fallback: return all schema items | ||
| return self.extract_schema_items(row) | ||
|
|
||
| # Sanitize and filter out invalid items | ||
| sanitized_items = [] | ||
| for item in result: | ||
| sanitized = self._sanitize_schema_item(item) | ||
| if sanitized: | ||
| sanitized_items.append(sanitized) | ||
|
|
||
| # If sanitization removed everything, fallback to all items | ||
| if not sanitized_items: | ||
| logger.warning( | ||
| f"All LLM schema items were invalid for question_id={row.get('question_id')}, " | ||
| f"falling back to all schema items" | ||
| ) | ||
| return self.extract_schema_items(row) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be more helpful if we print this warning for every single item that was invalid so we can catch them more easily. In this case, it would be better to add this warning somewhere after line 90, inside the |
||
|
|
||
| return sanitized_items | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we are going to change |
||
|
|
||
| def extract_schema_items(self, row: dict[str, Any]) -> list[str]: | ||
| """ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,40 +1,52 @@ | ||
| """Schema ranking prompt template version 1.""" | ||
|
|
||
| RANK_SCHEMA_ITEMS_V1 = """ | ||
| You are given: | ||
| 1. A natural language question. | ||
| 2. A list of schema items of an underlying database. Each schema item is either | ||
| "TABLE:[table_name]" or "COLUMN:[table_name].[column_name] | ||
|
|
||
| Task: | ||
| Filter the given list and return a subset of these items that are most relevant to the given question. | ||
| You can include at most 4 tables and at most 5 columns for each table. | ||
|
|
||
| Example: | ||
| Question: “What is the name of the instructor who has the lowest salary?” | ||
| Schema Items: | ||
| RANK_SCHEMA_ITEMS_V1 = """You are a database schema analyzer. Your task is to identify which schema items are relevant for answering a given question. | ||
|
|
||
| ## Input Format | ||
| You will receive: | ||
| 1. A natural language question | ||
| 2. A list of schema items (tables and columns) from a database | ||
|
|
||
| Each schema item follows this format: | ||
| - Tables: "TABLE:[table_name]" | ||
| - Columns: "COLUMN:[table_name].[column_name]" | ||
|
|
||
| ## Task | ||
| Select the schema items needed to answer the question. Choose: | ||
| - Maximum 4 tables | ||
| - Maximum 5 columns per table | ||
|
|
||
| ## Output Requirements | ||
| 1. Return a valid JSON array of strings | ||
| 2. Select items EXACTLY as they appear in the input list - do not modify them | ||
| 3. Include only items that are relevant to answering the question | ||
| 4. Ensure the output is valid JSON (properly quoted and bracketed) | ||
|
|
||
| ## Example | ||
|
|
||
| Input Question: "What is the name of the instructor who has the lowest salary?" | ||
|
|
||
| Input Schema Items: | ||
| [ | ||
| "TABLE:[department]", | ||
| "COLUMN:[department].[name]", | ||
| "TABLE:[instructor]", | ||
| "COLUMN:[instructor].[name]" | ||
| "COLUMN:[instructor].[salary]" | ||
| "COLUMN:[instructor].[name]", | ||
| "COLUMN:[instructor].[salary]", | ||
| "COLUMN:[instructor].[age]" | ||
| ] | ||
|
|
||
| Output: | ||
| Expected Output: | ||
| [ | ||
| "TABLE:[instructor]", | ||
| "COLUMN:[instructor].[name]" | ||
| "COLUMN:[instructor].[name]", | ||
| "COLUMN:[instructor].[salary]" | ||
| ] | ||
|
|
||
| Now filter the following list of Schema Items based on the given question. | ||
| ## Your Turn | ||
|
|
||
| Question: {question} | ||
|
|
||
| Schema Items: {schema_items} | ||
|
|
||
| Instructions: | ||
| - Output only a valid list of strings. | ||
| - Do not include any additional text, explanations, or formatting.: | ||
| - All strings should be in double quotes. | ||
| """ | ||
| Output:""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I recommend also including some tests to check whether these new classes or functions work properly after refactoring to better handle edge cases or any issues they might cause.