From cf13e7b7f8891daf94d6332336d9d42043e45506 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Tue, 9 Dec 2025 15:07:30 -0800 Subject: [PATCH 1/6] Add OpenAI streaming support --- .../openai_agents/_invoke_model_activity.py | 239 ++++++++++++------ .../contrib/openai_agents/_openai_runner.py | 130 ++++++---- .../openai_agents/_temporal_model_stub.py | 232 +++++++++++------ .../openai_agents/_temporal_openai_agents.py | 6 +- temporalio/contrib/openai_agents/testing.py | 138 +++++++++- tests/contrib/openai_agents/test_openai.py | 56 ++++ 6 files changed, 586 insertions(+), 215 deletions(-) diff --git a/temporalio/contrib/openai_agents/_invoke_model_activity.py b/temporalio/contrib/openai_agents/_invoke_model_activity.py index f03458c32..3befb690e 100644 --- a/temporalio/contrib/openai_agents/_invoke_model_activity.py +++ b/temporalio/contrib/openai_agents/_invoke_model_activity.py @@ -3,11 +3,12 @@ Implements mapping of OpenAI datastructures to Pydantic friendly types. """ +import asyncio import enum import json from dataclasses import dataclass from datetime import timedelta -from typing import Any, Optional, Union +from typing import Any, NoReturn, Optional, Union from agents import ( AgentOutputSchemaBase, @@ -28,6 +29,7 @@ UserError, WebSearchTool, ) +from agents.items import TResponseStreamEvent from openai import ( APIStatusError, AsyncOpenAI, @@ -148,6 +150,13 @@ class ActivityModelInput(TypedDict, total=False): prompt: Any | None +class StreamActivityModelInput(ActivityModelInput): + """Input for the stream_model activity.""" + + signal: str + batch_latency_seconds: float + + class ModelActivity: """Class wrapper for model invocation activities to allow model customization. By default, we use an OpenAIProvider with retries disabled. Disabling retries in your model of choice is recommended to allow activity retries to define the retry model. @@ -165,52 +174,8 @@ async def invoke_model_activity(self, input: ActivityModelInput) -> ModelRespons """Activity that invokes a model with the given input.""" model = self._model_provider.get_model(input.get("model_name")) - async def empty_on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> str: - return "" - - async def empty_on_invoke_handoff( - ctx: RunContextWrapper[Any], input: str - ) -> Any: - return None - - def make_tool(tool: ToolInput) -> Tool: - if isinstance( - tool, - ( - FileSearchTool, - WebSearchTool, - ImageGenerationTool, - CodeInterpreterTool, - ), - ): - return tool - elif isinstance(tool, HostedMCPToolInput): - return HostedMCPTool( - tool_config=tool.tool_config, - ) - elif isinstance(tool, FunctionToolInput): - return FunctionTool( - name=tool.name, - description=tool.description, - params_json_schema=tool.params_json_schema, - on_invoke_tool=empty_on_invoke_tool, - strict_json_schema=tool.strict_json_schema, - ) - else: - raise UserError(f"Unknown tool type: {tool.name}") - - tools = [make_tool(x) for x in input.get("tools", [])] - handoffs: list[Handoff[Any, Any]] = [ - Handoff( - tool_name=x.tool_name, - tool_description=x.tool_description, - input_json_schema=x.input_json_schema, - agent_name=x.agent_name, - strict_json_schema=x.strict_json_schema, - on_invoke_handoff=empty_on_invoke_handoff, - ) - for x in input.get("handoffs", []) - ] + tools = _make_tools(input) + handoffs = _make_handoffs(input) try: return await model.get_response( @@ -226,40 +191,146 @@ def make_tool(tool: ToolInput) -> Tool: prompt=input.get("prompt"), ) except APIStatusError as e: - # Listen to server hints - retry_after = None - retry_after_ms_header = e.response.headers.get("retry-after-ms") - if retry_after_ms_header is not None: - retry_after = timedelta(milliseconds=float(retry_after_ms_header)) - - if retry_after is None: - retry_after_header = e.response.headers.get("retry-after") - if retry_after_header is not None: - retry_after = timedelta(seconds=float(retry_after_header)) - - should_retry_header = e.response.headers.get("x-should-retry") - if should_retry_header == "true": - raise e - if should_retry_header == "false": - raise ApplicationError( - "Non retryable OpenAI error", - non_retryable=True, - next_retry_delay=retry_after, - ) from e - - # Specifically retryable status codes - if ( - e.response.status_code in [408, 409, 429] - or e.response.status_code >= 500 - ): - raise ApplicationError( - f"Retryable OpenAI status code: {e.response.status_code}", - non_retryable=False, - next_retry_delay=retry_after, - ) from e - - raise ApplicationError( - f"Non retryable OpenAI status code: {e.response.status_code}", - non_retryable=True, - next_retry_delay=retry_after, - ) from e + _handle_error(e) + + @activity.defn + async def stream_model(self, input: StreamActivityModelInput) -> None: + """Activity that streams a model with the given input.""" + model = self._model_provider.get_model(input.get("model_name")) + + tools = _make_tools(input) + handoffs = _make_handoffs(input) + + try: + handle = activity.client().get_workflow_handle( + workflow_id=activity.info().workflow_id + ) + events = model.stream_response( + system_instructions=input.get("system_instructions"), + input=input["input"], + model_settings=input["model_settings"], + tools=tools, + output_schema=input.get("output_schema"), + handoffs=handoffs, + tracing=ModelTracing(input["tracing"]), + previous_response_id=input.get("previous_response_id"), + conversation_id=input.get("conversation_id"), + prompt=input.get("prompt"), + ) + + # Batch events with configurable latency + batch: list[TResponseStreamEvent] = [] + last_signal_time = asyncio.get_event_loop().time() + batch_latency = input.get("batch_latency_seconds", 1.0) + + async def send_batch(): + nonlocal last_signal_time + if batch: + await handle.signal(input["signal"], batch) + batch.clear() + last_signal_time = asyncio.get_event_loop().time() + + async for event in events: + event.model_rebuild() + batch.append(event) + + current_time = asyncio.get_event_loop().time() + if current_time - last_signal_time >= batch_latency: + await send_batch() + + # Send any remaining events in the batch + if batch: + await send_batch() + + except APIStatusError as e: + _handle_error(e) + + +async def _empty_on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> str: + return "" + + +async def _empty_on_invoke_handoff(ctx: RunContextWrapper[Any], input: str) -> Any: + return None + + +def _make_tool(tool: ToolInput) -> Tool: + if isinstance( + tool, + ( + FileSearchTool, + WebSearchTool, + ImageGenerationTool, + CodeInterpreterTool, + ), + ): + return tool + elif isinstance(tool, HostedMCPToolInput): + return HostedMCPTool( + tool_config=tool.tool_config, + ) + elif isinstance(tool, FunctionToolInput): + return FunctionTool( + name=tool.name, + description=tool.description, + params_json_schema=tool.params_json_schema, + on_invoke_tool=_empty_on_invoke_tool, + strict_json_schema=tool.strict_json_schema, + ) + else: + raise UserError(f"Unknown tool type: {tool.name}") + + +def _make_tools(input: ActivityModelInput) -> list[Tool]: + return [_make_tool(x) for x in input.get("tools", [])] + + +def _make_handoffs(input: ActivityModelInput) -> list[Handoff[Any, Any]]: + return [ + Handoff( + tool_name=x.tool_name, + tool_description=x.tool_description, + input_json_schema=x.input_json_schema, + agent_name=x.agent_name, + strict_json_schema=x.strict_json_schema, + on_invoke_handoff=_empty_on_invoke_handoff, + ) + for x in input.get("handoffs", []) + ] + + +def _handle_error(e: APIStatusError) -> NoReturn: + # Listen to server hints + retry_after = None + retry_after_ms_header = e.response.headers.get("retry-after-ms") + if retry_after_ms_header is not None: + retry_after = timedelta(milliseconds=float(retry_after_ms_header)) + + if retry_after is None: + retry_after_header = e.response.headers.get("retry-after") + if retry_after_header is not None: + retry_after = timedelta(seconds=float(retry_after_header)) + + should_retry_header = e.response.headers.get("x-should-retry") + if should_retry_header == "true": + raise e + if should_retry_header == "false": + raise ApplicationError( + "Non retryable OpenAI error", + non_retryable=True, + next_retry_delay=retry_after, + ) from e + + # Specifically retryable status codes + if e.response.status_code in [408, 409, 429] or e.response.status_code >= 500: + raise ApplicationError( + f"Retryable OpenAI status code: {e.response.status_code}", + non_retryable=False, + next_retry_delay=retry_after, + ) from e + + raise ApplicationError( + f"Non retryable OpenAI status code: {e.response.status_code}", + non_retryable=True, + next_retry_delay=retry_after, + ) from e diff --git a/temporalio/contrib/openai_agents/_openai_runner.py b/temporalio/contrib/openai_agents/_openai_runner.py index a8065d207..b261306a6 100644 --- a/temporalio/contrib/openai_agents/_openai_runner.py +++ b/temporalio/contrib/openai_agents/_openai_runner.py @@ -15,7 +15,7 @@ Tool, TResponseInputItem, ) -from agents.run import DEFAULT_AGENT_RUNNER, DEFAULT_MAX_TURNS, AgentRunner +from agents.run import DEFAULT_AGENT_RUNNER, AgentRunner from temporalio import workflow from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters @@ -98,66 +98,15 @@ async def run( **kwargs, ) - tool_types = typing.get_args(Tool) - for t in starting_agent.tools: - if not isinstance(t, tool_types): - raise ValueError( - "Provided tool is not a tool type. If using an activity, make sure to wrap it with openai_agents.workflow.activity_as_tool." - ) + _check_preconditions(starting_agent, **kwargs) - if starting_agent.mcp_servers: - from temporalio.contrib.openai_agents._mcp import ( - _StatefulMCPServerReference, - _StatelessMCPServerReference, - ) - - for s in starting_agent.mcp_servers: - if not isinstance( - s, - ( - _StatelessMCPServerReference, - _StatefulMCPServerReference, - ), - ): - raise ValueError( - f"Unknown mcp_server type {type(s)} may not work durably." - ) - - context = kwargs.get("context") - max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS) - hooks = kwargs.get("hooks") - run_config = kwargs.get("run_config") - previous_response_id = kwargs.get("previous_response_id") - session = kwargs.get("session") - - if isinstance(session, SQLiteSession): - raise ValueError("Temporal workflows don't support SQLite sessions.") - - if run_config is None: - run_config = RunConfig() - - if run_config.model: - if not isinstance(run_config.model, str): - raise ValueError( - "Temporal workflows require a model name to be a string in the run config." - ) - run_config = dataclasses.replace( - run_config, - model=_TemporalModelStub( - run_config.model, model_params=self.model_params, agent=None - ), - ) + kwargs["run_config"] = self._process_run_config(kwargs.get("run_config")) try: return await self._runner.run( starting_agent=_convert_agent(self.model_params, starting_agent, None), input=input, - context=context, - max_turns=max_turns, - hooks=hooks, - run_config=run_config, - previous_response_id=previous_response_id, - session=session, + **kwargs, ) except AgentsException as e: # In order for workflow failures to properly fail the workflow, we need to rewrap them in @@ -199,7 +148,45 @@ def run_streamed( input, **kwargs, ) - raise RuntimeError("Temporal workflows do not support streaming.") + + _check_preconditions(starting_agent, **kwargs) + + kwargs["run_config"] = self._process_run_config(kwargs.get("run_config")) + + try: + return self._runner.run_streamed( + starting_agent=_convert_agent(self.model_params, starting_agent, None), + input=input, + **kwargs, + ) + except AgentsException as e: + # In order for workflow failures to properly fail the workflow, we need to rewrap them in + # a Temporal error + if e.__cause__ and workflow.is_failure_exception(e.__cause__): + reraise = AgentsWorkflowError( + f"Workflow failure exception in Agents Framework: {e}" + ) + reraise.__traceback__ = e.__traceback__ + raise reraise from e.__cause__ + else: + raise e + + def _process_run_config(self, run_config: RunConfig | None) -> RunConfig: + if run_config is None: + run_config = RunConfig() + + if run_config.model: + if not isinstance(run_config.model, str): + raise ValueError( + "Temporal workflows require a model name to be a string in the run config." + ) + run_config = dataclasses.replace( + run_config, + model=_TemporalModelStub( + run_config.model, model_params=self.model_params, agent=None + ), + ) + return run_config def _model_name(agent: Agent[Any]) -> str | None: @@ -209,3 +196,34 @@ def _model_name(agent: Agent[Any]) -> str | None: "Temporal workflows require a model name to be a string in the agent." ) return name + + +def _check_preconditions(starting_agent: Agent[TContext], **kwargs: Any) -> None: + tool_types = typing.get_args(Tool) + for t in starting_agent.tools: + if not isinstance(t, tool_types): + raise ValueError( + "Provided tool is not a tool type. If using an activity, make sure to wrap it with openai_agents.workflow.activity_as_tool." + ) + + if starting_agent.mcp_servers: + from temporalio.contrib.openai_agents._mcp import ( + _StatefulMCPServerReference, + _StatelessMCPServerReference, + ) + + for s in starting_agent.mcp_servers: + if not isinstance( + s, + ( + _StatelessMCPServerReference, + _StatefulMCPServerReference, + ), + ): + raise ValueError( + f"Unknown mcp_server type {type(s)} may not work durably." + ) + + session = kwargs.get("session") + if isinstance(session, SQLiteSession): + raise ValueError("Temporal workflows don't support SQLite sessions.") diff --git a/temporalio/contrib/openai_agents/_temporal_model_stub.py b/temporalio/contrib/openai_agents/_temporal_model_stub.py index f84488541..0f5b5e62a 100644 --- a/temporalio/contrib/openai_agents/_temporal_model_stub.py +++ b/temporalio/contrib/openai_agents/_temporal_model_stub.py @@ -1,10 +1,13 @@ from __future__ import annotations +import asyncio import logging -from typing import Optional +from asyncio import FIRST_COMPLETED +from typing import Optional, Tuple from temporalio import workflow from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters +from temporalio.contrib.openai_agents._temporal_trace_provider import _workflow_uuid logger = logging.getLogger(__name__) @@ -40,6 +43,7 @@ HostedMCPToolInput, ModelActivity, ModelTracingInput, + StreamActivityModelInput, ToolInput, ) @@ -57,6 +61,7 @@ def __init__( self.model_name = model_name self.model_params = model_params self.agent = agent + self.stream_events: list[TResponseStreamEvent] = [] async def get_response( self, @@ -72,88 +77,25 @@ async def get_response( conversation_id: str | None, prompt: ResponsePromptParam | None, ) -> ModelResponse: - def make_tool_info(tool: Tool) -> ToolInput: - if isinstance( - tool, - ( - FileSearchTool, - WebSearchTool, - ImageGenerationTool, - CodeInterpreterTool, - ), - ): - return tool - elif isinstance(tool, HostedMCPTool): - return HostedMCPToolInput(tool_config=tool.tool_config) - elif isinstance(tool, FunctionTool): - return FunctionToolInput( - name=tool.name, - description=tool.description, - params_json_schema=tool.params_json_schema, - strict_json_schema=tool.strict_json_schema, - ) - else: - raise ValueError(f"Unsupported tool type: {tool.name}") - - tool_infos = [make_tool_info(x) for x in tools] - handoff_infos = [ - HandoffInput( - tool_name=x.tool_name, - tool_description=x.tool_description, - input_json_schema=x.input_json_schema, - agent_name=x.agent_name, - strict_json_schema=x.strict_json_schema, - ) - for x in handoffs - ] - if output_schema is not None and not isinstance( - output_schema, AgentOutputSchema - ): - raise TypeError( - f"Only AgentOutputSchema is supported by Temporal Model, got {type(output_schema).__name__}" - ) - agent_output_schema = output_schema - output_schema_input = ( - None - if agent_output_schema is None - else AgentOutputSchemaInput( - output_type_name=agent_output_schema.name(), - is_wrapped=agent_output_schema._is_wrapped, - output_schema=agent_output_schema.json_schema() - if not agent_output_schema.is_plain_text() - else None, - strict_json_schema=agent_output_schema.is_strict_json_schema(), - ) - ) + tool_inputs = _make_tool_inputs(tools) + handoff_inputs = _make_handoff_inputs(handoffs) + output_schema_input = _make_output_schema_input(output_schema) activity_input = ActivityModelInput( model_name=self.model_name, system_instructions=system_instructions, input=input, model_settings=model_settings, - tools=tool_infos, + tools=tool_inputs, output_schema=output_schema_input, - handoffs=handoff_infos, + handoffs=handoff_inputs, tracing=ModelTracingInput(tracing.value), previous_response_id=previous_response_id, conversation_id=conversation_id, prompt=prompt, ) - if self.model_params.summary_override: - summary = ( - self.model_params.summary_override - if isinstance(self.model_params.summary_override, str) - else ( - self.model_params.summary_override.provide( - self.agent, system_instructions, input - ) - ) - ) - elif self.agent: - summary = self.agent.name - else: - summary = None + summary = self._make_summary(system_instructions, input) if self.model_params.use_local_activity: return await workflow.execute_local_activity_method( @@ -196,7 +138,91 @@ def stream_response( conversation_id: str | None, prompt: ResponsePromptParam | None, ) -> AsyncIterator[TResponseStreamEvent]: - raise NotImplementedError("Temporal model doesn't support streams yet") + if self.model_params.use_local_activity: + raise ValueError("Streaming is not available with local activities.") + + tool_inputs = _make_tool_inputs(tools) + handoff_inputs = _make_handoff_inputs(handoffs) + output_schema_input = _make_output_schema_input(output_schema) + + summary = self._make_summary(system_instructions, input) + + stream_queue: asyncio.Queue[TResponseStreamEvent | None] = asyncio.Queue() + + async def handle_stream_event(events: list[TResponseStreamEvent]): + for event in events: + await stream_queue.put(event) + + signal_name = "model_stream_signal" + workflow.set_signal_handler(signal_name, handle_stream_event) + + activity_input = StreamActivityModelInput( + model_name=self.model_name, + system_instructions=system_instructions, + input=input, + model_settings=model_settings, + tools=tool_inputs, + output_schema=output_schema_input, + handoffs=handoff_inputs, + tracing=ModelTracingInput(tracing.value), + previous_response_id=previous_response_id, + conversation_id=conversation_id, + prompt=prompt, + signal=signal_name, + batch_latency_seconds=1.0, + ) + + handle = workflow.start_activity_method( + ModelActivity.stream_model, + activity_input, + summary=summary, + task_queue=self.model_params.task_queue, + schedule_to_close_timeout=self.model_params.schedule_to_close_timeout, + schedule_to_start_timeout=self.model_params.schedule_to_start_timeout, + start_to_close_timeout=self.model_params.start_to_close_timeout, + heartbeat_timeout=self.model_params.heartbeat_timeout, + retry_policy=self.model_params.retry_policy, + cancellation_type=self.model_params.cancellation_type, + versioning_intent=self.model_params.versioning_intent, + priority=self.model_params.priority, + ) + + class SignalIterator(AsyncIterator): + def __init__(self): + self._monitor_task = asyncio.create_task(self._monitor_activity()) + + async def _monitor_activity(self): + try: + await handle + finally: + await stream_queue.put(None) # Signal end of stream + + async def __anext__(self): + item = await stream_queue.get() + if item is None: + self._monitor_task.cancel() + raise StopAsyncIteration + return item + + return SignalIterator() + + def _make_summary( + self, system_instructions: str | None, input: str | list[TResponseInputItem] + ) -> str | None: + if self.model_params.summary_override: + return ( + self.model_params.summary_override + if isinstance(self.model_params.summary_override, str) + else ( + self.model_params.summary_override.provide( + self.agent, system_instructions, input + ) + ) + ) + elif self.agent: + return self.agent.name + else: + return None def _extract_summary(input: str | list[TResponseInputItem]) -> str: @@ -228,3 +254,67 @@ def _extract_summary(input: str | list[TResponseInputItem]) -> str: except Exception as e: logger.error(f"Error getting summary: {e}") return "" + + +def _make_tool_input(tool: Tool) -> ToolInput: + if isinstance( + tool, + ( + FileSearchTool, + WebSearchTool, + ImageGenerationTool, + CodeInterpreterTool, + ), + ): + return tool + elif isinstance(tool, HostedMCPTool): + return HostedMCPToolInput(tool_config=tool.tool_config) + elif isinstance(tool, FunctionTool): + return FunctionToolInput( + name=tool.name, + description=tool.description, + params_json_schema=tool.params_json_schema, + strict_json_schema=tool.strict_json_schema, + ) + else: + raise ValueError(f"Unsupported tool type: {tool.name}") + + +def _make_tool_inputs(tools: list[Tool]) -> list[ToolInput]: + return [_make_tool_input(x) for x in tools] + + +def _make_handoff_inputs(handoffs: list[Handoff]) -> list[HandoffInput]: + return [ + HandoffInput( + tool_name=x.tool_name, + tool_description=x.tool_description, + input_json_schema=x.input_json_schema, + agent_name=x.agent_name, + strict_json_schema=x.strict_json_schema, + ) + for x in handoffs + ] + + +def _make_output_schema_input( + output_schema: AgentOutputSchemaBase | None, +) -> AgentOutputSchemaInput | None: + if output_schema is not None and not isinstance(output_schema, AgentOutputSchema): + raise TypeError( + f"Only AgentOutputSchema is supported by Temporal Model, got {type(output_schema).__name__}" + ) + + agent_output_schema = output_schema + return ( + None + if agent_output_schema is None + else AgentOutputSchemaInput( + output_type_name=agent_output_schema.name(), + is_wrapped=agent_output_schema._is_wrapped, + output_schema=agent_output_schema.json_schema() + if not agent_output_schema.is_plain_text() + else None, + strict_json_schema=agent_output_schema.is_strict_json_schema(), + ) + ) diff --git a/temporalio/contrib/openai_agents/_temporal_openai_agents.py b/temporalio/contrib/openai_agents/_temporal_openai_agents.py index 41ae419f7..5e42882cf 100644 --- a/temporalio/contrib/openai_agents/_temporal_openai_agents.py +++ b/temporalio/contrib/openai_agents/_temporal_openai_agents.py @@ -221,7 +221,11 @@ def add_activities( if not register_activities: return activities or [] - new_activities = [ModelActivity(model_provider).invoke_model_activity] + model_activity = ModelActivity(model_provider) + new_activities = [ + model_activity.invoke_model_activity, + model_activity.stream_model, + ] server_names = [server.name for server in mcp_server_providers] if len(server_names) != len(set(server_names)): diff --git a/temporalio/contrib/openai_agents/testing.py b/temporalio/contrib/openai_agents/testing.py index 4acab196a..1bfdb1e78 100644 --- a/temporalio/contrib/openai_agents/testing.py +++ b/temporalio/contrib/openai_agents/testing.py @@ -17,9 +17,14 @@ ) from agents.items import TResponseOutputItem, TResponseStreamEvent from openai.types.responses import ( + Response, + ResponseCompletedEvent, + ResponseContentPartDoneEvent, ResponseFunctionToolCall, + ResponseOutputItemDoneEvent, ResponseOutputMessage, ResponseOutputText, + ResponseTextDeltaEvent, ) from temporalio.client import Client @@ -109,6 +114,87 @@ def output_message(text: str) -> ModelResponse: ) +class EventBuilders: + """Builders for creating stream events for testing. + + .. warning:: + This API is experimental and may change in the future. + """ + + @staticmethod + def text_delta(text: str) -> ResponseTextDeltaEvent: + """Create a TResponseStreamEvent with an text delta. + + .. warning:: + This API is experimental and may change in the future. + """ + return ResponseTextDeltaEvent( + content_index=0, + delta=text, + item_id="", + logprobs=[], + output_index=0, + sequence_number=0, + type="response.output_text.delta", + ) + + @staticmethod + def content_part_done(text: str) -> TResponseStreamEvent: + """Create a TResponseStreamEvent for content part completion. + + .. warning:: + This API is experimental and may change in the future. + """ + return ResponseContentPartDoneEvent( + content_index=0, + item_id="", + output_index=0, + sequence_number=0, + type="response.content_part.done", + part=ResponseOutputText( + text=text, + annotations=[], + type="output_text", + ), + ) + + @staticmethod + def output_item_done(text: str) -> TResponseStreamEvent: + """Create a TResponseStreamEvent for output item completion. + + .. warning:: + This API is experimental and may change in the future. + """ + return ResponseOutputItemDoneEvent( + output_index=0, + sequence_number=0, + type="response.output_item.done", + item=ResponseBuilders.response_output_message(text), + ) + + @staticmethod + def response_completion(text: str) -> TResponseStreamEvent: + """Create a TResponseStreamEvent for response completion. + + .. warning:: + This API is experimental and may change in the future. + """ + return ResponseCompletedEvent( + response=Response( + id="", + created_at=0.0, + object="response", + model="", + parallel_tool_calls=False, + tool_choice="none", + tools=[], + output=[ResponseBuilders.response_output_message(text)], + ), + sequence_number=0, + type="response.completed", + ) + + class TestModelProvider(ModelProvider): """Test model provider which simply returns the given module. @@ -144,13 +230,19 @@ class TestModel(Model): __test__ = False - def __init__(self, fn: Callable[[], ModelResponse]) -> None: + def __init__( + self, + fn: Callable[[], ModelResponse] | None, + *, + streaming_fn: Callable[[], AsyncIterator[TResponseStreamEvent]] | None = None, + ) -> None: """Initialize a test model with a callable. .. warning:: This API is experimental and may change in the future. """ self.fn = fn + self.streaming_fn = streaming_fn async def get_response( self, @@ -164,6 +256,8 @@ async def get_response( **kwargs, ) -> ModelResponse: """Get a response from the mocked model, by calling the callable passed to the constructor.""" + if self.fn is None: + raise ValueError("No non-streaming function provided") return self.fn() def stream_response( @@ -177,8 +271,10 @@ def stream_response( tracing: ModelTracing, **kwargs, ) -> AsyncIterator[TResponseStreamEvent]: - """Get a streamed response from the model. Unimplemented.""" - raise NotImplementedError() + """Get a streamed response from the model.""" + if self.streaming_fn is None: + raise ValueError("No streaming function provided") + return self.streaming_fn() @staticmethod def returning_responses(responses: list[ModelResponse]) -> "TestModel": @@ -190,6 +286,42 @@ def returning_responses(responses: list[ModelResponse]) -> "TestModel": i = iter(responses) return TestModel(lambda: next(i)) + @staticmethod + def streaming_events(events: list[TResponseStreamEvent]) -> "TestModel": + """Create a mock model which sequentially returns responses from a list. + + .. warning:: + This API is experimental and may change in the future. + """ + + async def generator(): + for event in events: + yield event + + return TestModel(None, streaming_fn=lambda: generator()) + + @staticmethod + def streaming_events_with_ending( + events: list[ResponseTextDeltaEvent], + ) -> "TestModel": + """Create a mock model which sequentially returns responses from a list. Appends ending markers + + .. warning:: + This API is experimental and may change in the future. + """ + content = "" + for event in events: + content += event.delta + + return TestModel.streaming_events( + events + + [ + EventBuilders.content_part_done(content), + EventBuilders.output_item_done(content), + EventBuilders.response_completion(content), + ] + ) + class AgentEnvironment: """Testing environment for OpenAI agents with Temporal integration. diff --git a/tests/contrib/openai_agents/test_openai.py b/tests/contrib/openai_agents/test_openai.py index 381a213cd..1a47c3cf6 100644 --- a/tests/contrib/openai_agents/test_openai.py +++ b/tests/contrib/openai_agents/test_openai.py @@ -72,6 +72,7 @@ ResponseInputTextParam, ResponseOutputMessage, ResponseOutputText, + ResponseTextDeltaEvent, ) from openai.types.responses.response_file_search_tool_call import Result from openai.types.responses.response_function_web_search import ActionSearch @@ -101,6 +102,7 @@ ) from temporalio.contrib.openai_agents.testing import ( AgentEnvironment, + EventBuilders, ResponseBuilders, TestModel, TestModelProvider, @@ -2635,3 +2637,57 @@ async def test_split_workers(client: Client): execution_timeout=timedelta(seconds=120), ) assert result == "test" + + +@workflow.defn +class StreamingHelloWorldAgent: + @workflow.run + async def run(self, prompt: str) -> str | None: + agent = Agent[None]( + name="Assistant", + instructions="You are a helpful assistant.", + ) + + result = None + for _ in range(2): + result = Runner.run_streamed(starting_agent=agent, input=prompt) + async for event in result.stream_events(): + if event.type == "raw_response_event" and isinstance( + event.data, ResponseTextDeltaEvent + ): + print(event.data.delta, end="", flush=True) + + return result.final_output if result else None + + +def streaming_hello_model(): + return TestModel.streaming_events_with_ending( + [ + EventBuilders.text_delta("Hello"), + EventBuilders.text_delta(" there"), + EventBuilders.text_delta("!"), + ] + ) + + +async def test_streaming(client: Client): + async with AgentEnvironment( + model=streaming_hello_model(), + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=30), + ), + ) as env: + client = env.applied_on_client(client) + + async with new_worker( + client, StreamingHelloWorldAgent, max_cached_workflows=0 + ) as worker: + handle = await client.start_workflow( + StreamingHelloWorldAgent.run, + "Say hello.", + id=f"hello-workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=50), + ) + result = await handle.result() + assert result == "Hello there!" From a0c1e7e48ca88e62b3e989412702acbaa6a4dccf Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Thu, 11 Dec 2025 09:38:34 -0800 Subject: [PATCH 2/6] Improve batching --- .../openai_agents/_invoke_model_activity.py | 24 ++++++++++++----- .../openai_agents/_temporal_model_stub.py | 27 +++++++++---------- 2 files changed, 30 insertions(+), 21 deletions(-) diff --git a/temporalio/contrib/openai_agents/_invoke_model_activity.py b/temporalio/contrib/openai_agents/_invoke_model_activity.py index 3befb690e..612e758a6 100644 --- a/temporalio/contrib/openai_agents/_invoke_model_activity.py +++ b/temporalio/contrib/openai_agents/_invoke_model_activity.py @@ -230,13 +230,23 @@ async def send_batch(): batch.clear() last_signal_time = asyncio.get_event_loop().time() - async for event in events: - event.model_rebuild() - batch.append(event) - - current_time = asyncio.get_event_loop().time() - if current_time - last_signal_time >= batch_latency: - await send_batch() + try: + while True: + # If latency has been passed, send the batch + if asyncio.get_event_loop().time() - last_signal_time >= batch_latency: + await send_batch() + try: + event = await asyncio.wait_for( + anext(events), timeout=asyncio.get_event_loop().time() - last_signal_time + ) + event.model_rebuild() + batch.append(event) + + # If the wait timed out, the latency has expired so send the batch + except asyncio.TimeoutError: + await send_batch() + except StopAsyncIteration: + pass # Send any remaining events in the batch if batch: diff --git a/temporalio/contrib/openai_agents/_temporal_model_stub.py b/temporalio/contrib/openai_agents/_temporal_model_stub.py index 0f5b5e62a..7116cc982 100644 --- a/temporalio/contrib/openai_agents/_temporal_model_stub.py +++ b/temporalio/contrib/openai_agents/_temporal_model_stub.py @@ -174,7 +174,7 @@ async def handle_stream_event(events: list[TResponseStreamEvent]): handle = workflow.start_activity_method( ModelActivity.stream_model, - activity_input, + args=[activity_input], summary=summary, task_queue=self.model_params.task_queue, schedule_to_close_timeout=self.model_params.schedule_to_close_timeout, @@ -187,24 +187,23 @@ async def handle_stream_event(events: list[TResponseStreamEvent]): priority=self.model_params.priority, ) - class SignalIterator(AsyncIterator): - def __init__(self): - self._monitor_task = asyncio.create_task(self._monitor_activity()) + async def monitor_activity(): + try: + await handle + finally: + await stream_queue.put(None) # Signal end of stream - async def _monitor_activity(self): - try: - await handle - finally: - await stream_queue.put(None) # Signal end of stream + monitor_task = asyncio.create_task(monitor_activity()) - async def __anext__(self): + async def generator() -> AsyncIterator[TResponseStreamEvent]: + while True: item = await stream_queue.get() if item is None: - self._monitor_task.cancel() - raise StopAsyncIteration - return item + monitor_task.cancel() + return + yield item - return SignalIterator() + return generator() def _make_summary( self, system_instructions: str | None, input: str | list[TResponseInputItem] From 93bc0045b338e46b3cb2128dbdfb2dec60719330 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Thu, 18 Dec 2025 08:31:34 -0500 Subject: [PATCH 3/6] Fix batch timing and exception propagation --- pf.rules | 1 + .../openai_agents/_invoke_model_activity.py | 35 ++++----- .../openai_agents/_model_parameters.py | 3 + .../openai_agents/_temporal_model_stub.py | 9 +-- .../openai_agents/_temporal_openai_agents.py | 1 + temporalio/contrib/openai_agents/testing.py | 19 +++-- tests/contrib/openai_agents/test_openai.py | 76 ++++++++++++++++--- 7 files changed, 107 insertions(+), 37 deletions(-) create mode 100644 pf.rules diff --git a/pf.rules b/pf.rules new file mode 100644 index 000000000..50887dc11 --- /dev/null +++ b/pf.rules @@ -0,0 +1 @@ +block out quick to api.openai.com diff --git a/temporalio/contrib/openai_agents/_invoke_model_activity.py b/temporalio/contrib/openai_agents/_invoke_model_activity.py index 612e758a6..41c102433 100644 --- a/temporalio/contrib/openai_agents/_invoke_model_activity.py +++ b/temporalio/contrib/openai_agents/_invoke_model_activity.py @@ -6,6 +6,7 @@ import asyncio import enum import json +import traceback from dataclasses import dataclass from datetime import timedelta from typing import Any, NoReturn, Optional, Union @@ -220,34 +221,34 @@ async def stream_model(self, input: StreamActivityModelInput) -> None: # Batch events with configurable latency batch: list[TResponseStreamEvent] = [] - last_signal_time = asyncio.get_event_loop().time() batch_latency = input.get("batch_latency_seconds", 1.0) async def send_batch(): - nonlocal last_signal_time if batch: await handle.signal(input["signal"], batch) batch.clear() - last_signal_time = asyncio.get_event_loop().time() try: - while True: - # If latency has been passed, send the batch - if asyncio.get_event_loop().time() - last_signal_time >= batch_latency: - await send_batch() - try: - event = await asyncio.wait_for( - anext(events), timeout=asyncio.get_event_loop().time() - last_signal_time - ) - event.model_rebuild() + async def read_events(): + async for event in events: batch.append(event) - - # If the wait timed out, the latency has expired so send the batch - except asyncio.TimeoutError: + async def send_batches(): + while True: + await asyncio.sleep(batch_latency) await send_batch() - except StopAsyncIteration: + completed, pending = await asyncio.wait([read_events(), send_batches()], return_when=asyncio.FIRST_COMPLETED) + for task in pending: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + for task in completed: + await task + + except StopAsyncIteration as e: + traceback.print_exception(e.__class__, e, e.__traceback__) pass - # Send any remaining events in the batch if batch: await send_batch() diff --git a/temporalio/contrib/openai_agents/_model_parameters.py b/temporalio/contrib/openai_agents/_model_parameters.py index 3cab91c27..aed470653 100644 --- a/temporalio/contrib/openai_agents/_model_parameters.py +++ b/temporalio/contrib/openai_agents/_model_parameters.py @@ -69,3 +69,6 @@ class ModelActivityParameters: use_local_activity: bool = False """Whether to use a local activity. If changed during a workflow execution, that would break determinism.""" + + streaming_batch_latency_seconds: float = 1.0 + """Default batch latency for streaming events.""" \ No newline at end of file diff --git a/temporalio/contrib/openai_agents/_temporal_model_stub.py b/temporalio/contrib/openai_agents/_temporal_model_stub.py index 7116cc982..de2383e82 100644 --- a/temporalio/contrib/openai_agents/_temporal_model_stub.py +++ b/temporalio/contrib/openai_agents/_temporal_model_stub.py @@ -2,17 +2,14 @@ import asyncio import logging -from asyncio import FIRST_COMPLETED -from typing import Optional, Tuple from temporalio import workflow from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters -from temporalio.contrib.openai_agents._temporal_trace_provider import _workflow_uuid logger = logging.getLogger(__name__) from collections.abc import AsyncIterator -from typing import Any, Union, cast +from typing import Any from agents import ( Agent, @@ -169,7 +166,7 @@ async def handle_stream_event(events: list[TResponseStreamEvent]): conversation_id=conversation_id, prompt=prompt, signal=signal_name, - batch_latency_seconds=1.0, + batch_latency_seconds=self.model_params.streaming_batch_latency_seconds, ) handle = workflow.start_activity_method( @@ -199,7 +196,7 @@ async def generator() -> AsyncIterator[TResponseStreamEvent]: while True: item = await stream_queue.get() if item is None: - monitor_task.cancel() + await monitor_task return yield item diff --git a/temporalio/contrib/openai_agents/_temporal_openai_agents.py b/temporalio/contrib/openai_agents/_temporal_openai_agents.py index 5e42882cf..3e664125b 100644 --- a/temporalio/contrib/openai_agents/_temporal_openai_agents.py +++ b/temporalio/contrib/openai_agents/_temporal_openai_agents.py @@ -182,6 +182,7 @@ def __init__( Union["StatelessMCPServerProvider", "StatefulMCPServerProvider"] ] = (), register_activities: bool = True, + streaming_batch_latency_seconds: float = 1.0, ) -> None: """Initialize the OpenAI agents plugin. diff --git a/temporalio/contrib/openai_agents/testing.py b/temporalio/contrib/openai_agents/testing.py index 1bfdb1e78..9c3bce891 100644 --- a/temporalio/contrib/openai_agents/testing.py +++ b/temporalio/contrib/openai_agents/testing.py @@ -194,6 +194,19 @@ def response_completion(text: str) -> TResponseStreamEvent: type="response.completed", ) + @staticmethod + def ending(text: str) -> list[TResponseStreamEvent]: + """Create a list of TResponseStreamEvent for the end of a stream. + + .. warning:: + This API is experimental and may change in the future. + """ + return [ + EventBuilders.content_part_done(text), + EventBuilders.output_item_done(text), + EventBuilders.response_completion(text), + ] + class TestModelProvider(ModelProvider): """Test model provider which simply returns the given module. @@ -315,11 +328,7 @@ def streaming_events_with_ending( return TestModel.streaming_events( events - + [ - EventBuilders.content_part_done(content), - EventBuilders.output_item_done(content), - EventBuilders.response_completion(content), - ] + + EventBuilders.ending(content) ) diff --git a/tests/contrib/openai_agents/test_openai.py b/tests/contrib/openai_agents/test_openai.py index 1a47c3cf6..0fa0eda3c 100644 --- a/tests/contrib/openai_agents/test_openai.py +++ b/tests/contrib/openai_agents/test_openai.py @@ -2641,6 +2641,9 @@ async def test_split_workers(client: Client): @workflow.defn class StreamingHelloWorldAgent: + def __init__(self): + self.events = [] + @workflow.run async def run(self, prompt: str) -> str | None: agent = Agent[None]( @@ -2648,17 +2651,19 @@ async def run(self, prompt: str) -> str | None: instructions="You are a helpful assistant.", ) - result = None - for _ in range(2): - result = Runner.run_streamed(starting_agent=agent, input=prompt) - async for event in result.stream_events(): - if event.type == "raw_response_event" and isinstance( - event.data, ResponseTextDeltaEvent - ): - print(event.data.delta, end="", flush=True) + result = Runner.run_streamed(starting_agent=agent, input=prompt) + async for event in result.stream_events(): + if event.type == "raw_response_event" and isinstance( + event.data, ResponseTextDeltaEvent + ): + self.events.append(event.data.delta) return result.final_output if result else None + @workflow.query + def get_events(self) -> list[str]: + print("Querying events: ", self.events) + return self.events def streaming_hello_model(): return TestModel.streaming_events_with_ending( @@ -2684,10 +2689,63 @@ async def test_streaming(client: Client): ) as worker: handle = await client.start_workflow( StreamingHelloWorldAgent.run, - "Say hello.", + args=["Say hello."], + id=f"hello-workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=50), + ) + result = await handle.result() + assert result == "Hello there!" + assert len(await handle.query(StreamingHelloWorldAgent.get_events)) == 3 + +failed = False + +def streaming_failure_model(): + async def generator(): + try: + global failed + for event in [ + EventBuilders.text_delta("Hello"), + EventBuilders.text_delta(" there"), + EventBuilders.text_delta("!"), + ]: + yield event + await asyncio.sleep(0.25) + if not failed: + failed = True + raise ValueError("Intentional failure") + + for event in EventBuilders.ending("Hello there!"): + yield event + finally: + print("Leaving activity...") + + return TestModel(None, streaming_fn=lambda: generator()) + + +async def test_streaming_failure(client: Client): + async with AgentEnvironment( + # model=streaming_failure_model(), + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=30), + streaming_batch_latency_seconds=0.1, + retry_policy=RetryPolicy( + maximum_attempts=1 + ) + ), + ) as env: + client = env.applied_on_client(client) + + async with new_worker( + client, StreamingHelloWorldAgent, max_cached_workflows=0 + ) as worker: + handle = await client.start_workflow( + StreamingHelloWorldAgent.run, + args=["Say hello."], id=f"hello-workflow-{uuid.uuid4()}", task_queue=worker.task_queue, execution_timeout=timedelta(seconds=50), ) result = await handle.result() assert result == "Hello there!" + assert len(await handle.query(StreamingHelloWorldAgent.get_events)) == 6 From 65c2d99c8d6549af436d77fd28443fe97d782951 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Thu, 18 Dec 2025 09:27:03 -0500 Subject: [PATCH 4/6] Add callback based option --- temporalio/contrib/openai_agents/__init__.py | 8 +- .../openai_agents/_invoke_model_activity.py | 69 ++++++++++--- .../openai_agents/_model_parameters.py | 20 +++- .../contrib/openai_agents/_openai_runner.py | 33 +++++-- .../openai_agents/_temporal_model_stub.py | 92 +++++++++++------- .../openai_agents/_temporal_openai_agents.py | 22 +++-- temporalio/contrib/openai_agents/testing.py | 14 ++- tests/contrib/openai_agents/test_openai.py | 96 +++++++++++++------ 8 files changed, 258 insertions(+), 96 deletions(-) diff --git a/temporalio/contrib/openai_agents/__init__.py b/temporalio/contrib/openai_agents/__init__.py index d49733e8b..b70bf19ed 100644 --- a/temporalio/contrib/openai_agents/__init__.py +++ b/temporalio/contrib/openai_agents/__init__.py @@ -12,7 +12,11 @@ StatefulMCPServerProvider, StatelessMCPServerProvider, ) -from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters +from temporalio.contrib.openai_agents._model_parameters import ( + ModelActivityParameters, + ModelSummaryProvider, + StreamingOptions, +) from temporalio.contrib.openai_agents._temporal_openai_agents import ( OpenAIAgentsPlugin, OpenAIPayloadConverter, @@ -27,10 +31,12 @@ __all__ = [ "AgentsWorkflowError", "ModelActivityParameters", + "ModelSummaryProvider", "OpenAIAgentsPlugin", "OpenAIPayloadConverter", "StatelessMCPServerProvider", "StatefulMCPServerProvider", + "StreamingOptions", "testing", "workflow", ] diff --git a/temporalio/contrib/openai_agents/_invoke_model_activity.py b/temporalio/contrib/openai_agents/_invoke_model_activity.py index 41c102433..1e35b5c75 100644 --- a/temporalio/contrib/openai_agents/_invoke_model_activity.py +++ b/temporalio/contrib/openai_agents/_invoke_model_activity.py @@ -9,7 +9,7 @@ import traceback from dataclasses import dataclass from datetime import timedelta -from typing import Any, NoReturn, Optional, Union +from typing import Any, Awaitable, Callable, NoReturn, Optional, Union from agents import ( AgentOutputSchemaBase, @@ -41,6 +41,7 @@ from temporalio import activity, workflow from temporalio.contrib.openai_agents._heartbeat_decorator import _auto_heartbeater +from temporalio.contrib.openai_agents._model_parameters import StreamingOptions from temporalio.exceptions import ApplicationError @@ -151,11 +152,10 @@ class ActivityModelInput(TypedDict, total=False): prompt: Any | None -class StreamActivityModelInput(ActivityModelInput): +class ActivityModelInputWithSignal(ActivityModelInput): """Input for the stream_model activity.""" signal: str - batch_latency_seconds: float class ModelActivity: @@ -163,11 +163,14 @@ class ModelActivity: Disabling retries in your model of choice is recommended to allow activity retries to define the retry model. """ - def __init__(self, model_provider: ModelProvider | None = None): + def __init__( + self, model_provider: ModelProvider | None, streaming_options: StreamingOptions + ): """Initialize the activity with a model provider.""" self._model_provider = model_provider or OpenAIProvider( openai_client=AsyncOpenAI(max_retries=0) ) + self._streaming_options = streaming_options @activity.defn @_auto_heartbeater @@ -195,7 +198,7 @@ async def invoke_model_activity(self, input: ActivityModelInput) -> ModelRespons _handle_error(e) @activity.defn - async def stream_model(self, input: StreamActivityModelInput) -> None: + async def stream_model(self, input: ActivityModelInputWithSignal) -> None: """Activity that streams a model with the given input.""" model = self._model_provider.get_model(input.get("model_name")) @@ -221,22 +224,29 @@ async def stream_model(self, input: StreamActivityModelInput) -> None: # Batch events with configurable latency batch: list[TResponseStreamEvent] = [] - batch_latency = input.get("batch_latency_seconds", 1.0) + batch_latency = self._streaming_options.signal_batch_latency_seconds async def send_batch(): if batch: await handle.signal(input["signal"], batch) batch.clear() + async def read_events(): + async for event in events: + event.model_rebuild() + batch.append(event) + if self._streaming_options.callback is not None: + await self._streaming_options.callback(event) + + async def send_batches(): + while True: + await asyncio.sleep(batch_latency) + await send_batch() + try: - async def read_events(): - async for event in events: - batch.append(event) - async def send_batches(): - while True: - await asyncio.sleep(batch_latency) - await send_batch() - completed, pending = await asyncio.wait([read_events(), send_batches()], return_when=asyncio.FIRST_COMPLETED) + completed, pending = await asyncio.wait( + [read_events(), send_batches()], return_when=asyncio.FIRST_COMPLETED + ) for task in pending: task.cancel() try: @@ -256,6 +266,37 @@ async def send_batches(): except APIStatusError as e: _handle_error(e) + @activity.defn + async def batch_stream_model( + self, input: ActivityModelInput + ) -> list[TResponseStreamEvent]: + """Activity that streams a model with the given input.""" + model = self._model_provider.get_model(input.get("model_name")) + + tools = _make_tools(input) + handoffs = _make_handoffs(input) + + events = model.stream_response( + system_instructions=input.get("system_instructions"), + input=input["input"], + model_settings=input["model_settings"], + tools=tools, + output_schema=input.get("output_schema"), + handoffs=handoffs, + tracing=ModelTracing(input["tracing"]), + previous_response_id=input.get("previous_response_id"), + conversation_id=input.get("conversation_id"), + prompt=input.get("prompt"), + ) + result = [] + async for event in events: + event.model_rebuild() + result.append(event) + if self._streaming_options.callback is not None: + await self._streaming_options.callback(event) + + return result + async def _empty_on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> str: return "" diff --git a/temporalio/contrib/openai_agents/_model_parameters.py b/temporalio/contrib/openai_agents/_model_parameters.py index aed470653..199723c1a 100644 --- a/temporalio/contrib/openai_agents/_model_parameters.py +++ b/temporalio/contrib/openai_agents/_model_parameters.py @@ -4,9 +4,10 @@ from collections.abc import Callable from dataclasses import dataclass from datetime import timedelta -from typing import Any, Optional, Union +from typing import Any, Awaitable, Optional, Union from agents import Agent, TResponseInputItem +from agents.items import TResponseStreamEvent from temporalio.common import Priority, RetryPolicy from temporalio.workflow import ActivityCancellationType, VersioningIntent @@ -70,5 +71,18 @@ class ModelActivityParameters: use_local_activity: bool = False """Whether to use a local activity. If changed during a workflow execution, that would break determinism.""" - streaming_batch_latency_seconds: float = 1.0 - """Default batch latency for streaming events.""" \ No newline at end of file + +@dataclass +class StreamingOptions: + """Options applicable for use of run_streamed""" + + callback: Callable[[TResponseStreamEvent], Awaitable[None]] | None = None + """A callback function that will be invoked inside the activity on every stream event which occurs.""" + + use_signals: bool = False + """If true, the activity will use signals to provide events to the workflow as they occur. Ensure that the workflow + appropriately handles those signals during replay. If false, all the stream events will be delivered when the activity completes.""" + + signal_batch_latency_seconds: float = 1.0 + """Batch latency for sending signals. Lower values will result in lower stream event latency but higher + signal volume, and therefore cost.""" diff --git a/temporalio/contrib/openai_agents/_openai_runner.py b/temporalio/contrib/openai_agents/_openai_runner.py index b261306a6..7bac44ddb 100644 --- a/temporalio/contrib/openai_agents/_openai_runner.py +++ b/temporalio/contrib/openai_agents/_openai_runner.py @@ -18,7 +18,10 @@ from agents.run import DEFAULT_AGENT_RUNNER, AgentRunner from temporalio import workflow -from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters +from temporalio.contrib.openai_agents._model_parameters import ( + ModelActivityParameters, + StreamingOptions, +) from temporalio.contrib.openai_agents._temporal_model_stub import _TemporalModelStub from temporalio.contrib.openai_agents.workflow import AgentsWorkflowError @@ -26,6 +29,7 @@ # Recursively replace models in all agents def _convert_agent( model_params: ModelActivityParameters, + streaming_options: StreamingOptions, agent: Agent[Any], seen: dict[int, Agent] | None, ) -> Agent[Any]: @@ -49,13 +53,17 @@ def _convert_agent( new_handoffs: list[Agent | Handoff] = [] for handoff in agent.handoffs: if isinstance(handoff, Agent): - new_handoffs.append(_convert_agent(model_params, handoff, seen)) + new_handoffs.append( + _convert_agent(model_params, streaming_options, handoff, seen) + ) elif isinstance(handoff, Handoff): original_invoke = handoff.on_invoke_handoff async def on_invoke(context: RunContextWrapper[Any], args: str) -> Agent: handoff_agent = await original_invoke(context, args) - return _convert_agent(model_params, handoff_agent, seen) + return _convert_agent( + model_params, streaming_options, handoff_agent, seen + ) new_handoffs.append( dataclasses.replace(handoff, on_invoke_handoff=on_invoke) @@ -67,6 +75,7 @@ async def on_invoke(context: RunContextWrapper[Any], args: str) -> Agent: model_name=name, model_params=model_params, agent=agent, + streaming_options=streaming_options, ) new_agent.handoffs = new_handoffs return new_agent @@ -79,10 +88,13 @@ class TemporalOpenAIRunner(AgentRunner): """ - def __init__(self, model_params: ModelActivityParameters) -> None: + def __init__( + self, model_params: ModelActivityParameters, streaming_options: StreamingOptions + ) -> None: """Initialize the Temporal OpenAI Runner.""" self._runner = DEFAULT_AGENT_RUNNER or AgentRunner() self.model_params = model_params + self.streaming_options = streaming_options async def run( self, @@ -104,7 +116,9 @@ async def run( try: return await self._runner.run( - starting_agent=_convert_agent(self.model_params, starting_agent, None), + starting_agent=_convert_agent( + self.model_params, self.streaming_options, starting_agent, None + ), input=input, **kwargs, ) @@ -155,7 +169,9 @@ def run_streamed( try: return self._runner.run_streamed( - starting_agent=_convert_agent(self.model_params, starting_agent, None), + starting_agent=_convert_agent( + self.model_params, self.streaming_options, starting_agent, None + ), input=input, **kwargs, ) @@ -183,7 +199,10 @@ def _process_run_config(self, run_config: RunConfig | None) -> RunConfig: run_config = dataclasses.replace( run_config, model=_TemporalModelStub( - run_config.model, model_params=self.model_params, agent=None + run_config.model, + model_params=self.model_params, + streaming_options=self.streaming_options, + agent=None, ), ) return run_config diff --git a/temporalio/contrib/openai_agents/_temporal_model_stub.py b/temporalio/contrib/openai_agents/_temporal_model_stub.py index de2383e82..44887db15 100644 --- a/temporalio/contrib/openai_agents/_temporal_model_stub.py +++ b/temporalio/contrib/openai_agents/_temporal_model_stub.py @@ -4,7 +4,10 @@ import logging from temporalio import workflow -from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters +from temporalio.contrib.openai_agents._model_parameters import ( + ModelActivityParameters, + StreamingOptions, +) logger = logging.getLogger(__name__) @@ -34,13 +37,13 @@ from temporalio.contrib.openai_agents._invoke_model_activity import ( ActivityModelInput, + ActivityModelInputWithSignal, AgentOutputSchemaInput, FunctionToolInput, HandoffInput, HostedMCPToolInput, ModelActivity, ModelTracingInput, - StreamActivityModelInput, ToolInput, ) @@ -54,11 +57,13 @@ def __init__( *, model_params: ModelActivityParameters, agent: Agent[Any] | None, + streaming_options: StreamingOptions, ) -> None: self.model_name = model_name self.model_params = model_params self.agent = agent self.stream_events: list[TResponseStreamEvent] = [] + self.streaming_options = streaming_options async def get_response( self, @@ -153,7 +158,7 @@ async def handle_stream_event(events: list[TResponseStreamEvent]): signal_name = "model_stream_signal" workflow.set_signal_handler(signal_name, handle_stream_event) - activity_input = StreamActivityModelInput( + activity_input = ActivityModelInput( model_name=self.model_name, system_instructions=system_instructions, input=input, @@ -165,42 +170,63 @@ async def handle_stream_event(events: list[TResponseStreamEvent]): previous_response_id=previous_response_id, conversation_id=conversation_id, prompt=prompt, - signal=signal_name, - batch_latency_seconds=self.model_params.streaming_batch_latency_seconds, ) + if self.streaming_options.use_signals: + handle = workflow.start_activity_method( + ModelActivity.stream_model, + args=[ + ActivityModelInputWithSignal(**activity_input, signal=signal_name) + ], + summary=summary, + task_queue=self.model_params.task_queue, + schedule_to_close_timeout=self.model_params.schedule_to_close_timeout, + schedule_to_start_timeout=self.model_params.schedule_to_start_timeout, + start_to_close_timeout=self.model_params.start_to_close_timeout, + heartbeat_timeout=self.model_params.heartbeat_timeout, + retry_policy=self.model_params.retry_policy, + cancellation_type=self.model_params.cancellation_type, + versioning_intent=self.model_params.versioning_intent, + priority=self.model_params.priority, + ) - handle = workflow.start_activity_method( - ModelActivity.stream_model, - args=[activity_input], - summary=summary, - task_queue=self.model_params.task_queue, - schedule_to_close_timeout=self.model_params.schedule_to_close_timeout, - schedule_to_start_timeout=self.model_params.schedule_to_start_timeout, - start_to_close_timeout=self.model_params.start_to_close_timeout, - heartbeat_timeout=self.model_params.heartbeat_timeout, - retry_policy=self.model_params.retry_policy, - cancellation_type=self.model_params.cancellation_type, - versioning_intent=self.model_params.versioning_intent, - priority=self.model_params.priority, - ) + async def monitor_activity(): + try: + await handle + finally: + await stream_queue.put(None) # Signal end of stream + + monitor_task = asyncio.create_task(monitor_activity()) - async def monitor_activity(): - try: - await handle - finally: - await stream_queue.put(None) # Signal end of stream + async def generator() -> AsyncIterator[TResponseStreamEvent]: + while True: + item = await stream_queue.get() + if item is None: + await monitor_task + return + yield item - monitor_task = asyncio.create_task(monitor_activity()) + return generator() + else: - async def generator() -> AsyncIterator[TResponseStreamEvent]: - while True: - item = await stream_queue.get() - if item is None: - await monitor_task - return - yield item + async def generator() -> AsyncIterator[TResponseStreamEvent]: + results = await workflow.execute_activity_method( + ModelActivity.batch_stream_model, + args=[activity_input], + summary=summary, + task_queue=self.model_params.task_queue, + schedule_to_close_timeout=self.model_params.schedule_to_close_timeout, + schedule_to_start_timeout=self.model_params.schedule_to_start_timeout, + start_to_close_timeout=self.model_params.start_to_close_timeout, + heartbeat_timeout=self.model_params.heartbeat_timeout, + retry_policy=self.model_params.retry_policy, + cancellation_type=self.model_params.cancellation_type, + versioning_intent=self.model_params.versioning_intent, + priority=self.model_params.priority, + ) + for event in results: + yield event - return generator() + return generator() def _make_summary( self, system_instructions: str | None, input: str | list[TResponseInputItem] diff --git a/temporalio/contrib/openai_agents/_temporal_openai_agents.py b/temporalio/contrib/openai_agents/_temporal_openai_agents.py index 3e664125b..6fb3d8f3b 100644 --- a/temporalio/contrib/openai_agents/_temporal_openai_agents.py +++ b/temporalio/contrib/openai_agents/_temporal_openai_agents.py @@ -5,7 +5,7 @@ from collections.abc import AsyncIterator, Callable, Sequence from contextlib import asynccontextmanager, contextmanager from datetime import timedelta -from typing import Optional, Union +from typing import Union from agents import ModelProvider, set_trace_provider from agents.run import get_default_agent_runner, set_default_agent_runner @@ -13,7 +13,10 @@ from agents.tracing.provider import DefaultTraceProvider from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity -from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters +from temporalio.contrib.openai_agents._model_parameters import ( + ModelActivityParameters, + StreamingOptions, +) from temporalio.contrib.openai_agents._openai_runner import ( TemporalOpenAIRunner, ) @@ -47,6 +50,7 @@ def set_open_ai_agent_temporal_overrides( model_params: ModelActivityParameters, auto_close_tracing_in_workflows: bool = False, + streaming_options: StreamingOptions = StreamingOptions(), ): """Configure Temporal-specific overrides for OpenAI agents. @@ -67,6 +71,7 @@ def set_open_ai_agent_temporal_overrides( Args: model_params: Configuration parameters for Temporal activity execution of model calls. auto_close_tracing_in_workflows: If set to true, close tracing spans immediately. + streaming_options: Options applicable for use of run_streamed. Returns: A context manager that yields the configured TemporalTraceProvider. @@ -78,7 +83,7 @@ def set_open_ai_agent_temporal_overrides( ) try: - set_default_agent_runner(TemporalOpenAIRunner(model_params)) + set_default_agent_runner(TemporalOpenAIRunner(model_params, streaming_options)) set_trace_provider(provider) yield provider finally: @@ -136,6 +141,7 @@ class OpenAIAgentsPlugin(SimplePlugin): The plugin will wrap each server in a TemporalMCPServer if needed and manage their connection lifecycles tied to the worker lifetime. This is the recommended way to use MCP servers with Temporal workflows. + streaming_options: Options applicable for use of run_streamed. Example: >>> from temporalio.client import Client @@ -182,7 +188,7 @@ def __init__( Union["StatelessMCPServerProvider", "StatefulMCPServerProvider"] ] = (), register_activities: bool = True, - streaming_batch_latency_seconds: float = 1.0, + streaming_options: StreamingOptions = StreamingOptions(), ) -> None: """Initialize the OpenAI agents plugin. @@ -198,6 +204,7 @@ def __init__( register_activities: Whether to register activities during the worker execution. This can be disabled on some workers to allow a separation of workflows and activities but should not be disabled on all workers, or agents will not be able to progress. + streaming_options: Options applicable for use of run_streamed. """ if model_params is None: model_params = ModelActivityParameters() @@ -222,10 +229,11 @@ def add_activities( if not register_activities: return activities or [] - model_activity = ModelActivity(model_provider) + model_activity = ModelActivity(model_provider, streaming_options) new_activities = [ model_activity.invoke_model_activity, model_activity.stream_model, + model_activity.batch_stream_model, ] server_names = [server.name for server in mcp_server_providers] @@ -252,7 +260,9 @@ def workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: @asynccontextmanager async def run_context() -> AsyncIterator[None]: - with set_open_ai_agent_temporal_overrides(model_params): + with set_open_ai_agent_temporal_overrides( + model_params, streaming_options=streaming_options + ): yield super().__init__( diff --git a/temporalio/contrib/openai_agents/testing.py b/temporalio/contrib/openai_agents/testing.py index 9c3bce891..03ac23d58 100644 --- a/temporalio/contrib/openai_agents/testing.py +++ b/temporalio/contrib/openai_agents/testing.py @@ -32,7 +32,10 @@ StatefulMCPServerProvider, StatelessMCPServerProvider, ) -from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters +from temporalio.contrib.openai_agents._model_parameters import ( + ModelActivityParameters, + StreamingOptions, +) from temporalio.contrib.openai_agents._temporal_openai_agents import OpenAIAgentsPlugin __all__ = [ @@ -326,10 +329,7 @@ def streaming_events_with_ending( for event in events: content += event.delta - return TestModel.streaming_events( - events - + EventBuilders.ending(content) - ) + return TestModel.streaming_events(events + EventBuilders.ending(content)) class AgentEnvironment: @@ -367,6 +367,7 @@ def __init__( StatelessMCPServerProvider | StatefulMCPServerProvider ] = (), register_activities: bool = True, + streaming_options: StreamingOptions = StreamingOptions(), ) -> None: """Initialize the AgentEnvironment. @@ -383,6 +384,7 @@ def __init__( If both are provided, model_provider will be used. mcp_server_providers: Sequence of MCP servers to automatically register with the worker. register_activities: Whether to register activities during worker execution. + streaming_options: Options applicable for use of run_streamed. .. warning:: This API is experimental and may change in the future. @@ -396,6 +398,7 @@ def __init__( self._mcp_server_providers = mcp_server_providers self._register_activities = register_activities self._plugin: OpenAIAgentsPlugin | None = None + self.streaming_options = streaming_options async def __aenter__(self) -> "AgentEnvironment": """Enter the async context manager.""" @@ -405,6 +408,7 @@ async def __aenter__(self) -> "AgentEnvironment": model_provider=self._model_provider, mcp_server_providers=self._mcp_server_providers, register_activities=self._register_activities, + streaming_options=self.streaming_options, ) return self diff --git a/tests/contrib/openai_agents/test_openai.py b/tests/contrib/openai_agents/test_openai.py index 0fa0eda3c..53b0089fd 100644 --- a/tests/contrib/openai_agents/test_openai.py +++ b/tests/contrib/openai_agents/test_openai.py @@ -8,6 +8,7 @@ from datetime import timedelta from typing import ( Any, + Awaitable, Optional, Union, cast, @@ -91,10 +92,11 @@ from temporalio.contrib import openai_agents from temporalio.contrib.openai_agents import ( ModelActivityParameters, + ModelSummaryProvider, StatefulMCPServerProvider, StatelessMCPServerProvider, + StreamingOptions, ) -from temporalio.contrib.openai_agents._model_parameters import ModelSummaryProvider from temporalio.contrib.openai_agents._openai_runner import _convert_agent from temporalio.contrib.openai_agents._temporal_model_stub import ( _extract_summary, @@ -2553,7 +2555,9 @@ def override_get_activities() -> Sequence[Callable]: async def test_model_conversion_loops(): agent = init_agents() - converted = _convert_agent(ModelActivityParameters(), agent, None) + converted = _convert_agent( + ModelActivityParameters(), StreamingOptions(), agent, None + ) seat_booking_handoff = converted.handoffs[1] assert isinstance(seat_booking_handoff, Handoff) context: RunContextWrapper[AirlineAgentContext] = RunContextWrapper( @@ -2665,6 +2669,7 @@ def get_events(self) -> list[str]: print("Querying events: ", self.events) return self.events + def streaming_hello_model(): return TestModel.streaming_events_with_ending( [ @@ -2675,12 +2680,15 @@ def streaming_hello_model(): ) -async def test_streaming(client: Client): +async def test_signal_streaming(client: Client): async with AgentEnvironment( model=streaming_hello_model(), model_params=ModelActivityParameters( start_to_close_timeout=timedelta(seconds=30), ), + streaming_options=StreamingOptions( + use_signals=True, + ), ) as env: client = env.applied_on_client(client) @@ -2698,40 +2706,39 @@ async def test_streaming(client: Client): assert result == "Hello there!" assert len(await handle.query(StreamingHelloWorldAgent.get_events)) == 3 + failed = False + def streaming_failure_model(): - async def generator(): - try: - global failed - for event in [ - EventBuilders.text_delta("Hello"), - EventBuilders.text_delta(" there"), - EventBuilders.text_delta("!"), - ]: - yield event - await asyncio.sleep(0.25) - if not failed: - failed = True - raise ValueError("Intentional failure") - - for event in EventBuilders.ending("Hello there!"): - yield event - finally: - print("Leaving activity...") + async def generator() -> AsyncIterator[TResponseStreamEvent]: + global failed + for event in [ + EventBuilders.text_delta("Hello"), + EventBuilders.text_delta(" there"), + EventBuilders.text_delta("!"), + ]: + yield event + await asyncio.sleep(0.25) + if not failed: + failed = True + raise ValueError("Intentional failure") + + for end_event in EventBuilders.ending("Hello there!"): + yield end_event return TestModel(None, streaming_fn=lambda: generator()) -async def test_streaming_failure(client: Client): +async def test_signal_streaming_failure(client: Client): async with AgentEnvironment( - # model=streaming_failure_model(), + model=streaming_failure_model(), model_params=ModelActivityParameters( start_to_close_timeout=timedelta(seconds=30), - streaming_batch_latency_seconds=0.1, - retry_policy=RetryPolicy( - maximum_attempts=1 - ) + ), + streaming_options=StreamingOptions( + use_signals=True, + signal_batch_latency_seconds=0.1, ), ) as env: client = env.applied_on_client(client) @@ -2749,3 +2756,38 @@ async def test_streaming_failure(client: Client): result = await handle.result() assert result == "Hello there!" assert len(await handle.query(StreamingHelloWorldAgent.get_events)) == 6 + + +async def test_callback_streaming(client: Client): + events = [] + + async def callback(event: TResponseStreamEvent) -> None: + events.append(event) + + async with AgentEnvironment( + model=streaming_hello_model(), + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=30), + ), + streaming_options=StreamingOptions( + callback=callback, + ), + ) as env: + client = env.applied_on_client(client) + + async with new_worker( + client, StreamingHelloWorldAgent, max_cached_workflows=0 + ) as worker: + handle = await client.start_workflow( + StreamingHelloWorldAgent.run, + args=["Say hello."], + id=f"hello-workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=50), + ) + result = await handle.result() + assert result == "Hello there!" + assert len(await handle.query(StreamingHelloWorldAgent.get_events)) == 3 + + # The results include the ending markers because it wasn't filtered like the workflow + assert len(events) == 6 From 2347cf20ad26eed87686078ac49cbba679a824ae Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Thu, 18 Dec 2025 09:36:02 -0500 Subject: [PATCH 5/6] Create tasks --- .../openai_agents/_invoke_model_activity.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/temporalio/contrib/openai_agents/_invoke_model_activity.py b/temporalio/contrib/openai_agents/_invoke_model_activity.py index 1e35b5c75..df562592f 100644 --- a/temporalio/contrib/openai_agents/_invoke_model_activity.py +++ b/temporalio/contrib/openai_agents/_invoke_model_activity.py @@ -231,6 +231,11 @@ async def send_batch(): await handle.signal(input["signal"], batch) batch.clear() + async def send_batches(): + while True: + await asyncio.sleep(batch_latency) + await send_batch() + async def read_events(): async for event in events: event.model_rebuild() @@ -238,14 +243,13 @@ async def read_events(): if self._streaming_options.callback is not None: await self._streaming_options.callback(event) - async def send_batches(): - while True: - await asyncio.sleep(batch_latency) - await send_batch() - try: completed, pending = await asyncio.wait( - [read_events(), send_batches()], return_when=asyncio.FIRST_COMPLETED + [ + asyncio.create_task(read_events()), + asyncio.create_task(send_batches()), + ], + return_when=asyncio.FIRST_COMPLETED, ) for task in pending: task.cancel() From c14480b154f16da9fcda009ceef3dfa12b16a3f8 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Thu, 18 Dec 2025 09:52:45 -0500 Subject: [PATCH 6/6] Add additional context --- .../contrib/openai_agents/_invoke_model_activity.py | 6 ++++-- temporalio/contrib/openai_agents/_model_parameters.py | 9 ++++++--- tests/contrib/openai_agents/test_openai.py | 2 +- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/temporalio/contrib/openai_agents/_invoke_model_activity.py b/temporalio/contrib/openai_agents/_invoke_model_activity.py index df562592f..704c12df9 100644 --- a/temporalio/contrib/openai_agents/_invoke_model_activity.py +++ b/temporalio/contrib/openai_agents/_invoke_model_activity.py @@ -241,7 +241,9 @@ async def read_events(): event.model_rebuild() batch.append(event) if self._streaming_options.callback is not None: - await self._streaming_options.callback(event) + await self._streaming_options.callback( + input["model_settings"], event + ) try: completed, pending = await asyncio.wait( @@ -297,7 +299,7 @@ async def batch_stream_model( event.model_rebuild() result.append(event) if self._streaming_options.callback is not None: - await self._streaming_options.callback(event) + await self._streaming_options.callback(input["model_settings"], event) return result diff --git a/temporalio/contrib/openai_agents/_model_parameters.py b/temporalio/contrib/openai_agents/_model_parameters.py index 199723c1a..376f4d2db 100644 --- a/temporalio/contrib/openai_agents/_model_parameters.py +++ b/temporalio/contrib/openai_agents/_model_parameters.py @@ -6,7 +6,7 @@ from datetime import timedelta from typing import Any, Awaitable, Optional, Union -from agents import Agent, TResponseInputItem +from agents import Agent, ModelSettings, TResponseInputItem from agents.items import TResponseStreamEvent from temporalio.common import Priority, RetryPolicy @@ -76,8 +76,11 @@ class ModelActivityParameters: class StreamingOptions: """Options applicable for use of run_streamed""" - callback: Callable[[TResponseStreamEvent], Awaitable[None]] | None = None - """A callback function that will be invoked inside the activity on every stream event which occurs.""" + callback: ( + Callable[[ModelSettings, TResponseStreamEvent], Awaitable[None]] | None + ) = None + """A callback function that will be invoked inside the activity on every stream event which occurs. + ModelSettings are provided so that the callback can distinguish what to do based on extra_args if desired.""" use_signals: bool = False """If true, the activity will use signals to provide events to the workflow as they occur. Ensure that the workflow diff --git a/tests/contrib/openai_agents/test_openai.py b/tests/contrib/openai_agents/test_openai.py index 53b0089fd..e68525830 100644 --- a/tests/contrib/openai_agents/test_openai.py +++ b/tests/contrib/openai_agents/test_openai.py @@ -2761,7 +2761,7 @@ async def test_signal_streaming_failure(client: Client): async def test_callback_streaming(client: Client): events = [] - async def callback(event: TResponseStreamEvent) -> None: + async def callback(_: ModelSettings, event: TResponseStreamEvent) -> None: events.append(event) async with AgentEnvironment(