From 4d3e40327fe3c975bea36c2e195b5f49fcece09a Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Tue, 19 Aug 2025 12:27:40 -0700 Subject: [PATCH 1/4] POC For a simple plugin --- README.md | 227 +++++++++--------- temporalio/client.py | 18 +- .../openai_agents/_temporal_openai_agents.py | 10 +- temporalio/plugin.py | 160 ++++++++++++ temporalio/testing/_workflow.py | 4 +- temporalio/worker/__init__.py | 4 +- temporalio/worker/_plugin.py | 8 +- temporalio/worker/_replayer.py | 4 +- temporalio/worker/_worker.py | 10 +- tests/test_client.py | 2 +- tests/test_plugins.py | 56 ++++- 11 files changed, 355 insertions(+), 148 deletions(-) create mode 100644 temporalio/plugin.py diff --git a/README.md b/README.md index 3f6a37ef1..71a7436ea 100644 --- a/README.md +++ b/README.md @@ -1509,32 +1509,34 @@ authentication, modifying connection parameters, or adding custom behavior durin Here's an example of a client plugin that adds custom authentication: ```python -from temporalio.client import Plugin, ClientConfig +from temporalio.client import LowLevelPlugin, ClientConfig import temporalio.service -class AuthenticationPlugin(Plugin): - def __init__(self, api_key: str): - self.api_key = api_key - - def init_client_plugin(self, next: Plugin) -> None: - self.next_client_plugin = next - - def configure_client(self, config: ClientConfig) -> ClientConfig: - # Modify client configuration - config["namespace"] = "my-secure-namespace" - return self.next_client_plugin.configure_client(config) - - async def connect_service_client( - self, config: temporalio.service.ConnectConfig - ) -> temporalio.service.ServiceClient: - # Add authentication to the connection - config.api_key = self.api_key - return await self.next_client_plugin.connect_service_client(config) + +class AuthenticationPlugin(LowLevelPlugin): + def __init__(self, api_key: str): + self.api_key = api_key + + def init_client_plugin(self, next: LowLevelPlugin) -> None: + self.next_client_plugin = next + + def configure_client(self, config: ClientConfig) -> ClientConfig: + # Modify client configuration + config["namespace"] = "my-secure-namespace" + return self.next_client_plugin.configure_client(config) + + async def connect_service_client( + self, config: temporalio.service.ConnectConfig + ) -> temporalio.service.ServiceClient: + # Add authentication to the connection + config.api_key = self.api_key + return await self.next_client_plugin.connect_service_client(config) + # Use the plugin when connecting client = await Client.connect( - "my-server.com:7233", - plugins=[AuthenticationPlugin("my-api-key")] + "my-server.com:7233", + plugins=[AuthenticationPlugin("my-api-key")] ) ``` @@ -1551,53 +1553,59 @@ Here's an example of a worker plugin that adds custom monitoring: import temporalio from contextlib import asynccontextmanager from typing import AsyncIterator -from temporalio.worker import Plugin, WorkerConfig, ReplayerConfig, Worker, Replayer, WorkflowReplayResult +from temporalio.worker import LowLevelPlugin, WorkerConfig, ReplayerConfig, Worker, Replayer, WorkflowReplayResult import logging -class MonitoringPlugin(Plugin): - def __init__(self): - self.logger = logging.getLogger(__name__) - - def init_worker_plugin(self, next: Plugin) -> None: - self.next_worker_plugin = next - - def configure_worker(self, config: WorkerConfig) -> WorkerConfig: - # Modify worker configuration - original_task_queue = config["task_queue"] - config["task_queue"] = f"monitored-{original_task_queue}" - self.logger.info(f"Worker created for task queue: {config['task_queue']}") - return self.next_worker_plugin.configure_worker(config) - - async def run_worker(self, worker: Worker) -> None: - self.logger.info("Starting worker execution") - try: - await self.next_worker_plugin.run_worker(worker) - finally: - self.logger.info("Worker execution completed") - - def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: - return self.next_worker_plugin.configure_replayer(config) - - @asynccontextmanager - async def run_replayer( - self, - replayer: Replayer, - histories: AsyncIterator[temporalio.client.WorkflowHistory], - ) -> AsyncIterator[AsyncIterator[WorkflowReplayResult]]: - self.logger.info("Starting replay execution") - try: - async with self.next_worker_plugin.run_replayer(replayer, histories) as results: - yield results - finally: - self.logger.info("Replay execution completed") + +class MonitoringPlugin(LowLevelPlugin): + def __init__(self): + self.logger = logging.getLogger(__name__) + + def init_worker_plugin(self, next: LowLevelPlugin) -> None: + self.next_worker_plugin = next + + def configure_worker(self, config: WorkerConfig) -> WorkerConfig: + # Modify worker configuration + original_task_queue = config["task_queue"] + config["task_queue"] = f"monitored-{original_task_queue}" + self.logger.info(f"Worker created for task queue: {config['task_queue']}") + return self.next_worker_plugin.configure_worker(config) + + async def run_worker(self, worker: Worker) -> None: + self.logger.info("Starting worker execution") + try: + await self.next_worker_plugin.run_worker(worker) + finally: + self.logger.info("Worker execution completed") + + +def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: + return self.next_worker_plugin.configure_replayer(config) + + +@asynccontextmanager + + +async def run_replayer( + self, + replayer: Replayer, + histories: AsyncIterator[temporalio.client.WorkflowHistory], +) -> AsyncIterator[AsyncIterator[WorkflowReplayResult]]: + self.logger.info("Starting replay execution") + try: + async with self.next_worker_plugin.run_replayer(replayer, histories) as results: + yield results + finally: + self.logger.info("Replay execution completed") + # Use the plugin when creating a worker worker = Worker( - client, - task_queue="my-task-queue", - workflows=[MyWorkflow], - activities=[my_activity], - plugins=[MonitoringPlugin()] + client, + task_queue="my-task-queue", + workflows=[MyWorkflow], + activities=[my_activity], + plugins=[MonitoringPlugin()] ) ``` @@ -1607,60 +1615,63 @@ For plugins that need to work with both clients and workers, you can implement b import temporalio from contextlib import AbstractAsyncContextManager from typing import AsyncIterator -from temporalio.client import Plugin as ClientPlugin, ClientConfig -from temporalio.worker import Plugin as WorkerPlugin, WorkerConfig, ReplayerConfig, Worker, Replayer, WorkflowReplayResult +from temporalio.client import LowLevelPlugin as ClientPlugin, ClientConfig +from temporalio.worker import LowLevelPlugin as WorkerPlugin, WorkerConfig, ReplayerConfig, Worker, Replayer, + +WorkflowReplayResult class UnifiedPlugin(ClientPlugin, WorkerPlugin): - def init_client_plugin(self, next: ClientPlugin) -> None: - self.next_client_plugin = next - - def init_worker_plugin(self, next: WorkerPlugin) -> None: - self.next_worker_plugin = next - - def configure_client(self, config: ClientConfig) -> ClientConfig: - # Client-side customization - config["data_converter"] = pydantic_data_converter - return self.next_client_plugin.configure_client(config) - - async def connect_service_client( - self, config: temporalio.service.ConnectConfig - ) -> temporalio.service.ServiceClient: - # Add authentication to the connection - config.api_key = self.api_key - return await self.next_client_plugin.connect_service_client(config) - - def configure_worker(self, config: WorkerConfig) -> WorkerConfig: - # Worker-side customization - return self.next_worker_plugin.configure_worker(config) - - async def run_worker(self, worker: Worker) -> None: - print("Starting unified worker") - await self.next_worker_plugin.run_worker(worker) - - def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: - config["data_converter"] = pydantic_data_converter - return config - - async def run_replayer( - self, - replayer: Replayer, - histories: AsyncIterator[temporalio.client.WorkflowHistory], - ) -> AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]]: - return self.next_worker_plugin.run_replayer(replayer, histories) - + def init_client_plugin(self, next: ClientPlugin) -> None: + self.next_client_plugin = next + + def init_worker_plugin(self, next: WorkerPlugin) -> None: + self.next_worker_plugin = next + + def configure_client(self, config: ClientConfig) -> ClientConfig: + # Client-side customization + config["data_converter"] = pydantic_data_converter + return self.next_client_plugin.configure_client(config) + + async def connect_service_client( + self, config: temporalio.service.ConnectConfig + ) -> temporalio.service.ServiceClient: + # Add authentication to the connection + config.api_key = self.api_key + return await self.next_client_plugin.connect_service_client(config) + + def configure_worker(self, config: WorkerConfig) -> WorkerConfig: + # Worker-side customization + return self.next_worker_plugin.configure_worker(config) + + async def run_worker(self, worker: Worker) -> None: + print("Starting unified worker") + await self.next_worker_plugin.run_worker(worker) + + def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: + config["data_converter"] = pydantic_data_converter + return config + + async def run_replayer( + self, + replayer: Replayer, + histories: AsyncIterator[temporalio.client.WorkflowHistory], + ) -> AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]]: + return self.next_worker_plugin.run_replayer(replayer, histories) + + # Create client with the unified plugin client = await Client.connect( - "localhost:7233", - plugins=[UnifiedPlugin()] + "localhost:7233", + plugins=[UnifiedPlugin()] ) # Worker will automatically inherit the plugin from the client worker = Worker( - client, - task_queue="my-task-queue", - workflows=[MyWorkflow], - activities=[my_activity] + client, + task_queue="my-task-queue", + workflows=[MyWorkflow], + activities=[my_activity] ) ``` diff --git a/temporalio/client.py b/temporalio/client.py index 71bdf3dee..44e70f1b0 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -109,7 +109,7 @@ async def connect( namespace: str = "default", api_key: Optional[str] = None, data_converter: temporalio.converter.DataConverter = temporalio.converter.DataConverter.default, - plugins: Sequence[Plugin] = [], + plugins: Sequence[LowLevelPlugin] = [], interceptors: Sequence[Interceptor] = [], default_workflow_query_reject_condition: Optional[ temporalio.common.QueryRejectCondition @@ -190,7 +190,7 @@ async def connect( http_connect_proxy_config=http_connect_proxy_config, ) - root_plugin: Plugin = _RootPlugin() + root_plugin: LowLevelPlugin = _RootPlugin() for plugin in reversed(plugins): plugin.init_client_plugin(root_plugin) root_plugin = plugin @@ -213,7 +213,7 @@ def __init__( *, namespace: str = "default", data_converter: temporalio.converter.DataConverter = temporalio.converter.DataConverter.default, - plugins: Sequence[Plugin] = [], + plugins: Sequence[LowLevelPlugin] = [], interceptors: Sequence[Interceptor] = [], default_workflow_query_reject_condition: Optional[ temporalio.common.QueryRejectCondition @@ -235,7 +235,7 @@ def __init__( plugins=plugins, ) - root_plugin: Plugin = _RootPlugin() + root_plugin: LowLevelPlugin = _RootPlugin() for plugin in reversed(plugins): plugin.init_client_plugin(root_plugin) root_plugin = plugin @@ -1540,7 +1540,7 @@ class ClientConfig(TypedDict, total=False): Optional[temporalio.common.QueryRejectCondition] ] header_codec_behavior: Required[HeaderCodecBehavior] - plugins: Required[Sequence[Plugin]] + plugins: Required[Sequence[LowLevelPlugin]] class WorkflowHistoryEventFilterType(IntEnum): @@ -7367,7 +7367,7 @@ async def _decode_user_metadata( ) -class Plugin(abc.ABC): +class LowLevelPlugin(abc.ABC): """Base class for client plugins that can intercept and modify client behavior. Plugins allow customization of client creation and service connection processes @@ -7387,7 +7387,7 @@ def name(self) -> str: return type(self).__module__ + "." + type(self).__qualname__ @abstractmethod - def init_client_plugin(self, next: Plugin) -> None: + def init_client_plugin(self, next: LowLevelPlugin) -> None: """Initialize this plugin in the plugin chain. This method sets up the chain of responsibility pattern by providing a reference @@ -7433,8 +7433,8 @@ async def connect_service_client( """ -class _RootPlugin(Plugin): - def init_client_plugin(self, next: Plugin) -> None: +class _RootPlugin(LowLevelPlugin): + def init_client_plugin(self, next: LowLevelPlugin) -> None: raise NotImplementedError() def configure_client(self, config: ClientConfig) -> ClientConfig: diff --git a/temporalio/contrib/openai_agents/_temporal_openai_agents.py b/temporalio/contrib/openai_agents/_temporal_openai_agents.py index 73b9723d0..8340c0cc2 100644 --- a/temporalio/contrib/openai_agents/_temporal_openai_agents.py +++ b/temporalio/contrib/openai_agents/_temporal_openai_agents.py @@ -24,7 +24,7 @@ import temporalio.client import temporalio.worker -from temporalio.client import ClientConfig, Plugin +from temporalio.client import ClientConfig, LowLevelPlugin from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters from temporalio.contrib.openai_agents._openai_runner import TemporalOpenAIRunner @@ -150,7 +150,9 @@ def __init__(self) -> None: super().__init__(ToJsonOptions(exclude_unset=True)) -class OpenAIAgentsPlugin(temporalio.client.Plugin, temporalio.worker.Plugin): +class OpenAIAgentsPlugin( + temporalio.client.LowLevelPlugin, temporalio.worker.LowLevelPlugin +): """Temporal plugin for integrating OpenAI agents with Temporal workflows. .. warning:: @@ -233,7 +235,7 @@ def __init__( self._model_params = model_params self._model_provider = model_provider - def init_client_plugin(self, next: temporalio.client.Plugin) -> None: + def init_client_plugin(self, next: temporalio.client.LowLevelPlugin) -> None: """Set the next client plugin""" self.next_client_plugin = next @@ -243,7 +245,7 @@ async def connect_service_client( """No modifications to service client""" return await self.next_client_plugin.connect_service_client(config) - def init_worker_plugin(self, next: temporalio.worker.Plugin) -> None: + def init_worker_plugin(self, next: temporalio.worker.LowLevelPlugin) -> None: """Set the next worker plugin""" self.next_worker_plugin = next diff --git a/temporalio/plugin.py b/temporalio/plugin.py new file mode 100644 index 000000000..708253943 --- /dev/null +++ b/temporalio/plugin.py @@ -0,0 +1,160 @@ +import abc +from contextlib import AbstractAsyncContextManager, asynccontextmanager +from typing import Any, AsyncIterator, Callable, Sequence, Type, TypedDict, TypeVar + +import temporalio.client +import temporalio.converter +import temporalio.worker +from temporalio.client import ClientConfig, WorkflowHistory +from temporalio.worker import ( + Replayer, + ReplayerConfig, + Worker, + WorkerConfig, + WorkflowReplayResult, + WorkflowRunner, +) + + +class PluginConfig(TypedDict, total=False): + data_converter: temporalio.converter.DataConverter + client_interceptors: Sequence[temporalio.client.Interceptor] + worker_interceptors: Sequence[temporalio.worker.Interceptor] + activities: Sequence[Callable] + nexus_service_handlers: Sequence[Any] + workflows: Sequence[Type] + workflow_runner: WorkflowRunner + + +class Plugin( + temporalio.client.LowLevelPlugin, temporalio.worker.LowLevelPlugin, abc.ABC +): + def init_worker_plugin(self, next: temporalio.worker.LowLevelPlugin) -> None: + self.next_worker_plugin = next + + def init_client_plugin(self, next: temporalio.client.LowLevelPlugin) -> None: + self.next_client_plugin = next + + def configure_client(self, config: ClientConfig) -> ClientConfig: + plugin_config = self.configuration() + + new_converter = plugin_config.get("data_converter") + if new_converter: + if not config["data_converter"] == temporalio.converter.default(): + config["data_converter"] = self.resolve_collision( + config["data_converter"], new_converter + ) + else: + config["data_converter"] = new_converter + + client_interceptors = plugin_config.get("client_interceptors") + if client_interceptors: + config["interceptors"] = list(config.get("interceptors", [])) + list( + client_interceptors + ) + return self.next_client_plugin.configure_client(config) + + async def connect_service_client( + self, config: temporalio.service.ConnectConfig + ) -> temporalio.service.ServiceClient: + return await self.next_client_plugin.connect_service_client(config) + + def configure_worker(self, config: WorkerConfig) -> WorkerConfig: + plugin_config = self.configuration() + + worker_interceptors = plugin_config.get("worker_interceptors") + if worker_interceptors: + config["interceptors"] = list(config.get("interceptors", [])) + list( + worker_interceptors + ) + + activities = plugin_config.get("activities") + if activities: + config["activities"] = list(config.get("activities", [])) + list(activities) + + nexus_service_handlers = plugin_config.get("nexus_service_handlers") + if nexus_service_handlers: + config["nexus_service_handlers"] = list( + config.get("nexus_service_handlers", []) + ) + list(nexus_service_handlers) + + workflows = plugin_config.get("workflows") + if workflows: + config["workflows"] = list(config.get("workflows", [])) + list(workflows) + + workflow_runner = plugin_config.get("workflow_runner") + if workflow_runner: + old_runner = config.get("workflow_runner") + if old_runner: + config["workflow_runner"] = self.resolve_collision( + old_runner, workflow_runner + ) + else: + config["workflow_runner"] = workflow_runner + return config + + def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: + plugin_config = self.configuration() + + new_converter = plugin_config.get("data_converter") + if new_converter: + old_converter = config.get("data_converter") + if old_converter and not old_converter == temporalio.converter.default(): + config["data_converter"] = self.resolve_collision( + old_converter, new_converter + ) + else: + config["data_converter"] = new_converter + + worker_interceptors = plugin_config.get("worker_interceptors") + if worker_interceptors: + config["interceptors"] = list(config.get("interceptors", [])) + list( + worker_interceptors + ) + + workflows = plugin_config.get("workflows") + if workflows: + config["workflows"] = list(config.get("workflows", [])) + list(workflows) + + workflow_runner = plugin_config.get("workflow_runner") + if workflow_runner: + old_runner = config.get("workflow_runner") + if old_runner: + config["workflow_runner"] = self.resolve_collision( + old_runner, workflow_runner + ) + else: + config["workflow_runner"] = workflow_runner + return config + + async def run_worker(self, worker: Worker) -> None: + async with self.run_context(): + await self.next_worker_plugin.run_worker(worker) + + @asynccontextmanager + async def run_replayer( + self, + replayer: Replayer, + histories: AsyncIterator[WorkflowHistory], + ) -> AsyncIterator[AsyncIterator[WorkflowReplayResult]]: + async with self.next_worker_plugin.run_replayer(replayer, histories) as results: + yield results + + @abc.abstractmethod + def run_context(self) -> AbstractAsyncContextManager[None]: + raise NotImplementedError() + + @abc.abstractmethod + def configuration(self) -> PluginConfig: + raise NotImplementedError() + + T = TypeVar("T") + + def resolve_collision( + self, + old: T, + new: T, + ) -> T: + """How to handle cases where an option is already set by the user or an earlier plugin. + The default implementation is to fail, but it can be overridden.""" + raise ValueError(f"{old} is already set, plugin cannot reasonable override.") diff --git a/temporalio/testing/_workflow.py b/temporalio/testing/_workflow.py index d0eda5580..09c70fb6f 100644 --- a/temporalio/testing/_workflow.py +++ b/temporalio/testing/_workflow.py @@ -79,7 +79,7 @@ async def start_local( namespace: str = "default", data_converter: temporalio.converter.DataConverter = temporalio.converter.DataConverter.default, interceptors: Sequence[temporalio.client.Interceptor] = [], - plugins: Sequence[temporalio.client.Plugin] = [], + plugins: Sequence[temporalio.client.LowLevelPlugin] = [], default_workflow_query_reject_condition: Optional[ temporalio.common.QueryRejectCondition ] = None, @@ -239,7 +239,7 @@ async def start_time_skipping( *, data_converter: temporalio.converter.DataConverter = temporalio.converter.DataConverter.default, interceptors: Sequence[temporalio.client.Interceptor] = [], - plugins: Sequence[temporalio.client.Plugin] = [], + plugins: Sequence[temporalio.client.LowLevelPlugin] = [], default_workflow_query_reject_condition: Optional[ temporalio.common.QueryRejectCondition ] = None, diff --git a/temporalio/worker/__init__.py b/temporalio/worker/__init__.py index 08686dcb3..8823f1201 100644 --- a/temporalio/worker/__init__.py +++ b/temporalio/worker/__init__.py @@ -21,7 +21,7 @@ WorkflowInterceptorClassInput, WorkflowOutboundInterceptor, ) -from ._plugin import Plugin +from ._plugin import LowLevelPlugin from ._replayer import ( Replayer, ReplayerConfig, @@ -79,7 +79,7 @@ "ActivityOutboundInterceptor", "WorkflowInboundInterceptor", "WorkflowOutboundInterceptor", - "Plugin", + "LowLevelPlugin", # Interceptor input "ContinueAsNewInput", "ExecuteActivityInput", diff --git a/temporalio/worker/_plugin.py b/temporalio/worker/_plugin.py index 0e696a2dd..12b1467f0 100644 --- a/temporalio/worker/_plugin.py +++ b/temporalio/worker/_plugin.py @@ -16,7 +16,7 @@ ) -class Plugin(abc.ABC): +class LowLevelPlugin(abc.ABC): """Base class for worker plugins that can intercept and modify worker behavior. Plugins allow customization of worker creation and execution processes @@ -35,7 +35,7 @@ def name(self) -> str: return type(self).__module__ + "." + type(self).__qualname__ @abc.abstractmethod - def init_worker_plugin(self, next: Plugin) -> None: + def init_worker_plugin(self, next: LowLevelPlugin) -> None: """Initialize this plugin in the plugin chain. This method sets up the chain of responsibility pattern by providing a reference @@ -98,8 +98,8 @@ def run_replayer( """Hook called when running a replayer to allow interception of execution.""" -class _RootPlugin(Plugin): - def init_worker_plugin(self, next: Plugin) -> None: +class _RootPlugin(LowLevelPlugin): + def init_worker_plugin(self, next: LowLevelPlugin) -> None: raise NotImplementedError() def configure_worker(self, config: WorkerConfig) -> WorkerConfig: diff --git a/temporalio/worker/_replayer.py b/temporalio/worker/_replayer.py index 240429bf7..9e26c39ae 100644 --- a/temporalio/worker/_replayer.py +++ b/temporalio/worker/_replayer.py @@ -43,7 +43,7 @@ def __init__( namespace: str = "ReplayNamespace", data_converter: temporalio.converter.DataConverter = temporalio.converter.DataConverter.default, interceptors: Sequence[Interceptor] = [], - plugins: Sequence[temporalio.worker.Plugin] = [], + plugins: Sequence[temporalio.worker.LowLevelPlugin] = [], build_id: Optional[str] = None, identity: Optional[str] = None, workflow_failure_exception_types: Sequence[Type[BaseException]] = [], @@ -84,7 +84,7 @@ def __init__( ) # Apply plugin configuration - root_plugin: temporalio.worker.Plugin = _RootPlugin() + root_plugin: temporalio.worker.LowLevelPlugin = _RootPlugin() for plugin in reversed(plugins): plugin.init_worker_plugin(root_plugin) root_plugin = plugin diff --git a/temporalio/worker/_worker.py b/temporalio/worker/_worker.py index f93848496..ccaa94ac5 100644 --- a/temporalio/worker/_worker.py +++ b/temporalio/worker/_worker.py @@ -38,7 +38,7 @@ from ._activity import SharedStateManager, _ActivityWorker from ._interceptor import Interceptor from ._nexus import _NexusWorker -from ._plugin import Plugin, _RootPlugin +from ._plugin import LowLevelPlugin, _RootPlugin from ._tuning import WorkerTuner from ._workflow import _WorkflowWorker from ._workflow_instance import UnsandboxedWorkflowRunner, WorkflowRunner @@ -111,7 +111,7 @@ def __init__( nexus_task_executor: Optional[concurrent.futures.Executor] = None, workflow_runner: WorkflowRunner = SandboxedWorkflowRunner(), unsandboxed_workflow_runner: WorkflowRunner = UnsandboxedWorkflowRunner(), - plugins: Sequence[Plugin] = [], + plugins: Sequence[LowLevelPlugin] = [], interceptors: Sequence[Interceptor] = [], build_id: Optional[str] = None, identity: Optional[str] = None, @@ -362,8 +362,8 @@ def __init__( ) plugins_from_client = cast( - List[Plugin], - [p for p in client.config()["plugins"] if isinstance(p, Plugin)], + List[LowLevelPlugin], + [p for p in client.config()["plugins"] if isinstance(p, LowLevelPlugin)], ) for client_plugin in plugins_from_client: if type(client_plugin) in [type(p) for p in plugins]: @@ -372,7 +372,7 @@ def __init__( ) plugins = plugins_from_client + list(plugins) - root_plugin: Plugin = _RootPlugin() + root_plugin: LowLevelPlugin = _RootPlugin() for plugin in reversed(plugins): plugin.init_worker_plugin(root_plugin) root_plugin = plugin diff --git a/tests/test_client.py b/tests/test_client.py index 9c33e9e1c..c4aa1262c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -44,8 +44,8 @@ ClientConfig, CloudOperationsClient, Interceptor, + LowLevelPlugin, OutboundInterceptor, - Plugin, QueryWorkflowInput, RPCError, RPCStatusCode, diff --git a/tests/test_plugins.py b/tests/test_plugins.py index eb08bba2d..030dbf578 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -7,10 +7,12 @@ import pytest import temporalio.client +import temporalio.plugin import temporalio.worker from temporalio import workflow -from temporalio.client import Client, ClientConfig, OutboundInterceptor, Plugin +from temporalio.client import Client, ClientConfig, LowLevelPlugin, OutboundInterceptor from temporalio.contrib.pydantic import pydantic_data_converter +from temporalio.plugin import PluginConfig from temporalio.testing import WorkflowEnvironment from temporalio.worker import ( Replayer, @@ -33,11 +35,11 @@ def intercept_client(self, next: OutboundInterceptor) -> OutboundInterceptor: return super().intercept_client(next) -class MyClientPlugin(temporalio.client.Plugin): +class MyClientPlugin(temporalio.client.LowLevelPlugin): def __init__(self): self.interceptor = TestClientInterceptor() - def init_client_plugin(self, next: Plugin) -> None: + def init_client_plugin(self, next: LowLevelPlugin) -> None: self.next_client_plugin = next def configure_client(self, config: ClientConfig) -> ClientConfig: @@ -72,11 +74,13 @@ async def test_client_plugin(client: Client, env: WorkflowEnvironment): assert new_client.service_client.config.api_key == "replaced key" -class MyCombinedPlugin(temporalio.client.Plugin, temporalio.worker.Plugin): - def init_worker_plugin(self, next: temporalio.worker.Plugin) -> None: +class MyCombinedPlugin( + temporalio.client.LowLevelPlugin, temporalio.worker.LowLevelPlugin +): + def init_worker_plugin(self, next: temporalio.worker.LowLevelPlugin) -> None: self.next_worker_plugin = next - def init_client_plugin(self, next: temporalio.client.Plugin) -> None: + def init_client_plugin(self, next: temporalio.client.LowLevelPlugin) -> None: self.next_client_plugin = next def configure_client(self, config: ClientConfig) -> ClientConfig: @@ -105,8 +109,8 @@ def run_replayer( return self.next_worker_plugin.run_replayer(replayer, histories) -class MyWorkerPlugin(temporalio.worker.Plugin): - def init_worker_plugin(self, next: temporalio.worker.Plugin) -> None: +class MyWorkerPlugin(temporalio.worker.LowLevelPlugin): + def init_worker_plugin(self, next: temporalio.worker.LowLevelPlugin) -> None: self.next_worker_plugin = next def configure_worker(self, config: WorkerConfig) -> WorkerConfig: @@ -192,11 +196,13 @@ async def test_worker_sandbox_restrictions(client: Client) -> None: ) -class ReplayCheckPlugin(temporalio.client.Plugin, temporalio.worker.Plugin): - def init_worker_plugin(self, next: temporalio.worker.Plugin) -> None: +class ReplayCheckPlugin( + temporalio.client.LowLevelPlugin, temporalio.worker.LowLevelPlugin +): + def init_worker_plugin(self, next: temporalio.worker.LowLevelPlugin) -> None: self.next_worker_plugin = next - def init_client_plugin(self, next: temporalio.client.Plugin) -> None: + def init_client_plugin(self, next: temporalio.client.LowLevelPlugin) -> None: self.next_client_plugin = next def configure_client(self, config: ClientConfig) -> ClientConfig: @@ -256,3 +262,31 @@ async def test_replay(client: Client) -> None: assert replayer.config().get("data_converter") == pydantic_data_converter await replayer.replay_workflow(await handle.fetch_history()) + + +class SimplePlugin(temporalio.plugin.Plugin): + @asynccontextmanager + async def run_context(self) -> AsyncIterator[None]: + yield + + def configuration(self) -> PluginConfig: + return PluginConfig( + data_converter=pydantic_data_converter, + workflows=[HelloWorkflow], + ) + + +async def test_simple_plugin(client: Client) -> None: + plugin = SimplePlugin() + new_config = client.config() + new_config["plugins"] = [plugin] + client = Client(**new_config) + + async with new_worker(client) as worker: + handle = await client.start_workflow( + HelloWorkflow.run, + "Tim", + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + await handle.result() From 0226d95dfa2cb94ceb760d18722f84f0b6499654 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Tue, 19 Aug 2025 12:38:41 -0700 Subject: [PATCH 2/4] Demo with openai plugin --- .../openai_agents/_temporal_openai_agents.py | 86 +++---------------- 1 file changed, 14 insertions(+), 72 deletions(-) diff --git a/temporalio/contrib/openai_agents/_temporal_openai_agents.py b/temporalio/contrib/openai_agents/_temporal_openai_agents.py index 8340c0cc2..448354e2c 100644 --- a/temporalio/contrib/openai_agents/_temporal_openai_agents.py +++ b/temporalio/contrib/openai_agents/_temporal_openai_agents.py @@ -1,6 +1,6 @@ """Initialize Temporal OpenAI Agents overrides.""" -from contextlib import asynccontextmanager, contextmanager +from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager from datetime import timedelta from typing import AsyncIterator, Callable, Optional, Union @@ -23,6 +23,7 @@ from openai.types.responses import ResponsePromptParam import temporalio.client +import temporalio.plugin import temporalio.worker from temporalio.client import ClientConfig, LowLevelPlugin from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity @@ -41,6 +42,7 @@ from temporalio.converter import ( DataConverter, ) +from temporalio.plugin import PluginConfig from temporalio.worker import ( Replayer, ReplayerConfig, @@ -150,9 +152,7 @@ def __init__(self) -> None: super().__init__(ToJsonOptions(exclude_unset=True)) -class OpenAIAgentsPlugin( - temporalio.client.LowLevelPlugin, temporalio.worker.LowLevelPlugin -): +class OpenAIAgentsPlugin(temporalio.plugin.Plugin): """Temporal plugin for integrating OpenAI agents with Temporal workflows. .. warning:: @@ -235,19 +235,14 @@ def __init__( self._model_params = model_params self._model_provider = model_provider - def init_client_plugin(self, next: temporalio.client.LowLevelPlugin) -> None: - """Set the next client plugin""" - self.next_client_plugin = next - - async def connect_service_client( - self, config: temporalio.service.ConnectConfig - ) -> temporalio.service.ServiceClient: - """No modifications to service client""" - return await self.next_client_plugin.connect_service_client(config) - - def init_worker_plugin(self, next: temporalio.worker.LowLevelPlugin) -> None: - """Set the next worker plugin""" - self.next_worker_plugin = next + def configuration(self) -> PluginConfig: + return PluginConfig( + data_converter=DataConverter( + payload_converter_class=_OpenAIPayloadConverter + ), + worker_interceptors=[OpenAIAgentsTracingInterceptor()], + activities=[ModelActivity(self._model_provider).invoke_model_activity], + ) def configure_client(self, config: ClientConfig) -> ClientConfig: """Configure the Temporal client for OpenAI agents integration. @@ -266,60 +261,7 @@ def configure_client(self, config: ClientConfig) -> ClientConfig: ) return self.next_client_plugin.configure_client(config) - def configure_worker(self, config: WorkerConfig) -> WorkerConfig: - """Configure the Temporal worker for OpenAI agents integration. - - This method adds the necessary interceptors and activities for OpenAI - agent execution: - - Adds tracing interceptors for OpenAI agent interactions - - Registers model execution activities - - Args: - config: The worker configuration to modify. - - Returns: - The modified worker configuration. - """ - config["interceptors"] = list(config.get("interceptors") or []) + [ - OpenAIAgentsTracingInterceptor() - ] - config["activities"] = list(config.get("activities") or []) + [ - ModelActivity(self._model_provider).invoke_model_activity - ] - return self.next_worker_plugin.configure_worker(config) - - async def run_worker(self, worker: Worker) -> None: - """Run the worker with OpenAI agents temporal overrides. - - This method sets up the necessary runtime overrides for OpenAI agents - to work within the Temporal worker context, including custom runners - and trace providers. - - Args: - worker: The worker instance to run. - """ - with set_open_ai_agent_temporal_overrides(self._model_params): - await self.next_worker_plugin.run_worker(worker) - - def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: - """Configure the replayer for OpenAI Agents.""" - config["interceptors"] = list(config.get("interceptors") or []) + [ - OpenAIAgentsTracingInterceptor() - ] - config["data_converter"] = DataConverter( - payload_converter_class=_OpenAIPayloadConverter - ) - return config - @asynccontextmanager - async def run_replayer( - self, - replayer: Replayer, - histories: AsyncIterator[temporalio.client.WorkflowHistory], - ) -> AsyncIterator[AsyncIterator[WorkflowReplayResult]]: - """Set the OpenAI Overrides during replay""" + async def run_context(self) -> AsyncIterator[None]: with set_open_ai_agent_temporal_overrides(self._model_params): - async with self.next_worker_plugin.run_replayer( - replayer, histories - ) as results: - yield results + yield From 6ec4dc7265eefc2fcc8cabf59ce87e31326220be Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Tue, 19 Aug 2025 13:11:08 -0700 Subject: [PATCH 3/4] Fix run context in replayer --- temporalio/plugin.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/temporalio/plugin.py b/temporalio/plugin.py index 708253943..dff4486ee 100644 --- a/temporalio/plugin.py +++ b/temporalio/plugin.py @@ -137,8 +137,9 @@ async def run_replayer( replayer: Replayer, histories: AsyncIterator[WorkflowHistory], ) -> AsyncIterator[AsyncIterator[WorkflowReplayResult]]: - async with self.next_worker_plugin.run_replayer(replayer, histories) as results: - yield results + async with self.run_context(): + async with self.next_worker_plugin.run_replayer(replayer, histories) as results: + yield results @abc.abstractmethod def run_context(self) -> AbstractAsyncContextManager[None]: From 07c5ea7b3202cb8814f03cf64427e129baaebfaf Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Tue, 19 Aug 2025 13:13:35 -0700 Subject: [PATCH 4/4] Lint --- temporalio/plugin.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/temporalio/plugin.py b/temporalio/plugin.py index dff4486ee..b6c03809e 100644 --- a/temporalio/plugin.py +++ b/temporalio/plugin.py @@ -138,7 +138,9 @@ async def run_replayer( histories: AsyncIterator[WorkflowHistory], ) -> AsyncIterator[AsyncIterator[WorkflowReplayResult]]: async with self.run_context(): - async with self.next_worker_plugin.run_replayer(replayer, histories) as results: + async with self.next_worker_plugin.run_replayer( + replayer, histories + ) as results: yield results @abc.abstractmethod