Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 165 additions & 84 deletions temporalio/contrib/openai_agents/_invoke_model_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
Implements mapping of OpenAI datastructures to Pydantic friendly types.
"""

import asyncio
import enum
import json
from dataclasses import dataclass
from datetime import timedelta
from typing import Any, Optional, Union
from typing import Any, NoReturn, Optional, Union

from agents import (
AgentOutputSchemaBase,
Expand All @@ -28,6 +29,7 @@
UserError,
WebSearchTool,
)
from agents.items import TResponseStreamEvent
from openai import (
APIStatusError,
AsyncOpenAI,
Expand Down Expand Up @@ -148,6 +150,13 @@ class ActivityModelInput(TypedDict, total=False):
prompt: Any | None


class StreamActivityModelInput(ActivityModelInput):
"""Input for the stream_model activity."""

signal: str
batch_latency_seconds: float


class ModelActivity:
"""Class wrapper for model invocation activities to allow model customization. By default, we use an OpenAIProvider with retries disabled.
Disabling retries in your model of choice is recommended to allow activity retries to define the retry model.
Expand All @@ -165,52 +174,8 @@ async def invoke_model_activity(self, input: ActivityModelInput) -> ModelRespons
"""Activity that invokes a model with the given input."""
model = self._model_provider.get_model(input.get("model_name"))

async def empty_on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> str:
return ""

async def empty_on_invoke_handoff(
ctx: RunContextWrapper[Any], input: str
) -> Any:
return None

def make_tool(tool: ToolInput) -> Tool:
if isinstance(
tool,
(
FileSearchTool,
WebSearchTool,
ImageGenerationTool,
CodeInterpreterTool,
),
):
return tool
elif isinstance(tool, HostedMCPToolInput):
return HostedMCPTool(
tool_config=tool.tool_config,
)
elif isinstance(tool, FunctionToolInput):
return FunctionTool(
name=tool.name,
description=tool.description,
params_json_schema=tool.params_json_schema,
on_invoke_tool=empty_on_invoke_tool,
strict_json_schema=tool.strict_json_schema,
)
else:
raise UserError(f"Unknown tool type: {tool.name}")

tools = [make_tool(x) for x in input.get("tools", [])]
handoffs: list[Handoff[Any, Any]] = [
Handoff(
tool_name=x.tool_name,
tool_description=x.tool_description,
input_json_schema=x.input_json_schema,
agent_name=x.agent_name,
strict_json_schema=x.strict_json_schema,
on_invoke_handoff=empty_on_invoke_handoff,
)
for x in input.get("handoffs", [])
]
tools = _make_tools(input)
handoffs = _make_handoffs(input)

try:
return await model.get_response(
Expand All @@ -226,40 +191,156 @@ def make_tool(tool: ToolInput) -> Tool:
prompt=input.get("prompt"),
)
except APIStatusError as e:
# Listen to server hints
retry_after = None
retry_after_ms_header = e.response.headers.get("retry-after-ms")
if retry_after_ms_header is not None:
retry_after = timedelta(milliseconds=float(retry_after_ms_header))

if retry_after is None:
retry_after_header = e.response.headers.get("retry-after")
if retry_after_header is not None:
retry_after = timedelta(seconds=float(retry_after_header))

should_retry_header = e.response.headers.get("x-should-retry")
if should_retry_header == "true":
raise e
if should_retry_header == "false":
raise ApplicationError(
"Non retryable OpenAI error",
non_retryable=True,
next_retry_delay=retry_after,
) from e

# Specifically retryable status codes
if (
e.response.status_code in [408, 409, 429]
or e.response.status_code >= 500
):
raise ApplicationError(
f"Retryable OpenAI status code: {e.response.status_code}",
non_retryable=False,
next_retry_delay=retry_after,
) from e

raise ApplicationError(
f"Non retryable OpenAI status code: {e.response.status_code}",
non_retryable=True,
next_retry_delay=retry_after,
) from e
_handle_error(e)

@activity.defn
async def stream_model(self, input: StreamActivityModelInput) -> None:
"""Activity that streams a model with the given input."""
model = self._model_provider.get_model(input.get("model_name"))

tools = _make_tools(input)
handoffs = _make_handoffs(input)

try:
handle = activity.client().get_workflow_handle(
workflow_id=activity.info().workflow_id
)
events = model.stream_response(
system_instructions=input.get("system_instructions"),
input=input["input"],
model_settings=input["model_settings"],
tools=tools,
output_schema=input.get("output_schema"),
handoffs=handoffs,
tracing=ModelTracing(input["tracing"]),
previous_response_id=input.get("previous_response_id"),
conversation_id=input.get("conversation_id"),
prompt=input.get("prompt"),
)

# Batch events with configurable latency
batch: list[TResponseStreamEvent] = []
last_signal_time = asyncio.get_event_loop().time()
batch_latency = input.get("batch_latency_seconds", 1.0)

async def send_batch():
nonlocal last_signal_time
if batch:
await handle.signal(input["signal"], batch)
batch.clear()
last_signal_time = asyncio.get_event_loop().time()

try:
while True:
# If latency has been passed, send the batch
if asyncio.get_event_loop().time() - last_signal_time >= batch_latency:
await send_batch()
try:
event = await asyncio.wait_for(
anext(events), timeout=asyncio.get_event_loop().time() - last_signal_time
)
event.model_rebuild()
batch.append(event)

# If the wait timed out, the latency has expired so send the batch
except asyncio.TimeoutError:
await send_batch()
except StopAsyncIteration:
pass

# Send any remaining events in the batch
if batch:
await send_batch()

except APIStatusError as e:
_handle_error(e)


async def _empty_on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> str:
return ""


async def _empty_on_invoke_handoff(ctx: RunContextWrapper[Any], input: str) -> Any:
return None


def _make_tool(tool: ToolInput) -> Tool:
if isinstance(
tool,
(
FileSearchTool,
WebSearchTool,
ImageGenerationTool,
CodeInterpreterTool,
),
):
return tool
elif isinstance(tool, HostedMCPToolInput):
return HostedMCPTool(
tool_config=tool.tool_config,
)
elif isinstance(tool, FunctionToolInput):
return FunctionTool(
name=tool.name,
description=tool.description,
params_json_schema=tool.params_json_schema,
on_invoke_tool=_empty_on_invoke_tool,
strict_json_schema=tool.strict_json_schema,
)
else:
raise UserError(f"Unknown tool type: {tool.name}")


def _make_tools(input: ActivityModelInput) -> list[Tool]:
return [_make_tool(x) for x in input.get("tools", [])]


def _make_handoffs(input: ActivityModelInput) -> list[Handoff[Any, Any]]:
return [
Handoff(
tool_name=x.tool_name,
tool_description=x.tool_description,
input_json_schema=x.input_json_schema,
agent_name=x.agent_name,
strict_json_schema=x.strict_json_schema,
on_invoke_handoff=_empty_on_invoke_handoff,
)
for x in input.get("handoffs", [])
]


def _handle_error(e: APIStatusError) -> NoReturn:
# Listen to server hints
retry_after = None
retry_after_ms_header = e.response.headers.get("retry-after-ms")
if retry_after_ms_header is not None:
retry_after = timedelta(milliseconds=float(retry_after_ms_header))

if retry_after is None:
retry_after_header = e.response.headers.get("retry-after")
if retry_after_header is not None:
retry_after = timedelta(seconds=float(retry_after_header))

should_retry_header = e.response.headers.get("x-should-retry")
if should_retry_header == "true":
raise e
if should_retry_header == "false":
raise ApplicationError(
"Non retryable OpenAI error",
non_retryable=True,
next_retry_delay=retry_after,
) from e

# Specifically retryable status codes
if e.response.status_code in [408, 409, 429] or e.response.status_code >= 500:
raise ApplicationError(
f"Retryable OpenAI status code: {e.response.status_code}",
non_retryable=False,
next_retry_delay=retry_after,
) from e

raise ApplicationError(
f"Non retryable OpenAI status code: {e.response.status_code}",
non_retryable=True,
next_retry_delay=retry_after,
) from e
Loading
Loading