22
33from __future__ import annotations
44
5+ import dataclasses
56from contextlib import contextmanager
67from dataclasses import dataclass
78from typing import (
1415 Optional ,
1516 Sequence ,
1617 Type ,
18+ TypeVar ,
1719 cast ,
1820)
1921
22+ import nexusrpc .handler
2023import opentelemetry .baggage .propagation
2124import opentelemetry .context
2225import opentelemetry .context .context
2528import opentelemetry .trace
2629import opentelemetry .trace .propagation .tracecontext
2730import opentelemetry .util .types
31+ from nexusrpc .handler import StartOperationResultAsync , StartOperationResultSync
2832from opentelemetry .context import Context
2933from opentelemetry .trace import Status , StatusCode
3034from typing_extensions import Protocol , TypeAlias , TypedDict
@@ -135,6 +139,11 @@ def workflow_interceptor_class(
135139 )
136140 return TracingWorkflowInboundInterceptor
137141
142+ def intercept_nexus_operation (
143+ self , next : temporalio .worker .NexusOperationInboundInterceptor
144+ ) -> temporalio .worker .NexusOperationInboundInterceptor :
145+ return _TracingNexusOperationInboundInterceptor (next , self )
146+
138147 def _context_to_headers (
139148 self , headers : Mapping [str , temporalio .api .common .v1 .Payload ]
140149 ) -> Mapping [str , temporalio .api .common .v1 .Payload ]:
@@ -201,6 +210,45 @@ def _start_as_current_span(
201210 if token and context is opentelemetry .context .get_current ():
202211 opentelemetry .context .detach (token )
203212
213+ @contextmanager
214+ def _start_as_current_span_nexus (
215+ self ,
216+ name : str ,
217+ * ,
218+ attributes : opentelemetry .util .types .Attributes ,
219+ input_headers : Mapping [str , str ],
220+ kind : opentelemetry .trace .SpanKind ,
221+ context : Optional [Context ] = None ,
222+ ) -> Iterator [_CarrierDict ]:
223+ token = opentelemetry .context .attach (context ) if context else None
224+ try :
225+ with self .tracer .start_as_current_span (
226+ name ,
227+ attributes = attributes ,
228+ kind = kind ,
229+ context = context ,
230+ set_status_on_exception = False ,
231+ ) as span :
232+ new_headers : _CarrierDict = {** input_headers }
233+ self .text_map_propagator .inject (new_headers )
234+ try :
235+ yield new_headers
236+ except Exception as exc :
237+ if (
238+ not isinstance (exc , ApplicationError )
239+ or exc .category != ApplicationErrorCategory .BENIGN
240+ ):
241+ span .set_status (
242+ Status (
243+ status_code = StatusCode .ERROR ,
244+ description = f"{ type (exc ).__name__ } : { exc } " ,
245+ )
246+ )
247+ raise
248+ finally :
249+ if token and context is opentelemetry .context .get_current ():
250+ opentelemetry .context .detach (token )
251+
204252 def _completed_workflow_span (
205253 self , params : _CompletedWorkflowSpanParams
206254 ) -> Optional [_CarrierDict ]:
@@ -347,6 +395,74 @@ async def execute_activity(
347395 return await super ().execute_activity (input )
348396
349397
398+ class _NexusTracing :
399+ _ContextT = TypeVar ("_ContextT" , bound = nexusrpc .handler .OperationContext )
400+
401+ # TODO(amazzeo): not sure what to do if value happens to be a list
402+ # _CarrierDict represents http headers Map[str, List[str] | str]
403+ # but nexus headers are just Map[str, str]
404+ def _carrier_to_nexus_headers (
405+ self , carrier : _CarrierDict , initial : Mapping [str , str ] | None = None
406+ ) -> Mapping [str , str ]:
407+ out = {** initial } if initial else {}
408+ for k , v in carrier .items ():
409+ if isinstance (v , list ):
410+ out [k ] = "," .join (v )
411+ else :
412+ out [k ] = v
413+ return out
414+
415+ def _operation_ctx_with_carrier (
416+ self , ctx : _ContextT , carrier : _CarrierDict
417+ ) -> _ContextT :
418+ return dataclasses .replace (
419+ ctx , headers = self ._carrier_to_nexus_headers (carrier , ctx .headers )
420+ )
421+
422+
423+ class _TracingNexusOperationInboundInterceptor (
424+ temporalio .worker .NexusOperationInboundInterceptor , _NexusTracing
425+ ):
426+ def __init__ (
427+ self ,
428+ next : temporalio .worker .NexusOperationInboundInterceptor ,
429+ root : TracingInterceptor ,
430+ ) -> None :
431+ self ._next = next
432+ self ._root = root
433+
434+ def _context_from_nexus_headers (self , headers : Mapping [str , str ]):
435+ return self ._root .text_map_propagator .extract (headers )
436+
437+ async def start_operation (
438+ self , input : temporalio .worker .NexusOperationStartInput
439+ ) -> StartOperationResultSync [Any ] | StartOperationResultAsync :
440+ with self ._root ._start_as_current_span_nexus (
441+ f"RunStartNexusOperationHandler:{ input .ctx .service } /{ input .ctx .operation } " ,
442+ context = self ._context_from_nexus_headers (input .ctx .headers ),
443+ attributes = {
444+ "temporalNexusRequestId" : input .ctx .request_id ,
445+ },
446+ input_headers = input .ctx .headers ,
447+ kind = opentelemetry .trace .SpanKind .SERVER ,
448+ ) as new_headers :
449+ input .ctx = self ._operation_ctx_with_carrier (input .ctx , new_headers )
450+ return await self ._next .start_operation (input )
451+
452+ async def cancel_operation (
453+ self , input : temporalio .worker .NexusOperationCancelInput
454+ ) -> None :
455+ with self ._root ._start_as_current_span_nexus (
456+ f"RunCancelNexusOperationHandler:{ input .ctx .service } /{ input .ctx .operation } " ,
457+ context = self ._context_from_nexus_headers (input .ctx .headers ),
458+ attributes = {},
459+ input_headers = input .ctx .headers ,
460+ kind = opentelemetry .trace .SpanKind .SERVER ,
461+ ) as new_headers :
462+ input .ctx = self ._operation_ctx_with_carrier (input .ctx , new_headers )
463+ return await self ._next .cancel_operation (input )
464+
465+
350466class _InputWithHeaders (Protocol ):
351467 headers : Mapping [str , temporalio .api .common .v1 .Payload ]
352468
@@ -417,7 +533,7 @@ async def execute_workflow(
417533 """
418534 with self ._top_level_workflow_context (success_is_complete = True ):
419535 # Entrypoint of workflow should be `server` in OTel
420- self ._completed_span (
536+ self ._completed_span_grpc (
421537 f"RunWorkflow:{ temporalio .workflow .info ().workflow_type } " ,
422538 kind = opentelemetry .trace .SpanKind .SERVER ,
423539 )
@@ -436,7 +552,7 @@ async def handle_signal(self, input: temporalio.worker.HandleSignalInput) -> Non
436552 [link_context_header ]
437553 )[0 ]
438554 with self ._top_level_workflow_context (success_is_complete = False ):
439- self ._completed_span (
555+ self ._completed_span_grpc (
440556 f"HandleSignal:{ input .signal } " ,
441557 link_context_carrier = link_context_carrier ,
442558 kind = opentelemetry .trace .SpanKind .SERVER ,
@@ -468,7 +584,7 @@ async def handle_query(self, input: temporalio.worker.HandleQueryInput) -> Any:
468584 token = opentelemetry .context .attach (context )
469585 try :
470586 # This won't be created if there was no context header
471- self ._completed_span (
587+ self ._completed_span_grpc (
472588 f"HandleQuery:{ input .query } " ,
473589 link_context_carrier = link_context_carrier ,
474590 # Create even on replay for queries
@@ -497,7 +613,7 @@ def handle_update_validator(
497613 [link_context_header ]
498614 )[0 ]
499615 with self ._top_level_workflow_context (success_is_complete = False ):
500- self ._completed_span (
616+ self ._completed_span_grpc (
501617 f"ValidateUpdate:{ input .update } " ,
502618 link_context_carrier = link_context_carrier ,
503619 kind = opentelemetry .trace .SpanKind .SERVER ,
@@ -517,7 +633,7 @@ async def handle_update_handler(
517633 [link_context_header ]
518634 )[0 ]
519635 with self ._top_level_workflow_context (success_is_complete = False ):
520- self ._completed_span (
636+ self ._completed_span_grpc (
521637 f"HandleUpdate:{ input .update } " ,
522638 link_context_carrier = link_context_carrier ,
523639 kind = opentelemetry .trace .SpanKind .SERVER ,
@@ -566,7 +682,7 @@ def _top_level_workflow_context(
566682 finally :
567683 # Create a completed span before detaching context
568684 if exception or (success and success_is_complete ):
569- self ._completed_span (
685+ self ._completed_span_grpc (
570686 f"CompleteWorkflow:{ temporalio .workflow .info ().workflow_type } " ,
571687 exception = exception ,
572688 kind = opentelemetry .trace .SpanKind .INTERNAL ,
@@ -598,7 +714,7 @@ def _context_carrier_to_headers(
598714 }
599715 return headers
600716
601- def _completed_span (
717+ def _completed_span_grpc (
602718 self ,
603719 span_name : str ,
604720 * ,
@@ -609,19 +725,75 @@ def _completed_span(
609725 exception : Optional [Exception ] = None ,
610726 kind : opentelemetry .trace .SpanKind = opentelemetry .trace .SpanKind .INTERNAL ,
611727 ) -> None :
728+ updated_context_carrier = self ._completed_span (
729+ span_name = span_name ,
730+ link_context_carrier = link_context_carrier ,
731+ new_span_even_on_replay = new_span_even_on_replay ,
732+ additional_attributes = additional_attributes ,
733+ exception = exception ,
734+ kind = kind ,
735+ )
736+
737+ # Add to outbound if needed
738+ if add_to_outbound and updated_context_carrier :
739+ add_to_outbound .headers = self ._context_carrier_to_headers (
740+ updated_context_carrier , add_to_outbound .headers
741+ )
742+
743+ def _completed_span_nexus (
744+ self ,
745+ span_name : str ,
746+ * ,
747+ outbound_headers : Mapping [str , str ],
748+ link_context_carrier : Optional [_CarrierDict ] = None ,
749+ new_span_even_on_replay : bool = False ,
750+ additional_attributes : opentelemetry .util .types .Attributes = None ,
751+ exception : Optional [Exception ] = None ,
752+ kind : opentelemetry .trace .SpanKind = opentelemetry .trace .SpanKind .INTERNAL ,
753+ ) -> _CarrierDict | None :
754+ new_carrier = self ._completed_span (
755+ span_name = span_name ,
756+ link_context_carrier = link_context_carrier ,
757+ new_span_even_on_replay = new_span_even_on_replay ,
758+ additional_attributes = additional_attributes ,
759+ exception = exception ,
760+ kind = kind ,
761+ )
762+
763+ if new_carrier :
764+ return {** outbound_headers , ** new_carrier }
765+ else :
766+ return {** outbound_headers }
767+
768+ def _completed_span (
769+ self ,
770+ span_name : str ,
771+ * ,
772+ link_context_carrier : Optional [_CarrierDict ] = None ,
773+ new_span_even_on_replay : bool = False ,
774+ additional_attributes : opentelemetry .util .types .Attributes = None ,
775+ exception : Optional [Exception ] = None ,
776+ kind : opentelemetry .trace .SpanKind = opentelemetry .trace .SpanKind .INTERNAL ,
777+ ) -> _CarrierDict | None :
612778 # If we are replaying and they don't want a span on replay, no span
613779 if temporalio .workflow .unsafe .is_replaying () and not new_span_even_on_replay :
614780 return None
615781
616782 # Create the span. First serialize current context to carrier.
617783 new_context_carrier : _CarrierDict = {}
618784 self .text_map_propagator .inject (new_context_carrier )
785+
619786 # Invoke
620- info = temporalio .workflow .info ()
621- attributes : Dict [str , opentelemetry .util .types .AttributeValue ] = {
622- "temporalWorkflowID" : info .workflow_id ,
623- "temporalRunID" : info .run_id ,
624- }
787+ # TODO(amazzeo): I think this try/except is necessary once non-workflow callers
788+ # are added to Nexus
789+ attributes : Dict [str , opentelemetry .util .types .AttributeValue ] = {}
790+ try :
791+ info = temporalio .workflow .info ()
792+ attributes ["temporalWorkflowID" ] = info .workflow_id
793+ attributes ["temporalRunID" ] = info .run_id
794+ except temporalio .exceptions .TemporalError :
795+ pass
796+
625797 if additional_attributes :
626798 attributes .update (additional_attributes )
627799 updated_context_carrier = self ._extern_functions [
@@ -641,11 +813,7 @@ def _completed_span(
641813 )
642814 )
643815
644- # Add to outbound if needed
645- if add_to_outbound and updated_context_carrier :
646- add_to_outbound .headers = self ._context_carrier_to_headers (
647- updated_context_carrier , add_to_outbound .headers
648- )
816+ return updated_context_carrier
649817
650818 def _set_on_context (
651819 self , context : opentelemetry .context .Context
@@ -654,7 +822,7 @@ def _set_on_context(
654822
655823
656824class _TracingWorkflowOutboundInterceptor (
657- temporalio .worker .WorkflowOutboundInterceptor
825+ temporalio .worker .WorkflowOutboundInterceptor , _NexusTracing
658826):
659827 def __init__ (
660828 self ,
@@ -673,7 +841,7 @@ async def signal_child_workflow(
673841 self , input : temporalio .worker .SignalChildWorkflowInput
674842 ) -> None :
675843 # Create new span and put on outbound input
676- self .root ._completed_span (
844+ self .root ._completed_span_grpc (
677845 f"SignalChildWorkflow:{ input .signal } " ,
678846 add_to_outbound = input ,
679847 kind = opentelemetry .trace .SpanKind .SERVER ,
@@ -684,7 +852,7 @@ async def signal_external_workflow(
684852 self , input : temporalio .worker .SignalExternalWorkflowInput
685853 ) -> None :
686854 # Create new span and put on outbound input
687- self .root ._completed_span (
855+ self .root ._completed_span_grpc (
688856 f"SignalExternalWorkflow:{ input .signal } " ,
689857 add_to_outbound = input ,
690858 kind = opentelemetry .trace .SpanKind .CLIENT ,
@@ -695,7 +863,7 @@ def start_activity(
695863 self , input : temporalio .worker .StartActivityInput
696864 ) -> temporalio .workflow .ActivityHandle :
697865 # Create new span and put on outbound input
698- self .root ._completed_span (
866+ self .root ._completed_span_grpc (
699867 f"StartActivity:{ input .activity } " ,
700868 add_to_outbound = input ,
701869 kind = opentelemetry .trace .SpanKind .CLIENT ,
@@ -706,7 +874,7 @@ async def start_child_workflow(
706874 self , input : temporalio .worker .StartChildWorkflowInput
707875 ) -> temporalio .workflow .ChildWorkflowHandle :
708876 # Create new span and put on outbound input
709- self .root ._completed_span (
877+ self .root ._completed_span_grpc (
710878 f"StartChildWorkflow:{ input .workflow } " ,
711879 add_to_outbound = input ,
712880 kind = opentelemetry .trace .SpanKind .CLIENT ,
@@ -717,13 +885,26 @@ def start_local_activity(
717885 self , input : temporalio .worker .StartLocalActivityInput
718886 ) -> temporalio .workflow .ActivityHandle :
719887 # Create new span and put on outbound input
720- self .root ._completed_span (
888+ self .root ._completed_span_grpc (
721889 f"StartActivity:{ input .activity } " ,
722890 add_to_outbound = input ,
723891 kind = opentelemetry .trace .SpanKind .CLIENT ,
724892 )
725893 return super ().start_local_activity (input )
726894
895+ async def start_nexus_operation (
896+ self , input : temporalio .worker .StartNexusOperationInput [Any , Any ]
897+ ) -> temporalio .workflow .NexusOperationHandle [Any ]:
898+ new_carrier = self .root ._completed_span_nexus (
899+ f"StartNexusOperation:{ input .service } /{ input .operation_name } " ,
900+ kind = opentelemetry .trace .SpanKind .CLIENT ,
901+ outbound_headers = input .headers if input .headers else {},
902+ )
903+ if new_carrier :
904+ input .headers = self ._carrier_to_nexus_headers (new_carrier , input .headers )
905+
906+ return await super ().start_nexus_operation (input )
907+
727908
728909class workflow :
729910 """Contains static methods that are safe to call from within a workflow.
@@ -760,6 +941,6 @@ def completed_span(
760941 """
761942 interceptor = TracingWorkflowInboundInterceptor ._from_context ()
762943 if interceptor :
763- interceptor ._completed_span (
944+ interceptor ._completed_span_grpc (
764945 name , additional_attributes = attributes , exception = exception
765946 )
0 commit comments