Skip to content

Commit 32a26ee

Browse files
committed
First draft of nexus interceptors and otel support
1 parent f9fdb88 commit 32a26ee

File tree

6 files changed

+493
-41
lines changed

6 files changed

+493
-41
lines changed

temporalio/contrib/opentelemetry.py

Lines changed: 205 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import dataclasses
56
from contextlib import contextmanager
67
from dataclasses import dataclass
78
from typing import (
@@ -14,9 +15,11 @@
1415
Optional,
1516
Sequence,
1617
Type,
18+
TypeVar,
1719
cast,
1820
)
1921

22+
import nexusrpc.handler
2023
import opentelemetry.baggage.propagation
2124
import opentelemetry.context
2225
import opentelemetry.context.context
@@ -25,6 +28,7 @@
2528
import opentelemetry.trace
2629
import opentelemetry.trace.propagation.tracecontext
2730
import opentelemetry.util.types
31+
from nexusrpc.handler import StartOperationResultAsync, StartOperationResultSync
2832
from opentelemetry.context import Context
2933
from opentelemetry.trace import Status, StatusCode
3034
from typing_extensions import Protocol, TypeAlias, TypedDict
@@ -135,6 +139,11 @@ def workflow_interceptor_class(
135139
)
136140
return TracingWorkflowInboundInterceptor
137141

142+
def intercept_nexus_operation(
143+
self, next: temporalio.worker.NexusOperationInboundInterceptor
144+
) -> temporalio.worker.NexusOperationInboundInterceptor:
145+
return _TracingNexusOperationInboundInterceptor(next, self)
146+
138147
def _context_to_headers(
139148
self, headers: Mapping[str, temporalio.api.common.v1.Payload]
140149
) -> Mapping[str, temporalio.api.common.v1.Payload]:
@@ -201,6 +210,45 @@ def _start_as_current_span(
201210
if token and context is opentelemetry.context.get_current():
202211
opentelemetry.context.detach(token)
203212

213+
@contextmanager
214+
def _start_as_current_span_nexus(
215+
self,
216+
name: str,
217+
*,
218+
attributes: opentelemetry.util.types.Attributes,
219+
input_headers: Mapping[str, str],
220+
kind: opentelemetry.trace.SpanKind,
221+
context: Optional[Context] = None,
222+
) -> Iterator[_CarrierDict]:
223+
token = opentelemetry.context.attach(context) if context else None
224+
try:
225+
with self.tracer.start_as_current_span(
226+
name,
227+
attributes=attributes,
228+
kind=kind,
229+
context=context,
230+
set_status_on_exception=False,
231+
) as span:
232+
new_headers: _CarrierDict = {**input_headers}
233+
self.text_map_propagator.inject(new_headers)
234+
try:
235+
yield new_headers
236+
except Exception as exc:
237+
if (
238+
not isinstance(exc, ApplicationError)
239+
or exc.category != ApplicationErrorCategory.BENIGN
240+
):
241+
span.set_status(
242+
Status(
243+
status_code=StatusCode.ERROR,
244+
description=f"{type(exc).__name__}: {exc}",
245+
)
246+
)
247+
raise
248+
finally:
249+
if token and context is opentelemetry.context.get_current():
250+
opentelemetry.context.detach(token)
251+
204252
def _completed_workflow_span(
205253
self, params: _CompletedWorkflowSpanParams
206254
) -> Optional[_CarrierDict]:
@@ -347,6 +395,74 @@ async def execute_activity(
347395
return await super().execute_activity(input)
348396

349397

398+
class _NexusTracing:
399+
_ContextT = TypeVar("_ContextT", bound=nexusrpc.handler.OperationContext)
400+
401+
# TODO(amazzeo): not sure what to do if value happens to be a list
402+
# _CarrierDict represents http headers Map[str, List[str] | str]
403+
# but nexus headers are just Map[str, str]
404+
def _carrier_to_nexus_headers(
405+
self, carrier: _CarrierDict, initial: Mapping[str, str] | None = None
406+
) -> Mapping[str, str]:
407+
out = {**initial} if initial else {}
408+
for k, v in carrier.items():
409+
if isinstance(v, list):
410+
out[k] = ",".join(v)
411+
else:
412+
out[k] = v
413+
return out
414+
415+
def _operation_ctx_with_carrier(
416+
self, ctx: _ContextT, carrier: _CarrierDict
417+
) -> _ContextT:
418+
return dataclasses.replace(
419+
ctx, headers=self._carrier_to_nexus_headers(carrier, ctx.headers)
420+
)
421+
422+
423+
class _TracingNexusOperationInboundInterceptor(
424+
temporalio.worker.NexusOperationInboundInterceptor, _NexusTracing
425+
):
426+
def __init__(
427+
self,
428+
next: temporalio.worker.NexusOperationInboundInterceptor,
429+
root: TracingInterceptor,
430+
) -> None:
431+
self._next = next
432+
self._root = root
433+
434+
def _context_from_nexus_headers(self, headers: Mapping[str, str]):
435+
return self._root.text_map_propagator.extract(headers)
436+
437+
async def start_operation(
438+
self, input: temporalio.worker.NexusOperationStartInput
439+
) -> StartOperationResultSync[Any] | StartOperationResultAsync:
440+
with self._root._start_as_current_span_nexus(
441+
f"RunStartNexusOperationHandler:{input.ctx.service}/{input.ctx.operation}",
442+
context=self._context_from_nexus_headers(input.ctx.headers),
443+
attributes={
444+
"temporalNexusRequestId": input.ctx.request_id,
445+
},
446+
input_headers=input.ctx.headers,
447+
kind=opentelemetry.trace.SpanKind.SERVER,
448+
) as new_headers:
449+
input.ctx = self._operation_ctx_with_carrier(input.ctx, new_headers)
450+
return await self._next.start_operation(input)
451+
452+
async def cancel_operation(
453+
self, input: temporalio.worker.NexusOperationCancelInput
454+
) -> None:
455+
with self._root._start_as_current_span_nexus(
456+
f"RunCancelNexusOperationHandler:{input.ctx.service}/{input.ctx.operation}",
457+
context=self._context_from_nexus_headers(input.ctx.headers),
458+
attributes={},
459+
input_headers=input.ctx.headers,
460+
kind=opentelemetry.trace.SpanKind.SERVER,
461+
) as new_headers:
462+
input.ctx = self._operation_ctx_with_carrier(input.ctx, new_headers)
463+
return await self._next.cancel_operation(input)
464+
465+
350466
class _InputWithHeaders(Protocol):
351467
headers: Mapping[str, temporalio.api.common.v1.Payload]
352468

@@ -417,7 +533,7 @@ async def execute_workflow(
417533
"""
418534
with self._top_level_workflow_context(success_is_complete=True):
419535
# Entrypoint of workflow should be `server` in OTel
420-
self._completed_span(
536+
self._completed_span_grpc(
421537
f"RunWorkflow:{temporalio.workflow.info().workflow_type}",
422538
kind=opentelemetry.trace.SpanKind.SERVER,
423539
)
@@ -436,7 +552,7 @@ async def handle_signal(self, input: temporalio.worker.HandleSignalInput) -> Non
436552
[link_context_header]
437553
)[0]
438554
with self._top_level_workflow_context(success_is_complete=False):
439-
self._completed_span(
555+
self._completed_span_grpc(
440556
f"HandleSignal:{input.signal}",
441557
link_context_carrier=link_context_carrier,
442558
kind=opentelemetry.trace.SpanKind.SERVER,
@@ -468,7 +584,7 @@ async def handle_query(self, input: temporalio.worker.HandleQueryInput) -> Any:
468584
token = opentelemetry.context.attach(context)
469585
try:
470586
# This won't be created if there was no context header
471-
self._completed_span(
587+
self._completed_span_grpc(
472588
f"HandleQuery:{input.query}",
473589
link_context_carrier=link_context_carrier,
474590
# Create even on replay for queries
@@ -497,7 +613,7 @@ def handle_update_validator(
497613
[link_context_header]
498614
)[0]
499615
with self._top_level_workflow_context(success_is_complete=False):
500-
self._completed_span(
616+
self._completed_span_grpc(
501617
f"ValidateUpdate:{input.update}",
502618
link_context_carrier=link_context_carrier,
503619
kind=opentelemetry.trace.SpanKind.SERVER,
@@ -517,7 +633,7 @@ async def handle_update_handler(
517633
[link_context_header]
518634
)[0]
519635
with self._top_level_workflow_context(success_is_complete=False):
520-
self._completed_span(
636+
self._completed_span_grpc(
521637
f"HandleUpdate:{input.update}",
522638
link_context_carrier=link_context_carrier,
523639
kind=opentelemetry.trace.SpanKind.SERVER,
@@ -566,7 +682,7 @@ def _top_level_workflow_context(
566682
finally:
567683
# Create a completed span before detaching context
568684
if exception or (success and success_is_complete):
569-
self._completed_span(
685+
self._completed_span_grpc(
570686
f"CompleteWorkflow:{temporalio.workflow.info().workflow_type}",
571687
exception=exception,
572688
kind=opentelemetry.trace.SpanKind.INTERNAL,
@@ -598,7 +714,7 @@ def _context_carrier_to_headers(
598714
}
599715
return headers
600716

601-
def _completed_span(
717+
def _completed_span_grpc(
602718
self,
603719
span_name: str,
604720
*,
@@ -609,19 +725,75 @@ def _completed_span(
609725
exception: Optional[Exception] = None,
610726
kind: opentelemetry.trace.SpanKind = opentelemetry.trace.SpanKind.INTERNAL,
611727
) -> None:
728+
updated_context_carrier = self._completed_span(
729+
span_name=span_name,
730+
link_context_carrier=link_context_carrier,
731+
new_span_even_on_replay=new_span_even_on_replay,
732+
additional_attributes=additional_attributes,
733+
exception=exception,
734+
kind=kind,
735+
)
736+
737+
# Add to outbound if needed
738+
if add_to_outbound and updated_context_carrier:
739+
add_to_outbound.headers = self._context_carrier_to_headers(
740+
updated_context_carrier, add_to_outbound.headers
741+
)
742+
743+
def _completed_span_nexus(
744+
self,
745+
span_name: str,
746+
*,
747+
outbound_headers: Mapping[str, str],
748+
link_context_carrier: Optional[_CarrierDict] = None,
749+
new_span_even_on_replay: bool = False,
750+
additional_attributes: opentelemetry.util.types.Attributes = None,
751+
exception: Optional[Exception] = None,
752+
kind: opentelemetry.trace.SpanKind = opentelemetry.trace.SpanKind.INTERNAL,
753+
) -> _CarrierDict | None:
754+
new_carrier = self._completed_span(
755+
span_name=span_name,
756+
link_context_carrier=link_context_carrier,
757+
new_span_even_on_replay=new_span_even_on_replay,
758+
additional_attributes=additional_attributes,
759+
exception=exception,
760+
kind=kind,
761+
)
762+
763+
if new_carrier:
764+
return {**outbound_headers, **new_carrier}
765+
else:
766+
return {**outbound_headers}
767+
768+
def _completed_span(
769+
self,
770+
span_name: str,
771+
*,
772+
link_context_carrier: Optional[_CarrierDict] = None,
773+
new_span_even_on_replay: bool = False,
774+
additional_attributes: opentelemetry.util.types.Attributes = None,
775+
exception: Optional[Exception] = None,
776+
kind: opentelemetry.trace.SpanKind = opentelemetry.trace.SpanKind.INTERNAL,
777+
) -> _CarrierDict | None:
612778
# If we are replaying and they don't want a span on replay, no span
613779
if temporalio.workflow.unsafe.is_replaying() and not new_span_even_on_replay:
614780
return None
615781

616782
# Create the span. First serialize current context to carrier.
617783
new_context_carrier: _CarrierDict = {}
618784
self.text_map_propagator.inject(new_context_carrier)
785+
619786
# Invoke
620-
info = temporalio.workflow.info()
621-
attributes: Dict[str, opentelemetry.util.types.AttributeValue] = {
622-
"temporalWorkflowID": info.workflow_id,
623-
"temporalRunID": info.run_id,
624-
}
787+
# TODO(amazzeo): I think this try/except is necessary once non-workflow callers
788+
# are added to Nexus
789+
attributes: Dict[str, opentelemetry.util.types.AttributeValue] = {}
790+
try:
791+
info = temporalio.workflow.info()
792+
attributes["temporalWorkflowID"] = info.workflow_id
793+
attributes["temporalRunID"] = info.run_id
794+
except temporalio.exceptions.TemporalError:
795+
pass
796+
625797
if additional_attributes:
626798
attributes.update(additional_attributes)
627799
updated_context_carrier = self._extern_functions[
@@ -641,11 +813,7 @@ def _completed_span(
641813
)
642814
)
643815

644-
# Add to outbound if needed
645-
if add_to_outbound and updated_context_carrier:
646-
add_to_outbound.headers = self._context_carrier_to_headers(
647-
updated_context_carrier, add_to_outbound.headers
648-
)
816+
return updated_context_carrier
649817

650818
def _set_on_context(
651819
self, context: opentelemetry.context.Context
@@ -654,7 +822,7 @@ def _set_on_context(
654822

655823

656824
class _TracingWorkflowOutboundInterceptor(
657-
temporalio.worker.WorkflowOutboundInterceptor
825+
temporalio.worker.WorkflowOutboundInterceptor, _NexusTracing
658826
):
659827
def __init__(
660828
self,
@@ -673,7 +841,7 @@ async def signal_child_workflow(
673841
self, input: temporalio.worker.SignalChildWorkflowInput
674842
) -> None:
675843
# Create new span and put on outbound input
676-
self.root._completed_span(
844+
self.root._completed_span_grpc(
677845
f"SignalChildWorkflow:{input.signal}",
678846
add_to_outbound=input,
679847
kind=opentelemetry.trace.SpanKind.SERVER,
@@ -684,7 +852,7 @@ async def signal_external_workflow(
684852
self, input: temporalio.worker.SignalExternalWorkflowInput
685853
) -> None:
686854
# Create new span and put on outbound input
687-
self.root._completed_span(
855+
self.root._completed_span_grpc(
688856
f"SignalExternalWorkflow:{input.signal}",
689857
add_to_outbound=input,
690858
kind=opentelemetry.trace.SpanKind.CLIENT,
@@ -695,7 +863,7 @@ def start_activity(
695863
self, input: temporalio.worker.StartActivityInput
696864
) -> temporalio.workflow.ActivityHandle:
697865
# Create new span and put on outbound input
698-
self.root._completed_span(
866+
self.root._completed_span_grpc(
699867
f"StartActivity:{input.activity}",
700868
add_to_outbound=input,
701869
kind=opentelemetry.trace.SpanKind.CLIENT,
@@ -706,7 +874,7 @@ async def start_child_workflow(
706874
self, input: temporalio.worker.StartChildWorkflowInput
707875
) -> temporalio.workflow.ChildWorkflowHandle:
708876
# Create new span and put on outbound input
709-
self.root._completed_span(
877+
self.root._completed_span_grpc(
710878
f"StartChildWorkflow:{input.workflow}",
711879
add_to_outbound=input,
712880
kind=opentelemetry.trace.SpanKind.CLIENT,
@@ -717,13 +885,26 @@ def start_local_activity(
717885
self, input: temporalio.worker.StartLocalActivityInput
718886
) -> temporalio.workflow.ActivityHandle:
719887
# Create new span and put on outbound input
720-
self.root._completed_span(
888+
self.root._completed_span_grpc(
721889
f"StartActivity:{input.activity}",
722890
add_to_outbound=input,
723891
kind=opentelemetry.trace.SpanKind.CLIENT,
724892
)
725893
return super().start_local_activity(input)
726894

895+
async def start_nexus_operation(
896+
self, input: temporalio.worker.StartNexusOperationInput[Any, Any]
897+
) -> temporalio.workflow.NexusOperationHandle[Any]:
898+
new_carrier = self.root._completed_span_nexus(
899+
f"StartNexusOperation:{input.service}/{input.operation_name}",
900+
kind=opentelemetry.trace.SpanKind.CLIENT,
901+
outbound_headers=input.headers if input.headers else {},
902+
)
903+
if new_carrier:
904+
input.headers = self._carrier_to_nexus_headers(new_carrier, input.headers)
905+
906+
return await super().start_nexus_operation(input)
907+
727908

728909
class workflow:
729910
"""Contains static methods that are safe to call from within a workflow.
@@ -760,6 +941,6 @@ def completed_span(
760941
"""
761942
interceptor = TracingWorkflowInboundInterceptor._from_context()
762943
if interceptor:
763-
interceptor._completed_span(
944+
interceptor._completed_span_grpc(
764945
name, additional_attributes=attributes, exception=exception
765946
)

0 commit comments

Comments
 (0)