From 02855333817cbfc551cf0c1bafe6984a2b4dfcc0 Mon Sep 17 00:00:00 2001 From: Johann Schleier-Smith Date: Wed, 10 Sep 2025 11:44:17 -0700 Subject: [PATCH 1/3] sessions on top of postgres --- openai_agents/memory/README.md | 34 ++ openai_agents/memory/connection_state.py | 43 ++ openai_agents/memory/db_utils.py | 84 +++ openai_agents/memory/postgres_session.py | 286 ++++++++++ .../memory/run_postgres_session_worker.py | 71 +++ .../memory/run_postgres_session_workflow.py | 34 ++ .../workflows/postgres_session_workflow.py | 80 +++ pyproject.toml | 11 +- tests/conftest.py | 38 ++ .../memory/test_idempotence_util.py | 324 +++++++++++ .../memory/test_postgres_session.py | 504 ++++++++++++++++++ uv.lock | 77 ++- 12 files changed, 1565 insertions(+), 21 deletions(-) create mode 100644 openai_agents/memory/README.md create mode 100644 openai_agents/memory/connection_state.py create mode 100644 openai_agents/memory/db_utils.py create mode 100644 openai_agents/memory/postgres_session.py create mode 100644 openai_agents/memory/run_postgres_session_worker.py create mode 100644 openai_agents/memory/run_postgres_session_workflow.py create mode 100644 openai_agents/memory/workflows/postgres_session_workflow.py create mode 100644 tests/openai_agents/memory/test_idempotence_util.py create mode 100644 tests/openai_agents/memory/test_postgres_session.py diff --git a/openai_agents/memory/README.md b/openai_agents/memory/README.md new file mode 100644 index 00000000..1b51b974 --- /dev/null +++ b/openai_agents/memory/README.md @@ -0,0 +1,34 @@ +# Session Memory Examples + +Session memory examples for OpenAI Agents SDK integrated with Temporal workflows. + +*Adapted from [OpenAI Agents SDK session memory examples](https://github.com/openai/openai-agents-python/tree/main/examples/memory)* + +Before running these examples, be sure to review the [prerequisites and background on the integration](../README.md). + +## Running the Examples + +### PostgreSQL Session Memory + +This example uses a PostgreSQL database to store session data. + +You need can use the standard PostgreSQL environment variables to configure the database connection. +These include `PGDATABASE`, `PGUSER`, `PGPASSWORD`, `PGHOST`, and `PGPORT`. +We also support the `DATABASE_URL` environment variable. + +To confirm that your environment is configured correctly, just run the `psql` command after setting the environment variables. +For example: +```bash +PGDATABASE=postgres psql +``` + +Start the worker: +```bash +PGDATABASE=postgres uv run openai_agents/memory/run_postgres_session_worker.py +``` + +Then run the workflow: + +```bash +uv run openai_agents/memory/run_postgres_session_workflow.py +``` diff --git a/openai_agents/memory/connection_state.py b/openai_agents/memory/connection_state.py new file mode 100644 index 00000000..f17fc3f4 --- /dev/null +++ b/openai_agents/memory/connection_state.py @@ -0,0 +1,43 @@ +"""Worker-level database connection state management. + +WARNING: This implementation uses global state and is not safe for concurrent +testing (e.g., pytest-xdist). Run tests sequentially to avoid race conditions. +""" + +import asyncpg +from typing import Optional + + +# Module-level connection state +_connection: Optional[asyncpg.Connection] = None + + +def set_worker_connection(connection: asyncpg.Connection) -> None: + """Set the worker-level database connection.""" + global _connection + _connection = connection + + +def get_worker_connection() -> asyncpg.Connection: + """Get the worker-level database connection. + + Raises: + RuntimeError: If no connection has been set. + """ + if _connection is None: + raise RuntimeError( + "No worker-level database connection has been set. " + "Call set_worker_connection() before using activities." + ) + return _connection + + +def clear_worker_connection() -> None: + """Clear the worker-level database connection.""" + global _connection + _connection = None + + +def has_worker_connection() -> bool: + """Check if a worker-level connection is available.""" + return _connection is not None diff --git a/openai_agents/memory/db_utils.py b/openai_agents/memory/db_utils.py new file mode 100644 index 00000000..a2e16a2b --- /dev/null +++ b/openai_agents/memory/db_utils.py @@ -0,0 +1,84 @@ +import json +import asyncpg +from typing import Callable, Awaitable, TypeVar +from temporalio import activity +from pydantic import BaseModel + +T = TypeVar("T") + + +class IdempotenceHelper(BaseModel): + table_name: str + + def __init__(self, table_name: str): + super().__init__(table_name=table_name) + self.table_name = table_name + + async def create_table(self, conn: asyncpg.Connection) -> None: + await conn.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.table_name} ( + run_id UUID NOT NULL, + activity_id TEXT NOT NULL, + operation_started_at TIMESTAMP NOT NULL, + operation_completed_at TIMESTAMP NULL, + operation_result TEXT NULL, + PRIMARY KEY (run_id, activity_id) + ) + """ + ) + + async def idempotent_update( + self, + conn: asyncpg.Connection, + operation: Callable[[asyncpg.Connection], Awaitable[T]], + ) -> T | None: + """Insert idempotence row; on conflict, read and return existing result. + + The operation must be an async callable of the form: + async def op(conn: asyncpg.Connection) -> T + """ + activity_info = activity.info() + run_id = activity_info.workflow_run_id + activity_id = activity_info.activity_id + + async with conn.transaction(): + did_insert = await conn.fetchrow( + ( + f"INSERT INTO {self.table_name} " + f"(run_id, activity_id, operation_started_at) " + f"VALUES ($1, $2, NOW()) " + f"ON CONFLICT (run_id, activity_id) DO NOTHING " + f"RETURNING 1" + ), + run_id, + activity_id, + ) + + if did_insert: + res = await operation(conn) + + if hasattr(res, "model_dump_json"): + op_result = res.model_dump_json() + else: + op_result = json.dumps(res) + + await conn.execute( + f"UPDATE {self.table_name} SET operation_completed_at = NOW(), operation_result = $1 WHERE run_id = $2 AND activity_id = $3", + op_result, + run_id, + activity_id, + ) + return res + else: + row = await conn.fetchrow( + f"SELECT operation_result FROM {self.table_name} WHERE run_id = $1 AND activity_id = $2", + run_id, + activity_id, + ) + if not row or row["operation_result"] is None: + return None + try: + return json.loads(row["operation_result"]) + except Exception: + return row["operation_result"] diff --git a/openai_agents/memory/postgres_session.py b/openai_agents/memory/postgres_session.py new file mode 100644 index 00000000..8936a03c --- /dev/null +++ b/openai_agents/memory/postgres_session.py @@ -0,0 +1,286 @@ +from __future__ import annotations + +import json +from typing import Any + +import asyncpg +from temporalio import activity, workflow +from pydantic import BaseModel +from agents.memory.session import SessionABC +from agents.items import TResponseInputItem +from openai_agents.memory.db_utils import IdempotenceHelper +from typing import Callable + +_connection_factory: Callable[[], asyncpg.Connection] | None = None + + +def _convert_to_json_serializable(obj: Any) -> Any: + """Recursively convert objects to JSON serializable format.""" + if obj is None or isinstance(obj, (str, int, float, bool)): + return obj + + # Handle Pydantic models + if hasattr(obj, "model_dump"): + return _convert_to_json_serializable(obj.model_dump()) + elif hasattr(obj, "dict"): + return _convert_to_json_serializable(obj.dict()) + + # Handle dictionaries + if isinstance(obj, dict): + return {key: _convert_to_json_serializable(value) for key, value in obj.items()} + + # Handle lists, tuples, sets + if isinstance(obj, (list, tuple, set)): + return [_convert_to_json_serializable(item) for item in obj] + + # Handle other iterables (including ValidatorIterator) + if hasattr(obj, "__iter__") and not isinstance(obj, (str, bytes)): + try: + return [_convert_to_json_serializable(item) for item in obj] + except Exception: + # If iteration fails, try to convert to string + return str(obj) + + # Handle objects with __dict__ + if hasattr(obj, "__dict__"): + return _convert_to_json_serializable(obj.__dict__) + + # Fallback to string representation + return str(obj) + + +class PostgresSessionConfig(BaseModel): + messages_table: str = "session_messages" + sessions_table: str = "session" + operation_id_sequence: str = "session_operation_id_sequence" + idempotence_table: str = "activity_idempotence" + + +async def init_schema(conn: asyncpg.Connection, config: PostgresSessionConfig) -> None: + """Initialize the PostgreSQL schema.""" + async with conn.transaction(): + # Create sessions table + sessions_ddl = f""" + CREATE TABLE IF NOT EXISTS {config.sessions_table} ( + session_id TEXT NOT NULL, + created_at TIMESTAMP DEFAULT NOW(), + updated_at TIMESTAMP DEFAULT NOW(), + PRIMARY KEY (session_id) + ) + """ + await conn.execute(sessions_ddl) + + # Create operation_id sequence + operation_id_ddl = f""" + CREATE SEQUENCE IF NOT EXISTS {config.operation_id_sequence} START 1 + """ + await conn.execute(operation_id_ddl) + + # Create messages table + messages_ddl = f""" + CREATE TABLE IF NOT EXISTS {config.messages_table} ( + session_id TEXT NOT NULL, + operation_id INTEGER NOT NULL DEFAULT nextval('{config.operation_id_sequence}'), + message_data TEXT NOT NULL, + created_at TIMESTAMP DEFAULT NOW(), + deleted_at TIMESTAMP NULL, + PRIMARY KEY (session_id, operation_id), + FOREIGN KEY (session_id) + REFERENCES {config.sessions_table} (session_id) + ON DELETE CASCADE + ) + """ + await conn.execute(messages_ddl) + + +class PostgresSessionGetItemsRequest(BaseModel): + config: PostgresSessionConfig + session_id: str + limit: int | None = None + + +class PostgresSessionGetItemsResponse(BaseModel): + items: list[TResponseInputItem] + + +@activity.defn +async def postgres_session_get_items_activity( + request: PostgresSessionGetItemsRequest, +) -> PostgresSessionGetItemsResponse: + """Get items from the session in operation_id order.""" + activity.heartbeat() + + conn = PostgresSession._get_connection() + + if request.limit is None: + query = f""" + SELECT message_data FROM {request.config.messages_table} + WHERE session_id = $1 AND deleted_at IS NULL + ORDER BY operation_id ASC + """ + rows = await conn.fetch(query, request.session_id) + else: + query = f""" + SELECT t.message_data FROM ( + SELECT message_data, operation_id FROM {request.config.messages_table} + WHERE session_id = $1 AND deleted_at IS NULL + ORDER BY operation_id DESC + LIMIT $2 + ) AS t ORDER BY t.operation_id ASC + """ + rows = await conn.fetch(query, request.session_id, request.limit) + + return PostgresSessionGetItemsResponse( + items=[json.loads(row["message_data"]) for row in rows] + ) + + +class PostgresSessionAddItemsRequest(BaseModel): + config: PostgresSessionConfig + session_id: str + items: list[TResponseInputItem] + + +@activity.defn +async def postgres_session_add_items_activity( + request: PostgresSessionAddItemsRequest, +) -> None: + """Add items to the session.""" + + conn = PostgresSession._get_connection() + + async def add_items(conn: asyncpg.Connection): + # Ensure session exists + await conn.execute( + f"INSERT INTO {request.config.sessions_table} (session_id) VALUES ($1) ON CONFLICT (session_id) DO NOTHING", + request.session_id, + ) + for item in request.items: + # Use recursive conversion to handle nested objects + item_dict = _convert_to_json_serializable(item) + + await conn.execute( + f"INSERT INTO {request.config.messages_table} (session_id, message_data) VALUES ($1, $2)", + request.session_id, + json.dumps(item_dict), + ) + + idempotence_helper = IdempotenceHelper(table_name=request.config.idempotence_table) + await idempotence_helper.idempotent_update(conn, add_items) + + +class PostgresSessionPopItemRequest(BaseModel): + config: PostgresSessionConfig + session_id: str + + +class PostgresSessionPopItemResponse(BaseModel): + item: TResponseInputItem | None + + +@activity.defn +async def postgres_session_pop_item_activity( + request: PostgresSessionPopItemRequest, +) -> PostgresSessionPopItemResponse: + """Pop item from the session.""" + conn = PostgresSession._get_connection() + + async def pop_item(conn: asyncpg.Connection): + row = await conn.fetchrow( + f"WITH updated AS (UPDATE {request.config.messages_table} SET deleted_at = NOW() WHERE session_id = $1 AND operation_id = (SELECT operation_id FROM {request.config.messages_table} WHERE session_id = $1 AND deleted_at IS NULL ORDER BY operation_id DESC LIMIT 1) RETURNING message_data) SELECT message_data FROM updated", + request.session_id, + ) + if row: + return PostgresSessionPopItemResponse(item=json.loads(row["message_data"])) + else: + return PostgresSessionPopItemResponse(item=None) + + idempotence_helper = IdempotenceHelper(table_name=request.config.idempotence_table) + return await idempotence_helper.idempotent_update(conn, pop_item) + + +class PostgresSessionClearSessionRequest(BaseModel): + config: PostgresSessionConfig + session_id: str + + +@activity.defn +async def postgres_session_clear_session_activity( + request: PostgresSessionClearSessionRequest, +) -> None: + """Clear all items for this session.""" + conn = PostgresSession._get_connection() + + async def clear_session(conn: asyncpg.Connection): + await conn.execute( + f"UPDATE {request.config.messages_table} SET deleted_at = NOW() WHERE session_id = $1 AND deleted_at IS NULL", + request.session_id, + ) + + idempotence_helper = IdempotenceHelper(table_name=request.config.idempotence_table) + await idempotence_helper.idempotent_update(conn, clear_session) + + +class PostgresSession(SessionABC): + """PostgreSQL-based implementation of session storage using operation_id ordering.""" + + def __init__( + self, + session_id: str, + config: PostgresSessionConfig, + ): + self.session_id = session_id + self.config = config + + async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + """Retrieve the conversation history for this session.""" + result = await workflow.execute_activity( + postgres_session_get_items_activity, + PostgresSessionGetItemsRequest( + config=self.config, session_id=self.session_id, limit=limit + ), + start_to_close_timeout=workflow.timedelta(seconds=30), + ) + return result.items + + async def add_items(self, items: list[TResponseInputItem]) -> None: + """Add new items to the conversation history.""" + await workflow.execute_activity( + postgres_session_add_items_activity, + PostgresSessionAddItemsRequest( + config=self.config, session_id=self.session_id, items=items + ), + start_to_close_timeout=workflow.timedelta(seconds=30), + ) + + async def pop_item(self) -> TResponseInputItem | None: + """Remove and return the most recent item from the session.""" + result = await workflow.execute_activity( + postgres_session_pop_item_activity, + PostgresSessionPopItemRequest( + config=self.config, session_id=self.session_id + ), + start_to_close_timeout=workflow.timedelta(seconds=30), + ) + return result.item + + async def clear_session(self) -> None: + """Clear all items for this session.""" + await workflow.execute_activity( + postgres_session_clear_session_activity, + PostgresSessionClearSessionRequest( + config=self.config, session_id=self.session_id + ), + start_to_close_timeout=workflow.timedelta(seconds=30), + ) + + @staticmethod + def set_connection_factory(factory: Callable[[], asyncpg.Connection]): + global _connection_factory + _connection_factory = factory + + @staticmethod + def _get_connection(): + if _connection_factory is None: + raise ValueError("Connection factory not set") + return _connection_factory() diff --git a/openai_agents/memory/run_postgres_session_worker.py b/openai_agents/memory/run_postgres_session_worker.py new file mode 100644 index 00000000..9a3f9e75 --- /dev/null +++ b/openai_agents/memory/run_postgres_session_worker.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +import asyncio +import os +import asyncpg +from datetime import timedelta + +from temporalio.client import Client +from temporalio.contrib.openai_agents import ModelActivityParameters, OpenAIAgentsPlugin +from temporalio.worker import Worker + +from openai_agents.memory.workflows.postgres_session_workflow import ( + PostgresSessionWorkflow, +) +from openai_agents.memory.postgres_session import ( + PostgresSessionConfig, + init_schema, + PostgresSession, + postgres_session_get_items_activity, + postgres_session_add_items_activity, + postgres_session_pop_item_activity, + postgres_session_clear_session_activity, +) +from openai_agents.memory.db_utils import IdempotenceHelper + + +async def main(): + db_connection = await asyncpg.connect(os.getenv("DATABASE_URL")) + + # Database setup + postgres_session_config = PostgresSessionConfig( + messages_table="session_messages", + sessions_table="session", + operation_id_sequence="session_operation_id_sequence", + ) + PostgresSession.set_connection_factory(lambda: db_connection) + await init_schema(db_connection, config=postgres_session_config) + idempotence_helper = IdempotenceHelper(table_name="activity_idempotence") + await idempotence_helper.create_table(db_connection) + PostgresSession.set_connection_factory(lambda: db_connection) + + # Create client connected to server at the given address + client = await Client.connect( + "localhost:7233", + plugins=[ + OpenAIAgentsPlugin( + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=30) + ) + ), + ], + ) + + worker = Worker( + client, + task_queue="openai-agents-memory-task-queue", + workflows=[ + PostgresSessionWorkflow, + ], + activities=[ + postgres_session_get_items_activity, + postgres_session_add_items_activity, + postgres_session_pop_item_activity, + postgres_session_clear_session_activity, + ], + ) + await worker.run() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/openai_agents/memory/run_postgres_session_workflow.py b/openai_agents/memory/run_postgres_session_workflow.py new file mode 100644 index 00000000..facfa55a --- /dev/null +++ b/openai_agents/memory/run_postgres_session_workflow.py @@ -0,0 +1,34 @@ +import asyncio +import uuid + +from temporalio.client import Client +from temporalio.contrib.openai_agents import OpenAIAgentsPlugin + +from openai_agents.memory.workflows.postgres_session_workflow import ( + PostgresSessionWorkflow, +) + + +async def main(): + # Create client connected to server at the given address + client = await Client.connect( + "localhost:7233", + plugins=[ + OpenAIAgentsPlugin(), + ], + ) + + # Execute a workflow + result = await client.execute_workflow( + PostgresSessionWorkflow.run, + f"openai-session-workflow-{uuid.uuid4()}", + id=f"openai-session-workflow-{uuid.uuid4()}", + task_queue="openai-agents-memory-task-queue", + ) + + # Print the workflow output + print(result) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/openai_agents/memory/workflows/postgres_session_workflow.py b/openai_agents/memory/workflows/postgres_session_workflow.py new file mode 100644 index 00000000..d5e239bd --- /dev/null +++ b/openai_agents/memory/workflows/postgres_session_workflow.py @@ -0,0 +1,80 @@ +from io import StringIO +from agents import Agent, Runner +from temporalio import workflow +from openai_agents.memory.postgres_session import PostgresSession, PostgresSessionConfig + + +@workflow.defn +class PostgresSessionWorkflow: + @workflow.run + async def run(self, session_id: str) -> str: + # Create string buffer to capture all output + output = StringIO() + # Create a PostgreSQL session instance that will persist across runs + postgres_config = PostgresSessionConfig( + messages_table="session_messages", + sessions_table="session", + operation_id_sequence="session_operation_id_sequence", + ) + session = PostgresSession(session_id=session_id, config=postgres_config) + + # Create an agent + agent = Agent( + name="Assistant", + instructions="Reply very concisely.", + ) + + output.write("=== Session Example ===\n") + output.write("The agent will remember previous messages automatically.\n\n") + + # First turn + output.write("First turn:\n") + output.write("User: What city is the Golden Gate Bridge in?\n") + result = await Runner.run( + agent, + "What city is the Golden Gate Bridge in?", + session=session, + ) + output.write(f"Assistant: {result.final_output}\n\n") + + # Second turn - the agent will remember the previous conversation + output.write("Second turn:\n") + output.write("User: What state is it in?\n") + result = await Runner.run(agent, "What state is it in?", session=session) + output.write(f"Assistant: {result.final_output}\n\n") + + # Third turn - continuing the conversation + output.write("Third turn:\n") + output.write("User: What's the population of that state?\n") + result = await Runner.run( + agent, + "What's the population of that state?", + session=session, + ) + output.write(f"Assistant: {result.final_output}\n\n") + + output.write("=== Conversation Complete ===\n") + output.write( + "Notice how the agent remembered the context from previous turns!\n" + ) + output.write("Sessions automatically handles conversation history.\n") + + # Demonstrate the limit parameter - get only the latest 2 items + output.write("\n=== Latest Items Demo ===\n") + latest_items = await session.get_items(limit=2) + output.write("Latest 2 items:\n") + for i, msg in enumerate(latest_items, 1): + role = msg.get("role", "unknown") + content = msg.get("content", "") + output.write(f" {i}. {role}: {content}\n") + + output.write( + f"\nFetched {len(latest_items)} out of total conversation history.\n" + ) + + # Get all items to show the difference + all_items = await session.get_items() + output.write(f"Total items in session: {len(all_items)}\n") + + # Return the buffered output as a string + return output.getvalue() diff --git a/pyproject.toml b/pyproject.toml index fa5a0300..b6beb87c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,10 @@ authors = [{ name = "Temporal Technologies Inc", email = "sdk@temporal.io" }] requires-python = ">=3.10" readme = "README.md" license = "MIT" -dependencies = ["temporalio>=1.15.0,<2"] +dependencies = [ + "asyncpg>=0.30.0", + "temporalio>=1.15.0,<2", +] [project.urls] Homepage = "https://github.com/temporalio/samples-python" @@ -55,8 +58,8 @@ open-telemetry = [ "opentelemetry-exporter-otlp-proto-grpc", ] openai-agents = [ - "openai-agents[litellm] >= 0.2.3", - "temporalio[openai-agents] >= 1.15.0", + "openai-agents[litellm] >= 0.3.0", + "temporalio[openai-agents] >= 1.16.0", ] pydantic-converter = ["pydantic>=2.10.6,<3"] sentry = ["sentry-sdk>=2.13.0"] @@ -143,4 +146,4 @@ ignore_errors = true [[tool.mypy.overrides]] module = "opentelemetry.*" -ignore_errors = true \ No newline at end of file +ignore_errors = true diff --git a/tests/conftest.py b/tests/conftest.py index e63a059b..92e8a932 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,11 @@ import asyncio import multiprocessing +import os import sys +import uuid from typing import AsyncGenerator +import asyncpg import pytest import pytest_asyncio from temporalio.client import Client @@ -58,3 +61,38 @@ async def env(request) -> AsyncGenerator[WorkflowEnvironment, None]: @pytest_asyncio.fixture async def client(env: WorkflowEnvironment) -> Client: return env.client + + +@pytest_asyncio.fixture +async def db_connection(): + """Create a PostgreSQL connection with a unique schema for each test. + + Sets up a temporary schema with UUID naming, sets it as the default schema + for the connection, and cleans up with CASCADE on teardown. + """ + # Generate unique schema name + schema_name = f"test_{uuid.uuid4().hex}" + + # Create connection + # Note that we read the DATABASE_URL from the environment because asyncpg does not read this + # environment variable. It does read other postgres environment variables such as PGHOST, + # PGPORT, PGDATABASE, PGUSER, and PGPASSWORD, so you can still use those if you do not set + # DATABASE_URL. + conn = await asyncpg.connect(os.getenv("DATABASE_URL")) + + try: + # Create the schema + await conn.execute(f"CREATE SCHEMA {schema_name}") + + # Set the schema as default for this connection + await conn.execute(f"SET search_path TO {schema_name}") + + yield conn + + finally: + # Clean up: drop schema with cascade + try: + await conn.execute(f"DROP SCHEMA IF EXISTS {schema_name} CASCADE") + except Exception: + pass # Best effort cleanup + await conn.close() diff --git a/tests/openai_agents/memory/test_idempotence_util.py b/tests/openai_agents/memory/test_idempotence_util.py new file mode 100644 index 00000000..6df505a0 --- /dev/null +++ b/tests/openai_agents/memory/test_idempotence_util.py @@ -0,0 +1,324 @@ +from temporalio import workflow +from temporalio.client import Client +from temporalio.worker import Worker +from temporalio import activity + +import uuid +from datetime import timedelta +from fractions import Fraction + +import asyncpg +from temporalio.common import RetryPolicy +from openai_agents.memory.connection_state import ( + set_worker_connection, + get_worker_connection, + clear_worker_connection, +) +from openai_agents.memory.db_utils import IdempotenceHelper +from pydantic import BaseModel +from datetime import datetime +from temporalio.contrib.pydantic import pydantic_data_converter + + +@workflow.defn +class FailureFreeTestWorkflow: + @workflow.run + async def run(self): + res1 = await workflow.execute_activity( + read_only_operation, + start_to_close_timeout=timedelta(seconds=10), + retry_policy=RetryPolicy( + maximum_attempts=1, + ), + ) + assert res1 == 1 + + await workflow.execute_activity( + write_operation, + start_to_close_timeout=timedelta(seconds=10), + retry_policy=RetryPolicy( + maximum_attempts=1, + ), + ) + res2 = await workflow.execute_activity( + read_test_data, + start_to_close_timeout=timedelta(seconds=10), + retry_policy=RetryPolicy( + maximum_attempts=1, + ), + ) + assert len(res2) == 1 and res2[0] == 456 + + +@workflow.defn +class TestRetriedWriteWorkflow: + @workflow.run + async def run(self): + await workflow.execute_activity( + fail_mid_transaction_activity, + start_to_close_timeout=timedelta(seconds=10), + retry_policy=RetryPolicy( + maximum_attempts=2, + ), + ) + res = await workflow.execute_activity( + read_test_data, + start_to_close_timeout=timedelta(seconds=10), + retry_policy=RetryPolicy( + maximum_attempts=1, + ), + ) + assert len(res) == 2 + assert set(res) == {1, 2} + + res2 = await workflow.execute_activity( + update_and_fail_activity, + start_to_close_timeout=timedelta(seconds=10), + retry_policy=RetryPolicy( + maximum_attempts=2, + ), + ) + assert res2 == 6 + + res3 = await workflow.execute_activity( + read_test_data, + start_to_close_timeout=timedelta(seconds=10), + retry_policy=RetryPolicy( + maximum_attempts=1, + ), + ) + assert len(res3) == 3 + assert set(res3) == {1, 2, 3} + + +class MyActivityArgs(BaseModel): + insert_x: int + insert_y: Fraction + should_fail: bool + + +class MyPydanticModel(BaseModel): + x: int + y: Fraction + z: datetime + + +@workflow.defn +class TestPydanticModelWorkflow: + @workflow.run + async def run(self): + res1 = await workflow.execute_activity( + write_pydantic_model_activity, + MyActivityArgs(insert_x=1, should_fail=False, insert_y=Fraction(1, 3)), + start_to_close_timeout=timedelta(seconds=10), + retry_policy=RetryPolicy( + maximum_attempts=1, + ), + ) + assert isinstance(res1, MyPydanticModel) + assert res1.x == 1 + assert res1.y == Fraction(1, 3) + assert res1.z is not None + assert isinstance(res1.z, datetime) + + res2 = await workflow.execute_activity( + write_pydantic_model_activity, + MyActivityArgs(insert_x=2, should_fail=True, insert_y=Fraction(4, 5)), + start_to_close_timeout=timedelta(seconds=10), + retry_policy=RetryPolicy( + maximum_attempts=2, + ), + ) + assert isinstance(res2, MyPydanticModel) + assert res2.x == 2 + assert res2.y == Fraction(4, 5) + assert res2.z is not None + assert isinstance(res2.z, datetime) + + +@activity.defn +async def write_pydantic_model_activity(args: MyActivityArgs) -> MyPydanticModel: + conn = get_worker_connection() + + async def query(conn): + await conn.execute( + "INSERT INTO test (x, y, z) VALUES ($1, $2, NOW())", + args.insert_x, + str(args.insert_y), + ) + result = await conn.fetchrow( + "SELECT x, y, z FROM test WHERE x = $1 LIMIT 1", args.insert_x + ) + return MyPydanticModel(x=result[0], y=result[1], z=result[2]) + + idempotence_helper = IdempotenceHelper(table_name="activity_idempotence") + res = await idempotence_helper.idempotent_update(conn, query) + if args.should_fail and activity.info().attempt == 1: + raise Exception("Test exception") + return res + + +@activity.defn +async def read_only_operation(): + conn = get_worker_connection() + + # Read-only operation + async def query(conn): + return (await conn.fetchrow("SELECT 1"))[0] + + idempotence_helper = IdempotenceHelper(table_name="activity_idempotence") + record = await idempotence_helper.idempotent_update(conn, query) + return record + + +@activity.defn +async def write_operation(): + conn = get_worker_connection() + + async def query(conn): + await conn.execute("INSERT INTO test (x) VALUES (456)") + + idempotence_helper = IdempotenceHelper(table_name="activity_idempotence") + await idempotence_helper.idempotent_update(conn, query) + + +@activity.defn +async def read_test_data(): + conn = get_worker_connection() + # Not using the idempotence helper here because we are just validating + results = await conn.fetch("SELECT x FROM test") + return [r[0] for r in results] + + +@activity.defn +async def fail_mid_transaction_activity(): + conn = get_worker_connection() + + async def query(conn): + await conn.execute("INSERT INTO test (x) VALUES (1)") + if activity.info().attempt == 1: + raise Exception("Test exception") + await conn.execute("INSERT INTO test (x) VALUES (2)") + + idempotence_helper = IdempotenceHelper(table_name="activity_idempotence") + await idempotence_helper.idempotent_update(conn, query) + + +@activity.defn +async def update_and_fail_activity(): + # This activity updates the test table and fails after making the update and committing the transaction + # but before returning the result. This means it needs to read the result from the idempotence table. + conn = get_worker_connection() + + async def query(conn): + await conn.execute("INSERT INTO test (x) VALUES (3)") + result = await conn.fetchrow("SELECT SUM(x) FROM test") + return result[0] + + idempotence_helper = IdempotenceHelper(table_name="activity_idempotence") + res = await idempotence_helper.idempotent_update(conn, query) + if activity.info().attempt == 1: + raise Exception("Test exception") + return res + + +async def test_idempotence_util(client: Client, db_connection: asyncpg.Connection): + # Set the worker-level connection + set_worker_connection(db_connection) + + # setup a test table + await db_connection.execute("CREATE TABLE IF NOT EXISTS test (x INT)") + + # set up the idempotence table + idempotence_helper = IdempotenceHelper(table_name="activity_idempotence") + await idempotence_helper.create_table(db_connection) + + try: + async with Worker( + client, + task_queue=f"test-idempotence-tast-queue-{uuid.uuid4()}", + workflows=[FailureFreeTestWorkflow], + activities=[read_only_operation, write_operation, read_test_data], + ) as worker: + workflow_handle = await client.start_workflow( + FailureFreeTestWorkflow.run, + id=f"test-idempotence-workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + await workflow_handle.result() + finally: + # Clean up worker-level connection + clear_worker_connection() + + +async def test_idempotence_util_retried_write_activity( + client: Client, db_connection: asyncpg.Connection +): + # Set the worker-level connection + set_worker_connection(db_connection) + + # setup a test table + await db_connection.execute("CREATE TABLE IF NOT EXISTS test (x INT)") + + # set up the idempotence table + idempotence_helper = IdempotenceHelper(table_name="activity_idempotence") + await idempotence_helper.create_table(db_connection) + + try: + async with Worker( + client, + task_queue=f"test-idempotence-tast-queue-{uuid.uuid4()}", + workflows=[TestRetriedWriteWorkflow], + activities=[ + fail_mid_transaction_activity, + read_test_data, + update_and_fail_activity, + ], + ) as worker: + workflow_handle = await client.start_workflow( + TestRetriedWriteWorkflow.run, + id=f"test-idempotence-workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + await workflow_handle.result() + finally: + # Clean up worker-level connection + clear_worker_connection() + + +async def test_pydantic_model_result(client: Client, db_connection: asyncpg.Connection): + new_config = client.config() + new_config["data_converter"] = pydantic_data_converter + client = Client(**new_config) + + # Set the worker-level connection + set_worker_connection(db_connection) + + # setup a test table + await db_connection.execute( + "CREATE TABLE IF NOT EXISTS test (x INT, y TEXT, z TIMESTAMP DEFAULT NOW())" + ) + + # set up the idempotence table + idempotence_helper = IdempotenceHelper(table_name="activity_idempotence") + await idempotence_helper.create_table(db_connection) + + try: + async with Worker( + client, + task_queue=f"test-idempotence-tast-queue-{uuid.uuid4()}", + workflows=[TestPydanticModelWorkflow], + activities=[write_pydantic_model_activity], + ) as worker: + workflow_handle = await client.start_workflow( + TestPydanticModelWorkflow.run, + id=f"test-idempotence-workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + await workflow_handle.result() + finally: + # Clean up worker-level connection + clear_worker_connection() diff --git a/tests/openai_agents/memory/test_postgres_session.py b/tests/openai_agents/memory/test_postgres_session.py new file mode 100644 index 00000000..14ac715d --- /dev/null +++ b/tests/openai_agents/memory/test_postgres_session.py @@ -0,0 +1,504 @@ +import asyncpg +from temporalio import workflow +from temporalio.client import Client +from temporalio.worker import Worker + +import uuid +from datetime import timedelta + +from openai_agents.memory.db_utils import IdempotenceHelper +from openai_agents.memory.postgres_session import PostgresSessionConfig +from openai_agents.memory.postgres_session import PostgresSession +from openai_agents.memory.postgres_session import TResponseInputItem +from openai_agents.memory.postgres_session import ( + init_schema, + postgres_session_pop_item_activity, + postgres_session_add_items_activity, + postgres_session_clear_session_activity, + postgres_session_get_items_activity, +) +from pydantic import BaseModel +from temporalio.contrib.pydantic import pydantic_data_converter + + +class BasicSessionWorkflowConfig(BaseModel): + session_id: str + config: PostgresSessionConfig + + +@workflow.defn +class BasicSessionWorkflow: + @workflow.run + async def run(self, config: BasicSessionWorkflowConfig): + session = PostgresSession(session_id=config.session_id, config=config.config) + # Test adding and retrieving items + items: list[TResponseInputItem] = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + + await session.add_items(items) + retrieved = await session.get_items() + + assert len(retrieved) == 2 + assert retrieved[0].get("role") == "user" + assert retrieved[0].get("content") == "Hello" + assert retrieved[1].get("role") == "assistant" + assert retrieved[1].get("content") == "Hi there!" + + # Test clearing session + await session.clear_session() + retrieved_after_clear = await session.get_items() + assert len(retrieved_after_clear) == 0 + + +async def test_session_workflow(client: Client, db_connection: asyncpg.Connection): + new_config = client.config() + new_config["data_converter"] = pydantic_data_converter + client = Client(**new_config) + + postgres_session_config = PostgresSessionConfig( + messages_table="session_messages", + sessions_table="session", + operation_id_sequence="session_operation_id_sequence", + ) + PostgresSession.set_connection_factory(lambda: db_connection) + await init_schema(db_connection, config=postgres_session_config) + + idempotence_helper = IdempotenceHelper(table_name="activity_idempotence") + await idempotence_helper.create_table(db_connection) + + async with Worker( + client, + task_queue=f"basic-session-workflow-{uuid.uuid4()}", + workflows=[BasicSessionWorkflow], + activities=[ + postgres_session_pop_item_activity, + postgres_session_add_items_activity, + postgres_session_clear_session_activity, + postgres_session_get_items_activity, + ], + ) as worker: + workflow_handle = await client.start_workflow( + BasicSessionWorkflow.run, + BasicSessionWorkflowConfig( + session_id=f"test-session-{uuid.uuid4()}", + config=postgres_session_config, + ), + id=f"basic-session-workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + await workflow_handle.result() + + +# Pop item workflow and tests +@workflow.defn +class PopItemWorkflow: + @workflow.run + async def run(self, config: BasicSessionWorkflowConfig): + session = PostgresSession(session_id=config.session_id, config=config.config) + + # Test popping from empty session + popped = await session.pop_item() + assert popped is None + + # Add items + items: list[TResponseInputItem] = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "How are you?"}, + ] + await session.add_items(items) + + # Verify all items are there + retrieved = await session.get_items() + assert len(retrieved) == 3 + + # Pop the most recent item + popped = await session.pop_item() + assert popped is not None + assert popped.get("role") == "user" + assert popped.get("content") == "How are you?" + + # Verify item was removed + retrieved_after_pop = await session.get_items() + assert len(retrieved_after_pop) == 2 + assert retrieved_after_pop[-1].get("content") == "Hi there!" + + # Pop another item + popped2 = await session.pop_item() + assert popped2 is not None + assert popped2.get("role") == "assistant" + assert popped2.get("content") == "Hi there!" + + # Pop the last item + popped3 = await session.pop_item() + assert popped3 is not None + assert popped3.get("role") == "user" + assert popped3.get("content") == "Hello" + + # Try to pop from empty session again + popped4 = await session.pop_item() + assert popped4 is None + + # Verify session is empty + final_items = await session.get_items() + assert len(final_items) == 0 + + +async def test_postgres_session_pop_item( + client: Client, db_connection: asyncpg.Connection +): + """Test PostgresSession pop_item functionality.""" + new_config = client.config() + new_config["data_converter"] = pydantic_data_converter + client = Client(**new_config) + + postgres_session_config = PostgresSessionConfig( + messages_table="session_messages", + sessions_table="session", + operation_id_sequence="session_operation_id_sequence", + ) + PostgresSession.set_connection_factory(lambda: db_connection) + await init_schema(db_connection, config=postgres_session_config) + + idempotence_helper = IdempotenceHelper(table_name="activity_idempotence") + await idempotence_helper.create_table(db_connection) + + async with Worker( + client, + task_queue=f"pop-item-workflow-{uuid.uuid4()}", + workflows=[PopItemWorkflow], + activities=[ + postgres_session_pop_item_activity, + postgres_session_add_items_activity, + postgres_session_clear_session_activity, + postgres_session_get_items_activity, + ], + ) as worker: + workflow_handle = await client.start_workflow( + PopItemWorkflow.run, + BasicSessionWorkflowConfig( + session_id=f"pop-test-{uuid.uuid4()}", config=postgres_session_config + ), + id=f"pop-item-workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + await workflow_handle.result() + + +# Test different sessions workflow +@workflow.defn +class DifferentSessionsWorkflow: + @workflow.run + async def run(self, config: BasicSessionWorkflowConfig): + # Create two sessions with different IDs + session_1_id = f"session_1_{config.session_id}" + session_2_id = f"session_2_{config.session_id}" + + session_1 = PostgresSession(session_id=session_1_id, config=config.config) + session_2 = PostgresSession(session_id=session_2_id, config=config.config) + + # Add items to both sessions + items_1: list[TResponseInputItem] = [ + {"role": "user", "content": "Session 1 message"}, + ] + items_2: list[TResponseInputItem] = [ + {"role": "user", "content": "Session 2 message 1"}, + {"role": "user", "content": "Session 2 message 2"}, + ] + + await session_1.add_items(items_1) + await session_2.add_items(items_2) + + # Pop from session 2 + popped = await session_2.pop_item() + assert popped is not None + assert popped.get("content") == "Session 2 message 2" + + # Verify session 1 is unaffected + session_1_items = await session_1.get_items() + assert len(session_1_items) == 1 + assert session_1_items[0].get("content") == "Session 1 message" + + # Verify session 2 has one item left + session_2_items = await session_2.get_items() + assert len(session_2_items) == 1 + assert session_2_items[0].get("content") == "Session 2 message 1" + + +async def test_postgres_session_pop_different_sessions( + client: Client, db_connection: asyncpg.Connection +): + """Test that pop_item only affects the specified session.""" + new_config = client.config() + new_config["data_converter"] = pydantic_data_converter + client = Client(**new_config) + + postgres_session_config = PostgresSessionConfig( + messages_table="session_messages", + sessions_table="session", + operation_id_sequence="session_operation_id_sequence", + ) + PostgresSession.set_connection_factory(lambda: db_connection) + await init_schema(db_connection, config=postgres_session_config) + + idempotence_helper = IdempotenceHelper(table_name="activity_idempotence") + await idempotence_helper.create_table(db_connection) + + async with Worker( + client, + task_queue=f"different-sessions-workflow-{uuid.uuid4()}", + workflows=[DifferentSessionsWorkflow], + activities=[ + postgres_session_pop_item_activity, + postgres_session_add_items_activity, + postgres_session_clear_session_activity, + postgres_session_get_items_activity, + ], + ) as worker: + workflow_handle = await client.start_workflow( + DifferentSessionsWorkflow.run, + BasicSessionWorkflowConfig( + session_id=f"diff-sessions-test-{uuid.uuid4()}", + config=postgres_session_config, + ), + id=f"different-sessions-workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + await workflow_handle.result() + + +# Test get_items with limit workflow +@workflow.defn +class GetItemsWithLimitWorkflow: + @workflow.run + async def run(self, config: BasicSessionWorkflowConfig): + session = PostgresSession(session_id=config.session_id, config=config.config) + + # Add multiple items + items: list[TResponseInputItem] = [ + {"role": "user", "content": "Message 1"}, + {"role": "assistant", "content": "Response 1"}, + {"role": "user", "content": "Message 2"}, + {"role": "assistant", "content": "Response 2"}, + {"role": "user", "content": "Message 3"}, + {"role": "assistant", "content": "Response 3"}, + ] + await session.add_items(items) + + # Test getting all items (default behavior) + all_items = await session.get_items() + assert len(all_items) == 6 + assert all_items[0].get("content") == "Message 1" + assert all_items[-1].get("content") == "Response 3" + + # Test getting latest 2 items + latest_2 = await session.get_items(limit=2) + assert len(latest_2) == 2 + assert latest_2[0].get("content") == "Message 3" + assert latest_2[1].get("content") == "Response 3" + + # Test getting latest 4 items + latest_4 = await session.get_items(limit=4) + assert len(latest_4) == 4 + assert latest_4[0].get("content") == "Message 2" + assert latest_4[1].get("content") == "Response 2" + assert latest_4[2].get("content") == "Message 3" + assert latest_4[3].get("content") == "Response 3" + + # Test getting more items than available + latest_10 = await session.get_items(limit=10) + assert len(latest_10) == 6 # Should return all available items + assert latest_10[0].get("content") == "Message 1" + assert latest_10[-1].get("content") == "Response 3" + + # Test getting 0 items + latest_0 = await session.get_items(limit=0) + assert len(latest_0) == 0 + + +async def test_postgres_session_get_items_with_limit( + client: Client, db_connection: asyncpg.Connection +): + """Test PostgresSession get_items with limit parameter.""" + new_config = client.config() + new_config["data_converter"] = pydantic_data_converter + client = Client(**new_config) + + postgres_session_config = PostgresSessionConfig( + messages_table="session_messages", + sessions_table="session", + operation_id_sequence="session_operation_id_sequence", + ) + PostgresSession.set_connection_factory(lambda: db_connection) + await init_schema(db_connection, config=postgres_session_config) + + idempotence_helper = IdempotenceHelper(table_name="activity_idempotence") + await idempotence_helper.create_table(db_connection) + + async with Worker( + client, + task_queue=f"get-items-limit-workflow-{uuid.uuid4()}", + workflows=[GetItemsWithLimitWorkflow], + activities=[ + postgres_session_pop_item_activity, + postgres_session_add_items_activity, + postgres_session_clear_session_activity, + postgres_session_get_items_activity, + ], + ) as worker: + workflow_handle = await client.start_workflow( + GetItemsWithLimitWorkflow.run, + BasicSessionWorkflowConfig( + session_id=f"limit-test-{uuid.uuid4()}", config=postgres_session_config + ), + id=f"get-items-limit-workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + await workflow_handle.result() + + +# Test unicode content workflow +@workflow.defn +class UnicodeContentWorkflow: + @workflow.run + async def run(self, config: BasicSessionWorkflowConfig): + session = PostgresSession(session_id=config.session_id, config=config.config) + + # Add unicode content to the session + items: list[TResponseInputItem] = [ + {"role": "user", "content": "こんにちは"}, + {"role": "assistant", "content": "😊👍"}, + {"role": "user", "content": "Привет"}, + ] + await session.add_items(items) + + # Retrieve items and verify unicode content + retrieved = await session.get_items() + assert retrieved[0].get("content") == "こんにちは" + assert retrieved[1].get("content") == "😊👍" + assert retrieved[2].get("content") == "Привет" + + +async def test_postgres_session_unicode_content( + client: Client, db_connection: asyncpg.Connection +): + """Test that session correctly stores and retrieves unicode/non-ASCII content.""" + new_config = client.config() + new_config["data_converter"] = pydantic_data_converter + client = Client(**new_config) + + postgres_session_config = PostgresSessionConfig( + messages_table="session_messages", + sessions_table="session", + operation_id_sequence="session_operation_id_sequence", + ) + PostgresSession.set_connection_factory(lambda: db_connection) + await init_schema(db_connection, config=postgres_session_config) + + idempotence_helper = IdempotenceHelper(table_name="activity_idempotence") + await idempotence_helper.create_table(db_connection) + + async with Worker( + client, + task_queue=f"unicode-content-workflow-{uuid.uuid4()}", + workflows=[UnicodeContentWorkflow], + activities=[ + postgres_session_pop_item_activity, + postgres_session_add_items_activity, + postgres_session_clear_session_activity, + postgres_session_get_items_activity, + ], + ) as worker: + workflow_handle = await client.start_workflow( + UnicodeContentWorkflow.run, + BasicSessionWorkflowConfig( + session_id=f"unicode-test-{uuid.uuid4()}", + config=postgres_session_config, + ), + id=f"unicode-content-workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + await workflow_handle.result() + + +# Test special characters and SQL injection workflow +@workflow.defn +class SpecialCharactersWorkflow: + @workflow.run + async def run(self, config: BasicSessionWorkflowConfig): + session = PostgresSession(session_id=config.session_id, config=config.config) + + # Add items with special characters and SQL keywords + items: list[TResponseInputItem] = [ + {"role": "user", "content": "O'Reilly"}, + {"role": "assistant", "content": "DROP TABLE sessions;"}, + { + "role": "user", + "content": '"SELECT * FROM users WHERE name = \\"admin\\";";', + }, + {"role": "assistant", "content": "Robert'); DROP TABLE students;--"}, + {"role": "user", "content": "Normal message"}, + ] + await session.add_items(items) + + # Retrieve all items and verify they are stored correctly + retrieved = await session.get_items() + assert len(retrieved) == len(items) + assert retrieved[0].get("content") == "O'Reilly" + assert retrieved[1].get("content") == "DROP TABLE sessions;" + assert ( + retrieved[2].get("content") + == '"SELECT * FROM users WHERE name = \\"admin\\";";' + ) + assert retrieved[3].get("content") == "Robert'); DROP TABLE students;--" + assert retrieved[4].get("content") == "Normal message" + + +async def test_postgres_session_special_characters_and_sql_injection( + client: Client, db_connection: asyncpg.Connection +): + """Test that session safely stores and retrieves items with special characters and SQL keywords.""" + new_config = client.config() + new_config["data_converter"] = pydantic_data_converter + client = Client(**new_config) + + postgres_session_config = PostgresSessionConfig( + messages_table="session_messages", + sessions_table="session", + operation_id_sequence="session_operation_id_sequence", + ) + PostgresSession.set_connection_factory(lambda: db_connection) + await init_schema(db_connection, config=postgres_session_config) + + idempotence_helper = IdempotenceHelper(table_name="activity_idempotence") + await idempotence_helper.create_table(db_connection) + + async with Worker( + client, + task_queue=f"special-chars-workflow-{uuid.uuid4()}", + workflows=[SpecialCharactersWorkflow], + activities=[ + postgres_session_pop_item_activity, + postgres_session_add_items_activity, + postgres_session_clear_session_activity, + postgres_session_get_items_activity, + ], + ) as worker: + workflow_handle = await client.start_workflow( + SpecialCharactersWorkflow.run, + BasicSessionWorkflowConfig( + session_id=f"special-chars-test-{uuid.uuid4()}", + config=postgres_session_config, + ), + id=f"special-chars-workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + await workflow_handle.result() diff --git a/uv.lock b/uv.lock index 9b1309d1..92612855 100644 --- a/uv.lock +++ b/uv.lock @@ -150,6 +150,49 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a7/fa/e01228c2938de91d47b307831c62ab9e4001e747789d0b05baf779a6488c/async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028", size = 5721, upload-time = "2023-08-10T16:35:55.203Z" }, ] +[[package]] +name = "asyncpg" +version = "0.30.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "async-timeout", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/2f/4c/7c991e080e106d854809030d8584e15b2e996e26f16aee6d757e387bc17d/asyncpg-0.30.0.tar.gz", hash = "sha256:c551e9928ab6707602f44811817f82ba3c446e018bfe1d3abecc8ba5f3eac851", size = 957746, upload-time = "2024-10-20T00:30:41.127Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bb/07/1650a8c30e3a5c625478fa8aafd89a8dd7d85999bf7169b16f54973ebf2c/asyncpg-0.30.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bfb4dd5ae0699bad2b233672c8fc5ccbd9ad24b89afded02341786887e37927e", size = 673143, upload-time = "2024-10-20T00:29:08.846Z" }, + { url = "https://files.pythonhosted.org/packages/a0/9a/568ff9b590d0954553c56806766914c149609b828c426c5118d4869111d3/asyncpg-0.30.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:dc1f62c792752a49f88b7e6f774c26077091b44caceb1983509edc18a2222ec0", size = 645035, upload-time = "2024-10-20T00:29:12.02Z" }, + { url = "https://files.pythonhosted.org/packages/de/11/6f2fa6c902f341ca10403743701ea952bca896fc5b07cc1f4705d2bb0593/asyncpg-0.30.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3152fef2e265c9c24eec4ee3d22b4f4d2703d30614b0b6753e9ed4115c8a146f", size = 2912384, upload-time = "2024-10-20T00:29:13.644Z" }, + { url = "https://files.pythonhosted.org/packages/83/83/44bd393919c504ffe4a82d0aed8ea0e55eb1571a1dea6a4922b723f0a03b/asyncpg-0.30.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c7255812ac85099a0e1ffb81b10dc477b9973345793776b128a23e60148dd1af", size = 2947526, upload-time = "2024-10-20T00:29:15.871Z" }, + { url = "https://files.pythonhosted.org/packages/08/85/e23dd3a2b55536eb0ded80c457b0693352262dc70426ef4d4a6fc994fa51/asyncpg-0.30.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:578445f09f45d1ad7abddbff2a3c7f7c291738fdae0abffbeb737d3fc3ab8b75", size = 2895390, upload-time = "2024-10-20T00:29:19.346Z" }, + { url = "https://files.pythonhosted.org/packages/9b/26/fa96c8f4877d47dc6c1864fef5500b446522365da3d3d0ee89a5cce71a3f/asyncpg-0.30.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:c42f6bb65a277ce4d93f3fba46b91a265631c8df7250592dd4f11f8b0152150f", size = 3015630, upload-time = "2024-10-20T00:29:21.186Z" }, + { url = "https://files.pythonhosted.org/packages/34/00/814514eb9287614188a5179a8b6e588a3611ca47d41937af0f3a844b1b4b/asyncpg-0.30.0-cp310-cp310-win32.whl", hash = "sha256:aa403147d3e07a267ada2ae34dfc9324e67ccc4cdca35261c8c22792ba2b10cf", size = 568760, upload-time = "2024-10-20T00:29:22.769Z" }, + { url = "https://files.pythonhosted.org/packages/f0/28/869a7a279400f8b06dd237266fdd7220bc5f7c975348fea5d1e6909588e9/asyncpg-0.30.0-cp310-cp310-win_amd64.whl", hash = "sha256:fb622c94db4e13137c4c7f98834185049cc50ee01d8f657ef898b6407c7b9c50", size = 625764, upload-time = "2024-10-20T00:29:25.882Z" }, + { url = "https://files.pythonhosted.org/packages/4c/0e/f5d708add0d0b97446c402db7e8dd4c4183c13edaabe8a8500b411e7b495/asyncpg-0.30.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5e0511ad3dec5f6b4f7a9e063591d407eee66b88c14e2ea636f187da1dcfff6a", size = 674506, upload-time = "2024-10-20T00:29:27.988Z" }, + { url = "https://files.pythonhosted.org/packages/6a/a0/67ec9a75cb24a1d99f97b8437c8d56da40e6f6bd23b04e2f4ea5d5ad82ac/asyncpg-0.30.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:915aeb9f79316b43c3207363af12d0e6fd10776641a7de8a01212afd95bdf0ed", size = 645922, upload-time = "2024-10-20T00:29:29.391Z" }, + { url = "https://files.pythonhosted.org/packages/5c/d9/a7584f24174bd86ff1053b14bb841f9e714380c672f61c906eb01d8ec433/asyncpg-0.30.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c198a00cce9506fcd0bf219a799f38ac7a237745e1d27f0e1f66d3707c84a5a", size = 3079565, upload-time = "2024-10-20T00:29:30.832Z" }, + { url = "https://files.pythonhosted.org/packages/a0/d7/a4c0f9660e333114bdb04d1a9ac70db690dd4ae003f34f691139a5cbdae3/asyncpg-0.30.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3326e6d7381799e9735ca2ec9fd7be4d5fef5dcbc3cb555d8a463d8460607956", size = 3109962, upload-time = "2024-10-20T00:29:33.114Z" }, + { url = "https://files.pythonhosted.org/packages/3c/21/199fd16b5a981b1575923cbb5d9cf916fdc936b377e0423099f209e7e73d/asyncpg-0.30.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:51da377487e249e35bd0859661f6ee2b81db11ad1f4fc036194bc9cb2ead5056", size = 3064791, upload-time = "2024-10-20T00:29:34.677Z" }, + { url = "https://files.pythonhosted.org/packages/77/52/0004809b3427534a0c9139c08c87b515f1c77a8376a50ae29f001e53962f/asyncpg-0.30.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:bc6d84136f9c4d24d358f3b02be4b6ba358abd09f80737d1ac7c444f36108454", size = 3188696, upload-time = "2024-10-20T00:29:36.389Z" }, + { url = "https://files.pythonhosted.org/packages/52/cb/fbad941cd466117be58b774a3f1cc9ecc659af625f028b163b1e646a55fe/asyncpg-0.30.0-cp311-cp311-win32.whl", hash = "sha256:574156480df14f64c2d76450a3f3aaaf26105869cad3865041156b38459e935d", size = 567358, upload-time = "2024-10-20T00:29:37.915Z" }, + { url = "https://files.pythonhosted.org/packages/3c/0a/0a32307cf166d50e1ad120d9b81a33a948a1a5463ebfa5a96cc5606c0863/asyncpg-0.30.0-cp311-cp311-win_amd64.whl", hash = "sha256:3356637f0bd830407b5597317b3cb3571387ae52ddc3bca6233682be88bbbc1f", size = 629375, upload-time = "2024-10-20T00:29:39.987Z" }, + { url = "https://files.pythonhosted.org/packages/4b/64/9d3e887bb7b01535fdbc45fbd5f0a8447539833b97ee69ecdbb7a79d0cb4/asyncpg-0.30.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c902a60b52e506d38d7e80e0dd5399f657220f24635fee368117b8b5fce1142e", size = 673162, upload-time = "2024-10-20T00:29:41.88Z" }, + { url = "https://files.pythonhosted.org/packages/6e/eb/8b236663f06984f212a087b3e849731f917ab80f84450e943900e8ca4052/asyncpg-0.30.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:aca1548e43bbb9f0f627a04666fedaca23db0a31a84136ad1f868cb15deb6e3a", size = 637025, upload-time = "2024-10-20T00:29:43.352Z" }, + { url = "https://files.pythonhosted.org/packages/cc/57/2dc240bb263d58786cfaa60920779af6e8d32da63ab9ffc09f8312bd7a14/asyncpg-0.30.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c2a2ef565400234a633da0eafdce27e843836256d40705d83ab7ec42074efb3", size = 3496243, upload-time = "2024-10-20T00:29:44.922Z" }, + { url = "https://files.pythonhosted.org/packages/f4/40/0ae9d061d278b10713ea9021ef6b703ec44698fe32178715a501ac696c6b/asyncpg-0.30.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1292b84ee06ac8a2ad8e51c7475aa309245874b61333d97411aab835c4a2f737", size = 3575059, upload-time = "2024-10-20T00:29:46.891Z" }, + { url = "https://files.pythonhosted.org/packages/c3/75/d6b895a35a2c6506952247640178e5f768eeb28b2e20299b6a6f1d743ba0/asyncpg-0.30.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0f5712350388d0cd0615caec629ad53c81e506b1abaaf8d14c93f54b35e3595a", size = 3473596, upload-time = "2024-10-20T00:29:49.201Z" }, + { url = "https://files.pythonhosted.org/packages/c8/e7/3693392d3e168ab0aebb2d361431375bd22ffc7b4a586a0fc060d519fae7/asyncpg-0.30.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:db9891e2d76e6f425746c5d2da01921e9a16b5a71a1c905b13f30e12a257c4af", size = 3641632, upload-time = "2024-10-20T00:29:50.768Z" }, + { url = "https://files.pythonhosted.org/packages/32/ea/15670cea95745bba3f0352341db55f506a820b21c619ee66b7d12ea7867d/asyncpg-0.30.0-cp312-cp312-win32.whl", hash = "sha256:68d71a1be3d83d0570049cd1654a9bdfe506e794ecc98ad0873304a9f35e411e", size = 560186, upload-time = "2024-10-20T00:29:52.394Z" }, + { url = "https://files.pythonhosted.org/packages/7e/6b/fe1fad5cee79ca5f5c27aed7bd95baee529c1bf8a387435c8ba4fe53d5c1/asyncpg-0.30.0-cp312-cp312-win_amd64.whl", hash = "sha256:9a0292c6af5c500523949155ec17b7fe01a00ace33b68a476d6b5059f9630305", size = 621064, upload-time = "2024-10-20T00:29:53.757Z" }, + { url = "https://files.pythonhosted.org/packages/3a/22/e20602e1218dc07692acf70d5b902be820168d6282e69ef0d3cb920dc36f/asyncpg-0.30.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:05b185ebb8083c8568ea8a40e896d5f7af4b8554b64d7719c0eaa1eb5a5c3a70", size = 670373, upload-time = "2024-10-20T00:29:55.165Z" }, + { url = "https://files.pythonhosted.org/packages/3d/b3/0cf269a9d647852a95c06eb00b815d0b95a4eb4b55aa2d6ba680971733b9/asyncpg-0.30.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:c47806b1a8cbb0a0db896f4cd34d89942effe353a5035c62734ab13b9f938da3", size = 634745, upload-time = "2024-10-20T00:29:57.14Z" }, + { url = "https://files.pythonhosted.org/packages/8e/6d/a4f31bf358ce8491d2a31bfe0d7bcf25269e80481e49de4d8616c4295a34/asyncpg-0.30.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9b6fde867a74e8c76c71e2f64f80c64c0f3163e687f1763cfaf21633ec24ec33", size = 3512103, upload-time = "2024-10-20T00:29:58.499Z" }, + { url = "https://files.pythonhosted.org/packages/96/19/139227a6e67f407b9c386cb594d9628c6c78c9024f26df87c912fabd4368/asyncpg-0.30.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46973045b567972128a27d40001124fbc821c87a6cade040cfcd4fa8a30bcdc4", size = 3592471, upload-time = "2024-10-20T00:30:00.354Z" }, + { url = "https://files.pythonhosted.org/packages/67/e4/ab3ca38f628f53f0fd28d3ff20edff1c975dd1cb22482e0061916b4b9a74/asyncpg-0.30.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9110df111cabc2ed81aad2f35394a00cadf4f2e0635603db6ebbd0fc896f46a4", size = 3496253, upload-time = "2024-10-20T00:30:02.794Z" }, + { url = "https://files.pythonhosted.org/packages/ef/5f/0bf65511d4eeac3a1f41c54034a492515a707c6edbc642174ae79034d3ba/asyncpg-0.30.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:04ff0785ae7eed6cc138e73fc67b8e51d54ee7a3ce9b63666ce55a0bf095f7ba", size = 3662720, upload-time = "2024-10-20T00:30:04.501Z" }, + { url = "https://files.pythonhosted.org/packages/e7/31/1513d5a6412b98052c3ed9158d783b1e09d0910f51fbe0e05f56cc370bc4/asyncpg-0.30.0-cp313-cp313-win32.whl", hash = "sha256:ae374585f51c2b444510cdf3595b97ece4f233fde739aa14b50e0d64e8a7a590", size = 560404, upload-time = "2024-10-20T00:30:06.537Z" }, + { url = "https://files.pythonhosted.org/packages/c8/a4/cec76b3389c4c5ff66301cd100fe88c318563ec8a520e0b2e792b5b84972/asyncpg-0.30.0-cp313-cp313-win_amd64.whl", hash = "sha256:f59b430b8e27557c3fb9869222559f7417ced18688375825f8f12302c34e915e", size = 621623, upload-time = "2024-10-20T00:30:09.024Z" }, +] + [[package]] name = "attrs" version = "25.3.0" @@ -1458,7 +1501,7 @@ wheels = [ [[package]] name = "openai" -version = "1.97.1" +version = "1.107.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -1470,14 +1513,14 @@ dependencies = [ { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a6/57/1c471f6b3efb879d26686d31582997615e969f3bb4458111c9705e56332e/openai-1.97.1.tar.gz", hash = "sha256:a744b27ae624e3d4135225da9b1c89c107a2a7e5bc4c93e5b7b5214772ce7a4e", size = 494267, upload-time = "2025-07-22T13:10:12.607Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a0/66/61b0c63b68df8a22f8763d7d632ea7255edb4021dca1859f4359a5659b85/openai-1.107.2.tar.gz", hash = "sha256:a11fe8d4318e98e94309308dd3a25108dec4dfc1b606f9b1c5706e8d88bdd3cb", size = 564155, upload-time = "2025-09-12T19:52:21.159Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ee/35/412a0e9c3f0d37c94ed764b8ac7adae2d834dbd20e69f6aca582118e0f55/openai-1.97.1-py3-none-any.whl", hash = "sha256:4e96bbdf672ec3d44968c9ea39d2c375891db1acc1794668d8149d5fa6000606", size = 764380, upload-time = "2025-07-22T13:10:10.689Z" }, + { url = "https://files.pythonhosted.org/packages/d3/65/e51a77a368eed7b9cc22ce394087ab43f13fa2884724729b716adf2da389/openai-1.107.2-py3-none-any.whl", hash = "sha256:d159d4f3ee3d9c717b248c5d69fe93d7773a80563c8b1ca8e9cad789d3cf0260", size = 946937, upload-time = "2025-09-12T19:52:19.355Z" }, ] [[package]] name = "openai-agents" -version = "0.2.3" +version = "0.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "griffe" }, @@ -1488,9 +1531,9 @@ dependencies = [ { name = "types-requests" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/e3/17/1f9eefb99fde956e5912a00fbdd03d50ebc734cc45a80b8fe4007d3813c2/openai_agents-0.2.3.tar.gz", hash = "sha256:95d4ad194c5c0cf1a40038cb701eee8ecdaaf7698d87bb13e3c2c5cff80c4b4d", size = 1464947, upload-time = "2025-07-21T19:34:20.595Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a3/49/78c902865ceb0432dae52f4a3feee8bcb57ce5be724c52e0119754fd421b/openai_agents-0.3.0.tar.gz", hash = "sha256:4d5d1a4f43cdc35b55c41ae4f31157cf5ff2e2c89563cde58616f8a77fde932d", size = 1700443, upload-time = "2025-09-11T19:20:09.742Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/eb/a7/d6bdf69a54c15d237a2be979981f33dab8f5da53f9bc2e734fb2b58592ca/openai_agents-0.2.3-py3-none-any.whl", hash = "sha256:15c5602de7076a5df6d11f07a18ffe0cf4f6811f6135b301acdd1998398a6d5c", size = 161393, upload-time = "2025-07-21T19:34:18.883Z" }, + { url = "https://files.pythonhosted.org/packages/46/3b/58ee42582716645aa1a5c30c1e337dc9a7433f7f0b5ed84ef02a35368abb/openai_agents-0.3.0-py3-none-any.whl", hash = "sha256:16de8a28729ae9e27faad7ce146a4b74acf05c9eeca3fe23299f6e621a3893ed", size = 185007, upload-time = "2025-09-11T19:20:08.304Z" }, ] [package.optional-dependencies] @@ -2596,8 +2639,8 @@ wheels = [ [[package]] name = "temporalio" -version = "1.15.0" -source = { registry = "https://pypi.org/simple" } +version = "1.17.0" +source = { registry = "../../b/wheels" } dependencies = [ { name = "nexus-rpc" }, { name = "protobuf" }, @@ -2605,13 +2648,9 @@ dependencies = [ { name = "types-protobuf" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/0b/af/1a3619fc62333d0acbdf90cfc5ada97e68e8c0f79610363b2dbb30871d83/temporalio-1.15.0.tar.gz", hash = "sha256:a4bc6ca01717880112caab75d041713aacc8263dc66e41f5019caef68b344fa0", size = 1684485, upload-time = "2025-07-29T03:44:09.071Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/0e/2d/0153f2bc459e0cb59d41d4dd71da46bf9a98ca98bc37237576c258d6696b/temporalio-1.15.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:74bc5cc0e6bdc161a43015538b0821b8713f5faa716c4209971c274b528e0d47", size = 12703607, upload-time = "2025-07-29T03:43:30.083Z" }, - { url = "https://files.pythonhosted.org/packages/e4/39/1b867ec698c8987aef3b7a7024b5c0c732841112fa88d021303d0fc69bea/temporalio-1.15.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:ee8001304dae5723d79797516cfeebe04b966fdbdf348e658fce3b43afdda3cd", size = 12232853, upload-time = "2025-07-29T03:43:38.909Z" }, - { url = "https://files.pythonhosted.org/packages/5e/3e/647d9a7c8b2f638f639717404c0bcbdd7d54fddd7844fdb802e3f40dc55f/temporalio-1.15.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8febd1ac36720817e69c2176aa4aca14a97fe0b83f0d2449c0c730b8f0174d02", size = 12636700, upload-time = "2025-07-29T03:43:49.066Z" }, - { url = "https://files.pythonhosted.org/packages/9a/13/7aa9ec694fec9fba39efdbf61d892bccf7d2b1aa3d9bd359544534c1d309/temporalio-1.15.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:202d81a42cafaed9ccc7ccbea0898838e3b8bf92fee65394f8790f37eafbaa63", size = 12860186, upload-time = "2025-07-29T03:43:57.644Z" }, - { url = "https://files.pythonhosted.org/packages/9f/2b/ba962401324892236148046dbffd805d4443d6df7a7dc33cc7964b566bf9/temporalio-1.15.0-cp39-abi3-win_amd64.whl", hash = "sha256:aae5b18d7c9960238af0f3ebf6b7e5959e05f452106fc0d21a8278d78724f780", size = 12932800, upload-time = "2025-07-29T03:44:06.271Z" }, + { path = "temporalio-1.17.0-cp39-abi3-macosx_11_0_arm64.whl" }, + { path = "temporalio-1.17.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, ] [package.optional-dependencies] @@ -2628,6 +2667,7 @@ name = "temporalio-samples" version = "0.1a1" source = { editable = "." } dependencies = [ + { name = "asyncpg" }, { name = "temporalio" }, ] @@ -2697,7 +2737,10 @@ trio-async = [ ] [package.metadata] -requires-dist = [{ name = "temporalio", specifier = ">=1.15.0,<2" }] +requires-dist = [ + { name = "asyncpg", specifier = ">=0.30.0" }, + { name = "temporalio", specifier = ">=1.15.0,<2" }, +] [package.metadata.requires-dev] bedrock = [{ name = "boto3", specifier = ">=1.34.92,<2" }] @@ -2744,8 +2787,8 @@ open-telemetry = [ { name = "temporalio", extras = ["opentelemetry"] }, ] openai-agents = [ - { name = "openai-agents", extras = ["litellm"], specifier = ">=0.2.3" }, - { name = "temporalio", extras = ["openai-agents"], specifier = ">=1.15.0" }, + { name = "openai-agents", extras = ["litellm"], specifier = ">=0.3.0" }, + { name = "temporalio", extras = ["openai-agents"], specifier = ">=1.16.0" }, ] pydantic-converter = [{ name = "pydantic", specifier = ">=2.10.6,<3" }] sentry = [{ name = "sentry-sdk", specifier = ">=2.13.0" }] From 1196551748ac268b44997d11d1097a8f1b185dc7 Mon Sep 17 00:00:00 2001 From: Johann Schleier-Smith Date: Sun, 14 Sep 2025 08:11:28 -0700 Subject: [PATCH 2/3] move test connection management to tests --- openai_agents/memory/connection_state.py | 43 ------------------- tests/conftest.py | 38 ---------------- tests/openai_agents/memory/conftest.py | 40 +++++++++++++++++ .../memory/test_idempotence_util.py | 39 ++++++++++++++--- 4 files changed, 74 insertions(+), 86 deletions(-) delete mode 100644 openai_agents/memory/connection_state.py create mode 100644 tests/openai_agents/memory/conftest.py diff --git a/openai_agents/memory/connection_state.py b/openai_agents/memory/connection_state.py deleted file mode 100644 index f17fc3f4..00000000 --- a/openai_agents/memory/connection_state.py +++ /dev/null @@ -1,43 +0,0 @@ -"""Worker-level database connection state management. - -WARNING: This implementation uses global state and is not safe for concurrent -testing (e.g., pytest-xdist). Run tests sequentially to avoid race conditions. -""" - -import asyncpg -from typing import Optional - - -# Module-level connection state -_connection: Optional[asyncpg.Connection] = None - - -def set_worker_connection(connection: asyncpg.Connection) -> None: - """Set the worker-level database connection.""" - global _connection - _connection = connection - - -def get_worker_connection() -> asyncpg.Connection: - """Get the worker-level database connection. - - Raises: - RuntimeError: If no connection has been set. - """ - if _connection is None: - raise RuntimeError( - "No worker-level database connection has been set. " - "Call set_worker_connection() before using activities." - ) - return _connection - - -def clear_worker_connection() -> None: - """Clear the worker-level database connection.""" - global _connection - _connection = None - - -def has_worker_connection() -> bool: - """Check if a worker-level connection is available.""" - return _connection is not None diff --git a/tests/conftest.py b/tests/conftest.py index 92e8a932..e63a059b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,8 @@ import asyncio import multiprocessing -import os import sys -import uuid from typing import AsyncGenerator -import asyncpg import pytest import pytest_asyncio from temporalio.client import Client @@ -61,38 +58,3 @@ async def env(request) -> AsyncGenerator[WorkflowEnvironment, None]: @pytest_asyncio.fixture async def client(env: WorkflowEnvironment) -> Client: return env.client - - -@pytest_asyncio.fixture -async def db_connection(): - """Create a PostgreSQL connection with a unique schema for each test. - - Sets up a temporary schema with UUID naming, sets it as the default schema - for the connection, and cleans up with CASCADE on teardown. - """ - # Generate unique schema name - schema_name = f"test_{uuid.uuid4().hex}" - - # Create connection - # Note that we read the DATABASE_URL from the environment because asyncpg does not read this - # environment variable. It does read other postgres environment variables such as PGHOST, - # PGPORT, PGDATABASE, PGUSER, and PGPASSWORD, so you can still use those if you do not set - # DATABASE_URL. - conn = await asyncpg.connect(os.getenv("DATABASE_URL")) - - try: - # Create the schema - await conn.execute(f"CREATE SCHEMA {schema_name}") - - # Set the schema as default for this connection - await conn.execute(f"SET search_path TO {schema_name}") - - yield conn - - finally: - # Clean up: drop schema with cascade - try: - await conn.execute(f"DROP SCHEMA IF EXISTS {schema_name} CASCADE") - except Exception: - pass # Best effort cleanup - await conn.close() diff --git a/tests/openai_agents/memory/conftest.py b/tests/openai_agents/memory/conftest.py new file mode 100644 index 00000000..e75f63cd --- /dev/null +++ b/tests/openai_agents/memory/conftest.py @@ -0,0 +1,40 @@ +import os +import uuid + +import asyncpg +import pytest_asyncio + + +@pytest_asyncio.fixture +async def db_connection(): + """Create a PostgreSQL connection with a unique schema for each test. + + Sets up a temporary schema with UUID naming, sets it as the default schema + for the connection, and cleans up with CASCADE on teardown. + """ + # Generate unique schema name + schema_name = f"test_{uuid.uuid4().hex}" + + # Create connection + # Note that we read the DATABASE_URL from the environment because asyncpg does not read this + # environment variable. It does read other postgres environment variables such as PGHOST, + # PGPORT, PGDATABASE, PGUSER, and PGPASSWORD, so you can still use those if you do not set + # DATABASE_URL. + conn = await asyncpg.connect(os.getenv("DATABASE_URL")) + + try: + # Create the schema + await conn.execute(f"CREATE SCHEMA {schema_name}") + + # Set the schema as default for this connection + await conn.execute(f"SET search_path TO {schema_name}") + + yield conn + + finally: + # Clean up: drop schema with cascade + try: + await conn.execute(f"DROP SCHEMA IF EXISTS {schema_name} CASCADE") + except Exception: + pass # Best effort cleanup + await conn.close() diff --git a/tests/openai_agents/memory/test_idempotence_util.py b/tests/openai_agents/memory/test_idempotence_util.py index 6df505a0..6e7deda6 100644 --- a/tests/openai_agents/memory/test_idempotence_util.py +++ b/tests/openai_agents/memory/test_idempotence_util.py @@ -9,15 +9,44 @@ import asyncpg from temporalio.common import RetryPolicy -from openai_agents.memory.connection_state import ( - set_worker_connection, - get_worker_connection, - clear_worker_connection, -) from openai_agents.memory.db_utils import IdempotenceHelper from pydantic import BaseModel from datetime import datetime from temporalio.contrib.pydantic import pydantic_data_converter +from typing import Optional + + +# WARNING: This implementation uses global state and is not safe for concurrent +# testing (e.g., pytest-xdist). Run tests sequentially to avoid race conditions. + +# Module-level connection state +_connection: Optional[asyncpg.Connection] = None + + +def set_worker_connection(connection: asyncpg.Connection) -> None: + """Set the worker-level database connection.""" + global _connection + _connection = connection + + +def get_worker_connection() -> asyncpg.Connection: + """Get the worker-level database connection. + + Raises: + RuntimeError: If no connection has been set. + """ + if _connection is None: + raise RuntimeError( + "No worker-level database connection has been set. " + "Call set_worker_connection() before using activities." + ) + return _connection + + +def clear_worker_connection() -> None: + """Clear the worker-level database connection.""" + global _connection + _connection = None @workflow.defn From b5e76a1576d1f12e54de9067f843d3ba9d2fceb6 Mon Sep 17 00:00:00 2001 From: Johann Schleier-Smith Date: Sun, 14 Sep 2025 08:28:09 -0700 Subject: [PATCH 3/3] cleanup --- openai_agents/memory/postgres_session.py | 83 +++++++++++-------- .../memory/run_postgres_session_worker.py | 26 +++--- 2 files changed, 58 insertions(+), 51 deletions(-) diff --git a/openai_agents/memory/postgres_session.py b/openai_agents/memory/postgres_session.py index 8936a03c..3423d808 100644 --- a/openai_agents/memory/postgres_session.py +++ b/openai_agents/memory/postgres_session.py @@ -56,42 +56,6 @@ class PostgresSessionConfig(BaseModel): idempotence_table: str = "activity_idempotence" -async def init_schema(conn: asyncpg.Connection, config: PostgresSessionConfig) -> None: - """Initialize the PostgreSQL schema.""" - async with conn.transaction(): - # Create sessions table - sessions_ddl = f""" - CREATE TABLE IF NOT EXISTS {config.sessions_table} ( - session_id TEXT NOT NULL, - created_at TIMESTAMP DEFAULT NOW(), - updated_at TIMESTAMP DEFAULT NOW(), - PRIMARY KEY (session_id) - ) - """ - await conn.execute(sessions_ddl) - - # Create operation_id sequence - operation_id_ddl = f""" - CREATE SEQUENCE IF NOT EXISTS {config.operation_id_sequence} START 1 - """ - await conn.execute(operation_id_ddl) - - # Create messages table - messages_ddl = f""" - CREATE TABLE IF NOT EXISTS {config.messages_table} ( - session_id TEXT NOT NULL, - operation_id INTEGER NOT NULL DEFAULT nextval('{config.operation_id_sequence}'), - message_data TEXT NOT NULL, - created_at TIMESTAMP DEFAULT NOW(), - deleted_at TIMESTAMP NULL, - PRIMARY KEY (session_id, operation_id), - FOREIGN KEY (session_id) - REFERENCES {config.sessions_table} (session_id) - ON DELETE CASCADE - ) - """ - await conn.execute(messages_ddl) - class PostgresSessionGetItemsRequest(BaseModel): config: PostgresSessionConfig @@ -284,3 +248,50 @@ def _get_connection(): if _connection_factory is None: raise ValueError("Connection factory not set") return _connection_factory() + + @staticmethod + async def init_schema(config: PostgresSessionConfig) -> None: + conn = PostgresSession._get_connection() + """Initialize the PostgreSQL schema.""" + async with conn.transaction(): + # Create sessions table + sessions_ddl = f""" + CREATE TABLE IF NOT EXISTS {config.sessions_table} ( + session_id TEXT NOT NULL, + created_at TIMESTAMP DEFAULT NOW(), + updated_at TIMESTAMP DEFAULT NOW(), + PRIMARY KEY (session_id) + ) + """ + await conn.execute(sessions_ddl) + + # Create operation_id sequence + operation_id_ddl = f""" + CREATE SEQUENCE IF NOT EXISTS {config.operation_id_sequence} START 1 + """ + await conn.execute(operation_id_ddl) + + # Create messages table + messages_ddl = f""" + CREATE TABLE IF NOT EXISTS {config.messages_table} ( + session_id TEXT NOT NULL, + operation_id INTEGER NOT NULL DEFAULT nextval('{config.operation_id_sequence}'), + message_data TEXT NOT NULL, + created_at TIMESTAMP DEFAULT NOW(), + deleted_at TIMESTAMP NULL, + PRIMARY KEY (session_id, operation_id), + FOREIGN KEY (session_id) + REFERENCES {config.sessions_table} (session_id) + ON DELETE CASCADE + ) + """ + await conn.execute(messages_ddl) + + @staticmethod + def get_activities() -> list[Callable[[], activity.Activity]]: + return [ + postgres_session_get_items_activity, + postgres_session_add_items_activity, + postgres_session_pop_item_activity, + postgres_session_clear_session_activity, + ] \ No newline at end of file diff --git a/openai_agents/memory/run_postgres_session_worker.py b/openai_agents/memory/run_postgres_session_worker.py index 9a3f9e75..96b592cd 100644 --- a/openai_agents/memory/run_postgres_session_worker.py +++ b/openai_agents/memory/run_postgres_session_worker.py @@ -14,12 +14,7 @@ ) from openai_agents.memory.postgres_session import ( PostgresSessionConfig, - init_schema, PostgresSession, - postgres_session_get_items_activity, - postgres_session_add_items_activity, - postgres_session_pop_item_activity, - postgres_session_clear_session_activity, ) from openai_agents.memory.db_utils import IdempotenceHelper @@ -32,12 +27,18 @@ async def main(): messages_table="session_messages", sessions_table="session", operation_id_sequence="session_operation_id_sequence", + idempotence_table="activity_idempotence", ) - PostgresSession.set_connection_factory(lambda: db_connection) - await init_schema(db_connection, config=postgres_session_config) - idempotence_helper = IdempotenceHelper(table_name="activity_idempotence") + + # Create the idempotence table. This is used to ensure that activities are idempotent with + # respect to database modifications. + idempotence_helper = IdempotenceHelper(table_name=postgres_session_config.idempotence_table) await idempotence_helper.create_table(db_connection) + + # Configure the Postgres Session with the database connection. + # Initialize the schema. PostgresSession.set_connection_factory(lambda: db_connection) + await PostgresSession.init_schema(config=postgres_session_config) # Create client connected to server at the given address client = await Client.connect( @@ -53,16 +54,11 @@ async def main(): worker = Worker( client, - task_queue="openai-agents-memory-task-queue", + task_queue="openai-postgres-session-task-queue", workflows=[ PostgresSessionWorkflow, ], - activities=[ - postgres_session_get_items_activity, - postgres_session_add_items_activity, - postgres_session_pop_item_activity, - postgres_session_clear_session_activity, - ], + activities=[*PostgresSession.get_activities()], ) await worker.run()