diff --git a/snuba/web/rpc/v1/endpoint_get_trace.py b/snuba/web/rpc/v1/endpoint_get_trace.py index 9d7ed0d5c7..4cdb3897a3 100644 --- a/snuba/web/rpc/v1/endpoint_get_trace.py +++ b/snuba/web/rpc/v1/endpoint_get_trace.py @@ -1,8 +1,10 @@ +import random import uuid from datetime import datetime from operator import attrgetter from typing import Any, Dict, Iterable, Type +import sentry_sdk from google.protobuf.json_format import MessageToDict from google.protobuf.timestamp_pb2 import Timestamp from sentry_protos.snuba.v1.endpoint_get_trace_pb2 import ( @@ -11,6 +13,7 @@ ) from sentry_protos.snuba.v1.trace_item_attribute_pb2 import AttributeKey, AttributeValue +from snuba import state from snuba.attribution.appid import AppID from snuba.attribution.attribution_info import AttributionInfo from snuba.datasets.entities.entity_key import EntityKey @@ -46,6 +49,7 @@ "trace_id", "sampling_factor", ] +APPLY_FINAL_ROLLOUT_PERCENTAGE_CONFIG_KEY = "EndpointGetTrace.apply_final_rollout_percentage" def _build_query(request: GetTraceRequest, item: GetTraceRequest.TraceItem) -> Query: @@ -146,11 +150,28 @@ def _build_query(request: GetTraceRequest, item: GetTraceRequest.TraceItem) -> Q ], ) + if random.random() < _get_apply_final_rollout_percentage(): + query.set_final(True) + + span = sentry_sdk.get_current_span() + if span: + span.set_data("is_final", query.get_final()) + treeify_or_and_conditions(query) return query +def _get_apply_final_rollout_percentage() -> float: + return ( + state.get_float_config( + APPLY_FINAL_ROLLOUT_PERCENTAGE_CONFIG_KEY, + 0.0, + ) + or 0.0 + ) + + def _build_snuba_request( request: GetTraceRequest, item: GetTraceRequest.TraceItem, diff --git a/tests/web/rpc/v1/test_endpoint_get_trace.py b/tests/web/rpc/v1/test_endpoint_get_trace.py index 3f662e48bb..8acf908644 100644 --- a/tests/web/rpc/v1/test_endpoint_get_trace.py +++ b/tests/web/rpc/v1/test_endpoint_get_trace.py @@ -19,9 +19,15 @@ from sentry_protos.snuba.v1.trace_item_attribute_pb2 import AttributeKey, AttributeValue from sentry_protos.snuba.v1.trace_item_pb2 import AnyValue, TraceItem +from snuba import state from snuba.datasets.storages.factory import get_storage from snuba.datasets.storages.storage_key import StorageKey -from snuba.web.rpc.v1.endpoint_get_trace import EndpointGetTrace, _value_to_attribute +from snuba.web.rpc.v1.endpoint_get_trace import ( + APPLY_FINAL_ROLLOUT_PERCENTAGE_CONFIG_KEY, + EndpointGetTrace, + _build_query, + _value_to_attribute, +) from tests.base import BaseApiTest from tests.helpers import write_raw_unprocessed_events from tests.web.rpc.v1.test_utils import SERVER_NAME, gen_item_message @@ -92,9 +98,7 @@ def get_attributes( name=key, type=_PROTOBUF_TO_SENTRY_PROTOS[value_type][1], ) - args = { - _PROTOBUF_TO_SENTRY_PROTOS[value_type][0]: getattr(value, value_type) - } + args = {_PROTOBUF_TO_SENTRY_PROTOS[value_type][0]: getattr(value, value_type)} else: continue @@ -132,9 +136,7 @@ def test_without_data(self) -> None: ), trace_id=uuid.uuid4().hex, ) - response = self.app.post( - "/rpc/EndpointGetTrace/v1", data=message.SerializeToString() - ) + response = self.app.post("/rpc/EndpointGetTrace/v1", data=message.SerializeToString()) error_proto = ErrorProto() if response.status_code != 200: error_proto.ParseFromString(response.data) @@ -255,6 +257,55 @@ def test_with_specific_attributes(self, setup_teardown: Any) -> None: ) assert MessageToDict(response) == MessageToDict(expected_response) + def test_build_query_with_final(store_outcomes_data: Any) -> None: + ts = Timestamp(seconds=int(_BASE_TIME.timestamp())) + three_hours_later = int((_BASE_TIME + timedelta(hours=3)).timestamp()) + item = GetTraceRequest.TraceItem( + item_type=TraceItemType.TRACE_ITEM_TYPE_SPAN, + attributes=[ + AttributeKey( + name="server_name", + type=AttributeKey.Type.TYPE_STRING, + ), + AttributeKey( + name="sentry.parent_span_id", + type=AttributeKey.Type.TYPE_STRING, + ), + ], + ) + + message = GetTraceRequest( + meta=RequestMeta( + project_ids=[1, 2, 3], + organization_id=1, + cogs_category="something", + referrer="something", + start_timestamp=ts, + end_timestamp=Timestamp(seconds=three_hours_later), + request_id=_REQUEST_ID, + ), + trace_id=_TRACE_ID, + items=[item], + ) + + state.set_config( + APPLY_FINAL_ROLLOUT_PERCENTAGE_CONFIG_KEY, + 1.0, + ) + + query = _build_query(message, item) + + assert query.get_final() == True + + state.set_config( + APPLY_FINAL_ROLLOUT_PERCENTAGE_CONFIG_KEY, + 0.0, + ) + + query = _build_query(message, item) + + assert query.get_final() == False + def generate_spans_and_timestamps() -> tuple[list[TraceItem], list[Timestamp]]: timestamps: list[Timestamp] = [] @@ -264,8 +315,7 @@ def generate_spans_and_timestamps() -> tuple[list[TraceItem], list[Timestamp]]: span.ParseFromString(payload) timestamp = Timestamp() timestamp.FromNanoseconds( - int(span.attributes["sentry.start_timestamp_precise"].double_value * 1e6) - * 1000 + int(span.attributes["sentry.start_timestamp_precise"].double_value * 1e6) * 1000 ) timestamps.append(timestamp) spans.append(span) @@ -274,6 +324,12 @@ def generate_spans_and_timestamps() -> tuple[list[TraceItem], list[Timestamp]]: def get_span_id(span: TraceItem) -> str: # cut the 0x prefix - return hex(int.from_bytes(span.item_id, byteorder="little", signed=False,))[ + return hex( + int.from_bytes( + span.item_id, + byteorder="little", + signed=False, + ) + )[ 2: ].rjust(16, "0")