5454from typing import Callable , Collection , Iterable , Optional
5555from urllib .parse import urlparse
5656
57- from requests .models import Response
57+ from requests .models import PreparedRequest , Response
5858from requests .sessions import Session
5959from requests .structures import CaseInsensitiveDict
6060
8585
8686_excluded_urls_from_env = get_excluded_urls ("REQUESTS" )
8787
88+ _RequestHookT = Optional [Callable [[Span , PreparedRequest ], None ]]
89+ _ResponseHookT = Optional [Callable [[Span , PreparedRequest ], None ]]
90+
8891
8992# pylint: disable=unused-argument
9093# pylint: disable=R0915
9194def _instrument (
9295 tracer : Tracer ,
9396 duration_histogram : Histogram ,
94- span_callback : Optional [ Callable [[ Span , Response ], str ]] = None ,
95- name_callback : Optional [ Callable [[ str , str ], str ]] = None ,
97+ request_hook : _RequestHookT = None ,
98+ response_hook : _ResponseHookT = None ,
9699 excluded_urls : Iterable [str ] = None ,
97100):
98101 """Enables tracing of all requests calls that go through
@@ -106,29 +109,9 @@ def _instrument(
106109 # before v1.0.0, Dec 17, 2012, see
107110 # https://github.com/psf/requests/commit/4e5c4a6ab7bb0195dececdd19bb8505b872fe120)
108111
109- wrapped_request = Session .request
110112 wrapped_send = Session .send
111113
112- @functools .wraps (wrapped_request )
113- def instrumented_request (self , method , url , * args , ** kwargs ):
114- if excluded_urls and excluded_urls .url_disabled (url ):
115- return wrapped_request (self , method , url , * args , ** kwargs )
116-
117- def get_or_create_headers ():
118- headers = kwargs .get ("headers" )
119- if headers is None :
120- headers = {}
121- kwargs ["headers" ] = headers
122-
123- return headers
124-
125- def call_wrapped ():
126- return wrapped_request (self , method , url , * args , ** kwargs )
127-
128- return _instrumented_requests_call (
129- method , url , call_wrapped , get_or_create_headers
130- )
131-
114+ # pylint: disable-msg=too-many-locals,too-many-branches
132115 @functools .wraps (wrapped_send )
133116 def instrumented_send (self , request , ** kwargs ):
134117 if excluded_urls and excluded_urls .url_disabled (request .url ):
@@ -142,32 +125,17 @@ def get_or_create_headers():
142125 )
143126 return request .headers
144127
145- def call_wrapped ():
146- return wrapped_send (self , request , ** kwargs )
147-
148- return _instrumented_requests_call (
149- request .method , request .url , call_wrapped , get_or_create_headers
150- )
151-
152- # pylint: disable-msg=too-many-locals,too-many-branches
153- def _instrumented_requests_call (
154- method : str , url : str , call_wrapped , get_or_create_headers
155- ):
156128 if context .get_value (
157129 _SUPPRESS_INSTRUMENTATION_KEY
158130 ) or context .get_value (_SUPPRESS_HTTP_INSTRUMENTATION_KEY ):
159- return call_wrapped ( )
131+ return wrapped_send ( self , request , ** kwargs )
160132
161133 # See
162134 # https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/trace/semantic_conventions/http.md#http-client
163- method = method .upper ()
164- span_name = ""
165- if name_callback is not None :
166- span_name = name_callback (method , url )
167- if not span_name or not isinstance (span_name , str ):
168- span_name = get_default_span_name (method )
135+ method = request .method .upper ()
136+ span_name = get_default_span_name (method )
169137
170- url = remove_url_credentials (url )
138+ url = remove_url_credentials (request . url )
171139
172140 span_attributes = {
173141 SpanAttributes .HTTP_METHOD : method ,
@@ -195,6 +163,8 @@ def _instrumented_requests_call(
195163 span_name , kind = SpanKind .CLIENT , attributes = span_attributes
196164 ) as span , set_ip_on_next_http_connection (span ):
197165 exception = None
166+ if callable (request_hook ):
167+ request_hook (span , request )
198168
199169 headers = get_or_create_headers ()
200170 inject (headers )
@@ -206,7 +176,7 @@ def _instrumented_requests_call(
206176 start_time = default_timer ()
207177
208178 try :
209- result = call_wrapped ( ) # *** PROCEED
179+ result = wrapped_send ( self , request , ** kwargs ) # *** PROCEED
210180 except Exception as exc : # pylint: disable=W0703
211181 exception = exc
212182 result = getattr (exc , "response" , None )
@@ -236,8 +206,8 @@ def _instrumented_requests_call(
236206 "1.1" if version == 11 else "1.0"
237207 )
238208
239- if span_callback is not None :
240- span_callback (span , result )
209+ if callable ( response_hook ) :
210+ response_hook (span , request , result )
241211
242212 duration_histogram .record (elapsed_time , attributes = metric_labels )
243213
@@ -246,9 +216,6 @@ def _instrumented_requests_call(
246216
247217 return result
248218
249- instrumented_request .opentelemetry_instrumentation_requests_applied = True
250- Session .request = instrumented_request
251-
252219 instrumented_send .opentelemetry_instrumentation_requests_applied = True
253220 Session .send = instrumented_send
254221
@@ -295,10 +262,8 @@ def _instrument(self, **kwargs):
295262 Args:
296263 **kwargs: Optional arguments
297264 ``tracer_provider``: a TracerProvider, defaults to global
298- ``span_callback``: An optional callback invoked before returning the http response. Invoked with Span and requests.Response
299- ``name_callback``: Callback which calculates a generic span name for an
300- outgoing HTTP request based on the method and url.
301- Optional: Defaults to get_default_span_name.
265+ ``request_hook``: An optional callback that is invoked right after a span is created.
266+ ``response_hook``: An optional callback which is invoked right before the span is finished processing a response.
302267 ``excluded_urls``: A string containing a comma-delimited
303268 list of regexes used to exclude URLs from tracking
304269 """
@@ -319,8 +284,8 @@ def _instrument(self, **kwargs):
319284 _instrument (
320285 tracer ,
321286 duration_histogram ,
322- span_callback = kwargs .get ("span_callback " ),
323- name_callback = kwargs .get ("name_callback " ),
287+ request_hook = kwargs .get ("request_hook " ),
288+ response_hook = kwargs .get ("response_hook " ),
324289 excluded_urls = _excluded_urls_from_env
325290 if excluded_urls is None
326291 else parse_excluded_urls (excluded_urls ),
0 commit comments