diff --git a/CHANGELOG.md b/CHANGELOG.md index 8d79a899ba..54d4e1ab5a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +- `opentelemetry-instrumentation-aiohttp-server`: Support passing `TracerProvider` when instrumenting. + ([#3819](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3819)) + + ## Version 1.39.0/0.60b0 (2025-12-03) ### Added diff --git a/instrumentation/opentelemetry-instrumentation-aiohttp-server/src/opentelemetry/instrumentation/aiohttp_server/__init__.py b/instrumentation/opentelemetry-instrumentation-aiohttp-server/src/opentelemetry/instrumentation/aiohttp_server/__init__.py index 30f967d39f..2873475d55 100644 --- a/instrumentation/opentelemetry-instrumentation-aiohttp-server/src/opentelemetry/instrumentation/aiohttp_server/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-aiohttp-server/src/opentelemetry/instrumentation/aiohttp_server/__init__.py @@ -25,8 +25,14 @@ from opentelemetry.instrumentation.aiohttp_server import ( AioHttpServerInstrumentor ) + from opentelemetry.sdk.resources import Resource + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.sampling import ParentBased, TraceIdRatioBased - AioHttpServerInstrumentor().instrument() + # Optional: configure non-default TracerProvider, resource, sampler + resource = Resource(attributes={"service.name": "my-aiohttp-service"}) + sampler = ParentBased(root=TraceIdRatioBased(rate=0.25)) # sample 25% of traces + AioHttpServerInstrumentor().instrument(tracer_provider=TracerProvider(resource=resource, sampler=sampler)) async def hello(request): return web.Response(text="Hello, world") @@ -152,6 +158,7 @@ async def hello(request): import urllib from timeit import default_timer +from typing import Optional from aiohttp import web from multidict import CIMultiDictProxy @@ -387,73 +394,95 @@ def keys(self, carrier: dict) -> list: getter = AiohttpGetter() -@web.middleware -async def middleware(request, handler): - """Middleware for aiohttp implementing tracing logic""" - if not is_http_instrumentation_enabled() or _excluded_urls.url_disabled( - request.url.path - ): - return await handler(request) - - span_name = get_default_span_name(request) - - request_attrs = collect_request_attributes(request) - duration_attrs = _parse_duration_attrs(request_attrs) - active_requests_count_attrs = _parse_active_request_count_attrs( - request_attrs +def create_aiohttp_middleware( + tracer_provider: Optional[trace.TracerProvider] = None, +): + _tracer = ( + tracer_provider.get_tracer(__name__, __version__) + if tracer_provider + else tracer ) - duration_histogram = meter.create_histogram( - name=MetricInstruments.HTTP_SERVER_DURATION, - unit="ms", - description="Measures the duration of inbound HTTP requests.", - ) + @web.middleware + async def _middleware(request, handler): + """Middleware for aiohttp implementing tracing logic""" + if ( + not is_http_instrumentation_enabled() + or _excluded_urls.url_disabled(request.url.path) + ): + return await handler(request) + + span_name = get_default_span_name(request) + + request_attrs = collect_request_attributes(request) + duration_attrs = _parse_duration_attrs(request_attrs) + active_requests_count_attrs = _parse_active_request_count_attrs( + request_attrs + ) - active_requests_counter = meter.create_up_down_counter( - name=MetricInstruments.HTTP_SERVER_ACTIVE_REQUESTS, - unit="requests", - description="measures the number of concurrent HTTP requests those are currently in flight", - ) + duration_histogram = meter.create_histogram( + name=MetricInstruments.HTTP_SERVER_DURATION, + unit="ms", + description="Measures the duration of inbound HTTP requests.", + ) - with tracer.start_as_current_span( - span_name, - context=extract(request, getter=getter), - kind=trace.SpanKind.SERVER, - ) as span: - if span.is_recording(): - request_headers_attributes = collect_request_headers_attributes( - request - ) - request_attrs.update(request_headers_attributes) - span.set_attributes(request_attrs) - start = default_timer() - active_requests_counter.add(1, active_requests_count_attrs) - try: - resp = await handler(request) - set_status_code(span, resp.status) + active_requests_counter = meter.create_up_down_counter( + name=MetricInstruments.HTTP_SERVER_ACTIVE_REQUESTS, + unit="requests", + description="measures the number of concurrent HTTP requests those are currently in flight", + ) + + with _tracer.start_as_current_span( + span_name, + context=extract(request, getter=getter), + kind=trace.SpanKind.SERVER, + ) as span: if span.is_recording(): - response_headers_attributes = ( - collect_response_headers_attributes(resp) + request_headers_attributes = ( + collect_request_headers_attributes(request) ) - span.set_attributes(response_headers_attributes) - except web.HTTPException as ex: - set_status_code(span, ex.status_code) - raise - finally: - duration = max((default_timer() - start) * 1000, 0) - duration_histogram.record(duration, duration_attrs) - active_requests_counter.add(-1, active_requests_count_attrs) - return resp - - -class _InstrumentedApplication(web.Application): - """Insert tracing middleware""" - - def __init__(self, *args, **kwargs): - middlewares = kwargs.pop("middlewares", []) - middlewares.insert(0, middleware) - kwargs["middlewares"] = middlewares - super().__init__(*args, **kwargs) + request_attrs.update(request_headers_attributes) + span.set_attributes(request_attrs) + start = default_timer() + active_requests_counter.add(1, active_requests_count_attrs) + try: + resp = await handler(request) + set_status_code(span, resp.status) + if span.is_recording(): + response_headers_attributes = ( + collect_response_headers_attributes(resp) + ) + span.set_attributes(response_headers_attributes) + except web.HTTPException as ex: + set_status_code(span, ex.status_code) + raise + finally: + duration = max((default_timer() - start) * 1000, 0) + duration_histogram.record(duration, duration_attrs) + active_requests_counter.add(-1, active_requests_count_attrs) + return resp + + return _middleware + + +middleware = create_aiohttp_middleware() # for backwards compatibility + + +def create_instrumented_application( + tracer_provider: Optional[trace.TracerProvider] = None, +): + _middleware = create_aiohttp_middleware(tracer_provider=tracer_provider) + + class _InstrumentedApplication(web.Application): + """Insert tracing middleware""" + + def __init__(self, *args, **kwargs): + middlewares = kwargs.pop("middlewares", []) + middlewares.insert(0, _middleware) + kwargs["middlewares"] = middlewares + super().__init__(*args, **kwargs) + + return _InstrumentedApplication class AioHttpServerInstrumentor(BaseInstrumentor): @@ -464,6 +493,7 @@ class AioHttpServerInstrumentor(BaseInstrumentor): """ def _instrument(self, **kwargs): + tracer_provider = kwargs.get("tracer_provider", None) # update global values at instrument time so we can test them global _excluded_urls # pylint: disable=global-statement _excluded_urls = get_excluded_urls("AIOHTTP_SERVER") @@ -475,6 +505,10 @@ def _instrument(self, **kwargs): meter = metrics.get_meter(__name__, __version__) self._original_app = web.Application + + _InstrumentedApplication = create_instrumented_application( + tracer_provider=tracer_provider + ) setattr(web, "Application", _InstrumentedApplication) def _uninstrument(self, **kwargs): diff --git a/instrumentation/opentelemetry-instrumentation-aiohttp-server/tests/test_aiohttp_server_integration.py b/instrumentation/opentelemetry-instrumentation-aiohttp-server/tests/test_aiohttp_server_integration.py index 1528edf012..1baa9e7079 100644 --- a/instrumentation/opentelemetry-instrumentation-aiohttp-server/tests/test_aiohttp_server_integration.py +++ b/instrumentation/opentelemetry-instrumentation-aiohttp-server/tests/test_aiohttp_server_integration.py @@ -26,6 +26,7 @@ AioHttpServerInstrumentor, ) from opentelemetry.instrumentation.utils import suppress_http_instrumentation +from opentelemetry.sdk.trace.sampling import ParentBased, TraceIdRatioBased from opentelemetry.semconv._incubating.attributes.http_attributes import ( HTTP_METHOD, HTTP_STATUS_CODE, @@ -101,9 +102,9 @@ def fixture_suppress(): @pytest_asyncio.fixture(name="server_fixture") async def fixture_server_fixture(tracer, aiohttp_server, suppress): - _, memory_exporter = tracer + tracer_provider, memory_exporter = tracer - AioHttpServerInstrumentor().instrument() + AioHttpServerInstrumentor().instrument(tracer_provider=tracer_provider) app = aiohttp.web.Application() app.add_routes([aiohttp.web.get("/test-path", default_handler)]) @@ -228,20 +229,6 @@ async def handler(request): memory_exporter.clear() -def _get_sorted_metrics(metrics_data): - resource_metrics = metrics_data.resource_metrics if metrics_data else [] - - all_metrics = [] - for metrics in resource_metrics: - for scope_metrics in metrics.scope_metrics: - all_metrics.extend(scope_metrics.metrics) - - return sorted( - all_metrics, - key=lambda m: m.name, - ) - - @pytest.mark.asyncio @pytest.mark.parametrize( "env_var", @@ -281,6 +268,59 @@ async def handler(request): AioHttpServerInstrumentor().uninstrument() +@pytest.mark.asyncio +@pytest.mark.parametrize( + "tracer", + [ + TestBase().create_tracer_provider( + sampler=ParentBased(TraceIdRatioBased(0.05)) + ) + ], +) +async def test_non_global_tracer_provider( + tracer, + server_fixture, + aiohttp_client, +): + n_requests = 1000 + collection_ratio = 0.05 + n_expected_trace_ids = n_requests * collection_ratio + + _, memory_exporter = tracer + server, _ = server_fixture + + assert len(memory_exporter.get_finished_spans()) == 0 + + client = await aiohttp_client(server) + for _ in range(n_requests): + await client.get("/test-path") + + trace_ids = { + span.context.trace_id + for span in memory_exporter.get_finished_spans() + if span.context is not None + } + assert ( + 0.5 * n_expected_trace_ids + <= len(trace_ids) + <= 1.5 * n_expected_trace_ids + ) + + +def _get_sorted_metrics(metrics_data): + resource_metrics = metrics_data.resource_metrics if metrics_data else [] + + all_metrics = [] + for metrics in resource_metrics: + for scope_metrics in metrics.scope_metrics: + all_metrics.extend(scope_metrics.metrics) + + return sorted( + all_metrics, + key=lambda m: m.name, + ) + + @pytest.mark.asyncio async def test_custom_request_headers(tracer, aiohttp_server, monkeypatch): # pylint: disable=too-many-locals