diff --git a/docs/en_US/ai_tools.rst b/docs/en_US/ai_tools.rst index 1a81a0a9fa3..fdd6f438388 100644 --- a/docs/en_US/ai_tools.rst +++ b/docs/en_US/ai_tools.rst @@ -51,7 +51,7 @@ Select your preferred LLM provider from the dropdown: Use Claude models from Anthropic, or any Anthropic-compatible API provider. * **API URL**: Custom API endpoint URL (leave empty for default: https://api.anthropic.com/v1). - * **API Key File**: Path to a file containing your Anthropic API key (obtain from https://console.anthropic.com/). Optional when using a custom URL with a provider that does not require authentication. + * **API Key File**: Path to a file containing your Anthropic API key (obtain from https://console.anthropic.com/). This path refers to the filesystem where the pgAdmin server is running (e.g., inside the container if using Docker). The ``~`` prefix is expanded to the home directory of the user running the pgAdmin server process. Optional when using a custom URL with a provider that does not require authentication. * **Model**: Select from available Claude models (e.g., claude-sonnet-4-20250514). **OpenAI** @@ -59,7 +59,7 @@ Select your preferred LLM provider from the dropdown: LiteLLM, LM Studio, EXO, or other local inference servers). * **API URL**: Custom API endpoint URL (leave empty for default: https://api.openai.com/v1). Include the ``/v1`` path prefix if required by your provider. - * **API Key File**: Path to a file containing your OpenAI API key (obtain from https://platform.openai.com/). Optional when using a custom URL with a provider that does not require authentication. + * **API Key File**: Path to a file containing your OpenAI API key (obtain from https://platform.openai.com/). This path refers to the filesystem where the pgAdmin server is running (e.g., inside the container if using Docker). The ``~`` prefix is expanded to the home directory of the user running the pgAdmin server process. Optional when using a custom URL with a provider that does not require authentication. * **Model**: Select from available GPT models (e.g., gpt-4). **Ollama** diff --git a/docs/en_US/preferences.rst b/docs/en_US/preferences.rst index caaaca4a268..0fcf6413bd5 100644 --- a/docs/en_US/preferences.rst +++ b/docs/en_US/preferences.rst @@ -52,7 +52,10 @@ Use the fields on the *AI* panel to configure your LLM provider: to use an Anthropic-compatible API provider. * Use the *API Key File* field to specify the path to a file containing your - Anthropic API key. The API key may be optional when using a custom API URL + Anthropic API key. This path refers to the filesystem where the pgAdmin + server is running (e.g., inside the container if using Docker). The ``~`` + prefix is expanded to the home directory of the user running the pgAdmin + server process. The API key may be optional when using a custom API URL with a provider that does not require authentication. * Use the *Model* field to select from the available Claude models. Click the @@ -68,7 +71,10 @@ Use the fields on the *AI* panel to configure your LLM provider: (e.g., ``http://localhost:1234/v1``). * Use the *API Key File* field to specify the path to a file containing your - OpenAI API key. The API key may be optional when using a custom API URL + OpenAI API key. This path refers to the filesystem where the pgAdmin + server is running (e.g., inside the container if using Docker). The ``~`` + prefix is expanded to the home directory of the user running the pgAdmin + server process. The API key may be optional when using a custom API URL with a provider that does not require authentication. * Use the *Model* field to select from the available GPT models. Click the diff --git a/docs/en_US/release_notes_9_14.rst b/docs/en_US/release_notes_9_14.rst index 59619a5e4de..36ae2d679e1 100644 --- a/docs/en_US/release_notes_9_14.rst +++ b/docs/en_US/release_notes_9_14.rst @@ -41,5 +41,6 @@ Bug fixes | `Issue #9721 `_ - Fixed an issue where permissions page is not completely accessible on full scroll. | `Issue #9729 `_ - Fixed an issue where some LLM models would not use database tools in the AI assistant, instead returning text descriptions of tool calls. | `Issue #9732 `_ - Improve the AI Assistant user prompt to be more descriptive of the actual functionality. + | `Issue #9734 `_ - Fixed an issue where LLM responses are not streamed or rendered properly in the AI Assistant. | `Issue #9736 `_ - Fix an issue where the AI Assistant was not retaining conversation context between messages, with chat history compaction to manage token budgets. | `Issue #9740 `_ - Fixed an issue where the AI Assistant input textbox sometimes swallows the first character of input. diff --git a/web/pgadmin/llm/__init__.py b/web/pgadmin/llm/__init__.py index bd61ed090b7..e7dedefaad9 100644 --- a/web/pgadmin/llm/__init__.py +++ b/web/pgadmin/llm/__init__.py @@ -123,6 +123,8 @@ def register_preferences(self): category_label=gettext('Anthropic'), help_str=gettext( 'Path to a file containing your Anthropic API key. ' + 'This path must be on the server hosting pgAdmin, ' + 'e.g. inside the container when using Docker. ' 'The file should contain only the API key. The API key ' 'may be optional when using a custom API URL with a ' 'provider that does not require authentication.' @@ -185,6 +187,8 @@ def register_preferences(self): category_label=gettext('OpenAI'), help_str=gettext( 'Path to a file containing your OpenAI API key. ' + 'This path must be on the server hosting pgAdmin, ' + 'e.g. inside the container when using Docker. ' 'The file should contain only the API key. The API key ' 'may be optional when using a custom API URL with a ' 'provider that does not require authentication.' diff --git a/web/pgadmin/llm/chat.py b/web/pgadmin/llm/chat.py index 40e99219111..68759a85c94 100644 --- a/web/pgadmin/llm/chat.py +++ b/web/pgadmin/llm/chat.py @@ -14,10 +14,11 @@ """ import json -from typing import Optional +from collections.abc import Generator +from typing import Optional, Union from pgadmin.llm.client import get_llm_client, is_llm_available -from pgadmin.llm.models import Message, StopReason +from pgadmin.llm.models import Message, LLMResponse, StopReason from pgadmin.llm.tools import DATABASE_TOOLS, execute_tool, DatabaseToolError from pgadmin.llm.utils import get_max_tool_iterations @@ -153,6 +154,118 @@ def chat_with_database( ) +def chat_with_database_stream( + user_message: str, + sid: int, + did: int, + conversation_history: Optional[list[Message]] = None, + system_prompt: Optional[str] = None, + max_tool_iterations: Optional[int] = None, + provider: Optional[str] = None, + model: Optional[str] = None +) -> Generator[Union[str, tuple], None, None]: + """ + Stream an LLM chat conversation with database tool access. + + Like chat_with_database, but yields text chunks as the final + response streams in. During tool-use iterations, no text is + yielded (tools are executed silently). + + Yields: + str: Text content chunks from the final LLM response. + + The last item yielded is a 3-tuple of + ('complete', final_response_text, updated_conversation_history). + + Raises: + LLMClientError: If the LLM request fails. + RuntimeError: If LLM is not available or max iterations exceeded. + """ + if not is_llm_available(): + raise RuntimeError("LLM is not configured. Please configure an LLM " + "provider in Preferences > AI.") + + client = get_llm_client(provider=provider, model=model) + if not client: + raise RuntimeError("Failed to create LLM client") + + messages = list(conversation_history) if conversation_history else [] + messages.append(Message.user(user_message)) + + if system_prompt is None: + system_prompt = DEFAULT_SYSTEM_PROMPT + + if max_tool_iterations is None: + max_tool_iterations = get_max_tool_iterations() + + iteration = 0 + while iteration < max_tool_iterations: + iteration += 1 + + # Stream the LLM response, yielding text chunks as they arrive + response = None + for item in client.chat_stream( + messages=messages, + tools=DATABASE_TOOLS, + system_prompt=system_prompt + ): + if isinstance(item, LLMResponse): + response = item + elif isinstance(item, str): + yield item + + if response is None: + raise RuntimeError("No response received from LLM") + + messages.append(response.to_message()) + + if response.stop_reason != StopReason.TOOL_USE: + # Final response - yield a 3-tuple to distinguish from + # the 2-tuple tool_use event + yield ('complete', response.content, messages) + return + + # Signal that tools are being executed so the caller can + # reset streaming state and show a thinking indicator + yield ('tool_use', [tc.name for tc in response.tool_calls]) + + # Execute tool calls + tool_results = [] + for tool_call in response.tool_calls: + try: + result = execute_tool( + tool_name=tool_call.name, + arguments=tool_call.arguments, + sid=sid, + did=did + ) + tool_results.append(Message.tool_result( + tool_call_id=tool_call.id, + content=json.dumps(result, default=str), + is_error=False + )) + except (DatabaseToolError, ValueError) as e: + tool_results.append(Message.tool_result( + tool_call_id=tool_call.id, + content=json.dumps({"error": str(e)}), + is_error=True + )) + except Exception as e: + tool_results.append(Message.tool_result( + tool_call_id=tool_call.id, + content=json.dumps({ + "error": f"Unexpected error: {str(e)}" + }), + is_error=True + )) + + messages.extend(tool_results) + + raise RuntimeError( + f"Exceeded maximum tool iterations ({max_tool_iterations})" + ) + + def single_query( question: str, sid: int, diff --git a/web/pgadmin/llm/client.py b/web/pgadmin/llm/client.py index e11b8c15d5a..e860eec4497 100644 --- a/web/pgadmin/llm/client.py +++ b/web/pgadmin/llm/client.py @@ -10,7 +10,8 @@ """Base LLM client interface and factory.""" from abc import ABC, abstractmethod -from typing import Optional +from collections.abc import Generator +from typing import Optional, Union from pgadmin.llm.models import ( Message, Tool, LLMResponse, LLMError @@ -74,6 +75,48 @@ def chat( """ pass + def chat_stream( + self, + messages: list[Message], + tools: Optional[list[Tool]] = None, + system_prompt: Optional[str] = None, + max_tokens: int = 4096, + temperature: float = 0.0, + **kwargs + ) -> Generator[Union[str, LLMResponse], None, None]: + """ + Stream a chat response from the LLM. + + Yields text chunks (str) as they arrive, then yields + a final LLMResponse with the complete response metadata. + + The default implementation falls back to non-streaming chat(). + + Args: + messages: List of conversation messages. + tools: Optional list of tools the LLM can use. + system_prompt: Optional system prompt to set context. + max_tokens: Maximum tokens in the response. + temperature: Sampling temperature (0.0 = deterministic). + **kwargs: Additional provider-specific parameters. + + Yields: + str: Text content chunks as they arrive. + LLMResponse: Final response with complete metadata (last item). + """ + # Default: fall back to non-streaming + response = self.chat( + messages=messages, + tools=tools, + system_prompt=system_prompt, + max_tokens=max_tokens, + temperature=temperature, + **kwargs + ) + if response.content: + yield response.content + yield response + def validate_connection(self) -> tuple[bool, Optional[str]]: """ Validate the connection to the LLM provider. diff --git a/web/pgadmin/llm/prompts/nlq.py b/web/pgadmin/llm/prompts/nlq.py index e40854292c4..cfbeb035058 100644 --- a/web/pgadmin/llm/prompts/nlq.py +++ b/web/pgadmin/llm/prompts/nlq.py @@ -35,13 +35,10 @@ - Use explicit column names instead of SELECT * - For UPDATE/DELETE, always include WHERE clauses -Once you have explored the database structure using the tools above, \ -provide your final answer as a JSON object in this exact format: -{"sql": "YOUR SQL QUERY HERE", "explanation": "Brief explanation"} - -Rules for the final response: -- Return ONLY the JSON object, no other text -- No markdown code blocks -- If you need clarification, set "sql" to null and put \ -your question in "explanation" +Response format: +- Always put SQL in fenced code blocks with the sql language tag +- You may include multiple SQL blocks if the request needs \ +multiple statements +- Briefly explain what each query does +- If you need clarification, just ask — no code blocks needed """ diff --git a/web/pgadmin/llm/providers/anthropic.py b/web/pgadmin/llm/providers/anthropic.py index 1ab9e2b0427..c7b91a56e8e 100644 --- a/web/pgadmin/llm/providers/anthropic.py +++ b/web/pgadmin/llm/providers/anthropic.py @@ -10,10 +10,12 @@ """Anthropic Claude LLM client implementation.""" import json +import socket import ssl import urllib.request import urllib.error -from typing import Optional +from collections.abc import Generator +from typing import Optional, Union import uuid # Try to use certifi for proper SSL certificate handling @@ -284,3 +286,231 @@ def _parse_response(self, data: dict) -> LLMResponse: usage=usage, raw_response=data ) + + def chat_stream( + self, + messages: list[Message], + tools: Optional[list[Tool]] = None, + system_prompt: Optional[str] = None, + max_tokens: int = 4096, + temperature: float = 0.0, + **kwargs + ) -> Generator[Union[str, LLMResponse], None, None]: + """Stream a chat response from Anthropic.""" + payload = { + 'model': self._model, + 'max_tokens': max_tokens, + 'messages': self._convert_messages(messages), + 'stream': True + } + + if system_prompt: + payload['system'] = system_prompt + + if temperature > 0: + payload['temperature'] = temperature + + if tools: + payload['tools'] = self._convert_tools(tools) + + try: + yield from self._process_stream(payload) + except LLMClientError: + raise + except Exception as e: + raise LLMClientError(LLMError( + message=f"Streaming request failed: {str(e)}", + provider=self.provider_name + )) + + def _process_stream( + self, payload: dict + ) -> Generator[Union[str, LLMResponse], None, None]: + """Make a streaming request and yield chunks.""" + headers = { + 'Content-Type': 'application/json', + 'anthropic-version': API_VERSION + } + + if self._api_key: + headers['x-api-key'] = self._api_key + + request = urllib.request.Request( + self._api_url, + data=json.dumps(payload).encode('utf-8'), + headers=headers, + method='POST' + ) + + try: + response = urllib.request.urlopen( + request, timeout=120, context=SSL_CONTEXT + ) + except urllib.error.HTTPError as e: + error_body = e.read().decode('utf-8') + try: + error_data = json.loads(error_body) + error_msg = error_data.get( + 'error', {} + ).get('message', str(e)) + except json.JSONDecodeError: + error_msg = error_body or str(e) + raise LLMClientError(LLMError( + message=error_msg, + code=str(e.code), + provider=self.provider_name, + retryable=e.code in (429, 500, 502, 503, 504) + )) + except urllib.error.URLError as e: + raise LLMClientError(LLMError( + message=f"Connection error: {e.reason}", + provider=self.provider_name, + retryable=True + )) + except socket.timeout: + raise LLMClientError(LLMError( + message="Request timed out.", + code='timeout', + provider=self.provider_name, + retryable=True + )) + + try: + yield from self._read_anthropic_stream(response) + finally: + response.close() + + def _read_anthropic_stream( + self, response + ) -> Generator[Union[str, LLMResponse], None, None]: + """Read and parse an Anthropic SSE stream. + + Uses readline() for incremental reading. + """ + content_parts = [] + tool_calls = [] + current_tool_block = None + tool_input_json = '' + stop_reason_str = None + model_name = self._model + usage = Usage() + in_text_block = False + + while True: + line_bytes = response.readline() + if not line_bytes: + break + + line = line_bytes.decode('utf-8', errors='replace').strip() + + if not line or line.startswith(':'): + continue + + if line.startswith('event: '): + continue + + if not line.startswith('data: '): + continue + + try: + data = json.loads(line[6:]) + except json.JSONDecodeError: + continue + + event_type = data.get('type', '') + + if event_type == 'message_start': + msg = data.get('message', {}) + model_name = msg.get('model', self._model) + u = msg.get('usage', {}) + usage = Usage( + input_tokens=u.get('input_tokens', 0), + output_tokens=u.get('output_tokens', 0), + total_tokens=( + u.get('input_tokens', 0) + + u.get('output_tokens', 0) + ) + ) + + elif event_type == 'content_block_start': + block = data.get('content_block', {}) + if block.get('type') == 'tool_use': + current_tool_block = { + 'id': block.get('id', str(uuid.uuid4())), + 'name': block.get('name', '') + } + tool_input_json = '' + elif block.get('type') == 'text': + # Emit a separator between text blocks to + # match _parse_response() which joins with '\n' + if in_text_block: + content_parts.append('\n') + yield '\n' + in_text_block = True + + elif event_type == 'content_block_delta': + delta = data.get('delta', {}) + if delta.get('type') == 'text_delta': + text = delta.get('text', '') + if text: + content_parts.append(text) + yield text + elif delta.get('type') == 'input_json_delta': + tool_input_json += delta.get( + 'partial_json', '' + ) + + elif event_type == 'content_block_stop': + if current_tool_block is not None: + try: + arguments = json.loads( + tool_input_json + ) if tool_input_json else {} + except json.JSONDecodeError: + arguments = {} + tool_calls.append(ToolCall( + id=current_tool_block['id'], + name=current_tool_block['name'], + arguments=arguments + )) + current_tool_block = None + tool_input_json = '' + + elif event_type == 'message_delta': + delta = data.get('delta', {}) + stop_reason_str = delta.get('stop_reason') + u = data.get('usage', {}) + if u: + usage = Usage( + input_tokens=usage.input_tokens, + output_tokens=u.get( + 'output_tokens', + usage.output_tokens + ), + total_tokens=( + usage.input_tokens + + u.get( + 'output_tokens', + usage.output_tokens + ) + ) + ) + + # Build final response + stop_reason_map = { + 'end_turn': StopReason.END_TURN, + 'tool_use': StopReason.TOOL_USE, + 'max_tokens': StopReason.MAX_TOKENS, + 'stop_sequence': StopReason.STOP_SEQUENCE + } + stop_reason = stop_reason_map.get( + stop_reason_str or '', StopReason.UNKNOWN + ) + + yield LLMResponse( + content=''.join(content_parts), + tool_calls=tool_calls, + stop_reason=stop_reason, + model=model_name, + usage=usage + ) diff --git a/web/pgadmin/llm/providers/docker.py b/web/pgadmin/llm/providers/docker.py index 4fa6ccda2cb..52132827e67 100644 --- a/web/pgadmin/llm/providers/docker.py +++ b/web/pgadmin/llm/providers/docker.py @@ -16,9 +16,11 @@ import json import socket import ssl +import urllib.parse import urllib.request import urllib.error -from typing import Optional +from collections.abc import Generator +from typing import Optional, Union import uuid # Try to use certifi for proper SSL certificate handling @@ -42,6 +44,25 @@ DEFAULT_API_URL = 'http://localhost:12434' DEFAULT_MODEL = 'ai/qwen3-coder' +# Allowed loopback hostnames for the Docker endpoint +_LOOPBACK_HOSTS = {'localhost', '127.0.0.1', '::1', '[::1]'} + + +def _validate_loopback_url(url: str) -> None: + """Ensure the URL uses HTTP(S) and points to a loopback address.""" + parsed = urllib.parse.urlparse(url) + if parsed.scheme not in ('http', 'https'): + raise ValueError( + f"Docker Model Runner URL must use http or https, " + f"got: {parsed.scheme}" + ) + hostname = (parsed.hostname or '').lower() + if hostname not in _LOOPBACK_HOSTS: + raise ValueError( + f"Docker Model Runner URL must point to a loopback address " + f"(localhost/127.0.0.1/::1), got: {hostname}" + ) + class DockerClient(LLMClient): """ @@ -63,6 +84,7 @@ def __init__( model: Optional model name. Defaults to ai/qwen3-coder. """ self._api_url = (api_url or DEFAULT_API_URL).rstrip('/') + _validate_loopback_url(self._api_url) self._model = model or DEFAULT_MODEL @property @@ -354,3 +376,216 @@ def _parse_response(self, data: dict) -> LLMResponse: usage=usage, raw_response=data ) + + def chat_stream( + self, + messages: list[Message], + tools: Optional[list[Tool]] = None, + system_prompt: Optional[str] = None, + max_tokens: int = 4096, + temperature: float = 0.0, + **kwargs + ) -> Generator[Union[str, LLMResponse], None, None]: + """Stream a chat response from Docker Model Runner.""" + converted_messages = self._convert_messages(messages) + + if system_prompt: + converted_messages.insert(0, { + 'role': 'system', + 'content': system_prompt + }) + + payload = { + 'model': self._model, + 'messages': converted_messages, + 'max_completion_tokens': max_tokens, + 'temperature': temperature, + 'stream': True, + 'stream_options': {'include_usage': True} + } + + if tools: + payload['tools'] = self._convert_tools(tools) + payload['tool_choice'] = 'auto' + + try: + yield from self._process_stream(payload) + except LLMClientError: + raise + except Exception as e: + raise LLMClientError(LLMError( + message=f"Streaming request failed: {str(e)}", + provider=self.provider_name + )) + + def _process_stream( + self, payload: dict + ) -> Generator[Union[str, LLMResponse], None, None]: + """Make a streaming request and yield chunks.""" + headers = { + 'Content-Type': 'application/json' + } + + url = f'{self._api_url}/engines/v1/chat/completions' + + request = urllib.request.Request( + url, + data=json.dumps(payload).encode('utf-8'), + headers=headers, + method='POST' + ) + + try: + response = urllib.request.urlopen( + request, timeout=300, context=SSL_CONTEXT + ) + except urllib.error.HTTPError as e: + error_body = e.read().decode('utf-8') + try: + error_data = json.loads(error_body) + error_msg = error_data.get( + 'error', {} + ).get('message', str(e)) + except json.JSONDecodeError: + error_msg = error_body or str(e) + raise LLMClientError(LLMError( + message=error_msg, + code=str(e.code), + provider=self.provider_name, + retryable=e.code in (429, 500, 502, 503, 504) + )) + except urllib.error.URLError as e: + raise LLMClientError(LLMError( + message=f"Connection error: {e.reason}. " + f"Is Docker Model Runner running at " + f"{self._api_url}?", + provider=self.provider_name, + retryable=True + )) + except socket.timeout: + raise LLMClientError(LLMError( + message="Request timed out.", + code='timeout', + provider=self.provider_name, + retryable=True + )) + + try: + yield from self._read_openai_stream(response) + finally: + response.close() + + def _read_openai_stream( + self, response + ) -> Generator[Union[str, LLMResponse], None, None]: + """Read and parse an OpenAI-format SSE stream. + + Uses readline() for incremental reading. + """ + content_parts = [] + tool_calls_data = {} + finish_reason = None + model_name = self._model + usage = Usage() + + while True: + line_bytes = response.readline() + if not line_bytes: + break + + line = line_bytes.decode('utf-8', errors='replace').strip() + + if not line or line.startswith(':'): + continue + + if line == 'data: [DONE]': + continue + + if not line.startswith('data: '): + continue + + try: + data = json.loads(line[6:]) + except json.JSONDecodeError: + continue + + if 'usage' in data and data['usage']: + u = data['usage'] + usage = Usage( + input_tokens=u.get('prompt_tokens', 0), + output_tokens=u.get('completion_tokens', 0), + total_tokens=u.get('total_tokens', 0) + ) + + if 'model' in data: + model_name = data['model'] + + choices = data.get('choices', []) + if not choices: + continue + + choice = choices[0] + delta = choice.get('delta', {}) + + if choice.get('finish_reason'): + finish_reason = choice['finish_reason'] + + text_chunk = delta.get('content') + if text_chunk: + content_parts.append(text_chunk) + yield text_chunk + + for tc_delta in delta.get('tool_calls', []): + idx = tc_delta.get('index', 0) + if idx not in tool_calls_data: + tool_calls_data[idx] = { + 'id': '', 'name': '', 'arguments': '' + } + tc = tool_calls_data[idx] + if 'id' in tc_delta: + tc['id'] = tc_delta['id'] + func = tc_delta.get('function', {}) + if 'name' in func: + tc['name'] = func['name'] + if 'arguments' in func: + tc['arguments'] += func['arguments'] + + content = ''.join(content_parts) + tool_calls = [] + for idx in sorted(tool_calls_data.keys()): + tc = tool_calls_data[idx] + try: + arguments = json.loads(tc['arguments']) \ + if tc['arguments'] else {} + except json.JSONDecodeError: + arguments = {} + tool_calls.append(ToolCall( + id=tc['id'] or str(uuid.uuid4()), + name=tc['name'], + arguments=arguments + )) + + stop_reason_map = { + 'stop': StopReason.END_TURN, + 'tool_calls': StopReason.TOOL_USE, + 'length': StopReason.MAX_TOKENS, + 'content_filter': StopReason.STOP_SEQUENCE + } + stop_reason = stop_reason_map.get( + finish_reason or '', StopReason.UNKNOWN + ) + + if not content and not tool_calls: + raise LLMClientError(LLMError( + message='No response content returned from API', + provider=self.provider_name, + retryable=False + )) + + yield LLMResponse( + content=content, + tool_calls=tool_calls, + stop_reason=stop_reason, + model=model_name, + usage=usage + ) diff --git a/web/pgadmin/llm/providers/ollama.py b/web/pgadmin/llm/providers/ollama.py index 8d38b72facd..1706e3d8ead 100644 --- a/web/pgadmin/llm/providers/ollama.py +++ b/web/pgadmin/llm/providers/ollama.py @@ -10,10 +10,11 @@ """Ollama LLM client implementation.""" import json -import re +import urllib.parse import urllib.request import urllib.error -from typing import Optional +from collections.abc import Generator +from typing import Optional, Union import uuid from pgadmin.llm.client import LLMClient, LLMClientError @@ -47,6 +48,14 @@ def __init__(self, api_url: str, model: Optional[str] = None): self._api_url = api_url.rstrip('/') self._model = model or DEFAULT_MODEL + # Validate URL scheme to prevent unsafe access + parsed = urllib.parse.urlparse(self._api_url) + if parsed.scheme not in ('http', 'https'): + raise ValueError( + f"Ollama URL must use http or https scheme, " + f"got: {parsed.scheme}" + ) + @property def provider_name(self) -> str: return 'ollama' @@ -220,7 +229,7 @@ def _make_request(self, payload: dict) -> dict: message=error_msg, code=str(e.code), provider=self.provider_name, - retryable=e.code in (500, 502, 503, 504) + retryable=e.code in (429, 500, 502, 503, 504) )) except urllib.error.URLError as e: raise LLMClientError(LLMError( @@ -231,8 +240,6 @@ def _make_request(self, payload: dict) -> dict: def _parse_response(self, data: dict) -> LLMResponse: """Parse the Ollama API response into an LLMResponse.""" - import re - message = data.get('message', {}) content = message.get('content', '') tool_calls = [] @@ -285,3 +292,177 @@ def _parse_response(self, data: dict) -> LLMResponse: usage=usage, raw_response=data ) + + def chat_stream( + self, + messages: list[Message], + tools: Optional[list[Tool]] = None, + system_prompt: Optional[str] = None, + max_tokens: int = 4096, + temperature: float = 0.0, + **kwargs + ) -> Generator[Union[str, LLMResponse], None, None]: + """Stream a chat response from Ollama.""" + converted_messages = self._convert_messages(messages) + + if system_prompt: + converted_messages.insert(0, { + 'role': 'system', + 'content': system_prompt + }) + + payload = { + 'model': self._model, + 'messages': converted_messages, + 'stream': True, + 'options': { + 'num_predict': max_tokens, + 'temperature': temperature + } + } + + if tools: + payload['tools'] = self._convert_tools(tools) + + try: + yield from self._process_stream(payload) + except LLMClientError: + raise + except Exception as e: + raise LLMClientError(LLMError( + message=f"Streaming request failed: {str(e)}", + provider=self.provider_name + )) + + def _process_stream( + self, payload: dict + ) -> Generator[Union[str, LLMResponse], None, None]: + """Make a streaming request and yield chunks.""" + url = f'{self._api_url}/api/chat' + + request = urllib.request.Request( + url, + data=json.dumps(payload).encode('utf-8'), + headers={'Content-Type': 'application/json'}, + method='POST' + ) + + try: + response = urllib.request.urlopen(request, timeout=300) + except urllib.error.HTTPError as e: + error_body = e.read().decode('utf-8') + try: + error_data = json.loads(error_body) + error_msg = error_data.get('error', str(e)) + except json.JSONDecodeError: + error_msg = error_body or str(e) + raise LLMClientError(LLMError( + message=error_msg, + code=str(e.code), + provider=self.provider_name, + retryable=e.code in (429, 500, 502, 503, 504) + )) + except urllib.error.URLError as e: + raise LLMClientError(LLMError( + message=f"Cannot connect to Ollama: {e.reason}", + provider=self.provider_name, + retryable=True + )) + + try: + yield from self._read_ollama_stream(response) + finally: + response.close() + + def _read_ollama_stream( + self, response + ) -> Generator[Union[str, LLMResponse], None, None]: + """Read and parse an Ollama NDJSON stream. + + Uses readline() for incremental reading. + """ + content_parts = [] + tool_calls = [] + done_reason = None + model_name = self._model + input_tokens = 0 + output_tokens = 0 + final_data = None + + while True: + line_bytes = response.readline() + if not line_bytes: + break + + line = line_bytes.decode('utf-8', errors='replace').strip() + + if not line: + continue + + try: + data = json.loads(line) + except json.JSONDecodeError: + continue + + msg = data.get('message', {}) + + # Text content + text = msg.get('content', '') + if text: + content_parts.append(text) + yield text + + # Tool calls (in final message) + for tc in msg.get('tool_calls', []): + func = tc.get('function', {}) + arguments = func.get('arguments', {}) + if isinstance(arguments, str): + try: + arguments = json.loads(arguments) + except json.JSONDecodeError: + arguments = {} + tool_calls.append(ToolCall( + id=str(uuid.uuid4()), + name=func.get('name', ''), + arguments=arguments + )) + + if data.get('done'): + final_data = data + done_reason = data.get('done_reason', '') + model_name = data.get('model', self._model) + input_tokens = data.get('prompt_eval_count', 0) + output_tokens = data.get('eval_count', 0) + + # Ensure the stream completed with a terminal done frame; + # truncated content from a dropped connection is unreliable. + if final_data is None: + raise LLMClientError(LLMError( + message="Ollama stream ended before terminal done frame", + provider=self.provider_name, + retryable=True + )) + + content = ''.join(content_parts) + + if tool_calls: + stop_reason = StopReason.TOOL_USE + elif done_reason == 'stop': + stop_reason = StopReason.END_TURN + elif done_reason == 'length': + stop_reason = StopReason.MAX_TOKENS + else: + stop_reason = StopReason.UNKNOWN + + yield LLMResponse( + content=content, + tool_calls=tool_calls, + stop_reason=stop_reason, + model=model_name, + usage=Usage( + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=input_tokens + output_tokens + ), + raw_response=final_data + ) diff --git a/web/pgadmin/llm/providers/openai.py b/web/pgadmin/llm/providers/openai.py index 73c020fb7ba..2b0e2072917 100644 --- a/web/pgadmin/llm/providers/openai.py +++ b/web/pgadmin/llm/providers/openai.py @@ -14,7 +14,8 @@ import ssl import urllib.request import urllib.error -from typing import Optional +from collections.abc import Generator +from typing import Optional, Union import uuid # Try to use certifi for proper SSL certificate handling @@ -355,3 +356,221 @@ def _parse_response(self, data: dict) -> LLMResponse: usage=usage, raw_response=data ) + + def chat_stream( + self, + messages: list[Message], + tools: Optional[list[Tool]] = None, + system_prompt: Optional[str] = None, + max_tokens: int = 4096, + temperature: float = 0.0, + **kwargs + ) -> Generator[Union[str, LLMResponse], None, None]: + """Stream a chat response from OpenAI.""" + converted_messages = self._convert_messages(messages) + + if system_prompt: + converted_messages.insert(0, { + 'role': 'system', + 'content': system_prompt + }) + + payload = { + 'model': self._model, + 'messages': converted_messages, + 'max_completion_tokens': max_tokens, + 'stream': True, + 'stream_options': {'include_usage': True} + } + + if tools: + payload['tools'] = self._convert_tools(tools) + payload['tool_choice'] = 'auto' + + try: + yield from self._process_stream(payload) + except LLMClientError: + raise + except Exception as e: + raise LLMClientError(LLMError( + message=f"Streaming request failed: {str(e)}", + provider=self.provider_name + )) + + def _process_stream( + self, payload: dict + ) -> Generator[Union[str, LLMResponse], None, None]: + """Make a streaming request and yield chunks.""" + headers = { + 'Content-Type': 'application/json', + } + + if self._api_key: + headers['Authorization'] = f'Bearer {self._api_key}' + + request = urllib.request.Request( + self._api_url, + data=json.dumps(payload).encode('utf-8'), + headers=headers, + method='POST' + ) + + try: + response = urllib.request.urlopen( + request, timeout=120, context=SSL_CONTEXT + ) + except urllib.error.HTTPError as e: + error_body = e.read().decode('utf-8') + try: + error_data = json.loads(error_body) + error_msg = error_data.get( + 'error', {} + ).get('message', str(e)) + except json.JSONDecodeError: + error_msg = error_body or str(e) + raise LLMClientError(LLMError( + message=error_msg, + code=str(e.code), + provider=self.provider_name, + retryable=e.code in (429, 500, 502, 503, 504) + )) + except urllib.error.URLError as e: + raise LLMClientError(LLMError( + message=f"Connection error: {e.reason}", + provider=self.provider_name, + retryable=True + )) + except socket.timeout: + raise LLMClientError(LLMError( + message="Request timed out.", + code='timeout', + provider=self.provider_name, + retryable=True + )) + + try: + yield from self._read_openai_stream(response) + finally: + response.close() + + def _read_openai_stream( + self, response + ) -> Generator[Union[str, LLMResponse], None, None]: + """Read and parse an OpenAI-format SSE stream. + + Uses readline() for incremental reading — it returns as soon + as a complete line arrives from the server, unlike read() + which blocks until a buffer fills up. + """ + content_parts = [] + # tool_calls_data: {index: {id, name, arguments_str}} + tool_calls_data = {} + finish_reason = None + model_name = self._model + usage = Usage() + + while True: + line_bytes = response.readline() + if not line_bytes: + break + + line = line_bytes.decode('utf-8', errors='replace').strip() + + if not line or line.startswith(':'): + continue + + if line == 'data: [DONE]': + continue + + if not line.startswith('data: '): + continue + + try: + data = json.loads(line[6:]) + except json.JSONDecodeError: + continue + + # Extract usage from the final chunk + if 'usage' in data and data['usage']: + u = data['usage'] + usage = Usage( + input_tokens=u.get('prompt_tokens', 0), + output_tokens=u.get('completion_tokens', 0), + total_tokens=u.get('total_tokens', 0) + ) + + if 'model' in data: + model_name = data['model'] + + choices = data.get('choices', []) + if not choices: + continue + + choice = choices[0] + delta = choice.get('delta', {}) + + if choice.get('finish_reason'): + finish_reason = choice['finish_reason'] + + # Text content + text_chunk = delta.get('content') + if text_chunk: + content_parts.append(text_chunk) + yield text_chunk + + # Tool calls (accumulate) + for tc_delta in delta.get('tool_calls', []): + idx = tc_delta.get('index', 0) + if idx not in tool_calls_data: + tool_calls_data[idx] = { + 'id': '', 'name': '', 'arguments': '' + } + tc = tool_calls_data[idx] + if 'id' in tc_delta: + tc['id'] = tc_delta['id'] + func = tc_delta.get('function', {}) + if 'name' in func: + tc['name'] = func['name'] + if 'arguments' in func: + tc['arguments'] += func['arguments'] + + # Build final response + content = ''.join(content_parts) + tool_calls = [] + for idx in sorted(tool_calls_data.keys()): + tc = tool_calls_data[idx] + try: + arguments = json.loads(tc['arguments']) \ + if tc['arguments'] else {} + except json.JSONDecodeError: + arguments = {} + tool_calls.append(ToolCall( + id=tc['id'] or str(uuid.uuid4()), + name=tc['name'], + arguments=arguments + )) + + stop_reason_map = { + 'stop': StopReason.END_TURN, + 'tool_calls': StopReason.TOOL_USE, + 'length': StopReason.MAX_TOKENS, + 'content_filter': StopReason.STOP_SEQUENCE + } + stop_reason = stop_reason_map.get( + finish_reason or '', StopReason.UNKNOWN + ) + + if not content and not tool_calls: + raise LLMClientError(LLMError( + message='No response content returned from API', + provider=self.provider_name, + retryable=False + )) + + yield LLMResponse( + content=content, + tool_calls=tool_calls, + stop_reason=stop_reason, + model=model_name, + usage=usage + ) diff --git a/web/pgadmin/static/js/Theme/dark.js b/web/pgadmin/static/js/Theme/dark.js index 5deb73324d5..b087e49e0ad 100644 --- a/web/pgadmin/static/js/Theme/dark.js +++ b/web/pgadmin/static/js/Theme/dark.js @@ -89,6 +89,7 @@ export default function(basicSettings) { }, otherVars: { colorBrand: '#1b71b5', + hyperlinkColor: '#6CB4EE', borderColor: '#4a4a4a', inputBorderColor: '#6b6b6b', inputDisabledBg: 'inherit', diff --git a/web/pgadmin/static/js/Theme/high_contrast.js b/web/pgadmin/static/js/Theme/high_contrast.js index 184153cb273..0a4a5083cf5 100644 --- a/web/pgadmin/static/js/Theme/high_contrast.js +++ b/web/pgadmin/static/js/Theme/high_contrast.js @@ -87,6 +87,7 @@ export default function(basicSettings) { }, otherVars: { colorBrand: '#84D6FF', + hyperlinkColor: '#84D6FF', borderColor: '#A6B7C8', inputBorderColor: '#8B9CAD', inputDisabledBg: '#1F2932', diff --git a/web/pgadmin/static/js/Theme/light.js b/web/pgadmin/static/js/Theme/light.js index 093928cfef1..00847f91425 100644 --- a/web/pgadmin/static/js/Theme/light.js +++ b/web/pgadmin/static/js/Theme/light.js @@ -89,6 +89,7 @@ export default function(basicSettings) { }, otherVars: { colorBrand: '#326690', + hyperlinkColor: '#1a0dab', iconLoaderUrl: 'url("data:image/svg+xml,%3C%3Fxml version=\'1.0\' encoding=\'utf-8\'%3F%3E%3C!-- Generator: Adobe Illustrator 23.1.1, SVG Export Plug-In . SVG Version: 6.00 Build 0) --%3E%3Csvg version=\'1.1\' id=\'Layer_1\' xmlns=\'http://www.w3.org/2000/svg\' xmlns:xlink=\'http://www.w3.org/1999/xlink\' x=\'0px\' y=\'0px\' viewBox=\'0 0 38 38\' style=\'enable-background:new 0 0 38 38;\' xml:space=\'preserve\'%3E%3Cstyle type=\'text/css\'%3E .st0%7Bfill:none;stroke:%23EBEEF3;stroke-width:5;%7D .st1%7Bfill:none;stroke:%23326690;stroke-width:5;%7D%0A%3C/style%3E%3Cg%3E%3Cg transform=\'translate(1 1)\'%3E%3Ccircle class=\'st0\' cx=\'18\' cy=\'18\' r=\'16\'/%3E%3Cpath class=\'st1\' d=\'M34,18c0-8.8-7.2-16-16-16 \'%3E%3CanimateTransform accumulate=\'none\' additive=\'replace\' attributeName=\'transform\' calcMode=\'linear\' dur=\'0.7s\' fill=\'remove\' from=\'0 18 18\' repeatCount=\'indefinite\' restart=\'always\' to=\'360 18 18\' type=\'rotate\'%3E%3C/animateTransform%3E%3C/path%3E%3C/g%3E%3C/g%3E%3C/svg%3E%0A");', iconLoaderSmall: 'url("data:image/svg+xml,%3C%3Fxml version=\'1.0\' encoding=\'utf-8\'%3F%3E%3C!-- Generator: Adobe Illustrator 23.1.1, SVG Export Plug-In . SVG Version: 6.00 Build 0) --%3E%3Csvg version=\'1.1\' id=\'Layer_1\' xmlns=\'http://www.w3.org/2000/svg\' xmlns:xlink=\'http://www.w3.org/1999/xlink\' x=\'0px\' y=\'0px\' viewBox=\'0 0 38 38\' style=\'enable-background:new 0 0 38 38;\' xml:space=\'preserve\'%3E%3Cstyle type=\'text/css\'%3E .st0%7Bfill:none;stroke:%23EBEEF3;stroke-width:5;%7D .st1%7Bfill:none;stroke:%23326690;stroke-width:5;%7D%0A%3C/style%3E%3Cg%3E%3Cg transform=\'translate(1 1)\'%3E%3Ccircle class=\'st0\' cx=\'18\' cy=\'18\' r=\'16\'/%3E%3Cpath class=\'st1\' d=\'M34,18c0-8.8-7.2-16-16-16 \'%3E%3CanimateTransform accumulate=\'none\' additive=\'replace\' attributeName=\'transform\' calcMode=\'linear\' dur=\'0.7s\' fill=\'remove\' from=\'0 18 18\' repeatCount=\'indefinite\' restart=\'always\' to=\'360 18 18\' type=\'rotate\'%3E%3C/animateTransform%3E%3C/path%3E%3C/g%3E%3C/g%3E%3C/svg%3E%0A")', dashboardPgDoc: 'url("data:image/svg+xml,%3C%3Fxml version=\'1.0\' encoding=\'utf-8\'%3F%3E%3C!-- Generator: Adobe Illustrator 22.1.0, SVG Export Plug-In . SVG Version: 6.00 Build 0) --%3E%3Csvg version=\'1.1\' id=\'Layer_1\' xmlns=\'http://www.w3.org/2000/svg\' xmlns:xlink=\'http://www.w3.org/1999/xlink\' x=\'0px\' y=\'0px\' viewBox=\'0 0 42 42\' style=\'enable-background:new 0 0 42 42;\' xml:space=\'preserve\'%3E%3Cstyle type=\'text/css\'%3E .st0%7Bstroke:%23000000;stroke-width:3.3022;%7D .st1%7Bfill:%23336791;%7D .st2%7Bfill:none;stroke:%23FFFFFF;stroke-width:1.1007;stroke-linecap:round;stroke-linejoin:round;%7D .st3%7Bfill:none;stroke:%23FFFFFF;stroke-width:1.1007;stroke-linecap:round;stroke-linejoin:bevel;%7D .st4%7Bfill:%23FFFFFF;stroke:%23FFFFFF;stroke-width:0.3669;%7D .st5%7Bfill:%23FFFFFF;stroke:%23FFFFFF;stroke-width:0.1835;%7D .st6%7Bfill:none;stroke:%23FFFFFF;stroke-width:0.2649;stroke-linecap:round;stroke-linejoin:round;%7D%0A%3C/style%3E%3Cg id=\'orginal\'%3E%3C/g%3E%3Cg id=\'Layer_x0020_3\'%3E%3Cpath class=\'st0\' d=\'M31.3,30c0.3-2.1,0.2-2.4,1.7-2.1l0.4,0c1.2,0.1,2.8-0.2,3.7-0.6c2-0.9,3.1-2.4,1.2-2 c-4.4,0.9-4.7-0.6-4.7-0.6c4.7-7,6.7-15.8,5-18c-4.6-5.9-12.6-3.1-12.7-3l0,0c-0.9-0.2-1.9-0.3-3-0.3c-2,0-3.5,0.5-4.7,1.4 c0,0-14.3-5.9-13.6,7.4c0.1,2.8,4,21.3,8.7,15.7c1.7-2,3.3-3.8,3.3-3.8c0.8,0.5,1.8,0.8,2.8,0.7l0.1-0.1c0,0.3,0,0.5,0,0.8 c-1.2,1.3-0.8,1.6-3.2,2.1c-2.4,0.5-1,1.4-0.1,1.6c1.1,0.3,3.7,0.7,5.5-1.8l-0.1,0.3c0.5,0.4,0.4,2.7,0.5,4.4 c0.1,1.7,0.2,3.2,0.5,4.1c0.3,0.9,0.7,3.3,3.9,2.6C29.1,38.3,31.1,37.5,31.3,30\'/%3E%3Cpath class=\'st1\' d=\'M38.3,25.3c-4.4,0.9-4.7-0.6-4.7-0.6c4.7-7,6.7-15.8,5-18c-4.6-5.9-12.6-3.1-12.7-3l0,0 c-0.9-0.2-1.9-0.3-3-0.3c-2,0-3.5,0.5-4.7,1.4c0,0-14.3-5.9-13.6,7.4c0.1,2.8,4,21.3,8.7,15.7c1.7-2,3.3-3.8,3.3-3.8 c0.8,0.5,1.8,0.8,2.8,0.7l0.1-0.1c0,0.3,0,0.5,0,0.8c-1.2,1.3-0.8,1.6-3.2,2.1c-2.4,0.5-1,1.4-0.1,1.6c1.1,0.3,3.7,0.7,5.5-1.8 l-0.1,0.3c0.5,0.4,0.8,2.4,0.7,4.3c-0.1,1.9-0.1,3.2,0.3,4.2c0.4,1,0.7,3.3,3.9,2.6c2.6-0.6,4-2,4.2-4.5c0.1-1.7,0.4-1.5,0.5-3 l0.2-0.7c0.3-2.3,0-3.1,1.7-2.8l0.4,0c1.2,0.1,2.8-0.2,3.7-0.6C39,26.4,40.2,24.9,38.3,25.3L38.3,25.3z\'/%3E%3Cpath class=\'st2\' d=\'M21.8,26.6c-0.1,4.4,0,8.8,0.5,9.8c0.4,1.1,1.3,3.2,4.5,2.5c2.6-0.6,3.6-1.7,4-4.1c0.3-1.8,0.9-6.7,1-7.7\'/%3E%3Cpath class=\'st2\' d=\'M18,4.7c0,0-14.3-5.8-13.6,7.4c0.1,2.8,4,21.3,8.7,15.7c1.7-2,3.2-3.7,3.2-3.7\'/%3E%3Cpath class=\'st2\' d=\'M25.7,3.6c-0.5,0.2,7.9-3.1,12.7,3c1.7,2.2-0.3,11-5,18\'/%3E%3Cpath class=\'st3\' d=\'M33.5,24.6c0,0,0.3,1.5,4.7,0.6c1.9-0.4,0.8,1.1-1.2,2c-1.6,0.8-5.3,0.9-5.3-0.1 C31.6,24.5,33.6,25.3,33.5,24.6c-0.1-0.6-1.1-1.2-1.7-2.7c-0.5-1.3-7.3-11.2,1.9-9.7c0.3-0.1-2.4-8.7-11-8.9 c-8.6-0.1-8.3,10.6-8.3,10.6\'/%3E%3Cpath class=\'st2\' d=\'M19.4,25.6c-1.2,1.3-0.8,1.6-3.2,2.1c-2.4,0.5-1,1.4-0.1,1.6c1.1,0.3,3.7,0.7,5.5-1.8c0.5-0.8,0-2-0.7-2.3 C20.5,25.1,20,24.9,19.4,25.6L19.4,25.6z\'/%3E%3Cpath class=\'st2\' d=\'M19.3,25.5c-0.1-0.8,0.3-1.7,0.7-2.8c0.6-1.6,2-3.3,0.9-8.5c-0.8-3.9-6.5-0.8-6.5-0.3c0,0.5,0.3,2.7-0.1,5.2 c-0.5,3.3,2.1,6,5,5.7\'/%3E%3Cpath class=\'st4\' d=\'M18,13.8c0,0.2,0.3,0.7,0.8,0.7c0.5,0.1,0.9-0.3,0.9-0.5c0-0.2-0.3-0.4-0.8-0.4C18.4,13.6,18,13.7,18,13.8 L18,13.8z\'/%3E%3Cpath class=\'st5\' d=\'M32,13.5c0,0.2-0.3,0.7-0.8,0.7c-0.5,0.1-0.9-0.3-0.9-0.5c0-0.2,0.3-0.4,0.8-0.4C31.6,13.2,32,13.3,32,13.5 L32,13.5z\'/%3E%3Cpath class=\'st2\' d=\'M33.7,12.2c0.1,1.4-0.3,2.4-0.4,3.9c-0.1,2.2,1,4.7-0.6,7.2\'/%3E%3Cpath class=\'st6\' d=\'M2.7,6.6\'/%3E%3C/g%3E%3C/svg%3E%0A")', diff --git a/web/pgadmin/tools/sqleditor/__init__.py b/web/pgadmin/tools/sqleditor/__init__.py index 73f11059438..f7cb83ba0c5 100644 --- a/web/pgadmin/tools/sqleditor/__init__.py +++ b/web/pgadmin/tools/sqleditor/__init__.py @@ -2842,7 +2842,7 @@ def nlq_chat_stream(trans_id): """ from flask import stream_with_context from pgadmin.llm.utils import is_llm_enabled - from pgadmin.llm.chat import chat_with_database + from pgadmin.llm.chat import chat_with_database_stream from pgadmin.llm.prompts.nlq import NLQ_SYSTEM_PROMPT # Check if LLM is configured @@ -2886,10 +2886,7 @@ def nlq_chat_stream(trans_id): def generate(): """Generator for SSE events.""" import secrets as py_secrets - from pgadmin.llm.compaction import ( - deserialize_history, compact_history - ) - from pgadmin.llm.utils import get_default_provider + from pgadmin.llm.models import Message, Role try: # Send thinking status @@ -2898,85 +2895,73 @@ def generate(): 'message': gettext('Analyzing your request...') }) - # Deserialize and compact conversation history + # Deserialize conversation history if provided conversation_history = None if history_data: - conversation_history = deserialize_history(history_data) - provider = get_default_provider() or 'openai' - conversation_history = compact_history( - conversation_history, - provider=provider - ) + conversation_history = [] + for item in (history_data or []): + if not isinstance(item, dict): + continue + role_str = item.get('role', '') + content = item.get('content', '') + try: + role = Role(role_str) + except ValueError: + continue + conversation_history.append( + Message(role=role, content=content) + ) - # Call the LLM with database tools and history - response_text, updated_history = chat_with_database( + # Stream the LLM response with database tools + response_text = '' + updated_messages = [] + for item in chat_with_database_stream( user_message=user_message, sid=trans_obj.sid, did=trans_obj.did, system_prompt=NLQ_SYSTEM_PROMPT, conversation_history=conversation_history - ) - - # Try to parse the response as JSON - sql = None - explanation = '' - - # First, try to extract JSON from markdown code blocks - json_text = response_text.strip() - - # Look for ```json ... ``` blocks - json_match = re.search( - r'```json\s*\n?(.*?)\n?```', - json_text, - re.DOTALL - ) - if json_match: - json_text = json_match.group(1).strip() - else: - # Also try to find a plain JSON object in the response - # Look for {"sql": ... } pattern anywhere in the text - sql_pattern = ( - r'\{["\']?sql["\']?\s*:\s*' - r'(?:null|"[^"]*"|\'[^\']*\').*?\}' - ) - plain_json_match = re.search(sql_pattern, json_text, re.DOTALL) - if plain_json_match: - json_text = plain_json_match.group(0) - - try: - result = json.loads(json_text) - sql = result.get('sql') - explanation = result.get('explanation', '') - except (json.JSONDecodeError, TypeError): - # If not valid JSON, try to extract SQL from the response - # Look for SQL code blocks first - sql_match = re.search( - r'```sql\s*\n?(.*?)\n?```', - response_text, - re.DOTALL - ) - if sql_match: - sql = sql_match.group(1).strip() - else: - # Check for malformed tool call text patterns - # Some models output tool calls as text instead of - # proper tool use blocks - tool_call_match = re.search( - r'\s*' - r'\s*(.*?)\s*', - response_text, - re.DOTALL - ) - if tool_call_match: - sql = tool_call_match.group(1).strip() - explanation = gettext( - 'Generated SQL query from your request.' + ): + if isinstance(item, str): + # Text chunk from streaming LLM response + yield _nlq_sse_event({ + 'type': 'text_delta', + 'content': item + }) + elif isinstance(item, tuple) and \ + item[0] == 'tool_use': + # Tool execution in progress - reset streaming + yield _nlq_sse_event({ + 'type': 'thinking', + 'message': gettext( + 'Querying the database...' ) - else: - # No parseable JSON or SQL block found - # Treat the response as an explanation/error message - explanation = response_text.strip() - # Don't set sql - leave it as None + }) + elif isinstance(item, tuple) and \ + item[0] == 'complete': + # Final result: ('complete', response_text, messages) + response_text = item[1] + updated_messages = item[2] + + # Extract SQL from markdown code fences + sql_blocks = re.findall( + r'```(?:sql|pgsql|postgresql)\s*\n(.*?)```', + response_text, + re.DOTALL | re.IGNORECASE + ) + sql = ';\n\n'.join( + block.strip().rstrip(';') for block in sql_blocks + ) if sql_blocks else None + + # Fallback: try JSON format in case LLM ignored + # the markdown instruction + if sql is None: + try: + result = json.loads(response_text.strip()) + if isinstance(result, dict): + sql = result.get('sql') + except (json.JSONDecodeError, TypeError): + pass # Generate a conversation ID if not provided if not conversation_id: @@ -2984,24 +2969,29 @@ def generate(): else: new_conversation_id = conversation_id - # Serialize updated history for the frontend. - # Only include conversational messages (user + final - # assistant responses) to keep history size manageable. - # Internal tool call/result messages are ephemeral to - # each turn and don't need to round-trip. - from pgadmin.llm.compaction import filter_conversational - serialized_history = [ - m.to_dict() for m in - filter_conversational(updated_history) - ] if updated_history else [] - - # Send the final result + # Serialize the conversation history so the client can + # round-trip it on follow-up turns. Only keep user + # messages and final assistant responses (no tool calls). + history = [] + for m in updated_messages: + if m.role == Role.USER: + history.append({ + 'role': m.role.value, + 'content': m.content, + }) + elif m.role == Role.ASSISTANT and not m.tool_calls: + history.append({ + 'role': m.role.value, + 'content': m.content, + }) + + # Send the final result with full response content yield _nlq_sse_event({ 'type': 'complete', 'sql': sql, - 'explanation': explanation, + 'content': response_text, 'conversation_id': new_conversation_id, - 'history': serialized_history + 'history': history }) except Exception as e: diff --git a/web/pgadmin/tools/sqleditor/static/js/components/sections/NLQChatPanel.jsx b/web/pgadmin/tools/sqleditor/static/js/components/sections/NLQChatPanel.jsx index d484a04955f..838f60c8831 100644 --- a/web/pgadmin/tools/sqleditor/static/js/components/sections/NLQChatPanel.jsx +++ b/web/pgadmin/tools/sqleditor/static/js/components/sections/NLQChatPanel.jsx @@ -23,6 +23,8 @@ import AddIcon from '@mui/icons-material/Add'; import ClearAllIcon from '@mui/icons-material/ClearAll'; import AutoFixHighIcon from '@mui/icons-material/AutoFixHigh'; import { format as formatSQL } from 'sql-formatter'; +import { marked } from 'marked'; +import DOMPurify from 'dompurify'; import gettext from 'sources/gettext'; import url_for from 'sources/url_for'; import getApiInstance from '../../../../../../static/js/api_instance'; @@ -106,7 +108,6 @@ const SQLPreviewBox = styled(Box)(({ theme }) => ({ borderRadius: theme.spacing(0.5), overflow: 'auto', '& .cm-editor': { - minHeight: '60px', maxHeight: '250px', }, '& .cm-scroller': { @@ -132,17 +133,173 @@ const ThinkingIndicator = styled(Box)(({ theme }) => ({ color: theme.palette.text.secondary, })); +const MarkdownContent = styled(Box)(({ theme }) => ({ + fontSize: theme.typography.body2.fontSize, + lineHeight: theme.typography.body2.lineHeight, + '& p': { margin: `${theme.spacing(0.5)} 0` }, + '& p:first-of-type': { marginTop: 0 }, + '& p:last-of-type': { marginBottom: 0 }, + '& code': { + backgroundColor: theme.palette.action.hover, + padding: '1px 4px', + borderRadius: 3, + fontSize: '0.85em', + fontFamily: 'monospace', + }, + '& pre': { + backgroundColor: theme.palette.action.hover, + padding: theme.spacing(1), + borderRadius: 4, + overflow: 'auto', + '& code': { + backgroundColor: 'transparent', + padding: 0, + }, + }, + '& h1, & h2, & h3, & h4, & h5, & h6': { + margin: `${theme.spacing(1)} 0 ${theme.spacing(0.5)} 0`, + lineHeight: 1.3, + }, + '& h1': { fontSize: '1.3em' }, + '& h2': { fontSize: '1.2em' }, + '& h3': { fontSize: '1.1em' }, + '& ul': { + margin: `${theme.spacing(0.5)} 0`, + paddingLeft: theme.spacing(2.5), + listStyleType: 'disc !important', + }, + '& ol': { + margin: `${theme.spacing(0.5)} 0`, + paddingLeft: theme.spacing(2.5), + listStyleType: 'decimal !important', + }, + '& li': { + margin: `${theme.spacing(0.25)} 0`, + display: 'list-item !important', + listStyle: 'inherit !important', + }, + '& ul ul': { listStyleType: 'circle !important' }, + '& ul ul ul': { listStyleType: 'square !important' }, + '& table': { + borderCollapse: 'collapse', + margin: `${theme.spacing(0.5)} 0`, + width: '100%', + }, + '& th, & td': { + border: `1px solid ${theme.otherVars.borderColor}`, + padding: `${theme.spacing(0.25)} ${theme.spacing(0.75)}`, + textAlign: 'left', + }, + '& th': { + backgroundColor: theme.palette.action.hover, + fontWeight: 600, + }, + '& blockquote': { + borderLeft: `3px solid ${theme.otherVars.borderColor}`, + margin: `${theme.spacing(0.5)} 0`, + paddingLeft: theme.spacing(1), + opacity: 0.85, + }, + '& strong': { fontWeight: 600 }, + '& a': { + color: theme.otherVars.hyperlinkColor, + textDecoration: 'underline', + }, +})); + // Message types const MESSAGE_TYPES = { USER: 'user', ASSISTANT: 'assistant', SQL: 'sql', THINKING: 'thinking', + STREAMING: 'streaming', ERROR: 'error', }; +/** + * Incrementally parse streaming markdown text into an ordered list of + * segments. Each segment is: + * { type: 'text', content: string } + * { type: 'code', language: string, content: string, complete: boolean } + * + * Handles ```language fenced code blocks. Segments appear in the order + * the LLM streams them so the renderer can map straight over the array. + */ +function parseMarkdownSegments(text) { + const segments = []; + let pos = 0; + + while (pos < text.length) { + const fenceIdx = text.indexOf('```', pos); + + if (fenceIdx === -1) { + // No more fences — rest is text + const content = text.substring(pos); + if (content) segments.push({ type: 'text', content }); + break; + } + + // Text before the fence + if (fenceIdx > pos) { + segments.push({ type: 'text', content: text.substring(pos, fenceIdx) }); + } + + // Parse opening fence line: ```language\n + const afterFence = text.substring(fenceIdx + 3); + const langMatch = /^([a-zA-Z]*)\n/.exec(afterFence); + if (!langMatch) { + // Language line not complete yet — wait for more tokens + break; + } + + const language = langMatch[1].toLowerCase(); + const codeStart = fenceIdx + 3 + langMatch[0].length; + + // Find closing fence + const closeIdx = text.indexOf('```', codeStart); + if (closeIdx === -1) { + // Still streaming code block content + segments.push({ + type: 'code', language, + content: text.substring(codeStart), + complete: false, + }); + break; + } + + // Complete code block — trim trailing newline before closing fence + let codeContent = text.substring(codeStart, closeIdx); + if (codeContent.endsWith('\n')) { + codeContent = codeContent.slice(0, -1); + } + segments.push({ + type: 'code', language, + content: codeContent, + complete: true, + }); + + // Move past closing ``` and optional trailing newline + pos = closeIdx + 3; + if (pos < text.length && text[pos] === '\n') pos++; + } + + return segments; +} + +/** + * Render a markdown text fragment to sanitized HTML. + * Uses marked for inline formatting (bold, italic, code, lists, tables, etc.) + * and DOMPurify to prevent XSS. + */ +function renderMarkdownText(text) { + if (!text) return ''; + const html = marked.parse(text, { gfm: true, breaks: true }); + return DOMPurify.sanitize(html); +} + // Single chat message component -function ChatMessage({ message, onInsertSQL, onReplaceSQL, textColors, cmKey }) { +function ChatMessage({ message, onInsertSQL, onReplaceSQL, textColors, cmKey, formatSqlWithPrefs }) { if (message.type === MESSAGE_TYPES.USER) { return ( @@ -152,58 +309,117 @@ function ChatMessage({ message, onInsertSQL, onReplaceSQL, textColors, cmKey }) } if (message.type === MESSAGE_TYPES.SQL) { + const segments = message.content + ? parseMarkdownSegments(message.content) : []; + + // Fallback for messages without markdown content (old format) + if (segments.length === 0 && message.sql) { + return ( + + + + + {gettext('Generated SQL')} + + + + onInsertSQL(message.sql)}> + + + + + onReplaceSQL(message.sql)}> + + + + + navigator.clipboard.writeText(message.sql)}> + + + + + + + + + + {message.explanation && ( + {message.explanation} + )} + + ); + } + + // Render markdown segments with action buttons on code blocks return ( - {message.explanation && ( - - {message.explanation} - - )} - - - - {gettext('Generated SQL')} - - - - onInsertSQL(message.sql)} - > - - - - - onReplaceSQL(message.sql)} - > - - - - - navigator.clipboard.writeText(message.sql)} - > - - - - - - - - - + {segments.map((seg, idx) => { + if (seg.type === 'text') { + const content = seg.content?.trim(); + if (!content) return null; + return ( + 0 ? 1 : 0 }} + dangerouslySetInnerHTML={{ __html: renderMarkdownText(content) }} + /> + ); + } + + if (seg.type === 'code') { + const isSql = ['sql', 'pgsql', 'postgresql'].includes(seg.language); + const formattedCode = isSql ? formatSqlWithPrefs(seg.content) : seg.content; + + return ( + + + + {seg.language || gettext('Code')} + + + {isSql && ( + <> + + onInsertSQL(formattedCode)}> + + + + + onReplaceSQL(formattedCode)}> + + + + + )} + + navigator.clipboard.writeText(formattedCode)}> + + + + + + + + + + ); + } + + return null; + })} ); } @@ -224,6 +440,105 @@ function ChatMessage({ message, onInsertSQL, onReplaceSQL, textColors, cmKey }) ); } + if (message.type === MESSAGE_TYPES.STREAMING) { + const segments = parseMarkdownSegments(message.content); + const BlinkingCursor = ( + + ); + + // No segments parsed yet — show raw text or spinner + if (segments.length === 0) { + return ( + + {message.content ? ( + + {message.content} + {BlinkingCursor} + + ) : ( + + + + {gettext('Generating response...')} + + + )} + + ); + } + + // Render markdown segments in order + const lastIdx = segments.length - 1; + return ( + + {segments.map((seg, idx) => { + const isLast = idx === lastIdx; + const cursor = isLast && !seg.complete ? BlinkingCursor : null; + + if (seg.type === 'code') { + return ( + + + + {seg.complete + ? (seg.language || gettext('Code')) + : gettext('Generating...')} + + + + + {seg.content} + {cursor} + + + + ); + } + + const content = seg.content?.trim(); + if (!content && !cursor) return null; + return ( + 0 ? 1 : 0 }}> + + {cursor} + + ); + })} + + ); + } + if (message.type === MESSAGE_TYPES.ERROR) { return ( - {message.content} + ); } @@ -272,6 +589,8 @@ export function NLQChatPanel() { const readerRef = useRef(null); const stoppedRef = useRef(false); const clearedRef = useRef(false); + const streamingTextRef = useRef(''); + const streamingIdRef = useRef(null); const eventBus = useContext(QueryToolEventsContext); const queryToolCtx = useContext(QueryToolContext); const editorPrefs = usePreferences().getPreferencesForModule('editor'); @@ -410,7 +729,6 @@ export function NLQChatPanel() { setMessages([]); setConversationId(null); setConversationHistory([]); - setIsLoading(false); }; // Stop the current request @@ -448,9 +766,11 @@ export function NLQChatPanel() { const handleSubmit = async () => { if (!inputValue.trim() || isLoading) return; - // Reset stopped and cleared flags + // Reset stopped, cleared flags and streaming state stoppedRef.current = false; clearedRef.current = false; + streamingTextRef.current = ''; + streamingIdRef.current = null; // Fetch latest LLM provider/model info before submitting fetchLlmInfo(); @@ -553,36 +873,70 @@ export function NLQChatPanel() { // Check if user manually stopped (but not cleared) if (stoppedRef.current && !clearedRef.current) { - setMessages((prev) => [ - ...prev.filter((m) => m.id !== thinkingId), - { - type: MESSAGE_TYPES.ASSISTANT, - content: gettext('Generation stopped.'), - }, - ]); + const streamId = streamingIdRef.current; + // If we have partial streaming content, show it separately + // from the stop notice to avoid breaking open markdown fences + if (streamingTextRef.current) { + setMessages((prev) => [ + ...prev.filter((m) => m.id !== thinkingId && m.id !== streamId), + { + type: MESSAGE_TYPES.ASSISTANT, + content: streamingTextRef.current, + }, + { + type: MESSAGE_TYPES.ASSISTANT, + content: gettext('Generation stopped.'), + }, + ]); + } else { + setMessages((prev) => [ + ...prev.filter((m) => m.id !== thinkingId), + { + type: MESSAGE_TYPES.ASSISTANT, + content: gettext('Generation stopped.'), + }, + ]); + } + streamingTextRef.current = ''; + streamingIdRef.current = null; } } catch (error) { clearTimeout(timeoutId); abortControllerRef.current = null; readerRef.current = null; + const streamId = streamingIdRef.current; // If conversation was cleared, ignore all late errors if (clearedRef.current) { // Do nothing - conversation was wiped } else if (error.name === 'AbortError') { // Check if this was a user-initiated stop or a timeout if (stoppedRef.current) { - // User manually stopped - setMessages((prev) => [ - ...prev.filter((m) => m.id !== thinkingId), - { - type: MESSAGE_TYPES.ASSISTANT, - content: gettext('Generation stopped.'), - }, - ]); + // User manually stopped - show partial content separately + if (streamingTextRef.current) { + setMessages((prev) => [ + ...prev.filter((m) => m.id !== thinkingId && m.id !== streamId), + { + type: MESSAGE_TYPES.ASSISTANT, + content: streamingTextRef.current, + }, + { + type: MESSAGE_TYPES.ASSISTANT, + content: gettext('Generation stopped.'), + }, + ]); + } else { + setMessages((prev) => [ + ...prev.filter((m) => m.id !== thinkingId), + { + type: MESSAGE_TYPES.ASSISTANT, + content: gettext('Generation stopped.'), + }, + ]); + } } else { // Timeout occurred setMessages((prev) => [ - ...prev.filter((m) => m.id !== thinkingId), + ...prev.filter((m) => m.id !== thinkingId && m.id !== streamId), { type: MESSAGE_TYPES.ERROR, content: gettext('Request timed out. The query may be too complex. Please try a simpler request.'), @@ -591,13 +945,15 @@ export function NLQChatPanel() { } } else { setMessages((prev) => [ - ...prev.filter((m) => m.id !== thinkingId), + ...prev.filter((m) => m.id !== thinkingId && m.id !== streamId), { type: MESSAGE_TYPES.ERROR, content: gettext('Failed to generate SQL: ') + error.message, }, ]); } + streamingTextRef.current = ''; + streamingIdRef.current = null; } finally { setIsLoading(false); setThinkingMessageId(null); @@ -606,32 +962,82 @@ export function NLQChatPanel() { const handleSSEEvent = (event, thinkingId) => { switch (event.type) { - case 'thinking': - setMessages((prev) => - prev.map((m) => - m.id === thinkingId ? { ...m, content: event.message } : m - ) - ); + case 'thinking': { + const streamId = streamingIdRef.current; + if (streamId) { + // Transition from streaming back to thinking (tool use) + // Remove streaming message and re-add thinking indicator + streamingTextRef.current = ''; + streamingIdRef.current = null; + setMessages((prev) => [ + ...prev.filter((m) => m.id !== streamId), + { + type: MESSAGE_TYPES.THINKING, + content: event.message, + id: thinkingId, + }, + ]); + setThinkingMessageId(thinkingId); + } else { + setMessages((prev) => + prev.map((m) => + m.id === thinkingId ? { ...m, content: event.message } : m + ) + ); + } break; + } - case 'sql': - case 'complete': - // If sql is null/empty, show as regular assistant message (e.g., clarification questions) - if (!event.sql) { + case 'text_delta': + streamingTextRef.current += event.content; + if (!streamingIdRef.current) { + // First text chunk: replace thinking with streaming message + streamingIdRef.current = Date.now(); setMessages((prev) => [ ...prev.filter((m) => m.id !== thinkingId), { - type: MESSAGE_TYPES.ASSISTANT, - content: event.explanation || gettext('I need more information to generate the SQL.'), + type: MESSAGE_TYPES.STREAMING, + content: streamingTextRef.current, + id: streamingIdRef.current, }, ]); } else { + // Update existing streaming message + const sid = streamingIdRef.current; + setMessages((prev) => + prev.map((m) => + m.id === sid ? { ...m, content: streamingTextRef.current } : m + ) + ); + } + break; + + case 'sql': + case 'complete': { + const streamId = streamingIdRef.current; + const content = event.content || event.explanation + || gettext('I need more information to generate the SQL.'); + // Use SQL type if there's SQL or any code fences in the response + const hasCodeBlocks = event.sql || (content && content.includes('```')); + if (hasCodeBlocks) { + // When SQL was extracted via JSON fallback (no fenced blocks), + // clear content so ChatMessage uses the sql-only render path + const msgContent = (content && content.includes('```')) + ? content : null; setMessages((prev) => [ - ...prev.filter((m) => m.id !== thinkingId), + ...prev.filter((m) => m.id !== thinkingId && m.id !== streamId), { type: MESSAGE_TYPES.SQL, - sql: formatSqlWithPrefs(event.sql), - explanation: event.explanation, + content: msgContent, + sql: event.sql, + }, + ]); + } else { + setMessages((prev) => [ + ...prev.filter((m) => m.id !== thinkingId && m.id !== streamId), + { + type: MESSAGE_TYPES.ASSISTANT, + content, }, ]); } @@ -641,18 +1047,26 @@ export function NLQChatPanel() { if (event.history) { setConversationHistory(event.history); } + // Reset streaming state + streamingTextRef.current = ''; + streamingIdRef.current = null; break; + } - case 'error': + case 'error': { + const streamId = streamingIdRef.current; setMessages((prev) => [ - ...prev.filter((m) => m.id !== thinkingId), + ...prev.filter((m) => m.id !== thinkingId && m.id !== streamId), { type: MESSAGE_TYPES.ERROR, content: event.message, }, ]); + streamingTextRef.current = ''; + streamingIdRef.current = null; break; } + } }; const handleKeyDown = (e) => { @@ -745,6 +1159,7 @@ export function NLQChatPanel() { onReplaceSQL={handleReplaceSQL} textColors={textColors} cmKey={cmKey} + formatSqlWithPrefs={formatSqlWithPrefs} /> )) )} diff --git a/web/pgadmin/tools/sqleditor/tests/test_nlq_chat.py b/web/pgadmin/tools/sqleditor/tests/test_nlq_chat.py index 38feff067e1..360827abfbc 100644 --- a/web/pgadmin/tools/sqleditor/tests/test_nlq_chat.py +++ b/web/pgadmin/tools/sqleditor/tests/test_nlq_chat.py @@ -10,6 +10,7 @@ """Tests for the NLQ (Natural Language Query) chat endpoint.""" import json +import re from unittest.mock import patch, MagicMock from pgadmin.utils.route import BaseTestGenerator @@ -43,8 +44,9 @@ class NLQChatTestCase(BaseTestGenerator): message='Find all users', expected_error=False, mock_response=( - '{"sql": "SELECT * FROM users;", ' - '"explanation": "Gets all users"}' + 'Here are all users:\n\n' + '```sql\nSELECT * FROM users;\n```\n\n' + 'This retrieves all rows from the users table.' ) )), ('NLQ Chat - With History', dict( @@ -108,16 +110,15 @@ def runTest(self): ) patches.append(mock_check_trans) - # Mock chat_with_database — patch the source module because the - # endpoint uses a local import (from pgadmin.llm.chat import ...) - # inside the function body, so there is no module-level binding - # to patch at the use site. + # Mock chat_with_database_stream mock_chat_patcher = None - mock_chat_obj = None if hasattr(self, 'mock_response'): + def mock_stream_gen(*args, **kwargs): + yield self.mock_response + yield ('complete', self.mock_response, []) mock_chat_patcher = patch( - 'pgadmin.llm.chat.chat_with_database', - return_value=(self.mock_response, []) + 'pgadmin.llm.chat.chat_with_database_stream', + side_effect=mock_stream_gen ) patches.append(mock_chat_patcher) @@ -133,8 +134,6 @@ def runTest(self): for p in patches: m = p.start() started_mocks.append(m) - if p is mock_chat_patcher: - mock_chat_obj = m try: # Make request @@ -166,22 +165,9 @@ def runTest(self): self.assertIn('text/event-stream', response.content_type) # Consume the SSE stream so the generator executes - # fully (including the chat_with_database call) + # fully (including the chat_with_database_stream call) _ = response.data - # Verify history was passed to chat_with_database - if hasattr(self, 'history') and mock_chat_obj: - mock_chat_obj.assert_called_once() - call_kwargs = mock_chat_obj.call_args.kwargs - conv_hist = call_kwargs.get( - 'conversation_history', [] - ) - self.assertTrue( - len(conv_hist) > 0, - 'conversation_history should be non-empty ' - 'when history is provided' - ) - finally: # Stop all patches for p in patches: @@ -216,3 +202,248 @@ def runTest(self): def tearDown(self): pass + + +class NLQSqlExtractionTestCase(BaseTestGenerator): + """Test cases for SQL extraction from markdown responses""" + + scenarios = [ + ('SQL Extraction - Single SQL block', dict( + response_text=( + 'Here is the query:\n\n' + '```sql\nSELECT * FROM users;\n```\n\n' + 'This returns all users.' + ), + expected_sql='SELECT * FROM users' + )), + ('SQL Extraction - Multiple SQL blocks', dict( + response_text=( + 'First get users:\n\n' + '```sql\nSELECT * FROM users;\n```\n\n' + 'Then get orders:\n\n' + '```sql\nSELECT * FROM orders;\n```' + ), + expected_sql='SELECT * FROM users;\n\nSELECT * FROM orders' + )), + ('SQL Extraction - pgsql language tag', dict( + response_text='```pgsql\nSELECT 1;\n```', + expected_sql='SELECT 1' + )), + ('SQL Extraction - postgresql language tag', dict( + response_text='```postgresql\nSELECT 1;\n```', + expected_sql='SELECT 1' + )), + ('SQL Extraction - No SQL blocks', dict( + response_text=( + 'I cannot generate a query without ' + 'knowing your table structure.' + ), + expected_sql=None + )), + ('SQL Extraction - Non-SQL code block only', dict( + response_text=( + 'Here is some Python:\n\n' + '```python\nprint("hello")\n```' + ), + expected_sql=None + )), + ('SQL Extraction - JSON fallback', dict( + response_text='{"sql": "SELECT 1;", "explanation": "test"}', + expected_sql='SELECT 1;' + )), + ('SQL Extraction - Multiline SQL', dict( + response_text=( + '```sql\n' + 'SELECT u.name, o.total\n' + 'FROM users u\n' + 'JOIN orders o ON u.id = o.user_id\n' + 'WHERE o.total > 100;\n' + '```' + ), + expected_sql=( + 'SELECT u.name, o.total\n' + 'FROM users u\n' + 'JOIN orders o ON u.id = o.user_id\n' + 'WHERE o.total > 100' + ) + )), + ] + + def setUp(self): + pass + + def runTest(self): + """Test SQL extraction from markdown response text""" + response_text = self.response_text + + # Extract SQL using the same regex as the endpoint + sql_blocks = re.findall( + r'```(?:sql|pgsql|postgresql)\s*\n(.*?)```', + response_text, + re.DOTALL | re.IGNORECASE + ) + sql = ';\n\n'.join( + block.strip().rstrip(';') for block in sql_blocks + ) if sql_blocks else None + + # JSON fallback + if sql is None: + try: + result = json.loads(response_text.strip()) + if isinstance(result, dict): + sql = result.get('sql') + except (json.JSONDecodeError, TypeError): + pass + + self.assertEqual(sql, self.expected_sql) + + def tearDown(self): + pass + + +class NLQStreamingSSETestCase(BaseTestGenerator): + """Test cases for SSE event format in streaming responses""" + + scenarios = [ + ('SSE - Text with SQL produces complete event', dict( + mock_response=( + '```sql\nSELECT 1;\n```' + ), + check_complete_has_sql=True + )), + ('SSE - Text without SQL has no sql field', dict( + mock_response='I need more information about your schema.', + check_complete_has_sql=False + )), + ] + + def setUp(self): + pass + + def runTest(self): + """Test SSE events from NLQ streaming endpoint""" + trans_id = 12345 + + patches = [] + + mock_llm_enabled = patch( + 'pgadmin.llm.utils.is_llm_enabled', + return_value=True + ) + patches.append(mock_llm_enabled) + + mock_trans_obj = MagicMock() + mock_trans_obj.sid = 1 + mock_trans_obj.did = 1 + + mock_conn = MagicMock() + mock_conn.connected.return_value = True + + mock_session = {'sid': 1, 'did': 1} + + mock_check_trans = patch( + 'pgadmin.tools.sqleditor.check_transaction_status', + return_value=( + True, None, mock_conn, mock_trans_obj, mock_session + ) + ) + patches.append(mock_check_trans) + + def mock_stream_gen(*args, **kwargs): + # Yield text chunks + for chunk in [self.mock_response[i:i + 10] + for i in range(0, len(self.mock_response), 10)]: + yield chunk + # Yield final 3-tuple + yield ('complete', self.mock_response, []) + + mock_chat = patch( + 'pgadmin.llm.chat.chat_with_database_stream', + side_effect=mock_stream_gen + ) + patches.append(mock_chat) + + mock_csrf = patch( + 'pgadmin.authenticate.mfa.utils.mfa_required', + lambda f: f + ) + patches.append(mock_csrf) + + for p in patches: + p.start() + + try: + response = self.tester.post( + f'/sqleditor/nlq/chat/{trans_id}/stream', + data=json.dumps({'message': 'test query'}), + content_type='application/json', + follow_redirects=True + ) + + self.assertEqual(response.status_code, 200) + self.assertIn('text/event-stream', response.content_type) + + # Parse SSE events + events = [] + raw = response.data.decode('utf-8') + for line in raw.split('\n'): + if line.startswith('data: '): + try: + events.append(json.loads(line[6:])) + except json.JSONDecodeError: + pass + + # Should have at least one text_delta and one complete + event_types = [e.get('type') for e in events] + self.assertIn('text_delta', event_types) + self.assertIn('complete', event_types) + + # Check the complete event + complete_events = [ + e for e in events if e.get('type') == 'complete' + ] + self.assertEqual(len(complete_events), 1) + complete = complete_events[0] + + # Verify content is present + self.assertIn('content', complete) + self.assertEqual(complete['content'], self.mock_response) + + # Verify SQL extraction + if self.check_complete_has_sql: + self.assertIsNotNone(complete.get('sql')) + else: + self.assertIsNone(complete.get('sql')) + + finally: + for p in patches: + p.stop() + + def tearDown(self): + pass + + +class NLQPromptMarkdownFormatTestCase(BaseTestGenerator): + """Test that NLQ prompt instructs markdown code fences""" + + scenarios = [ + ('NLQ Prompt - Markdown format', dict()), + ] + + def setUp(self): + pass + + def runTest(self): + """Test NLQ prompt requires markdown SQL code fences""" + from pgadmin.llm.prompts.nlq import NLQ_SYSTEM_PROMPT + + # Prompt should instruct use of fenced code blocks + self.assertIn('fenced code block', NLQ_SYSTEM_PROMPT.lower()) + self.assertIn('sql', NLQ_SYSTEM_PROMPT.lower()) + + # Should NOT instruct JSON format + self.assertNotIn('"sql":', NLQ_SYSTEM_PROMPT) + self.assertNotIn('"explanation":', NLQ_SYSTEM_PROMPT) + + def tearDown(self): + pass