diff --git a/temporalio/contrib/openai_agents/_invoke_model_activity.py b/temporalio/contrib/openai_agents/_invoke_model_activity.py index f03458c32..612e758a6 100644 --- a/temporalio/contrib/openai_agents/_invoke_model_activity.py +++ b/temporalio/contrib/openai_agents/_invoke_model_activity.py @@ -3,11 +3,12 @@ Implements mapping of OpenAI datastructures to Pydantic friendly types. """ +import asyncio import enum import json from dataclasses import dataclass from datetime import timedelta -from typing import Any, Optional, Union +from typing import Any, NoReturn, Optional, Union from agents import ( AgentOutputSchemaBase, @@ -28,6 +29,7 @@ UserError, WebSearchTool, ) +from agents.items import TResponseStreamEvent from openai import ( APIStatusError, AsyncOpenAI, @@ -148,6 +150,13 @@ class ActivityModelInput(TypedDict, total=False): prompt: Any | None +class StreamActivityModelInput(ActivityModelInput): + """Input for the stream_model activity.""" + + signal: str + batch_latency_seconds: float + + class ModelActivity: """Class wrapper for model invocation activities to allow model customization. By default, we use an OpenAIProvider with retries disabled. Disabling retries in your model of choice is recommended to allow activity retries to define the retry model. @@ -165,52 +174,8 @@ async def invoke_model_activity(self, input: ActivityModelInput) -> ModelRespons """Activity that invokes a model with the given input.""" model = self._model_provider.get_model(input.get("model_name")) - async def empty_on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> str: - return "" - - async def empty_on_invoke_handoff( - ctx: RunContextWrapper[Any], input: str - ) -> Any: - return None - - def make_tool(tool: ToolInput) -> Tool: - if isinstance( - tool, - ( - FileSearchTool, - WebSearchTool, - ImageGenerationTool, - CodeInterpreterTool, - ), - ): - return tool - elif isinstance(tool, HostedMCPToolInput): - return HostedMCPTool( - tool_config=tool.tool_config, - ) - elif isinstance(tool, FunctionToolInput): - return FunctionTool( - name=tool.name, - description=tool.description, - params_json_schema=tool.params_json_schema, - on_invoke_tool=empty_on_invoke_tool, - strict_json_schema=tool.strict_json_schema, - ) - else: - raise UserError(f"Unknown tool type: {tool.name}") - - tools = [make_tool(x) for x in input.get("tools", [])] - handoffs: list[Handoff[Any, Any]] = [ - Handoff( - tool_name=x.tool_name, - tool_description=x.tool_description, - input_json_schema=x.input_json_schema, - agent_name=x.agent_name, - strict_json_schema=x.strict_json_schema, - on_invoke_handoff=empty_on_invoke_handoff, - ) - for x in input.get("handoffs", []) - ] + tools = _make_tools(input) + handoffs = _make_handoffs(input) try: return await model.get_response( @@ -226,40 +191,156 @@ def make_tool(tool: ToolInput) -> Tool: prompt=input.get("prompt"), ) except APIStatusError as e: - # Listen to server hints - retry_after = None - retry_after_ms_header = e.response.headers.get("retry-after-ms") - if retry_after_ms_header is not None: - retry_after = timedelta(milliseconds=float(retry_after_ms_header)) - - if retry_after is None: - retry_after_header = e.response.headers.get("retry-after") - if retry_after_header is not None: - retry_after = timedelta(seconds=float(retry_after_header)) - - should_retry_header = e.response.headers.get("x-should-retry") - if should_retry_header == "true": - raise e - if should_retry_header == "false": - raise ApplicationError( - "Non retryable OpenAI error", - non_retryable=True, - next_retry_delay=retry_after, - ) from e - - # Specifically retryable status codes - if ( - e.response.status_code in [408, 409, 429] - or e.response.status_code >= 500 - ): - raise ApplicationError( - f"Retryable OpenAI status code: {e.response.status_code}", - non_retryable=False, - next_retry_delay=retry_after, - ) from e - - raise ApplicationError( - f"Non retryable OpenAI status code: {e.response.status_code}", - non_retryable=True, - next_retry_delay=retry_after, - ) from e + _handle_error(e) + + @activity.defn + async def stream_model(self, input: StreamActivityModelInput) -> None: + """Activity that streams a model with the given input.""" + model = self._model_provider.get_model(input.get("model_name")) + + tools = _make_tools(input) + handoffs = _make_handoffs(input) + + try: + handle = activity.client().get_workflow_handle( + workflow_id=activity.info().workflow_id + ) + events = model.stream_response( + system_instructions=input.get("system_instructions"), + input=input["input"], + model_settings=input["model_settings"], + tools=tools, + output_schema=input.get("output_schema"), + handoffs=handoffs, + tracing=ModelTracing(input["tracing"]), + previous_response_id=input.get("previous_response_id"), + conversation_id=input.get("conversation_id"), + prompt=input.get("prompt"), + ) + + # Batch events with configurable latency + batch: list[TResponseStreamEvent] = [] + last_signal_time = asyncio.get_event_loop().time() + batch_latency = input.get("batch_latency_seconds", 1.0) + + async def send_batch(): + nonlocal last_signal_time + if batch: + await handle.signal(input["signal"], batch) + batch.clear() + last_signal_time = asyncio.get_event_loop().time() + + try: + while True: + # If latency has been passed, send the batch + if asyncio.get_event_loop().time() - last_signal_time >= batch_latency: + await send_batch() + try: + event = await asyncio.wait_for( + anext(events), timeout=asyncio.get_event_loop().time() - last_signal_time + ) + event.model_rebuild() + batch.append(event) + + # If the wait timed out, the latency has expired so send the batch + except asyncio.TimeoutError: + await send_batch() + except StopAsyncIteration: + pass + + # Send any remaining events in the batch + if batch: + await send_batch() + + except APIStatusError as e: + _handle_error(e) + + +async def _empty_on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> str: + return "" + + +async def _empty_on_invoke_handoff(ctx: RunContextWrapper[Any], input: str) -> Any: + return None + + +def _make_tool(tool: ToolInput) -> Tool: + if isinstance( + tool, + ( + FileSearchTool, + WebSearchTool, + ImageGenerationTool, + CodeInterpreterTool, + ), + ): + return tool + elif isinstance(tool, HostedMCPToolInput): + return HostedMCPTool( + tool_config=tool.tool_config, + ) + elif isinstance(tool, FunctionToolInput): + return FunctionTool( + name=tool.name, + description=tool.description, + params_json_schema=tool.params_json_schema, + on_invoke_tool=_empty_on_invoke_tool, + strict_json_schema=tool.strict_json_schema, + ) + else: + raise UserError(f"Unknown tool type: {tool.name}") + + +def _make_tools(input: ActivityModelInput) -> list[Tool]: + return [_make_tool(x) for x in input.get("tools", [])] + + +def _make_handoffs(input: ActivityModelInput) -> list[Handoff[Any, Any]]: + return [ + Handoff( + tool_name=x.tool_name, + tool_description=x.tool_description, + input_json_schema=x.input_json_schema, + agent_name=x.agent_name, + strict_json_schema=x.strict_json_schema, + on_invoke_handoff=_empty_on_invoke_handoff, + ) + for x in input.get("handoffs", []) + ] + + +def _handle_error(e: APIStatusError) -> NoReturn: + # Listen to server hints + retry_after = None + retry_after_ms_header = e.response.headers.get("retry-after-ms") + if retry_after_ms_header is not None: + retry_after = timedelta(milliseconds=float(retry_after_ms_header)) + + if retry_after is None: + retry_after_header = e.response.headers.get("retry-after") + if retry_after_header is not None: + retry_after = timedelta(seconds=float(retry_after_header)) + + should_retry_header = e.response.headers.get("x-should-retry") + if should_retry_header == "true": + raise e + if should_retry_header == "false": + raise ApplicationError( + "Non retryable OpenAI error", + non_retryable=True, + next_retry_delay=retry_after, + ) from e + + # Specifically retryable status codes + if e.response.status_code in [408, 409, 429] or e.response.status_code >= 500: + raise ApplicationError( + f"Retryable OpenAI status code: {e.response.status_code}", + non_retryable=False, + next_retry_delay=retry_after, + ) from e + + raise ApplicationError( + f"Non retryable OpenAI status code: {e.response.status_code}", + non_retryable=True, + next_retry_delay=retry_after, + ) from e diff --git a/temporalio/contrib/openai_agents/_openai_runner.py b/temporalio/contrib/openai_agents/_openai_runner.py index a8065d207..b261306a6 100644 --- a/temporalio/contrib/openai_agents/_openai_runner.py +++ b/temporalio/contrib/openai_agents/_openai_runner.py @@ -15,7 +15,7 @@ Tool, TResponseInputItem, ) -from agents.run import DEFAULT_AGENT_RUNNER, DEFAULT_MAX_TURNS, AgentRunner +from agents.run import DEFAULT_AGENT_RUNNER, AgentRunner from temporalio import workflow from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters @@ -98,66 +98,15 @@ async def run( **kwargs, ) - tool_types = typing.get_args(Tool) - for t in starting_agent.tools: - if not isinstance(t, tool_types): - raise ValueError( - "Provided tool is not a tool type. If using an activity, make sure to wrap it with openai_agents.workflow.activity_as_tool." - ) + _check_preconditions(starting_agent, **kwargs) - if starting_agent.mcp_servers: - from temporalio.contrib.openai_agents._mcp import ( - _StatefulMCPServerReference, - _StatelessMCPServerReference, - ) - - for s in starting_agent.mcp_servers: - if not isinstance( - s, - ( - _StatelessMCPServerReference, - _StatefulMCPServerReference, - ), - ): - raise ValueError( - f"Unknown mcp_server type {type(s)} may not work durably." - ) - - context = kwargs.get("context") - max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS) - hooks = kwargs.get("hooks") - run_config = kwargs.get("run_config") - previous_response_id = kwargs.get("previous_response_id") - session = kwargs.get("session") - - if isinstance(session, SQLiteSession): - raise ValueError("Temporal workflows don't support SQLite sessions.") - - if run_config is None: - run_config = RunConfig() - - if run_config.model: - if not isinstance(run_config.model, str): - raise ValueError( - "Temporal workflows require a model name to be a string in the run config." - ) - run_config = dataclasses.replace( - run_config, - model=_TemporalModelStub( - run_config.model, model_params=self.model_params, agent=None - ), - ) + kwargs["run_config"] = self._process_run_config(kwargs.get("run_config")) try: return await self._runner.run( starting_agent=_convert_agent(self.model_params, starting_agent, None), input=input, - context=context, - max_turns=max_turns, - hooks=hooks, - run_config=run_config, - previous_response_id=previous_response_id, - session=session, + **kwargs, ) except AgentsException as e: # In order for workflow failures to properly fail the workflow, we need to rewrap them in @@ -199,7 +148,45 @@ def run_streamed( input, **kwargs, ) - raise RuntimeError("Temporal workflows do not support streaming.") + + _check_preconditions(starting_agent, **kwargs) + + kwargs["run_config"] = self._process_run_config(kwargs.get("run_config")) + + try: + return self._runner.run_streamed( + starting_agent=_convert_agent(self.model_params, starting_agent, None), + input=input, + **kwargs, + ) + except AgentsException as e: + # In order for workflow failures to properly fail the workflow, we need to rewrap them in + # a Temporal error + if e.__cause__ and workflow.is_failure_exception(e.__cause__): + reraise = AgentsWorkflowError( + f"Workflow failure exception in Agents Framework: {e}" + ) + reraise.__traceback__ = e.__traceback__ + raise reraise from e.__cause__ + else: + raise e + + def _process_run_config(self, run_config: RunConfig | None) -> RunConfig: + if run_config is None: + run_config = RunConfig() + + if run_config.model: + if not isinstance(run_config.model, str): + raise ValueError( + "Temporal workflows require a model name to be a string in the run config." + ) + run_config = dataclasses.replace( + run_config, + model=_TemporalModelStub( + run_config.model, model_params=self.model_params, agent=None + ), + ) + return run_config def _model_name(agent: Agent[Any]) -> str | None: @@ -209,3 +196,34 @@ def _model_name(agent: Agent[Any]) -> str | None: "Temporal workflows require a model name to be a string in the agent." ) return name + + +def _check_preconditions(starting_agent: Agent[TContext], **kwargs: Any) -> None: + tool_types = typing.get_args(Tool) + for t in starting_agent.tools: + if not isinstance(t, tool_types): + raise ValueError( + "Provided tool is not a tool type. If using an activity, make sure to wrap it with openai_agents.workflow.activity_as_tool." + ) + + if starting_agent.mcp_servers: + from temporalio.contrib.openai_agents._mcp import ( + _StatefulMCPServerReference, + _StatelessMCPServerReference, + ) + + for s in starting_agent.mcp_servers: + if not isinstance( + s, + ( + _StatelessMCPServerReference, + _StatefulMCPServerReference, + ), + ): + raise ValueError( + f"Unknown mcp_server type {type(s)} may not work durably." + ) + + session = kwargs.get("session") + if isinstance(session, SQLiteSession): + raise ValueError("Temporal workflows don't support SQLite sessions.") diff --git a/temporalio/contrib/openai_agents/_temporal_model_stub.py b/temporalio/contrib/openai_agents/_temporal_model_stub.py index f84488541..7116cc982 100644 --- a/temporalio/contrib/openai_agents/_temporal_model_stub.py +++ b/temporalio/contrib/openai_agents/_temporal_model_stub.py @@ -1,10 +1,13 @@ from __future__ import annotations +import asyncio import logging -from typing import Optional +from asyncio import FIRST_COMPLETED +from typing import Optional, Tuple from temporalio import workflow from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters +from temporalio.contrib.openai_agents._temporal_trace_provider import _workflow_uuid logger = logging.getLogger(__name__) @@ -40,6 +43,7 @@ HostedMCPToolInput, ModelActivity, ModelTracingInput, + StreamActivityModelInput, ToolInput, ) @@ -57,6 +61,7 @@ def __init__( self.model_name = model_name self.model_params = model_params self.agent = agent + self.stream_events: list[TResponseStreamEvent] = [] async def get_response( self, @@ -72,88 +77,25 @@ async def get_response( conversation_id: str | None, prompt: ResponsePromptParam | None, ) -> ModelResponse: - def make_tool_info(tool: Tool) -> ToolInput: - if isinstance( - tool, - ( - FileSearchTool, - WebSearchTool, - ImageGenerationTool, - CodeInterpreterTool, - ), - ): - return tool - elif isinstance(tool, HostedMCPTool): - return HostedMCPToolInput(tool_config=tool.tool_config) - elif isinstance(tool, FunctionTool): - return FunctionToolInput( - name=tool.name, - description=tool.description, - params_json_schema=tool.params_json_schema, - strict_json_schema=tool.strict_json_schema, - ) - else: - raise ValueError(f"Unsupported tool type: {tool.name}") - - tool_infos = [make_tool_info(x) for x in tools] - handoff_infos = [ - HandoffInput( - tool_name=x.tool_name, - tool_description=x.tool_description, - input_json_schema=x.input_json_schema, - agent_name=x.agent_name, - strict_json_schema=x.strict_json_schema, - ) - for x in handoffs - ] - if output_schema is not None and not isinstance( - output_schema, AgentOutputSchema - ): - raise TypeError( - f"Only AgentOutputSchema is supported by Temporal Model, got {type(output_schema).__name__}" - ) - agent_output_schema = output_schema - output_schema_input = ( - None - if agent_output_schema is None - else AgentOutputSchemaInput( - output_type_name=agent_output_schema.name(), - is_wrapped=agent_output_schema._is_wrapped, - output_schema=agent_output_schema.json_schema() - if not agent_output_schema.is_plain_text() - else None, - strict_json_schema=agent_output_schema.is_strict_json_schema(), - ) - ) + tool_inputs = _make_tool_inputs(tools) + handoff_inputs = _make_handoff_inputs(handoffs) + output_schema_input = _make_output_schema_input(output_schema) activity_input = ActivityModelInput( model_name=self.model_name, system_instructions=system_instructions, input=input, model_settings=model_settings, - tools=tool_infos, + tools=tool_inputs, output_schema=output_schema_input, - handoffs=handoff_infos, + handoffs=handoff_inputs, tracing=ModelTracingInput(tracing.value), previous_response_id=previous_response_id, conversation_id=conversation_id, prompt=prompt, ) - if self.model_params.summary_override: - summary = ( - self.model_params.summary_override - if isinstance(self.model_params.summary_override, str) - else ( - self.model_params.summary_override.provide( - self.agent, system_instructions, input - ) - ) - ) - elif self.agent: - summary = self.agent.name - else: - summary = None + summary = self._make_summary(system_instructions, input) if self.model_params.use_local_activity: return await workflow.execute_local_activity_method( @@ -196,7 +138,90 @@ def stream_response( conversation_id: str | None, prompt: ResponsePromptParam | None, ) -> AsyncIterator[TResponseStreamEvent]: - raise NotImplementedError("Temporal model doesn't support streams yet") + if self.model_params.use_local_activity: + raise ValueError("Streaming is not available with local activities.") + + tool_inputs = _make_tool_inputs(tools) + handoff_inputs = _make_handoff_inputs(handoffs) + output_schema_input = _make_output_schema_input(output_schema) + + summary = self._make_summary(system_instructions, input) + + stream_queue: asyncio.Queue[TResponseStreamEvent | None] = asyncio.Queue() + + async def handle_stream_event(events: list[TResponseStreamEvent]): + for event in events: + await stream_queue.put(event) + + signal_name = "model_stream_signal" + workflow.set_signal_handler(signal_name, handle_stream_event) + + activity_input = StreamActivityModelInput( + model_name=self.model_name, + system_instructions=system_instructions, + input=input, + model_settings=model_settings, + tools=tool_inputs, + output_schema=output_schema_input, + handoffs=handoff_inputs, + tracing=ModelTracingInput(tracing.value), + previous_response_id=previous_response_id, + conversation_id=conversation_id, + prompt=prompt, + signal=signal_name, + batch_latency_seconds=1.0, + ) + + handle = workflow.start_activity_method( + ModelActivity.stream_model, + args=[activity_input], + summary=summary, + task_queue=self.model_params.task_queue, + schedule_to_close_timeout=self.model_params.schedule_to_close_timeout, + schedule_to_start_timeout=self.model_params.schedule_to_start_timeout, + start_to_close_timeout=self.model_params.start_to_close_timeout, + heartbeat_timeout=self.model_params.heartbeat_timeout, + retry_policy=self.model_params.retry_policy, + cancellation_type=self.model_params.cancellation_type, + versioning_intent=self.model_params.versioning_intent, + priority=self.model_params.priority, + ) + + async def monitor_activity(): + try: + await handle + finally: + await stream_queue.put(None) # Signal end of stream + + monitor_task = asyncio.create_task(monitor_activity()) + + async def generator() -> AsyncIterator[TResponseStreamEvent]: + while True: + item = await stream_queue.get() + if item is None: + monitor_task.cancel() + return + yield item + + return generator() + + def _make_summary( + self, system_instructions: str | None, input: str | list[TResponseInputItem] + ) -> str | None: + if self.model_params.summary_override: + return ( + self.model_params.summary_override + if isinstance(self.model_params.summary_override, str) + else ( + self.model_params.summary_override.provide( + self.agent, system_instructions, input + ) + ) + ) + elif self.agent: + return self.agent.name + else: + return None def _extract_summary(input: str | list[TResponseInputItem]) -> str: @@ -228,3 +253,67 @@ def _extract_summary(input: str | list[TResponseInputItem]) -> str: except Exception as e: logger.error(f"Error getting summary: {e}") return "" + + +def _make_tool_input(tool: Tool) -> ToolInput: + if isinstance( + tool, + ( + FileSearchTool, + WebSearchTool, + ImageGenerationTool, + CodeInterpreterTool, + ), + ): + return tool + elif isinstance(tool, HostedMCPTool): + return HostedMCPToolInput(tool_config=tool.tool_config) + elif isinstance(tool, FunctionTool): + return FunctionToolInput( + name=tool.name, + description=tool.description, + params_json_schema=tool.params_json_schema, + strict_json_schema=tool.strict_json_schema, + ) + else: + raise ValueError(f"Unsupported tool type: {tool.name}") + + +def _make_tool_inputs(tools: list[Tool]) -> list[ToolInput]: + return [_make_tool_input(x) for x in tools] + + +def _make_handoff_inputs(handoffs: list[Handoff]) -> list[HandoffInput]: + return [ + HandoffInput( + tool_name=x.tool_name, + tool_description=x.tool_description, + input_json_schema=x.input_json_schema, + agent_name=x.agent_name, + strict_json_schema=x.strict_json_schema, + ) + for x in handoffs + ] + + +def _make_output_schema_input( + output_schema: AgentOutputSchemaBase | None, +) -> AgentOutputSchemaInput | None: + if output_schema is not None and not isinstance(output_schema, AgentOutputSchema): + raise TypeError( + f"Only AgentOutputSchema is supported by Temporal Model, got {type(output_schema).__name__}" + ) + + agent_output_schema = output_schema + return ( + None + if agent_output_schema is None + else AgentOutputSchemaInput( + output_type_name=agent_output_schema.name(), + is_wrapped=agent_output_schema._is_wrapped, + output_schema=agent_output_schema.json_schema() + if not agent_output_schema.is_plain_text() + else None, + strict_json_schema=agent_output_schema.is_strict_json_schema(), + ) + ) diff --git a/temporalio/contrib/openai_agents/_temporal_openai_agents.py b/temporalio/contrib/openai_agents/_temporal_openai_agents.py index 41ae419f7..5e42882cf 100644 --- a/temporalio/contrib/openai_agents/_temporal_openai_agents.py +++ b/temporalio/contrib/openai_agents/_temporal_openai_agents.py @@ -221,7 +221,11 @@ def add_activities( if not register_activities: return activities or [] - new_activities = [ModelActivity(model_provider).invoke_model_activity] + model_activity = ModelActivity(model_provider) + new_activities = [ + model_activity.invoke_model_activity, + model_activity.stream_model, + ] server_names = [server.name for server in mcp_server_providers] if len(server_names) != len(set(server_names)): diff --git a/temporalio/contrib/openai_agents/testing.py b/temporalio/contrib/openai_agents/testing.py index 4acab196a..1bfdb1e78 100644 --- a/temporalio/contrib/openai_agents/testing.py +++ b/temporalio/contrib/openai_agents/testing.py @@ -17,9 +17,14 @@ ) from agents.items import TResponseOutputItem, TResponseStreamEvent from openai.types.responses import ( + Response, + ResponseCompletedEvent, + ResponseContentPartDoneEvent, ResponseFunctionToolCall, + ResponseOutputItemDoneEvent, ResponseOutputMessage, ResponseOutputText, + ResponseTextDeltaEvent, ) from temporalio.client import Client @@ -109,6 +114,87 @@ def output_message(text: str) -> ModelResponse: ) +class EventBuilders: + """Builders for creating stream events for testing. + + .. warning:: + This API is experimental and may change in the future. + """ + + @staticmethod + def text_delta(text: str) -> ResponseTextDeltaEvent: + """Create a TResponseStreamEvent with an text delta. + + .. warning:: + This API is experimental and may change in the future. + """ + return ResponseTextDeltaEvent( + content_index=0, + delta=text, + item_id="", + logprobs=[], + output_index=0, + sequence_number=0, + type="response.output_text.delta", + ) + + @staticmethod + def content_part_done(text: str) -> TResponseStreamEvent: + """Create a TResponseStreamEvent for content part completion. + + .. warning:: + This API is experimental and may change in the future. + """ + return ResponseContentPartDoneEvent( + content_index=0, + item_id="", + output_index=0, + sequence_number=0, + type="response.content_part.done", + part=ResponseOutputText( + text=text, + annotations=[], + type="output_text", + ), + ) + + @staticmethod + def output_item_done(text: str) -> TResponseStreamEvent: + """Create a TResponseStreamEvent for output item completion. + + .. warning:: + This API is experimental and may change in the future. + """ + return ResponseOutputItemDoneEvent( + output_index=0, + sequence_number=0, + type="response.output_item.done", + item=ResponseBuilders.response_output_message(text), + ) + + @staticmethod + def response_completion(text: str) -> TResponseStreamEvent: + """Create a TResponseStreamEvent for response completion. + + .. warning:: + This API is experimental and may change in the future. + """ + return ResponseCompletedEvent( + response=Response( + id="", + created_at=0.0, + object="response", + model="", + parallel_tool_calls=False, + tool_choice="none", + tools=[], + output=[ResponseBuilders.response_output_message(text)], + ), + sequence_number=0, + type="response.completed", + ) + + class TestModelProvider(ModelProvider): """Test model provider which simply returns the given module. @@ -144,13 +230,19 @@ class TestModel(Model): __test__ = False - def __init__(self, fn: Callable[[], ModelResponse]) -> None: + def __init__( + self, + fn: Callable[[], ModelResponse] | None, + *, + streaming_fn: Callable[[], AsyncIterator[TResponseStreamEvent]] | None = None, + ) -> None: """Initialize a test model with a callable. .. warning:: This API is experimental and may change in the future. """ self.fn = fn + self.streaming_fn = streaming_fn async def get_response( self, @@ -164,6 +256,8 @@ async def get_response( **kwargs, ) -> ModelResponse: """Get a response from the mocked model, by calling the callable passed to the constructor.""" + if self.fn is None: + raise ValueError("No non-streaming function provided") return self.fn() def stream_response( @@ -177,8 +271,10 @@ def stream_response( tracing: ModelTracing, **kwargs, ) -> AsyncIterator[TResponseStreamEvent]: - """Get a streamed response from the model. Unimplemented.""" - raise NotImplementedError() + """Get a streamed response from the model.""" + if self.streaming_fn is None: + raise ValueError("No streaming function provided") + return self.streaming_fn() @staticmethod def returning_responses(responses: list[ModelResponse]) -> "TestModel": @@ -190,6 +286,42 @@ def returning_responses(responses: list[ModelResponse]) -> "TestModel": i = iter(responses) return TestModel(lambda: next(i)) + @staticmethod + def streaming_events(events: list[TResponseStreamEvent]) -> "TestModel": + """Create a mock model which sequentially returns responses from a list. + + .. warning:: + This API is experimental and may change in the future. + """ + + async def generator(): + for event in events: + yield event + + return TestModel(None, streaming_fn=lambda: generator()) + + @staticmethod + def streaming_events_with_ending( + events: list[ResponseTextDeltaEvent], + ) -> "TestModel": + """Create a mock model which sequentially returns responses from a list. Appends ending markers + + .. warning:: + This API is experimental and may change in the future. + """ + content = "" + for event in events: + content += event.delta + + return TestModel.streaming_events( + events + + [ + EventBuilders.content_part_done(content), + EventBuilders.output_item_done(content), + EventBuilders.response_completion(content), + ] + ) + class AgentEnvironment: """Testing environment for OpenAI agents with Temporal integration. diff --git a/tests/contrib/openai_agents/test_openai.py b/tests/contrib/openai_agents/test_openai.py index 381a213cd..1a47c3cf6 100644 --- a/tests/contrib/openai_agents/test_openai.py +++ b/tests/contrib/openai_agents/test_openai.py @@ -72,6 +72,7 @@ ResponseInputTextParam, ResponseOutputMessage, ResponseOutputText, + ResponseTextDeltaEvent, ) from openai.types.responses.response_file_search_tool_call import Result from openai.types.responses.response_function_web_search import ActionSearch @@ -101,6 +102,7 @@ ) from temporalio.contrib.openai_agents.testing import ( AgentEnvironment, + EventBuilders, ResponseBuilders, TestModel, TestModelProvider, @@ -2635,3 +2637,57 @@ async def test_split_workers(client: Client): execution_timeout=timedelta(seconds=120), ) assert result == "test" + + +@workflow.defn +class StreamingHelloWorldAgent: + @workflow.run + async def run(self, prompt: str) -> str | None: + agent = Agent[None]( + name="Assistant", + instructions="You are a helpful assistant.", + ) + + result = None + for _ in range(2): + result = Runner.run_streamed(starting_agent=agent, input=prompt) + async for event in result.stream_events(): + if event.type == "raw_response_event" and isinstance( + event.data, ResponseTextDeltaEvent + ): + print(event.data.delta, end="", flush=True) + + return result.final_output if result else None + + +def streaming_hello_model(): + return TestModel.streaming_events_with_ending( + [ + EventBuilders.text_delta("Hello"), + EventBuilders.text_delta(" there"), + EventBuilders.text_delta("!"), + ] + ) + + +async def test_streaming(client: Client): + async with AgentEnvironment( + model=streaming_hello_model(), + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=30), + ), + ) as env: + client = env.applied_on_client(client) + + async with new_worker( + client, StreamingHelloWorldAgent, max_cached_workflows=0 + ) as worker: + handle = await client.start_workflow( + StreamingHelloWorldAgent.run, + "Say hello.", + id=f"hello-workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=50), + ) + result = await handle.result() + assert result == "Hello there!"