Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ license = "MIT"
license-files = ["LICENSE"]
keywords = ["temporal", "workflow"]
dependencies = [
"nexus-rpc==1.2.0",
"nexus-rpc @ git+https://github.com/nexus-rpc/sdk-python@interceptors",
"protobuf>=3.20,<7.0.0",
"python-dateutil>=2.8.2,<3 ; python_version < '3.11'",
"types-protobuf>=3.20",
Expand Down
134 changes: 121 additions & 13 deletions temporalio/contrib/opentelemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,25 @@

from __future__ import annotations

import dataclasses
from contextlib import contextmanager
from dataclasses import dataclass
from typing import (
Any,
Callable,
Dict,
Generic,
Iterator,
Mapping,
NoReturn,
Optional,
Sequence,
Type,
TypeVar,
cast,
)

import nexusrpc.handler
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does both import nexusrpc.handler and from nexusrpc.handler import X. Should decide which approach. I personally prefer import nexusrpc.handler and qualify at use site.

import opentelemetry.baggage.propagation
import opentelemetry.context
import opentelemetry.context.context
Expand Down Expand Up @@ -56,6 +60,8 @@

_CarrierDict: TypeAlias = Dict[str, opentelemetry.propagators.textmap.CarrierValT]

_ContextT = TypeVar("_ContextT", bound=nexusrpc.handler.OperationContext)


class TracingInterceptor(temporalio.client.Interceptor, temporalio.worker.Interceptor):
"""Interceptor that supports client and worker OpenTelemetry span creation
Expand Down Expand Up @@ -135,6 +141,14 @@ def workflow_interceptor_class(
)
return TracingWorkflowInboundInterceptor

def intercept_nexus_operation(
self, next: temporalio.worker.NexusOperationInboundInterceptor
) -> temporalio.worker.NexusOperationInboundInterceptor:
"""Implementation of
:py:meth:`temporalio.worker.Interceptor.intercept_nexus_operation`.
"""
return _TracingNexusOperationInboundInterceptor(next, self)

def _context_to_headers(
self, headers: Mapping[str, temporalio.api.common.v1.Payload]
) -> Mapping[str, temporalio.api.common.v1.Payload]:
Expand Down Expand Up @@ -168,9 +182,10 @@ def _start_as_current_span(
name: str,
*,
attributes: opentelemetry.util.types.Attributes,
input: Optional[_InputWithHeaders] = None,
input_with_headers: _InputWithHeaders | None = None,
input_with_ctx: _InputWithOperationContext | None = None,
kind: opentelemetry.trace.SpanKind,
context: Optional[Context] = None,
context: Context | None = None,
) -> Iterator[None]:
token = opentelemetry.context.attach(context) if context else None
try:
Expand All @@ -181,8 +196,19 @@ def _start_as_current_span(
context=context,
set_status_on_exception=False,
) as span:
if input:
input.headers = self._context_to_headers(input.headers)
if input_with_headers:
input_with_headers.headers = self._context_to_headers(
input_with_headers.headers
)
if input_with_ctx:
carrier: _CarrierDict = {}
self.text_map_propagator.inject(carrier)
input_with_ctx.ctx = dataclasses.replace(
input_with_ctx.ctx,
headers=_carrier_to_nexus_headers(
carrier, input_with_ctx.ctx.headers
),
)
try:
yield None
except Exception as exc:
Expand Down Expand Up @@ -260,7 +286,7 @@ async def start_workflow(
with self.root._start_as_current_span(
f"{prefix}:{input.workflow}",
attributes={"temporalWorkflowID": input.id},
input=input,
input_with_headers=input,
kind=opentelemetry.trace.SpanKind.CLIENT,
):
return await super().start_workflow(input)
Expand All @@ -269,7 +295,7 @@ async def query_workflow(self, input: temporalio.client.QueryWorkflowInput) -> A
with self.root._start_as_current_span(
f"QueryWorkflow:{input.query}",
attributes={"temporalWorkflowID": input.id},
input=input,
input_with_headers=input,
kind=opentelemetry.trace.SpanKind.CLIENT,
):
return await super().query_workflow(input)
Expand All @@ -280,7 +306,7 @@ async def signal_workflow(
with self.root._start_as_current_span(
f"SignalWorkflow:{input.signal}",
attributes={"temporalWorkflowID": input.id},
input=input,
input_with_headers=input,
kind=opentelemetry.trace.SpanKind.CLIENT,
):
return await super().signal_workflow(input)
Expand All @@ -291,7 +317,7 @@ async def start_workflow_update(
with self.root._start_as_current_span(
f"StartWorkflowUpdate:{input.update}",
attributes={"temporalWorkflowID": input.id},
input=input,
input_with_headers=input,
kind=opentelemetry.trace.SpanKind.CLIENT,
):
return await super().start_workflow_update(input)
Expand All @@ -308,7 +334,7 @@ async def start_update_with_start_workflow(
with self.root._start_as_current_span(
f"StartUpdateWithStartWorkflow:{input.start_workflow_input.workflow}",
attributes=attrs,
input=input.start_workflow_input,
input_with_headers=input.start_workflow_input,
kind=opentelemetry.trace.SpanKind.CLIENT,
):
otel_header = input.start_workflow_input.headers.get(self.root.header_key)
Expand Down Expand Up @@ -347,10 +373,60 @@ async def execute_activity(
return await super().execute_activity(input)


class _TracingNexusOperationInboundInterceptor(
temporalio.worker.NexusOperationInboundInterceptor
):
def __init__(
self,
next: temporalio.worker.NexusOperationInboundInterceptor,
root: TracingInterceptor,
) -> None:
self._next = next
self._root = root

def _context_from_nexus_headers(self, headers: Mapping[str, str]):
return self._root.text_map_propagator.extract(headers)

async def execute_nexus_operation_start(
self, input: temporalio.worker.ExecuteNexusOperationStartInput
) -> (
nexusrpc.handler.StartOperationResultSync[Any]
| nexusrpc.handler.StartOperationResultAsync
):
with self._root._start_as_current_span(
f"RunStartNexusOperationHandler:{input.ctx.service}/{input.ctx.operation}",
context=self._context_from_nexus_headers(input.ctx.headers),
attributes={},
input_with_ctx=input,
kind=opentelemetry.trace.SpanKind.SERVER,
):
return await self._next.execute_nexus_operation_start(input)

async def execute_nexus_operation_cancel(
self, input: temporalio.worker.ExecuteNexusOperationCancelInput
) -> None:
with self._root._start_as_current_span(
f"RunCancelNexusOperationHandler:{input.ctx.service}/{input.ctx.operation}",
context=self._context_from_nexus_headers(input.ctx.headers),
attributes={},
input_with_ctx=input,
kind=opentelemetry.trace.SpanKind.SERVER,
):
return await self._next.execute_nexus_operation_cancel(input)


class _InputWithHeaders(Protocol):
headers: Mapping[str, temporalio.api.common.v1.Payload]


class _InputWithStringHeaders(Protocol):
headers: Mapping[str, str] | None


class _InputWithOperationContext(Generic[_ContextT], Protocol):
ctx: _ContextT


class _WorkflowExternFunctions(TypedDict):
__temporal_opentelemetry_completed_span: Callable[
[_CompletedWorkflowSpanParams], Optional[_CarrierDict]
Expand Down Expand Up @@ -604,6 +680,7 @@ def _completed_span(
*,
link_context_carrier: Optional[_CarrierDict] = None,
add_to_outbound: Optional[_InputWithHeaders] = None,
add_to_outbound_str: Optional[_InputWithStringHeaders] = None,
new_span_even_on_replay: bool = False,
additional_attributes: opentelemetry.util.types.Attributes = None,
exception: Optional[Exception] = None,
Expand All @@ -616,12 +693,14 @@ def _completed_span(
# Create the span. First serialize current context to carrier.
new_context_carrier: _CarrierDict = {}
self.text_map_propagator.inject(new_context_carrier)

# Invoke
info = temporalio.workflow.info()
attributes: Dict[str, opentelemetry.util.types.AttributeValue] = {
"temporalWorkflowID": info.workflow_id,
"temporalRunID": info.run_id,
}

if additional_attributes:
attributes.update(additional_attributes)
updated_context_carrier = self._extern_functions[
Expand All @@ -642,10 +721,16 @@ def _completed_span(
)

# Add to outbound if needed
if add_to_outbound and updated_context_carrier:
add_to_outbound.headers = self._context_carrier_to_headers(
updated_context_carrier, add_to_outbound.headers
)
if updated_context_carrier:
if add_to_outbound:
add_to_outbound.headers = self._context_carrier_to_headers(
updated_context_carrier, add_to_outbound.headers
)

if add_to_outbound_str:
add_to_outbound_str.headers = _carrier_to_nexus_headers(
updated_context_carrier, add_to_outbound_str.headers
)

def _set_on_context(
self, context: opentelemetry.context.Context
Expand Down Expand Up @@ -724,6 +809,29 @@ def start_local_activity(
)
return super().start_local_activity(input)

async def start_nexus_operation(
self, input: temporalio.worker.StartNexusOperationInput[Any, Any]
) -> temporalio.workflow.NexusOperationHandle[Any]:
self.root._completed_span(
f"StartNexusOperation:{input.service}/{input.operation_name}",
kind=opentelemetry.trace.SpanKind.CLIENT,
add_to_outbound_str=input,
)

return await super().start_nexus_operation(input)


def _carrier_to_nexus_headers(
carrier: _CarrierDict, initial: Mapping[str, str] | None = None
) -> Mapping[str, str]:
out = {**initial} if initial else {}
for k, v in carrier.items():
if isinstance(v, list):
out[k] = ",".join(v)
else:
out[k] = v
return out


class workflow:
"""Contains static methods that are safe to call from within a workflow.
Expand Down
6 changes: 6 additions & 0 deletions temporalio/worker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@
ActivityOutboundInterceptor,
ContinueAsNewInput,
ExecuteActivityInput,
ExecuteNexusOperationCancelInput,
ExecuteNexusOperationStartInput,
ExecuteWorkflowInput,
HandleQueryInput,
HandleSignalInput,
HandleUpdateInput,
Interceptor,
NexusOperationInboundInterceptor,
SignalChildWorkflowInput,
SignalExternalWorkflowInput,
StartActivityInput,
Expand Down Expand Up @@ -80,6 +83,7 @@
"ActivityOutboundInterceptor",
"WorkflowInboundInterceptor",
"WorkflowOutboundInterceptor",
"NexusOperationInboundInterceptor",
"Plugin",
# Interceptor input
"ContinueAsNewInput",
Expand All @@ -95,6 +99,8 @@
"StartLocalActivityInput",
"StartNexusOperationInput",
"WorkflowInterceptorClassInput",
"ExecuteNexusOperationStartInput",
"ExecuteNexusOperationCancelInput",
# Advanced activity classes
"SharedStateManager",
"SharedHeartbeatSender",
Expand Down
61 changes: 61 additions & 0 deletions temporalio/worker/_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,20 @@ def workflow_interceptor_class(
"""
return None

def intercept_nexus_operation(
self, next: NexusOperationInboundInterceptor
) -> NexusOperationInboundInterceptor:
"""Method called for intercepting a Nexus operation.

Args:
next: The underlying inbound this interceptor
should delegate to.

Returns:
The new interceptor that should be used for the Nexus operation.
"""
return next


@dataclass(frozen=True)
class WorkflowInterceptorClassInput:
Expand Down Expand Up @@ -470,3 +484,50 @@ async def start_nexus_operation(
) -> temporalio.workflow.NexusOperationHandle[OutputT]:
"""Called for every :py:func:`temporalio.workflow.start_nexus_operation` call."""
return await self.next.start_nexus_operation(input)


@dataclass
class ExecuteNexusOperationStartInput:
"""Input for :pyt:meth:`NexusOperationInboundInterceptor.start_operation"""

ctx: nexusrpc.handler.StartOperationContext
input: Any


@dataclass
class ExecuteNexusOperationCancelInput:
"""Input for :pyt:meth:`NexusOperationInboundInterceptor.cancel_operation"""

ctx: nexusrpc.handler.CancelOperationContext
token: str


class NexusOperationInboundInterceptor:
"""Inbound interceptor to wrap Nexus operation starting and cancelling.

This should be extended by any Nexus operation inbound interceptors.
"""

def __init__(self, next: NexusOperationInboundInterceptor) -> None:
"""Create the inbound interceptor.

Args:
next: The next interceptor in the chain. The default implementation
of all calls is to delegate to the next interceptor.
"""
self.next = next

async def execute_nexus_operation_start(
self, input: ExecuteNexusOperationStartInput
) -> (
nexusrpc.handler.StartOperationResultSync[Any]
| nexusrpc.handler.StartOperationResultAsync
):
"""Called to start a Nexus operation"""
return await self.next.execute_nexus_operation_start(input)

async def execute_nexus_operation_cancel(
self, input: ExecuteNexusOperationCancelInput
) -> None:
"""Called to cancel an in progress Nexus operation"""
return await self.next.execute_nexus_operation_cancel(input)
Loading
Loading