diff --git a/pyproject.toml b/pyproject.toml index e5eae2c21f..d9993a099a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -214,3 +214,6 @@ exclude = [ "grpc_test_service_pb2.py", "grpc_test_service_pb2_grpc.py", ] +per-file-ignores = [ + "sentry_sdk/integrations/spark/*:N802,N803", +] diff --git a/sentry_sdk/_compat.py b/sentry_sdk/_compat.py index fc04ed5859..8fc7f16ee7 100644 --- a/sentry_sdk/_compat.py +++ b/sentry_sdk/_compat.py @@ -1,12 +1,9 @@ +from __future__ import annotations import sys - from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Any - from typing import TypeVar - - T = TypeVar("T") PY38 = sys.version_info[0] == 3 and sys.version_info[1] >= 8 @@ -14,18 +11,15 @@ PY311 = sys.version_info[0] == 3 and sys.version_info[1] >= 11 -def with_metaclass(meta, *bases): - # type: (Any, *Any) -> Any +def with_metaclass(meta: Any, *bases: Any) -> Any: class MetaClass(type): - def __new__(metacls, name, this_bases, d): - # type: (Any, Any, Any, Any) -> Any + def __new__(metacls: Any, name: Any, this_bases: Any, d: Any) -> Any: return meta(name, bases, d) return type.__new__(MetaClass, "temporary_class", (), {}) -def check_uwsgi_thread_support(): - # type: () -> bool +def check_uwsgi_thread_support() -> bool: # We check two things here: # # 1. uWSGI doesn't run in threaded mode by default -- issue a warning if @@ -45,8 +39,7 @@ def check_uwsgi_thread_support(): from sentry_sdk.consts import FALSE_VALUES - def enabled(option): - # type: (str) -> bool + def enabled(option: str) -> bool: value = opt.get(option, False) if isinstance(value, bool): return value diff --git a/sentry_sdk/_init_implementation.py b/sentry_sdk/_init_implementation.py index 34e9d071e9..06e7f28d4f 100644 --- a/sentry_sdk/_init_implementation.py +++ b/sentry_sdk/_init_implementation.py @@ -1,23 +1,23 @@ +from __future__ import annotations + from typing import TYPE_CHECKING +if TYPE_CHECKING: + from typing import Optional, Any + import sentry_sdk from sentry_sdk.consts import ClientConstructor from sentry_sdk.opentelemetry.scope import setup_scope_context_management -if TYPE_CHECKING: - from typing import Any, Optional - -def _check_python_deprecations(): - # type: () -> None +def _check_python_deprecations() -> None: # Since we're likely to deprecate Python versions in the future, I'm keeping # this handy function around. Use this to detect the Python version used and # to output logger.warning()s if it's deprecated. pass -def _init(*args, **kwargs): - # type: (*Optional[str], **Any) -> None +def _init(*args: Optional[str], **kwargs: Any) -> None: """Initializes the SDK and optionally integrations. This takes the same arguments as the client constructor. diff --git a/sentry_sdk/_log_batcher.py b/sentry_sdk/_log_batcher.py index 87bebdb226..5486e91b26 100644 --- a/sentry_sdk/_log_batcher.py +++ b/sentry_sdk/_log_batcher.py @@ -1,37 +1,35 @@ +from __future__ import annotations import os import random import threading from datetime import datetime, timezone -from typing import Optional, List, Callable, TYPE_CHECKING, Any from sentry_sdk.utils import format_timestamp, safe_repr from sentry_sdk.envelope import Envelope, Item, PayloadRef +from typing import TYPE_CHECKING + if TYPE_CHECKING: from sentry_sdk._types import Log + from typing import Optional, List, Callable, Any class LogBatcher: MAX_LOGS_BEFORE_FLUSH = 100 FLUSH_WAIT_TIME = 5.0 - def __init__( - self, - capture_func, # type: Callable[[Envelope], None] - ): - # type: (...) -> None - self._log_buffer = [] # type: List[Log] + def __init__(self, capture_func: Callable[[Envelope], None]) -> None: + self._log_buffer: List[Log] = [] self._capture_func = capture_func self._running = True self._lock = threading.Lock() - self._flush_event = threading.Event() # type: threading.Event + self._flush_event = threading.Event() - self._flusher = None # type: Optional[threading.Thread] - self._flusher_pid = None # type: Optional[int] + self._flusher: Optional[threading.Thread] = None + self._flusher_pid: Optional[int] = None - def _ensure_thread(self): - # type: (...) -> bool + def _ensure_thread(self) -> bool: """For forking processes we might need to restart this thread. This ensures that our process actually has that thread running. """ @@ -63,18 +61,13 @@ def _ensure_thread(self): return True - def _flush_loop(self): - # type: (...) -> None + def _flush_loop(self) -> None: while self._running: self._flush_event.wait(self.FLUSH_WAIT_TIME + random.random()) self._flush_event.clear() self._flush() - def add( - self, - log, # type: Log - ): - # type: (...) -> None + def add(self, log: Log) -> None: if not self._ensure_thread() or self._flusher is None: return None @@ -83,8 +76,7 @@ def add( if len(self._log_buffer) >= self.MAX_LOGS_BEFORE_FLUSH: self._flush_event.set() - def kill(self): - # type: (...) -> None + def kill(self) -> None: if self._flusher is None: return @@ -92,15 +84,12 @@ def kill(self): self._flush_event.set() self._flusher = None - def flush(self): - # type: (...) -> None + def flush(self) -> None: self._flush() @staticmethod - def _log_to_transport_format(log): - # type: (Log) -> Any - def format_attribute(val): - # type: (int | float | str | bool) -> Any + def _log_to_transport_format(log: Log) -> Any: + def format_attribute(val: int | float | str | bool) -> Any: if isinstance(val, bool): return {"value": val, "type": "boolean"} if isinstance(val, int): @@ -128,8 +117,7 @@ def format_attribute(val): return res - def _flush(self): - # type: (...) -> Optional[Envelope] + def _flush(self) -> Optional[Envelope]: envelope = Envelope( headers={"sent_at": format_timestamp(datetime.now(timezone.utc))} diff --git a/sentry_sdk/_lru_cache.py b/sentry_sdk/_lru_cache.py index cbadd9723b..aec8883546 100644 --- a/sentry_sdk/_lru_cache.py +++ b/sentry_sdk/_lru_cache.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -8,17 +10,15 @@ class LRUCache: - def __init__(self, max_size): - # type: (int) -> None + def __init__(self, max_size: int) -> None: if max_size <= 0: raise AssertionError(f"invalid max_size: {max_size}") self.max_size = max_size - self._data = {} # type: dict[Any, Any] + self._data: dict[Any, Any] = {} self.hits = self.misses = 0 self.full = False - def set(self, key, value): - # type: (Any, Any) -> None + def set(self, key: Any, value: Any) -> None: current = self._data.pop(key, _SENTINEL) if current is not _SENTINEL: self._data[key] = value @@ -29,8 +29,7 @@ def set(self, key, value): self._data[key] = value self.full = len(self._data) >= self.max_size - def get(self, key, default=None): - # type: (Any, Any) -> Any + def get(self, key: Any, default: Any = None) -> Any: try: ret = self._data.pop(key) except KeyError: @@ -42,6 +41,5 @@ def get(self, key, default=None): return ret - def get_all(self): - # type: () -> list[tuple[Any, Any]] + def get_all(self) -> list[tuple[Any, Any]]: return list(self._data.items()) diff --git a/sentry_sdk/_queue.py b/sentry_sdk/_queue.py index a21c86ec0a..7a385cd861 100644 --- a/sentry_sdk/_queue.py +++ b/sentry_sdk/_queue.py @@ -81,6 +81,7 @@ if TYPE_CHECKING: from typing import Any + __all__ = ["EmptyError", "FullError", "Queue"] @@ -275,7 +276,7 @@ def get_nowait(self): # Initialize the queue representation def _init(self, maxsize): - self.queue = deque() # type: Any + self.queue: Any = deque() def _qsize(self): return len(self.queue) diff --git a/sentry_sdk/_types.py b/sentry_sdk/_types.py index 01778d5e20..41828cb9e3 100644 --- a/sentry_sdk/_types.py +++ b/sentry_sdk/_types.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import TYPE_CHECKING, TypeVar, Union @@ -18,32 +20,27 @@ class AnnotatedValue: __slots__ = ("value", "metadata") - def __init__(self, value, metadata): - # type: (Optional[Any], Dict[str, Any]) -> None + def __init__(self, value: Optional[Any], metadata: Dict[str, Any]) -> None: self.value = value self.metadata = metadata - def __eq__(self, other): - # type: (Any) -> bool + def __eq__(self, other: Any) -> bool: if not isinstance(other, AnnotatedValue): return False return self.value == other.value and self.metadata == other.metadata - def __str__(self): - # type: (AnnotatedValue) -> str + def __str__(self) -> str: return str({"value": str(self.value), "metadata": str(self.metadata)}) - def __len__(self): - # type: (AnnotatedValue) -> int + def __len__(self) -> int: if self.value is not None: return len(self.value) else: return 0 @classmethod - def removed_because_raw_data(cls): - # type: () -> AnnotatedValue + def removed_because_raw_data(cls) -> AnnotatedValue: """The value was removed because it could not be parsed. This is done for request body values that are not json nor a form.""" return AnnotatedValue( value="", @@ -58,8 +55,7 @@ def removed_because_raw_data(cls): ) @classmethod - def removed_because_over_size_limit(cls, value=""): - # type: (Any) -> AnnotatedValue + def removed_because_over_size_limit(cls, value: Any = "") -> AnnotatedValue: """ The actual value was removed because the size of the field exceeded the configured maximum size, for example specified with the max_request_body_size sdk option. @@ -77,8 +73,7 @@ def removed_because_over_size_limit(cls, value=""): ) @classmethod - def substituted_because_contains_sensitive_data(cls): - # type: () -> AnnotatedValue + def substituted_because_contains_sensitive_data(cls) -> AnnotatedValue: """The actual value was removed because it contained sensitive information.""" return AnnotatedValue( value=SENSITIVE_DATA_SUBSTITUTE, diff --git a/sentry_sdk/_werkzeug.py b/sentry_sdk/_werkzeug.py index 0fa3d611f1..8886d5cffa 100644 --- a/sentry_sdk/_werkzeug.py +++ b/sentry_sdk/_werkzeug.py @@ -32,12 +32,12 @@ SUCH DAMAGE. """ +from __future__ import annotations + from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Dict - from typing import Iterator - from typing import Tuple + from typing import Dict, Iterator, Tuple # @@ -47,8 +47,7 @@ # We need this function because Django does not give us a "pure" http header # dict. So we might as well use it for all WSGI integrations. # -def _get_headers(environ): - # type: (Dict[str, str]) -> Iterator[Tuple[str, str]] +def _get_headers(environ: Dict[str, str]) -> Iterator[Tuple[str, str]]: """ Returns only proper HTTP headers. """ @@ -67,8 +66,7 @@ def _get_headers(environ): # `get_host` comes from `werkzeug.wsgi.get_host` # https://github.com/pallets/werkzeug/blob/1.0.1/src/werkzeug/wsgi.py#L145 # -def get_host(environ, use_x_forwarded_for=False): - # type: (Dict[str, str], bool) -> str +def get_host(environ: Dict[str, str], use_x_forwarded_for: bool = False) -> str: """ Return the host for the given WSGI environment. """ diff --git a/sentry_sdk/ai/monitoring.py b/sentry_sdk/ai/monitoring.py index 5940fb5bc2..0d5faf587b 100644 --- a/sentry_sdk/ai/monitoring.py +++ b/sentry_sdk/ai/monitoring.py @@ -1,3 +1,4 @@ +from __future__ import annotations import inspect from functools import wraps @@ -15,22 +16,17 @@ _ai_pipeline_name = ContextVar("ai_pipeline_name", default=None) -def set_ai_pipeline_name(name): - # type: (Optional[str]) -> None +def set_ai_pipeline_name(name: Optional[str]) -> None: _ai_pipeline_name.set(name) -def get_ai_pipeline_name(): - # type: () -> Optional[str] +def get_ai_pipeline_name() -> Optional[str]: return _ai_pipeline_name.get() -def ai_track(description, **span_kwargs): - # type: (str, Any) -> Callable[..., Any] - def decorator(f): - # type: (Callable[..., Any]) -> Callable[..., Any] - def sync_wrapped(*args, **kwargs): - # type: (Any, Any) -> Any +def ai_track(description: str, **span_kwargs: Any) -> Callable[..., Any]: + def decorator(f: Callable[..., Any]) -> Callable[..., Any]: + def sync_wrapped(*args: Any, **kwargs: Any) -> Any: curr_pipeline = _ai_pipeline_name.get() op = span_kwargs.get("op", "ai.run" if curr_pipeline else "ai.pipeline") @@ -60,8 +56,7 @@ def sync_wrapped(*args, **kwargs): _ai_pipeline_name.set(None) return res - async def async_wrapped(*args, **kwargs): - # type: (Any, Any) -> Any + async def async_wrapped(*args: Any, **kwargs: Any) -> Any: curr_pipeline = _ai_pipeline_name.get() op = span_kwargs.get("op", "ai.run" if curr_pipeline else "ai.pipeline") @@ -100,9 +95,11 @@ async def async_wrapped(*args, **kwargs): def record_token_usage( - span, prompt_tokens=None, completion_tokens=None, total_tokens=None -): - # type: (Span, Optional[int], Optional[int], Optional[int]) -> None + span: Span, + prompt_tokens: Optional[int] = None, + completion_tokens: Optional[int] = None, + total_tokens: Optional[int] = None, +) -> None: ai_pipeline_name = get_ai_pipeline_name() if ai_pipeline_name: span.set_attribute(SPANDATA.AI_PIPELINE_NAME, ai_pipeline_name) diff --git a/sentry_sdk/ai/utils.py b/sentry_sdk/ai/utils.py index 5868606940..fc590da166 100644 --- a/sentry_sdk/ai/utils.py +++ b/sentry_sdk/ai/utils.py @@ -1,3 +1,4 @@ +from __future__ import annotations from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -7,8 +8,7 @@ from sentry_sdk.utils import logger -def _normalize_data(data): - # type: (Any) -> Any +def _normalize_data(data: Any) -> Any: # convert pydantic data (e.g. OpenAI v1+) to json compatible format if hasattr(data, "model_dump"): @@ -26,7 +26,6 @@ def _normalize_data(data): return data -def set_data_normalized(span, key, value): - # type: (Span, str, Any) -> None +def set_data_normalized(span: Span, key: str, value: Any) -> None: normalized = _normalize_data(value) span.set_attribute(key, normalized) diff --git a/sentry_sdk/api.py b/sentry_sdk/api.py index ff98261066..3aefc57f69 100644 --- a/sentry_sdk/api.py +++ b/sentry_sdk/api.py @@ -1,3 +1,4 @@ +from __future__ import annotations import inspect from contextlib import contextmanager @@ -20,21 +21,23 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from collections.abc import Mapping - - from typing import Any - from typing import Dict - from typing import Optional - from typing import Callable - from typing import TypeVar - from typing import Union - from typing import Generator - - import sentry_sdk + from typing import Any, Optional, Callable, TypeVar, Union, Generator T = TypeVar("T") F = TypeVar("F", bound=Callable[..., Any]) + from collections.abc import Mapping + from sentry_sdk.client import BaseClient + from sentry_sdk.tracing import Span + from sentry_sdk._types import ( + Event, + Hint, + LogLevelStr, + ExcInfo, + BreadcrumbHint, + Breadcrumb, + ) + # When changing this, update __all__ in __init__.py too __all__ = [ @@ -74,8 +77,7 @@ ] -def scopemethod(f): - # type: (F) -> F +def scopemethod(f: F) -> F: f.__doc__ = "%s\n\n%s" % ( "Alias for :py:meth:`sentry_sdk.Scope.%s`" % f.__name__, inspect.getdoc(getattr(Scope, f.__name__)), @@ -83,8 +85,7 @@ def scopemethod(f): return f -def clientmethod(f): - # type: (F) -> F +def clientmethod(f: F) -> F: f.__doc__ = "%s\n\n%s" % ( "Alias for :py:meth:`sentry_sdk.Client.%s`" % f.__name__, inspect.getdoc(getattr(Client, f.__name__)), @@ -93,13 +94,11 @@ def clientmethod(f): @scopemethod -def get_client(): - # type: () -> sentry_sdk.client.BaseClient +def get_client() -> BaseClient: return Scope.get_client() -def is_initialized(): - # type: () -> bool +def is_initialized() -> bool: """ .. versionadded:: 2.0.0 @@ -113,26 +112,22 @@ def is_initialized(): @scopemethod -def get_global_scope(): - # type: () -> BaseScope +def get_global_scope() -> BaseScope: return Scope.get_global_scope() @scopemethod -def get_isolation_scope(): - # type: () -> Scope +def get_isolation_scope() -> Scope: return Scope.get_isolation_scope() @scopemethod -def get_current_scope(): - # type: () -> Scope +def get_current_scope() -> Scope: return Scope.get_current_scope() @scopemethod -def last_event_id(): - # type: () -> Optional[str] +def last_event_id() -> Optional[str]: """ See :py:meth:`sentry_sdk.Scope.last_event_id` documentation regarding this method's limitations. @@ -142,23 +137,21 @@ def last_event_id(): @scopemethod def capture_event( - event, # type: sentry_sdk._types.Event - hint=None, # type: Optional[sentry_sdk._types.Hint] - scope=None, # type: Optional[Any] - **scope_kwargs, # type: Any -): - # type: (...) -> Optional[str] + event: Event, + hint: Optional[Hint] = None, + scope: Optional[Any] = None, + **scope_kwargs: Any, +) -> Optional[str]: return get_current_scope().capture_event(event, hint, scope=scope, **scope_kwargs) @scopemethod def capture_message( - message, # type: str - level=None, # type: Optional[sentry_sdk._types.LogLevelStr] - scope=None, # type: Optional[Any] - **scope_kwargs, # type: Any -): - # type: (...) -> Optional[str] + message: str, + level: Optional[LogLevelStr] = None, + scope: Optional[Any] = None, + **scope_kwargs: Any, +) -> Optional[str]: return get_current_scope().capture_message( message, level, scope=scope, **scope_kwargs ) @@ -166,23 +159,21 @@ def capture_message( @scopemethod def capture_exception( - error=None, # type: Optional[Union[BaseException, sentry_sdk._types.ExcInfo]] - scope=None, # type: Optional[Any] - **scope_kwargs, # type: Any -): - # type: (...) -> Optional[str] + error: Optional[Union[BaseException, ExcInfo]] = None, + scope: Optional[Any] = None, + **scope_kwargs: Any, +) -> Optional[str]: return get_current_scope().capture_exception(error, scope=scope, **scope_kwargs) @scopemethod def add_attachment( - bytes=None, # type: Union[None, bytes, Callable[[], bytes]] - filename=None, # type: Optional[str] - path=None, # type: Optional[str] - content_type=None, # type: Optional[str] - add_to_transactions=False, # type: bool -): - # type: (...) -> None + bytes: Union[None, bytes, Callable[[], bytes]] = None, + filename: Optional[str] = None, + path: Optional[str] = None, + content_type: Optional[str] = None, + add_to_transactions: bool = False, +) -> None: return get_isolation_scope().add_attachment( bytes, filename, path, content_type, add_to_transactions ) @@ -190,61 +181,52 @@ def add_attachment( @scopemethod def add_breadcrumb( - crumb=None, # type: Optional[sentry_sdk._types.Breadcrumb] - hint=None, # type: Optional[sentry_sdk._types.BreadcrumbHint] - **kwargs, # type: Any -): - # type: (...) -> None + crumb: Optional[Breadcrumb] = None, + hint: Optional[BreadcrumbHint] = None, + **kwargs: Any, +) -> None: return get_isolation_scope().add_breadcrumb(crumb, hint, **kwargs) @scopemethod -def set_tag(key, value): - # type: (str, Any) -> None +def set_tag(key: str, value: Any) -> None: return get_isolation_scope().set_tag(key, value) @scopemethod -def set_tags(tags): - # type: (Mapping[str, object]) -> None +def set_tags(tags: Mapping[str, object]) -> None: return get_isolation_scope().set_tags(tags) @scopemethod -def set_context(key, value): - # type: (str, Dict[str, Any]) -> None +def set_context(key: str, value: dict[str, Any]) -> None: return get_isolation_scope().set_context(key, value) @scopemethod -def set_extra(key, value): - # type: (str, Any) -> None +def set_extra(key: str, value: Any) -> None: return get_isolation_scope().set_extra(key, value) @scopemethod -def set_user(value): - # type: (Optional[Dict[str, Any]]) -> None +def set_user(value: Optional[dict[str, Any]]) -> None: return get_isolation_scope().set_user(value) @scopemethod -def set_level(value): - # type: (sentry_sdk._types.LogLevelStr) -> None +def set_level(value: LogLevelStr) -> None: return get_isolation_scope().set_level(value) @clientmethod def flush( - timeout=None, # type: Optional[float] - callback=None, # type: Optional[Callable[[int, float], None]] -): - # type: (...) -> None + timeout: Optional[float] = None, + callback: Optional[Callable[[int, float], None]] = None, +) -> None: return get_client().flush(timeout=timeout, callback=callback) -def start_span(**kwargs): - # type: (Any) -> sentry_sdk.tracing.Span +def start_span(**kwargs: Any) -> Span: """ Start and return a span. @@ -260,11 +242,7 @@ def start_span(**kwargs): return get_current_scope().start_span(**kwargs) -def start_transaction( - transaction=None, # type: Optional[sentry_sdk.tracing.Span] - **kwargs, # type: Any -): - # type: (...) -> sentry_sdk.tracing.Span +def start_transaction(transaction: Optional[Span] = None, **kwargs: Any) -> Span: """ .. deprecated:: 3.0.0 This function is deprecated and will be removed in a future release. @@ -303,24 +281,21 @@ def start_transaction( ) -def get_current_span(scope=None): - # type: (Optional[Scope]) -> Optional[sentry_sdk.tracing.Span] +def get_current_span(scope: Optional[Scope] = None) -> Optional[Span]: """ Returns the currently active span if there is one running, otherwise `None` """ return tracing_utils.get_current_span(scope) -def get_traceparent(): - # type: () -> Optional[str] +def get_traceparent() -> Optional[str]: """ Returns the traceparent either from the active span or from the scope. """ return get_current_scope().get_traceparent() -def get_baggage(): - # type: () -> Optional[str] +def get_baggage() -> Optional[str]: """ Returns Baggage either from the active span or from the scope. """ @@ -332,8 +307,7 @@ def get_baggage(): @contextmanager -def continue_trace(environ_or_headers): - # type: (Dict[str, Any]) -> Generator[None, None, None] +def continue_trace(environ_or_headers: dict[str, Any]) -> Generator[None, None, None]: """ Sets the propagation context from environment or headers to continue an incoming trace. """ @@ -343,13 +317,11 @@ def continue_trace(environ_or_headers): @scopemethod def start_session( - session_mode="application", # type: str -): - # type: (...) -> None + session_mode: str = "application", +) -> None: return get_isolation_scope().start_session(session_mode=session_mode) @scopemethod -def end_session(): - # type: () -> None +def end_session() -> None: return get_isolation_scope().end_session() diff --git a/sentry_sdk/attachments.py b/sentry_sdk/attachments.py index e5404f8658..1f2fe7bb30 100644 --- a/sentry_sdk/attachments.py +++ b/sentry_sdk/attachments.py @@ -1,3 +1,4 @@ +from __future__ import annotations import os import mimetypes @@ -31,13 +32,12 @@ class Attachment: def __init__( self, - bytes=None, # type: Union[None, bytes, Callable[[], bytes]] - filename=None, # type: Optional[str] - path=None, # type: Optional[str] - content_type=None, # type: Optional[str] - add_to_transactions=False, # type: bool - ): - # type: (...) -> None + bytes: Union[None, bytes, Callable[[], bytes]] = None, + filename: Optional[str] = None, + path: Optional[str] = None, + content_type: Optional[str] = None, + add_to_transactions: bool = False, + ) -> None: if bytes is None and path is None: raise TypeError("path or raw bytes required for attachment") if filename is None and path is not None: @@ -52,10 +52,9 @@ def __init__( self.content_type = content_type self.add_to_transactions = add_to_transactions - def to_envelope_item(self): - # type: () -> Item + def to_envelope_item(self) -> Item: """Returns an envelope item for this attachment.""" - payload = None # type: Union[None, PayloadRef, bytes] + payload: Union[None, PayloadRef, bytes] = None if self.bytes is not None: if callable(self.bytes): payload = self.bytes() @@ -70,6 +69,5 @@ def to_envelope_item(self): filename=self.filename, ) - def __repr__(self): - # type: () -> str + def __repr__(self) -> str: return "" % (self.filename,) diff --git a/sentry_sdk/client.py b/sentry_sdk/client.py index 0fe5a1d616..67c723c76c 100644 --- a/sentry_sdk/client.py +++ b/sentry_sdk/client.py @@ -1,3 +1,4 @@ +from __future__ import annotations import os import uuid import random @@ -5,7 +6,7 @@ from collections.abc import Mapping from datetime import datetime, timezone from importlib import import_module -from typing import TYPE_CHECKING, List, Dict, cast, overload +from typing import TYPE_CHECKING, overload import sentry_sdk from sentry_sdk._compat import check_uwsgi_thread_support @@ -48,13 +49,16 @@ from sentry_sdk.spotlight import setup_spotlight if TYPE_CHECKING: - from typing import Any - from typing import Callable - from typing import Optional - from typing import Sequence - from typing import Type - from typing import Union - from typing import TypeVar + from typing import ( + Any, + Callable, + Optional, + Sequence, + Type, + Union, + TypeVar, + Dict, + ) from sentry_sdk._types import Event, Hint, SDKInfo, Log from sentry_sdk.integrations import Integration @@ -64,22 +68,22 @@ from sentry_sdk.transport import Transport from sentry_sdk._log_batcher import LogBatcher - I = TypeVar("I", bound=Integration) # noqa: E741 + IntegrationType = TypeVar("IntegrationType", bound=Integration) # noqa: E741 + _client_init_debug = ContextVar("client_init_debug") -SDK_INFO = { +SDK_INFO: SDKInfo = { "name": "sentry.python", # SDK name will be overridden after integrations have been loaded with sentry_sdk.integrations.setup_integrations() "version": VERSION, "packages": [{"name": "pypi:sentry-sdk", "version": VERSION}], -} # type: SDKInfo +} -def _get_options(*args, **kwargs): - # type: (*Optional[str], **Any) -> Dict[str, Any] +def _get_options(*args: Optional[str], **kwargs: Any) -> Dict[str, Any]: if args and (isinstance(args[0], (bytes, str)) or args[0] is None): - dsn = args[0] # type: Optional[str] + dsn: Optional[str] = args[0] args = args[1:] else: dsn = None @@ -149,37 +153,31 @@ class BaseClient: The basic definition of a client that is used for sending data to Sentry. """ - spotlight = None # type: Optional[SpotlightClient] + spotlight: Optional[SpotlightClient] = None - def __init__(self, options=None): - # type: (Optional[Dict[str, Any]]) -> None - self.options = ( + def __init__(self, options: Optional[Dict[str, Any]] = None) -> None: + self.options: Dict[str, Any] = ( options if options is not None else DEFAULT_OPTIONS - ) # type: Dict[str, Any] + ) - self.transport = None # type: Optional[Transport] - self.monitor = None # type: Optional[Monitor] - self.log_batcher = None # type: Optional[LogBatcher] + self.transport: Optional[Transport] = None + self.monitor: Optional[Monitor] = None + self.log_batcher: Optional[LogBatcher] = None - def __getstate__(self, *args, **kwargs): - # type: (*Any, **Any) -> Any + def __getstate__(self, *args: Any, **kwargs: Any) -> Any: return {"options": {}} - def __setstate__(self, *args, **kwargs): - # type: (*Any, **Any) -> None + def __setstate__(self, *args: Any, **kwargs: Any) -> None: pass @property - def dsn(self): - # type: () -> Optional[str] + def dsn(self) -> Optional[str]: return None - def should_send_default_pii(self): - # type: () -> bool + def should_send_default_pii(self) -> bool: return False - def is_active(self): - # type: () -> bool + def is_active(self) -> bool: """ .. versionadded:: 2.0.0 @@ -187,48 +185,40 @@ def is_active(self): """ return False - def capture_event(self, *args, **kwargs): - # type: (*Any, **Any) -> Optional[str] + def capture_event(self, *args: Any, **kwargs: Any) -> Optional[str]: return None - def _capture_experimental_log(self, log): - # type: (Log) -> None + def _capture_experimental_log(self, log: "Log") -> None: pass - def capture_session(self, *args, **kwargs): - # type: (*Any, **Any) -> None + def capture_session(self, *args: Any, **kwargs: Any) -> None: return None if TYPE_CHECKING: @overload - def get_integration(self, name_or_class): - # type: (str) -> Optional[Integration] - ... + def get_integration(self, name_or_class: str) -> Optional[Integration]: ... @overload - def get_integration(self, name_or_class): - # type: (type[I]) -> Optional[I] - ... + def get_integration( + self, name_or_class: type[IntegrationType] + ) -> Optional[IntegrationType]: ... - def get_integration(self, name_or_class): - # type: (Union[str, type[Integration]]) -> Optional[Integration] + def get_integration( + self, name_or_class: Union[str, type[Integration]] + ) -> Optional[Integration]: return None - def close(self, *args, **kwargs): - # type: (*Any, **Any) -> None + def close(self, *args: Any, **kwargs: Any) -> None: return None - def flush(self, *args, **kwargs): - # type: (*Any, **Any) -> None + def flush(self, *args: Any, **kwargs: Any) -> None: return None - def __enter__(self): - # type: () -> BaseClient + def __enter__(self) -> BaseClient: return self - def __exit__(self, exc_type, exc_value, tb): - # type: (Any, Any, Any) -> None + def __exit__(self, exc_type: Any, exc_value: Any, tb: Any) -> None: return None @@ -252,22 +242,20 @@ class _Client(BaseClient): Alias of :py:class:`sentry_sdk.Client`. (Was created for better intelisense support) """ - def __init__(self, *args, **kwargs): - # type: (*Any, **Any) -> None - super(_Client, self).__init__(options=get_options(*args, **kwargs)) + def __init__(self, *args: Any, **kwargs: Any) -> None: + super(_Client, self).__init__(options=_get_options(*args, **kwargs)) self._init_impl() - def __getstate__(self): - # type: () -> Any + def __getstate__(self) -> Any: return {"options": self.options} - def __setstate__(self, state): - # type: (Any) -> None + def __setstate__(self, state: Any) -> None: self.options = state["options"] self._init_impl() - def _setup_instrumentation(self, functions_to_trace): - # type: (Sequence[Dict[str, str]]) -> None + def _setup_instrumentation( + self, functions_to_trace: Sequence[Dict[str, str]] + ) -> None: """ Instruments the functions given in the list `functions_to_trace` with the `@sentry_sdk.tracing.trace` decorator. """ @@ -317,12 +305,10 @@ def _setup_instrumentation(self, functions_to_trace): e, ) - def _init_impl(self): - # type: () -> None + def _init_impl(self) -> None: old_debug = _client_init_debug.get(False) - def _capture_envelope(envelope): - # type: (Envelope) -> None + def _capture_envelope(envelope: Envelope) -> None: if self.transport is not None: self.transport.capture_envelope(envelope) @@ -423,8 +409,7 @@ def _capture_envelope(envelope): # need to check if it's safe to use them. check_uwsgi_thread_support() - def is_active(self): - # type: () -> bool + def is_active(self) -> bool: """ .. versionadded:: 2.0.0 @@ -432,8 +417,7 @@ def is_active(self): """ return True - def should_send_default_pii(self): - # type: () -> bool + def should_send_default_pii(self) -> bool: """ .. versionadded:: 2.0.0 @@ -442,28 +426,26 @@ def should_send_default_pii(self): return self.options.get("send_default_pii") or False @property - def dsn(self): - # type: () -> Optional[str] + def dsn(self) -> Optional[str]: """Returns the configured DSN as string.""" return self.options["dsn"] def _prepare_event( self, - event, # type: Event - hint, # type: Hint - scope, # type: Optional[Scope] - ): - # type: (...) -> Optional[Event] + event: Event, + hint: Hint, + scope: Optional[Scope], + ) -> Optional[Event]: - previous_total_spans = None # type: Optional[int] - previous_total_breadcrumbs = None # type: Optional[int] + previous_total_spans: Optional[int] = None + previous_total_breadcrumbs: Optional[int] = None if event.get("timestamp") is None: event["timestamp"] = datetime.now(timezone.utc) if scope is not None: is_transaction = event.get("type") == "transaction" - spans_before = len(cast(List[Dict[str, object]], event.get("spans", []))) + spans_before = len(event.get("spans", [])) event_ = scope.apply_to_event(event, hint, self.options) # one of the event/error processors returned None @@ -481,16 +463,14 @@ def _prepare_event( ) return None - event = event_ # type: Optional[Event] # type: ignore[no-redef] - spans_delta = spans_before - len( - cast(List[Dict[str, object]], event.get("spans", [])) - ) + event = event_ + spans_delta = spans_before - len(event.get("spans", [])) if is_transaction and spans_delta > 0 and self.transport is not None: self.transport.record_lost_event( "event_processor", data_category="span", quantity=spans_delta ) - dropped_spans = event.pop("_dropped_spans", 0) + spans_delta # type: int + dropped_spans: int = event.pop("_dropped_spans", 0) + spans_delta if dropped_spans > 0: previous_total_spans = spans_before + dropped_spans if scope._n_breadcrumbs_truncated > 0: @@ -562,14 +542,11 @@ def _prepare_event( # Postprocess the event here so that annotated types do # generally not surface in before_send if event is not None: - event = cast( - "Event", - serialize( - cast("Dict[str, Any]", event), - max_request_body_size=self.options.get("max_request_body_size"), - max_value_length=self.options.get("max_value_length"), - custom_repr=self.options.get("custom_repr"), - ), + event: Event = serialize( # type: ignore[no-redef] + event, + max_request_body_size=self.options.get("max_request_body_size"), + max_value_length=self.options.get("max_value_length"), + custom_repr=self.options.get("custom_repr"), ) before_send = self.options["before_send"] @@ -578,7 +555,7 @@ def _prepare_event( and event is not None and event.get("type") != "transaction" ): - new_event = None # type: Optional[Event] + new_event: Optional["Event"] = None with capture_internal_exceptions(): new_event = before_send(event, hint or {}) if new_event is None: @@ -595,7 +572,7 @@ def _prepare_event( if event.get("exception"): DedupeIntegration.reset_last_seen() - event = new_event # type: Optional[Event] # type: ignore[no-redef] + event = new_event before_send_transaction = self.options["before_send_transaction"] if ( @@ -604,7 +581,7 @@ def _prepare_event( and event.get("type") == "transaction" ): new_event = None - spans_before = len(cast(List[Dict[str, object]], event.get("spans", []))) + spans_before = len(event.get("spans", [])) with capture_internal_exceptions(): new_event = before_send_transaction(event, hint or {}) if new_event is None: @@ -619,20 +596,17 @@ def _prepare_event( quantity=spans_before + 1, # +1 for the transaction itself ) else: - spans_delta = spans_before - len( - cast(List[Dict[str, object]], new_event.get("spans", [])) - ) + spans_delta = spans_before - len(new_event.get("spans", [])) if spans_delta > 0 and self.transport is not None: self.transport.record_lost_event( reason="before_send", data_category="span", quantity=spans_delta ) - event = new_event # type: Optional[Event] # type: ignore[no-redef] + event = new_event return event - def _is_ignored_error(self, event, hint): - # type: (Event, Hint) -> bool + def _is_ignored_error(self, event: Event, hint: Hint) -> bool: exc_info = hint.get("exc_info") if exc_info is None: return False @@ -655,11 +629,10 @@ def _is_ignored_error(self, event, hint): def _should_capture( self, - event, # type: Event - hint, # type: Hint - scope=None, # type: Optional[Scope] - ): - # type: (...) -> bool + event: "Event", + hint: "Hint", + scope: Optional["Scope"] = None, + ) -> bool: # Transactions are sampled independent of error events. is_transaction = event.get("type") == "transaction" if is_transaction: @@ -677,10 +650,9 @@ def _should_capture( def _should_sample_error( self, - event, # type: Event - hint, # type: Hint - ): - # type: (...) -> bool + event: Event, + hint: Hint, + ) -> bool: error_sampler = self.options.get("error_sampler", None) if callable(error_sampler): @@ -725,10 +697,9 @@ def _should_sample_error( def _update_session_from_event( self, - session, # type: Session - event, # type: Event - ): - # type: (...) -> None + session: Session, + event: Event, + ) -> None: crashed = False errored = False @@ -764,11 +735,10 @@ def _update_session_from_event( def capture_event( self, - event, # type: Event - hint=None, # type: Optional[Hint] - scope=None, # type: Optional[Scope] - ): - # type: (...) -> Optional[str] + event: Event, + hint: Optional[Hint] = None, + scope: Optional[Scope] = None, + ) -> Optional[str]: """Captures an event. :param event: A ready-made event that can be directly sent to Sentry. @@ -779,7 +749,7 @@ def capture_event( :returns: An event ID. May be `None` if there is no DSN set or of if the SDK decided to discard the event for other reasons. In such situations setting `debug=True` on `init()` may help. """ - hint = dict(hint or ()) # type: Hint + hint: Hint = dict(hint or ()) if not self._should_capture(event, hint, scope): return None @@ -814,10 +784,10 @@ def capture_event( trace_context = event_opt.get("contexts", {}).get("trace") or {} dynamic_sampling_context = trace_context.pop("dynamic_sampling_context", {}) - headers = { + headers: dict[str, object] = { "event_id": event_opt["event_id"], "sent_at": format_timestamp(datetime.now(timezone.utc)), - } # type: dict[str, object] + } if dynamic_sampling_context: headers["trace"] = dynamic_sampling_context @@ -847,8 +817,7 @@ def capture_event( return return_value - def _capture_experimental_log(self, log): - # type: (Log) -> None + def _capture_experimental_log(self, log: Log) -> None: logs_enabled = self.options["_experiments"].get("enable_logs", False) if not logs_enabled: return @@ -914,10 +883,7 @@ def _capture_experimental_log(self, log): if self.log_batcher: self.log_batcher.add(log) - def capture_session( - self, session # type: Session - ): - # type: (...) -> None + def capture_session(self, session: Session) -> None: if not session.release: logger.info("Discarded session update because of missing release") else: @@ -926,19 +892,16 @@ def capture_session( if TYPE_CHECKING: @overload - def get_integration(self, name_or_class): - # type: (str) -> Optional[Integration] - ... + def get_integration(self, name_or_class: str) -> Optional[Integration]: ... @overload - def get_integration(self, name_or_class): - # type: (type[I]) -> Optional[I] - ... + def get_integration( + self, name_or_class: type[IntegrationType] + ) -> Optional[IntegrationType]: ... def get_integration( - self, name_or_class # type: Union[str, Type[Integration]] - ): - # type: (...) -> Optional[Integration] + self, name_or_class: Union[str, Type[Integration]] + ) -> Optional[Integration]: """Returns the integration for this client by name or class. If the client does not have that integration then `None` is returned. """ @@ -953,10 +916,9 @@ def get_integration( def close( self, - timeout=None, # type: Optional[float] - callback=None, # type: Optional[Callable[[int, float], None]] - ): - # type: (...) -> None + timeout: Optional[float] = None, + callback: Optional[Callable[[int, float], None]] = None, + ) -> None: """ Close the client and shut down the transport. Arguments have the same semantics as :py:meth:`Client.flush`. @@ -977,10 +939,9 @@ def close( def flush( self, - timeout=None, # type: Optional[float] - callback=None, # type: Optional[Callable[[int, float], None]] - ): - # type: (...) -> None + timeout: Optional[float] = None, + callback: Optional[Callable[[int, float], None]] = None, + ) -> None: """ Wait for the current events to be sent. @@ -998,17 +959,13 @@ def flush( self.transport.flush(timeout=timeout, callback=callback) - def __enter__(self): - # type: () -> _Client + def __enter__(self) -> _Client: return self - def __exit__(self, exc_type, exc_value, tb): - # type: (Any, Any, Any) -> None + def __exit__(self, exc_type: Any, exc_value: Any, tb: Any) -> None: self.close() -from typing import TYPE_CHECKING - if TYPE_CHECKING: # Make mypy, PyCharm and other static analyzers think `get_options` is a # type to have nicer autocompletion for params. diff --git a/sentry_sdk/consts.py b/sentry_sdk/consts.py index dbed277202..8f67b127df 100644 --- a/sentry_sdk/consts.py +++ b/sentry_sdk/consts.py @@ -1,7 +1,21 @@ +from __future__ import annotations import itertools from enum import Enum from typing import TYPE_CHECKING +if TYPE_CHECKING: + from typing import ( + Optional, + Callable, + Union, + List, + Type, + Dict, + Any, + Sequence, + Tuple, + ) + # up top to prevent circular import due to integration import DEFAULT_MAX_VALUE_LENGTH = 1024 @@ -26,17 +40,6 @@ class CompressionAlgo(Enum): if TYPE_CHECKING: - import sentry_sdk - - from typing import Optional - from typing import Callable - from typing import Union - from typing import List - from typing import Type - from typing import Dict - from typing import Any - from typing import Sequence - from typing import Tuple from typing_extensions import Literal from typing_extensions import TypedDict @@ -52,6 +55,8 @@ class CompressionAlgo(Enum): TransactionProcessor, ) + import sentry_sdk + # Experiments are feature flags to enable and disable certain unstable SDK # functionality. Changing them from the defaults (`None`) in production # code is highly discouraged. They are not subject to any stability @@ -728,8 +733,7 @@ class TransactionSource(str, Enum): URL = "url" VIEW = "view" - def __str__(self): - # type: () -> str + def __str__(self) -> str: return self.value @@ -757,68 +761,73 @@ class ClientConstructor: def __init__( self, - dsn=None, # type: Optional[str] + dsn: Optional[str] = None, *, - max_breadcrumbs=DEFAULT_MAX_BREADCRUMBS, # type: int - release=None, # type: Optional[str] - environment=None, # type: Optional[str] - server_name=None, # type: Optional[str] - shutdown_timeout=2, # type: float - integrations=[], # type: Sequence[sentry_sdk.integrations.Integration] # noqa: B006 - in_app_include=[], # type: List[str] # noqa: B006 - in_app_exclude=[], # type: List[str] # noqa: B006 - default_integrations=True, # type: bool - dist=None, # type: Optional[str] - transport=None, # type: Optional[Union[sentry_sdk.transport.Transport, Type[sentry_sdk.transport.Transport], Callable[[Event], None]]] - transport_queue_size=DEFAULT_QUEUE_SIZE, # type: int - sample_rate=1.0, # type: float - send_default_pii=None, # type: Optional[bool] - http_proxy=None, # type: Optional[str] - https_proxy=None, # type: Optional[str] - ignore_errors=[], # type: Sequence[Union[type, str]] # noqa: B006 - max_request_body_size="medium", # type: str - socket_options=None, # type: Optional[List[Tuple[int, int, int | bytes]]] - keep_alive=None, # type: Optional[bool] - before_send=None, # type: Optional[EventProcessor] - before_breadcrumb=None, # type: Optional[BreadcrumbProcessor] - debug=None, # type: Optional[bool] - attach_stacktrace=False, # type: bool - ca_certs=None, # type: Optional[str] - traces_sample_rate=None, # type: Optional[float] - traces_sampler=None, # type: Optional[TracesSampler] - profiles_sample_rate=None, # type: Optional[float] - profiles_sampler=None, # type: Optional[TracesSampler] - profiler_mode=None, # type: Optional[ProfilerMode] - profile_lifecycle="manual", # type: Literal["manual", "trace"] - profile_session_sample_rate=None, # type: Optional[float] - auto_enabling_integrations=True, # type: bool - disabled_integrations=None, # type: Optional[Sequence[sentry_sdk.integrations.Integration]] - auto_session_tracking=True, # type: bool - send_client_reports=True, # type: bool - _experiments={}, # type: Experiments # noqa: B006 - proxy_headers=None, # type: Optional[Dict[str, str]] - before_send_transaction=None, # type: Optional[TransactionProcessor] - project_root=None, # type: Optional[str] - include_local_variables=True, # type: Optional[bool] - include_source_context=True, # type: Optional[bool] - trace_propagation_targets=[ # noqa: B006 - MATCH_ALL - ], # type: Optional[Sequence[str]] - functions_to_trace=[], # type: Sequence[Dict[str, str]] # noqa: B006 - event_scrubber=None, # type: Optional[sentry_sdk.scrubber.EventScrubber] - max_value_length=DEFAULT_MAX_VALUE_LENGTH, # type: int - enable_backpressure_handling=True, # type: bool - error_sampler=None, # type: Optional[Callable[[Event, Hint], Union[float, bool]]] - enable_db_query_source=True, # type: bool - db_query_source_threshold_ms=100, # type: int - spotlight=None, # type: Optional[Union[bool, str]] - cert_file=None, # type: Optional[str] - key_file=None, # type: Optional[str] - custom_repr=None, # type: Optional[Callable[..., Optional[str]]] - add_full_stack=DEFAULT_ADD_FULL_STACK, # type: bool - max_stack_frames=DEFAULT_MAX_STACK_FRAMES, # type: Optional[int] - ): - # type: (...) -> None + max_breadcrumbs: int = DEFAULT_MAX_BREADCRUMBS, + release: Optional[str] = None, + environment: Optional[str] = None, + server_name: Optional[str] = None, + shutdown_timeout: float = 2, + integrations: Sequence[sentry_sdk.integrations.Integration] = [], # noqa: B006 + in_app_include: List[str] = [], # noqa: B006 + in_app_exclude: List[str] = [], # noqa: B006 + default_integrations: bool = True, + dist: Optional[str] = None, + transport: Optional[ + Union[ + sentry_sdk.transport.Transport, + Type[sentry_sdk.transport.Transport], + Callable[[Event], None], + ] + ] = None, + transport_queue_size: int = DEFAULT_QUEUE_SIZE, + sample_rate: float = 1.0, + send_default_pii: Optional[bool] = None, + http_proxy: Optional[str] = None, + https_proxy: Optional[str] = None, + ignore_errors: Sequence[Union[type, str]] = [], # noqa: B006 + max_request_body_size: str = "medium", + socket_options: Optional[List[Tuple[int, int, int | bytes]]] = None, + keep_alive: Optional[bool] = None, + before_send: Optional[EventProcessor] = None, + before_breadcrumb: Optional[BreadcrumbProcessor] = None, + debug: Optional[bool] = None, + attach_stacktrace: bool = False, + ca_certs: Optional[str] = None, + traces_sample_rate: Optional[float] = None, + traces_sampler: Optional[TracesSampler] = None, + profiles_sample_rate: Optional[float] = None, + profiles_sampler: Optional[TracesSampler] = None, + profiler_mode: Optional[ProfilerMode] = None, + profile_lifecycle: Literal["manual", "trace"] = "manual", + profile_session_sample_rate: Optional[float] = None, + auto_enabling_integrations: bool = True, + disabled_integrations: Optional[ + Sequence[sentry_sdk.integrations.Integration] + ] = None, + auto_session_tracking: bool = True, + send_client_reports: bool = True, + _experiments: Experiments = {}, # noqa: B006 + proxy_headers: Optional[Dict[str, str]] = None, + before_send_transaction: Optional[TransactionProcessor] = None, + project_root: Optional[str] = None, + include_local_variables: Optional[bool] = True, + include_source_context: Optional[bool] = True, + trace_propagation_targets: Optional[Sequence[str]] = [MATCH_ALL], # noqa: B006 + functions_to_trace: Sequence[Dict[str, str]] = [], # noqa: B006 + event_scrubber: Optional[sentry_sdk.scrubber.EventScrubber] = None, + max_value_length: int = DEFAULT_MAX_VALUE_LENGTH, + enable_backpressure_handling: bool = True, + error_sampler: Optional[Callable[[Event, Hint], Union[float, bool]]] = None, + enable_db_query_source: bool = True, + db_query_source_threshold_ms: int = 100, + spotlight: Optional[Union[bool, str]] = None, + cert_file: Optional[str] = None, + key_file: Optional[str] = None, + custom_repr: Optional[Callable[..., Optional[str]]] = None, + add_full_stack: bool = DEFAULT_ADD_FULL_STACK, + max_stack_frames: Optional[int] = DEFAULT_MAX_STACK_FRAMES, + ) -> None: """Initialize the Sentry SDK with the given parameters. All parameters described here can be used in a call to `sentry_sdk.init()`. :param dsn: The DSN tells the SDK where to send the events. @@ -1198,8 +1207,7 @@ def __init__( pass -def _get_default_options(): - # type: () -> dict[str, Any] +def _get_default_options() -> dict[str, Any]: import inspect a = inspect.getfullargspec(ClientConstructor.__init__) diff --git a/sentry_sdk/crons/api.py b/sentry_sdk/crons/api.py index 20e95685a7..cbe8b92834 100644 --- a/sentry_sdk/crons/api.py +++ b/sentry_sdk/crons/api.py @@ -1,3 +1,4 @@ +from __future__ import annotations import uuid import sentry_sdk @@ -10,17 +11,16 @@ def _create_check_in_event( - monitor_slug=None, # type: Optional[str] - check_in_id=None, # type: Optional[str] - status=None, # type: Optional[str] - duration_s=None, # type: Optional[float] - monitor_config=None, # type: Optional[MonitorConfig] -): - # type: (...) -> Event + monitor_slug: Optional[str] = None, + check_in_id: Optional[str] = None, + status: Optional[str] = None, + duration_s: Optional[float] = None, + monitor_config: Optional[MonitorConfig] = None, +) -> Event: options = sentry_sdk.get_client().options - check_in_id = check_in_id or uuid.uuid4().hex # type: str + check_in_id = check_in_id or uuid.uuid4().hex - check_in = { + check_in: Event = { "type": "check_in", "monitor_slug": monitor_slug, "check_in_id": check_in_id, @@ -28,7 +28,7 @@ def _create_check_in_event( "duration": duration_s, "environment": options.get("environment", None), "release": options.get("release", None), - } # type: Event + } if monitor_config: check_in["monitor_config"] = monitor_config @@ -37,13 +37,12 @@ def _create_check_in_event( def capture_checkin( - monitor_slug=None, # type: Optional[str] - check_in_id=None, # type: Optional[str] - status=None, # type: Optional[str] - duration=None, # type: Optional[float] - monitor_config=None, # type: Optional[MonitorConfig] -): - # type: (...) -> str + monitor_slug: Optional[str] = None, + check_in_id: Optional[str] = None, + status: Optional[str] = None, + duration: Optional[float] = None, + monitor_config: Optional[MonitorConfig] = None, +) -> str: check_in_event = _create_check_in_event( monitor_slug=monitor_slug, check_in_id=check_in_id, diff --git a/sentry_sdk/crons/decorator.py b/sentry_sdk/crons/decorator.py index 9af00e61c0..50078a2dba 100644 --- a/sentry_sdk/crons/decorator.py +++ b/sentry_sdk/crons/decorator.py @@ -1,3 +1,4 @@ +from __future__ import annotations from functools import wraps from inspect import iscoroutinefunction @@ -16,8 +17,6 @@ ParamSpec, Type, TypeVar, - Union, - cast, overload, ) from sentry_sdk._types import MonitorConfig @@ -55,13 +54,15 @@ def test(arg): ``` """ - def __init__(self, monitor_slug=None, monitor_config=None): - # type: (Optional[str], Optional[MonitorConfig]) -> None + def __init__( + self, + monitor_slug: Optional[str] = None, + monitor_config: Optional[MonitorConfig] = None, + ) -> None: self.monitor_slug = monitor_slug self.monitor_config = monitor_config - def __enter__(self): - # type: () -> None + def __enter__(self) -> None: self.start_timestamp = now() self.check_in_id = capture_checkin( monitor_slug=self.monitor_slug, @@ -69,8 +70,12 @@ def __enter__(self): monitor_config=self.monitor_config, ) - def __exit__(self, exc_type, exc_value, traceback): - # type: (Optional[Type[BaseException]], Optional[BaseException], Optional[TracebackType]) -> None + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: duration_s = now() - self.start_timestamp if exc_type is None and exc_value is None and traceback is None: @@ -89,46 +94,39 @@ def __exit__(self, exc_type, exc_value, traceback): if TYPE_CHECKING: @overload - def __call__(self, fn): - # type: (Callable[P, Awaitable[Any]]) -> Callable[P, Awaitable[Any]] + def __call__( + self, fn: Callable[P, Awaitable[Any]] + ) -> Callable[P, Awaitable[Any]]: # Unfortunately, mypy does not give us any reliable way to type check the # return value of an Awaitable (i.e. async function) for this overload, # since calling iscouroutinefunction narrows the type to Callable[P, Awaitable[Any]]. ... @overload - def __call__(self, fn): - # type: (Callable[P, R]) -> Callable[P, R] - ... + def __call__(self, fn: Callable[P, R]) -> Callable[P, R]: ... def __call__( self, - fn, # type: Union[Callable[P, R], Callable[P, Awaitable[Any]]] - ): - # type: (...) -> Union[Callable[P, R], Callable[P, Awaitable[Any]]] + fn: Callable[..., Any], + ) -> Callable[..., Any]: if iscoroutinefunction(fn): return self._async_wrapper(fn) - else: - if TYPE_CHECKING: - fn = cast("Callable[P, R]", fn) return self._sync_wrapper(fn) - def _async_wrapper(self, fn): - # type: (Callable[P, Awaitable[Any]]) -> Callable[P, Awaitable[Any]] + def _async_wrapper( + self, fn: Callable[P, Awaitable[Any]] + ) -> Callable[P, Awaitable[Any]]: @wraps(fn) - async def inner(*args: "P.args", **kwargs: "P.kwargs"): - # type: (...) -> R + async def inner(*args: P.args, **kwargs: P.kwargs) -> R: with self: return await fn(*args, **kwargs) return inner - def _sync_wrapper(self, fn): - # type: (Callable[P, R]) -> Callable[P, R] + def _sync_wrapper(self, fn: Callable[P, R]) -> Callable[P, R]: @wraps(fn) - def inner(*args: "P.args", **kwargs: "P.kwargs"): - # type: (...) -> R + def inner(*args: P.args, **kwargs: P.kwargs) -> R: with self: return fn(*args, **kwargs) diff --git a/sentry_sdk/debug.py b/sentry_sdk/debug.py index c0c30fdd5d..5564bb5ea3 100644 --- a/sentry_sdk/debug.py +++ b/sentry_sdk/debug.py @@ -1,3 +1,4 @@ +from __future__ import annotations import sys import logging @@ -8,22 +9,19 @@ class _DebugFilter(logging.Filter): - def filter(self, record): - # type: (LogRecord) -> bool + def filter(self, record: LogRecord) -> bool: if _client_init_debug.get(False): return True return get_client().options["debug"] -def init_debug_support(): - # type: () -> None +def init_debug_support() -> None: if not logger.handlers: configure_logger() -def configure_logger(): - # type: () -> None +def configure_logger() -> None: _handler = logging.StreamHandler(sys.stderr) _handler.setFormatter(logging.Formatter(" [sentry] %(levelname)s: %(message)s")) logger.addHandler(_handler) diff --git a/sentry_sdk/envelope.py b/sentry_sdk/envelope.py index 378028377b..c532191202 100644 --- a/sentry_sdk/envelope.py +++ b/sentry_sdk/envelope.py @@ -1,3 +1,4 @@ +from __future__ import annotations import io import json import mimetypes @@ -8,18 +9,11 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Any - from typing import Optional - from typing import Union - from typing import Dict - from typing import List - from typing import Iterator - from sentry_sdk._types import Event, EventDataCategory + from typing import Any, Optional, Union, Dict, List, Iterator -def parse_json(data): - # type: (Union[bytes, str]) -> Any +def parse_json(data: Union[bytes, str]) -> Any: # on some python 3 versions this needs to be bytes if isinstance(data, bytes): data = data.decode("utf-8", "replace") @@ -35,10 +29,9 @@ class Envelope: def __init__( self, - headers=None, # type: Optional[Dict[str, Any]] - items=None, # type: Optional[List[Item]] - ): - # type: (...) -> None + headers: Optional[Dict[str, Any]] = None, + items: Optional[List[Item]] = None, + ) -> None: if headers is not None: headers = dict(headers) self.headers = headers or {} @@ -49,35 +42,22 @@ def __init__( self.items = items @property - def description(self): - # type: (...) -> str + def description(self) -> str: return "envelope with %s items (%s)" % ( len(self.items), ", ".join(x.data_category for x in self.items), ) - def add_event( - self, event # type: Event - ): - # type: (...) -> None + def add_event(self, event: Event) -> None: self.add_item(Item(payload=PayloadRef(json=event), type="event")) - def add_transaction( - self, transaction # type: Event - ): - # type: (...) -> None + def add_transaction(self, transaction: Event) -> None: self.add_item(Item(payload=PayloadRef(json=transaction), type="transaction")) - def add_profile( - self, profile # type: Any - ): - # type: (...) -> None + def add_profile(self, profile: Any) -> None: self.add_item(Item(payload=PayloadRef(json=profile), type="profile")) - def add_profile_chunk( - self, profile_chunk # type: Any - ): - # type: (...) -> None + def add_profile_chunk(self, profile_chunk: Any) -> None: self.add_item( Item( payload=PayloadRef(json=profile_chunk), @@ -86,72 +66,50 @@ def add_profile_chunk( ) ) - def add_checkin( - self, checkin # type: Any - ): - # type: (...) -> None + def add_checkin(self, checkin: Any) -> None: self.add_item(Item(payload=PayloadRef(json=checkin), type="check_in")) - def add_session( - self, session # type: Union[Session, Any] - ): - # type: (...) -> None + def add_session(self, session: Union[Session, Any]) -> None: if isinstance(session, Session): session = session.to_json() self.add_item(Item(payload=PayloadRef(json=session), type="session")) - def add_sessions( - self, sessions # type: Any - ): - # type: (...) -> None + def add_sessions(self, sessions: Any) -> None: self.add_item(Item(payload=PayloadRef(json=sessions), type="sessions")) - def add_item( - self, item # type: Item - ): - # type: (...) -> None + def add_item(self, item: Item) -> None: self.items.append(item) - def get_event(self): - # type: (...) -> Optional[Event] + def get_event(self) -> Optional[Event]: for items in self.items: event = items.get_event() if event is not None: return event return None - def get_transaction_event(self): - # type: (...) -> Optional[Event] + def get_transaction_event(self) -> Optional[Event]: for item in self.items: event = item.get_transaction_event() if event is not None: return event return None - def __iter__(self): - # type: (...) -> Iterator[Item] + def __iter__(self) -> Iterator[Item]: return iter(self.items) - def serialize_into( - self, f # type: Any - ): - # type: (...) -> None + def serialize_into(self, f: Any) -> None: f.write(json_dumps(self.headers)) f.write(b"\n") for item in self.items: item.serialize_into(f) - def serialize(self): - # type: (...) -> bytes + def serialize(self) -> bytes: out = io.BytesIO() self.serialize_into(out) return out.getvalue() @classmethod - def deserialize_from( - cls, f # type: Any - ): - # type: (...) -> Envelope + def deserialize_from(cls, f: Any) -> Envelope: headers = parse_json(f.readline()) items = [] while 1: @@ -162,31 +120,25 @@ def deserialize_from( return cls(headers=headers, items=items) @classmethod - def deserialize( - cls, bytes # type: bytes - ): - # type: (...) -> Envelope + def deserialize(cls, bytes: bytes) -> Envelope: return cls.deserialize_from(io.BytesIO(bytes)) - def __repr__(self): - # type: (...) -> str + def __repr__(self) -> str: return "" % (self.headers, self.items) class PayloadRef: def __init__( self, - bytes=None, # type: Optional[bytes] - path=None, # type: Optional[Union[bytes, str]] - json=None, # type: Optional[Any] - ): - # type: (...) -> None + bytes: Optional[bytes] = None, + path: Optional[Union[bytes, str]] = None, + json: Optional[Any] = None, + ) -> None: self.json = json self.bytes = bytes self.path = path - def get_bytes(self): - # type: (...) -> bytes + def get_bytes(self) -> bytes: if self.bytes is None: if self.path is not None: with capture_internal_exceptions(): @@ -197,8 +149,7 @@ def get_bytes(self): return self.bytes or b"" @property - def inferred_content_type(self): - # type: (...) -> str + def inferred_content_type(self) -> str: if self.json is not None: return "application/json" elif self.path is not None: @@ -210,20 +161,19 @@ def inferred_content_type(self): return ty return "application/octet-stream" - def __repr__(self): - # type: (...) -> str + def __repr__(self) -> str: return "" % (self.inferred_content_type,) class Item: def __init__( self, - payload, # type: Union[bytes, str, PayloadRef] - headers=None, # type: Optional[Dict[str, Any]] - type=None, # type: Optional[str] - content_type=None, # type: Optional[str] - filename=None, # type: Optional[str] - ): + payload: Union[bytes, str, PayloadRef], + headers: Optional[Dict[str, Any]] = None, + type: Optional[str] = None, + content_type: Optional[str] = None, + filename: Optional[str] = None, + ) -> None: if headers is not None: headers = dict(headers) elif headers is None: @@ -247,8 +197,7 @@ def __init__( self.payload = payload - def __repr__(self): - # type: (...) -> str + def __repr__(self) -> str: return "" % ( self.headers, self.payload, @@ -256,13 +205,11 @@ def __repr__(self): ) @property - def type(self): - # type: (...) -> Optional[str] + def type(self) -> Optional[str]: return self.headers.get("type") @property - def data_category(self): - # type: (...) -> EventDataCategory + def data_category(self) -> EventDataCategory: ty = self.headers.get("type") if ty == "session" or ty == "sessions": return "session" @@ -285,12 +232,10 @@ def data_category(self): else: return "default" - def get_bytes(self): - # type: (...) -> bytes + def get_bytes(self) -> bytes: return self.payload.get_bytes() - def get_event(self): - # type: (...) -> Optional[Event] + def get_event(self) -> Optional[Event]: """ Returns an error event if there is one. """ @@ -298,16 +243,12 @@ def get_event(self): return self.payload.json return None - def get_transaction_event(self): - # type: (...) -> Optional[Event] + def get_transaction_event(self) -> Optional[Event]: if self.type == "transaction" and self.payload.json is not None: return self.payload.json return None - def serialize_into( - self, f # type: Any - ): - # type: (...) -> None + def serialize_into(self, f: Any) -> None: headers = dict(self.headers) bytes = self.get_bytes() headers["length"] = len(bytes) @@ -316,17 +257,13 @@ def serialize_into( f.write(bytes) f.write(b"\n") - def serialize(self): - # type: (...) -> bytes + def serialize(self) -> bytes: out = io.BytesIO() self.serialize_into(out) return out.getvalue() @classmethod - def deserialize_from( - cls, f # type: Any - ): - # type: (...) -> Optional[Item] + def deserialize_from(cls, f: Any) -> Optional[Item]: line = f.readline().rstrip() if not line: return None @@ -346,8 +283,5 @@ def deserialize_from( return rv @classmethod - def deserialize( - cls, bytes # type: bytes - ): - # type: (...) -> Optional[Item] + def deserialize(cls, bytes: bytes) -> Optional[Item]: return cls.deserialize_from(io.BytesIO(bytes)) diff --git a/sentry_sdk/feature_flags.py b/sentry_sdk/feature_flags.py index efc92661e7..2f0660a80f 100644 --- a/sentry_sdk/feature_flags.py +++ b/sentry_sdk/feature_flags.py @@ -1,23 +1,22 @@ +from __future__ import annotations import copy import sentry_sdk from sentry_sdk._lru_cache import LRUCache from threading import Lock -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import TypedDict + from typing import Any, TypedDict FlagData = TypedDict("FlagData", {"flag": str, "result": bool}) - DEFAULT_FLAG_CAPACITY = 100 class FlagBuffer: - def __init__(self, capacity): - # type: (int) -> None + def __init__(self, capacity: int) -> None: self.capacity = capacity self.lock = Lock() @@ -25,26 +24,22 @@ def __init__(self, capacity): # directly you're on your own! self.__buffer = LRUCache(capacity) - def clear(self): - # type: () -> None + def clear(self) -> None: self.__buffer = LRUCache(self.capacity) - def __deepcopy__(self, memo): - # type: (dict[int, Any]) -> FlagBuffer + def __deepcopy__(self, memo: dict[int, Any]) -> FlagBuffer: with self.lock: buffer = FlagBuffer(self.capacity) buffer.__buffer = copy.deepcopy(self.__buffer, memo) return buffer - def get(self): - # type: () -> list[FlagData] + def get(self) -> list[FlagData]: with self.lock: return [ {"flag": key, "result": value} for key, value in self.__buffer.get_all() ] - def set(self, flag, result): - # type: (str, bool) -> None + def set(self, flag: str, result: bool) -> None: if isinstance(result, FlagBuffer): # If someone were to insert `self` into `self` this would create a circular dependency # on the lock. This is of course a deadlock. However, this is far outside the expected @@ -58,8 +53,7 @@ def set(self, flag, result): self.__buffer.set(flag, result) -def add_feature_flag(flag, result): - # type: (str, bool) -> None +def add_feature_flag(flag: str, result: bool) -> None: """ Records a flag and its value to be sent on subsequent error events. We recommend you do this on flag evaluations. Flags are buffered per Sentry scope. diff --git a/sentry_sdk/integrations/__init__.py b/sentry_sdk/integrations/__init__.py index f2d1a28522..5485ebe4c3 100644 --- a/sentry_sdk/integrations/__init__.py +++ b/sentry_sdk/integrations/__init__.py @@ -1,3 +1,4 @@ +from __future__ import annotations from abc import ABC, abstractmethod from threading import Lock @@ -23,20 +24,20 @@ _installer_lock = Lock() # Set of all integration identifiers we have attempted to install -_processed_integrations = set() # type: Set[str] +_processed_integrations: Set[str] = set() # Set of all integration identifiers we have actually installed -_installed_integrations = set() # type: Set[str] +_installed_integrations: Set[str] = set() def _generate_default_integrations_iterator( - integrations, # type: List[str] - auto_enabling_integrations, # type: List[str] -): - # type: (...) -> Callable[[bool], Iterator[Type[Integration]]] + integrations: List[str], + auto_enabling_integrations: List[str], +) -> Callable[[bool], Iterator[Type[Integration]]]: - def iter_default_integrations(with_auto_enabling_integrations): - # type: (bool) -> Iterator[Type[Integration]] + def iter_default_integrations( + with_auto_enabling_integrations: bool, + ) -> Iterator[Type[Integration]]: """Returns an iterator of the default integration classes:""" from importlib import import_module @@ -165,12 +166,13 @@ def iter_default_integrations(with_auto_enabling_integrations): def setup_integrations( - integrations, - with_defaults=True, - with_auto_enabling_integrations=False, - disabled_integrations=None, -): - # type: (Sequence[Integration], bool, bool, Optional[Sequence[Union[type[Integration], Integration]]]) -> Dict[str, Integration] + integrations: Sequence[Integration], + with_defaults: bool = True, + with_auto_enabling_integrations: bool = False, + disabled_integrations: Optional[ + Sequence[Union[type[Integration], Integration]] + ] = None, +) -> Dict[str, Integration]: """ Given a list of integration instances, this installs them all. @@ -239,8 +241,11 @@ def setup_integrations( return integrations -def _check_minimum_version(integration, version, package=None): - # type: (type[Integration], Optional[tuple[int, ...]], Optional[str]) -> None +def _check_minimum_version( + integration: type[Integration], + version: Optional[tuple[int, ...]], + package: Optional[str] = None, +) -> None: package = package or integration.identifier if version is None: @@ -276,13 +281,12 @@ class Integration(ABC): install = None """Legacy method, do not implement.""" - identifier = None # type: str + identifier: str """String unique ID of integration type""" @staticmethod @abstractmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: """ Initialize the integration. diff --git a/sentry_sdk/integrations/_asgi_common.py b/sentry_sdk/integrations/_asgi_common.py index 22aa17de0b..efa67bda05 100644 --- a/sentry_sdk/integrations/_asgi_common.py +++ b/sentry_sdk/integrations/_asgi_common.py @@ -1,26 +1,25 @@ +from __future__ import annotations import urllib from sentry_sdk.scope import should_send_default_pii from sentry_sdk.integrations._wsgi_common import _filter_headers -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Any from typing import Dict from typing import Optional from typing import Union - from typing_extensions import Literal from sentry_sdk.utils import AnnotatedValue -def _get_headers(asgi_scope): - # type: (Any) -> Dict[str, str] +def _get_headers(asgi_scope: Any) -> Dict[str, str]: """ Extract headers from the ASGI scope, in the format that the Sentry protocol expects. """ - headers = {} # type: Dict[str, str] + headers: Dict[str, str] = {} for raw_key, raw_value in asgi_scope.get("headers", {}): key = raw_key.decode("latin-1") value = raw_value.decode("latin-1") @@ -32,12 +31,16 @@ def _get_headers(asgi_scope): return headers -def _get_url(asgi_scope, default_scheme=None, host=None): - # type: (Dict[str, Any], Optional[Literal["ws", "http"]], Optional[Union[AnnotatedValue, str]]) -> str +def _get_url( + asgi_scope: Dict[str, Any], + host: Optional[Union[AnnotatedValue, str]] = None, +) -> str: """ Extract URL from the ASGI scope, without also including the querystring. """ - scheme = cast(str, asgi_scope.get("scheme", default_scheme)) + scheme = asgi_scope.get( + "scheme", "http" if asgi_scope.get("type") == "http" else "ws" + ) server = asgi_scope.get("server", None) path = asgi_scope.get("root_path", "") + asgi_scope.get("path", "") @@ -53,8 +56,7 @@ def _get_url(asgi_scope, default_scheme=None, host=None): return path -def _get_query(asgi_scope): - # type: (Any) -> Any +def _get_query(asgi_scope: Any) -> Any: """ Extract querystring from the ASGI scope, in the format that the Sentry protocol expects. """ @@ -64,8 +66,7 @@ def _get_query(asgi_scope): return urllib.parse.unquote(qs.decode("latin-1")) -def _get_ip(asgi_scope): - # type: (Any) -> str +def _get_ip(asgi_scope: Any) -> str: """ Extract IP Address from the ASGI scope based on request headers with fallback to scope client. """ @@ -83,12 +84,11 @@ def _get_ip(asgi_scope): return asgi_scope.get("client")[0] -def _get_request_data(asgi_scope): - # type: (Any) -> Dict[str, Any] +def _get_request_data(asgi_scope: Any) -> Dict[str, Any]: """ Returns data related to the HTTP request from the ASGI scope. """ - request_data = {} # type: Dict[str, Any] + request_data: Dict[str, Any] = {} ty = asgi_scope["type"] if ty in ("http", "websocket"): request_data["method"] = asgi_scope.get("method") @@ -96,9 +96,7 @@ def _get_request_data(asgi_scope): request_data["headers"] = headers = _filter_headers(_get_headers(asgi_scope)) request_data["query_string"] = _get_query(asgi_scope) - request_data["url"] = _get_url( - asgi_scope, "http" if ty == "http" else "ws", headers.get("host") - ) + request_data["url"] = _get_url(asgi_scope, headers.get("host")) client = asgi_scope.get("client") if client and should_send_default_pii(): diff --git a/sentry_sdk/integrations/_wsgi_common.py b/sentry_sdk/integrations/_wsgi_common.py index 2d4a5f7b73..625deb89a5 100644 --- a/sentry_sdk/integrations/_wsgi_common.py +++ b/sentry_sdk/integrations/_wsgi_common.py @@ -1,3 +1,4 @@ +from __future__ import annotations import json from copy import deepcopy @@ -50,8 +51,9 @@ ) -def request_body_within_bounds(client, content_length): - # type: (Optional[sentry_sdk.client.BaseClient], int) -> bool +def request_body_within_bounds( + client: Optional[sentry_sdk.client.BaseClient], content_length: int +) -> bool: if client is None: return False @@ -73,17 +75,15 @@ class RequestExtractor: # it. Only some child classes implement all methods that raise # NotImplementedError in this class. - def __init__(self, request): - # type: (Any) -> None + def __init__(self, request: Any) -> None: self.request = request - def extract_into_event(self, event): - # type: (Event) -> None + def extract_into_event(self, event: Event) -> None: client = sentry_sdk.get_client() if not client.is_active(): return - data = None # type: Optional[Union[AnnotatedValue, Dict[str, Any]]] + data: Optional[Union[AnnotatedValue, Dict[str, Any]]] = None content_length = self.content_length() request_info = event.get("request", {}) @@ -119,27 +119,22 @@ def extract_into_event(self, event): event["request"] = deepcopy(request_info) - def content_length(self): - # type: () -> int + def content_length(self) -> int: try: return int(self.env().get("CONTENT_LENGTH", 0)) except ValueError: return 0 - def cookies(self): - # type: () -> MutableMapping[str, Any] + def cookies(self) -> MutableMapping[str, Any]: raise NotImplementedError() - def raw_data(self): - # type: () -> Optional[Union[str, bytes]] + def raw_data(self) -> Optional[Union[str, bytes]]: raise NotImplementedError() - def form(self): - # type: () -> Optional[Dict[str, Any]] + def form(self) -> Optional[Dict[str, Any]]: raise NotImplementedError() - def parsed_body(self): - # type: () -> Optional[Dict[str, Any]] + def parsed_body(self) -> Optional[Dict[str, Any]]: try: form = self.form() except Exception: @@ -161,12 +156,10 @@ def parsed_body(self): return self.json() - def is_json(self): - # type: () -> bool + def is_json(self) -> bool: return _is_json_content_type(self.env().get("CONTENT_TYPE")) - def json(self): - # type: () -> Optional[Any] + def json(self) -> Optional[Any]: try: if not self.is_json(): return None @@ -190,21 +183,17 @@ def json(self): return None - def files(self): - # type: () -> Optional[Dict[str, Any]] + def files(self) -> Optional[Dict[str, Any]]: raise NotImplementedError() - def size_of_file(self, file): - # type: (Any) -> int + def size_of_file(self, file: Any) -> int: raise NotImplementedError() - def env(self): - # type: () -> Dict[str, Any] + def env(self) -> Dict[str, Any]: raise NotImplementedError() -def _is_json_content_type(ct): - # type: (Optional[str]) -> bool +def _is_json_content_type(ct: Optional[str]) -> bool: mt = (ct or "").split(";", 1)[0] return ( mt == "application/json" @@ -213,8 +202,9 @@ def _is_json_content_type(ct): ) -def _filter_headers(headers): - # type: (Mapping[str, str]) -> Mapping[str, Union[AnnotatedValue, str]] +def _filter_headers( + headers: Mapping[str, str], +) -> Mapping[str, Union[AnnotatedValue, str]]: if should_send_default_pii(): return headers @@ -228,8 +218,7 @@ def _filter_headers(headers): } -def _request_headers_to_span_attributes(headers): - # type: (dict[str, str]) -> dict[str, str] +def _request_headers_to_span_attributes(headers: dict[str, str]) -> dict[str, str]: attributes = {} headers = _filter_headers(headers) diff --git a/sentry_sdk/integrations/aiohttp.py b/sentry_sdk/integrations/aiohttp.py index 5e89658acd..5417baf4cf 100644 --- a/sentry_sdk/integrations/aiohttp.py +++ b/sentry_sdk/integrations/aiohttp.py @@ -1,3 +1,4 @@ +from __future__ import annotations import sys import weakref from functools import wraps @@ -85,11 +86,10 @@ class AioHttpIntegration(Integration): def __init__( self, - transaction_style="handler_name", # type: str + transaction_style: str = "handler_name", *, - failed_request_status_codes=_DEFAULT_FAILED_REQUEST_STATUS_CODES, # type: Set[int] - ): - # type: (...) -> None + failed_request_status_codes: Set[int] = _DEFAULT_FAILED_REQUEST_STATUS_CODES, + ) -> None: if transaction_style not in TRANSACTION_STYLE_VALUES: raise ValueError( "Invalid value for transaction_style: %s (must be in %s)" @@ -99,8 +99,7 @@ def __init__( self._failed_request_status_codes = failed_request_status_codes @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: version = parse_version(AIOHTTP_VERSION) _check_minimum_version(AioHttpIntegration, version) @@ -117,8 +116,9 @@ def setup_once(): old_handle = Application._handle - async def sentry_app_handle(self, request, *args, **kwargs): - # type: (Any, Request, *Any, **Any) -> Any + async def sentry_app_handle( + self: Any, request: Request, *args: Any, **kwargs: Any + ) -> Any: integration = sentry_sdk.get_client().get_integration(AioHttpIntegration) if integration is None: return await old_handle(self, request, *args, **kwargs) @@ -172,8 +172,9 @@ async def sentry_app_handle(self, request, *args, **kwargs): old_urldispatcher_resolve = UrlDispatcher.resolve @wraps(old_urldispatcher_resolve) - async def sentry_urldispatcher_resolve(self, request): - # type: (UrlDispatcher, Request) -> UrlMappingMatchInfo + async def sentry_urldispatcher_resolve( + self: UrlDispatcher, request: Request + ) -> UrlMappingMatchInfo: rv = await old_urldispatcher_resolve(self, request) integration = sentry_sdk.get_client().get_integration(AioHttpIntegration) @@ -205,8 +206,7 @@ async def sentry_urldispatcher_resolve(self, request): old_client_session_init = ClientSession.__init__ @ensure_integration_enabled(AioHttpIntegration, old_client_session_init) - def init(*args, **kwargs): - # type: (Any, Any) -> None + def init(*args: Any, **kwargs: Any) -> None: client_trace_configs = list(kwargs.get("trace_configs") or ()) trace_config = create_trace_config() client_trace_configs.append(trace_config) @@ -217,11 +217,13 @@ def init(*args, **kwargs): ClientSession.__init__ = init -def create_trace_config(): - # type: () -> TraceConfig +def create_trace_config() -> TraceConfig: - async def on_request_start(session, trace_config_ctx, params): - # type: (ClientSession, SimpleNamespace, TraceRequestStartParams) -> None + async def on_request_start( + session: ClientSession, + trace_config_ctx: SimpleNamespace, + params: TraceRequestStartParams, + ) -> None: if sentry_sdk.get_client().get_integration(AioHttpIntegration) is None: return @@ -277,8 +279,11 @@ async def on_request_start(session, trace_config_ctx, params): trace_config_ctx.span = span trace_config_ctx.span_data = data - async def on_request_end(session, trace_config_ctx, params): - # type: (ClientSession, SimpleNamespace, TraceRequestEndParams) -> None + async def on_request_end( + session: ClientSession, + trace_config_ctx: SimpleNamespace, + params: TraceRequestEndParams, + ) -> None: if trace_config_ctx.span is None: return @@ -307,13 +312,13 @@ async def on_request_end(session, trace_config_ctx, params): return trace_config -def _make_request_processor(weak_request): - # type: (weakref.ReferenceType[Request]) -> EventProcessor +def _make_request_processor( + weak_request: weakref.ReferenceType[Request], +) -> EventProcessor: def aiohttp_processor( - event, # type: Event - hint, # type: dict[str, Tuple[type, BaseException, Any]] - ): - # type: (...) -> Event + event: Event, + hint: dict[str, Tuple[type, BaseException, Any]], + ) -> Event: request = weak_request() if request is None: return event @@ -342,8 +347,7 @@ def aiohttp_processor( return aiohttp_processor -def _capture_exception(): - # type: () -> ExcInfo +def _capture_exception() -> ExcInfo: exc_info = sys.exc_info() event, hint = event_from_exception( exc_info, @@ -357,8 +361,7 @@ def _capture_exception(): BODY_NOT_READ_MESSAGE = "[Can't show request body due to implementation details.]" -def get_aiohttp_request_data(request): - # type: (Request) -> Union[Optional[str], AnnotatedValue] +def get_aiohttp_request_data(request: Request) -> Union[Optional[str], AnnotatedValue]: bytes_body = request._read_bytes if bytes_body is not None: @@ -377,8 +380,7 @@ def get_aiohttp_request_data(request): return None -def _prepopulate_attributes(request): - # type: (Request) -> dict[str, Any] +def _prepopulate_attributes(request: Request) -> dict[str, Any]: """Construct initial span attributes that can be used in traces sampler.""" attributes = {} diff --git a/sentry_sdk/integrations/anthropic.py b/sentry_sdk/integrations/anthropic.py index 454b6f93ca..bc1126db12 100644 --- a/sentry_sdk/integrations/anthropic.py +++ b/sentry_sdk/integrations/anthropic.py @@ -1,3 +1,4 @@ +from __future__ import annotations from functools import wraps from typing import TYPE_CHECKING @@ -29,13 +30,11 @@ class AnthropicIntegration(Integration): identifier = "anthropic" origin = f"auto.ai.{identifier}" - def __init__(self, include_prompts=True): - # type: (AnthropicIntegration, bool) -> None + def __init__(self: AnthropicIntegration, include_prompts: bool = True) -> None: self.include_prompts = include_prompts @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: version = package_version("anthropic") _check_minimum_version(AnthropicIntegration, version) @@ -43,8 +42,7 @@ def setup_once(): AsyncMessages.create = _wrap_message_create_async(AsyncMessages.create) -def _capture_exception(exc): - # type: (Any) -> None +def _capture_exception(exc: Any) -> None: event, hint = event_from_exception( exc, client_options=sentry_sdk.get_client().options, @@ -53,8 +51,7 @@ def _capture_exception(exc): sentry_sdk.capture_event(event, hint=hint) -def _calculate_token_usage(result, span): - # type: (Messages, Span) -> None +def _calculate_token_usage(result: Messages, span: Span) -> None: input_tokens = 0 output_tokens = 0 if hasattr(result, "usage"): @@ -68,8 +65,7 @@ def _calculate_token_usage(result, span): record_token_usage(span, input_tokens, output_tokens, total_tokens) -def _get_responses(content): - # type: (list[Any]) -> list[dict[str, Any]] +def _get_responses(content: list[Any]) -> list[dict[str, Any]]: """ Get JSON of a Anthropic responses. """ @@ -85,8 +81,12 @@ def _get_responses(content): return responses -def _collect_ai_data(event, input_tokens, output_tokens, content_blocks): - # type: (MessageStreamEvent, int, int, list[str]) -> tuple[int, int, list[str]] +def _collect_ai_data( + event: MessageStreamEvent, + input_tokens: int, + output_tokens: int, + content_blocks: list[str], +) -> tuple[int, int, list[str]]: """ Count token usage and collect content blocks from the AI streaming response. """ @@ -112,9 +112,12 @@ def _collect_ai_data(event, input_tokens, output_tokens, content_blocks): def _add_ai_data_to_span( - span, integration, input_tokens, output_tokens, content_blocks -): - # type: (Span, AnthropicIntegration, int, int, list[str]) -> None + span: Span, + integration: AnthropicIntegration, + input_tokens: int, + output_tokens: int, + content_blocks: list[str], +) -> None: """ Add token usage and content blocks from the AI streaming response to the span. """ @@ -130,8 +133,7 @@ def _add_ai_data_to_span( span.set_attribute(SPANDATA.AI_STREAMING, True) -def _sentry_patched_create_common(f, *args, **kwargs): - # type: (Any, *Any, **Any) -> Any +def _sentry_patched_create_common(f: Any, *args: Any, **kwargs: Any) -> Any: integration = kwargs.pop("integration") if integration is None: return f(*args, **kwargs) @@ -177,11 +179,10 @@ def _sentry_patched_create_common(f, *args, **kwargs): elif hasattr(result, "_iterator"): old_iterator = result._iterator - def new_iterator(): - # type: () -> Iterator[MessageStreamEvent] + def new_iterator() -> Iterator[MessageStreamEvent]: input_tokens = 0 output_tokens = 0 - content_blocks = [] # type: list[str] + content_blocks: list[str] = [] for event in old_iterator: input_tokens, output_tokens, content_blocks = _collect_ai_data( @@ -194,11 +195,10 @@ def new_iterator(): ) span.__exit__(None, None, None) - async def new_iterator_async(): - # type: () -> AsyncIterator[MessageStreamEvent] + async def new_iterator_async() -> AsyncIterator[MessageStreamEvent]: input_tokens = 0 output_tokens = 0 - content_blocks = [] # type: list[str] + content_blocks: list[str] = [] async for event in old_iterator: input_tokens, output_tokens, content_blocks = _collect_ai_data( @@ -223,10 +223,8 @@ async def new_iterator_async(): return result -def _wrap_message_create(f): - # type: (Any) -> Any - def _execute_sync(f, *args, **kwargs): - # type: (Any, *Any, **Any) -> Any +def _wrap_message_create(f: Any) -> Any: + def _execute_sync(f: Any, *args: Any, **kwargs: Any) -> Any: gen = _sentry_patched_create_common(f, *args, **kwargs) try: @@ -246,8 +244,7 @@ def _execute_sync(f, *args, **kwargs): return e.value @wraps(f) - def _sentry_patched_create_sync(*args, **kwargs): - # type: (*Any, **Any) -> Any + def _sentry_patched_create_sync(*args: Any, **kwargs: Any) -> Any: integration = sentry_sdk.get_client().get_integration(AnthropicIntegration) kwargs["integration"] = integration @@ -256,10 +253,8 @@ def _sentry_patched_create_sync(*args, **kwargs): return _sentry_patched_create_sync -def _wrap_message_create_async(f): - # type: (Any) -> Any - async def _execute_async(f, *args, **kwargs): - # type: (Any, *Any, **Any) -> Any +def _wrap_message_create_async(f: Any) -> Any: + async def _execute_async(f: Any, *args: Any, **kwargs: Any) -> Any: gen = _sentry_patched_create_common(f, *args, **kwargs) try: @@ -279,8 +274,7 @@ async def _execute_async(f, *args, **kwargs): return e.value @wraps(f) - async def _sentry_patched_create_async(*args, **kwargs): - # type: (*Any, **Any) -> Any + async def _sentry_patched_create_async(*args: Any, **kwargs: Any) -> Any: integration = sentry_sdk.get_client().get_integration(AnthropicIntegration) kwargs["integration"] = integration diff --git a/sentry_sdk/integrations/argv.py b/sentry_sdk/integrations/argv.py index 315feefb4a..bf139bb219 100644 --- a/sentry_sdk/integrations/argv.py +++ b/sentry_sdk/integrations/argv.py @@ -1,3 +1,4 @@ +from __future__ import annotations import sys import sentry_sdk @@ -16,11 +17,9 @@ class ArgvIntegration(Integration): identifier = "argv" @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: @add_global_event_processor - def processor(event, hint): - # type: (Event, Optional[Hint]) -> Optional[Event] + def processor(event: Event, hint: Optional[Hint]) -> Optional[Event]: if sentry_sdk.get_client().get_integration(ArgvIntegration) is not None: extra = event.setdefault("extra", {}) # If some event processor decided to set extra to e.g. an diff --git a/sentry_sdk/integrations/ariadne.py b/sentry_sdk/integrations/ariadne.py index 1a95bc0145..77a3aa2d9d 100644 --- a/sentry_sdk/integrations/ariadne.py +++ b/sentry_sdk/integrations/ariadne.py @@ -1,3 +1,4 @@ +from __future__ import annotations from importlib import import_module import sentry_sdk @@ -33,8 +34,7 @@ class AriadneIntegration(Integration): identifier = "ariadne" @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: version = package_version("ariadne") _check_minimum_version(AriadneIntegration, version) @@ -43,15 +43,15 @@ def setup_once(): _patch_graphql() -def _patch_graphql(): - # type: () -> None +def _patch_graphql() -> None: old_parse_query = ariadne_graphql.parse_query old_handle_errors = ariadne_graphql.handle_graphql_errors old_handle_query_result = ariadne_graphql.handle_query_result @ensure_integration_enabled(AriadneIntegration, old_parse_query) - def _sentry_patched_parse_query(context_value, query_parser, data): - # type: (Optional[Any], Optional[QueryParser], Any) -> DocumentNode + def _sentry_patched_parse_query( + context_value: Optional[Any], query_parser: Optional[QueryParser], data: Any + ) -> DocumentNode: event_processor = _make_request_event_processor(data) sentry_sdk.get_isolation_scope().add_event_processor(event_processor) @@ -59,8 +59,9 @@ def _sentry_patched_parse_query(context_value, query_parser, data): return result @ensure_integration_enabled(AriadneIntegration, old_handle_errors) - def _sentry_patched_handle_graphql_errors(errors, *args, **kwargs): - # type: (List[GraphQLError], Any, Any) -> GraphQLResult + def _sentry_patched_handle_graphql_errors( + errors: List[GraphQLError], *args: Any, **kwargs: Any + ) -> GraphQLResult: result = old_handle_errors(errors, *args, **kwargs) event_processor = _make_response_event_processor(result[1]) @@ -83,8 +84,9 @@ def _sentry_patched_handle_graphql_errors(errors, *args, **kwargs): return result @ensure_integration_enabled(AriadneIntegration, old_handle_query_result) - def _sentry_patched_handle_query_result(result, *args, **kwargs): - # type: (Any, Any, Any) -> GraphQLResult + def _sentry_patched_handle_query_result( + result: Any, *args: Any, **kwargs: Any + ) -> GraphQLResult: query_result = old_handle_query_result(result, *args, **kwargs) event_processor = _make_response_event_processor(query_result[1]) @@ -111,12 +113,10 @@ def _sentry_patched_handle_query_result(result, *args, **kwargs): ariadne_graphql.handle_query_result = _sentry_patched_handle_query_result # type: ignore -def _make_request_event_processor(data): - # type: (GraphQLSchema) -> EventProcessor +def _make_request_event_processor(data: GraphQLSchema) -> EventProcessor: """Add request data and api_target to events.""" - def inner(event, hint): - # type: (Event, dict[str, Any]) -> Event + def inner(event: Event, hint: dict[str, Any]) -> Event: if not isinstance(data, dict): return event @@ -143,12 +143,10 @@ def inner(event, hint): return inner -def _make_response_event_processor(response): - # type: (Dict[str, Any]) -> EventProcessor +def _make_response_event_processor(response: Dict[str, Any]) -> EventProcessor: """Add response data to the event's response context.""" - def inner(event, hint): - # type: (Event, dict[str, Any]) -> Event + def inner(event: Event, hint: dict[str, Any]) -> Event: with capture_internal_exceptions(): if should_send_default_pii() and response.get("errors"): contexts = event.setdefault("contexts", {}) diff --git a/sentry_sdk/integrations/arq.py b/sentry_sdk/integrations/arq.py index b7d3c67b46..e356d914e0 100644 --- a/sentry_sdk/integrations/arq.py +++ b/sentry_sdk/integrations/arq.py @@ -1,3 +1,4 @@ +from __future__ import annotations import sys import sentry_sdk @@ -45,8 +46,7 @@ class ArqIntegration(Integration): origin = f"auto.queue.{identifier}" @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: try: if isinstance(ARQ_VERSION, str): @@ -66,13 +66,13 @@ def setup_once(): ignore_logger("arq.worker") -def patch_enqueue_job(): - # type: () -> None +def patch_enqueue_job() -> None: old_enqueue_job = ArqRedis.enqueue_job original_kwdefaults = old_enqueue_job.__kwdefaults__ - async def _sentry_enqueue_job(self, function, *args, **kwargs): - # type: (ArqRedis, str, *Any, **Any) -> Optional[Job] + async def _sentry_enqueue_job( + self: ArqRedis, function: str, *args: Any, **kwargs: Any + ) -> Optional[Job]: integration = sentry_sdk.get_client().get_integration(ArqIntegration) if integration is None: return await old_enqueue_job(self, function, *args, **kwargs) @@ -89,12 +89,10 @@ async def _sentry_enqueue_job(self, function, *args, **kwargs): ArqRedis.enqueue_job = _sentry_enqueue_job -def patch_run_job(): - # type: () -> None +def patch_run_job() -> None: old_run_job = Worker.run_job - async def _sentry_run_job(self, job_id, score): - # type: (Worker, str, int) -> None + async def _sentry_run_job(self: Worker, job_id: str, score: int) -> None: integration = sentry_sdk.get_client().get_integration(ArqIntegration) if integration is None: return await old_run_job(self, job_id, score) @@ -123,8 +121,7 @@ async def _sentry_run_job(self, job_id, score): Worker.run_job = _sentry_run_job -def _capture_exception(exc_info): - # type: (ExcInfo) -> None +def _capture_exception(exc_info: ExcInfo) -> None: scope = sentry_sdk.get_current_scope() if scope.root_span is not None: @@ -142,10 +139,10 @@ def _capture_exception(exc_info): sentry_sdk.capture_event(event, hint=hint) -def _make_event_processor(ctx, *args, **kwargs): - # type: (Dict[Any, Any], *Any, **Any) -> EventProcessor - def event_processor(event, hint): - # type: (Event, Hint) -> Optional[Event] +def _make_event_processor( + ctx: Dict[Any, Any], *args: Any, **kwargs: Any +) -> EventProcessor: + def event_processor(event: Event, hint: Hint) -> Optional[Event]: with capture_internal_exceptions(): scope = sentry_sdk.get_current_scope() @@ -173,11 +170,9 @@ def event_processor(event, hint): return event_processor -def _wrap_coroutine(name, coroutine): - # type: (str, WorkerCoroutine) -> WorkerCoroutine +def _wrap_coroutine(name: str, coroutine: WorkerCoroutine) -> WorkerCoroutine: - async def _sentry_coroutine(ctx, *args, **kwargs): - # type: (Dict[Any, Any], *Any, **Any) -> Any + async def _sentry_coroutine(ctx: Dict[Any, Any], *args: Any, **kwargs: Any) -> Any: integration = sentry_sdk.get_client().get_integration(ArqIntegration) if integration is None: return await coroutine(ctx, *args, **kwargs) @@ -198,13 +193,11 @@ async def _sentry_coroutine(ctx, *args, **kwargs): return _sentry_coroutine -def patch_create_worker(): - # type: () -> None +def patch_create_worker() -> None: old_create_worker = arq.worker.create_worker @ensure_integration_enabled(ArqIntegration, old_create_worker) - def _sentry_create_worker(*args, **kwargs): - # type: (*Any, **Any) -> Worker + def _sentry_create_worker(*args: Any, **kwargs: Any) -> Worker: settings_cls = args[0] if isinstance(settings_cls, dict): @@ -243,16 +236,14 @@ def _sentry_create_worker(*args, **kwargs): arq.worker.create_worker = _sentry_create_worker -def _get_arq_function(func): - # type: (Union[str, Function, WorkerCoroutine]) -> Function +def _get_arq_function(func: Union[str, Function, WorkerCoroutine]) -> Function: arq_func = arq.worker.func(func) arq_func.coroutine = _wrap_coroutine(arq_func.name, arq_func.coroutine) return arq_func -def _get_arq_cron_job(cron_job): - # type: (CronJob) -> CronJob +def _get_arq_cron_job(cron_job: CronJob) -> CronJob: cron_job.coroutine = _wrap_coroutine(cron_job.name, cron_job.coroutine) return cron_job diff --git a/sentry_sdk/integrations/asgi.py b/sentry_sdk/integrations/asgi.py index 5769f88408..29a42afe3c 100644 --- a/sentry_sdk/integrations/asgi.py +++ b/sentry_sdk/integrations/asgi.py @@ -4,6 +4,7 @@ Based on Tom Christie's `sentry-asgi `. """ +from __future__ import annotations import asyncio import inspect from copy import deepcopy @@ -61,8 +62,7 @@ } -def _capture_exception(exc, mechanism_type="asgi"): - # type: (Any, str) -> None +def _capture_exception(exc: Any, mechanism_type: str = "asgi") -> None: event, hint = event_from_exception( exc, @@ -72,8 +72,7 @@ def _capture_exception(exc, mechanism_type="asgi"): sentry_sdk.capture_event(event, hint=hint) -def _looks_like_asgi3(app): - # type: (Any) -> bool +def _looks_like_asgi3(app: Any) -> bool: """ Try to figure out if an application object supports ASGI3. @@ -100,14 +99,13 @@ class SentryAsgiMiddleware: def __init__( self, - app, # type: Any - unsafe_context_data=False, # type: bool - transaction_style="endpoint", # type: str - mechanism_type="asgi", # type: str - span_origin=None, # type: Optional[str] - http_methods_to_capture=DEFAULT_HTTP_METHODS_TO_CAPTURE, # type: Tuple[str, ...] - ): - # type: (...) -> None + app: Any, + unsafe_context_data: bool = False, + transaction_style: str = "endpoint", + mechanism_type: str = "asgi", + span_origin: Optional[str] = None, + http_methods_to_capture: Tuple[str, ...] = DEFAULT_HTTP_METHODS_TO_CAPTURE, + ) -> None: """ Instrument an ASGI application with Sentry. Provides HTTP/websocket data to sent events and basic handling for exceptions bubbling up @@ -145,42 +143,41 @@ def __init__( self.http_methods_to_capture = http_methods_to_capture if _looks_like_asgi3(app): - self.__call__ = self._run_asgi3 # type: Callable[..., Any] + self.__call__: Callable[..., Any] = self._run_asgi3 else: self.__call__ = self._run_asgi2 - def _capture_lifespan_exception(self, exc): - # type: (Exception) -> None + def _capture_lifespan_exception(self, exc: Exception) -> None: """Capture exceptions raise in application lifespan handlers. The separate function is needed to support overriding in derived integrations that use different catching mechanisms. """ return _capture_exception(exc=exc, mechanism_type=self.mechanism_type) - def _capture_request_exception(self, exc): - # type: (Exception) -> None + def _capture_request_exception(self, exc: Exception) -> None: """Capture exceptions raised in incoming request handlers. The separate function is needed to support overriding in derived integrations that use different catching mechanisms. """ return _capture_exception(exc=exc, mechanism_type=self.mechanism_type) - def _run_asgi2(self, scope): - # type: (Any) -> Any - async def inner(receive, send): - # type: (Any, Any) -> Any + def _run_asgi2(self, scope: Any) -> Any: + async def inner(receive: Any, send: Any) -> Any: return await self._run_app(scope, receive, send, asgi_version=2) return inner - async def _run_asgi3(self, scope, receive, send): - # type: (Any, Any, Any) -> Any + async def _run_asgi3(self, scope: Any, receive: Any, send: Any) -> Any: return await self._run_app(scope, receive, send, asgi_version=3) async def _run_original_app( - self, scope, receive, send, asgi_version, is_lifespan=False - ): - # type: (Any, Any, Any, Any, int) -> Any + self, + scope: Any, + receive: Any, + send: Any, + asgi_version: Any, + is_lifespan: int = False, + ) -> Any: try: if asgi_version == 2: return await self.app(scope)(receive, send) @@ -194,8 +191,9 @@ async def _run_original_app( self._capture_request_exception(exc) raise exc from None - async def _run_app(self, scope, receive, send, asgi_version): - # type: (Any, Any, Any, int) -> Any + async def _run_app( + self, scope: Any, receive: Any, send: Any, asgi_version: int + ) -> Any: is_recursive_asgi_middleware = _asgi_middleware_applied.get(False) is_lifespan = scope["type"] == "lifespan" if is_recursive_asgi_middleware or is_lifespan: @@ -251,8 +249,9 @@ async def _run_app(self, scope, receive, send, asgi_version): logger.debug("[ASGI] Started transaction: %s", span) span.set_tag("asgi.type", ty) - async def _sentry_wrapped_send(event): - # type: (Dict[str, Any]) -> Any + async def _sentry_wrapped_send( + event: Dict[str, Any], + ) -> Any: is_http_response = ( event.get("type") == "http.response.start" and span is not None @@ -273,8 +272,9 @@ async def _sentry_wrapped_send(event): finally: _asgi_middleware_applied.set(False) - def event_processor(self, event, hint, asgi_scope): - # type: (Event, Hint, Any) -> Optional[Event] + def event_processor( + self, event: Event, hint: Hint, asgi_scope: Any + ) -> Optional[Event]: request_data = event.get("request", {}) request_data.update(_get_request_data(asgi_scope)) event["request"] = deepcopy(request_data) @@ -313,11 +313,11 @@ def event_processor(self, event, hint, asgi_scope): # data to your liking it's recommended to use the `before_send` callback # for that. - def _get_transaction_name_and_source(self, transaction_style, asgi_scope): - # type: (SentryAsgiMiddleware, str, Any) -> Tuple[str, str] + def _get_transaction_name_and_source( + self: SentryAsgiMiddleware, transaction_style: str, asgi_scope: Any + ) -> Tuple[str, str]: name = None source = SOURCE_FOR_STYLE[transaction_style] - ty = asgi_scope.get("type") if transaction_style == "endpoint": endpoint = asgi_scope.get("endpoint") @@ -327,7 +327,7 @@ def _get_transaction_name_and_source(self, transaction_style, asgi_scope): if endpoint: name = transaction_from_function(endpoint) or "" else: - name = _get_url(asgi_scope, "http" if ty == "http" else "ws", host=None) + name = _get_url(asgi_scope) source = TransactionSource.URL elif transaction_style == "url": @@ -339,7 +339,7 @@ def _get_transaction_name_and_source(self, transaction_style, asgi_scope): if path is not None: name = path else: - name = _get_url(asgi_scope, "http" if ty == "http" else "ws", host=None) + name = _get_url(asgi_scope) source = TransactionSource.URL if name is None: @@ -350,8 +350,7 @@ def _get_transaction_name_and_source(self, transaction_style, asgi_scope): return name, source -def _prepopulate_attributes(scope): - # type: (Any) -> dict[str, Any] +def _prepopulate_attributes(scope: Any) -> dict[str, Any]: """Unpack ASGI scope into serializable OTel attributes.""" scope = scope or {} diff --git a/sentry_sdk/integrations/asyncio.py b/sentry_sdk/integrations/asyncio.py index d287ce6118..4f44983e61 100644 --- a/sentry_sdk/integrations/asyncio.py +++ b/sentry_sdk/integrations/asyncio.py @@ -1,3 +1,4 @@ +from __future__ import annotations import sys import sentry_sdk @@ -11,7 +12,7 @@ except ImportError: raise DidNotEnable("asyncio not available") -from typing import cast, TYPE_CHECKING +from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Any @@ -20,8 +21,7 @@ from sentry_sdk._types import ExcInfo -def get_name(coro): - # type: (Any) -> str +def get_name(coro: Any) -> str: return ( getattr(coro, "__qualname__", None) or getattr(coro, "__name__", None) @@ -29,18 +29,19 @@ def get_name(coro): ) -def patch_asyncio(): - # type: () -> None +def patch_asyncio() -> None: orig_task_factory = None try: loop = asyncio.get_running_loop() orig_task_factory = loop.get_task_factory() - def _sentry_task_factory(loop, coro, **kwargs): - # type: (asyncio.AbstractEventLoop, Coroutine[Any, Any, Any], Any) -> asyncio.Future[Any] + def _sentry_task_factory( + loop: asyncio.AbstractEventLoop, + coro: Coroutine[Any, Any, Any], + **kwargs: Any, + ) -> asyncio.Future[Any]: - async def _task_with_sentry_span_creation(): - # type: () -> Any + async def _task_with_sentry_span_creation() -> Any: result = None with sentry_sdk.isolation_scope(): @@ -79,9 +80,8 @@ async def _task_with_sentry_span_creation(): # Set the task name to include the original coroutine's name try: - cast("asyncio.Task[Any]", task).set_name( - f"{get_name(coro)} (Sentry-wrapped)" - ) + if isinstance(task, asyncio.Task): + task.set_name(f"{get_name(coro)} (Sentry-wrapped)") except AttributeError: # set_name might not be available in all Python versions pass @@ -100,8 +100,7 @@ async def _task_with_sentry_span_creation(): ) -def _capture_exception(): - # type: () -> ExcInfo +def _capture_exception() -> ExcInfo: exc_info = sys.exc_info() client = sentry_sdk.get_client() @@ -123,6 +122,5 @@ class AsyncioIntegration(Integration): origin = f"auto.function.{identifier}" @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: patch_asyncio() diff --git a/sentry_sdk/integrations/atexit.py b/sentry_sdk/integrations/atexit.py index dfc6d08e1a..de60d15dcc 100644 --- a/sentry_sdk/integrations/atexit.py +++ b/sentry_sdk/integrations/atexit.py @@ -1,3 +1,4 @@ +from __future__ import annotations import os import sys import atexit @@ -12,15 +13,13 @@ from typing import Optional -def default_callback(pending, timeout): - # type: (int, int) -> None +def default_callback(pending: int, timeout: int) -> None: """This is the default shutdown callback that is set on the options. It prints out a message to stderr that informs the user that some events are still pending and the process is waiting for them to flush out. """ - def echo(msg): - # type: (str) -> None + def echo(msg: str) -> None: sys.stderr.write(msg + "\n") echo("Sentry is attempting to send %i pending events" % pending) @@ -32,18 +31,15 @@ def echo(msg): class AtexitIntegration(Integration): identifier = "atexit" - def __init__(self, callback=None): - # type: (Optional[Any]) -> None + def __init__(self, callback: Optional[Any] = None) -> None: if callback is None: callback = default_callback self.callback = callback @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: @atexit.register - def _shutdown(): - # type: () -> None + def _shutdown() -> None: client = sentry_sdk.get_client() integration = client.get_integration(AtexitIntegration) diff --git a/sentry_sdk/integrations/aws_lambda.py b/sentry_sdk/integrations/aws_lambda.py index 66d14b22a3..7d39cc3a78 100644 --- a/sentry_sdk/integrations/aws_lambda.py +++ b/sentry_sdk/integrations/aws_lambda.py @@ -1,3 +1,4 @@ +from __future__ import annotations import functools import json import re @@ -54,11 +55,9 @@ } -def _wrap_init_error(init_error): - # type: (F) -> F +def _wrap_init_error(init_error: F) -> F: @ensure_integration_enabled(AwsLambdaIntegration, init_error) - def sentry_init_error(*args, **kwargs): - # type: (*Any, **Any) -> Any + def sentry_init_error(*args: Any, **kwargs: Any) -> Any: client = sentry_sdk.get_client() with capture_internal_exceptions(): @@ -86,11 +85,11 @@ def sentry_init_error(*args, **kwargs): return sentry_init_error # type: ignore -def _wrap_handler(handler): - # type: (F) -> F +def _wrap_handler(handler: F) -> F: @functools.wraps(handler) - def sentry_handler(aws_event, aws_context, *args, **kwargs): - # type: (Any, Any, *Any, **Any) -> Any + def sentry_handler( + aws_event: Any, aws_context: Any, *args: Any, **kwargs: Any + ) -> Any: # Per https://docs.aws.amazon.com/lambda/latest/dg/python-handler.html, # `event` here is *likely* a dictionary, but also might be a number of @@ -192,8 +191,7 @@ def sentry_handler(aws_event, aws_context, *args, **kwargs): return sentry_handler # type: ignore -def _drain_queue(): - # type: () -> None +def _drain_queue() -> None: with capture_internal_exceptions(): client = sentry_sdk.get_client() integration = client.get_integration(AwsLambdaIntegration) @@ -207,13 +205,11 @@ class AwsLambdaIntegration(Integration): identifier = "aws_lambda" origin = f"auto.function.{identifier}" - def __init__(self, timeout_warning=False): - # type: (bool) -> None + def __init__(self, timeout_warning: bool = False) -> None: self.timeout_warning = timeout_warning @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: lambda_bootstrap = get_lambda_bootstrap() if not lambda_bootstrap: @@ -249,10 +245,8 @@ def sentry_handle_event_request( # type: ignore # Patch the runtime client to drain the queue. This should work # even when the SDK is initialized inside of the handler - def _wrap_post_function(f): - # type: (F) -> F - def inner(*args, **kwargs): - # type: (*Any, **Any) -> Any + def _wrap_post_function(f: F) -> F: + def inner(*args: Any, **kwargs: Any) -> Any: _drain_queue() return f(*args, **kwargs) @@ -270,8 +264,7 @@ def inner(*args, **kwargs): ) -def get_lambda_bootstrap(): - # type: () -> Optional[Any] +def get_lambda_bootstrap() -> Optional[Any]: # Python 3.7: If the bootstrap module is *already imported*, it is the # one we actually want to use (no idea what's in __main__) @@ -307,12 +300,14 @@ def get_lambda_bootstrap(): return None -def _make_request_event_processor(aws_event, aws_context, configured_timeout): - # type: (Any, Any, Any) -> EventProcessor +def _make_request_event_processor( + aws_event: Any, aws_context: Any, configured_timeout: Any +) -> EventProcessor: start_time = datetime.now(timezone.utc) - def event_processor(sentry_event, hint, start_time=start_time): - # type: (Event, Hint, datetime) -> Optional[Event] + def event_processor( + sentry_event: Event, hint: Hint, start_time: datetime = start_time + ) -> Optional[Event]: remaining_time_in_milis = aws_context.get_remaining_time_in_millis() exec_duration = configured_timeout - remaining_time_in_milis @@ -375,8 +370,7 @@ def event_processor(sentry_event, hint, start_time=start_time): return event_processor -def _get_url(aws_event, aws_context): - # type: (Any, Any) -> str +def _get_url(aws_event: Any, aws_context: Any) -> str: path = aws_event.get("path", None) headers = aws_event.get("headers") @@ -392,8 +386,7 @@ def _get_url(aws_event, aws_context): return "awslambda:///{}".format(aws_context.function_name) -def _get_cloudwatch_logs_url(aws_context, start_time): - # type: (Any, datetime) -> str +def _get_cloudwatch_logs_url(aws_context: Any, start_time: datetime) -> str: """ Generates a CloudWatchLogs console URL based on the context object @@ -424,8 +417,7 @@ def _get_cloudwatch_logs_url(aws_context, start_time): return url -def _parse_formatted_traceback(formatted_tb): - # type: (list[str]) -> list[dict[str, Any]] +def _parse_formatted_traceback(formatted_tb: list[str]) -> list[dict[str, Any]]: frames = [] for frame in formatted_tb: match = re.match(r'File "(.+)", line (\d+), in (.+)', frame.strip()) @@ -446,8 +438,7 @@ def _parse_formatted_traceback(formatted_tb): return frames -def _event_from_error_json(error_json): - # type: (dict[str, Any]) -> Event +def _event_from_error_json(error_json: dict[str, Any]) -> Event: """ Converts the error JSON from AWS Lambda into a Sentry error event. This is not a full fletched event, but better than nothing. @@ -455,7 +446,7 @@ def _event_from_error_json(error_json): This is an example of where AWS creates the error JSON: https://github.com/aws/aws-lambda-python-runtime-interface-client/blob/2.2.1/awslambdaric/bootstrap.py#L479 """ - event = { + event: Event = { "level": "error", "exception": { "values": [ @@ -474,13 +465,12 @@ def _event_from_error_json(error_json): } ], }, - } # type: Event + } return event -def _prepopulate_attributes(aws_event, aws_context): - # type: (Any, Any) -> dict[str, Any] +def _prepopulate_attributes(aws_event: Any, aws_context: Any) -> dict[str, Any]: attributes = { "cloud.provider": "aws", } diff --git a/sentry_sdk/integrations/beam.py b/sentry_sdk/integrations/beam.py index a2e4553f5a..fd37111be2 100644 --- a/sentry_sdk/integrations/beam.py +++ b/sentry_sdk/integrations/beam.py @@ -1,3 +1,4 @@ +from __future__ import annotations import sys import types from functools import wraps @@ -35,8 +36,7 @@ class BeamIntegration(Integration): identifier = "beam" @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: from apache_beam.transforms.core import DoFn, ParDo # type: ignore ignore_logger("root") @@ -52,8 +52,7 @@ def setup_once(): old_init = ParDo.__init__ - def sentry_init_pardo(self, fn, *args, **kwargs): - # type: (ParDo, Any, *Any, **Any) -> Any + def sentry_init_pardo(self: ParDo, fn: Any, *args: Any, **kwargs: Any) -> Any: # Do not monkey patch init twice if not getattr(self, "_sentry_is_patched", False): for func_name in function_patches: @@ -79,14 +78,12 @@ def sentry_init_pardo(self, fn, *args, **kwargs): ParDo.__init__ = sentry_init_pardo -def _wrap_inspect_call(cls, func_name): - # type: (Any, Any) -> Any +def _wrap_inspect_call(cls: Any, func_name: Any) -> Any: if not hasattr(cls, func_name): return None - def _inspect(self): - # type: (Any) -> Any + def _inspect(self: Any) -> Any: """ Inspect function overrides the way Beam gets argspec. """ @@ -113,15 +110,13 @@ def _inspect(self): return _inspect -def _wrap_task_call(func): - # type: (F) -> F +def _wrap_task_call(func: F) -> F: """ Wrap task call with a try catch to get exceptions. """ @wraps(func) - def _inner(*args, **kwargs): - # type: (*Any, **Any) -> Any + def _inner(*args: Any, **kwargs: Any) -> Any: try: gen = func(*args, **kwargs) except Exception: @@ -136,8 +131,7 @@ def _inner(*args, **kwargs): @ensure_integration_enabled(BeamIntegration) -def _capture_exception(exc_info): - # type: (ExcInfo) -> None +def _capture_exception(exc_info: ExcInfo) -> None: """ Send Beam exception to Sentry. """ @@ -151,8 +145,7 @@ def _capture_exception(exc_info): sentry_sdk.capture_event(event, hint=hint) -def raise_exception(): - # type: () -> None +def raise_exception() -> None: """ Raise an exception. """ @@ -162,8 +155,7 @@ def raise_exception(): reraise(*exc_info) -def _wrap_generator_call(gen): - # type: (Iterator[T]) -> Iterator[T] +def _wrap_generator_call(gen: Iterator[T]) -> Iterator[T]: """ Wrap the generator to handle any failures. """ diff --git a/sentry_sdk/integrations/boto3.py b/sentry_sdk/integrations/boto3.py index 65239b7548..c5eddc5841 100644 --- a/sentry_sdk/integrations/boto3.py +++ b/sentry_sdk/integrations/boto3.py @@ -1,3 +1,4 @@ +from __future__ import annotations from functools import partial import sentry_sdk @@ -34,15 +35,15 @@ class Boto3Integration(Integration): origin = f"auto.http.{identifier}" @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: version = parse_version(BOTOCORE_VERSION) _check_minimum_version(Boto3Integration, version, "botocore") orig_init = BaseClient.__init__ - def sentry_patched_init(self, *args, **kwargs): - # type: (Type[BaseClient], *Any, **Any) -> None + def sentry_patched_init( + self: Type[BaseClient], *args: Any, **kwargs: Any + ) -> None: orig_init(self, *args, **kwargs) meta = self.meta service_id = meta.service_model.service_id.hyphenize() @@ -57,8 +58,9 @@ def sentry_patched_init(self, *args, **kwargs): @ensure_integration_enabled(Boto3Integration) -def _sentry_request_created(service_id, request, operation_name, **kwargs): - # type: (str, AWSRequest, str, **Any) -> None +def _sentry_request_created( + service_id: str, request: AWSRequest, operation_name: str, **kwargs: Any +) -> None: description = "aws.%s.%s" % (service_id, operation_name) span = sentry_sdk.start_span( op=OP.HTTP_CLIENT, @@ -92,9 +94,10 @@ def _sentry_request_created(service_id, request, operation_name, **kwargs): request.context["_sentrysdk_span_data"] = data -def _sentry_after_call(context, parsed, **kwargs): - # type: (Dict[str, Any], Dict[str, Any], **Any) -> None - span = context.pop("_sentrysdk_span", None) # type: Optional[Span] +def _sentry_after_call( + context: Dict[str, Any], parsed: Dict[str, Any], **kwargs: Any +) -> None: + span: Optional[Span] = context.pop("_sentrysdk_span", None) # Span could be absent if the integration is disabled. if span is None: @@ -122,8 +125,7 @@ def _sentry_after_call(context, parsed, **kwargs): orig_read = body.read - def sentry_streaming_body_read(*args, **kwargs): - # type: (*Any, **Any) -> bytes + def sentry_streaming_body_read(*args: Any, **kwargs: Any) -> bytes: try: ret = orig_read(*args, **kwargs) if not ret: @@ -137,8 +139,7 @@ def sentry_streaming_body_read(*args, **kwargs): orig_close = body.close - def sentry_streaming_body_close(*args, **kwargs): - # type: (*Any, **Any) -> None + def sentry_streaming_body_close(*args: Any, **kwargs: Any) -> None: streaming_span.finish() orig_close(*args, **kwargs) @@ -147,9 +148,10 @@ def sentry_streaming_body_close(*args, **kwargs): span.__exit__(None, None, None) -def _sentry_after_call_error(context, exception, **kwargs): - # type: (Dict[str, Any], Type[BaseException], **Any) -> None - span = context.pop("_sentrysdk_span", None) # type: Optional[Span] +def _sentry_after_call_error( + context: Dict[str, Any], exception: Type[BaseException], **kwargs: Any +) -> None: + span: Optional[Span] = context.pop("_sentrysdk_span", None) # Span could be absent if the integration is disabled. if span is None: diff --git a/sentry_sdk/integrations/bottle.py b/sentry_sdk/integrations/bottle.py index 1fefcf0319..cdc36f50d6 100644 --- a/sentry_sdk/integrations/bottle.py +++ b/sentry_sdk/integrations/bottle.py @@ -1,3 +1,4 @@ +from __future__ import annotations import functools import sentry_sdk @@ -55,11 +56,10 @@ class BottleIntegration(Integration): def __init__( self, - transaction_style="endpoint", # type: str + transaction_style: str = "endpoint", *, - failed_request_status_codes=_DEFAULT_FAILED_REQUEST_STATUS_CODES, # type: Set[int] - ): - # type: (...) -> None + failed_request_status_codes: Set[int] = _DEFAULT_FAILED_REQUEST_STATUS_CODES, + ) -> None: if transaction_style not in TRANSACTION_STYLE_VALUES: raise ValueError( @@ -70,16 +70,16 @@ def __init__( self.failed_request_status_codes = failed_request_status_codes @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: version = parse_version(BOTTLE_VERSION) _check_minimum_version(BottleIntegration, version) old_app = Bottle.__call__ @ensure_integration_enabled(BottleIntegration, old_app) - def sentry_patched_wsgi_app(self, environ, start_response): - # type: (Any, Dict[str, str], Callable[..., Any]) -> _ScopedResponse + def sentry_patched_wsgi_app( + self: Any, environ: Dict[str, str], start_response: Callable[..., Any] + ) -> _ScopedResponse: middleware = SentryWsgiMiddleware( lambda *a, **kw: old_app(self, *a, **kw), span_origin=BottleIntegration.origin, @@ -92,8 +92,7 @@ def sentry_patched_wsgi_app(self, environ, start_response): old_handle = Bottle._handle @functools.wraps(old_handle) - def _patched_handle(self, environ): - # type: (Bottle, Dict[str, Any]) -> Any + def _patched_handle(self: Bottle, environ: Dict[str, Any]) -> Any: integration = sentry_sdk.get_client().get_integration(BottleIntegration) if integration is None: return old_handle(self, environ) @@ -112,16 +111,14 @@ def _patched_handle(self, environ): old_make_callback = Route._make_callback @functools.wraps(old_make_callback) - def patched_make_callback(self, *args, **kwargs): - # type: (Route, *object, **object) -> Any + def patched_make_callback(self: Route, *args: object, **kwargs: object) -> Any: prepared_callback = old_make_callback(self, *args, **kwargs) integration = sentry_sdk.get_client().get_integration(BottleIntegration) if integration is None: return prepared_callback - def wrapped_callback(*args, **kwargs): - # type: (*object, **object) -> Any + def wrapped_callback(*args: object, **kwargs: object) -> Any: try: res = prepared_callback(*args, **kwargs) except Exception as exception: @@ -142,38 +139,33 @@ def wrapped_callback(*args, **kwargs): class BottleRequestExtractor(RequestExtractor): - def env(self): - # type: () -> Dict[str, str] + def env(self) -> Dict[str, str]: return self.request.environ - def cookies(self): - # type: () -> Dict[str, str] + def cookies(self) -> Dict[str, str]: return self.request.cookies - def raw_data(self): - # type: () -> bytes + def raw_data(self) -> bytes: return self.request.body.read() - def form(self): - # type: () -> FormsDict + def form(self) -> FormsDict: if self.is_json(): return None return self.request.forms.decode() - def files(self): - # type: () -> Optional[Dict[str, str]] + def files(self) -> Optional[Dict[str, str]]: if self.is_json(): return None return self.request.files - def size_of_file(self, file): - # type: (FileUpload) -> int + def size_of_file(self, file: FileUpload) -> int: return file.content_length -def _set_transaction_name_and_source(event, transaction_style, request): - # type: (Event, str, Any) -> None +def _set_transaction_name_and_source( + event: Event, transaction_style: str, request: Any +) -> None: name = "" if transaction_style == "url": @@ -196,11 +188,11 @@ def _set_transaction_name_and_source(event, transaction_style, request): event["transaction_info"] = {"source": SOURCE_FOR_STYLE[transaction_style]} -def _make_request_event_processor(app, request, integration): - # type: (Bottle, LocalRequest, BottleIntegration) -> EventProcessor +def _make_request_event_processor( + app: Bottle, request: LocalRequest, integration: BottleIntegration +) -> EventProcessor: - def event_processor(event, hint): - # type: (Event, dict[str, Any]) -> Event + def event_processor(event: Event, hint: dict[str, Any]) -> Event: _set_transaction_name_and_source(event, integration.transaction_style, request) with capture_internal_exceptions(): @@ -211,8 +203,7 @@ def event_processor(event, hint): return event_processor -def _capture_exception(exception, handled): - # type: (BaseException, bool) -> None +def _capture_exception(exception: BaseException, handled: bool) -> None: event, hint = event_from_exception( exception, client_options=sentry_sdk.get_client().options, diff --git a/sentry_sdk/integrations/celery/__init__.py b/sentry_sdk/integrations/celery/__init__.py index f078f629da..de9bb45422 100644 --- a/sentry_sdk/integrations/celery/__init__.py +++ b/sentry_sdk/integrations/celery/__init__.py @@ -1,3 +1,4 @@ +from __future__ import annotations import sys from collections.abc import Mapping from functools import wraps @@ -62,11 +63,10 @@ class CeleryIntegration(Integration): def __init__( self, - propagate_traces=True, - monitor_beat_tasks=False, - exclude_beat_tasks=None, - ): - # type: (bool, bool, Optional[List[str]]) -> None + propagate_traces: bool = True, + monitor_beat_tasks: bool = False, + exclude_beat_tasks: Optional[List[str]] = None, + ) -> None: self.propagate_traces = propagate_traces self.monitor_beat_tasks = monitor_beat_tasks self.exclude_beat_tasks = exclude_beat_tasks @@ -76,8 +76,7 @@ def __init__( _setup_celery_beat_signals(monitor_beat_tasks) @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: _check_minimum_version(CeleryIntegration, CELERY_VERSION) _patch_build_tracer() @@ -97,16 +96,14 @@ def setup_once(): ignore_logger("celery.redirected") -def _set_status(status): - # type: (str) -> None +def _set_status(status: str) -> None: with capture_internal_exceptions(): span = sentry_sdk.get_current_span() if span is not None: span.set_status(status) -def _capture_exception(task, exc_info): - # type: (Any, ExcInfo) -> None +def _capture_exception(task: Any, exc_info: ExcInfo) -> None: client = sentry_sdk.get_client() if client.get_integration(CeleryIntegration) is None: return @@ -129,10 +126,10 @@ def _capture_exception(task, exc_info): sentry_sdk.capture_event(event, hint=hint) -def _make_event_processor(task, uuid, args, kwargs, request=None): - # type: (Any, Any, Any, Any, Optional[Any]) -> EventProcessor - def event_processor(event, hint): - # type: (Event, Hint) -> Optional[Event] +def _make_event_processor( + task: Any, uuid: Any, args: Any, kwargs: Any, request: Optional[Any] = None +) -> EventProcessor: + def event_processor(event: Event, hint: Hint) -> Optional[Event]: with capture_internal_exceptions(): tags = event.setdefault("tags", {}) @@ -158,8 +155,9 @@ def event_processor(event, hint): return event_processor -def _update_celery_task_headers(original_headers, span, monitor_beat_tasks): - # type: (dict[str, Any], Optional[Span], bool) -> dict[str, Any] +def _update_celery_task_headers( + original_headers: dict[str, Any], span: Optional[Span], monitor_beat_tasks: bool +) -> dict[str, Any]: """ Updates the headers of the Celery task with the tracing information and eventually Sentry Crons monitoring information for beat tasks. @@ -233,20 +231,16 @@ def _update_celery_task_headers(original_headers, span, monitor_beat_tasks): class NoOpMgr: - def __enter__(self): - # type: () -> None + def __enter__(self) -> None: return None - def __exit__(self, exc_type, exc_value, traceback): - # type: (Any, Any, Any) -> None + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: return None -def _wrap_task_run(f): - # type: (F) -> F +def _wrap_task_run(f: F) -> F: @wraps(f) - def apply_async(*args, **kwargs): - # type: (*Any, **Any) -> Any + def apply_async(*args: Any, **kwargs: Any) -> Any: # Note: kwargs can contain headers=None, so no setdefault! # Unsure which backend though. integration = sentry_sdk.get_client().get_integration(CeleryIntegration) @@ -262,7 +256,7 @@ def apply_async(*args, **kwargs): return f(*args, **kwargs) if isinstance(args[0], Task): - task_name = args[0].name # type: str + task_name: str = args[0].name elif len(args) > 1 and isinstance(args[1], str): task_name = args[1] else: @@ -270,7 +264,7 @@ def apply_async(*args, **kwargs): task_started_from_beat = sentry_sdk.get_isolation_scope()._name == "celery-beat" - span_mgr = ( + span_mgr: Union[Span, NoOpMgr] = ( sentry_sdk.start_span( op=OP.QUEUE_SUBMIT_CELERY, name=task_name, @@ -279,7 +273,7 @@ def apply_async(*args, **kwargs): ) if not task_started_from_beat else NoOpMgr() - ) # type: Union[Span, NoOpMgr] + ) with span_mgr as span: kwargs["headers"] = _update_celery_task_headers( @@ -290,8 +284,7 @@ def apply_async(*args, **kwargs): return apply_async # type: ignore -def _wrap_tracer(task, f): - # type: (Any, F) -> F +def _wrap_tracer(task: Any, f: F) -> F: # Need to wrap tracer for pushing the scope before prerun is sent, and # popping it after postrun is sent. @@ -301,8 +294,7 @@ def _wrap_tracer(task, f): # crashes. @wraps(f) @ensure_integration_enabled(CeleryIntegration, f) - def _inner(*args, **kwargs): - # type: (*Any, **Any) -> Any + def _inner(*args: Any, **kwargs: Any) -> Any: with isolation_scope() as scope: scope._name = "celery" scope.clear_breadcrumbs() @@ -333,8 +325,7 @@ def _inner(*args, **kwargs): return _inner # type: ignore -def _set_messaging_destination_name(task, span): - # type: (Any, Span) -> None +def _set_messaging_destination_name(task: Any, span: Span) -> None: """Set "messaging.destination.name" tag for span""" with capture_internal_exceptions(): delivery_info = task.request.delivery_info @@ -346,8 +337,7 @@ def _set_messaging_destination_name(task, span): span.set_attribute(SPANDATA.MESSAGING_DESTINATION_NAME, routing_key) -def _wrap_task_call(task, f): - # type: (Any, F) -> F +def _wrap_task_call(task: Any, f: F) -> F: # Need to wrap task call because the exception is caught before we get to # see it. Also celery's reported stacktrace is untrustworthy. @@ -358,8 +348,7 @@ def _wrap_task_call(task, f): # to add @functools.wraps(f) here. # https://github.com/getsentry/sentry-python/issues/421 @ensure_integration_enabled(CeleryIntegration, f) - def _inner(*args, **kwargs): - # type: (*Any, **Any) -> Any + def _inner(*args: Any, **kwargs: Any) -> Any: try: with sentry_sdk.start_span( op=OP.QUEUE_PROCESS, @@ -409,14 +398,12 @@ def _inner(*args, **kwargs): return _inner # type: ignore -def _patch_build_tracer(): - # type: () -> None +def _patch_build_tracer() -> None: import celery.app.trace as trace # type: ignore original_build_tracer = trace.build_tracer - def sentry_build_tracer(name, task, *args, **kwargs): - # type: (Any, Any, *Any, **Any) -> Any + def sentry_build_tracer(name: Any, task: Any, *args: Any, **kwargs: Any) -> Any: if not getattr(task, "_sentry_is_patched", False): # determine whether Celery will use __call__ or run and patch # accordingly @@ -435,20 +422,17 @@ def sentry_build_tracer(name, task, *args, **kwargs): trace.build_tracer = sentry_build_tracer -def _patch_task_apply_async(): - # type: () -> None +def _patch_task_apply_async() -> None: Task.apply_async = _wrap_task_run(Task.apply_async) -def _patch_celery_send_task(): - # type: () -> None +def _patch_celery_send_task() -> None: from celery import Celery Celery.send_task = _wrap_task_run(Celery.send_task) -def _patch_worker_exit(): - # type: () -> None +def _patch_worker_exit() -> None: # Need to flush queue before worker shutdown because a crashing worker will # call os._exit @@ -456,8 +440,7 @@ def _patch_worker_exit(): original_workloop = Worker.workloop - def sentry_workloop(*args, **kwargs): - # type: (*Any, **Any) -> Any + def sentry_workloop(*args: Any, **kwargs: Any) -> Any: try: return original_workloop(*args, **kwargs) finally: @@ -471,13 +454,11 @@ def sentry_workloop(*args, **kwargs): Worker.workloop = sentry_workloop -def _patch_producer_publish(): - # type: () -> None +def _patch_producer_publish() -> None: original_publish = Producer.publish @ensure_integration_enabled(CeleryIntegration, original_publish) - def sentry_publish(self, *args, **kwargs): - # type: (Producer, *Any, **Any) -> Any + def sentry_publish(self: Producer, *args: Any, **kwargs: Any) -> Any: kwargs_headers = kwargs.get("headers", {}) if not isinstance(kwargs_headers, Mapping): # Ensure kwargs_headers is a Mapping, so we can safely call get(). @@ -521,8 +502,7 @@ def sentry_publish(self, *args, **kwargs): Producer.publish = sentry_publish -def _prepopulate_attributes(task, args, kwargs): - # type: (Any, *Any, **Any) -> dict[str, str] +def _prepopulate_attributes(task: Any, args: Any, kwargs: Any) -> dict[str, str]: attributes = { "celery.job.task": task.name, } diff --git a/sentry_sdk/integrations/celery/beat.py b/sentry_sdk/integrations/celery/beat.py index 4b7e45e6f0..b0c28f7bc8 100644 --- a/sentry_sdk/integrations/celery/beat.py +++ b/sentry_sdk/integrations/celery/beat.py @@ -1,3 +1,4 @@ +from __future__ import annotations import sentry_sdk from sentry_sdk.crons import capture_checkin, MonitorStatus from sentry_sdk.integrations import DidNotEnable @@ -42,8 +43,7 @@ RedBeatScheduler = None -def _get_headers(task): - # type: (Task) -> dict[str, Any] +def _get_headers(task: Task) -> dict[str, Any]: headers = task.request.get("headers") or {} # flatten nested headers @@ -56,12 +56,13 @@ def _get_headers(task): return headers -def _get_monitor_config(celery_schedule, app, monitor_name): - # type: (Any, Celery, str) -> MonitorConfig - monitor_config = {} # type: MonitorConfig - schedule_type = None # type: Optional[MonitorConfigScheduleType] - schedule_value = None # type: Optional[Union[str, int]] - schedule_unit = None # type: Optional[MonitorConfigScheduleUnit] +def _get_monitor_config( + celery_schedule: Any, app: Celery, monitor_name: str +) -> MonitorConfig: + monitor_config: MonitorConfig = {} + schedule_type: Optional[MonitorConfigScheduleType] = None + schedule_value: Optional[Union[str, int]] = None + schedule_unit: Optional[MonitorConfigScheduleUnit] = None if isinstance(celery_schedule, crontab): schedule_type = "crontab" @@ -113,8 +114,11 @@ def _get_monitor_config(celery_schedule, app, monitor_name): return monitor_config -def _apply_crons_data_to_schedule_entry(scheduler, schedule_entry, integration): - # type: (Any, Any, sentry_sdk.integrations.celery.CeleryIntegration) -> None +def _apply_crons_data_to_schedule_entry( + scheduler: Any, + schedule_entry: Any, + integration: sentry_sdk.integrations.celery.CeleryIntegration, +) -> None: """ Add Sentry Crons information to the schedule_entry headers. """ @@ -158,8 +162,7 @@ def _apply_crons_data_to_schedule_entry(scheduler, schedule_entry, integration): schedule_entry.options["headers"] = headers -def _wrap_beat_scheduler(original_function): - # type: (Callable[..., Any]) -> Callable[..., Any] +def _wrap_beat_scheduler(original_function: Callable[..., Any]) -> Callable[..., Any]: """ Makes sure that: - a new Sentry trace is started for each task started by Celery Beat and @@ -178,8 +181,7 @@ def _wrap_beat_scheduler(original_function): from sentry_sdk.integrations.celery import CeleryIntegration - def sentry_patched_scheduler(*args, **kwargs): - # type: (*Any, **Any) -> None + def sentry_patched_scheduler(*args: Any, **kwargs: Any) -> None: integration = sentry_sdk.get_client().get_integration(CeleryIntegration) if integration is None: return original_function(*args, **kwargs) @@ -197,29 +199,25 @@ def sentry_patched_scheduler(*args, **kwargs): return sentry_patched_scheduler -def _patch_beat_apply_entry(): - # type: () -> None +def _patch_beat_apply_entry() -> None: Scheduler.apply_entry = _wrap_beat_scheduler(Scheduler.apply_entry) -def _patch_redbeat_apply_async(): - # type: () -> None +def _patch_redbeat_apply_async() -> None: if RedBeatScheduler is None: return RedBeatScheduler.apply_async = _wrap_beat_scheduler(RedBeatScheduler.apply_async) -def _setup_celery_beat_signals(monitor_beat_tasks): - # type: (bool) -> None +def _setup_celery_beat_signals(monitor_beat_tasks: bool) -> None: if monitor_beat_tasks: task_success.connect(crons_task_success) task_failure.connect(crons_task_failure) task_retry.connect(crons_task_retry) -def crons_task_success(sender, **kwargs): - # type: (Task, dict[Any, Any]) -> None +def crons_task_success(sender: Task, **kwargs: dict[Any, Any]) -> None: logger.debug("celery_task_success %s", sender) headers = _get_headers(sender) @@ -243,8 +241,7 @@ def crons_task_success(sender, **kwargs): ) -def crons_task_failure(sender, **kwargs): - # type: (Task, dict[Any, Any]) -> None +def crons_task_failure(sender: Task, **kwargs: dict[Any, Any]) -> None: logger.debug("celery_task_failure %s", sender) headers = _get_headers(sender) @@ -268,8 +265,7 @@ def crons_task_failure(sender, **kwargs): ) -def crons_task_retry(sender, **kwargs): - # type: (Task, dict[Any, Any]) -> None +def crons_task_retry(sender: Task, **kwargs: dict[Any, Any]) -> None: logger.debug("celery_task_retry %s", sender) headers = _get_headers(sender) diff --git a/sentry_sdk/integrations/celery/utils.py b/sentry_sdk/integrations/celery/utils.py index a1961b15bc..eb96cb9016 100644 --- a/sentry_sdk/integrations/celery/utils.py +++ b/sentry_sdk/integrations/celery/utils.py @@ -1,13 +1,20 @@ +from __future__ import annotations import time -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Any, Tuple + from typing import Any, Tuple, List from sentry_sdk._types import MonitorConfigScheduleUnit -def _now_seconds_since_epoch(): - # type: () -> float +TIME_UNITS: List[Tuple[MonitorConfigScheduleUnit, float]] = [ + ("day", 60 * 60 * 24.0), + ("hour", 60 * 60.0), + ("minute", 60.0), +] + + +def _now_seconds_since_epoch() -> float: # We cannot use `time.perf_counter()` when dealing with the duration # of a Celery task, because the start of a Celery task and # the end are recorded in different processes. @@ -16,28 +23,19 @@ def _now_seconds_since_epoch(): return time.time() -def _get_humanized_interval(seconds): - # type: (float) -> Tuple[int, MonitorConfigScheduleUnit] - TIME_UNITS = ( # noqa: N806 - ("day", 60 * 60 * 24.0), - ("hour", 60 * 60.0), - ("minute", 60.0), - ) - +def _get_humanized_interval(seconds: float) -> Tuple[int, MonitorConfigScheduleUnit]: seconds = float(seconds) for unit, divider in TIME_UNITS: if seconds >= divider: interval = int(seconds / divider) - return (interval, cast("MonitorConfigScheduleUnit", unit)) + return (interval, unit) return (int(seconds), "second") class NoOpMgr: - def __enter__(self): - # type: () -> None + def __enter__(self) -> None: return None - def __exit__(self, exc_type, exc_value, traceback): - # type: (Any, Any, Any) -> None + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: return None diff --git a/sentry_sdk/integrations/chalice.py b/sentry_sdk/integrations/chalice.py index 947e41ebf7..8a4e95ba00 100644 --- a/sentry_sdk/integrations/chalice.py +++ b/sentry_sdk/integrations/chalice.py @@ -1,3 +1,4 @@ +from __future__ import annotations import sys from functools import wraps @@ -32,8 +33,7 @@ class EventSourceHandler(ChaliceEventSourceHandler): # type: ignore - def __call__(self, event, context): - # type: (Any, Any) -> Any + def __call__(self, event: Any, context: Any) -> Any: client = sentry_sdk.get_client() with sentry_sdk.isolation_scope() as scope: @@ -56,11 +56,9 @@ def __call__(self, event, context): reraise(*exc_info) -def _get_view_function_response(app, view_function, function_args): - # type: (Any, F, Any) -> F +def _get_view_function_response(app: Any, view_function: F, function_args: Any) -> F: @wraps(view_function) - def wrapped_view_function(**function_args): - # type: (**Any) -> Any + def wrapped_view_function(**function_args: Any) -> Any: client = sentry_sdk.get_client() with sentry_sdk.isolation_scope() as scope: with capture_internal_exceptions(): @@ -99,8 +97,7 @@ class ChaliceIntegration(Integration): identifier = "chalice" @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: version = parse_version(CHALICE_VERSION) @@ -116,8 +113,9 @@ def setup_once(): RestAPIEventHandler._get_view_function_response ) - def sentry_event_response(app, view_function, function_args): - # type: (Any, F, Dict[str, Any]) -> Any + def sentry_event_response( + app: Any, view_function: F, function_args: Dict[str, Any] + ) -> Any: wrapped_view_function = _get_view_function_response( app, view_function, function_args ) diff --git a/sentry_sdk/integrations/clickhouse_driver.py b/sentry_sdk/integrations/clickhouse_driver.py index 7c908b7d6d..5d89eb0e76 100644 --- a/sentry_sdk/integrations/clickhouse_driver.py +++ b/sentry_sdk/integrations/clickhouse_driver.py @@ -1,3 +1,4 @@ +from __future__ import annotations import sentry_sdk from sentry_sdk.consts import OP, SPANDATA from sentry_sdk.integrations import _check_minimum_version, Integration, DidNotEnable @@ -9,27 +10,13 @@ ensure_integration_enabled, ) -from typing import TYPE_CHECKING, cast, Any, Dict, TypeVar +from typing import TYPE_CHECKING -# Hack to get new Python features working in older versions -# without introducing a hard dependency on `typing_extensions` -# from: https://stackoverflow.com/a/71944042/300572 if TYPE_CHECKING: - from typing import ParamSpec, Callable -else: - # Fake ParamSpec - class ParamSpec: - def __init__(self, _): - self.args = None - self.kwargs = None + from typing import ParamSpec, Callable, Any, Dict, TypeVar - # Callable[anything] will return None - class _Callable: - def __getitem__(self, _): - return None - - # Make instances - Callable = _Callable() + P = ParamSpec("P") + T = TypeVar("T") try: @@ -72,10 +59,6 @@ def setup_once() -> None: ) -P = ParamSpec("P") -T = TypeVar("T") - - def _wrap_start(f: Callable[P, T]) -> Callable[P, T]: @ensure_integration_enabled(ClickhouseDriverIntegration, f) def _inner(*args: P.args, **kwargs: P.kwargs) -> T: @@ -93,8 +76,7 @@ def _inner(*args: P.args, **kwargs: P.kwargs) -> T: connection._sentry_span = span # type: ignore[attr-defined] - data = _get_db_data(connection) - data = cast("dict[str, Any]", data) + data: dict[str, Any] = _get_db_data(connection) data["db.query.text"] = query if query_id: @@ -117,7 +99,11 @@ def _inner(*args: P.args, **kwargs: P.kwargs) -> T: def _wrap_end(f: Callable[P, T]) -> Callable[P, T]: def _inner_end(*args: P.args, **kwargs: P.kwargs) -> T: res = f(*args, **kwargs) - client = cast("clickhouse_driver.client.Client", args[0]) + + client = args[0] + if not isinstance(client, clickhouse_driver.client.Client): + return res + connection = client.connection span = getattr(connection, "_sentry_span", None) @@ -150,9 +136,11 @@ def _inner_end(*args: P.args, **kwargs: P.kwargs) -> T: def _wrap_send_data(f: Callable[P, T]) -> Callable[P, T]: def _inner_send_data(*args: P.args, **kwargs: P.kwargs) -> T: - client = cast("clickhouse_driver.client.Client", args[0]) + client = args[0] + if not isinstance(client, clickhouse_driver.client.Client): + return f(*args, **kwargs) + connection = client.connection - db_params_data = cast("list[Any]", args[2]) span = getattr(connection, "_sentry_span", None) if span is not None: @@ -160,11 +148,13 @@ def _inner_send_data(*args: P.args, **kwargs: P.kwargs) -> T: _set_on_span(span, data) if should_send_default_pii(): - saved_db_data = getattr( + saved_db_data: dict[str, Any] = getattr( connection, "_sentry_db_data", {} - ) # type: dict[str, Any] - db_params = saved_db_data.get("db.params") or [] # type: list[Any] - db_params.extend(db_params_data) + ) + db_params: list[Any] = saved_db_data.get("db.params") or [] + db_params_data = args[2] + if isinstance(db_params_data, list): + db_params.extend(db_params_data) saved_db_data["db.params"] = db_params span.set_attribute("db.params", _serialize_span_attribute(db_params)) diff --git a/sentry_sdk/integrations/cloud_resource_context.py b/sentry_sdk/integrations/cloud_resource_context.py index ca5ae47e6b..607899a5a7 100644 --- a/sentry_sdk/integrations/cloud_resource_context.py +++ b/sentry_sdk/integrations/cloud_resource_context.py @@ -1,3 +1,4 @@ +from __future__ import annotations import json import urllib3 @@ -65,13 +66,11 @@ class CloudResourceContextIntegration(Integration): gcp_metadata = None - def __init__(self, cloud_provider=""): - # type: (str) -> None + def __init__(self, cloud_provider: str = "") -> None: CloudResourceContextIntegration.cloud_provider = cloud_provider @classmethod - def _is_aws(cls): - # type: () -> bool + def _is_aws(cls) -> bool: try: r = cls.http.request( "PUT", @@ -95,8 +94,7 @@ def _is_aws(cls): return False @classmethod - def _get_aws_context(cls): - # type: () -> Dict[str, str] + def _get_aws_context(cls) -> Dict[str, str]: ctx = { "cloud.provider": CLOUD_PROVIDER.AWS, "cloud.platform": CLOUD_PLATFORM.AWS_EC2, @@ -149,8 +147,7 @@ def _get_aws_context(cls): return ctx @classmethod - def _is_gcp(cls): - # type: () -> bool + def _is_gcp(cls) -> bool: try: r = cls.http.request( "GET", @@ -174,8 +171,7 @@ def _is_gcp(cls): return False @classmethod - def _get_gcp_context(cls): - # type: () -> Dict[str, str] + def _get_gcp_context(cls) -> Dict[str, str]: ctx = { "cloud.provider": CLOUD_PROVIDER.GCP, "cloud.platform": CLOUD_PLATFORM.GCP_COMPUTE_ENGINE, @@ -229,8 +225,7 @@ def _get_gcp_context(cls): return ctx @classmethod - def _get_cloud_provider(cls): - # type: () -> str + def _get_cloud_provider(cls) -> str: if cls._is_aws(): return CLOUD_PROVIDER.AWS @@ -240,8 +235,7 @@ def _get_cloud_provider(cls): return "" @classmethod - def _get_cloud_resource_context(cls): - # type: () -> Dict[str, str] + def _get_cloud_resource_context(cls) -> Dict[str, str]: cloud_provider = ( cls.cloud_provider if cls.cloud_provider != "" @@ -253,8 +247,7 @@ def _get_cloud_resource_context(cls): return {} @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: cloud_provider = CloudResourceContextIntegration.cloud_provider unsupported_cloud_provider = ( cloud_provider != "" and cloud_provider not in context_getters.keys() diff --git a/sentry_sdk/integrations/cohere.py b/sentry_sdk/integrations/cohere.py index 6fd9fc8150..2eb22c76e8 100644 --- a/sentry_sdk/integrations/cohere.py +++ b/sentry_sdk/integrations/cohere.py @@ -1,3 +1,4 @@ +from __future__ import annotations from functools import wraps from sentry_sdk import consts @@ -70,20 +71,17 @@ class CohereIntegration(Integration): identifier = "cohere" origin = f"auto.ai.{identifier}" - def __init__(self, include_prompts=True): - # type: (CohereIntegration, bool) -> None + def __init__(self: CohereIntegration, include_prompts: bool = True) -> None: self.include_prompts = include_prompts @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: BaseCohere.chat = _wrap_chat(BaseCohere.chat, streaming=False) Client.embed = _wrap_embed(Client.embed) BaseCohere.chat_stream = _wrap_chat(BaseCohere.chat_stream, streaming=True) -def _capture_exception(exc): - # type: (Any) -> None +def _capture_exception(exc: Any) -> None: event, hint = event_from_exception( exc, client_options=sentry_sdk.get_client().options, @@ -92,11 +90,11 @@ def _capture_exception(exc): sentry_sdk.capture_event(event, hint=hint) -def _wrap_chat(f, streaming): - # type: (Callable[..., Any], bool) -> Callable[..., Any] +def _wrap_chat(f: Callable[..., Any], streaming: bool) -> Callable[..., Any]: - def collect_chat_response_fields(span, res, include_pii): - # type: (Span, NonStreamedChatResponse, bool) -> None + def collect_chat_response_fields( + span: Span, res: NonStreamedChatResponse, include_pii: bool + ) -> None: if include_pii: if hasattr(res, "text"): set_data_normalized( @@ -130,8 +128,7 @@ def collect_chat_response_fields(span, res, include_pii): set_data_normalized(span, SPANDATA.AI_WARNINGS, res.meta.warnings) @wraps(f) - def new_chat(*args, **kwargs): - # type: (*Any, **Any) -> Any + def new_chat(*args: Any, **kwargs: Any) -> Any: integration = sentry_sdk.get_client().get_integration(CohereIntegration) if ( @@ -185,8 +182,7 @@ def new_chat(*args, **kwargs): if streaming: old_iterator = res - def new_iterator(): - # type: () -> Iterator[StreamedChatResponse] + def new_iterator() -> Iterator[StreamedChatResponse]: with capture_internal_exceptions(): for x in old_iterator: @@ -220,12 +216,10 @@ def new_iterator(): return new_chat -def _wrap_embed(f): - # type: (Callable[..., Any]) -> Callable[..., Any] +def _wrap_embed(f: Callable[..., Any]) -> Callable[..., Any]: @wraps(f) - def new_embed(*args, **kwargs): - # type: (*Any, **Any) -> Any + def new_embed(*args: Any, **kwargs: Any) -> Any: integration = sentry_sdk.get_client().get_integration(CohereIntegration) if integration is None: return f(*args, **kwargs) diff --git a/sentry_sdk/integrations/dedupe.py b/sentry_sdk/integrations/dedupe.py index a115e35292..2434b531cb 100644 --- a/sentry_sdk/integrations/dedupe.py +++ b/sentry_sdk/integrations/dedupe.py @@ -1,3 +1,4 @@ +from __future__ import annotations import sentry_sdk from sentry_sdk.utils import ContextVar from sentry_sdk.integrations import Integration @@ -14,16 +15,13 @@ class DedupeIntegration(Integration): identifier = "dedupe" - def __init__(self): - # type: () -> None + def __init__(self) -> None: self._last_seen = ContextVar("last-seen") @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: @add_global_event_processor - def processor(event, hint): - # type: (Event, Optional[Hint]) -> Optional[Event] + def processor(event: Event, hint: Optional[Hint]) -> Optional[Event]: if hint is None: return event @@ -42,8 +40,7 @@ def processor(event, hint): return event @staticmethod - def reset_last_seen(): - # type: () -> None + def reset_last_seen() -> None: integration = sentry_sdk.get_client().get_integration(DedupeIntegration) if integration is None: return diff --git a/sentry_sdk/integrations/django/__init__.py b/sentry_sdk/integrations/django/__init__.py index e62ba63f70..14b42001c9 100644 --- a/sentry_sdk/integrations/django/__init__.py +++ b/sentry_sdk/integrations/django/__init__.py @@ -1,3 +1,4 @@ +from __future__ import annotations import functools import inspect import sys @@ -107,18 +108,17 @@ class DjangoIntegration(Integration): middleware_spans = None signals_spans = None cache_spans = None - signals_denylist = [] # type: list[signals.Signal] + signals_denylist: list[signals.Signal] = [] def __init__( self, - transaction_style="url", # type: str - middleware_spans=True, # type: bool - signals_spans=True, # type: bool - cache_spans=True, # type: bool - signals_denylist=None, # type: Optional[list[signals.Signal]] - http_methods_to_capture=DEFAULT_HTTP_METHODS_TO_CAPTURE, # type: tuple[str, ...] - ): - # type: (...) -> None + transaction_style: str = "url", + middleware_spans: bool = True, + signals_spans: bool = True, + cache_spans: bool = True, + signals_denylist: Optional[list[signals.Signal]] = None, + http_methods_to_capture: tuple[str, ...] = DEFAULT_HTTP_METHODS_TO_CAPTURE, + ) -> None: if transaction_style not in TRANSACTION_STYLE_VALUES: raise ValueError( "Invalid value for transaction_style: %s (must be in %s)" @@ -135,8 +135,7 @@ def __init__( self.http_methods_to_capture = tuple(map(str.upper, http_methods_to_capture)) @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: _check_minimum_version(DjangoIntegration, DJANGO_VERSION) install_sql_hook() @@ -151,8 +150,9 @@ def setup_once(): old_app = WSGIHandler.__call__ @ensure_integration_enabled(DjangoIntegration, old_app) - def sentry_patched_wsgi_handler(self, environ, start_response): - # type: (Any, Dict[str, str], Callable[..., Any]) -> _ScopedResponse + def sentry_patched_wsgi_handler( + self: Any, environ: Dict[str, str], start_response: Callable[..., Any] + ) -> _ScopedResponse: bound_old_app = old_app.__get__(self, WSGIHandler) from django.conf import settings @@ -182,8 +182,9 @@ def sentry_patched_wsgi_handler(self, environ, start_response): signals.got_request_exception.connect(_got_request_exception) @add_global_event_processor - def process_django_templates(event, hint): - # type: (Event, Optional[Hint]) -> Optional[Event] + def process_django_templates( + event: Event, hint: Optional[Hint] + ) -> Optional[Event]: if hint is None: return event @@ -225,8 +226,9 @@ def process_django_templates(event, hint): return event @add_global_repr_processor - def _django_queryset_repr(value, hint): - # type: (Any, Dict[str, Any]) -> Union[NotImplementedType, str] + def _django_queryset_repr( + value: Any, hint: Dict[str, Any] + ) -> Union[NotImplementedType, str]: try: # Django 1.6 can fail to import `QuerySet` when Django settings # have not yet been initialized. @@ -261,8 +263,7 @@ def _django_queryset_repr(value, hint): _DRF_PATCH_LOCK = threading.Lock() -def _patch_drf(): - # type: () -> None +def _patch_drf() -> None: """ Patch Django Rest Framework for more/better request data. DRF's request type is a wrapper around Django's request type. The attribute we're @@ -305,8 +306,9 @@ def _patch_drf(): old_drf_initial = APIView.initial @functools.wraps(old_drf_initial) - def sentry_patched_drf_initial(self, request, *args, **kwargs): - # type: (APIView, Any, *Any, **Any) -> Any + def sentry_patched_drf_initial( + self: APIView, request: Any, *args: Any, **kwargs: Any + ) -> Any: with capture_internal_exceptions(): request._request._sentry_drf_request_backref = weakref.ref( request @@ -317,8 +319,7 @@ def sentry_patched_drf_initial(self, request, *args, **kwargs): APIView.initial = sentry_patched_drf_initial -def _patch_channels(): - # type: () -> None +def _patch_channels() -> None: try: from channels.http import AsgiHandler # type: ignore except ImportError: @@ -342,8 +343,7 @@ def _patch_channels(): patch_channels_asgi_handler_impl(AsgiHandler) -def _patch_django_asgi_handler(): - # type: () -> None +def _patch_django_asgi_handler() -> None: try: from django.core.handlers.asgi import ASGIHandler except ImportError: @@ -364,8 +364,9 @@ def _patch_django_asgi_handler(): patch_django_asgi_handler_impl(ASGIHandler) -def _set_transaction_name_and_source(scope, transaction_style, request): - # type: (sentry_sdk.Scope, str, WSGIRequest) -> None +def _set_transaction_name_and_source( + scope: sentry_sdk.Scope, transaction_style: str, request: WSGIRequest +) -> None: try: transaction_name = None if transaction_style == "function_name": @@ -408,8 +409,7 @@ def _set_transaction_name_and_source(scope, transaction_style, request): pass -def _before_get_response(request): - # type: (WSGIRequest) -> None +def _before_get_response(request: WSGIRequest) -> None: integration = sentry_sdk.get_client().get_integration(DjangoIntegration) if integration is None: return @@ -425,8 +425,9 @@ def _before_get_response(request): ) -def _attempt_resolve_again(request, scope, transaction_style): - # type: (WSGIRequest, sentry_sdk.Scope, str) -> None +def _attempt_resolve_again( + request: WSGIRequest, scope: sentry_sdk.Scope, transaction_style: str +) -> None: """ Some django middlewares overwrite request.urlconf so we need to respect that contract, @@ -438,8 +439,7 @@ def _attempt_resolve_again(request, scope, transaction_style): _set_transaction_name_and_source(scope, transaction_style, request) -def _after_get_response(request): - # type: (WSGIRequest) -> None +def _after_get_response(request: WSGIRequest) -> None: integration = sentry_sdk.get_client().get_integration(DjangoIntegration) if integration is None or integration.transaction_style != "url": return @@ -448,8 +448,7 @@ def _after_get_response(request): _attempt_resolve_again(request, scope, integration.transaction_style) -def _patch_get_response(): - # type: () -> None +def _patch_get_response() -> None: """ patch get_response, because at that point we have the Django request object """ @@ -458,8 +457,9 @@ def _patch_get_response(): old_get_response = BaseHandler.get_response @functools.wraps(old_get_response) - def sentry_patched_get_response(self, request): - # type: (Any, WSGIRequest) -> Union[HttpResponse, BaseException] + def sentry_patched_get_response( + self: Any, request: WSGIRequest + ) -> Union[HttpResponse, BaseException]: _before_get_response(request) rv = old_get_response(self, request) _after_get_response(request) @@ -473,10 +473,10 @@ def sentry_patched_get_response(self, request): patch_get_response_async(BaseHandler, _before_get_response) -def _make_wsgi_request_event_processor(weak_request, integration): - # type: (Callable[[], WSGIRequest], DjangoIntegration) -> EventProcessor - def wsgi_request_event_processor(event, hint): - # type: (Event, dict[str, Any]) -> Event +def _make_wsgi_request_event_processor( + weak_request: Callable[[], WSGIRequest], integration: DjangoIntegration +) -> EventProcessor: + def wsgi_request_event_processor(event: Event, hint: dict[str, Any]) -> Event: # if the request is gone we are fine not logging the data from # it. This might happen if the processor is pushed away to # another thread. @@ -501,8 +501,7 @@ def wsgi_request_event_processor(event, hint): return wsgi_request_event_processor -def _got_request_exception(request=None, **kwargs): - # type: (WSGIRequest, **Any) -> None +def _got_request_exception(request: WSGIRequest = None, **kwargs: Any) -> None: client = sentry_sdk.get_client() integration = client.get_integration(DjangoIntegration) if integration is None: @@ -521,8 +520,7 @@ def _got_request_exception(request=None, **kwargs): class DjangoRequestExtractor(RequestExtractor): - def __init__(self, request): - # type: (Union[WSGIRequest, ASGIRequest]) -> None + def __init__(self, request: Union[WSGIRequest, ASGIRequest]) -> None: try: drf_request = request._sentry_drf_request_backref() if drf_request is not None: @@ -531,18 +529,16 @@ def __init__(self, request): pass self.request = request - def env(self): - # type: () -> Dict[str, str] + def env(self) -> Dict[str, str]: return self.request.META - def cookies(self): - # type: () -> Dict[str, Union[str, AnnotatedValue]] + def cookies(self) -> Dict[str, Union[str, AnnotatedValue]]: privacy_cookies = [ django_settings.CSRF_COOKIE_NAME, django_settings.SESSION_COOKIE_NAME, ] - clean_cookies = {} # type: Dict[str, Union[str, AnnotatedValue]] + clean_cookies: Dict[str, Union[str, AnnotatedValue]] = {} for key, val in self.request.COOKIES.items(): if key in privacy_cookies: clean_cookies[key] = SENSITIVE_DATA_SUBSTITUTE @@ -551,32 +547,26 @@ def cookies(self): return clean_cookies - def raw_data(self): - # type: () -> bytes + def raw_data(self) -> bytes: return self.request.body - def form(self): - # type: () -> QueryDict + def form(self) -> QueryDict: return self.request.POST - def files(self): - # type: () -> MultiValueDict + def files(self) -> MultiValueDict: return self.request.FILES - def size_of_file(self, file): - # type: (Any) -> int + def size_of_file(self, file: Any) -> int: return file.size - def parsed_body(self): - # type: () -> Optional[Dict[str, Any]] + def parsed_body(self) -> Optional[Dict[str, Any]]: try: return self.request.data except Exception: return RequestExtractor.parsed_body(self) -def _set_user_info(request, event): - # type: (WSGIRequest, Event) -> None +def _set_user_info(request: WSGIRequest, event: Event) -> None: user_info = event.setdefault("user", {}) user = getattr(request, "user", None) @@ -600,8 +590,7 @@ def _set_user_info(request, event): pass -def install_sql_hook(): - # type: () -> None +def install_sql_hook() -> None: """If installed this causes Django's queries to be captured.""" try: from django.db.backends.utils import CursorWrapper @@ -615,8 +604,7 @@ def install_sql_hook(): real_connect = BaseDatabaseWrapper.connect @ensure_integration_enabled(DjangoIntegration, real_execute) - def execute(self, sql, params=None): - # type: (CursorWrapper, Any, Optional[Any]) -> Any + def execute(self: CursorWrapper, sql: Any, params: Optional[Any] = None) -> Any: with record_sql_queries( cursor=self.cursor, query=sql, @@ -634,8 +622,7 @@ def execute(self, sql, params=None): return result @ensure_integration_enabled(DjangoIntegration, real_executemany) - def executemany(self, sql, param_list): - # type: (CursorWrapper, Any, List[Any]) -> Any + def executemany(self: CursorWrapper, sql: Any, param_list: List[Any]) -> Any: with record_sql_queries( cursor=self.cursor, query=sql, @@ -654,8 +641,7 @@ def executemany(self, sql, param_list): return result @ensure_integration_enabled(DjangoIntegration, real_connect) - def connect(self): - # type: (BaseDatabaseWrapper) -> None + def connect(self: BaseDatabaseWrapper) -> None: with capture_internal_exceptions(): sentry_sdk.add_breadcrumb(message="connect", category="query") @@ -674,8 +660,7 @@ def connect(self): ignore_logger("django.db.backends") -def _set_db_data(span, cursor_or_db): - # type: (Span, Any) -> None +def _set_db_data(span: Span, cursor_or_db: Any) -> None: db = cursor_or_db.db if hasattr(cursor_or_db, "db") else cursor_or_db vendor = db.vendor span.set_attribute(SPANDATA.DB_SYSTEM, vendor) diff --git a/sentry_sdk/integrations/django/asgi.py b/sentry_sdk/integrations/django/asgi.py index d37503d16d..1435fefa04 100644 --- a/sentry_sdk/integrations/django/asgi.py +++ b/sentry_sdk/integrations/django/asgi.py @@ -1,3 +1,5 @@ +from __future__ import annotations + """ Instrumentation for Django 3.0 @@ -51,10 +53,8 @@ def markcoroutinefunction(func: "_F") -> "_F": return func -def _make_asgi_request_event_processor(request): - # type: (ASGIRequest) -> EventProcessor - def asgi_request_event_processor(event, hint): - # type: (Event, dict[str, Any]) -> Event +def _make_asgi_request_event_processor(request: ASGIRequest) -> EventProcessor: + def asgi_request_event_processor(event: Event, hint: dict[str, Any]) -> Event: # if the request is gone we are fine not logging the data from # it. This might happen if the processor is pushed away to # another thread. @@ -81,16 +81,16 @@ def asgi_request_event_processor(event, hint): return asgi_request_event_processor -def patch_django_asgi_handler_impl(cls): - # type: (Any) -> None +def patch_django_asgi_handler_impl(cls: Any) -> None: from sentry_sdk.integrations.django import DjangoIntegration old_app = cls.__call__ @functools.wraps(old_app) - async def sentry_patched_asgi_handler(self, scope, receive, send): - # type: (Any, Any, Any, Any) -> Any + async def sentry_patched_asgi_handler( + self: Any, scope: Any, receive: Any, send: Any + ) -> Any: integration = sentry_sdk.get_client().get_integration(DjangoIntegration) if integration is None: return await old_app(self, scope, receive, send) @@ -111,8 +111,7 @@ async def sentry_patched_asgi_handler(self, scope, receive, send): old_create_request = cls.create_request @ensure_integration_enabled(DjangoIntegration, old_create_request) - def sentry_patched_create_request(self, *args, **kwargs): - # type: (Any, *Any, **Any) -> Any + def sentry_patched_create_request(self: Any, *args: Any, **kwargs: Any) -> Any: request, error_response = old_create_request(self, *args, **kwargs) scope = sentry_sdk.get_isolation_scope() scope.add_event_processor(_make_asgi_request_event_processor(request)) @@ -122,21 +121,20 @@ def sentry_patched_create_request(self, *args, **kwargs): cls.create_request = sentry_patched_create_request -def patch_get_response_async(cls, _before_get_response): - # type: (Any, Any) -> None +def patch_get_response_async(cls: Any, _before_get_response: Any) -> None: old_get_response_async = cls.get_response_async @functools.wraps(old_get_response_async) - async def sentry_patched_get_response_async(self, request): - # type: (Any, Any) -> Union[HttpResponse, BaseException] + async def sentry_patched_get_response_async( + self: Any, request: Any + ) -> Union[HttpResponse, BaseException]: _before_get_response(request) return await old_get_response_async(self, request) cls.get_response_async = sentry_patched_get_response_async -def patch_channels_asgi_handler_impl(cls): - # type: (Any) -> None +def patch_channels_asgi_handler_impl(cls: Any) -> None: import channels # type: ignore from sentry_sdk.integrations.django import DjangoIntegration @@ -145,8 +143,9 @@ def patch_channels_asgi_handler_impl(cls): old_app = cls.__call__ @functools.wraps(old_app) - async def sentry_patched_asgi_handler(self, receive, send): - # type: (Any, Any, Any) -> Any + async def sentry_patched_asgi_handler( + self: Any, receive: Any, send: Any + ) -> Any: integration = sentry_sdk.get_client().get_integration(DjangoIntegration) if integration is None: return await old_app(self, receive, send) @@ -168,13 +167,11 @@ async def sentry_patched_asgi_handler(self, receive, send): patch_django_asgi_handler_impl(cls) -def wrap_async_view(callback): - # type: (Any) -> Any +def wrap_async_view(callback: Any) -> Any: from sentry_sdk.integrations.django import DjangoIntegration @functools.wraps(callback) - async def sentry_wrapped_callback(request, *args, **kwargs): - # type: (Any, *Any, **Any) -> Any + async def sentry_wrapped_callback(request: Any, *args: Any, **kwargs: Any) -> Any: current_scope = sentry_sdk.get_current_scope() if current_scope.root_span is not None: current_scope.root_span.update_active_thread() @@ -194,8 +191,7 @@ async def sentry_wrapped_callback(request, *args, **kwargs): return sentry_wrapped_callback -def _asgi_middleware_mixin_factory(_check_middleware_span): - # type: (Callable[..., Any]) -> Any +def _asgi_middleware_mixin_factory(_check_middleware_span: Callable[..., Any]) -> Any: """ Mixin class factory that generates a middleware mixin for handling requests in async mode. @@ -205,14 +201,12 @@ class SentryASGIMixin: if TYPE_CHECKING: _inner = None - def __init__(self, get_response): - # type: (Callable[..., Any]) -> None + def __init__(self, get_response: Callable[..., Any]) -> None: self.get_response = get_response self._acall_method = None self._async_check() - def _async_check(self): - # type: () -> None + def _async_check(self) -> None: """ If get_response is a coroutine function, turns us into async mode so a thread is not consumed during a whole request. @@ -221,16 +215,14 @@ def _async_check(self): if iscoroutinefunction(self.get_response): markcoroutinefunction(self) - def async_route_check(self): - # type: () -> bool + def async_route_check(self) -> bool: """ Function that checks if we are in async mode, and if we are forwards the handling of requests to __acall__ """ return iscoroutinefunction(self.get_response) - async def __acall__(self, *args, **kwargs): - # type: (*Any, **Any) -> Any + async def __acall__(self, *args: Any, **kwargs: Any) -> Any: f = self._acall_method if f is None: if hasattr(self._inner, "__acall__"): diff --git a/sentry_sdk/integrations/django/caching.py b/sentry_sdk/integrations/django/caching.py index 65bf2674e1..021ed398f7 100644 --- a/sentry_sdk/integrations/django/caching.py +++ b/sentry_sdk/integrations/django/caching.py @@ -1,3 +1,4 @@ +from __future__ import annotations import functools from typing import TYPE_CHECKING from sentry_sdk.integrations.redis.utils import _get_safe_key, _key_as_string @@ -28,22 +29,29 @@ ] -def _get_span_description(method_name, args, kwargs): - # type: (str, tuple[Any], dict[str, Any]) -> str +def _get_span_description( + method_name: str, args: tuple[Any], kwargs: dict[str, Any] +) -> str: return _key_as_string(_get_safe_key(method_name, args, kwargs)) -def _patch_cache_method(cache, method_name, address, port): - # type: (CacheHandler, str, Optional[str], Optional[int]) -> None +def _patch_cache_method( + cache: CacheHandler, method_name: str, address: Optional[str], port: Optional[int] +) -> None: from sentry_sdk.integrations.django import DjangoIntegration original_method = getattr(cache, method_name) @ensure_integration_enabled(DjangoIntegration, original_method) def _instrument_call( - cache, method_name, original_method, args, kwargs, address, port - ): - # type: (CacheHandler, str, Callable[..., Any], tuple[Any, ...], dict[str, Any], Optional[str], Optional[int]) -> Any + cache: CacheHandler, + method_name: str, + original_method: Callable[..., Any], + args: tuple[Any, ...], + kwargs: dict[str, Any], + address: Optional[str], + port: Optional[int], + ) -> Any: is_set_operation = method_name.startswith("set") is_get_operation = not is_set_operation @@ -91,8 +99,7 @@ def _instrument_call( return value @functools.wraps(original_method) - def sentry_method(*args, **kwargs): - # type: (*Any, **Any) -> Any + def sentry_method(*args: Any, **kwargs: Any) -> Any: return _instrument_call( cache, method_name, original_method, args, kwargs, address, port ) @@ -100,16 +107,16 @@ def sentry_method(*args, **kwargs): setattr(cache, method_name, sentry_method) -def _patch_cache(cache, address=None, port=None): - # type: (CacheHandler, Optional[str], Optional[int]) -> None +def _patch_cache( + cache: CacheHandler, address: Optional[str] = None, port: Optional[int] = None +) -> None: if not hasattr(cache, "_sentry_patched"): for method_name in METHODS_TO_INSTRUMENT: _patch_cache_method(cache, method_name, address, port) cache._sentry_patched = True -def _get_address_port(settings): - # type: (dict[str, Any]) -> tuple[Optional[str], Optional[int]] +def _get_address_port(settings: dict[str, Any]) -> tuple[Optional[str], Optional[int]]: location = settings.get("LOCATION") # TODO: location can also be an array of locations @@ -134,8 +141,7 @@ def _get_address_port(settings): return address, int(port) if port is not None else None -def patch_caching(): - # type: () -> None +def patch_caching() -> None: from sentry_sdk.integrations.django import DjangoIntegration if not hasattr(CacheHandler, "_sentry_patched"): @@ -143,8 +149,7 @@ def patch_caching(): original_get_item = CacheHandler.__getitem__ @functools.wraps(original_get_item) - def sentry_get_item(self, alias): - # type: (CacheHandler, str) -> Any + def sentry_get_item(self: CacheHandler, alias: str) -> Any: cache = original_get_item(self, alias) integration = sentry_sdk.get_client().get_integration(DjangoIntegration) @@ -166,8 +171,7 @@ def sentry_get_item(self, alias): original_create_connection = CacheHandler.create_connection @functools.wraps(original_create_connection) - def sentry_create_connection(self, alias): - # type: (CacheHandler, str) -> Any + def sentry_create_connection(self: CacheHandler, alias: str) -> Any: cache = original_create_connection(self, alias) integration = sentry_sdk.get_client().get_integration(DjangoIntegration) diff --git a/sentry_sdk/integrations/django/middleware.py b/sentry_sdk/integrations/django/middleware.py index 6640ac2919..bb16cc890a 100644 --- a/sentry_sdk/integrations/django/middleware.py +++ b/sentry_sdk/integrations/django/middleware.py @@ -1,3 +1,5 @@ +from __future__ import annotations + """ Create spans from Django middleware invocations """ @@ -38,14 +40,12 @@ from .asgi import _asgi_middleware_mixin_factory -def patch_django_middlewares(): - # type: () -> None +def patch_django_middlewares() -> None: from django.core.handlers import base old_import_string = base.import_string - def sentry_patched_import_string(dotted_path): - # type: (str) -> Any + def sentry_patched_import_string(dotted_path: str) -> Any: rv = old_import_string(dotted_path) if _import_string_should_wrap_middleware.get(None): @@ -57,8 +57,7 @@ def sentry_patched_import_string(dotted_path): old_load_middleware = base.BaseHandler.load_middleware - def sentry_patched_load_middleware(*args, **kwargs): - # type: (Any, Any) -> Any + def sentry_patched_load_middleware(*args: Any, **kwargs: Any) -> Any: _import_string_should_wrap_middleware.set(True) try: return old_load_middleware(*args, **kwargs) @@ -68,12 +67,10 @@ def sentry_patched_load_middleware(*args, **kwargs): base.BaseHandler.load_middleware = sentry_patched_load_middleware -def _wrap_middleware(middleware, middleware_name): - # type: (Any, str) -> Any +def _wrap_middleware(middleware: Any, middleware_name: str) -> Any: from sentry_sdk.integrations.django import DjangoIntegration - def _check_middleware_span(old_method): - # type: (Callable[..., Any]) -> Optional[Span] + def _check_middleware_span(old_method: Callable[..., Any]) -> Optional[Span]: integration = sentry_sdk.get_client().get_integration(DjangoIntegration) if integration is None or not integration.middleware_spans: return None @@ -96,12 +93,10 @@ def _check_middleware_span(old_method): return middleware_span - def _get_wrapped_method(old_method): - # type: (F) -> F + def _get_wrapped_method(old_method: F) -> F: with capture_internal_exceptions(): - def sentry_wrapped_method(*args, **kwargs): - # type: (*Any, **Any) -> Any + def sentry_wrapped_method(*args: Any, **kwargs: Any) -> Any: middleware_span = _check_middleware_span(old_method) if middleware_span is None: @@ -131,8 +126,12 @@ class SentryWrappingMiddleware( middleware, "async_capable", False ) - def __init__(self, get_response=None, *args, **kwargs): - # type: (Optional[Callable[..., Any]], *Any, **Any) -> None + def __init__( + self, + get_response: Optional[Callable[..., Any]] = None, + *args: Any, + **kwargs: Any, + ) -> None: if get_response: self._inner = middleware(get_response, *args, **kwargs) else: @@ -144,8 +143,7 @@ def __init__(self, get_response=None, *args, **kwargs): # We need correct behavior for `hasattr()`, which we can only determine # when we have an instance of the middleware we're wrapping. - def __getattr__(self, method_name): - # type: (str) -> Any + def __getattr__(self, method_name: str) -> Any: if method_name not in ( "process_request", "process_view", @@ -160,8 +158,7 @@ def __getattr__(self, method_name): self.__dict__[method_name] = rv return rv - def __call__(self, *args, **kwargs): - # type: (*Any, **Any) -> Any + def __call__(self, *args: Any, **kwargs: Any) -> Any: if hasattr(self, "async_route_check") and self.async_route_check(): return self.__acall__(*args, **kwargs) diff --git a/sentry_sdk/integrations/django/signals_handlers.py b/sentry_sdk/integrations/django/signals_handlers.py index 6e398ddfc3..23c821c65b 100644 --- a/sentry_sdk/integrations/django/signals_handlers.py +++ b/sentry_sdk/integrations/django/signals_handlers.py @@ -1,3 +1,4 @@ +from __future__ import annotations from functools import wraps from django.dispatch import Signal @@ -13,8 +14,7 @@ from typing import Any, Union -def _get_receiver_name(receiver): - # type: (Callable[..., Any]) -> str +def _get_receiver_name(receiver: Callable[..., Any]) -> str: name = "" if hasattr(receiver, "__qualname__"): @@ -38,8 +38,7 @@ def _get_receiver_name(receiver): return name -def patch_signals(): - # type: () -> None +def patch_signals() -> None: """ Patch django signal receivers to create a span. @@ -51,19 +50,21 @@ def patch_signals(): old_live_receivers = Signal._live_receivers @wraps(old_live_receivers) - def _sentry_live_receivers(self, sender): - # type: (Signal, Any) -> Union[tuple[list[Callable[..., Any]], list[Callable[..., Any]]], list[Callable[..., Any]]] + def _sentry_live_receivers(self: Signal, sender: Any) -> Union[ + tuple[list[Callable[..., Any]], list[Callable[..., Any]]], + list[Callable[..., Any]], + ]: if DJANGO_VERSION >= (5, 0): sync_receivers, async_receivers = old_live_receivers(self, sender) else: sync_receivers = old_live_receivers(self, sender) async_receivers = [] - def sentry_sync_receiver_wrapper(receiver): - # type: (Callable[..., Any]) -> Callable[..., Any] + def sentry_sync_receiver_wrapper( + receiver: Callable[..., Any], + ) -> Callable[..., Any]: @wraps(receiver) - def wrapper(*args, **kwargs): - # type: (Any, Any) -> Any + def wrapper(*args: Any, **kwargs: Any) -> Any: signal_name = _get_receiver_name(receiver) with sentry_sdk.start_span( op=OP.EVENT_DJANGO, diff --git a/sentry_sdk/integrations/django/templates.py b/sentry_sdk/integrations/django/templates.py index fd6e56b515..8299afa92f 100644 --- a/sentry_sdk/integrations/django/templates.py +++ b/sentry_sdk/integrations/django/templates.py @@ -1,3 +1,4 @@ +from __future__ import annotations import functools from django.template import TemplateSyntaxError @@ -18,8 +19,9 @@ from typing import Tuple -def get_template_frame_from_exception(exc_value): - # type: (Optional[BaseException]) -> Optional[Dict[str, Any]] +def get_template_frame_from_exception( + exc_value: Optional[BaseException], +) -> Optional[Dict[str, Any]]: # As of Django 1.9 or so the new template debug thing showed up. if hasattr(exc_value, "template_debug"): @@ -41,8 +43,7 @@ def get_template_frame_from_exception(exc_value): return None -def _get_template_name_description(template_name): - # type: (str) -> str +def _get_template_name_description(template_name: str) -> str: if isinstance(template_name, (list, tuple)): if template_name: return "[{}, ...]".format(template_name[0]) @@ -50,8 +51,7 @@ def _get_template_name_description(template_name): return template_name -def patch_templates(): - # type: () -> None +def patch_templates() -> None: from django.template.response import SimpleTemplateResponse from sentry_sdk.integrations.django import DjangoIntegration @@ -59,8 +59,7 @@ def patch_templates(): @property # type: ignore @ensure_integration_enabled(DjangoIntegration, real_rendered_content.fget) - def rendered_content(self): - # type: (SimpleTemplateResponse) -> str + def rendered_content(self: SimpleTemplateResponse) -> str: with sentry_sdk.start_span( op=OP.TEMPLATE_RENDER, name=_get_template_name_description(self.template_name), @@ -80,8 +79,13 @@ def rendered_content(self): @functools.wraps(real_render) @ensure_integration_enabled(DjangoIntegration, real_render) - def render(request, template_name, context=None, *args, **kwargs): - # type: (django.http.HttpRequest, str, Optional[Dict[str, Any]], *Any, **Any) -> django.http.HttpResponse + def render( + request: django.http.HttpRequest, + template_name: str, + context: Optional[Dict[str, Any]] = None, + *args: Any, + **kwargs: Any, + ) -> django.http.HttpResponse: # Inject trace meta tags into template context context = context or {} @@ -103,8 +107,7 @@ def render(request, template_name, context=None, *args, **kwargs): django.shortcuts.render = render -def _get_template_frame_from_debug(debug): - # type: (Dict[str, Any]) -> Dict[str, Any] +def _get_template_frame_from_debug(debug: Dict[str, Any]) -> Dict[str, Any]: if debug is None: return None @@ -135,8 +138,7 @@ def _get_template_frame_from_debug(debug): } -def _linebreak_iter(template_source): - # type: (str) -> Iterator[int] +def _linebreak_iter(template_source: str) -> Iterator[int]: yield 0 p = template_source.find("\n") while p >= 0: @@ -144,8 +146,9 @@ def _linebreak_iter(template_source): p = template_source.find("\n", p + 1) -def _get_template_frame_from_source(source): - # type: (Tuple[Origin, Tuple[int, int]]) -> Optional[Dict[str, Any]] +def _get_template_frame_from_source( + source: Tuple[Origin, Tuple[int, int]], +) -> Optional[Dict[str, Any]]: if not source: return None diff --git a/sentry_sdk/integrations/django/transactions.py b/sentry_sdk/integrations/django/transactions.py index 78b972bc37..3fe81f2029 100644 --- a/sentry_sdk/integrations/django/transactions.py +++ b/sentry_sdk/integrations/django/transactions.py @@ -1,3 +1,5 @@ +from __future__ import annotations + """ Copied from raven-python. @@ -27,8 +29,7 @@ from django.core.urlresolvers import get_resolver -def get_regex(resolver_or_pattern): - # type: (Union[URLPattern, URLResolver]) -> Pattern[str] +def get_regex(resolver_or_pattern: Union[URLPattern, URLResolver]) -> Pattern[str]: """Utility method for django's deprecated resolver.regex""" try: regex = resolver_or_pattern.regex @@ -48,10 +49,9 @@ class RavenResolver: _either_option_matcher = re.compile(r"\[([^\]]+)\|([^\]]+)\]") _camel_re = re.compile(r"([A-Z]+)([a-z])") - _cache = {} # type: Dict[URLPattern, str] + _cache: Dict[URLPattern, str] = {} - def _simplify(self, pattern): - # type: (Union[URLPattern, URLResolver]) -> str + def _simplify(self, pattern: Union[URLPattern, URLResolver]) -> str: r""" Clean up urlpattern regexes into something readable by humans: @@ -102,8 +102,12 @@ def _simplify(self, pattern): return result - def _resolve(self, resolver, path, parents=None): - # type: (URLResolver, str, Optional[List[URLResolver]]) -> Optional[str] + def _resolve( + self, + resolver: URLResolver, + path: str, + parents: Optional[List[URLResolver]] = None, + ) -> Optional[str]: match = get_regex(resolver).search(path) # Django < 2.0 @@ -142,10 +146,11 @@ def _resolve(self, resolver, path, parents=None): def resolve( self, - path, # type: str - urlconf=None, # type: Union[None, Tuple[URLPattern, URLPattern, URLResolver], Tuple[URLPattern]] - ): - # type: (...) -> Optional[str] + path: str, + urlconf: Union[ + None, Tuple[URLPattern, URLPattern, URLResolver], Tuple[URLPattern] + ] = None, + ) -> Optional[str]: resolver = get_resolver(urlconf) match = self._resolve(resolver, path) return match diff --git a/sentry_sdk/integrations/django/views.py b/sentry_sdk/integrations/django/views.py index 6240ac6bbb..4cc5a90cfc 100644 --- a/sentry_sdk/integrations/django/views.py +++ b/sentry_sdk/integrations/django/views.py @@ -1,3 +1,4 @@ +from __future__ import annotations import functools import sentry_sdk @@ -21,8 +22,7 @@ wrap_async_view = None # type: ignore -def patch_views(): - # type: () -> None +def patch_views() -> None: from django.core.handlers.base import BaseHandler from django.template.response import SimpleTemplateResponse @@ -32,8 +32,7 @@ def patch_views(): old_render = SimpleTemplateResponse.render @functools.wraps(old_render) - def sentry_patched_render(self): - # type: (SimpleTemplateResponse) -> Any + def sentry_patched_render(self: SimpleTemplateResponse) -> Any: with sentry_sdk.start_span( op=OP.VIEW_RESPONSE_RENDER, name="serialize response", @@ -43,8 +42,7 @@ def sentry_patched_render(self): return old_render(self) @functools.wraps(old_make_view_atomic) - def sentry_patched_make_view_atomic(self, *args, **kwargs): - # type: (Any, *Any, **Any) -> Any + def sentry_patched_make_view_atomic(self: Any, *args: Any, **kwargs: Any) -> Any: callback = old_make_view_atomic(self, *args, **kwargs) # XXX: The wrapper function is created for every request. Find more @@ -71,13 +69,11 @@ def sentry_patched_make_view_atomic(self, *args, **kwargs): BaseHandler.make_view_atomic = sentry_patched_make_view_atomic -def _wrap_sync_view(callback): - # type: (Any) -> Any +def _wrap_sync_view(callback: Any) -> Any: from sentry_sdk.integrations.django import DjangoIntegration @functools.wraps(callback) - def sentry_wrapped_callback(request, *args, **kwargs): - # type: (Any, *Any, **Any) -> Any + def sentry_wrapped_callback(request: Any, *args: Any, **kwargs: Any) -> Any: current_scope = sentry_sdk.get_current_scope() if current_scope.root_span is not None: current_scope.root_span.update_active_thread() diff --git a/sentry_sdk/integrations/dramatiq.py b/sentry_sdk/integrations/dramatiq.py index a756b4c669..76abf243bc 100644 --- a/sentry_sdk/integrations/dramatiq.py +++ b/sentry_sdk/integrations/dramatiq.py @@ -1,3 +1,4 @@ +from __future__ import annotations import json import sentry_sdk @@ -36,17 +37,14 @@ class DramatiqIntegration(Integration): identifier = "dramatiq" @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: _patch_dramatiq_broker() -def _patch_dramatiq_broker(): - # type: () -> None +def _patch_dramatiq_broker() -> None: original_broker__init__ = Broker.__init__ - def sentry_patched_broker__init__(self, *args, **kw): - # type: (Broker, *Any, **Any) -> None + def sentry_patched_broker__init__(self: Broker, *args: Any, **kw: Any) -> None: integration = sentry_sdk.get_client().get_integration(DramatiqIntegration) try: @@ -85,8 +83,7 @@ class SentryMiddleware(Middleware): # type: ignore[misc] DramatiqIntegration. """ - def before_process_message(self, broker, message): - # type: (Broker, Message) -> None + def before_process_message(self, broker: Broker, message: Message) -> None: integration = sentry_sdk.get_client().get_integration(DramatiqIntegration) if integration is None: return @@ -99,8 +96,14 @@ def before_process_message(self, broker, message): scope.set_extra("dramatiq_message_id", message.message_id) scope.add_event_processor(_make_message_event_processor(message, integration)) - def after_process_message(self, broker, message, *, result=None, exception=None): - # type: (Broker, Message, Any, Optional[Any], Optional[Exception]) -> None + def after_process_message( + self: Broker, + broker: Message, + message: Any, + *, + result: Optional[Any] = None, + exception: Optional[Exception] = None, + ) -> None: integration = sentry_sdk.get_client().get_integration(DramatiqIntegration) if integration is None: return @@ -127,11 +130,11 @@ def after_process_message(self, broker, message, *, result=None, exception=None) message._scope_manager.__exit__(None, None, None) -def _make_message_event_processor(message, integration): - # type: (Message, DramatiqIntegration) -> Callable[[Event, Hint], Optional[Event]] +def _make_message_event_processor( + message: Message, integration: DramatiqIntegration +) -> Callable[[Event, Hint], Optional[Event]]: - def inner(event, hint): - # type: (Event, Hint) -> Optional[Event] + def inner(event: Event, hint: Hint) -> Optional[Event]: with capture_internal_exceptions(): DramatiqMessageExtractor(message).extract_into_event(event) @@ -141,16 +144,13 @@ def inner(event, hint): class DramatiqMessageExtractor: - def __init__(self, message): - # type: (Message) -> None + def __init__(self, message: Message) -> None: self.message_data = dict(message.asdict()) - def content_length(self): - # type: () -> int + def content_length(self) -> int: return len(json.dumps(self.message_data)) - def extract_into_event(self, event): - # type: (Event) -> None + def extract_into_event(self, event: Event) -> None: client = sentry_sdk.get_client() if not client.is_active(): return @@ -159,7 +159,7 @@ def extract_into_event(self, event): request_info = contexts.setdefault("dramatiq", {}) request_info["type"] = "dramatiq" - data = None # type: Optional[Union[AnnotatedValue, Dict[str, Any]]] + data: Optional[Union[AnnotatedValue, Dict[str, Any]]] = None if not request_body_within_bounds(client, self.content_length()): data = AnnotatedValue.removed_because_over_size_limit() else: diff --git a/sentry_sdk/integrations/excepthook.py b/sentry_sdk/integrations/excepthook.py index 61c7e460bf..ad3f7a82b6 100644 --- a/sentry_sdk/integrations/excepthook.py +++ b/sentry_sdk/integrations/excepthook.py @@ -1,3 +1,4 @@ +from __future__ import annotations import sys import sentry_sdk @@ -28,8 +29,7 @@ class ExcepthookIntegration(Integration): always_run = False - def __init__(self, always_run=False): - # type: (bool) -> None + def __init__(self, always_run: bool = False) -> None: if not isinstance(always_run, bool): raise ValueError( @@ -39,15 +39,16 @@ def __init__(self, always_run=False): self.always_run = always_run @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: sys.excepthook = _make_excepthook(sys.excepthook) -def _make_excepthook(old_excepthook): - # type: (Excepthook) -> Excepthook - def sentry_sdk_excepthook(type_, value, traceback): - # type: (Type[BaseException], BaseException, Optional[TracebackType]) -> None +def _make_excepthook(old_excepthook: Excepthook) -> Excepthook: + def sentry_sdk_excepthook( + type_: Type[BaseException], + value: BaseException, + traceback: Optional[TracebackType], + ) -> None: integration = sentry_sdk.get_client().get_integration(ExcepthookIntegration) # Note: If we replace this with ensure_integration_enabled then @@ -70,8 +71,7 @@ def sentry_sdk_excepthook(type_, value, traceback): return sentry_sdk_excepthook -def _should_send(always_run=False): - # type: (bool) -> bool +def _should_send(always_run: bool = False) -> bool: if always_run: return True diff --git a/sentry_sdk/integrations/executing.py b/sentry_sdk/integrations/executing.py index 6e68b8c0c7..649af64e58 100644 --- a/sentry_sdk/integrations/executing.py +++ b/sentry_sdk/integrations/executing.py @@ -1,3 +1,4 @@ +from __future__ import annotations import sentry_sdk from sentry_sdk.integrations import Integration, DidNotEnable from sentry_sdk.scope import add_global_event_processor @@ -20,12 +21,10 @@ class ExecutingIntegration(Integration): identifier = "executing" @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: @add_global_event_processor - def add_executing_info(event, hint): - # type: (Event, Optional[Hint]) -> Optional[Event] + def add_executing_info(event: Event, hint: Optional[Hint]) -> Optional[Event]: if sentry_sdk.get_client().get_integration(ExecutingIntegration) is None: return event diff --git a/sentry_sdk/integrations/falcon.py b/sentry_sdk/integrations/falcon.py index 9038c01a3f..622f8bb3a0 100644 --- a/sentry_sdk/integrations/falcon.py +++ b/sentry_sdk/integrations/falcon.py @@ -1,3 +1,4 @@ +from __future__ import annotations import sentry_sdk from sentry_sdk.consts import SOURCE_FOR_STYLE from sentry_sdk.integrations import _check_minimum_version, Integration, DidNotEnable @@ -33,30 +34,25 @@ falcon_app_class = falcon.App -_FALCON_UNSET = None # type: Optional[object] +_FALCON_UNSET: Optional[object] = None with capture_internal_exceptions(): from falcon.request import _UNSET as _FALCON_UNSET # type: ignore[import-not-found, no-redef] class FalconRequestExtractor(RequestExtractor): - def env(self): - # type: () -> Dict[str, Any] + def env(self) -> Dict[str, Any]: return self.request.env - def cookies(self): - # type: () -> Dict[str, Any] + def cookies(self) -> Dict[str, Any]: return self.request.cookies - def form(self): - # type: () -> None + def form(self) -> None: return None # No such concept in Falcon - def files(self): - # type: () -> None + def files(self) -> None: return None # No such concept in Falcon - def raw_data(self): - # type: () -> Optional[str] + def raw_data(self) -> Optional[str]: # As request data can only be read once we won't make this available # to Sentry. Just send back a dummy string in case there was a @@ -68,8 +64,7 @@ def raw_data(self): else: return None - def json(self): - # type: () -> Optional[Dict[str, Any]] + def json(self) -> Optional[Dict[str, Any]]: # fallback to cached_media = None if self.request._media is not available cached_media = None with capture_internal_exceptions(): @@ -90,8 +85,7 @@ def json(self): class SentryFalconMiddleware: """Captures exceptions in Falcon requests and send to Sentry""" - def process_request(self, req, resp, *args, **kwargs): - # type: (Any, Any, *Any, **Any) -> None + def process_request(self, req: Any, resp: Any, *args: Any, **kwargs: Any) -> None: integration = sentry_sdk.get_client().get_integration(FalconIntegration) if integration is None: return @@ -110,8 +104,7 @@ class FalconIntegration(Integration): transaction_style = "" - def __init__(self, transaction_style="uri_template"): - # type: (str) -> None + def __init__(self, transaction_style: str = "uri_template") -> None: if transaction_style not in TRANSACTION_STYLE_VALUES: raise ValueError( "Invalid value for transaction_style: %s (must be in %s)" @@ -120,8 +113,7 @@ def __init__(self, transaction_style="uri_template"): self.transaction_style = transaction_style @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: version = parse_version(FALCON_VERSION) _check_minimum_version(FalconIntegration, version) @@ -131,12 +123,10 @@ def setup_once(): _patch_prepare_middleware() -def _patch_wsgi_app(): - # type: () -> None +def _patch_wsgi_app() -> None: original_wsgi_app = falcon_app_class.__call__ - def sentry_patched_wsgi_app(self, env, start_response): - # type: (falcon.API, Any, Any) -> Any + def sentry_patched_wsgi_app(self: falcon.API, env: Any, start_response: Any) -> Any: integration = sentry_sdk.get_client().get_integration(FalconIntegration) if integration is None: return original_wsgi_app(self, env, start_response) @@ -151,13 +141,11 @@ def sentry_patched_wsgi_app(self, env, start_response): falcon_app_class.__call__ = sentry_patched_wsgi_app -def _patch_handle_exception(): - # type: () -> None +def _patch_handle_exception() -> None: original_handle_exception = falcon_app_class._handle_exception @ensure_integration_enabled(FalconIntegration, original_handle_exception) - def sentry_patched_handle_exception(self, *args): - # type: (falcon.API, *Any) -> Any + def sentry_patched_handle_exception(self: falcon.API, *args: Any) -> Any: # NOTE(jmagnusson): falcon 2.0 changed falcon.API._handle_exception # method signature from `(ex, req, resp, params)` to # `(req, resp, ex, params)` @@ -189,14 +177,12 @@ def sentry_patched_handle_exception(self, *args): falcon_app_class._handle_exception = sentry_patched_handle_exception -def _patch_prepare_middleware(): - # type: () -> None +def _patch_prepare_middleware() -> None: original_prepare_middleware = falcon_helpers.prepare_middleware def sentry_patched_prepare_middleware( - middleware=None, independent_middleware=False, asgi=False - ): - # type: (Any, Any, bool) -> Any + middleware: Any = None, independent_middleware: Any = False, asgi: bool = False + ) -> Any: if asgi: # We don't support ASGI Falcon apps, so we don't patch anything here return original_prepare_middleware(middleware, independent_middleware, asgi) @@ -212,8 +198,7 @@ def sentry_patched_prepare_middleware( falcon_helpers.prepare_middleware = sentry_patched_prepare_middleware -def _exception_leads_to_http_5xx(ex, response): - # type: (Exception, falcon.Response) -> bool +def _exception_leads_to_http_5xx(ex: Exception, response: falcon.Response) -> bool: is_server_error = isinstance(ex, falcon.HTTPError) and (ex.status or "").startswith( "5" ) @@ -224,13 +209,13 @@ def _exception_leads_to_http_5xx(ex, response): return (is_server_error or is_unhandled_error) and _has_http_5xx_status(response) -def _has_http_5xx_status(response): - # type: (falcon.Response) -> bool +def _has_http_5xx_status(response: falcon.Response) -> bool: return response.status.startswith("5") -def _set_transaction_name_and_source(event, transaction_style, request): - # type: (Event, str, falcon.Request) -> None +def _set_transaction_name_and_source( + event: Event, transaction_style: str, request: falcon.Request +) -> None: name_for_style = { "uri_template": request.uri_template, "path": request.path, @@ -239,11 +224,11 @@ def _set_transaction_name_and_source(event, transaction_style, request): event["transaction_info"] = {"source": SOURCE_FOR_STYLE[transaction_style]} -def _make_request_event_processor(req, integration): - # type: (falcon.Request, FalconIntegration) -> EventProcessor +def _make_request_event_processor( + req: falcon.Request, integration: FalconIntegration +) -> EventProcessor: - def event_processor(event, hint): - # type: (Event, dict[str, Any]) -> Event + def event_processor(event: Event, hint: dict[str, Any]) -> Event: _set_transaction_name_and_source(event, integration.transaction_style, req) with capture_internal_exceptions(): diff --git a/sentry_sdk/integrations/fastapi.py b/sentry_sdk/integrations/fastapi.py index 0e087e3975..10391fe934 100644 --- a/sentry_sdk/integrations/fastapi.py +++ b/sentry_sdk/integrations/fastapi.py @@ -1,3 +1,4 @@ +from __future__ import annotations import asyncio from copy import deepcopy from functools import wraps @@ -38,13 +39,13 @@ class FastApiIntegration(StarletteIntegration): identifier = "fastapi" @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: patch_get_request_handler() -def _set_transaction_name_and_source(scope, transaction_style, request): - # type: (sentry_sdk.Scope, str, Any) -> None +def _set_transaction_name_and_source( + scope: sentry_sdk.Scope, transaction_style: str, request: Any +) -> None: name = "" if transaction_style == "endpoint": @@ -71,12 +72,10 @@ def _set_transaction_name_and_source(scope, transaction_style, request): ) -def patch_get_request_handler(): - # type: () -> None +def patch_get_request_handler() -> None: old_get_request_handler = fastapi.routing.get_request_handler - def _sentry_get_request_handler(*args, **kwargs): - # type: (*Any, **Any) -> Any + def _sentry_get_request_handler(*args: Any, **kwargs: Any) -> Any: dependant = kwargs.get("dependant") if ( dependant @@ -86,8 +85,7 @@ def _sentry_get_request_handler(*args, **kwargs): old_call = dependant.call @wraps(old_call) - def _sentry_call(*args, **kwargs): - # type: (*Any, **Any) -> Any + def _sentry_call(*args: Any, **kwargs: Any) -> Any: current_scope = sentry_sdk.get_current_scope() if current_scope.root_span is not None: current_scope.root_span.update_active_thread() @@ -102,8 +100,7 @@ def _sentry_call(*args, **kwargs): old_app = old_get_request_handler(*args, **kwargs) - async def _sentry_app(*args, **kwargs): - # type: (*Any, **Any) -> Any + async def _sentry_app(*args: Any, **kwargs: Any) -> Any: integration = sentry_sdk.get_client().get_integration(FastApiIntegration) if integration is None: return await old_app(*args, **kwargs) @@ -117,10 +114,10 @@ async def _sentry_app(*args, **kwargs): extractor = StarletteRequestExtractor(request) info = await extractor.extract_request_info() - def _make_request_event_processor(req, integration): - # type: (Any, Any) -> Callable[[Event, Dict[str, Any]], Event] - def event_processor(event, hint): - # type: (Event, Dict[str, Any]) -> Event + def _make_request_event_processor( + req: Any, integration: Any + ) -> Callable[[Event, Dict[str, Any]], Event]: + def event_processor(event: Event, hint: Dict[str, Any]) -> Event: # Extract information from request request_info = event.get("request", {}) diff --git a/sentry_sdk/integrations/flask.py b/sentry_sdk/integrations/flask.py index 9223eacd24..708bcd01f9 100644 --- a/sentry_sdk/integrations/flask.py +++ b/sentry_sdk/integrations/flask.py @@ -1,3 +1,4 @@ +from __future__ import annotations import sentry_sdk from sentry_sdk.consts import SOURCE_FOR_STYLE from sentry_sdk.integrations import _check_minimum_version, DidNotEnable, Integration @@ -57,10 +58,9 @@ class FlaskIntegration(Integration): def __init__( self, - transaction_style="endpoint", # type: str - http_methods_to_capture=DEFAULT_HTTP_METHODS_TO_CAPTURE, # type: tuple[str, ...] - ): - # type: (...) -> None + transaction_style: str = "endpoint", + http_methods_to_capture: tuple[str, ...] = DEFAULT_HTTP_METHODS_TO_CAPTURE, + ) -> None: if transaction_style not in TRANSACTION_STYLE_VALUES: raise ValueError( "Invalid value for transaction_style: %s (must be in %s)" @@ -70,8 +70,7 @@ def __init__( self.http_methods_to_capture = tuple(map(str.upper, http_methods_to_capture)) @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: try: from quart import Quart # type: ignore @@ -93,8 +92,9 @@ def setup_once(): old_app = Flask.__call__ - def sentry_patched_wsgi_app(self, environ, start_response): - # type: (Any, Dict[str, str], Callable[..., Any]) -> _ScopedResponse + def sentry_patched_wsgi_app( + self: Any, environ: Dict[str, str], start_response: Callable[..., Any] + ) -> _ScopedResponse: if sentry_sdk.get_client().get_integration(FlaskIntegration) is None: return old_app(self, environ, start_response) @@ -114,8 +114,9 @@ def sentry_patched_wsgi_app(self, environ, start_response): Flask.__call__ = sentry_patched_wsgi_app -def _add_sentry_trace(sender, template, context, **extra): - # type: (Flask, Any, Dict[str, Any], **Any) -> None +def _add_sentry_trace( + sender: Flask, template: Any, context: Dict[str, Any], **extra: Any +) -> None: if "sentry_trace" in context: return @@ -125,8 +126,9 @@ def _add_sentry_trace(sender, template, context, **extra): context["sentry_trace_meta"] = trace_meta -def _set_transaction_name_and_source(scope, transaction_style, request): - # type: (sentry_sdk.Scope, str, Request) -> None +def _set_transaction_name_and_source( + scope: sentry_sdk.Scope, transaction_style: str, request: Request +) -> None: try: name_for_style = { "url": request.url_rule.rule, @@ -140,8 +142,7 @@ def _set_transaction_name_and_source(scope, transaction_style, request): pass -def _request_started(app, **kwargs): - # type: (Flask, **Any) -> None +def _request_started(app: Flask, **kwargs: Any) -> None: integration = sentry_sdk.get_client().get_integration(FlaskIntegration) if integration is None: return @@ -160,47 +161,39 @@ def _request_started(app, **kwargs): class FlaskRequestExtractor(RequestExtractor): - def env(self): - # type: () -> Dict[str, str] + def env(self) -> Dict[str, str]: return self.request.environ - def cookies(self): - # type: () -> Dict[Any, Any] + def cookies(self) -> Dict[Any, Any]: return { k: v[0] if isinstance(v, list) and len(v) == 1 else v for k, v in self.request.cookies.items() } - def raw_data(self): - # type: () -> bytes + def raw_data(self) -> bytes: return self.request.get_data() - def form(self): - # type: () -> ImmutableMultiDict[str, Any] + def form(self) -> ImmutableMultiDict[str, Any]: return self.request.form - def files(self): - # type: () -> ImmutableMultiDict[str, Any] + def files(self) -> ImmutableMultiDict[str, Any]: return self.request.files - def is_json(self): - # type: () -> bool + def is_json(self) -> bool: return self.request.is_json - def json(self): - # type: () -> Any + def json(self) -> Any: return self.request.get_json(silent=True) - def size_of_file(self, file): - # type: (FileStorage) -> int + def size_of_file(self, file: FileStorage) -> int: return file.content_length -def _make_request_event_processor(app, request, integration): - # type: (Flask, Callable[[], Request], FlaskIntegration) -> EventProcessor +def _make_request_event_processor( + app: Flask, request: Callable[[], Request], integration: FlaskIntegration +) -> EventProcessor: - def inner(event, hint): - # type: (Event, dict[str, Any]) -> Event + def inner(event: Event, hint: dict[str, Any]) -> Event: # if the request is gone we are fine not logging the data from # it. This might happen if the processor is pushed away to @@ -221,8 +214,9 @@ def inner(event, hint): @ensure_integration_enabled(FlaskIntegration) -def _capture_exception(sender, exception, **kwargs): - # type: (Flask, Union[ValueError, BaseException], **Any) -> None +def _capture_exception( + sender: Flask, exception: Union[ValueError, BaseException], **kwargs: Any +) -> None: event, hint = event_from_exception( exception, client_options=sentry_sdk.get_client().options, @@ -232,8 +226,7 @@ def _capture_exception(sender, exception, **kwargs): sentry_sdk.capture_event(event, hint=hint) -def _add_user_to_event(event): - # type: (Event) -> None +def _add_user_to_event(event: Event) -> None: if flask_login is None: return diff --git a/sentry_sdk/integrations/gcp.py b/sentry_sdk/integrations/gcp.py index 97b72ff1ce..a347ce3ffe 100644 --- a/sentry_sdk/integrations/gcp.py +++ b/sentry_sdk/integrations/gcp.py @@ -1,3 +1,4 @@ +from __future__ import annotations import functools import sys from copy import deepcopy @@ -39,11 +40,11 @@ F = TypeVar("F", bound=Callable[..., Any]) -def _wrap_func(func): - # type: (F) -> F +def _wrap_func(func: F) -> F: @functools.wraps(func) - def sentry_func(functionhandler, gcp_event, *args, **kwargs): - # type: (Any, Any, *Any, **Any) -> Any + def sentry_func( + functionhandler: Any, gcp_event: Any, *args: Any, **kwargs: Any + ) -> Any: client = sentry_sdk.get_client() integration = client.get_integration(GcpIntegration) @@ -118,13 +119,11 @@ class GcpIntegration(Integration): identifier = "gcp" origin = f"auto.function.{identifier}" - def __init__(self, timeout_warning=False): - # type: (bool) -> None + def __init__(self, timeout_warning: bool = False) -> None: self.timeout_warning = timeout_warning @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: import __main__ as gcp_functions if not hasattr(gcp_functions, "worker_v1"): @@ -140,11 +139,11 @@ def setup_once(): ) -def _make_request_event_processor(gcp_event, configured_timeout, initial_time): - # type: (Any, Any, Any) -> EventProcessor +def _make_request_event_processor( + gcp_event: Any, configured_timeout: Any, initial_time: Any +) -> EventProcessor: - def event_processor(event, hint): - # type: (Event, Hint) -> Optional[Event] + def event_processor(event: Event, hint: Hint) -> Optional[Event]: final_time = datetime.now(timezone.utc) time_diff = final_time - initial_time @@ -195,8 +194,7 @@ def event_processor(event, hint): return event_processor -def _get_google_cloud_logs_url(final_time): - # type: (datetime) -> str +def _get_google_cloud_logs_url(final_time: datetime) -> str: """ Generates a Google Cloud Logs console URL based on the environment variables Arguments: @@ -238,8 +236,7 @@ def _get_google_cloud_logs_url(final_time): } -def _prepopulate_attributes(gcp_event): - # type: (Any) -> dict[str, Any] +def _prepopulate_attributes(gcp_event: Any) -> dict[str, Any]: attributes = { "cloud.provider": "gcp", } diff --git a/sentry_sdk/integrations/gnu_backtrace.py b/sentry_sdk/integrations/gnu_backtrace.py index dc3dc80fe0..c9d2e51409 100644 --- a/sentry_sdk/integrations/gnu_backtrace.py +++ b/sentry_sdk/integrations/gnu_backtrace.py @@ -1,3 +1,4 @@ +from __future__ import annotations import re import sentry_sdk @@ -38,17 +39,14 @@ class GnuBacktraceIntegration(Integration): identifier = "gnu_backtrace" @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: @add_global_event_processor - def process_gnu_backtrace(event, hint): - # type: (Event, dict[str, Any]) -> Event + def process_gnu_backtrace(event: Event, hint: dict[str, Any]) -> Event: with capture_internal_exceptions(): return _process_gnu_backtrace(event, hint) -def _process_gnu_backtrace(event, hint): - # type: (Event, dict[str, Any]) -> Event +def _process_gnu_backtrace(event: Event, hint: dict[str, Any]) -> Event: if sentry_sdk.get_client().get_integration(GnuBacktraceIntegration) is None: return event diff --git a/sentry_sdk/integrations/gql.py b/sentry_sdk/integrations/gql.py index 5f4436f5b2..a43f04a062 100644 --- a/sentry_sdk/integrations/gql.py +++ b/sentry_sdk/integrations/gql.py @@ -1,3 +1,4 @@ +from __future__ import annotations import sentry_sdk from sentry_sdk.utils import ( event_from_exception, @@ -34,19 +35,17 @@ class GQLIntegration(Integration): identifier = "gql" @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: gql_version = parse_version(gql.__version__) _check_minimum_version(GQLIntegration, gql_version) _patch_execute() -def _data_from_document(document): - # type: (DocumentNode) -> EventDataType +def _data_from_document(document: DocumentNode) -> EventDataType: try: operation_ast = get_operation_ast(document) - data = {"query": print_ast(document)} # type: EventDataType + data: EventDataType = {"query": print_ast(document)} if operation_ast is not None: data["variables"] = operation_ast.variable_definitions @@ -58,8 +57,7 @@ def _data_from_document(document): return dict() -def _transport_method(transport): - # type: (Union[Transport, AsyncTransport]) -> str +def _transport_method(transport: Union[Transport, AsyncTransport]) -> str: """ The RequestsHTTPTransport allows defining the HTTP method; all other transports use POST. @@ -70,8 +68,9 @@ def _transport_method(transport): return "POST" -def _request_info_from_transport(transport): - # type: (Union[Transport, AsyncTransport, None]) -> Dict[str, str] +def _request_info_from_transport( + transport: Union[Transport, AsyncTransport, None], +) -> Dict[str, str]: if transport is None: return {} @@ -87,13 +86,13 @@ def _request_info_from_transport(transport): return request_info -def _patch_execute(): - # type: () -> None +def _patch_execute() -> None: real_execute = gql.Client.execute @ensure_integration_enabled(GQLIntegration, real_execute) - def sentry_patched_execute(self, document, *args, **kwargs): - # type: (gql.Client, DocumentNode, Any, Any) -> Any + def sentry_patched_execute( + self: gql.Client, document: DocumentNode, *args: Any, **kwargs: Any + ) -> Any: scope = sentry_sdk.get_isolation_scope() scope.add_event_processor(_make_gql_event_processor(self, document)) @@ -112,10 +111,10 @@ def sentry_patched_execute(self, document, *args, **kwargs): gql.Client.execute = sentry_patched_execute -def _make_gql_event_processor(client, document): - # type: (gql.Client, DocumentNode) -> EventProcessor - def processor(event, hint): - # type: (Event, dict[str, Any]) -> Event +def _make_gql_event_processor( + client: gql.Client, document: DocumentNode +) -> EventProcessor: + def processor(event: Event, hint: dict[str, Any]) -> Event: try: errors = hint["exc_info"][1].errors except (AttributeError, KeyError): diff --git a/sentry_sdk/integrations/graphene.py b/sentry_sdk/integrations/graphene.py index 9269a4403c..c0ea3adb53 100644 --- a/sentry_sdk/integrations/graphene.py +++ b/sentry_sdk/integrations/graphene.py @@ -1,3 +1,4 @@ +from __future__ import annotations from contextlib import contextmanager import sentry_sdk @@ -31,22 +32,21 @@ class GrapheneIntegration(Integration): identifier = "graphene" @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: version = package_version("graphene") _check_minimum_version(GrapheneIntegration, version) _patch_graphql() -def _patch_graphql(): - # type: () -> None +def _patch_graphql() -> None: old_graphql_sync = graphene_schema.graphql_sync old_graphql_async = graphene_schema.graphql @ensure_integration_enabled(GrapheneIntegration, old_graphql_sync) - def _sentry_patched_graphql_sync(schema, source, *args, **kwargs): - # type: (GraphQLSchema, Union[str, Source], Any, Any) -> ExecutionResult + def _sentry_patched_graphql_sync( + schema: GraphQLSchema, source: Union[str, Source], *args: Any, **kwargs: Any + ) -> ExecutionResult: scope = sentry_sdk.get_isolation_scope() scope.add_event_processor(_event_processor) @@ -68,8 +68,9 @@ def _sentry_patched_graphql_sync(schema, source, *args, **kwargs): return result - async def _sentry_patched_graphql_async(schema, source, *args, **kwargs): - # type: (GraphQLSchema, Union[str, Source], Any, Any) -> ExecutionResult + async def _sentry_patched_graphql_async( + schema: GraphQLSchema, source: Union[str, Source], *args: Any, **kwargs: Any + ) -> ExecutionResult: integration = sentry_sdk.get_client().get_integration(GrapheneIntegration) if integration is None: return await old_graphql_async(schema, source, *args, **kwargs) @@ -99,8 +100,7 @@ async def _sentry_patched_graphql_async(schema, source, *args, **kwargs): graphene_schema.graphql = _sentry_patched_graphql_async -def _event_processor(event, hint): - # type: (Event, Dict[str, Any]) -> Event +def _event_processor(event: Event, hint: Dict[str, Any]) -> Event: if should_send_default_pii(): request_info = event.setdefault("request", {}) request_info["api_target"] = "graphql" @@ -112,8 +112,9 @@ def _event_processor(event, hint): @contextmanager -def graphql_span(schema, source, kwargs): - # type: (GraphQLSchema, Union[str, Source], Dict[str, Any]) -> Generator[None, None, None] +def graphql_span( + schema: GraphQLSchema, source: Union[str, Source], kwargs: Dict[str, Any] +) -> Generator[None, None, None]: operation_name = kwargs.get("operation_name") operation_type = "query" diff --git a/sentry_sdk/integrations/grpc/__init__.py b/sentry_sdk/integrations/grpc/__init__.py index 4e15f95ae5..29f6100a50 100644 --- a/sentry_sdk/integrations/grpc/__init__.py +++ b/sentry_sdk/integrations/grpc/__init__.py @@ -1,3 +1,4 @@ +from __future__ import annotations from functools import wraps import grpc @@ -130,10 +131,10 @@ def patched_aio_server( # type: ignore **kwargs: P.kwargs, ) -> Server: server_interceptor = AsyncServerInterceptor() - interceptors = [ + interceptors: Sequence[grpc.ServerInterceptor] = [ server_interceptor, *(interceptors or []), - ] # type: Sequence[grpc.ServerInterceptor] + ] try: # We prefer interceptors as a list because of compatibility with diff --git a/sentry_sdk/integrations/grpc/aio/server.py b/sentry_sdk/integrations/grpc/aio/server.py index 91c2e9d74f..2538b89252 100644 --- a/sentry_sdk/integrations/grpc/aio/server.py +++ b/sentry_sdk/integrations/grpc/aio/server.py @@ -1,3 +1,4 @@ +from __future__ import annotations import sentry_sdk from sentry_sdk.consts import OP from sentry_sdk.integrations import DidNotEnable @@ -21,14 +22,19 @@ class ServerInterceptor(grpc.aio.ServerInterceptor): # type: ignore - def __init__(self, find_name=None): - # type: (ServerInterceptor, Callable[[ServicerContext], str] | None) -> None + def __init__( + self: ServerInterceptor, + find_name: Callable[[ServicerContext], str] | None = None, + ) -> None: self._find_method_name = find_name or self._find_name super().__init__() - async def intercept_service(self, continuation, handler_call_details): - # type: (ServerInterceptor, Callable[[HandlerCallDetails], Awaitable[RpcMethodHandler]], HandlerCallDetails) -> Optional[Awaitable[RpcMethodHandler]] + async def intercept_service( + self: ServerInterceptor, + continuation: Callable[[HandlerCallDetails], Awaitable[RpcMethodHandler]], + handler_call_details: HandlerCallDetails, + ) -> Optional[Awaitable[RpcMethodHandler]]: self._handler_call_details = handler_call_details handler = await continuation(handler_call_details) if handler is None: @@ -37,8 +43,7 @@ async def intercept_service(self, continuation, handler_call_details): if not handler.request_streaming and not handler.response_streaming: handler_factory = grpc.unary_unary_rpc_method_handler - async def wrapped(request, context): - # type: (Any, ServicerContext) -> Any + async def wrapped(request: Any, context: ServicerContext) -> Any: name = self._find_method_name(context) if not name: return await handler(request, context) @@ -66,24 +71,21 @@ async def wrapped(request, context): elif not handler.request_streaming and handler.response_streaming: handler_factory = grpc.unary_stream_rpc_method_handler - async def wrapped(request, context): # type: ignore - # type: (Any, ServicerContext) -> Any + async def wrapped(request: Any, context: ServicerContext) -> Any: # type: ignore async for r in handler.unary_stream(request, context): yield r elif handler.request_streaming and not handler.response_streaming: handler_factory = grpc.stream_unary_rpc_method_handler - async def wrapped(request, context): - # type: (Any, ServicerContext) -> Any + async def wrapped(request: Any, context: ServicerContext) -> Any: response = handler.stream_unary(request, context) return await response elif handler.request_streaming and handler.response_streaming: handler_factory = grpc.stream_stream_rpc_method_handler - async def wrapped(request, context): # type: ignore - # type: (Any, ServicerContext) -> Any + async def wrapped(request: Any, context: ServicerContext) -> Any: # type: ignore async for r in handler.stream_stream(request, context): yield r @@ -93,6 +95,5 @@ async def wrapped(request, context): # type: ignore response_serializer=handler.response_serializer, ) - def _find_name(self, context): - # type: (ServicerContext) -> str + def _find_name(self, context: ServicerContext) -> str: return self._handler_call_details.method diff --git a/sentry_sdk/integrations/grpc/client.py b/sentry_sdk/integrations/grpc/client.py index b7a1ddd85e..a3c434bd56 100644 --- a/sentry_sdk/integrations/grpc/client.py +++ b/sentry_sdk/integrations/grpc/client.py @@ -1,3 +1,4 @@ +from __future__ import annotations import sentry_sdk from sentry_sdk.consts import OP from sentry_sdk.integrations import DidNotEnable @@ -23,8 +24,12 @@ class ClientInterceptor( ): _is_intercepted = False - def intercept_unary_unary(self, continuation, client_call_details, request): - # type: (ClientInterceptor, Callable[[ClientCallDetails, Message], _UnaryOutcome], ClientCallDetails, Message) -> _UnaryOutcome + def intercept_unary_unary( + self: ClientInterceptor, + continuation: Callable[[ClientCallDetails, Message], _UnaryOutcome], + client_call_details: ClientCallDetails, + request: Message, + ) -> _UnaryOutcome: method = client_call_details.method with sentry_sdk.start_span( @@ -45,8 +50,14 @@ def intercept_unary_unary(self, continuation, client_call_details, request): return response - def intercept_unary_stream(self, continuation, client_call_details, request): - # type: (ClientInterceptor, Callable[[ClientCallDetails, Message], Union[Iterable[Any], UnaryStreamCall]], ClientCallDetails, Message) -> Union[Iterator[Message], Call] + def intercept_unary_stream( + self: ClientInterceptor, + continuation: Callable[ + [ClientCallDetails, Message], Union[Iterable[Any], UnaryStreamCall] + ], + client_call_details: ClientCallDetails, + request: Message, + ) -> Union[Iterator[Message], Call]: method = client_call_details.method with sentry_sdk.start_span( @@ -62,17 +73,16 @@ def intercept_unary_stream(self, continuation, client_call_details, request): client_call_details ) - response = continuation( - client_call_details, request - ) # type: UnaryStreamCall + response: UnaryStreamCall = continuation(client_call_details, request) # Setting code on unary-stream leads to execution getting stuck # span.set_attribute("code", response.code().name) return response @staticmethod - def _update_client_call_details_metadata_from_scope(client_call_details): - # type: (ClientCallDetails) -> ClientCallDetails + def _update_client_call_details_metadata_from_scope( + client_call_details: ClientCallDetails, + ) -> ClientCallDetails: metadata = ( list(client_call_details.metadata) if client_call_details.metadata else [] ) diff --git a/sentry_sdk/integrations/grpc/consts.py b/sentry_sdk/integrations/grpc/consts.py index 9fdb975caf..6ee9ed49ca 100644 --- a/sentry_sdk/integrations/grpc/consts.py +++ b/sentry_sdk/integrations/grpc/consts.py @@ -1 +1,3 @@ +from __future__ import annotations + SPAN_ORIGIN = "auto.grpc.grpc" diff --git a/sentry_sdk/integrations/grpc/server.py b/sentry_sdk/integrations/grpc/server.py index 582ef6e24a..2407bfecbe 100644 --- a/sentry_sdk/integrations/grpc/server.py +++ b/sentry_sdk/integrations/grpc/server.py @@ -1,3 +1,4 @@ +from __future__ import annotations import sentry_sdk from sentry_sdk.consts import OP from sentry_sdk.integrations import DidNotEnable @@ -18,20 +19,24 @@ class ServerInterceptor(grpc.ServerInterceptor): # type: ignore - def __init__(self, find_name=None): - # type: (ServerInterceptor, Optional[Callable[[ServicerContext], str]]) -> None + def __init__( + self: ServerInterceptor, + find_name: Optional[Callable[[ServicerContext], str]] = None, + ) -> None: self._find_method_name = find_name or ServerInterceptor._find_name super().__init__() - def intercept_service(self, continuation, handler_call_details): - # type: (ServerInterceptor, Callable[[HandlerCallDetails], RpcMethodHandler], HandlerCallDetails) -> RpcMethodHandler + def intercept_service( + self: ServerInterceptor, + continuation: Callable[[HandlerCallDetails], RpcMethodHandler], + handler_call_details: HandlerCallDetails, + ) -> RpcMethodHandler: handler = continuation(handler_call_details) if not handler or not handler.unary_unary: return handler - def behavior(request, context): - # type: (Message, ServicerContext) -> Message + def behavior(request: Message, context: ServicerContext) -> Message: with sentry_sdk.isolation_scope(): name = self._find_method_name(context) @@ -59,6 +64,5 @@ def behavior(request, context): ) @staticmethod - def _find_name(context): - # type: (ServicerContext) -> str + def _find_name(context: ServicerContext) -> str: return context._rpc_event.call_details.method.decode() diff --git a/sentry_sdk/integrations/httpx.py b/sentry_sdk/integrations/httpx.py index 61ce75734b..46e83d49a8 100644 --- a/sentry_sdk/integrations/httpx.py +++ b/sentry_sdk/integrations/httpx.py @@ -1,3 +1,4 @@ +from __future__ import annotations import sentry_sdk from sentry_sdk.consts import OP, SPANDATA, BAGGAGE_HEADER_NAME from sentry_sdk.integrations import Integration, DidNotEnable @@ -32,8 +33,7 @@ class HttpxIntegration(Integration): origin = f"auto.http.{identifier}" @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: """ httpx has its own transport layer and can be customized when needed, so patch Client.send and AsyncClient.send to support both synchronous and async interfaces. @@ -42,13 +42,11 @@ def setup_once(): _install_httpx_async_client() -def _install_httpx_client(): - # type: () -> None +def _install_httpx_client() -> None: real_send = Client.send @ensure_integration_enabled(HttpxIntegration, real_send) - def send(self, request, **kwargs): - # type: (Client, Request, **Any) -> Response + def send(self: Client, request: Request, **kwargs: Any) -> Response: parsed_url = None with capture_internal_exceptions(): parsed_url = parse_url(str(request.url), sanitize=False) @@ -112,12 +110,10 @@ def send(self, request, **kwargs): Client.send = send -def _install_httpx_async_client(): - # type: () -> None +def _install_httpx_async_client() -> None: real_send = AsyncClient.send - async def send(self, request, **kwargs): - # type: (AsyncClient, Request, **Any) -> Response + async def send(self: AsyncClient, request: Request, **kwargs: Any) -> Response: if sentry_sdk.get_client().get_integration(HttpxIntegration) is None: return await real_send(self, request, **kwargs) @@ -184,8 +180,9 @@ async def send(self, request, **kwargs): AsyncClient.send = send -def _add_sentry_baggage_to_headers(headers, sentry_baggage): - # type: (MutableMapping[str, str], str) -> None +def _add_sentry_baggage_to_headers( + headers: MutableMapping[str, str], sentry_baggage: str +) -> None: """Add the Sentry baggage to the headers. This function directly mutates the provided headers. The provided sentry_baggage diff --git a/sentry_sdk/integrations/huey.py b/sentry_sdk/integrations/huey.py index 1d1c498843..7bd766b413 100644 --- a/sentry_sdk/integrations/huey.py +++ b/sentry_sdk/integrations/huey.py @@ -1,3 +1,4 @@ +from __future__ import annotations import sys from datetime import datetime @@ -45,19 +46,16 @@ class HueyIntegration(Integration): origin = f"auto.queue.{identifier}" @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: patch_enqueue() patch_execute() -def patch_enqueue(): - # type: () -> None +def patch_enqueue() -> None: old_enqueue = Huey.enqueue @ensure_integration_enabled(HueyIntegration, old_enqueue) - def _sentry_enqueue(self, task): - # type: (Huey, Task) -> Optional[Union[Result, ResultGroup]] + def _sentry_enqueue(self: Huey, task: Task) -> Optional[Union[Result, ResultGroup]]: with sentry_sdk.start_span( op=OP.QUEUE_SUBMIT_HUEY, name=task.name, @@ -77,10 +75,8 @@ def _sentry_enqueue(self, task): Huey.enqueue = _sentry_enqueue -def _make_event_processor(task): - # type: (Any) -> EventProcessor - def event_processor(event, hint): - # type: (Event, Hint) -> Optional[Event] +def _make_event_processor(task: Any) -> EventProcessor: + def event_processor(event: Event, hint: Hint) -> Optional[Event]: with capture_internal_exceptions(): tags = event.setdefault("tags", {}) @@ -107,8 +103,7 @@ def event_processor(event, hint): return event_processor -def _capture_exception(exc_info): - # type: (ExcInfo) -> None +def _capture_exception(exc_info: ExcInfo) -> None: scope = sentry_sdk.get_current_scope() if scope.root_span is not None: @@ -126,12 +121,10 @@ def _capture_exception(exc_info): scope.capture_event(event, hint=hint) -def _wrap_task_execute(func): - # type: (F) -> F +def _wrap_task_execute(func: F) -> F: @ensure_integration_enabled(HueyIntegration, func) - def _sentry_execute(*args, **kwargs): - # type: (*Any, **Any) -> Any + def _sentry_execute(*args: Any, **kwargs: Any) -> Any: try: result = func(*args, **kwargs) except Exception: @@ -148,13 +141,13 @@ def _sentry_execute(*args, **kwargs): return _sentry_execute # type: ignore -def patch_execute(): - # type: () -> None +def patch_execute() -> None: old_execute = Huey._execute @ensure_integration_enabled(HueyIntegration, old_execute) - def _sentry_execute(self, task, timestamp=None): - # type: (Huey, Task, Optional[datetime]) -> Any + def _sentry_execute( + self: Huey, task: Task, timestamp: Optional[datetime] = None + ) -> Any: with sentry_sdk.isolation_scope() as scope: with capture_internal_exceptions(): scope._name = "huey" diff --git a/sentry_sdk/integrations/huggingface_hub.py b/sentry_sdk/integrations/huggingface_hub.py index d9a4c2563e..dc8cc68631 100644 --- a/sentry_sdk/integrations/huggingface_hub.py +++ b/sentry_sdk/integrations/huggingface_hub.py @@ -1,3 +1,4 @@ +from __future__ import annotations from functools import wraps from sentry_sdk import consts @@ -27,13 +28,11 @@ class HuggingfaceHubIntegration(Integration): identifier = "huggingface_hub" origin = f"auto.ai.{identifier}" - def __init__(self, include_prompts=True): - # type: (HuggingfaceHubIntegration, bool) -> None + def __init__(self: HuggingfaceHubIntegration, include_prompts: bool = True) -> None: self.include_prompts = include_prompts @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: huggingface_hub.inference._client.InferenceClient.text_generation = ( _wrap_text_generation( huggingface_hub.inference._client.InferenceClient.text_generation @@ -41,8 +40,7 @@ def setup_once(): ) -def _capture_exception(exc): - # type: (Any) -> None +def _capture_exception(exc: Any) -> None: event, hint = event_from_exception( exc, client_options=sentry_sdk.get_client().options, @@ -51,11 +49,9 @@ def _capture_exception(exc): sentry_sdk.capture_event(event, hint=hint) -def _wrap_text_generation(f): - # type: (Callable[..., Any]) -> Callable[..., Any] +def _wrap_text_generation(f: Callable[..., Any]) -> Callable[..., Any]: @wraps(f) - def new_text_generation(*args, **kwargs): - # type: (*Any, **Any) -> Any + def new_text_generation(*args: Any, **kwargs: Any) -> Any: integration = sentry_sdk.get_client().get_integration(HuggingfaceHubIntegration) if integration is None: return f(*args, **kwargs) @@ -124,8 +120,7 @@ def new_text_generation(*args, **kwargs): if kwargs.get("details", False): # res is Iterable[TextGenerationStreamOutput] - def new_details_iterator(): - # type: () -> Iterable[ChatCompletionStreamOutput] + def new_details_iterator() -> Iterable[ChatCompletionStreamOutput]: with capture_internal_exceptions(): tokens_used = 0 data_buf: list[str] = [] @@ -153,8 +148,7 @@ def new_details_iterator(): else: # res is Iterable[str] - def new_iterator(): - # type: () -> Iterable[str] + def new_iterator() -> Iterable[str]: data_buf: list[str] = [] with capture_internal_exceptions(): for s in res: diff --git a/sentry_sdk/integrations/langchain.py b/sentry_sdk/integrations/langchain.py index 5f1524e445..fc7e54af8f 100644 --- a/sentry_sdk/integrations/langchain.py +++ b/sentry_sdk/integrations/langchain.py @@ -1,3 +1,4 @@ +from __future__ import annotations import itertools from collections import OrderedDict from functools import wraps @@ -60,37 +61,41 @@ class LangchainIntegration(Integration): max_spans = 1024 def __init__( - self, include_prompts=True, max_spans=1024, tiktoken_encoding_name=None - ): - # type: (LangchainIntegration, bool, int, Optional[str]) -> None + self: LangchainIntegration, + include_prompts: bool = True, + max_spans: int = 1024, + tiktoken_encoding_name: Optional[str] = None, + ) -> None: self.include_prompts = include_prompts self.max_spans = max_spans self.tiktoken_encoding_name = tiktoken_encoding_name @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: manager._configure = _wrap_configure(manager._configure) class WatchedSpan: - num_completion_tokens = 0 # type: int - num_prompt_tokens = 0 # type: int - no_collect_tokens = False # type: bool - children = [] # type: List[WatchedSpan] - is_pipeline = False # type: bool - - def __init__(self, span): - # type: (Span) -> None + num_completion_tokens: int = 0 + num_prompt_tokens: int = 0 + no_collect_tokens: bool = False + children: List[WatchedSpan] = [] + is_pipeline: bool = False + + def __init__(self, span: Span) -> None: self.span = span class SentryLangchainCallback(BaseCallbackHandler): # type: ignore[misc] """Base callback handler that can be used to handle callbacks from langchain.""" - def __init__(self, max_span_map_size, include_prompts, tiktoken_encoding_name=None): - # type: (int, bool, Optional[str]) -> None - self.span_map = OrderedDict() # type: OrderedDict[UUID, WatchedSpan] + def __init__( + self, + max_span_map_size: int, + include_prompts: bool, + tiktoken_encoding_name: Optional[str] = None, + ) -> None: + self.span_map: OrderedDict[UUID, WatchedSpan] = OrderedDict() self.max_span_map_size = max_span_map_size self.include_prompts = include_prompts @@ -100,21 +105,18 @@ def __init__(self, max_span_map_size, include_prompts, tiktoken_encoding_name=No self.tiktoken_encoding = tiktoken.get_encoding(tiktoken_encoding_name) - def count_tokens(self, s): - # type: (str) -> int + def count_tokens(self, s: str) -> int: if self.tiktoken_encoding is not None: return len(self.tiktoken_encoding.encode_ordinary(s)) return 0 - def gc_span_map(self): - # type: () -> None + def gc_span_map(self) -> None: while len(self.span_map) > self.max_span_map_size: run_id, watched_span = self.span_map.popitem(last=False) self._exit_span(watched_span, run_id) - def _handle_error(self, run_id, error): - # type: (UUID, Any) -> None + def _handle_error(self, run_id: UUID, error: Any) -> None: if not run_id or run_id not in self.span_map: return @@ -126,14 +128,17 @@ def _handle_error(self, run_id, error): span_data.span.finish() del self.span_map[run_id] - def _normalize_langchain_message(self, message): - # type: (BaseMessage) -> Any + def _normalize_langchain_message(self, message: BaseMessage) -> Any: parsed = {"content": message.content, "role": message.type} parsed.update(message.additional_kwargs) return parsed - def _create_span(self, run_id, parent_id, **kwargs): - # type: (SentryLangchainCallback, UUID, Optional[Any], Any) -> WatchedSpan + def _create_span( + self: SentryLangchainCallback, + run_id: UUID, + parent_id: Optional[Any], + **kwargs: Any, + ) -> WatchedSpan: parent_watched_span = self.span_map.get(parent_id) if parent_id else None sentry_span = sentry_sdk.start_span( @@ -160,8 +165,9 @@ def _create_span(self, run_id, parent_id, **kwargs): self.gc_span_map() return watched_span - def _exit_span(self, span_data, run_id): - # type: (SentryLangchainCallback, WatchedSpan, UUID) -> None + def _exit_span( + self: SentryLangchainCallback, span_data: WatchedSpan, run_id: UUID + ) -> None: if span_data.is_pipeline: set_ai_pipeline_name(None) @@ -171,17 +177,16 @@ def _exit_span(self, span_data, run_id): del self.span_map[run_id] def on_llm_start( - self, - serialized, - prompts, + self: SentryLangchainCallback, + serialized: Dict[str, Any], + prompts: List[str], *, - run_id, - tags=None, - parent_run_id=None, - metadata=None, - **kwargs, - ): - # type: (SentryLangchainCallback, Dict[str, Any], List[str], UUID, Optional[List[str]], Optional[UUID], Optional[Dict[str, Any]], Any) -> Any + run_id: UUID, + tags: Optional[List[str]] = None, + parent_run_id: Optional[UUID] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Any: """Run when LLM starts running.""" with capture_internal_exceptions(): if not run_id: @@ -202,8 +207,14 @@ def on_llm_start( if k in all_params: set_data_normalized(span, v, all_params[k]) - def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs): - # type: (SentryLangchainCallback, Dict[str, Any], List[List[BaseMessage]], UUID, Any) -> Any + def on_chat_model_start( + self: SentryLangchainCallback, + serialized: Dict[str, Any], + messages: List[List[BaseMessage]], + *, + run_id: UUID, + **kwargs: Any, + ) -> Any: """Run when Chat Model starts running.""" with capture_internal_exceptions(): if not run_id: @@ -248,8 +259,9 @@ def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs): message.content ) + self.count_tokens(message.type) - def on_llm_new_token(self, token, *, run_id, **kwargs): - # type: (SentryLangchainCallback, str, UUID, Any) -> Any + def on_llm_new_token( + self: SentryLangchainCallback, token: str, *, run_id: UUID, **kwargs: Any + ) -> Any: """Run on new LLM token. Only available when streaming is enabled.""" with capture_internal_exceptions(): if not run_id or run_id not in self.span_map: @@ -259,8 +271,13 @@ def on_llm_new_token(self, token, *, run_id, **kwargs): return span_data.num_completion_tokens += self.count_tokens(token) - def on_llm_end(self, response, *, run_id, **kwargs): - # type: (SentryLangchainCallback, LLMResult, UUID, Any) -> Any + def on_llm_end( + self: SentryLangchainCallback, + response: LLMResult, + *, + run_id: UUID, + **kwargs: Any, + ) -> Any: """Run when LLM ends running.""" with capture_internal_exceptions(): if not run_id: @@ -298,14 +315,25 @@ def on_llm_end(self, response, *, run_id, **kwargs): self._exit_span(span_data, run_id) - def on_llm_error(self, error, *, run_id, **kwargs): - # type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Any) -> Any + def on_llm_error( + self: SentryLangchainCallback, + error: Union[Exception, KeyboardInterrupt], + *, + run_id: UUID, + **kwargs: Any, + ) -> Any: """Run when LLM errors.""" with capture_internal_exceptions(): self._handle_error(run_id, error) - def on_chain_start(self, serialized, inputs, *, run_id, **kwargs): - # type: (SentryLangchainCallback, Dict[str, Any], Dict[str, Any], UUID, Any) -> Any + def on_chain_start( + self: SentryLangchainCallback, + serialized: Dict[str, Any], + inputs: Dict[str, Any], + *, + run_id: UUID, + **kwargs: Any, + ) -> Any: """Run when chain starts running.""" with capture_internal_exceptions(): if not run_id: @@ -325,8 +353,13 @@ def on_chain_start(self, serialized, inputs, *, run_id, **kwargs): if metadata: set_data_normalized(watched_span.span, SPANDATA.AI_METADATA, metadata) - def on_chain_end(self, outputs, *, run_id, **kwargs): - # type: (SentryLangchainCallback, Dict[str, Any], UUID, Any) -> Any + def on_chain_end( + self: SentryLangchainCallback, + outputs: Dict[str, Any], + *, + run_id: UUID, + **kwargs: Any, + ) -> Any: """Run when chain ends running.""" with capture_internal_exceptions(): if not run_id or run_id not in self.span_map: @@ -337,13 +370,23 @@ def on_chain_end(self, outputs, *, run_id, **kwargs): return self._exit_span(span_data, run_id) - def on_chain_error(self, error, *, run_id, **kwargs): - # type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Any) -> Any + def on_chain_error( + self: SentryLangchainCallback, + error: Union[Exception, KeyboardInterrupt], + *, + run_id: UUID, + **kwargs: Any, + ) -> Any: """Run when chain errors.""" self._handle_error(run_id, error) - def on_agent_action(self, action, *, run_id, **kwargs): - # type: (SentryLangchainCallback, AgentAction, UUID, Any) -> Any + def on_agent_action( + self: SentryLangchainCallback, + action: AgentAction, + *, + run_id: UUID, + **kwargs: Any, + ) -> Any: with capture_internal_exceptions(): if not run_id: return @@ -359,8 +402,13 @@ def on_agent_action(self, action, *, run_id, **kwargs): watched_span.span, SPANDATA.AI_INPUT_MESSAGES, action.tool_input ) - def on_agent_finish(self, finish, *, run_id, **kwargs): - # type: (SentryLangchainCallback, AgentFinish, UUID, Any) -> Any + def on_agent_finish( + self: SentryLangchainCallback, + finish: AgentFinish, + *, + run_id: UUID, + **kwargs: Any, + ) -> Any: with capture_internal_exceptions(): if not run_id: return @@ -374,8 +422,14 @@ def on_agent_finish(self, finish, *, run_id, **kwargs): ) self._exit_span(span_data, run_id) - def on_tool_start(self, serialized, input_str, *, run_id, **kwargs): - # type: (SentryLangchainCallback, Dict[str, Any], str, UUID, Any) -> Any + def on_tool_start( + self: SentryLangchainCallback, + serialized: Dict[str, Any], + input_str: str, + *, + run_id: UUID, + **kwargs: Any, + ) -> Any: """Run when tool starts running.""" with capture_internal_exceptions(): if not run_id: @@ -398,8 +452,9 @@ def on_tool_start(self, serialized, input_str, *, run_id, **kwargs): watched_span.span, SPANDATA.AI_METADATA, kwargs.get("metadata") ) - def on_tool_end(self, output, *, run_id, **kwargs): - # type: (SentryLangchainCallback, str, UUID, Any) -> Any + def on_tool_end( + self: SentryLangchainCallback, output: str, *, run_id: UUID, **kwargs: Any + ) -> Any: """Run when tool ends running.""" with capture_internal_exceptions(): if not run_id or run_id not in self.span_map: @@ -412,24 +467,27 @@ def on_tool_end(self, output, *, run_id, **kwargs): set_data_normalized(span_data.span, SPANDATA.AI_RESPONSES, output) self._exit_span(span_data, run_id) - def on_tool_error(self, error, *args, run_id, **kwargs): - # type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Any) -> Any + def on_tool_error( + self, + error: SentryLangchainCallback, + *args: Union[Exception, KeyboardInterrupt], + run_id: UUID, + **kwargs: Any, + ) -> Any: """Run when tool errors.""" self._handle_error(run_id, error) -def _wrap_configure(f): - # type: (Callable[..., Any]) -> Callable[..., Any] +def _wrap_configure(f: Callable[..., Any]) -> Callable[..., Any]: @wraps(f) def new_configure( - callback_manager_cls, # type: type - inheritable_callbacks=None, # type: Callbacks - local_callbacks=None, # type: Callbacks - *args, # type: Any - **kwargs, # type: Any - ): - # type: (...) -> Any + callback_manager_cls: type, + inheritable_callbacks: Callbacks = None, + local_callbacks: Callbacks = None, + *args: Any, + **kwargs: Any, + ) -> Any: integration = sentry_sdk.get_client().get_integration(LangchainIntegration) if integration is None: diff --git a/sentry_sdk/integrations/launchdarkly.py b/sentry_sdk/integrations/launchdarkly.py index d3c423e7be..18081e617a 100644 --- a/sentry_sdk/integrations/launchdarkly.py +++ b/sentry_sdk/integrations/launchdarkly.py @@ -1,3 +1,4 @@ +from __future__ import annotations from typing import TYPE_CHECKING from sentry_sdk.feature_flags import add_feature_flag @@ -20,8 +21,7 @@ class LaunchDarklyIntegration(Integration): identifier = "launchdarkly" - def __init__(self, ld_client=None): - # type: (LDClient | None) -> None + def __init__(self, ld_client: LDClient | None = None) -> None: """ :param client: An initialized LDClient instance. If a client is not provided, this integration will attempt to use the shared global instance. @@ -38,25 +38,28 @@ def __init__(self, ld_client=None): client.add_hook(LaunchDarklyHook()) @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: pass class LaunchDarklyHook(Hook): @property - def metadata(self): - # type: () -> Metadata + def metadata(self) -> Metadata: return Metadata(name="sentry-flag-auditor") - def after_evaluation(self, series_context, data, detail): - # type: (EvaluationSeriesContext, dict[Any, Any], EvaluationDetail) -> dict[Any, Any] + def after_evaluation( + self, + series_context: EvaluationSeriesContext, + data: dict[Any, Any], + detail: EvaluationDetail, + ) -> dict[Any, Any]: if isinstance(detail.value, bool): add_feature_flag(series_context.key, detail.value) return data - def before_evaluation(self, series_context, data): - # type: (EvaluationSeriesContext, dict[Any, Any]) -> dict[Any, Any] + def before_evaluation( + self, series_context: EvaluationSeriesContext, data: dict[Any, Any] + ) -> dict[Any, Any]: return data # No-op. diff --git a/sentry_sdk/integrations/litestar.py b/sentry_sdk/integrations/litestar.py index 267a1c89af..7b4740e74b 100644 --- a/sentry_sdk/integrations/litestar.py +++ b/sentry_sdk/integrations/litestar.py @@ -1,3 +1,4 @@ +from __future__ import annotations from collections.abc import Set import sentry_sdk from sentry_sdk.consts import OP, TransactionSource, SOURCE_FOR_STYLE @@ -52,13 +53,12 @@ class LitestarIntegration(Integration): def __init__( self, - failed_request_status_codes=_DEFAULT_FAILED_REQUEST_STATUS_CODES, # type: Set[int] + failed_request_status_codes: Set[int] = _DEFAULT_FAILED_REQUEST_STATUS_CODES, ) -> None: self.failed_request_status_codes = failed_request_status_codes @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: patch_app_init() patch_middlewares() patch_http_route_handle() @@ -75,8 +75,9 @@ def setup_once(): class SentryLitestarASGIMiddleware(SentryAsgiMiddleware): - def __init__(self, app, span_origin=LitestarIntegration.origin): - # type: (ASGIApp, str) -> None + def __init__( + self, app: ASGIApp, span_origin: str = LitestarIntegration.origin + ) -> None: super().__init__( app=app, @@ -86,8 +87,7 @@ def __init__(self, app, span_origin=LitestarIntegration.origin): span_origin=span_origin, ) - def _capture_request_exception(self, exc): - # type: (Exception) -> None + def _capture_request_exception(self, exc: Exception) -> None: """Avoid catching exceptions from request handlers. Those exceptions are already handled in Litestar.after_exception handler. @@ -96,8 +96,7 @@ def _capture_request_exception(self, exc): pass -def patch_app_init(): - # type: () -> None +def patch_app_init() -> None: """ Replaces the Litestar class's `__init__` function in order to inject `after_exception` handlers and set the `SentryLitestarASGIMiddleware` as the outmost middleware in the stack. @@ -108,8 +107,7 @@ def patch_app_init(): old__init__ = Litestar.__init__ @ensure_integration_enabled(LitestarIntegration, old__init__) - def injection_wrapper(self, *args, **kwargs): - # type: (Litestar, *Any, **Any) -> None + def injection_wrapper(self: Litestar, *args: Any, **kwargs: Any) -> None: kwargs["after_exception"] = [ exception_handler, *(kwargs.get("after_exception") or []), @@ -123,13 +121,11 @@ def injection_wrapper(self, *args, **kwargs): Litestar.__init__ = injection_wrapper -def patch_middlewares(): - # type: () -> None +def patch_middlewares() -> None: old_resolve_middleware_stack = BaseRouteHandler.resolve_middleware @ensure_integration_enabled(LitestarIntegration, old_resolve_middleware_stack) - def resolve_middleware_wrapper(self): - # type: (BaseRouteHandler) -> list[Middleware] + def resolve_middleware_wrapper(self: BaseRouteHandler) -> list[Middleware]: return [ enable_span_for_middleware(middleware) for middleware in old_resolve_middleware_stack(self) @@ -138,8 +134,7 @@ def resolve_middleware_wrapper(self): BaseRouteHandler.resolve_middleware = resolve_middleware_wrapper -def enable_span_for_middleware(middleware): - # type: (Middleware) -> Middleware +def enable_span_for_middleware(middleware: Middleware) -> Middleware: if ( not hasattr(middleware, "__call__") # noqa: B004 or middleware is SentryLitestarASGIMiddleware @@ -147,12 +142,13 @@ def enable_span_for_middleware(middleware): return middleware if isinstance(middleware, DefineMiddleware): - old_call = middleware.middleware.__call__ # type: ASGIApp + old_call: ASGIApp = middleware.middleware.__call__ else: old_call = middleware.__call__ - async def _create_span_call(self, scope, receive, send): - # type: (MiddlewareProtocol, LitestarScope, Receive, Send) -> None + async def _create_span_call( + self: MiddlewareProtocol, scope: LitestarScope, receive: Receive, send: Send + ) -> None: if sentry_sdk.get_client().get_integration(LitestarIntegration) is None: return await old_call(self, scope, receive, send) @@ -166,8 +162,9 @@ async def _create_span_call(self, scope, receive, send): middleware_span.set_tag("litestar.middleware_name", middleware_name) # Creating spans for the "receive" callback - async def _sentry_receive(*args, **kwargs): - # type: (*Any, **Any) -> Union[HTTPReceiveMessage, WebSocketReceiveMessage] + async def _sentry_receive( + *args: Any, **kwargs: Any + ) -> Union[HTTPReceiveMessage, WebSocketReceiveMessage]: if sentry_sdk.get_client().get_integration(LitestarIntegration) is None: return await receive(*args, **kwargs) with sentry_sdk.start_span( @@ -184,8 +181,7 @@ async def _sentry_receive(*args, **kwargs): new_receive = _sentry_receive if not receive_patched else receive # Creating spans for the "send" callback - async def _sentry_send(message): - # type: (Message) -> None + async def _sentry_send(message: Message) -> None: if sentry_sdk.get_client().get_integration(LitestarIntegration) is None: return await send(message) with sentry_sdk.start_span( @@ -214,19 +210,19 @@ async def _sentry_send(message): return middleware -def patch_http_route_handle(): - # type: () -> None +def patch_http_route_handle() -> None: old_handle = HTTPRoute.handle - async def handle_wrapper(self, scope, receive, send): - # type: (HTTPRoute, HTTPScope, Receive, Send) -> None + async def handle_wrapper( + self: HTTPRoute, scope: HTTPScope, receive: Receive, send: Send + ) -> None: if sentry_sdk.get_client().get_integration(LitestarIntegration) is None: return await old_handle(self, scope, receive, send) sentry_scope = sentry_sdk.get_isolation_scope() - request = scope["app"].request_class( + request: Request[Any, Any] = scope["app"].request_class( scope=scope, receive=receive, send=send - ) # type: Request[Any, Any] + ) extracted_request_data = ConnectionDataExtractor( parse_body=True, parse_query=True )(request) @@ -234,8 +230,7 @@ async def handle_wrapper(self, scope, receive, send): request_data = await body - def event_processor(event, _): - # type: (Event, Hint) -> Event + def event_processor(event: Event, _: Hint) -> Event: route_handler = scope.get("route_handler") request_info = event.get("request", {}) @@ -279,8 +274,7 @@ def event_processor(event, _): HTTPRoute.handle = handle_wrapper -def retrieve_user_from_scope(scope): - # type: (LitestarScope) -> Optional[dict[str, Any]] +def retrieve_user_from_scope(scope: LitestarScope) -> Optional[dict[str, Any]]: scope_user = scope.get("user") if isinstance(scope_user, dict): return scope_user @@ -291,9 +285,8 @@ def retrieve_user_from_scope(scope): @ensure_integration_enabled(LitestarIntegration) -def exception_handler(exc, scope): - # type: (Exception, LitestarScope) -> None - user_info = None # type: Optional[dict[str, Any]] +def exception_handler(exc: Exception, scope: LitestarScope) -> None: + user_info: Optional[dict[str, Any]] = None if should_send_default_pii(): user_info = retrieve_user_from_scope(scope) if user_info and isinstance(user_info, dict): diff --git a/sentry_sdk/integrations/logging.py b/sentry_sdk/integrations/logging.py index f807a62966..9b363ba510 100644 --- a/sentry_sdk/integrations/logging.py +++ b/sentry_sdk/integrations/logging.py @@ -1,3 +1,4 @@ +from __future__ import annotations import logging import sys from datetime import datetime, timezone @@ -64,9 +65,8 @@ def ignore_logger( - name, # type: str -): - # type: (...) -> None + name: str, +) -> None: """This disables recording (both in breadcrumbs and as events) calls to a logger of a specific name. Among other uses, many of our integrations use this to prevent their actions being recorded as breadcrumbs. Exposed @@ -82,11 +82,10 @@ class LoggingIntegration(Integration): def __init__( self, - level=DEFAULT_LEVEL, - event_level=DEFAULT_EVENT_LEVEL, - sentry_logs_level=DEFAULT_LEVEL, - ): - # type: (Optional[int], Optional[int], Optional[int]) -> None + level: Optional[int] = DEFAULT_LEVEL, + event_level: Optional[int] = DEFAULT_EVENT_LEVEL, + sentry_logs_level: Optional[int] = DEFAULT_LEVEL, + ) -> None: self._handler = None self._breadcrumb_handler = None self._sentry_logs_handler = None @@ -100,8 +99,7 @@ def __init__( if event_level is not None: self._handler = EventHandler(level=event_level) - def _handle_record(self, record): - # type: (LogRecord) -> None + def _handle_record(self, record: LogRecord) -> None: if self._handler is not None and record.levelno >= self._handler.level: self._handler.handle(record) @@ -118,12 +116,10 @@ def _handle_record(self, record): self._sentry_logs_handler.handle(record) @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: old_callhandlers = logging.Logger.callHandlers - def sentry_patched_callhandlers(self, record): - # type: (Any, LogRecord) -> Any + def sentry_patched_callhandlers(self: Any, record: LogRecord) -> Any: # keeping a local reference because the # global might be discarded on shutdown ignored_loggers = _IGNORED_LOGGERS @@ -179,22 +175,19 @@ class _BaseHandler(logging.Handler): ) ) - def _can_record(self, record): - # type: (LogRecord) -> bool + def _can_record(self, record: LogRecord) -> bool: """Prevents ignored loggers from recording""" for logger in _IGNORED_LOGGERS: if fnmatch(record.name.strip(), logger): return False return True - def _logging_to_event_level(self, record): - # type: (LogRecord) -> str + def _logging_to_event_level(self, record: LogRecord) -> str: return LOGGING_TO_EVENT_LEVEL.get( record.levelno, record.levelname.lower() if record.levelname else "" ) - def _extra_from_record(self, record): - # type: (LogRecord) -> MutableMapping[str, object] + def _extra_from_record(self, record: LogRecord) -> MutableMapping[str, object]: return { k: v for k, v in vars(record).items() @@ -210,14 +203,12 @@ class EventHandler(_BaseHandler): Note that you do not have to use this class if the logging integration is enabled, which it is by default. """ - def emit(self, record): - # type: (LogRecord) -> Any + def emit(self, record: LogRecord) -> Any: with capture_internal_exceptions(): self.format(record) return self._emit(record) - def _emit(self, record): - # type: (LogRecord) -> None + def _emit(self, record: LogRecord) -> None: if not self._can_record(record): return @@ -304,14 +295,12 @@ class BreadcrumbHandler(_BaseHandler): Note that you do not have to use this class if the logging integration is enabled, which it is by default. """ - def emit(self, record): - # type: (LogRecord) -> Any + def emit(self, record: LogRecord) -> Any: with capture_internal_exceptions(): self.format(record) return self._emit(record) - def _emit(self, record): - # type: (LogRecord) -> None + def _emit(self, record: LogRecord) -> None: if not self._can_record(record): return @@ -319,8 +308,7 @@ def _emit(self, record): self._breadcrumb_from_record(record), hint={"log_record": record} ) - def _breadcrumb_from_record(self, record): - # type: (LogRecord) -> Dict[str, Any] + def _breadcrumb_from_record(self, record: LogRecord) -> Dict[str, Any]: return { "type": "log", "level": self._logging_to_event_level(record), @@ -338,8 +326,7 @@ class SentryLogsHandler(_BaseHandler): Note that you do not have to use this class if the logging integration is enabled, which it is by default. """ - def emit(self, record): - # type: (LogRecord) -> Any + def emit(self, record: LogRecord) -> Any: with capture_internal_exceptions(): self.format(record) if not self._can_record(record): @@ -354,13 +341,12 @@ def emit(self, record): self._capture_log_from_record(client, record) - def _capture_log_from_record(self, client, record): - # type: (BaseClient, LogRecord) -> None + def _capture_log_from_record(self, client: BaseClient, record: LogRecord) -> None: otel_severity_number, otel_severity_text = _log_level_to_otel( record.levelno, SEVERITY_TO_OTEL_SEVERITY ) project_root = client.options["project_root"] - attrs = self._extra_from_record(record) # type: Any + attrs: Any = self._extra_from_record(record) attrs["sentry.origin"] = "auto.logger.log" if isinstance(record.msg, str): attrs["sentry.message.template"] = record.msg diff --git a/sentry_sdk/integrations/loguru.py b/sentry_sdk/integrations/loguru.py index df3ecf161a..6bcec10d8e 100644 --- a/sentry_sdk/integrations/loguru.py +++ b/sentry_sdk/integrations/loguru.py @@ -1,3 +1,4 @@ +from __future__ import annotations import enum import sentry_sdk @@ -65,21 +66,20 @@ class LoggingLevels(enum.IntEnum): class LoguruIntegration(Integration): identifier = "loguru" - level = DEFAULT_LEVEL # type: Optional[int] - event_level = DEFAULT_EVENT_LEVEL # type: Optional[int] + level: Optional[int] = DEFAULT_LEVEL + event_level: Optional[int] = DEFAULT_EVENT_LEVEL breadcrumb_format = DEFAULT_FORMAT event_format = DEFAULT_FORMAT - sentry_logs_level = DEFAULT_LEVEL # type: Optional[int] + sentry_logs_level: Optional[int] = DEFAULT_LEVEL def __init__( self, - level=DEFAULT_LEVEL, - event_level=DEFAULT_EVENT_LEVEL, - breadcrumb_format=DEFAULT_FORMAT, - event_format=DEFAULT_FORMAT, - sentry_logs_level=DEFAULT_LEVEL, - ): - # type: (Optional[int], Optional[int], str | loguru.FormatFunction, str | loguru.FormatFunction, Optional[int]) -> None + level: Optional[int] = DEFAULT_LEVEL, + event_level: Optional[int] = DEFAULT_EVENT_LEVEL, + breadcrumb_format: str | loguru.FormatFunction = DEFAULT_FORMAT, + event_format: str | loguru.FormatFunction = DEFAULT_FORMAT, + sentry_logs_level: Optional[int] = DEFAULT_LEVEL, + ) -> None: LoguruIntegration.level = level LoguruIntegration.event_level = event_level LoguruIntegration.breadcrumb_format = breadcrumb_format @@ -87,8 +87,7 @@ def __init__( LoguruIntegration.sentry_logs_level = sentry_logs_level @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: if LoguruIntegration.level is not None: logger.add( LoguruBreadcrumbHandler(level=LoguruIntegration.level), @@ -111,8 +110,7 @@ def setup_once(): class _LoguruBaseHandler(_BaseHandler): - def __init__(self, *args, **kwargs): - # type: (*Any, **Any) -> None + def __init__(self, *args: Any, **kwargs: Any) -> None: if kwargs.get("level"): kwargs["level"] = SENTRY_LEVEL_FROM_LOGURU_LEVEL.get( kwargs.get("level", ""), DEFAULT_LEVEL @@ -120,8 +118,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def _logging_to_event_level(self, record): - # type: (LogRecord) -> str + def _logging_to_event_level(self, record: LogRecord) -> str: try: return SENTRY_LEVEL_FROM_LOGURU_LEVEL[ LoggingLevels(record.levelno).name @@ -142,8 +139,7 @@ class LoguruBreadcrumbHandler(_LoguruBaseHandler, BreadcrumbHandler): pass -def loguru_sentry_logs_handler(message): - # type: (Message) -> None +def loguru_sentry_logs_handler(message: Message) -> None: # This is intentionally a callable sink instead of a standard logging handler # since otherwise we wouldn't get direct access to message.record client = sentry_sdk.get_client() @@ -166,7 +162,7 @@ def loguru_sentry_logs_handler(message): record["level"].no, SEVERITY_TO_OTEL_SEVERITY ) - attrs = {"sentry.origin": "auto.logger.loguru"} # type: dict[str, Any] + attrs: dict[str, Any] = {"sentry.origin": "auto.logger.loguru"} project_root = client.options["project_root"] if record.get("file"): diff --git a/sentry_sdk/integrations/modules.py b/sentry_sdk/integrations/modules.py index ce3ee78665..a289ce1989 100644 --- a/sentry_sdk/integrations/modules.py +++ b/sentry_sdk/integrations/modules.py @@ -1,3 +1,4 @@ +from __future__ import annotations import sentry_sdk from sentry_sdk.integrations import Integration from sentry_sdk.scope import add_global_event_processor @@ -14,11 +15,9 @@ class ModulesIntegration(Integration): identifier = "modules" @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: @add_global_event_processor - def processor(event, hint): - # type: (Event, Any) -> Event + def processor(event: Event, hint: Any) -> Event: if event.get("type") == "transaction": return event diff --git a/sentry_sdk/integrations/openai.py b/sentry_sdk/integrations/openai.py index a4467c9782..dc9c1c1100 100644 --- a/sentry_sdk/integrations/openai.py +++ b/sentry_sdk/integrations/openai.py @@ -1,3 +1,4 @@ +from __future__ import annotations from functools import wraps import sentry_sdk @@ -32,8 +33,11 @@ class OpenAIIntegration(Integration): identifier = "openai" origin = f"auto.ai.{identifier}" - def __init__(self, include_prompts=True, tiktoken_encoding_name=None): - # type: (OpenAIIntegration, bool, Optional[str]) -> None + def __init__( + self: OpenAIIntegration, + include_prompts: bool = True, + tiktoken_encoding_name: Optional[str] = None, + ) -> None: self.include_prompts = include_prompts self.tiktoken_encoding = None @@ -43,8 +47,7 @@ def __init__(self, include_prompts=True, tiktoken_encoding_name=None): self.tiktoken_encoding = tiktoken.get_encoding(tiktoken_encoding_name) @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: Completions.create = _wrap_chat_completion_create(Completions.create) Embeddings.create = _wrap_embeddings_create(Embeddings.create) @@ -53,15 +56,13 @@ def setup_once(): ) AsyncEmbeddings.create = _wrap_async_embeddings_create(AsyncEmbeddings.create) - def count_tokens(self, s): - # type: (OpenAIIntegration, str) -> int + def count_tokens(self: OpenAIIntegration, s: str) -> int: if self.tiktoken_encoding is not None: return len(self.tiktoken_encoding.encode_ordinary(s)) return 0 -def _capture_exception(exc): - # type: (Any) -> None +def _capture_exception(exc: Any) -> None: event, hint = event_from_exception( exc, client_options=sentry_sdk.get_client().options, @@ -71,12 +72,15 @@ def _capture_exception(exc): def _calculate_chat_completion_usage( - messages, response, span, streaming_message_responses, count_tokens -): - # type: (Iterable[ChatCompletionMessageParam], Any, Span, Optional[List[str]], Callable[..., Any]) -> None - completion_tokens = 0 # type: Optional[int] - prompt_tokens = 0 # type: Optional[int] - total_tokens = 0 # type: Optional[int] + messages: Iterable[ChatCompletionMessageParam], + response: Any, + span: Span, + streaming_message_responses: Optional[List[str]], + count_tokens: Callable[..., Any], +) -> None: + completion_tokens: Optional[int] = 0 + prompt_tokens: Optional[int] = 0 + total_tokens: Optional[int] = 0 if hasattr(response, "usage"): if hasattr(response.usage, "completion_tokens") and isinstance( response.usage.completion_tokens, int @@ -114,8 +118,7 @@ def _calculate_chat_completion_usage( record_token_usage(span, prompt_tokens, completion_tokens, total_tokens) -def _new_chat_completion_common(f, *args, **kwargs): - # type: (Any, *Any, **Any) -> Any +def _new_chat_completion_common(f: Any, *args: Any, **kwargs: Any) -> Any: integration = sentry_sdk.get_client().get_integration(OpenAIIntegration) if integration is None: return f(*args, **kwargs) @@ -168,8 +171,7 @@ def _new_chat_completion_common(f, *args, **kwargs): old_iterator = res._iterator - def new_iterator(): - # type: () -> Iterator[ChatCompletionChunk] + def new_iterator() -> Iterator[ChatCompletionChunk]: with capture_internal_exceptions(): for x in old_iterator: if hasattr(x, "choices"): @@ -201,8 +203,7 @@ def new_iterator(): ) span.__exit__(None, None, None) - async def new_iterator_async(): - # type: () -> AsyncIterator[ChatCompletionChunk] + async def new_iterator_async() -> AsyncIterator[ChatCompletionChunk]: with capture_internal_exceptions(): async for x in old_iterator: if hasattr(x, "choices"): @@ -245,10 +246,8 @@ async def new_iterator_async(): return res -def _wrap_chat_completion_create(f): - # type: (Callable[..., Any]) -> Callable[..., Any] - def _execute_sync(f, *args, **kwargs): - # type: (Any, *Any, **Any) -> Any +def _wrap_chat_completion_create(f: Callable[..., Any]) -> Callable[..., Any]: + def _execute_sync(f: Any, *args: Any, **kwargs: Any) -> Any: gen = _new_chat_completion_common(f, *args, **kwargs) try: @@ -268,8 +267,7 @@ def _execute_sync(f, *args, **kwargs): return e.value @wraps(f) - def _sentry_patched_create_sync(*args, **kwargs): - # type: (*Any, **Any) -> Any + def _sentry_patched_create_sync(*args: Any, **kwargs: Any) -> Any: integration = sentry_sdk.get_client().get_integration(OpenAIIntegration) if integration is None or "messages" not in kwargs: # no "messages" means invalid call (in all versions of openai), let it return error @@ -280,10 +278,8 @@ def _sentry_patched_create_sync(*args, **kwargs): return _sentry_patched_create_sync -def _wrap_async_chat_completion_create(f): - # type: (Callable[..., Any]) -> Callable[..., Any] - async def _execute_async(f, *args, **kwargs): - # type: (Any, *Any, **Any) -> Any +def _wrap_async_chat_completion_create(f: Callable[..., Any]) -> Callable[..., Any]: + async def _execute_async(f: Any, *args: Any, **kwargs: Any) -> Any: gen = _new_chat_completion_common(f, *args, **kwargs) try: @@ -303,8 +299,7 @@ async def _execute_async(f, *args, **kwargs): return e.value @wraps(f) - async def _sentry_patched_create_async(*args, **kwargs): - # type: (*Any, **Any) -> Any + async def _sentry_patched_create_async(*args: Any, **kwargs: Any) -> Any: integration = sentry_sdk.get_client().get_integration(OpenAIIntegration) if integration is None or "messages" not in kwargs: # no "messages" means invalid call (in all versions of openai), let it return error @@ -315,8 +310,7 @@ async def _sentry_patched_create_async(*args, **kwargs): return _sentry_patched_create_async -def _new_embeddings_create_common(f, *args, **kwargs): - # type: (Any, *Any, **Any) -> Any +def _new_embeddings_create_common(f: Any, *args: Any, **kwargs: Any) -> Any: integration = sentry_sdk.get_client().get_integration(OpenAIIntegration) if integration is None: return f(*args, **kwargs) @@ -363,10 +357,8 @@ def _new_embeddings_create_common(f, *args, **kwargs): return response -def _wrap_embeddings_create(f): - # type: (Any) -> Any - def _execute_sync(f, *args, **kwargs): - # type: (Any, *Any, **Any) -> Any +def _wrap_embeddings_create(f: Any) -> Any: + def _execute_sync(f: Any, *args: Any, **kwargs: Any) -> Any: gen = _new_embeddings_create_common(f, *args, **kwargs) try: @@ -386,8 +378,7 @@ def _execute_sync(f, *args, **kwargs): return e.value @wraps(f) - def _sentry_patched_create_sync(*args, **kwargs): - # type: (*Any, **Any) -> Any + def _sentry_patched_create_sync(*args: Any, **kwargs: Any) -> Any: integration = sentry_sdk.get_client().get_integration(OpenAIIntegration) if integration is None: return f(*args, **kwargs) @@ -397,10 +388,8 @@ def _sentry_patched_create_sync(*args, **kwargs): return _sentry_patched_create_sync -def _wrap_async_embeddings_create(f): - # type: (Any) -> Any - async def _execute_async(f, *args, **kwargs): - # type: (Any, *Any, **Any) -> Any +def _wrap_async_embeddings_create(f: Any) -> Any: + async def _execute_async(f: Any, *args: Any, **kwargs: Any) -> Any: gen = _new_embeddings_create_common(f, *args, **kwargs) try: @@ -420,8 +409,7 @@ async def _execute_async(f, *args, **kwargs): return e.value @wraps(f) - async def _sentry_patched_create_async(*args, **kwargs): - # type: (*Any, **Any) -> Any + async def _sentry_patched_create_async(*args: Any, **kwargs: Any) -> Any: integration = sentry_sdk.get_client().get_integration(OpenAIIntegration) if integration is None: return await f(*args, **kwargs) diff --git a/sentry_sdk/integrations/openfeature.py b/sentry_sdk/integrations/openfeature.py index e2b33d83f2..613d7b5bd1 100644 --- a/sentry_sdk/integrations/openfeature.py +++ b/sentry_sdk/integrations/openfeature.py @@ -1,3 +1,4 @@ +from __future__ import annotations from typing import TYPE_CHECKING from sentry_sdk.feature_flags import add_feature_flag @@ -18,20 +19,24 @@ class OpenFeatureIntegration(Integration): identifier = "openfeature" @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: # Register the hook within the global openfeature hooks list. api.add_hooks(hooks=[OpenFeatureHook()]) class OpenFeatureHook(Hook): - def after(self, hook_context, details, hints): - # type: (HookContext, FlagEvaluationDetails[bool], HookHints) -> None + def after( + self, + hook_context: HookContext, + details: FlagEvaluationDetails[bool], + hints: HookHints, + ) -> None: if isinstance(details.value, bool): add_feature_flag(details.flag_key, details.value) - def error(self, hook_context, exception, hints): - # type: (HookContext, Exception, HookHints) -> None + def error( + self, hook_context: HookContext, exception: Exception, hints: HookHints + ) -> None: if isinstance(hook_context.default_value, bool): add_feature_flag(hook_context.flag_key, hook_context.default_value) diff --git a/sentry_sdk/integrations/pure_eval.py b/sentry_sdk/integrations/pure_eval.py index c1c3d63871..74cfa5a7c6 100644 --- a/sentry_sdk/integrations/pure_eval.py +++ b/sentry_sdk/integrations/pure_eval.py @@ -1,3 +1,4 @@ +from __future__ import annotations import ast import sentry_sdk @@ -35,12 +36,10 @@ class PureEvalIntegration(Integration): identifier = "pure_eval" @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: @add_global_event_processor - def add_executing_info(event, hint): - # type: (Event, Optional[Hint]) -> Optional[Event] + def add_executing_info(event: Event, hint: Optional[Hint]) -> Optional[Event]: if sentry_sdk.get_client().get_integration(PureEvalIntegration) is None: return event @@ -81,8 +80,7 @@ def add_executing_info(event, hint): return event -def pure_eval_frame(frame): - # type: (FrameType) -> Dict[str, Any] +def pure_eval_frame(frame: FrameType) -> Dict[str, Any]: source = executing.Source.for_frame(frame) if not source.tree: return {} @@ -103,16 +101,14 @@ def pure_eval_frame(frame): evaluator = pure_eval.Evaluator.from_frame(frame) expressions = evaluator.interesting_expressions_grouped(scope) - def closeness(expression): - # type: (Tuple[List[Any], Any]) -> Tuple[int, int] + def closeness(expression: Tuple[List[Any], Any]) -> Tuple[int, int]: # Prioritise expressions with a node closer to the statement executed # without being after that statement # A higher return value is better - the expression will appear # earlier in the list of values and is less likely to be trimmed nodes, _value = expression - def start(n): - # type: (ast.expr) -> Tuple[int, int] + def start(n: ast.expr) -> Tuple[int, int]: return (n.lineno, n.col_offset) nodes_before_stmt = [ diff --git a/sentry_sdk/integrations/pymongo.py b/sentry_sdk/integrations/pymongo.py index 32cb294075..d7a29b3e8b 100644 --- a/sentry_sdk/integrations/pymongo.py +++ b/sentry_sdk/integrations/pymongo.py @@ -1,3 +1,4 @@ +from __future__ import annotations import copy import sentry_sdk @@ -41,8 +42,7 @@ ] -def _strip_pii(command): - # type: (Dict[str, Any]) -> Dict[str, Any] +def _strip_pii(command: Dict[str, Any]) -> Dict[str, Any]: for key in command: is_safe_field = key in SAFE_COMMAND_ATTRIBUTES if is_safe_field: @@ -84,8 +84,7 @@ def _strip_pii(command): return command -def _get_db_data(event): - # type: (Any) -> Dict[str, Any] +def _get_db_data(event: Any) -> Dict[str, Any]: data = {} data[SPANDATA.DB_SYSTEM] = "mongodb" @@ -106,16 +105,16 @@ def _get_db_data(event): class CommandTracer(monitoring.CommandListener): - def __init__(self): - # type: () -> None - self._ongoing_operations = {} # type: Dict[int, Span] + def __init__(self) -> None: + self._ongoing_operations: Dict[int, Span] = {} - def _operation_key(self, event): - # type: (Union[CommandFailedEvent, CommandStartedEvent, CommandSucceededEvent]) -> int + def _operation_key( + self, + event: Union[CommandFailedEvent, CommandStartedEvent, CommandSucceededEvent], + ) -> int: return event.request_id - def started(self, event): - # type: (CommandStartedEvent) -> None + def started(self, event: CommandStartedEvent) -> None: if sentry_sdk.get_client().get_integration(PyMongoIntegration) is None: return @@ -172,8 +171,7 @@ def started(self, event): self._ongoing_operations[self._operation_key(event)] = span.__enter__() - def failed(self, event): - # type: (CommandFailedEvent) -> None + def failed(self, event: CommandFailedEvent) -> None: if sentry_sdk.get_client().get_integration(PyMongoIntegration) is None: return @@ -184,8 +182,7 @@ def failed(self, event): except KeyError: return - def succeeded(self, event): - # type: (CommandSucceededEvent) -> None + def succeeded(self, event: CommandSucceededEvent) -> None: if sentry_sdk.get_client().get_integration(PyMongoIntegration) is None: return @@ -202,6 +199,5 @@ class PyMongoIntegration(Integration): origin = f"auto.db.{identifier}" @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: monitoring.register(CommandTracer()) diff --git a/sentry_sdk/integrations/pyramid.py b/sentry_sdk/integrations/pyramid.py index a4d30e38a4..68a725451a 100644 --- a/sentry_sdk/integrations/pyramid.py +++ b/sentry_sdk/integrations/pyramid.py @@ -1,3 +1,4 @@ +from __future__ import annotations import functools import os import sys @@ -40,8 +41,7 @@ if getattr(Request, "authenticated_userid", None): - def authenticated_userid(request): - # type: (Request) -> Optional[Any] + def authenticated_userid(request: Request) -> Optional[Any]: return request.authenticated_userid else: @@ -58,8 +58,7 @@ class PyramidIntegration(Integration): transaction_style = "" - def __init__(self, transaction_style="route_name"): - # type: (str) -> None + def __init__(self, transaction_style: str = "route_name") -> None: if transaction_style not in TRANSACTION_STYLE_VALUES: raise ValueError( "Invalid value for transaction_style: %s (must be in %s)" @@ -68,15 +67,15 @@ def __init__(self, transaction_style="route_name"): self.transaction_style = transaction_style @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: from pyramid import router old_call_view = router._call_view @functools.wraps(old_call_view) - def sentry_patched_call_view(registry, request, *args, **kwargs): - # type: (Any, Request, *Any, **Any) -> Response + def sentry_patched_call_view( + registry: Any, request: Request, *args: Any, **kwargs: Any + ) -> Response: integration = sentry_sdk.get_client().get_integration(PyramidIntegration) if integration is None: return old_call_view(registry, request, *args, **kwargs) @@ -96,8 +95,9 @@ def sentry_patched_call_view(registry, request, *args, **kwargs): if hasattr(Request, "invoke_exception_view"): old_invoke_exception_view = Request.invoke_exception_view - def sentry_patched_invoke_exception_view(self, *args, **kwargs): - # type: (Request, *Any, **Any) -> Any + def sentry_patched_invoke_exception_view( + self: Request, *args: Any, **kwargs: Any + ) -> Any: rv = old_invoke_exception_view(self, *args, **kwargs) if ( @@ -116,10 +116,12 @@ def sentry_patched_invoke_exception_view(self, *args, **kwargs): old_wsgi_call = router.Router.__call__ @ensure_integration_enabled(PyramidIntegration, old_wsgi_call) - def sentry_patched_wsgi_call(self, environ, start_response): - # type: (Any, Dict[str, str], Callable[..., Any]) -> _ScopedResponse - def sentry_patched_inner_wsgi_call(environ, start_response): - # type: (Dict[str, Any], Callable[..., Any]) -> Any + def sentry_patched_wsgi_call( + self: Any, environ: Dict[str, str], start_response: Callable[..., Any] + ) -> _ScopedResponse: + def sentry_patched_inner_wsgi_call( + environ: Dict[str, Any], start_response: Callable[..., Any] + ) -> Any: try: return old_wsgi_call(self, environ, start_response) except Exception: @@ -137,8 +139,7 @@ def sentry_patched_inner_wsgi_call(environ, start_response): @ensure_integration_enabled(PyramidIntegration) -def _capture_exception(exc_info): - # type: (ExcInfo) -> None +def _capture_exception(exc_info: ExcInfo) -> None: if exc_info[0] is None or issubclass(exc_info[0], HTTPException): return @@ -151,8 +152,9 @@ def _capture_exception(exc_info): sentry_sdk.capture_event(event, hint=hint) -def _set_transaction_name_and_source(scope, transaction_style, request): - # type: (sentry_sdk.Scope, str, Request) -> None +def _set_transaction_name_and_source( + scope: sentry_sdk.Scope, transaction_style: str, request: Request +) -> None: try: name_for_style = { "route_name": request.matched_route.name, @@ -167,40 +169,33 @@ def _set_transaction_name_and_source(scope, transaction_style, request): class PyramidRequestExtractor(RequestExtractor): - def url(self): - # type: () -> str + def url(self) -> str: return self.request.path_url - def env(self): - # type: () -> Dict[str, str] + def env(self) -> Dict[str, str]: return self.request.environ - def cookies(self): - # type: () -> RequestCookies + def cookies(self) -> RequestCookies: return self.request.cookies - def raw_data(self): - # type: () -> str + def raw_data(self) -> str: return self.request.text - def form(self): - # type: () -> Dict[str, str] + def form(self) -> Dict[str, str]: return { key: value for key, value in self.request.POST.items() if not getattr(value, "filename", None) } - def files(self): - # type: () -> Dict[str, _FieldStorageWithFile] + def files(self) -> Dict[str, _FieldStorageWithFile]: return { key: value for key, value in self.request.POST.items() if getattr(value, "filename", None) } - def size_of_file(self, postdata): - # type: (_FieldStorageWithFile) -> int + def size_of_file(self, postdata: _FieldStorageWithFile) -> int: file = postdata.file try: return os.fstat(file.fileno()).st_size @@ -208,10 +203,10 @@ def size_of_file(self, postdata): return 0 -def _make_event_processor(weak_request, integration): - # type: (Callable[[], Request], PyramidIntegration) -> EventProcessor - def pyramid_event_processor(event, hint): - # type: (Event, Dict[str, Any]) -> Event +def _make_event_processor( + weak_request: Callable[[], Request], integration: PyramidIntegration +) -> EventProcessor: + def pyramid_event_processor(event: Event, hint: Dict[str, Any]) -> Event: request = weak_request() if request is None: return event diff --git a/sentry_sdk/integrations/quart.py b/sentry_sdk/integrations/quart.py index 68c1342216..eb7e117cc9 100644 --- a/sentry_sdk/integrations/quart.py +++ b/sentry_sdk/integrations/quart.py @@ -1,3 +1,4 @@ +from __future__ import annotations import asyncio import inspect from functools import wraps @@ -60,8 +61,7 @@ class QuartIntegration(Integration): transaction_style = "" - def __init__(self, transaction_style="endpoint"): - # type: (str) -> None + def __init__(self, transaction_style: str = "endpoint") -> None: if transaction_style not in TRANSACTION_STYLE_VALUES: raise ValueError( "Invalid value for transaction_style: %s (must be in %s)" @@ -70,8 +70,7 @@ def __init__(self, transaction_style="endpoint"): self.transaction_style = transaction_style @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: request_started.connect(_request_websocket_started) websocket_started.connect(_request_websocket_started) @@ -83,12 +82,12 @@ def setup_once(): patch_scaffold_route() -def patch_asgi_app(): - # type: () -> None +def patch_asgi_app() -> None: old_app = Quart.__call__ - async def sentry_patched_asgi_app(self, scope, receive, send): - # type: (Any, Any, Any, Any) -> Any + async def sentry_patched_asgi_app( + self: Any, scope: Any, receive: Any, send: Any + ) -> Any: if sentry_sdk.get_client().get_integration(QuartIntegration) is None: return await old_app(self, scope, receive, send) @@ -102,16 +101,13 @@ async def sentry_patched_asgi_app(self, scope, receive, send): Quart.__call__ = sentry_patched_asgi_app -def patch_scaffold_route(): - # type: () -> None +def patch_scaffold_route() -> None: old_route = Scaffold.route - def _sentry_route(*args, **kwargs): - # type: (*Any, **Any) -> Any + def _sentry_route(*args: Any, **kwargs: Any) -> Any: old_decorator = old_route(*args, **kwargs) - def decorator(old_func): - # type: (Any) -> Any + def decorator(old_func: Any) -> Any: if inspect.isfunction(old_func) and not asyncio.iscoroutinefunction( old_func @@ -119,8 +115,7 @@ def decorator(old_func): @wraps(old_func) @ensure_integration_enabled(QuartIntegration, old_func) - def _sentry_func(*args, **kwargs): - # type: (*Any, **Any) -> Any + def _sentry_func(*args: Any, **kwargs: Any) -> Any: current_scope = sentry_sdk.get_current_scope() if current_scope.root_span is not None: current_scope.root_span.update_active_thread() @@ -140,8 +135,9 @@ def _sentry_func(*args, **kwargs): Scaffold.route = _sentry_route -def _set_transaction_name_and_source(scope, transaction_style, request): - # type: (sentry_sdk.Scope, str, Request) -> None +def _set_transaction_name_and_source( + scope: sentry_sdk.Scope, transaction_style: str, request: Request +) -> None: try: name_for_style = { @@ -156,8 +152,7 @@ def _set_transaction_name_and_source(scope, transaction_style, request): pass -async def _request_websocket_started(app, **kwargs): - # type: (Quart, **Any) -> None +async def _request_websocket_started(app: Quart, **kwargs: Any) -> None: integration = sentry_sdk.get_client().get_integration(QuartIntegration) if integration is None: return @@ -178,10 +173,10 @@ async def _request_websocket_started(app, **kwargs): scope.add_event_processor(evt_processor) -def _make_request_event_processor(app, request, integration): - # type: (Quart, Request, QuartIntegration) -> EventProcessor - def inner(event, hint): - # type: (Event, dict[str, Any]) -> Event +def _make_request_event_processor( + app: Quart, request: Request, integration: QuartIntegration +) -> EventProcessor: + def inner(event: Event, hint: dict[str, Any]) -> Event: # if the request is gone we are fine not logging the data from # it. This might happen if the processor is pushed away to # another thread. @@ -207,8 +202,9 @@ def inner(event, hint): return inner -async def _capture_exception(sender, exception, **kwargs): - # type: (Quart, Union[ValueError, BaseException], **Any) -> None +async def _capture_exception( + sender: Quart, exception: Union[ValueError, BaseException], **kwargs: Any +) -> None: integration = sentry_sdk.get_client().get_integration(QuartIntegration) if integration is None: return @@ -222,8 +218,7 @@ async def _capture_exception(sender, exception, **kwargs): sentry_sdk.capture_event(event, hint=hint) -def _add_user_to_event(event): - # type: (Event) -> None +def _add_user_to_event(event: Event) -> None: if quart_auth is None: return diff --git a/sentry_sdk/integrations/ray.py b/sentry_sdk/integrations/ray.py index 1bb78c859f..46a9b65dcb 100644 --- a/sentry_sdk/integrations/ray.py +++ b/sentry_sdk/integrations/ray.py @@ -1,3 +1,4 @@ +from __future__ import annotations import inspect import sys @@ -29,8 +30,7 @@ DEFAULT_TRANSACTION_NAME = "unknown Ray function" -def _check_sentry_initialized(): - # type: () -> None +def _check_sentry_initialized() -> None: if sentry_sdk.get_client().is_active(): return @@ -39,13 +39,13 @@ def _check_sentry_initialized(): ) -def _patch_ray_remote(): - # type: () -> None +def _patch_ray_remote() -> None: old_remote = ray.remote @functools.wraps(old_remote) - def new_remote(f=None, *args, **kwargs): - # type: (Optional[Callable[..., Any]], *Any, **Any) -> Callable[..., Any] + def new_remote( + f: Optional[Callable[..., Any]] = None, *args: Any, **kwargs: Any + ) -> Callable[..., Any]: if inspect.isclass(f): # Ray Actors @@ -54,10 +54,10 @@ def new_remote(f=None, *args, **kwargs): # (Only Ray Tasks are supported) return old_remote(f, *args, **kwargs) - def wrapper(user_f): - # type: (Callable[..., Any]) -> Any - def new_func(*f_args, _tracing=None, **f_kwargs): - # type: (Any, Optional[dict[str, Any]], Any) -> Any + def wrapper(user_f: Callable[..., Any]) -> Any: + def new_func( + *f_args: Any, _tracing: Optional[dict[str, Any]] = None, **f_kwargs: Any + ) -> Any: _check_sentry_initialized() root_span_name = ( @@ -91,8 +91,9 @@ def new_func(*f_args, _tracing=None, **f_kwargs): rv = old_remote(*args, **kwargs)(new_func) old_remote_method = rv.remote - def _remote_method_with_header_propagation(*args, **kwargs): - # type: (*Any, **Any) -> Any + def _remote_method_with_header_propagation( + *args: Any, **kwargs: Any + ) -> Any: """ Ray Client """ @@ -129,8 +130,7 @@ def _remote_method_with_header_propagation(*args, **kwargs): ray.remote = new_remote -def _capture_exception(exc_info, **kwargs): - # type: (ExcInfo, **Any) -> None +def _capture_exception(exc_info: ExcInfo, **kwargs: Any) -> None: client = sentry_sdk.get_client() event, hint = event_from_exception( @@ -149,8 +149,7 @@ class RayIntegration(Integration): origin = f"auto.queue.{identifier}" @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: version = package_version("ray") _check_minimum_version(RayIntegration, version) diff --git a/sentry_sdk/integrations/redis/__init__.py b/sentry_sdk/integrations/redis/__init__.py index f443138295..1d0b39f1cb 100644 --- a/sentry_sdk/integrations/redis/__init__.py +++ b/sentry_sdk/integrations/redis/__init__.py @@ -1,3 +1,4 @@ +from __future__ import annotations from sentry_sdk.integrations import Integration, DidNotEnable from sentry_sdk.integrations.redis.consts import _DEFAULT_MAX_DATA_SIZE from sentry_sdk.integrations.redis.rb import _patch_rb @@ -15,14 +16,16 @@ class RedisIntegration(Integration): identifier = "redis" - def __init__(self, max_data_size=_DEFAULT_MAX_DATA_SIZE, cache_prefixes=None): - # type: (int, Optional[list[str]]) -> None + def __init__( + self, + max_data_size: int = _DEFAULT_MAX_DATA_SIZE, + cache_prefixes: Optional[list[str]] = None, + ) -> None: self.max_data_size = max_data_size self.cache_prefixes = cache_prefixes if cache_prefixes is not None else [] @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: try: from redis import StrictRedis, client except ImportError: diff --git a/sentry_sdk/integrations/redis/_async_common.py b/sentry_sdk/integrations/redis/_async_common.py index c3e23f8a99..ca23db3939 100644 --- a/sentry_sdk/integrations/redis/_async_common.py +++ b/sentry_sdk/integrations/redis/_async_common.py @@ -1,3 +1,4 @@ +from __future__ import annotations import sentry_sdk from sentry_sdk.consts import OP from sentry_sdk.integrations.redis.consts import SPAN_ORIGIN @@ -24,15 +25,16 @@ def patch_redis_async_pipeline( - pipeline_cls, is_cluster, get_command_args_fn, get_db_data_fn -): - # type: (Union[type[Pipeline[Any]], type[ClusterPipeline[Any]]], bool, Any, Callable[[Any], dict[str, Any]]) -> None + pipeline_cls: Union[type[Pipeline[Any]], type[ClusterPipeline[Any]]], + is_cluster: bool, + get_command_args_fn: Any, + get_db_data_fn: Callable[[Any], dict[str, Any]], +) -> None: old_execute = pipeline_cls.execute from sentry_sdk.integrations.redis import RedisIntegration - async def _sentry_execute(self, *args, **kwargs): - # type: (Any, *Any, **Any) -> Any + async def _sentry_execute(self: Any, *args: Any, **kwargs: Any) -> Any: if sentry_sdk.get_client().get_integration(RedisIntegration) is None: return await old_execute(self, *args, **kwargs) @@ -67,14 +69,18 @@ async def _sentry_execute(self, *args, **kwargs): pipeline_cls.execute = _sentry_execute # type: ignore -def patch_redis_async_client(cls, is_cluster, get_db_data_fn): - # type: (Union[type[StrictRedis[Any]], type[RedisCluster[Any]]], bool, Callable[[Any], dict[str, Any]]) -> None +def patch_redis_async_client( + cls: Union[type[StrictRedis[Any]], type[RedisCluster[Any]]], + is_cluster: bool, + get_db_data_fn: Callable[[Any], dict[str, Any]], +) -> None: old_execute_command = cls.execute_command from sentry_sdk.integrations.redis import RedisIntegration - async def _sentry_execute_command(self, name, *args, **kwargs): - # type: (Any, str, *Any, **Any) -> Any + async def _sentry_execute_command( + self: Any, name: str, *args: Any, **kwargs: Any + ) -> Any: integration = sentry_sdk.get_client().get_integration(RedisIntegration) if integration is None: return await old_execute_command(self, name, *args, **kwargs) diff --git a/sentry_sdk/integrations/redis/_sync_common.py b/sentry_sdk/integrations/redis/_sync_common.py index 7efdf764a7..e3d5b77323 100644 --- a/sentry_sdk/integrations/redis/_sync_common.py +++ b/sentry_sdk/integrations/redis/_sync_common.py @@ -1,3 +1,4 @@ +from __future__ import annotations import sentry_sdk from sentry_sdk.consts import OP from sentry_sdk.integrations.redis.consts import SPAN_ORIGIN @@ -22,18 +23,16 @@ def patch_redis_pipeline( - pipeline_cls, - is_cluster, - get_command_args_fn, - get_db_data_fn, -): - # type: (Any, bool, Any, Callable[[Any], dict[str, Any]]) -> None + pipeline_cls: Any, + is_cluster: bool, + get_command_args_fn: Any, + get_db_data_fn: Callable[[Any], dict[str, Any]], +) -> None: old_execute = pipeline_cls.execute from sentry_sdk.integrations.redis import RedisIntegration - def sentry_patched_execute(self, *args, **kwargs): - # type: (Any, *Any, **Any) -> Any + def sentry_patched_execute(self: Any, *args: Any, **kwargs: Any) -> Any: if sentry_sdk.get_client().get_integration(RedisIntegration) is None: return old_execute(self, *args, **kwargs) @@ -64,8 +63,9 @@ def sentry_patched_execute(self, *args, **kwargs): pipeline_cls.execute = sentry_patched_execute -def patch_redis_client(cls, is_cluster, get_db_data_fn): - # type: (Any, bool, Callable[[Any], dict[str, Any]]) -> None +def patch_redis_client( + cls: Any, is_cluster: bool, get_db_data_fn: Callable[[Any], dict[str, Any]] +) -> None: """ This function can be used to instrument custom redis client classes or subclasses. @@ -74,8 +74,9 @@ def patch_redis_client(cls, is_cluster, get_db_data_fn): from sentry_sdk.integrations.redis import RedisIntegration - def sentry_patched_execute_command(self, name, *args, **kwargs): - # type: (Any, str, *Any, **Any) -> Any + def sentry_patched_execute_command( + self: Any, name: str, *args: Any, **kwargs: Any + ) -> Any: integration = sentry_sdk.get_client().get_integration(RedisIntegration) if integration is None: return old_execute_command(self, name, *args, **kwargs) diff --git a/sentry_sdk/integrations/redis/modules/caches.py b/sentry_sdk/integrations/redis/modules/caches.py index 4ab33d2ea8..574c928f12 100644 --- a/sentry_sdk/integrations/redis/modules/caches.py +++ b/sentry_sdk/integrations/redis/modules/caches.py @@ -2,6 +2,7 @@ Code used for the Caches module in Sentry """ +from __future__ import annotations from sentry_sdk.consts import OP, SPANDATA from sentry_sdk.integrations.redis.utils import _get_safe_key, _key_as_string from sentry_sdk.utils import capture_internal_exceptions @@ -16,8 +17,7 @@ from typing import Any, Optional -def _get_op(name): - # type: (str) -> Optional[str] +def _get_op(name: str) -> Optional[str]: op = None if name.lower() in GET_COMMANDS: op = OP.CACHE_GET @@ -27,8 +27,12 @@ def _get_op(name): return op -def _compile_cache_span_properties(redis_command, args, kwargs, integration): - # type: (str, tuple[Any, ...], dict[str, Any], RedisIntegration) -> dict[str, Any] +def _compile_cache_span_properties( + redis_command: str, + args: tuple[Any, ...], + kwargs: dict[str, Any], + integration: RedisIntegration, +) -> dict[str, Any]: key = _get_safe_key(redis_command, args, kwargs) key_as_string = _key_as_string(key) keys_as_string = key_as_string.split(", ") @@ -61,8 +65,12 @@ def _compile_cache_span_properties(redis_command, args, kwargs, integration): return properties -def _get_cache_span_description(redis_command, args, kwargs, integration): - # type: (str, tuple[Any, ...], dict[str, Any], RedisIntegration) -> str +def _get_cache_span_description( + redis_command: str, + args: tuple[Any, ...], + kwargs: dict[str, Any], + integration: RedisIntegration, +) -> str: description = _key_as_string(_get_safe_key(redis_command, args, kwargs)) data_should_be_truncated = ( @@ -74,8 +82,9 @@ def _get_cache_span_description(redis_command, args, kwargs, integration): return description -def _get_cache_data(redis_client, properties, return_value): - # type: (Any, dict[str, Any], Optional[Any]) -> dict[str, Any] +def _get_cache_data( + redis_client: Any, properties: dict[str, Any], return_value: Optional[Any] +) -> dict[str, Any]: data = {} with capture_internal_exceptions(): diff --git a/sentry_sdk/integrations/redis/modules/queries.py b/sentry_sdk/integrations/redis/modules/queries.py index c070893ac8..312d48e2bd 100644 --- a/sentry_sdk/integrations/redis/modules/queries.py +++ b/sentry_sdk/integrations/redis/modules/queries.py @@ -2,6 +2,7 @@ Code used for the Queries module in Sentry """ +from __future__ import annotations from sentry_sdk.consts import OP, SPANDATA from sentry_sdk.integrations.redis.utils import _get_safe_command from sentry_sdk.utils import capture_internal_exceptions @@ -14,8 +15,9 @@ from typing import Any -def _compile_db_span_properties(integration, redis_command, args): - # type: (RedisIntegration, str, tuple[Any, ...]) -> dict[str, Any] +def _compile_db_span_properties( + integration: RedisIntegration, redis_command: str, args: tuple[Any, ...] +) -> dict[str, Any]: description = _get_db_span_description(integration, redis_command, args) properties = { @@ -26,8 +28,9 @@ def _compile_db_span_properties(integration, redis_command, args): return properties -def _get_db_span_description(integration, command_name, args): - # type: (RedisIntegration, str, tuple[Any, ...]) -> str +def _get_db_span_description( + integration: RedisIntegration, command_name: str, args: tuple[Any, ...] +) -> str: description = command_name with capture_internal_exceptions(): @@ -42,8 +45,7 @@ def _get_db_span_description(integration, command_name, args): return description -def _get_connection_data(connection_params): - # type: (dict[str, Any]) -> dict[str, Any] +def _get_connection_data(connection_params: dict[str, Any]) -> dict[str, Any]: data = { SPANDATA.DB_SYSTEM: "redis", } @@ -63,8 +65,7 @@ def _get_connection_data(connection_params): return data -def _get_db_data(redis_instance): - # type: (Redis[Any]) -> dict[str, Any] +def _get_db_data(redis_instance: Redis[Any]) -> dict[str, Any]: try: return _get_connection_data(redis_instance.connection_pool.connection_kwargs) except AttributeError: diff --git a/sentry_sdk/integrations/redis/rb.py b/sentry_sdk/integrations/redis/rb.py index 68d3c3a9d6..b6eab57171 100644 --- a/sentry_sdk/integrations/redis/rb.py +++ b/sentry_sdk/integrations/redis/rb.py @@ -4,12 +4,13 @@ https://github.com/getsentry/rb """ +from __future__ import annotations + from sentry_sdk.integrations.redis._sync_common import patch_redis_client from sentry_sdk.integrations.redis.modules.queries import _get_db_data -def _patch_rb(): - # type: () -> None +def _patch_rb() -> None: try: import rb.clients # type: ignore except ImportError: diff --git a/sentry_sdk/integrations/redis/redis.py b/sentry_sdk/integrations/redis/redis.py index 935a828c3d..f7332c906b 100644 --- a/sentry_sdk/integrations/redis/redis.py +++ b/sentry_sdk/integrations/redis/redis.py @@ -4,6 +4,8 @@ https://github.com/redis/redis-py """ +from __future__ import annotations + from sentry_sdk.integrations.redis._sync_common import ( patch_redis_client, patch_redis_pipeline, @@ -16,13 +18,11 @@ from typing import Any, Sequence -def _get_redis_command_args(command): - # type: (Any) -> Sequence[Any] +def _get_redis_command_args(command: Any) -> Sequence[Any]: return command[0] -def _patch_redis(StrictRedis, client): # noqa: N803 - # type: (Any, Any) -> None +def _patch_redis(StrictRedis: Any, client: Any) -> None: # noqa: N803 patch_redis_client( StrictRedis, is_cluster=False, diff --git a/sentry_sdk/integrations/redis/redis_cluster.py b/sentry_sdk/integrations/redis/redis_cluster.py index 5aab34ad64..3c4dfdea93 100644 --- a/sentry_sdk/integrations/redis/redis_cluster.py +++ b/sentry_sdk/integrations/redis/redis_cluster.py @@ -5,6 +5,8 @@ https://github.com/redis/redis-py/blob/master/redis/cluster.py """ +from __future__ import annotations + from sentry_sdk.integrations.redis._sync_common import ( patch_redis_client, patch_redis_pipeline, @@ -25,8 +27,9 @@ ) -def _get_async_cluster_db_data(async_redis_cluster_instance): - # type: (AsyncRedisCluster[Any]) -> dict[str, Any] +def _get_async_cluster_db_data( + async_redis_cluster_instance: AsyncRedisCluster[Any], +) -> dict[str, Any]: default_node = async_redis_cluster_instance.get_default_node() if default_node is not None and default_node.connection_kwargs is not None: return _get_connection_data(default_node.connection_kwargs) @@ -34,8 +37,9 @@ def _get_async_cluster_db_data(async_redis_cluster_instance): return {} -def _get_async_cluster_pipeline_db_data(async_redis_cluster_pipeline_instance): - # type: (AsyncClusterPipeline[Any]) -> dict[str, Any] +def _get_async_cluster_pipeline_db_data( + async_redis_cluster_pipeline_instance: AsyncClusterPipeline[Any], +) -> dict[str, Any]: with capture_internal_exceptions(): client = getattr(async_redis_cluster_pipeline_instance, "cluster_client", None) if client is None: @@ -50,8 +54,7 @@ def _get_async_cluster_pipeline_db_data(async_redis_cluster_pipeline_instance): return _get_async_cluster_db_data(client) -def _get_cluster_db_data(redis_cluster_instance): - # type: (RedisCluster[Any]) -> dict[str, Any] +def _get_cluster_db_data(redis_cluster_instance: RedisCluster[Any]) -> dict[str, Any]: default_node = redis_cluster_instance.get_default_node() if default_node is not None: @@ -64,8 +67,7 @@ def _get_cluster_db_data(redis_cluster_instance): return {} -def _patch_redis_cluster(): - # type: () -> None +def _patch_redis_cluster() -> None: """Patches the cluster module on redis SDK (as opposed to rediscluster library)""" try: from redis import RedisCluster, cluster diff --git a/sentry_sdk/integrations/redis/redis_py_cluster_legacy.py b/sentry_sdk/integrations/redis/redis_py_cluster_legacy.py index 53b545c21b..e658443e81 100644 --- a/sentry_sdk/integrations/redis/redis_py_cluster_legacy.py +++ b/sentry_sdk/integrations/redis/redis_py_cluster_legacy.py @@ -5,6 +5,8 @@ https://github.com/grokzen/redis-py-cluster """ +from __future__ import annotations + from sentry_sdk.integrations.redis._sync_common import ( patch_redis_client, patch_redis_pipeline, @@ -13,8 +15,7 @@ from sentry_sdk.integrations.redis.utils import _parse_rediscluster_command -def _patch_rediscluster(): - # type: () -> None +def _patch_rediscluster() -> None: try: import rediscluster # type: ignore except ImportError: diff --git a/sentry_sdk/integrations/redis/utils.py b/sentry_sdk/integrations/redis/utils.py index 6d9a2d6160..e109d3fe34 100644 --- a/sentry_sdk/integrations/redis/utils.py +++ b/sentry_sdk/integrations/redis/utils.py @@ -1,3 +1,4 @@ +from __future__ import annotations import sentry_sdk from sentry_sdk.consts import SPANDATA from sentry_sdk.integrations.redis.consts import ( @@ -26,8 +27,7 @@ ] -def _update_span(span, *data_bags): - # type: (Span, *dict[str, Any]) -> None +def _update_span(span: Span, *data_bags: dict[str, Any]) -> None: """ Set tags and data on the given span to data from the given data bags. """ @@ -39,8 +39,7 @@ def _update_span(span, *data_bags): span.set_attribute(key, value) -def _create_breadcrumb(message, *data_bags): - # type: (str, *dict[str, Any]) -> None +def _create_breadcrumb(message: str, *data_bags: dict[str, Any]) -> None: """ Create a breadcrumb containing the tags data from the given data bags. """ @@ -58,8 +57,7 @@ def _create_breadcrumb(message, *data_bags): ) -def _get_safe_command(name, args): - # type: (str, Sequence[Any]) -> str +def _get_safe_command(name: str, args: Sequence[Any]) -> str: command_parts = [name] for i, arg in enumerate(args): @@ -86,8 +84,7 @@ def _get_safe_command(name, args): return command -def _safe_decode(key): - # type: (Any) -> str +def _safe_decode(key: Any) -> str: if isinstance(key, bytes): try: return key.decode() @@ -97,8 +94,7 @@ def _safe_decode(key): return str(key) -def _key_as_string(key): - # type: (Any) -> str +def _key_as_string(key: Any) -> str: if isinstance(key, (dict, list, tuple)): key = ", ".join(_safe_decode(x) for x in key) elif isinstance(key, bytes): @@ -111,8 +107,9 @@ def _key_as_string(key): return key -def _get_safe_key(method_name, args, kwargs): - # type: (str, Optional[tuple[Any, ...]], Optional[dict[str, Any]]) -> Optional[tuple[str, ...]] +def _get_safe_key( + method_name: str, args: Optional[tuple[Any, ...]], kwargs: Optional[dict[str, Any]] +) -> Optional[tuple[str, ...]]: """ Gets the key (or keys) from the given method_name. The method_name could be a redis command or a django caching command @@ -142,17 +139,20 @@ def _get_safe_key(method_name, args, kwargs): return key -def _parse_rediscluster_command(command): - # type: (Any) -> Sequence[Any] +def _parse_rediscluster_command(command: Any) -> Sequence[Any]: return command.args -def _get_pipeline_data(is_cluster, get_command_args_fn, is_transaction, command_seq): - # type: (bool, Any, bool, Sequence[Any]) -> dict[str, Any] - data = { +def _get_pipeline_data( + is_cluster: bool, + get_command_args_fn: Any, + is_transaction: bool, + command_seq: Sequence[Any], +) -> dict[str, Any]: + data: dict[str, Any] = { "redis.is_cluster": is_cluster, "redis.transaction": is_transaction, - } # type: dict[str, Any] + } commands = [] for i, arg in enumerate(command_seq): @@ -168,11 +168,10 @@ def _get_pipeline_data(is_cluster, get_command_args_fn, is_transaction, command_ return data -def _get_client_data(is_cluster, name, *args): - # type: (bool, str, *Any) -> dict[str, Any] - data = { +def _get_client_data(is_cluster: bool, name: str, *args: Any) -> dict[str, Any]: + data: dict[str, Any] = { "redis.is_cluster": is_cluster, - } # type: dict[str, Any] + } if name: data["redis.command"] = name diff --git a/sentry_sdk/integrations/rq.py b/sentry_sdk/integrations/rq.py index 33910ed476..43a943c272 100644 --- a/sentry_sdk/integrations/rq.py +++ b/sentry_sdk/integrations/rq.py @@ -1,3 +1,4 @@ +from __future__ import annotations import weakref import sentry_sdk @@ -49,16 +50,16 @@ class RqIntegration(Integration): origin = f"auto.queue.{identifier}" @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: version = parse_version(RQ_VERSION) _check_minimum_version(RqIntegration, version) old_perform_job = Worker.perform_job @ensure_integration_enabled(RqIntegration, old_perform_job) - def sentry_patched_perform_job(self, job, queue, *args, **kwargs): - # type: (Any, Job, Queue, *Any, **Any) -> bool + def sentry_patched_perform_job( + self: Any, job: Job, queue: Queue, *args: Any, **kwargs: Any + ) -> bool: with sentry_sdk.new_scope() as scope: try: transaction_name = job.func_name or DEFAULT_TRANSACTION_NAME @@ -95,8 +96,9 @@ def sentry_patched_perform_job(self, job, queue, *args, **kwargs): old_handle_exception = Worker.handle_exception - def sentry_patched_handle_exception(self, job, *exc_info, **kwargs): - # type: (Worker, Any, *Any, **Any) -> Any + def sentry_patched_handle_exception( + self: Worker, job: Any, *exc_info: Any, **kwargs: Any + ) -> Any: retry = ( hasattr(job, "retries_left") and job.retries_left @@ -113,8 +115,7 @@ def sentry_patched_handle_exception(self, job, *exc_info, **kwargs): old_enqueue_job = Queue.enqueue_job @ensure_integration_enabled(RqIntegration, old_enqueue_job) - def sentry_patched_enqueue_job(self, job, **kwargs): - # type: (Queue, Any, **Any) -> Any + def sentry_patched_enqueue_job(self: Queue, job: Any, **kwargs: Any) -> Any: job.meta["_sentry_trace_headers"] = dict( sentry_sdk.get_current_scope().iter_trace_propagation_headers() ) @@ -126,10 +127,8 @@ def sentry_patched_enqueue_job(self, job, **kwargs): ignore_logger("rq.worker") -def _make_event_processor(weak_job): - # type: (Callable[[], Job]) -> EventProcessor - def event_processor(event, hint): - # type: (Event, dict[str, Any]) -> Event +def _make_event_processor(weak_job: Callable[[], Job]) -> EventProcessor: + def event_processor(event: Event, hint: dict[str, Any]) -> Event: job = weak_job() if job is not None: with capture_internal_exceptions(): @@ -159,8 +158,7 @@ def event_processor(event, hint): return event_processor -def _capture_exception(exc_info, **kwargs): - # type: (ExcInfo, **Any) -> None +def _capture_exception(exc_info: ExcInfo, **kwargs: Any) -> None: client = sentry_sdk.get_client() event, hint = event_from_exception( @@ -172,8 +170,7 @@ def _capture_exception(exc_info, **kwargs): sentry_sdk.capture_event(event, hint=hint) -def _prepopulate_attributes(job, queue): - # type: (Job, Queue) -> dict[str, Any] +def _prepopulate_attributes(job: Job, queue: Queue) -> dict[str, Any]: attributes = { "messaging.system": "rq", "rq.job.id": job.id, diff --git a/sentry_sdk/integrations/rust_tracing.py b/sentry_sdk/integrations/rust_tracing.py index acfe9bd7f4..3af40bc1fd 100644 --- a/sentry_sdk/integrations/rust_tracing.py +++ b/sentry_sdk/integrations/rust_tracing.py @@ -30,9 +30,13 @@ Each native extension requires its own integration. """ +from __future__ import annotations import json from enum import Enum, auto -from typing import Any, Callable, Dict, Optional +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Any, Callable, Dict, Optional import sentry_sdk from sentry_sdk.integrations import Integration @@ -56,8 +60,7 @@ class EventTypeMapping(Enum): Event = auto() -def tracing_level_to_sentry_level(level): - # type: (str) -> sentry_sdk._types.LogLevelStr +def tracing_level_to_sentry_level(level: str) -> sentry_sdk._types.LogLevelStr: level = RustTracingLevel(level) if level in (RustTracingLevel.Trace, RustTracingLevel.Debug): return "debug" @@ -97,15 +100,15 @@ def process_event(event: Dict[str, Any]) -> None: logger = metadata.get("target") level = tracing_level_to_sentry_level(metadata.get("level")) - message = event.get("message") # type: sentry_sdk._types.Any + message: sentry_sdk._types.Any = event.get("message") contexts = extract_contexts(event) - sentry_event = { + sentry_event: sentry_sdk._types.Event = { "logger": logger, "level": level, "message": message, "contexts": contexts, - } # type: sentry_sdk._types.Event + } sentry_sdk.capture_event(sentry_event) diff --git a/sentry_sdk/integrations/sanic.py b/sentry_sdk/integrations/sanic.py index 06e30ffe31..1eed090332 100644 --- a/sentry_sdk/integrations/sanic.py +++ b/sentry_sdk/integrations/sanic.py @@ -1,3 +1,4 @@ +from __future__ import annotations import sys import weakref from inspect import isawaitable @@ -59,8 +60,9 @@ class SanicIntegration(Integration): origin = f"auto.http.{identifier}" version = None - def __init__(self, unsampled_statuses=frozenset({404})): - # type: (Optional[Container[int]]) -> None + def __init__( + self, unsampled_statuses: Optional[Container[int]] = frozenset({404}) + ) -> None: """ The unsampled_statuses parameter can be used to specify for which HTTP statuses the transactions should not be sent to Sentry. By default, transactions are sent for all @@ -70,8 +72,7 @@ def __init__(self, unsampled_statuses=frozenset({404})): self._unsampled_statuses = unsampled_statuses or set() @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: SanicIntegration.version = parse_version(SANIC_VERSION) _check_minimum_version(SanicIntegration, SanicIntegration.version) @@ -103,56 +104,45 @@ def setup_once(): class SanicRequestExtractor(RequestExtractor): - def content_length(self): - # type: () -> int + def content_length(self) -> int: if self.request.body is None: return 0 return len(self.request.body) - def cookies(self): - # type: () -> Dict[str, str] + def cookies(self) -> Dict[str, str]: return dict(self.request.cookies) - def raw_data(self): - # type: () -> bytes + def raw_data(self) -> bytes: return self.request.body - def form(self): - # type: () -> RequestParameters + def form(self) -> RequestParameters: return self.request.form - def is_json(self): - # type: () -> bool + def is_json(self) -> bool: raise NotImplementedError() - def json(self): - # type: () -> Optional[Any] + def json(self) -> Optional[Any]: return self.request.json - def files(self): - # type: () -> RequestParameters + def files(self) -> RequestParameters: return self.request.files - def size_of_file(self, file): - # type: (Any) -> int + def size_of_file(self, file: Any) -> int: return len(file.body or ()) -def _setup_sanic(): - # type: () -> None +def _setup_sanic() -> None: Sanic._startup = _startup ErrorHandler.lookup = _sentry_error_handler_lookup -def _setup_legacy_sanic(): - # type: () -> None +def _setup_legacy_sanic() -> None: Sanic.handle_request = _legacy_handle_request Router.get = _legacy_router_get ErrorHandler.lookup = _sentry_error_handler_lookup -async def _startup(self): - # type: (Sanic) -> None +async def _startup(self: Sanic) -> None: # This happens about as early in the lifecycle as possible, just after the # Request object is created. The body has not yet been consumed. self.signal("http.lifecycle.request")(_context_enter) @@ -171,8 +161,7 @@ async def _startup(self): await old_startup(self) -async def _context_enter(request): - # type: (Request) -> None +async def _context_enter(request: Request) -> None: request.ctx._sentry_do_integration = ( sentry_sdk.get_client().get_integration(SanicIntegration) is not None ) @@ -203,8 +192,9 @@ async def _context_enter(request): ).__enter__() -async def _context_exit(request, response=None): - # type: (Request, Optional[BaseHTTPResponse]) -> None +async def _context_exit( + request: Request, response: Optional[BaseHTTPResponse] = None +) -> None: with capture_internal_exceptions(): if not request.ctx._sentry_do_integration: return @@ -233,8 +223,7 @@ async def _context_exit(request, response=None): request.ctx._sentry_scope_manager.__exit__(None, None, None) -async def _set_transaction(request, route, **_): - # type: (Request, Route, **Any) -> None +async def _set_transaction(request: Request, route: Route, **_: Any) -> None: if request.ctx._sentry_do_integration: with capture_internal_exceptions(): scope = sentry_sdk.get_current_scope() @@ -242,8 +231,9 @@ async def _set_transaction(request, route, **_): scope.set_transaction_name(route_name, source=TransactionSource.COMPONENT) -def _sentry_error_handler_lookup(self, exception, *args, **kwargs): - # type: (Any, Exception, *Any, **Any) -> Optional[object] +def _sentry_error_handler_lookup( + self: Any, exception: Exception, *args: Any, **kwargs: Any +) -> Optional[object]: _capture_exception(exception) old_error_handler = old_error_handler_lookup(self, exception, *args, **kwargs) @@ -253,8 +243,9 @@ def _sentry_error_handler_lookup(self, exception, *args, **kwargs): if sentry_sdk.get_client().get_integration(SanicIntegration) is None: return old_error_handler - async def sentry_wrapped_error_handler(request, exception): - # type: (Request, Exception) -> Any + async def sentry_wrapped_error_handler( + request: Request, exception: Exception + ) -> Any: try: response = old_error_handler(request, exception) if isawaitable(response): @@ -276,8 +267,9 @@ async def sentry_wrapped_error_handler(request, exception): return sentry_wrapped_error_handler -async def _legacy_handle_request(self, request, *args, **kwargs): - # type: (Any, Request, *Any, **Any) -> Any +async def _legacy_handle_request( + self: Any, request: Request, *args: Any, **kwargs: Any +) -> Any: if sentry_sdk.get_client().get_integration(SanicIntegration) is None: return await old_handle_request(self, request, *args, **kwargs) @@ -294,8 +286,7 @@ async def _legacy_handle_request(self, request, *args, **kwargs): return response -def _legacy_router_get(self, *args): - # type: (Any, Union[Any, Request]) -> Any +def _legacy_router_get(self: Any, *args: Union[Any, Request]) -> Any: rv = old_router_get(self, *args) if sentry_sdk.get_client().get_integration(SanicIntegration) is not None: with capture_internal_exceptions(): @@ -325,8 +316,7 @@ def _legacy_router_get(self, *args): @ensure_integration_enabled(SanicIntegration) -def _capture_exception(exception): - # type: (Union[ExcInfo, BaseException]) -> None +def _capture_exception(exception: Union[ExcInfo, BaseException]) -> None: with capture_internal_exceptions(): event, hint = event_from_exception( exception, @@ -340,10 +330,8 @@ def _capture_exception(exception): sentry_sdk.capture_event(event, hint=hint) -def _make_request_processor(weak_request): - # type: (Callable[[], Request]) -> EventProcessor - def sanic_processor(event, hint): - # type: (Event, Optional[Hint]) -> Optional[Event] +def _make_request_processor(weak_request: Callable[[], Request]) -> EventProcessor: + def sanic_processor(event: Event, hint: Optional[Hint]) -> Optional[Event]: try: if hint and issubclass(hint["exc_info"][0], SanicException): diff --git a/sentry_sdk/integrations/serverless.py b/sentry_sdk/integrations/serverless.py index 760c07ffad..dd8fbe526d 100644 --- a/sentry_sdk/integrations/serverless.py +++ b/sentry_sdk/integrations/serverless.py @@ -1,47 +1,43 @@ +from __future__ import annotations import sys from functools import wraps import sentry_sdk from sentry_sdk.utils import event_from_exception, reraise -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, overload if TYPE_CHECKING: - from typing import Any + from typing import NoReturn from typing import Callable from typing import TypeVar + from typing import ParamSpec from typing import Union from typing import Optional - from typing import overload - F = TypeVar("F", bound=Callable[..., Any]) + T = TypeVar("T") + P = ParamSpec("P") -else: - def overload(x): - # type: (F) -> F - return x - - -@overload -def serverless_function(f, flush=True): - # type: (F, bool) -> F - pass +if TYPE_CHECKING: + @overload + def serverless_function(f: Callable[P, T], flush: bool = True) -> Callable[P, T]: + pass -@overload -def serverless_function(f=None, flush=True): # noqa: F811 - # type: (None, bool) -> Callable[[F], F] - pass + @overload + def serverless_function( + f: None = None, flush: bool = True + ) -> Callable[[Callable[P, T]], Callable[P, T]]: + pass -def serverless_function(f=None, flush=True): # noqa - # type: (Optional[F], bool) -> Union[F, Callable[[F], F]] - def wrapper(f): - # type: (F) -> F +def serverless_function( + f: Optional[Callable[P, T]] = None, flush: bool = True +) -> Union[Callable[P, T], Callable[[Callable[P, T]], Callable[P, T]]]: + def wrapper(f: Callable[P, T]) -> Callable[P, T]: @wraps(f) - def inner(*args, **kwargs): - # type: (*Any, **Any) -> Any + def inner(*args: P.args, **kwargs: P.kwargs) -> T: with sentry_sdk.isolation_scope() as scope: scope.clear_breadcrumbs() @@ -53,7 +49,7 @@ def inner(*args, **kwargs): if flush: sentry_sdk.flush() - return inner # type: ignore + return inner if f is None: return wrapper @@ -61,8 +57,7 @@ def inner(*args, **kwargs): return wrapper(f) -def _capture_and_reraise(): - # type: () -> None +def _capture_and_reraise() -> NoReturn: exc_info = sys.exc_info() client = sentry_sdk.get_client() if client.is_active(): diff --git a/sentry_sdk/integrations/socket.py b/sentry_sdk/integrations/socket.py index 544a63c0f0..67e2feb2b2 100644 --- a/sentry_sdk/integrations/socket.py +++ b/sentry_sdk/integrations/socket.py @@ -1,3 +1,4 @@ +from __future__ import annotations import socket import sentry_sdk @@ -17,8 +18,7 @@ class SocketIntegration(Integration): origin = f"auto.socket.{identifier}" @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: """ patches two of the most used functions of socket: create_connection and getaddrinfo(dns resolver) """ @@ -26,8 +26,9 @@ def setup_once(): _patch_getaddrinfo() -def _get_span_description(host, port): - # type: (Union[bytes, str, None], Union[bytes, str, int, None]) -> str +def _get_span_description( + host: Union[bytes, str, None], port: Union[bytes, str, int, None] +) -> str: try: host = host.decode() # type: ignore @@ -43,16 +44,14 @@ def _get_span_description(host, port): return description -def _patch_create_connection(): - # type: () -> None +def _patch_create_connection() -> None: real_create_connection = socket.create_connection def create_connection( - address, - timeout=socket._GLOBAL_DEFAULT_TIMEOUT, # type: ignore - source_address=None, - ): - # type: (Tuple[Optional[str], int], Optional[float], Optional[Tuple[Union[bytearray, bytes, str], int]])-> socket.socket + address: Tuple[Optional[str], int], + timeout: Optional[float] = socket._GLOBAL_DEFAULT_TIMEOUT, # type: ignore + source_address: Optional[Tuple[Union[bytearray, bytes, str], int]] = None, + ) -> socket.socket: integration = sentry_sdk.get_client().get_integration(SocketIntegration) if integration is None: return real_create_connection(address, timeout, source_address) @@ -76,12 +75,25 @@ def create_connection( socket.create_connection = create_connection # type: ignore -def _patch_getaddrinfo(): - # type: () -> None +def _patch_getaddrinfo() -> None: real_getaddrinfo = socket.getaddrinfo - def getaddrinfo(host, port, family=0, type=0, proto=0, flags=0): - # type: (Union[bytes, str, None], Union[bytes, str, int, None], int, int, int, int) -> List[Tuple[AddressFamily, SocketKind, int, str, Union[Tuple[str, int], Tuple[str, int, int, int], Tuple[int, bytes]]]] + def getaddrinfo( + host: Union[bytes, str, None], + port: Union[bytes, str, int, None], + family: int = 0, + type: int = 0, + proto: int = 0, + flags: int = 0, + ) -> List[ + Tuple[ + AddressFamily, + SocketKind, + int, + str, + Union[Tuple[str, int], Tuple[str, int, int, int], Tuple[int, bytes]], + ] + ]: integration = sentry_sdk.get_client().get_integration(SocketIntegration) if integration is None: return real_getaddrinfo(host, port, family, type, proto, flags) diff --git a/sentry_sdk/integrations/spark/__init__.py b/sentry_sdk/integrations/spark/__init__.py index 10d94163c5..d9e8e3fa84 100644 --- a/sentry_sdk/integrations/spark/__init__.py +++ b/sentry_sdk/integrations/spark/__init__.py @@ -1,3 +1,4 @@ +from __future__ import annotations from sentry_sdk.integrations.spark.spark_driver import SparkIntegration from sentry_sdk.integrations.spark.spark_worker import SparkWorkerIntegration diff --git a/sentry_sdk/integrations/spark/spark_driver.py b/sentry_sdk/integrations/spark/spark_driver.py index fac985357f..a35883b60f 100644 --- a/sentry_sdk/integrations/spark/spark_driver.py +++ b/sentry_sdk/integrations/spark/spark_driver.py @@ -1,3 +1,4 @@ +from __future__ import annotations import sentry_sdk from sentry_sdk.integrations import Integration from sentry_sdk.utils import capture_internal_exceptions, ensure_integration_enabled @@ -16,13 +17,11 @@ class SparkIntegration(Integration): identifier = "spark" @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: _setup_sentry_tracing() -def _set_app_properties(): - # type: () -> None +def _set_app_properties() -> None: """ Set properties in driver that propagate to worker processes, allowing for workers to have access to those properties. This allows worker integration to have access to app_name and application_id. @@ -41,8 +40,7 @@ def _set_app_properties(): ) -def _start_sentry_listener(sc): - # type: (SparkContext) -> None +def _start_sentry_listener(sc: SparkContext) -> None: """ Start java gateway server to add custom `SparkListener` """ @@ -54,13 +52,11 @@ def _start_sentry_listener(sc): sc._jsc.sc().addSparkListener(listener) -def _add_event_processor(sc): - # type: (SparkContext) -> None +def _add_event_processor(sc: SparkContext) -> None: scope = sentry_sdk.get_isolation_scope() @scope.add_event_processor - def process_event(event, hint): - # type: (Event, Hint) -> Optional[Event] + def process_event(event: Event, hint: Hint) -> Optional[Event]: with capture_internal_exceptions(): if sentry_sdk.get_client().get_integration(SparkIntegration) is None: return event @@ -90,23 +86,22 @@ def process_event(event, hint): return event -def _activate_integration(sc): - # type: (SparkContext) -> None +def _activate_integration(sc: SparkContext) -> None: _start_sentry_listener(sc) _set_app_properties() _add_event_processor(sc) -def _patch_spark_context_init(): - # type: () -> None +def _patch_spark_context_init() -> None: from pyspark import SparkContext spark_context_init = SparkContext._do_init @ensure_integration_enabled(SparkIntegration, spark_context_init) - def _sentry_patched_spark_context_init(self, *args, **kwargs): - # type: (SparkContext, *Any, **Any) -> Optional[Any] + def _sentry_patched_spark_context_init( + self: SparkContext, *args: Any, **kwargs: Any + ) -> Optional[Any]: rv = spark_context_init(self, *args, **kwargs) _activate_integration(self) return rv @@ -114,8 +109,7 @@ def _sentry_patched_spark_context_init(self, *args, **kwargs): SparkContext._do_init = _sentry_patched_spark_context_init -def _setup_sentry_tracing(): - # type: () -> None +def _setup_sentry_tracing() -> None: from pyspark import SparkContext if SparkContext._active_spark_context is not None: @@ -125,102 +119,76 @@ def _setup_sentry_tracing(): class SparkListener: - def onApplicationEnd(self, applicationEnd): # noqa: N802,N803 - # type: (Any) -> None + def onApplicationEnd(self, applicationEnd: Any) -> None: pass - def onApplicationStart(self, applicationStart): # noqa: N802,N803 - # type: (Any) -> None + def onApplicationStart(self, applicationStart: Any) -> None: pass - def onBlockManagerAdded(self, blockManagerAdded): # noqa: N802,N803 - # type: (Any) -> None + def onBlockManagerAdded(self, blockManagerAdded: Any) -> None: pass - def onBlockManagerRemoved(self, blockManagerRemoved): # noqa: N802,N803 - # type: (Any) -> None + def onBlockManagerRemoved(self, blockManagerRemoved: Any) -> None: pass - def onBlockUpdated(self, blockUpdated): # noqa: N802,N803 - # type: (Any) -> None + def onBlockUpdated(self, blockUpdated: Any) -> None: pass - def onEnvironmentUpdate(self, environmentUpdate): # noqa: N802,N803 - # type: (Any) -> None + def onEnvironmentUpdate(self, environmentUpdate: Any) -> None: pass - def onExecutorAdded(self, executorAdded): # noqa: N802,N803 - # type: (Any) -> None + def onExecutorAdded(self, executorAdded: Any) -> None: pass - def onExecutorBlacklisted(self, executorBlacklisted): # noqa: N802,N803 - # type: (Any) -> None + def onExecutorBlacklisted(self, executorBlacklisted: Any) -> None: pass - def onExecutorBlacklistedForStage( # noqa: N802 - self, executorBlacklistedForStage # noqa: N803 - ): - # type: (Any) -> None + def onExecutorBlacklistedForStage(self, executorBlacklistedForStage: Any) -> None: pass - def onExecutorMetricsUpdate(self, executorMetricsUpdate): # noqa: N802,N803 - # type: (Any) -> None + def onExecutorMetricsUpdate(self, executorMetricsUpdate: Any) -> None: pass - def onExecutorRemoved(self, executorRemoved): # noqa: N802,N803 - # type: (Any) -> None + def onExecutorRemoved(self, executorRemoved: Any) -> None: pass - def onJobEnd(self, jobEnd): # noqa: N802,N803 - # type: (Any) -> None + def onJobEnd(self, jobEnd: Any) -> None: pass - def onJobStart(self, jobStart): # noqa: N802,N803 - # type: (Any) -> None + def onJobStart(self, jobStart: Any) -> None: pass - def onNodeBlacklisted(self, nodeBlacklisted): # noqa: N802,N803 - # type: (Any) -> None + def onNodeBlacklisted(self, nodeBlacklisted: Any) -> None: pass - def onNodeBlacklistedForStage(self, nodeBlacklistedForStage): # noqa: N802,N803 - # type: (Any) -> None + def onNodeBlacklistedForStage(self, nodeBlacklistedForStage: Any) -> None: pass - def onNodeUnblacklisted(self, nodeUnblacklisted): # noqa: N802,N803 - # type: (Any) -> None + def onNodeUnblacklisted(self, nodeUnblacklisted: Any) -> None: pass - def onOtherEvent(self, event): # noqa: N802,N803 - # type: (Any) -> None + def onOtherEvent(self, event: Any) -> None: pass - def onSpeculativeTaskSubmitted(self, speculativeTask): # noqa: N802,N803 - # type: (Any) -> None + def onSpeculativeTaskSubmitted(self, speculativeTask: Any) -> None: pass - def onStageCompleted(self, stageCompleted): # noqa: N802,N803 - # type: (Any) -> None + def onStageCompleted(self, stageCompleted: Any) -> None: pass - def onStageSubmitted(self, stageSubmitted): # noqa: N802,N803 - # type: (Any) -> None + def onStageSubmitted(self, stageSubmitted: Any) -> None: pass - def onTaskEnd(self, taskEnd): # noqa: N802,N803 - # type: (Any) -> None + def onTaskEnd(self, taskEnd: Any) -> None: pass - def onTaskGettingResult(self, taskGettingResult): # noqa: N802,N803 - # type: (Any) -> None + def onTaskGettingResult(self, taskGettingResult: Any) -> None: pass - def onTaskStart(self, taskStart): # noqa: N802,N803 - # type: (Any) -> None + def onTaskStart(self, taskStart: Any) -> None: pass - def onUnpersistRDD(self, unpersistRDD): # noqa: N802,N803 - # type: (Any) -> None + def onUnpersistRDD(self, unpersistRDD: Any) -> None: pass class Java: @@ -230,25 +198,22 @@ class Java: class SentryListener(SparkListener): def _add_breadcrumb( self, - level, # type: str - message, # type: str - data=None, # type: Optional[dict[str, Any]] - ): - # type: (...) -> None + level: str, + message: str, + data: Optional[dict[str, Any]] = None, + ) -> None: sentry_sdk.get_isolation_scope().add_breadcrumb( level=level, message=message, data=data ) - def onJobStart(self, jobStart): # noqa: N802,N803 - # type: (Any) -> None + def onJobStart(self, jobStart: Any) -> None: sentry_sdk.get_isolation_scope().clear_breadcrumbs() message = "Job {} Started".format(jobStart.jobId()) self._add_breadcrumb(level="info", message=message) _set_app_properties() - def onJobEnd(self, jobEnd): # noqa: N802,N803 - # type: (Any) -> None + def onJobEnd(self, jobEnd: Any) -> None: level = "" message = "" data = {"result": jobEnd.jobResult().toString()} @@ -262,8 +227,7 @@ def onJobEnd(self, jobEnd): # noqa: N802,N803 self._add_breadcrumb(level=level, message=message, data=data) - def onStageSubmitted(self, stageSubmitted): # noqa: N802,N803 - # type: (Any) -> None + def onStageSubmitted(self, stageSubmitted: Any) -> None: stage_info = stageSubmitted.stageInfo() message = "Stage {} Submitted".format(stage_info.stageId()) @@ -275,8 +239,7 @@ def onStageSubmitted(self, stageSubmitted): # noqa: N802,N803 self._add_breadcrumb(level="info", message=message, data=data) _set_app_properties() - def onStageCompleted(self, stageCompleted): # noqa: N802,N803 - # type: (Any) -> None + def onStageCompleted(self, stageCompleted: Any) -> None: from py4j.protocol import Py4JJavaError # type: ignore stage_info = stageCompleted.stageInfo() @@ -300,8 +263,7 @@ def onStageCompleted(self, stageCompleted): # noqa: N802,N803 self._add_breadcrumb(level=level, message=message, data=data) -def _get_attempt_id(stage_info): - # type: (Any) -> Optional[int] +def _get_attempt_id(stage_info: Any) -> Optional[int]: try: return stage_info.attemptId() except Exception: diff --git a/sentry_sdk/integrations/spark/spark_worker.py b/sentry_sdk/integrations/spark/spark_worker.py index 5340a0b350..ce42c752f5 100644 --- a/sentry_sdk/integrations/spark/spark_worker.py +++ b/sentry_sdk/integrations/spark/spark_worker.py @@ -1,3 +1,4 @@ +from __future__ import annotations import sys import sentry_sdk @@ -23,15 +24,13 @@ class SparkWorkerIntegration(Integration): identifier = "spark_worker" @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: import pyspark.daemon as original_daemon original_daemon.worker_main = _sentry_worker_main -def _capture_exception(exc_info): - # type: (ExcInfo) -> None +def _capture_exception(exc_info: ExcInfo) -> None: client = sentry_sdk.get_client() mechanism = {"type": "spark", "handled": False} @@ -53,22 +52,20 @@ def _capture_exception(exc_info): if rv: rv.reverse() hint = event_hint_with_exc_info(exc_info) - event = {"level": "error", "exception": {"values": rv}} # type: Event + event: Event = {"level": "error", "exception": {"values": rv}} _tag_task_context() sentry_sdk.capture_event(event, hint=hint) -def _tag_task_context(): - # type: () -> None +def _tag_task_context() -> None: from pyspark.taskcontext import TaskContext scope = sentry_sdk.get_isolation_scope() @scope.add_event_processor - def process_event(event, hint): - # type: (Event, Hint) -> Optional[Event] + def process_event(event: Event, hint: Hint) -> Optional[Event]: with capture_internal_exceptions(): integration = sentry_sdk.get_client().get_integration( SparkWorkerIntegration @@ -103,8 +100,7 @@ def process_event(event, hint): return event -def _sentry_worker_main(*args, **kwargs): - # type: (*Optional[Any], **Optional[Any]) -> None +def _sentry_worker_main(*args: Optional[Any], **kwargs: Optional[Any]) -> None: import pyspark.worker as original_worker try: diff --git a/sentry_sdk/integrations/sqlalchemy.py b/sentry_sdk/integrations/sqlalchemy.py index 4c4d8fde8c..658d10b3ca 100644 --- a/sentry_sdk/integrations/sqlalchemy.py +++ b/sentry_sdk/integrations/sqlalchemy.py @@ -1,3 +1,4 @@ +from __future__ import annotations from sentry_sdk.consts import SPANSTATUS, SPANDATA from sentry_sdk.integrations import _check_minimum_version, Integration, DidNotEnable from sentry_sdk.tracing_utils import add_query_source, record_sql_queries @@ -29,8 +30,7 @@ class SqlalchemyIntegration(Integration): origin = f"auto.db.{identifier}" @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: version = parse_version(SQLALCHEMY_VERSION) _check_minimum_version(SqlalchemyIntegration, version) @@ -41,9 +41,14 @@ def setup_once(): @ensure_integration_enabled(SqlalchemyIntegration) def _before_cursor_execute( - conn, cursor, statement, parameters, context, executemany, *args -): - # type: (Any, Any, Any, Any, Any, bool, *Any) -> None + conn: Any, + cursor: Any, + statement: Any, + parameters: Any, + context: Any, + executemany: bool, + *args: Any, +) -> None: ctx_mgr = record_sql_queries( cursor, statement, @@ -62,13 +67,14 @@ def _before_cursor_execute( @ensure_integration_enabled(SqlalchemyIntegration) -def _after_cursor_execute(conn, cursor, statement, parameters, context, *args): - # type: (Any, Any, Any, Any, Any, *Any) -> None - ctx_mgr = getattr( +def _after_cursor_execute( + conn: Any, cursor: Any, statement: Any, parameters: Any, context: Any, *args: Any +) -> None: + ctx_mgr: Optional[ContextManager[Any]] = getattr( context, "_sentry_sql_span_manager", None - ) # type: Optional[ContextManager[Any]] + ) - span = getattr(context, "_sentry_sql_span", None) # type: Optional[Span] + span: Optional[Span] = getattr(context, "_sentry_sql_span", None) if span is not None: with capture_internal_exceptions(): add_query_source(span) @@ -78,13 +84,12 @@ def _after_cursor_execute(conn, cursor, statement, parameters, context, *args): ctx_mgr.__exit__(None, None, None) -def _handle_error(context, *args): - # type: (Any, *Any) -> None +def _handle_error(context: Any, *args: Any) -> None: execution_context = context.execution_context if execution_context is None: return - span = getattr(execution_context, "_sentry_sql_span", None) # type: Optional[Span] + span: Optional[Span] = getattr(execution_context, "_sentry_sql_span", None) if span is not None: span.set_status(SPANSTATUS.INTERNAL_ERROR) @@ -92,9 +97,9 @@ def _handle_error(context, *args): # _after_cursor_execute does not get called for crashing SQL stmts. Judging # from SQLAlchemy codebase it does seem like any error coming into this # handler is going to be fatal. - ctx_mgr = getattr( + ctx_mgr: Optional[ContextManager[Any]] = getattr( execution_context, "_sentry_sql_span_manager", None - ) # type: Optional[ContextManager[Any]] + ) if ctx_mgr is not None: execution_context._sentry_sql_span_manager = None @@ -102,8 +107,7 @@ def _handle_error(context, *args): # See: https://docs.sqlalchemy.org/en/20/dialects/index.html -def _get_db_system(name): - # type: (str) -> Optional[str] +def _get_db_system(name: str) -> Optional[str]: name = str(name) if "sqlite" in name: @@ -124,8 +128,7 @@ def _get_db_system(name): return None -def _set_db_data(span, conn): - # type: (Span, Any) -> None +def _set_db_data(span: Span, conn: Any) -> None: db_system = _get_db_system(conn.engine.name) if db_system is not None: span.set_attribute(SPANDATA.DB_SYSTEM, db_system) diff --git a/sentry_sdk/integrations/starlette.py b/sentry_sdk/integrations/starlette.py index e6016a3624..117a0b0031 100644 --- a/sentry_sdk/integrations/starlette.py +++ b/sentry_sdk/integrations/starlette.py @@ -1,3 +1,4 @@ +from __future__ import annotations import asyncio import functools from collections.abc import Set @@ -82,12 +83,11 @@ class StarletteIntegration(Integration): def __init__( self, - transaction_style="url", # type: str - failed_request_status_codes=_DEFAULT_FAILED_REQUEST_STATUS_CODES, # type: Set[int] - middleware_spans=True, # type: bool - http_methods_to_capture=DEFAULT_HTTP_METHODS_TO_CAPTURE, # type: tuple[str, ...] - ): - # type: (...) -> None + transaction_style: str = "url", + failed_request_status_codes: Set[int] = _DEFAULT_FAILED_REQUEST_STATUS_CODES, + middleware_spans: bool = True, + http_methods_to_capture: tuple[str, ...] = DEFAULT_HTTP_METHODS_TO_CAPTURE, + ) -> None: if transaction_style not in TRANSACTION_STYLE_VALUES: raise ValueError( "Invalid value for transaction_style: %s (must be in %s)" @@ -100,8 +100,7 @@ def __init__( self.failed_request_status_codes = failed_request_status_codes @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: version = parse_version(STARLETTE_VERSION) if version is None: @@ -117,12 +116,16 @@ def setup_once(): patch_templates() -def _enable_span_for_middleware(middleware_class): - # type: (Any) -> type +def _enable_span_for_middleware(middleware_class: Any) -> type: old_call = middleware_class.__call__ - async def _create_span_call(app, scope, receive, send, **kwargs): - # type: (Any, Dict[str, Any], Callable[[], Awaitable[Dict[str, Any]]], Callable[[Dict[str, Any]], Awaitable[None]], Any) -> None + async def _create_span_call( + app: Any, + scope: Dict[str, Any], + receive: Callable[[], Awaitable[Dict[str, Any]]], + send: Callable[[Dict[str, Any]], Awaitable[None]], + **kwargs: Any, + ) -> None: integration = sentry_sdk.get_client().get_integration(StarletteIntegration) if integration is None or not integration.middleware_spans: return await old_call(app, scope, receive, send, **kwargs) @@ -146,8 +149,7 @@ async def _create_span_call(app, scope, receive, send, **kwargs): middleware_span.set_tag("starlette.middleware_name", middleware_name) # Creating spans for the "receive" callback - async def _sentry_receive(*args, **kwargs): - # type: (*Any, **Any) -> Any + async def _sentry_receive(*args: Any, **kwargs: Any) -> Any: with sentry_sdk.start_span( op=OP.MIDDLEWARE_STARLETTE_RECEIVE, name=getattr(receive, "__qualname__", str(receive)), @@ -162,8 +164,7 @@ async def _sentry_receive(*args, **kwargs): new_receive = _sentry_receive if not receive_patched else receive # Creating spans for the "send" callback - async def _sentry_send(*args, **kwargs): - # type: (*Any, **Any) -> Any + async def _sentry_send(*args: Any, **kwargs: Any) -> Any: with sentry_sdk.start_span( op=OP.MIDDLEWARE_STARLETTE_SEND, name=getattr(send, "__qualname__", str(send)), @@ -192,8 +193,7 @@ async def _sentry_send(*args, **kwargs): @ensure_integration_enabled(StarletteIntegration) -def _capture_exception(exception, handled=False): - # type: (BaseException, **Any) -> None +def _capture_exception(exception: BaseException, handled: Any = False) -> None: event, hint = event_from_exception( exception, client_options=sentry_sdk.get_client().options, @@ -203,8 +203,7 @@ def _capture_exception(exception, handled=False): sentry_sdk.capture_event(event, hint=hint) -def patch_exception_middleware(middleware_class): - # type: (Any) -> None +def patch_exception_middleware(middleware_class: Any) -> None: """ Capture all exceptions in Starlette app and also extract user information. @@ -215,15 +214,15 @@ def patch_exception_middleware(middleware_class): if not_yet_patched: - def _sentry_middleware_init(self, *args, **kwargs): - # type: (Any, Any, Any) -> None + def _sentry_middleware_init(self: Any, *args: Any, **kwargs: Any) -> None: old_middleware_init(self, *args, **kwargs) # Patch existing exception handlers old_handlers = self._exception_handlers.copy() - async def _sentry_patched_exception_handler(self, *args, **kwargs): - # type: (Any, Any, Any) -> None + async def _sentry_patched_exception_handler( + self: Any, *args: Any, **kwargs: Any + ) -> None: integration = sentry_sdk.get_client().get_integration( StarletteIntegration ) @@ -261,8 +260,12 @@ async def _sentry_patched_exception_handler(self, *args, **kwargs): old_call = middleware_class.__call__ - async def _sentry_exceptionmiddleware_call(self, scope, receive, send): - # type: (Dict[str, Any], Dict[str, Any], Callable[[], Awaitable[Dict[str, Any]]], Callable[[Dict[str, Any]], Awaitable[None]]) -> None + async def _sentry_exceptionmiddleware_call( + self: Dict[str, Any], + scope: Dict[str, Any], + receive: Callable[[], Awaitable[Dict[str, Any]]], + send: Callable[[Dict[str, Any]], Awaitable[None]], + ) -> None: # Also add the user (that was eventually set by be Authentication middle # that was called before this middleware). This is done because the authentication # middleware sets the user in the scope and then (in the same function) @@ -281,8 +284,7 @@ async def _sentry_exceptionmiddleware_call(self, scope, receive, send): @ensure_integration_enabled(StarletteIntegration) -def _add_user_to_sentry_scope(scope): - # type: (Dict[str, Any]) -> None +def _add_user_to_sentry_scope(scope: Dict[str, Any]) -> None: """ Extracts user information from the ASGI scope and adds it to Sentry's scope. @@ -293,7 +295,7 @@ def _add_user_to_sentry_scope(scope): if not should_send_default_pii(): return - user_info = {} # type: Dict[str, Any] + user_info: Dict[str, Any] = {} starlette_user = scope["user"] username = getattr(starlette_user, "username", None) @@ -312,8 +314,7 @@ def _add_user_to_sentry_scope(scope): sentry_scope.set_user(user_info) -def patch_authentication_middleware(middleware_class): - # type: (Any) -> None +def patch_authentication_middleware(middleware_class: Any) -> None: """ Add user information to Sentry scope. """ @@ -323,16 +324,19 @@ def patch_authentication_middleware(middleware_class): if not_yet_patched: - async def _sentry_authenticationmiddleware_call(self, scope, receive, send): - # type: (Dict[str, Any], Dict[str, Any], Callable[[], Awaitable[Dict[str, Any]]], Callable[[Dict[str, Any]], Awaitable[None]]) -> None + async def _sentry_authenticationmiddleware_call( + self: Dict[str, Any], + scope: Dict[str, Any], + receive: Callable[[], Awaitable[Dict[str, Any]]], + send: Callable[[Dict[str, Any]], Awaitable[None]], + ) -> None: await old_call(self, scope, receive, send) _add_user_to_sentry_scope(scope) middleware_class.__call__ = _sentry_authenticationmiddleware_call -def patch_middlewares(): - # type: () -> None +def patch_middlewares() -> None: """ Patches Starlettes `Middleware` class to record spans for every middleware invoked. @@ -343,8 +347,9 @@ def patch_middlewares(): if not_yet_patched: - def _sentry_middleware_init(self, cls, *args, **kwargs): - # type: (Any, Any, Any, Any) -> None + def _sentry_middleware_init( + self: Any, cls: Any, *args: Any, **kwargs: Any + ) -> None: if cls == SentryAsgiMiddleware: return old_middleware_init(self, cls, *args, **kwargs) @@ -360,15 +365,15 @@ def _sentry_middleware_init(self, cls, *args, **kwargs): Middleware.__init__ = _sentry_middleware_init -def patch_asgi_app(): - # type: () -> None +def patch_asgi_app() -> None: """ Instrument Starlette ASGI app using the SentryAsgiMiddleware. """ old_app = Starlette.__call__ - async def _sentry_patched_asgi_app(self, scope, receive, send): - # type: (Starlette, StarletteScope, Receive, Send) -> None + async def _sentry_patched_asgi_app( + self: Starlette, scope: StarletteScope, receive: Receive, send: Send + ) -> None: integration = sentry_sdk.get_client().get_integration(StarletteIntegration) if integration is None: return await old_app(self, scope, receive, send) @@ -393,8 +398,7 @@ async def _sentry_patched_asgi_app(self, scope, receive, send): # This was vendored in from Starlette to support Starlette 0.19.1 because # this function was only introduced in 0.20.x -def _is_async_callable(obj): - # type: (Any) -> bool +def _is_async_callable(obj: Any) -> bool: while isinstance(obj, functools.partial): obj = obj.func @@ -403,19 +407,16 @@ def _is_async_callable(obj): ) -def patch_request_response(): - # type: () -> None +def patch_request_response() -> None: old_request_response = starlette.routing.request_response - def _sentry_request_response(func): - # type: (Callable[[Any], Any]) -> ASGIApp + def _sentry_request_response(func: Callable[[Any], Any]) -> ASGIApp: old_func = func is_coroutine = _is_async_callable(old_func) if is_coroutine: - async def _sentry_async_func(*args, **kwargs): - # type: (*Any, **Any) -> Any + async def _sentry_async_func(*args: Any, **kwargs: Any) -> Any: integration = sentry_sdk.get_client().get_integration( StarletteIntegration ) @@ -434,10 +435,10 @@ async def _sentry_async_func(*args, **kwargs): extractor = StarletteRequestExtractor(request) info = await extractor.extract_request_info() - def _make_request_event_processor(req, integration): - # type: (Any, Any) -> Callable[[Event, dict[str, Any]], Event] - def event_processor(event, hint): - # type: (Event, Dict[str, Any]) -> Event + def _make_request_event_processor( + req: Any, integration: Any + ) -> Callable[[Event, dict[str, Any]], Event]: + def event_processor(event: Event, hint: Dict[str, Any]) -> Event: # Add info from request to event request_info = event.get("request", {}) @@ -464,8 +465,7 @@ def event_processor(event, hint): else: @functools.wraps(old_func) - def _sentry_sync_func(*args, **kwargs): - # type: (*Any, **Any) -> Any + def _sentry_sync_func(*args: Any, **kwargs: Any) -> Any: integration = sentry_sdk.get_client().get_integration( StarletteIntegration ) @@ -489,10 +489,10 @@ def _sentry_sync_func(*args, **kwargs): extractor = StarletteRequestExtractor(request) cookies = extractor.extract_cookies_from_request() - def _make_request_event_processor(req, integration): - # type: (Any, Any) -> Callable[[Event, dict[str, Any]], Event] - def event_processor(event, hint): - # type: (Event, dict[str, Any]) -> Event + def _make_request_event_processor( + req: Any, integration: Any + ) -> Callable[[Event, dict[str, Any]], Event]: + def event_processor(event: Event, hint: dict[str, Any]) -> Event: # Extract information from request request_info = event.get("request", {}) @@ -519,8 +519,7 @@ def event_processor(event, hint): starlette.routing.request_response = _sentry_request_response -def patch_templates(): - # type: () -> None +def patch_templates() -> None: # If markupsafe is not installed, then Jinja2 is not installed # (markupsafe is a dependency of Jinja2) @@ -540,10 +539,10 @@ def patch_templates(): if not_yet_patched: - def _sentry_jinja2templates_init(self, *args, **kwargs): - # type: (Jinja2Templates, *Any, **Any) -> None - def add_sentry_trace_meta(request): - # type: (Request) -> Dict[str, Any] + def _sentry_jinja2templates_init( + self: Jinja2Templates, *args: Any, **kwargs: Any + ) -> None: + def add_sentry_trace_meta(request: Request) -> Dict[str, Any]: trace_meta = Markup( sentry_sdk.get_current_scope().trace_propagation_meta() ) @@ -567,25 +566,26 @@ class StarletteRequestExtractor: (like form data or cookies) and adds it to the Sentry event. """ - request = None # type: Request + request: Request = None - def __init__(self, request): - # type: (StarletteRequestExtractor, Request) -> None + def __init__(self: StarletteRequestExtractor, request: Request) -> None: self.request = request - def extract_cookies_from_request(self): - # type: (StarletteRequestExtractor) -> Optional[Dict[str, Any]] - cookies = None # type: Optional[Dict[str, Any]] + def extract_cookies_from_request( + self: StarletteRequestExtractor, + ) -> Optional[Dict[str, Any]]: + cookies: Optional[Dict[str, Any]] = None if should_send_default_pii(): cookies = self.cookies() return cookies - async def extract_request_info(self): - # type: (StarletteRequestExtractor) -> Optional[Dict[str, Any]] + async def extract_request_info( + self: StarletteRequestExtractor, + ) -> Optional[Dict[str, Any]]: client = sentry_sdk.get_client() - request_info = {} # type: Dict[str, Any] + request_info: Dict[str, Any] = {} with capture_internal_exceptions(): # Add cookies @@ -629,19 +629,16 @@ async def extract_request_info(self): request_info["data"] = AnnotatedValue.removed_because_raw_data() return request_info - async def content_length(self): - # type: (StarletteRequestExtractor) -> Optional[int] + async def content_length(self: StarletteRequestExtractor) -> Optional[int]: if "content-length" in self.request.headers: return int(self.request.headers["content-length"]) return None - def cookies(self): - # type: (StarletteRequestExtractor) -> Dict[str, Any] + def cookies(self: StarletteRequestExtractor) -> Dict[str, Any]: return self.request.cookies - async def form(self): - # type: (StarletteRequestExtractor) -> Any + async def form(self: StarletteRequestExtractor) -> Any: if multipart is None: return None @@ -653,12 +650,10 @@ async def form(self): return await self.request.form() - def is_json(self): - # type: (StarletteRequestExtractor) -> bool + def is_json(self: StarletteRequestExtractor) -> bool: return _is_json_content_type(self.request.headers.get("content-type")) - async def json(self): - # type: (StarletteRequestExtractor) -> Optional[Dict[str, Any]] + async def json(self: StarletteRequestExtractor) -> Optional[Dict[str, Any]]: if not self.is_json(): return None try: @@ -667,8 +662,7 @@ async def json(self): return None -def _transaction_name_from_router(scope): - # type: (StarletteScope) -> Optional[str] +def _transaction_name_from_router(scope: StarletteScope) -> Optional[str]: router = scope.get("router") if not router: return None @@ -685,8 +679,9 @@ def _transaction_name_from_router(scope): return None -def _set_transaction_name_and_source(scope, transaction_style, request): - # type: (sentry_sdk.Scope, str, Any) -> None +def _set_transaction_name_and_source( + scope: sentry_sdk.Scope, transaction_style: str, request: Any +) -> None: name = None source = SOURCE_FOR_STYLE[transaction_style] @@ -708,8 +703,9 @@ def _set_transaction_name_and_source(scope, transaction_style, request): ) -def _get_transaction_from_middleware(app, asgi_scope, integration): - # type: (Any, Dict[str, Any], StarletteIntegration) -> Tuple[Optional[str], Optional[str]] +def _get_transaction_from_middleware( + app: Any, asgi_scope: Dict[str, Any], integration: StarletteIntegration +) -> Tuple[Optional[str], Optional[str]]: name = None source = None diff --git a/sentry_sdk/integrations/starlite.py b/sentry_sdk/integrations/starlite.py index 928c697373..e9b42f29b4 100644 --- a/sentry_sdk/integrations/starlite.py +++ b/sentry_sdk/integrations/starlite.py @@ -1,3 +1,4 @@ +from __future__ import annotations import sentry_sdk from sentry_sdk.consts import OP, SOURCE_FOR_STYLE, TransactionSource from sentry_sdk.integrations import DidNotEnable, Integration @@ -48,16 +49,16 @@ class StarliteIntegration(Integration): origin = f"auto.http.{identifier}" @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: patch_app_init() patch_middlewares() patch_http_route_handle() class SentryStarliteASGIMiddleware(SentryAsgiMiddleware): - def __init__(self, app, span_origin=StarliteIntegration.origin): - # type: (ASGIApp, str) -> None + def __init__( + self, app: ASGIApp, span_origin: str = StarliteIntegration.origin + ) -> None: super().__init__( app=app, unsafe_context_data=False, @@ -67,8 +68,7 @@ def __init__(self, app, span_origin=StarliteIntegration.origin): ) -def patch_app_init(): - # type: () -> None +def patch_app_init() -> None: """ Replaces the Starlite class's `__init__` function in order to inject `after_exception` handlers and set the `SentryStarliteASGIMiddleware` as the outmost middleware in the stack. @@ -79,8 +79,7 @@ def patch_app_init(): old__init__ = Starlite.__init__ @ensure_integration_enabled(StarliteIntegration, old__init__) - def injection_wrapper(self, *args, **kwargs): - # type: (Starlite, *Any, **Any) -> None + def injection_wrapper(self: Starlite, *args: Any, **kwargs: Any) -> None: after_exception = kwargs.pop("after_exception", []) kwargs.update( after_exception=[ @@ -101,13 +100,11 @@ def injection_wrapper(self, *args, **kwargs): Starlite.__init__ = injection_wrapper -def patch_middlewares(): - # type: () -> None +def patch_middlewares() -> None: old_resolve_middleware_stack = BaseRouteHandler.resolve_middleware @ensure_integration_enabled(StarliteIntegration, old_resolve_middleware_stack) - def resolve_middleware_wrapper(self): - # type: (BaseRouteHandler) -> list[Middleware] + def resolve_middleware_wrapper(self: BaseRouteHandler) -> list[Middleware]: return [ enable_span_for_middleware(middleware) for middleware in old_resolve_middleware_stack(self) @@ -116,8 +113,7 @@ def resolve_middleware_wrapper(self): BaseRouteHandler.resolve_middleware = resolve_middleware_wrapper -def enable_span_for_middleware(middleware): - # type: (Middleware) -> Middleware +def enable_span_for_middleware(middleware: Middleware) -> Middleware: if ( not hasattr(middleware, "__call__") # noqa: B004 or middleware is SentryStarliteASGIMiddleware @@ -125,12 +121,13 @@ def enable_span_for_middleware(middleware): return middleware if isinstance(middleware, DefineMiddleware): - old_call = middleware.middleware.__call__ # type: ASGIApp + old_call: ASGIApp = middleware.middleware.__call__ else: old_call = middleware.__call__ - async def _create_span_call(self, scope, receive, send): - # type: (MiddlewareProtocol, StarliteScope, Receive, Send) -> None + async def _create_span_call( + self: MiddlewareProtocol, scope: StarliteScope, receive: Receive, send: Send + ) -> None: if sentry_sdk.get_client().get_integration(StarliteIntegration) is None: return await old_call(self, scope, receive, send) @@ -144,8 +141,9 @@ async def _create_span_call(self, scope, receive, send): middleware_span.set_tag("starlite.middleware_name", middleware_name) # Creating spans for the "receive" callback - async def _sentry_receive(*args, **kwargs): - # type: (*Any, **Any) -> Union[HTTPReceiveMessage, WebSocketReceiveMessage] + async def _sentry_receive( + *args: Any, **kwargs: Any + ) -> Union[HTTPReceiveMessage, WebSocketReceiveMessage]: if sentry_sdk.get_client().get_integration(StarliteIntegration) is None: return await receive(*args, **kwargs) with sentry_sdk.start_span( @@ -162,8 +160,7 @@ async def _sentry_receive(*args, **kwargs): new_receive = _sentry_receive if not receive_patched else receive # Creating spans for the "send" callback - async def _sentry_send(message): - # type: (Message) -> None + async def _sentry_send(message: Message) -> None: if sentry_sdk.get_client().get_integration(StarliteIntegration) is None: return await send(message) with sentry_sdk.start_span( @@ -192,19 +189,19 @@ async def _sentry_send(message): return middleware -def patch_http_route_handle(): - # type: () -> None +def patch_http_route_handle() -> None: old_handle = HTTPRoute.handle - async def handle_wrapper(self, scope, receive, send): - # type: (HTTPRoute, HTTPScope, Receive, Send) -> None + async def handle_wrapper( + self: HTTPRoute, scope: HTTPScope, receive: Receive, send: Send + ) -> None: if sentry_sdk.get_client().get_integration(StarliteIntegration) is None: return await old_handle(self, scope, receive, send) sentry_scope = sentry_sdk.get_isolation_scope() - request = scope["app"].request_class( + request: Request[Any, Any] = scope["app"].request_class( scope=scope, receive=receive, send=send - ) # type: Request[Any, Any] + ) extracted_request_data = ConnectionDataExtractor( parse_body=True, parse_query=True )(request) @@ -212,8 +209,7 @@ async def handle_wrapper(self, scope, receive, send): request_data = await body - def event_processor(event, _): - # type: (Event, Hint) -> Event + def event_processor(event: Event, _: Hint) -> Event: route_handler = scope.get("route_handler") request_info = event.get("request", {}) @@ -256,8 +252,7 @@ def event_processor(event, _): HTTPRoute.handle = handle_wrapper -def retrieve_user_from_scope(scope): - # type: (StarliteScope) -> Optional[dict[str, Any]] +def retrieve_user_from_scope(scope: StarliteScope) -> Optional[dict[str, Any]]: scope_user = scope.get("user") if not scope_user: return None @@ -276,9 +271,8 @@ def retrieve_user_from_scope(scope): @ensure_integration_enabled(StarliteIntegration) -def exception_handler(exc, scope, _): - # type: (Exception, StarliteScope, State) -> None - user_info = None # type: Optional[dict[str, Any]] +def exception_handler(exc: Exception, scope: StarliteScope, _: State) -> None: + user_info: Optional[dict[str, Any]] = None if should_send_default_pii(): user_info = retrieve_user_from_scope(scope) if user_info and isinstance(user_info, dict): diff --git a/sentry_sdk/integrations/statsig.py b/sentry_sdk/integrations/statsig.py index 1d84eb8aa2..9a62e3d18f 100644 --- a/sentry_sdk/integrations/statsig.py +++ b/sentry_sdk/integrations/statsig.py @@ -1,3 +1,4 @@ +from __future__ import annotations from functools import wraps from typing import Any, TYPE_CHECKING @@ -19,8 +20,7 @@ class StatsigIntegration(Integration): identifier = "statsig" @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: version = parse_version(STATSIG_VERSION) _check_minimum_version(StatsigIntegration, version, "statsig") @@ -28,8 +28,9 @@ def setup_once(): old_check_gate = statsig_module.check_gate @wraps(old_check_gate) - def sentry_check_gate(user, gate, *args, **kwargs): - # type: (StatsigUser, str, *Any, **Any) -> Any + def sentry_check_gate( + user: StatsigUser, gate: str, *args: Any, **kwargs: Any + ) -> Any: enabled = old_check_gate(user, gate, *args, **kwargs) add_feature_flag(gate, enabled) return enabled diff --git a/sentry_sdk/integrations/stdlib.py b/sentry_sdk/integrations/stdlib.py index 2507eb7895..9058939c04 100644 --- a/sentry_sdk/integrations/stdlib.py +++ b/sentry_sdk/integrations/stdlib.py @@ -1,3 +1,4 @@ +from __future__ import annotations import os import subprocess import sys @@ -34,25 +35,23 @@ from sentry_sdk._types import Event, Hint -_RUNTIME_CONTEXT = { +_RUNTIME_CONTEXT: dict[str, object] = { "name": platform.python_implementation(), "version": "%s.%s.%s" % (sys.version_info[:3]), "build": sys.version, -} # type: dict[str, object] +} class StdlibIntegration(Integration): identifier = "stdlib" @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: _install_httplib() _install_subprocess() @add_global_event_processor - def add_python_runtime_context(event, hint): - # type: (Event, Hint) -> Optional[Event] + def add_python_runtime_context(event: Event, hint: Hint) -> Optional[Event]: if sentry_sdk.get_client().get_integration(StdlibIntegration) is not None: contexts = event.setdefault("contexts", {}) if isinstance(contexts, dict) and "runtime" not in contexts: @@ -61,13 +60,13 @@ def add_python_runtime_context(event, hint): return event -def _install_httplib(): - # type: () -> None +def _install_httplib() -> None: real_putrequest = HTTPConnection.putrequest real_getresponse = HTTPConnection.getresponse - def putrequest(self, method, url, *args, **kwargs): - # type: (HTTPConnection, str, str, *Any, **Any) -> Any + def putrequest( + self: HTTPConnection, method: str, url: str, *args: Any, **kwargs: Any + ) -> Any: host = self.host port = self.port default_port = self.default_port @@ -134,8 +133,7 @@ def putrequest(self, method, url, *args, **kwargs): return rv - def getresponse(self, *args, **kwargs): - # type: (HTTPConnection, *Any, **Any) -> Any + def getresponse(self: HTTPConnection, *args: Any, **kwargs: Any) -> Any: span = getattr(self, "_sentrysdk_span", None) if span is None: @@ -167,8 +165,13 @@ def getresponse(self, *args, **kwargs): HTTPConnection.getresponse = getresponse # type: ignore[method-assign] -def _init_argument(args, kwargs, name, position, setdefault_callback=None): - # type: (List[Any], Dict[Any, Any], str, int, Optional[Callable[[Any], Any]]) -> Any +def _init_argument( + args: List[Any], + kwargs: Dict[Any, Any], + name: str, + position: int, + setdefault_callback: Optional[Callable[[Any], Any]] = None, +) -> Any: """ given (*args, **kwargs) of a function call, retrieve (and optionally set a default for) an argument by either name or position. @@ -198,13 +201,13 @@ def _init_argument(args, kwargs, name, position, setdefault_callback=None): return rv -def _install_subprocess(): - # type: () -> None +def _install_subprocess() -> None: old_popen_init = subprocess.Popen.__init__ @ensure_integration_enabled(StdlibIntegration, old_popen_init) - def sentry_patched_popen_init(self, *a, **kw): - # type: (subprocess.Popen[Any], *Any, **Any) -> None + def sentry_patched_popen_init( + self: subprocess.Popen[Any], *a: Any, **kw: Any + ) -> None: # Convert from tuple to list to be able to set values. a = list(a) @@ -279,8 +282,9 @@ def sentry_patched_popen_init(self, *a, **kw): old_popen_wait = subprocess.Popen.wait @ensure_integration_enabled(StdlibIntegration, old_popen_wait) - def sentry_patched_popen_wait(self, *a, **kw): - # type: (subprocess.Popen[Any], *Any, **Any) -> Any + def sentry_patched_popen_wait( + self: subprocess.Popen[Any], *a: Any, **kw: Any + ) -> Any: with sentry_sdk.start_span( op=OP.SUBPROCESS_WAIT, origin="auto.subprocess.stdlib.subprocess", @@ -294,8 +298,9 @@ def sentry_patched_popen_wait(self, *a, **kw): old_popen_communicate = subprocess.Popen.communicate @ensure_integration_enabled(StdlibIntegration, old_popen_communicate) - def sentry_patched_popen_communicate(self, *a, **kw): - # type: (subprocess.Popen[Any], *Any, **Any) -> Any + def sentry_patched_popen_communicate( + self: subprocess.Popen[Any], *a: Any, **kw: Any + ) -> Any: with sentry_sdk.start_span( op=OP.SUBPROCESS_COMMUNICATE, origin="auto.subprocess.stdlib.subprocess", @@ -307,6 +312,5 @@ def sentry_patched_popen_communicate(self, *a, **kw): subprocess.Popen.communicate = sentry_patched_popen_communicate # type: ignore -def get_subprocess_traceparent_headers(): - # type: () -> EnvironHeaders +def get_subprocess_traceparent_headers() -> EnvironHeaders: return EnvironHeaders(os.environ, prefix="SUBPROCESS_") diff --git a/sentry_sdk/integrations/strawberry.py b/sentry_sdk/integrations/strawberry.py index 274ae8d1c9..7994daded6 100644 --- a/sentry_sdk/integrations/strawberry.py +++ b/sentry_sdk/integrations/strawberry.py @@ -1,3 +1,4 @@ +from __future__ import annotations import functools import hashlib from inspect import isawaitable @@ -62,8 +63,7 @@ class StrawberryIntegration(Integration): identifier = "strawberry" origin = f"auto.graphql.{identifier}" - def __init__(self, async_execution=None): - # type: (Optional[bool]) -> None + def __init__(self, async_execution: Optional[bool] = None) -> None: if async_execution not in (None, False, True): raise ValueError( 'Invalid value for async_execution: "{}" (must be bool)'.format( @@ -73,8 +73,7 @@ def __init__(self, async_execution=None): self.async_execution = async_execution @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: version = package_version("strawberry-graphql") _check_minimum_version(StrawberryIntegration, version, "strawberry-graphql") @@ -82,13 +81,11 @@ def setup_once(): _patch_views() -def _patch_schema_init(): - # type: () -> None +def _patch_schema_init() -> None: old_schema_init = Schema.__init__ @functools.wraps(old_schema_init) - def _sentry_patched_schema_init(self, *args, **kwargs): - # type: (Schema, Any, Any) -> None + def _sentry_patched_schema_init(self: Schema, *args: Any, **kwargs: Any) -> None: integration = sentry_sdk.get_client().get_integration(StrawberryIntegration) if integration is None: return old_schema_init(self, *args, **kwargs) @@ -121,17 +118,15 @@ def _sentry_patched_schema_init(self, *args, **kwargs): class SentryAsyncExtension(SchemaExtension): def __init__( - self, + self: Any, *, - execution_context=None, - ): - # type: (Any, Optional[ExecutionContext]) -> None + execution_context: Optional[ExecutionContext] = None, + ) -> None: if execution_context: self.execution_context = execution_context @cached_property - def _resource_name(self): - # type: () -> str + def _resource_name(self) -> str: query_hash = self.hash_query(self.execution_context.query) # type: ignore if self.execution_context.operation_name: @@ -139,12 +134,10 @@ def _resource_name(self): return query_hash - def hash_query(self, query): - # type: (str) -> str + def hash_query(self, query: str) -> str: return hashlib.md5(query.encode("utf-8")).hexdigest() - def on_operation(self): - # type: () -> Generator[None, None, None] + def on_operation(self) -> Generator[None, None, None]: self._operation_name = self.execution_context.operation_name operation_type = "query" @@ -205,8 +198,7 @@ def on_operation(self): if root_span: root_span.op = op - def on_validate(self): - # type: () -> Generator[None, None, None] + def on_validate(self) -> Generator[None, None, None]: with sentry_sdk.start_span( op=OP.GRAPHQL_VALIDATE, name="validation", @@ -214,8 +206,7 @@ def on_validate(self): ): yield - def on_parse(self): - # type: () -> Generator[None, None, None] + def on_parse(self) -> Generator[None, None, None]: with sentry_sdk.start_span( op=OP.GRAPHQL_PARSE, name="parsing", @@ -223,12 +214,21 @@ def on_parse(self): ): yield - def should_skip_tracing(self, _next, info): - # type: (Callable[[Any, GraphQLResolveInfo, Any, Any], Any], GraphQLResolveInfo) -> bool + def should_skip_tracing( + self, + _next: Callable[[Any, GraphQLResolveInfo, Any, Any], Any], + info: GraphQLResolveInfo, + ) -> bool: return strawberry_should_skip_tracing(_next, info) - async def _resolve(self, _next, root, info, *args, **kwargs): - # type: (Callable[[Any, GraphQLResolveInfo, Any, Any], Any], Any, GraphQLResolveInfo, str, Any) -> Any + async def _resolve( + self, + _next: Callable[[Any, GraphQLResolveInfo, Any, Any], Any], + root: Any, + info: GraphQLResolveInfo, + *args: str, + **kwargs: Any, + ) -> Any: result = _next(root, info, *args, **kwargs) if isawaitable(result): @@ -236,8 +236,14 @@ async def _resolve(self, _next, root, info, *args, **kwargs): return result - async def resolve(self, _next, root, info, *args, **kwargs): - # type: (Callable[[Any, GraphQLResolveInfo, Any, Any], Any], Any, GraphQLResolveInfo, str, Any) -> Any + async def resolve( + self, + _next: Callable[[Any, GraphQLResolveInfo, Any, Any], Any], + root: Any, + info: GraphQLResolveInfo, + *args: str, + **kwargs: Any, + ) -> Any: if self.should_skip_tracing(_next, info): return await self._resolve(_next, root, info, *args, **kwargs) @@ -257,8 +263,14 @@ async def resolve(self, _next, root, info, *args, **kwargs): class SentrySyncExtension(SentryAsyncExtension): - def resolve(self, _next, root, info, *args, **kwargs): - # type: (Callable[[Any, Any, Any, Any], Any], Any, GraphQLResolveInfo, str, Any) -> Any + def resolve( + self, + _next: Callable[[Any, Any, Any, Any], Any], + root: Any, + info: GraphQLResolveInfo, + *args: str, + **kwargs: Any, + ) -> Any: if self.should_skip_tracing(_next, info): return _next(root, info, *args, **kwargs) @@ -277,24 +289,26 @@ def resolve(self, _next, root, info, *args, **kwargs): return _next(root, info, *args, **kwargs) -def _patch_views(): - # type: () -> None +def _patch_views() -> None: old_async_view_handle_errors = async_base_view.AsyncBaseHTTPView._handle_errors old_sync_view_handle_errors = sync_base_view.SyncBaseHTTPView._handle_errors - def _sentry_patched_async_view_handle_errors(self, errors, response_data): - # type: (Any, List[GraphQLError], GraphQLHTTPResponse) -> None + def _sentry_patched_async_view_handle_errors( + self: Any, errors: List[GraphQLError], response_data: GraphQLHTTPResponse + ) -> None: old_async_view_handle_errors(self, errors, response_data) _sentry_patched_handle_errors(self, errors, response_data) - def _sentry_patched_sync_view_handle_errors(self, errors, response_data): - # type: (Any, List[GraphQLError], GraphQLHTTPResponse) -> None + def _sentry_patched_sync_view_handle_errors( + self: Any, errors: List[GraphQLError], response_data: GraphQLHTTPResponse + ) -> None: old_sync_view_handle_errors(self, errors, response_data) _sentry_patched_handle_errors(self, errors, response_data) @ensure_integration_enabled(StrawberryIntegration) - def _sentry_patched_handle_errors(self, errors, response_data): - # type: (Any, List[GraphQLError], GraphQLHTTPResponse) -> None + def _sentry_patched_handle_errors( + self: Any, errors: List[GraphQLError], response_data: GraphQLHTTPResponse + ) -> None: if not errors: return @@ -322,18 +336,18 @@ def _sentry_patched_handle_errors(self, errors, response_data): ) -def _make_request_event_processor(execution_context): - # type: (ExecutionContext) -> EventProcessor +def _make_request_event_processor( + execution_context: ExecutionContext, +) -> EventProcessor: - def inner(event, hint): - # type: (Event, dict[str, Any]) -> Event + def inner(event: Event, hint: dict[str, Any]) -> Event: with capture_internal_exceptions(): if should_send_default_pii(): request_data = event.setdefault("request", {}) request_data["api_target"] = "graphql" if not request_data.get("data"): - data = {"query": execution_context.query} # type: dict[str, Any] + data: dict[str, Any] = {"query": execution_context.query} if execution_context.variables: data["variables"] = execution_context.variables if execution_context.operation_name: @@ -352,11 +366,11 @@ def inner(event, hint): return inner -def _make_response_event_processor(response_data): - # type: (GraphQLHTTPResponse) -> EventProcessor +def _make_response_event_processor( + response_data: GraphQLHTTPResponse, +) -> EventProcessor: - def inner(event, hint): - # type: (Event, dict[str, Any]) -> Event + def inner(event: Event, hint: dict[str, Any]) -> Event: with capture_internal_exceptions(): if should_send_default_pii(): contexts = event.setdefault("contexts", {}) @@ -367,8 +381,7 @@ def inner(event, hint): return inner -def _guess_if_using_async(extensions): - # type: (List[SchemaExtension]) -> bool +def _guess_if_using_async(extensions: List[SchemaExtension]) -> bool: return bool( {"starlette", "starlite", "litestar", "fastapi"} & set(_get_installed_modules()) ) diff --git a/sentry_sdk/integrations/sys_exit.py b/sentry_sdk/integrations/sys_exit.py index 2341e11359..ff1a97d5b6 100644 --- a/sentry_sdk/integrations/sys_exit.py +++ b/sentry_sdk/integrations/sys_exit.py @@ -1,3 +1,4 @@ +from __future__ import annotations import functools import sys @@ -24,23 +25,19 @@ class SysExitIntegration(Integration): identifier = "sys_exit" - def __init__(self, *, capture_successful_exits=False): - # type: (bool) -> None + def __init__(self, *, capture_successful_exits: bool = False) -> None: self._capture_successful_exits = capture_successful_exits @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: SysExitIntegration._patch_sys_exit() @staticmethod - def _patch_sys_exit(): - # type: () -> None - old_exit = sys.exit # type: Callable[[Union[str, int, None]], NoReturn] + def _patch_sys_exit() -> None: + old_exit: Callable[[Union[str, int, None]], NoReturn] = sys.exit @functools.wraps(old_exit) - def sentry_patched_exit(__status=0): - # type: (Union[str, int, None]) -> NoReturn + def sentry_patched_exit(__status: Union[str, int, None] = 0) -> NoReturn: # @ensure_integration_enabled ensures that this is non-None integration = sentry_sdk.get_client().get_integration(SysExitIntegration) if integration is None: @@ -60,8 +57,7 @@ def sentry_patched_exit(__status=0): sys.exit = sentry_patched_exit -def _capture_exception(exc): - # type: (SystemExit) -> None +def _capture_exception(exc: SystemExit) -> None: event, hint = event_from_exception( exc, client_options=sentry_sdk.get_client().options, diff --git a/sentry_sdk/integrations/threading.py b/sentry_sdk/integrations/threading.py index 8d0bb69f9d..1f4ea080f6 100644 --- a/sentry_sdk/integrations/threading.py +++ b/sentry_sdk/integrations/threading.py @@ -1,3 +1,4 @@ +from __future__ import annotations import sys import warnings from functools import wraps @@ -28,13 +29,11 @@ class ThreadingIntegration(Integration): identifier = "threading" - def __init__(self, propagate_scope=True): - # type: (bool) -> None + def __init__(self, propagate_scope: bool = True) -> None: self.propagate_scope = propagate_scope @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: old_start = Thread.start try: @@ -47,8 +46,7 @@ def setup_once(): channels_version = None @wraps(old_start) - def sentry_start(self, *a, **kw): - # type: (Thread, *Any, **Any) -> Any + def sentry_start(self: Thread, *a: Any, **kw: Any) -> Any: integration = sentry_sdk.get_client().get_integration(ThreadingIntegration) if integration is None: return old_start(self, *a, **kw) @@ -98,13 +96,14 @@ def sentry_start(self, *a, **kw): Thread.start = sentry_start # type: ignore -def _wrap_run(isolation_scope_to_use, current_scope_to_use, old_run_func): - # type: (sentry_sdk.Scope, sentry_sdk.Scope, F) -> F +def _wrap_run( + isolation_scope_to_use: sentry_sdk.Scope, + current_scope_to_use: sentry_sdk.Scope, + old_run_func: F, +) -> F: @wraps(old_run_func) - def run(*a, **kw): - # type: (*Any, **Any) -> Any - def _run_old_run_func(): - # type: () -> Any + def run(*a: Any, **kw: Any) -> Any: + def _run_old_run_func() -> Any: try: self = current_thread() return old_run_func(self, *a, **kw) @@ -118,8 +117,7 @@ def _run_old_run_func(): return run # type: ignore -def _capture_exception(): - # type: () -> ExcInfo +def _capture_exception() -> ExcInfo: exc_info = sys.exc_info() client = sentry_sdk.get_client() diff --git a/sentry_sdk/integrations/tornado.py b/sentry_sdk/integrations/tornado.py index 07f3e6575c..cb5ceab061 100644 --- a/sentry_sdk/integrations/tornado.py +++ b/sentry_sdk/integrations/tornado.py @@ -1,3 +1,4 @@ +from __future__ import annotations import weakref import contextlib from inspect import iscoroutinefunction @@ -56,8 +57,7 @@ class TornadoIntegration(Integration): origin = f"auto.http.{identifier}" @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: _check_minimum_version(TornadoIntegration, TORNADO_VERSION) if not HAS_REAL_CONTEXTVARS: @@ -77,16 +77,18 @@ def setup_once(): if awaitable: # Starting Tornado 6 RequestHandler._execute method is a standard Python coroutine (async/await) # In that case our method should be a coroutine function too - async def sentry_execute_request_handler(self, *args, **kwargs): - # type: (RequestHandler, *Any, **Any) -> Any + async def sentry_execute_request_handler( + self: RequestHandler, *args: Any, **kwargs: Any + ) -> Any: with _handle_request_impl(self): return await old_execute(self, *args, **kwargs) else: @coroutine # type: ignore - def sentry_execute_request_handler(self, *args, **kwargs): - # type: (RequestHandler, *Any, **Any) -> Any + def sentry_execute_request_handler( + self: RequestHandler, *args: Any, **kwargs: Any + ) -> Any: with _handle_request_impl(self): result = yield from old_execute(self, *args, **kwargs) return result @@ -95,8 +97,14 @@ def sentry_execute_request_handler(self, *args, **kwargs): old_log_exception = RequestHandler.log_exception - def sentry_log_exception(self, ty, value, tb, *args, **kwargs): - # type: (Any, type, BaseException, Any, *Any, **Any) -> Optional[Any] + def sentry_log_exception( + self: Any, + ty: type, + value: BaseException, + tb: Any, + *args: Any, + **kwargs: Any, + ) -> Optional[Any]: _capture_exception(ty, value, tb) return old_log_exception(self, ty, value, tb, *args, **kwargs) @@ -104,8 +112,7 @@ def sentry_log_exception(self, ty, value, tb, *args, **kwargs): @contextlib.contextmanager -def _handle_request_impl(self): - # type: (RequestHandler) -> Generator[None, None, None] +def _handle_request_impl(self: RequestHandler) -> Generator[None, None, None]: integration = sentry_sdk.get_client().get_integration(TornadoIntegration) if integration is None: @@ -136,8 +143,7 @@ def _handle_request_impl(self): @ensure_integration_enabled(TornadoIntegration) -def _capture_exception(ty, value, tb): - # type: (type, BaseException, Any) -> None +def _capture_exception(ty: type, value: BaseException, tb: Any) -> None: if isinstance(value, HTTPError): return @@ -150,10 +156,8 @@ def _capture_exception(ty, value, tb): sentry_sdk.capture_event(event, hint=hint) -def _make_event_processor(weak_handler): - # type: (Callable[[], RequestHandler]) -> EventProcessor - def tornado_processor(event, hint): - # type: (Event, dict[str, Any]) -> Event +def _make_event_processor(weak_handler: Callable[[], RequestHandler]) -> EventProcessor: + def tornado_processor(event: Event, hint: dict[str, Any]) -> Event: handler = weak_handler() if handler is None: return event @@ -192,42 +196,34 @@ def tornado_processor(event, hint): class TornadoRequestExtractor(RequestExtractor): - def content_length(self): - # type: () -> int + def content_length(self) -> int: if self.request.body is None: return 0 return len(self.request.body) - def cookies(self): - # type: () -> Dict[str, str] + def cookies(self) -> Dict[str, str]: return {k: v.value for k, v in self.request.cookies.items()} - def raw_data(self): - # type: () -> bytes + def raw_data(self) -> bytes: return self.request.body - def form(self): - # type: () -> Dict[str, Any] + def form(self) -> Dict[str, Any]: return { k: [v.decode("latin1", "replace") for v in vs] for k, vs in self.request.body_arguments.items() } - def is_json(self): - # type: () -> bool + def is_json(self) -> bool: return _is_json_content_type(self.request.headers.get("content-type")) - def files(self): - # type: () -> Dict[str, Any] + def files(self) -> Dict[str, Any]: return {k: v[0] for k, v in self.request.files.items() if v} - def size_of_file(self, file): - # type: (Any) -> int + def size_of_file(self, file: Any) -> int: return len(file.body or ()) -def _prepopulate_attributes(request): - # type: (HTTPServerRequest) -> dict[str, Any] +def _prepopulate_attributes(request: HTTPServerRequest) -> dict[str, Any]: # https://www.tornadoweb.org/en/stable/httputil.html#tornado.httputil.HTTPServerRequest attributes = {} diff --git a/sentry_sdk/integrations/trytond.py b/sentry_sdk/integrations/trytond.py index fd2c6f389f..91ed51180c 100644 --- a/sentry_sdk/integrations/trytond.py +++ b/sentry_sdk/integrations/trytond.py @@ -1,3 +1,4 @@ +from __future__ import annotations import sentry_sdk from sentry_sdk.integrations import _check_minimum_version, Integration from sentry_sdk.integrations.wsgi import SentryWsgiMiddleware @@ -15,11 +16,11 @@ class TrytondWSGIIntegration(Integration): identifier = "trytond_wsgi" origin = f"auto.http.{identifier}" - def __init__(self): # type: () -> None + def __init__(self) -> None: pass @staticmethod - def setup_once(): # type: () -> None + def setup_once() -> None: _check_minimum_version(TrytondWSGIIntegration, trytond_version) app.wsgi_app = SentryWsgiMiddleware( @@ -28,7 +29,7 @@ def setup_once(): # type: () -> None ) @ensure_integration_enabled(TrytondWSGIIntegration) - def error_handler(e): # type: (Exception) -> None + def error_handler(e: Exception) -> None: if isinstance(e, TrytonException): return else: diff --git a/sentry_sdk/integrations/typer.py b/sentry_sdk/integrations/typer.py index 8879d6d0d0..ab3a22a6ff 100644 --- a/sentry_sdk/integrations/typer.py +++ b/sentry_sdk/integrations/typer.py @@ -1,3 +1,4 @@ +from __future__ import annotations import sentry_sdk from sentry_sdk.utils import ( capture_internal_exceptions, @@ -30,15 +31,16 @@ class TyperIntegration(Integration): identifier = "typer" @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: typer.main.except_hook = _make_excepthook(typer.main.except_hook) # type: ignore -def _make_excepthook(old_excepthook): - # type: (Excepthook) -> Excepthook - def sentry_sdk_excepthook(type_, value, traceback): - # type: (Type[BaseException], BaseException, Optional[TracebackType]) -> None +def _make_excepthook(old_excepthook: Excepthook) -> Excepthook: + def sentry_sdk_excepthook( + type_: Type[BaseException], + value: BaseException, + traceback: Optional[TracebackType], + ) -> None: integration = sentry_sdk.get_client().get_integration(TyperIntegration) # Note: If we replace this with ensure_integration_enabled then diff --git a/sentry_sdk/integrations/unleash.py b/sentry_sdk/integrations/unleash.py index 6daa0a411f..6dc63cc5a8 100644 --- a/sentry_sdk/integrations/unleash.py +++ b/sentry_sdk/integrations/unleash.py @@ -1,3 +1,4 @@ +from __future__ import annotations from functools import wraps from typing import Any @@ -14,14 +15,14 @@ class UnleashIntegration(Integration): identifier = "unleash" @staticmethod - def setup_once(): - # type: () -> None + def setup_once() -> None: # Wrap and patch evaluation methods (class methods) old_is_enabled = UnleashClient.is_enabled @wraps(old_is_enabled) - def sentry_is_enabled(self, feature, *args, **kwargs): - # type: (UnleashClient, str, *Any, **Any) -> Any + def sentry_is_enabled( + self: UnleashClient, feature: str, *args: Any, **kwargs: Any + ) -> Any: enabled = old_is_enabled(self, feature, *args, **kwargs) # We have no way of knowing what type of unleash feature this is, so we have to treat diff --git a/sentry_sdk/integrations/wsgi.py b/sentry_sdk/integrations/wsgi.py index 88708d6080..37f6946dd4 100644 --- a/sentry_sdk/integrations/wsgi.py +++ b/sentry_sdk/integrations/wsgi.py @@ -1,3 +1,4 @@ +from __future__ import annotations import sys from functools import partial @@ -25,22 +26,26 @@ from typing import Callable from typing import Dict from typing import Iterator + from typing import Iterable from typing import Any from typing import Tuple + from typing import List from typing import Optional - from typing import TypeVar from typing import Protocol from sentry_sdk.utils import ExcInfo from sentry_sdk._types import Event, EventProcessor - WsgiResponseIter = TypeVar("WsgiResponseIter") - WsgiResponseHeaders = TypeVar("WsgiResponseHeaders") - WsgiExcInfo = TypeVar("WsgiExcInfo") + WsgiResponseIter = Iterable[bytes] + WsgiResponseHeaders = List[Tuple[str, str]] class StartResponse(Protocol): - def __call__(self, status, response_headers, exc_info=None): # type: ignore - # type: (str, WsgiResponseHeaders, Optional[WsgiExcInfo]) -> WsgiResponseIter + def __call__( + self, + status: str, + response_headers: WsgiResponseHeaders, + exc_info: Optional[ExcInfo] = None, + ) -> WsgiResponseIter: pass @@ -58,13 +63,11 @@ def __call__(self, status, response_headers, exc_info=None): # type: ignore } -def wsgi_decoding_dance(s, charset="utf-8", errors="replace"): - # type: (str, str, str) -> str +def wsgi_decoding_dance(s: str, charset: str = "utf-8", errors: str = "replace") -> str: return s.encode("latin1").decode(charset, errors) -def get_request_url(environ, use_x_forwarded_for=False): - # type: (Dict[str, str], bool) -> str +def get_request_url(environ: Dict[str, str], use_x_forwarded_for: bool = False) -> str: """Return the absolute URL without query string for the given WSGI environment.""" script_name = environ.get("SCRIPT_NAME", "").rstrip("/") @@ -88,19 +91,19 @@ class SentryWsgiMiddleware: def __init__( self, - app, # type: Callable[[Dict[str, str], Callable[..., Any]], Any] - use_x_forwarded_for=False, # type: bool - span_origin=None, # type: Optional[str] - http_methods_to_capture=DEFAULT_HTTP_METHODS_TO_CAPTURE, # type: Tuple[str, ...] - ): - # type: (...) -> None + app: Callable[[Dict[str, str], Callable[..., Any]], Any], + use_x_forwarded_for: bool = False, + span_origin: Optional[str] = None, + http_methods_to_capture: Tuple[str, ...] = DEFAULT_HTTP_METHODS_TO_CAPTURE, + ) -> None: self.app = app self.use_x_forwarded_for = use_x_forwarded_for self.span_origin = span_origin self.http_methods_to_capture = http_methods_to_capture - def __call__(self, environ, start_response): - # type: (Dict[str, str], Callable[..., Any]) -> _ScopedResponse + def __call__( + self, environ: Dict[str, str], start_response: Callable[..., Any] + ) -> _ScopedResponse: if _wsgi_middleware_applied.get(False): return self.app(environ, start_response) @@ -144,8 +147,12 @@ def __call__(self, environ, start_response): return _ScopedResponse(scope, response) - def _run_original_app(self, environ, start_response, span): - # type: (dict[str, str], StartResponse, Optional[Span]) -> Any + def _run_original_app( + self, + environ: dict[str, str], + start_response: StartResponse, + span: Optional[Span], + ) -> Any: try: return self.app( environ, @@ -159,14 +166,13 @@ def _run_original_app(self, environ, start_response, span): reraise(*_capture_exception()) -def _sentry_start_response( # type: ignore - old_start_response, # type: StartResponse - span, # type: Optional[Span] - status, # type: str - response_headers, # type: WsgiResponseHeaders - exc_info=None, # type: Optional[WsgiExcInfo] -): - # type: (...) -> WsgiResponseIter +def _sentry_start_response( + old_start_response: StartResponse, + span: Optional[Span], + status: str, + response_headers: WsgiResponseHeaders, + exc_info: Optional[ExcInfo] = None, +) -> WsgiResponseIter: with capture_internal_exceptions(): status_int = int(status.split(" ", 1)[0]) if span is not None: @@ -181,8 +187,7 @@ def _sentry_start_response( # type: ignore return old_start_response(status, response_headers, exc_info) -def _get_environ(environ): - # type: (Dict[str, str]) -> Iterator[Tuple[str, str]] +def _get_environ(environ: Dict[str, str]) -> Iterator[Tuple[str, str]]: """ Returns our explicitly included environment variables we want to capture (server name, port and remote addr if pii is enabled). @@ -198,8 +203,7 @@ def _get_environ(environ): yield key, environ[key] -def get_client_ip(environ): - # type: (Dict[str, str]) -> Optional[Any] +def get_client_ip(environ: Dict[str, str]) -> Optional[Any]: """ Infer the user IP address from various headers. This cannot be used in security sensitive situations since the value may be forged from a client, @@ -218,8 +222,7 @@ def get_client_ip(environ): return environ.get("REMOTE_ADDR") -def _capture_exception(): - # type: () -> ExcInfo +def _capture_exception() -> ExcInfo: """ Captures the current exception and sends it to Sentry. Returns the ExcInfo tuple to it can be reraised afterwards. @@ -253,13 +256,11 @@ class _ScopedResponse: __slots__ = ("_response", "_scope") - def __init__(self, scope, response): - # type: (sentry_sdk.Scope, Iterator[bytes]) -> None + def __init__(self, scope: sentry_sdk.Scope, response: Iterator[bytes]) -> None: self._scope = scope self._response = response - def __iter__(self): - # type: () -> Iterator[bytes] + def __iter__(self) -> Iterator[bytes]: iterator = iter(self._response) while True: @@ -273,8 +274,7 @@ def __iter__(self): yield chunk - def close(self): - # type: () -> None + def close(self) -> None: with sentry_sdk.use_isolation_scope(self._scope): try: self._response.close() # type: ignore @@ -284,8 +284,9 @@ def close(self): reraise(*_capture_exception()) -def _make_wsgi_event_processor(environ, use_x_forwarded_for): - # type: (Dict[str, str], bool) -> EventProcessor +def _make_wsgi_event_processor( + environ: Dict[str, str], use_x_forwarded_for: bool +) -> EventProcessor: # It's a bit unfortunate that we have to extract and parse the request data # from the environ so eagerly, but there are a few good reasons for this. # @@ -305,8 +306,7 @@ def _make_wsgi_event_processor(environ, use_x_forwarded_for): env = dict(_get_environ(environ)) headers = _filter_headers(dict(_get_headers(environ))) - def event_processor(event, hint): - # type: (Event, Dict[str, Any]) -> Event + def event_processor(event: Event, hint: Dict[str, Any]) -> Event: with capture_internal_exceptions(): # if the code below fails halfway through we at least have some data request_info = event.setdefault("request", {}) @@ -327,8 +327,9 @@ def event_processor(event, hint): return event_processor -def _prepopulate_attributes(wsgi_environ, use_x_forwarded_for=False): - # type: (dict[str, str], bool) -> dict[str, str] +def _prepopulate_attributes( + wsgi_environ: dict[str, str], use_x_forwarded_for: bool = False +) -> dict[str, str]: """Extract span attributes from the WSGI environment.""" attributes = {} diff --git a/sentry_sdk/logger.py b/sentry_sdk/logger.py index c18cf91ff2..3d5d904312 100644 --- a/sentry_sdk/logger.py +++ b/sentry_sdk/logger.py @@ -1,7 +1,12 @@ # NOTE: this is the logger sentry exposes to users, not some generic logger. +from __future__ import annotations import functools import time -from typing import Any + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Any from sentry_sdk import get_client from sentry_sdk.utils import safe_repr @@ -18,13 +23,14 @@ ] -def _capture_log(severity_text, severity_number, template, **kwargs): - # type: (str, int, str, **Any) -> None +def _capture_log( + severity_text: str, severity_number: int, template: str, **kwargs: Any +) -> None: client = get_client() - attrs = { + attrs: dict[str, str | bool | float | int] = { "sentry.message.template": template, - } # type: dict[str, str | bool | float | int] + } if "attributes" in kwargs: attrs.update(kwargs.pop("attributes")) for k, v in kwargs.items(): @@ -65,8 +71,7 @@ def _capture_log(severity_text, severity_number, template, **kwargs): fatal = functools.partial(_capture_log, "fatal", 21) -def _otel_severity_text(otel_severity_number): - # type: (int) -> str +def _otel_severity_text(otel_severity_number: int) -> str: for (lower, upper), severity in OTEL_RANGES: if lower <= otel_severity_number <= upper: return severity @@ -74,8 +79,7 @@ def _otel_severity_text(otel_severity_number): return "default" -def _log_level_to_otel(level, mapping): - # type: (int, dict[Any, int]) -> tuple[int, str] +def _log_level_to_otel(level: int, mapping: dict[Any, int]) -> tuple[int, str]: for py_level, otel_severity_number in sorted(mapping.items(), reverse=True): if level >= py_level: return otel_severity_number, _otel_severity_text(otel_severity_number) diff --git a/sentry_sdk/monitor.py b/sentry_sdk/monitor.py index 68d9017bf9..5d58b5491f 100644 --- a/sentry_sdk/monitor.py +++ b/sentry_sdk/monitor.py @@ -1,14 +1,15 @@ +from __future__ import annotations import os import time from threading import Thread, Lock -import sentry_sdk from sentry_sdk.utils import logger from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Optional + from sentry_sdk.transport import Transport MAX_DOWNSAMPLE_FACTOR = 10 @@ -23,21 +24,19 @@ class Monitor: name = "sentry.monitor" - def __init__(self, transport, interval=10): - # type: (sentry_sdk.transport.Transport, float) -> None - self.transport = transport # type: sentry_sdk.transport.Transport - self.interval = interval # type: float + def __init__(self, transport: Transport, interval: float = 10) -> None: + self.transport: Transport = transport + self.interval: float = interval self._healthy = True - self._downsample_factor = 0 # type: int + self._downsample_factor: int = 0 - self._thread = None # type: Optional[Thread] + self._thread: Optional[Thread] = None self._thread_lock = Lock() - self._thread_for_pid = None # type: Optional[int] + self._thread_for_pid: Optional[int] = None self._running = True - def _ensure_running(self): - # type: () -> None + def _ensure_running(self) -> None: """ Check that the monitor has an active thread to run in, or create one if not. @@ -52,8 +51,7 @@ def _ensure_running(self): if self._thread_for_pid == os.getpid() and self._thread is not None: return None - def _thread(): - # type: (...) -> None + def _thread() -> None: while self._running: time.sleep(self.interval) if self._running: @@ -74,13 +72,11 @@ def _thread(): return None - def run(self): - # type: () -> None + def run(self) -> None: self.check_health() self.set_downsample_factor() - def set_downsample_factor(self): - # type: () -> None + def set_downsample_factor(self) -> None: if self._healthy: if self._downsample_factor > 0: logger.debug( @@ -95,8 +91,7 @@ def set_downsample_factor(self): self._downsample_factor, ) - def check_health(self): - # type: () -> None + def check_health(self) -> None: """ Perform the actual health checks, currently only checks if the transport is rate-limited. @@ -104,21 +99,17 @@ def check_health(self): """ self._healthy = self.transport.is_healthy() - def is_healthy(self): - # type: () -> bool + def is_healthy(self) -> bool: self._ensure_running() return self._healthy @property - def downsample_factor(self): - # type: () -> int + def downsample_factor(self) -> int: self._ensure_running() return self._downsample_factor - def kill(self): - # type: () -> None + def kill(self) -> None: self._running = False - def __del__(self): - # type: () -> None + def __del__(self) -> None: self.kill() diff --git a/sentry_sdk/opentelemetry/consts.py b/sentry_sdk/opentelemetry/consts.py index 7f7afce9e2..9ce21b237e 100644 --- a/sentry_sdk/opentelemetry/consts.py +++ b/sentry_sdk/opentelemetry/consts.py @@ -1,5 +1,4 @@ from opentelemetry.context import create_key -from sentry_sdk.tracing_utils import Baggage # propagation keys @@ -13,9 +12,10 @@ SENTRY_USE_ISOLATION_SCOPE_KEY = create_key("sentry_use_isolation_scope") # trace state keys -TRACESTATE_SAMPLED_KEY = Baggage.SENTRY_PREFIX + "sampled" -TRACESTATE_SAMPLE_RATE_KEY = Baggage.SENTRY_PREFIX + "sample_rate" -TRACESTATE_SAMPLE_RAND_KEY = Baggage.SENTRY_PREFIX + "sample_rand" +SENTRY_PREFIX = "sentry-" +TRACESTATE_SAMPLED_KEY = SENTRY_PREFIX + "sampled" +TRACESTATE_SAMPLE_RATE_KEY = SENTRY_PREFIX + "sample_rate" +TRACESTATE_SAMPLE_RAND_KEY = SENTRY_PREFIX + "sample_rand" # misc OTEL_SENTRY_CONTEXT = "otel" diff --git a/sentry_sdk/opentelemetry/contextvars_context.py b/sentry_sdk/opentelemetry/contextvars_context.py index abd4c60d3f..34d7866f3c 100644 --- a/sentry_sdk/opentelemetry/contextvars_context.py +++ b/sentry_sdk/opentelemetry/contextvars_context.py @@ -1,4 +1,5 @@ -from typing import cast, TYPE_CHECKING +from __future__ import annotations +from typing import TYPE_CHECKING from opentelemetry.trace import get_current_span, set_span_in_context from opentelemetry.trace.span import INVALID_SPAN @@ -13,36 +14,37 @@ SENTRY_USE_CURRENT_SCOPE_KEY, SENTRY_USE_ISOLATION_SCOPE_KEY, ) +from sentry_sdk.opentelemetry.scope import PotelScope, validate_scopes if TYPE_CHECKING: - from typing import Optional from contextvars import Token - import sentry_sdk.opentelemetry.scope as scope class SentryContextVarsRuntimeContext(ContextVarsRuntimeContext): - def attach(self, context): - # type: (Context) -> Token[Context] - scopes = get_value(SENTRY_SCOPES_KEY, context) + def attach(self, context: Context) -> Token[Context]: + scopes = validate_scopes(get_value(SENTRY_SCOPES_KEY, context)) - should_fork_isolation_scope = context.pop( - SENTRY_FORK_ISOLATION_SCOPE_KEY, False + should_fork_isolation_scope = bool( + context.pop(SENTRY_FORK_ISOLATION_SCOPE_KEY, False) ) - should_fork_isolation_scope = cast("bool", should_fork_isolation_scope) should_use_isolation_scope = context.pop(SENTRY_USE_ISOLATION_SCOPE_KEY, None) - should_use_isolation_scope = cast( - "Optional[scope.PotelScope]", should_use_isolation_scope + should_use_isolation_scope = ( + should_use_isolation_scope + if isinstance(should_use_isolation_scope, PotelScope) + else None ) should_use_current_scope = context.pop(SENTRY_USE_CURRENT_SCOPE_KEY, None) - should_use_current_scope = cast( - "Optional[scope.PotelScope]", should_use_current_scope + should_use_current_scope = ( + should_use_current_scope + if isinstance(should_use_current_scope, PotelScope) + else None ) if scopes: - scopes = cast("tuple[scope.PotelScope, scope.PotelScope]", scopes) - (current_scope, isolation_scope) = scopes + current_scope = scopes[0] + isolation_scope = scopes[1] else: current_scope = sentry_sdk.get_current_scope() isolation_scope = sentry_sdk.get_isolation_scope() diff --git a/sentry_sdk/opentelemetry/propagator.py b/sentry_sdk/opentelemetry/propagator.py index 16a0d19cc9..f76dcc3906 100644 --- a/sentry_sdk/opentelemetry/propagator.py +++ b/sentry_sdk/opentelemetry/propagator.py @@ -1,4 +1,4 @@ -from typing import cast +from __future__ import annotations from opentelemetry import trace from opentelemetry.context import ( @@ -37,12 +37,12 @@ extract_sentrytrace_data, should_propagate_trace, ) +from sentry_sdk.opentelemetry.scope import validate_scopes from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Optional, Set - import sentry_sdk.opentelemetry.scope as scope class SentryPropagator(TextMapPropagator): @@ -50,8 +50,12 @@ class SentryPropagator(TextMapPropagator): Propagates tracing headers for Sentry's tracing system in a way OTel understands. """ - def extract(self, carrier, context=None, getter=default_getter): - # type: (CarrierT, Optional[Context], Getter[CarrierT]) -> Context + def extract( + self, + carrier: CarrierT, + context: Optional[Context] = None, + getter: Getter[CarrierT] = default_getter, + ) -> Context: if context is None: context = get_current() @@ -93,13 +97,16 @@ def extract(self, carrier, context=None, getter=default_getter): modified_context = trace.set_span_in_context(span, context) return modified_context - def inject(self, carrier, context=None, setter=default_setter): - # type: (CarrierT, Optional[Context], Setter[CarrierT]) -> None - scopes = get_value(SENTRY_SCOPES_KEY, context) + def inject( + self, + carrier: CarrierT, + context: Optional[Context] = None, + setter: Setter[CarrierT] = default_setter, + ) -> None: + scopes = validate_scopes(get_value(SENTRY_SCOPES_KEY, context)) if not scopes: return - scopes = cast("tuple[scope.PotelScope, scope.PotelScope]", scopes) (current_scope, _) = scopes span = current_scope.span @@ -114,6 +121,5 @@ def inject(self, carrier, context=None, setter=default_setter): setter.set(carrier, key, value) @property - def fields(self): - # type: () -> Set[str] + def fields(self) -> Set[str]: return {SENTRY_TRACE_HEADER_NAME, BAGGAGE_HEADER_NAME} diff --git a/sentry_sdk/opentelemetry/sampler.py b/sentry_sdk/opentelemetry/sampler.py index ab3defe3de..878b856f5a 100644 --- a/sentry_sdk/opentelemetry/sampler.py +++ b/sentry_sdk/opentelemetry/sampler.py @@ -1,5 +1,5 @@ +from __future__ import annotations from decimal import Decimal -from typing import cast from opentelemetry import trace from opentelemetry.sdk.trace.sampling import Sampler, SamplingResult, Decision @@ -21,15 +21,16 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Any, Optional, Sequence, Union + from typing import Any, Optional, Sequence from opentelemetry.context import Context from opentelemetry.trace import Link, SpanKind from opentelemetry.trace.span import SpanContext from opentelemetry.util.types import Attributes -def get_parent_sampled(parent_context, trace_id): - # type: (Optional[SpanContext], int) -> Optional[bool] +def get_parent_sampled( + parent_context: Optional[SpanContext], trace_id: int +) -> Optional[bool]: if parent_context is None: return None @@ -54,8 +55,9 @@ def get_parent_sampled(parent_context, trace_id): return None -def get_parent_sample_rate(parent_context, trace_id): - # type: (Optional[SpanContext], int) -> Optional[float] +def get_parent_sample_rate( + parent_context: Optional[SpanContext], trace_id: int +) -> Optional[float]: if parent_context is None: return None @@ -74,8 +76,9 @@ def get_parent_sample_rate(parent_context, trace_id): return None -def get_parent_sample_rand(parent_context, trace_id): - # type: (Optional[SpanContext], int) -> Optional[Decimal] +def get_parent_sample_rand( + parent_context: Optional[SpanContext], trace_id: int +) -> Optional[Decimal]: if parent_context is None: return None @@ -91,8 +94,12 @@ def get_parent_sample_rand(parent_context, trace_id): return None -def dropped_result(span_context, attributes, sample_rate=None, sample_rand=None): - # type: (SpanContext, Attributes, Optional[float], Optional[Decimal]) -> SamplingResult +def dropped_result( + span_context: SpanContext, + attributes: Attributes, + sample_rate: Optional[float] = None, + sample_rand: Optional[Decimal] = None, +) -> SamplingResult: """ React to a span getting unsampled and return a DROP SamplingResult. @@ -129,8 +136,12 @@ def dropped_result(span_context, attributes, sample_rate=None, sample_rand=None) ) -def sampled_result(span_context, attributes, sample_rate=None, sample_rand=None): - # type: (SpanContext, Attributes, Optional[float], Optional[Decimal]) -> SamplingResult +def sampled_result( + span_context: SpanContext, + attributes: Attributes, + sample_rate: Optional[float] = None, + sample_rand: Optional[Decimal] = None, +) -> SamplingResult: """ React to a span being sampled and return a sampled SamplingResult. @@ -151,8 +162,12 @@ def sampled_result(span_context, attributes, sample_rate=None, sample_rand=None) ) -def _update_trace_state(span_context, sampled, sample_rate=None, sample_rand=None): - # type: (SpanContext, bool, Optional[float], Optional[Decimal]) -> TraceState +def _update_trace_state( + span_context: SpanContext, + sampled: bool, + sample_rate: Optional[float] = None, + sample_rand: Optional[Decimal] = None, +) -> TraceState: trace_state = span_context.trace_state sampled = "true" if sampled else "false" @@ -175,15 +190,14 @@ def _update_trace_state(span_context, sampled, sample_rate=None, sample_rand=Non class SentrySampler(Sampler): def should_sample( self, - parent_context, # type: Optional[Context] - trace_id, # type: int - name, # type: str - kind=None, # type: Optional[SpanKind] - attributes=None, # type: Attributes - links=None, # type: Optional[Sequence[Link]] - trace_state=None, # type: Optional[TraceState] - ): - # type: (...) -> SamplingResult + parent_context: Optional[Context], + trace_id: int, + name: str, + kind: Optional[SpanKind] = None, + attributes: Attributes = None, + links: Optional[Sequence[Link]] = None, + trace_state: Optional[TraceState] = None, + ) -> SamplingResult: client = sentry_sdk.get_client() parent_span_context = trace.get_current_span(parent_context).get_span_context() @@ -209,13 +223,12 @@ def should_sample( sample_rand = parent_sample_rand else: # We are the head SDK and we need to generate a new sample_rand - sample_rand = cast(Decimal, _generate_sample_rand(str(trace_id), (0, 1))) + sample_rand = _generate_sample_rand(str(trace_id), (0, 1)) # Explicit sampled value provided at start_span - custom_sampled = cast( - "Optional[bool]", attributes.get(SentrySpanAttribute.CUSTOM_SAMPLED) - ) - if custom_sampled is not None: + custom_sampled = attributes.get(SentrySpanAttribute.CUSTOM_SAMPLED) + + if custom_sampled is not None and isinstance(custom_sampled, bool): if is_root_span: sample_rate = float(custom_sampled) if sample_rate > 0: @@ -262,7 +275,8 @@ def should_sample( sample_rate_to_propagate = sample_rate # If the sample rate is invalid, drop the span - if not is_valid_sample_rate(sample_rate, source=self.__class__.__name__): + sample_rate = is_valid_sample_rate(sample_rate, source=self.__class__.__name__) + if sample_rate is None: logger.warning( f"[Tracing.Sampler] Discarding {name} because of invalid sample rate." ) @@ -275,7 +289,6 @@ def should_sample( sample_rate_to_propagate = sample_rate # Compare sample_rand to sample_rate to make the final sampling decision - sample_rate = float(cast("Union[bool, float, int]", sample_rate)) sampled = sample_rand < Decimal.from_float(sample_rate) if sampled: @@ -307,9 +320,13 @@ def get_description(self) -> str: return self.__class__.__name__ -def create_sampling_context(name, attributes, parent_span_context, trace_id): - # type: (str, Attributes, Optional[SpanContext], int) -> dict[str, Any] - sampling_context = { +def create_sampling_context( + name: str, + attributes: Attributes, + parent_span_context: Optional[SpanContext], + trace_id: int, +) -> dict[str, Any]: + sampling_context: dict[str, Any] = { "transaction_context": { "name": name, "op": attributes.get(SentrySpanAttribute.OP) if attributes else None, @@ -318,7 +335,7 @@ def create_sampling_context(name, attributes, parent_span_context, trace_id): ), }, "parent_sampled": get_parent_sampled(parent_span_context, trace_id), - } # type: dict[str, Any] + } if attributes is not None: sampling_context.update(attributes) diff --git a/sentry_sdk/opentelemetry/scope.py b/sentry_sdk/opentelemetry/scope.py index 4db5e288e3..ec398093c7 100644 --- a/sentry_sdk/opentelemetry/scope.py +++ b/sentry_sdk/opentelemetry/scope.py @@ -1,4 +1,4 @@ -from typing import cast +from __future__ import annotations from contextlib import contextmanager import warnings @@ -24,9 +24,6 @@ SENTRY_USE_ISOLATION_SCOPE_KEY, TRACESTATE_SAMPLED_KEY, ) -from sentry_sdk.opentelemetry.contextvars_context import ( - SentryContextVarsRuntimeContext, -) from sentry_sdk.opentelemetry.utils import trace_state_from_baggage from sentry_sdk.scope import Scope, ScopeType from sentry_sdk.tracing import Span @@ -38,26 +35,21 @@ class PotelScope(Scope): @classmethod - def _get_scopes(cls): - # type: () -> Optional[Tuple[PotelScope, PotelScope]] + def _get_scopes(cls) -> Optional[Tuple[PotelScope, PotelScope]]: """ Returns the current scopes tuple on the otel context. Internal use only. """ - return cast( - "Optional[Tuple[PotelScope, PotelScope]]", get_value(SENTRY_SCOPES_KEY) - ) + return validate_scopes(get_value(SENTRY_SCOPES_KEY)) @classmethod - def get_current_scope(cls): - # type: () -> PotelScope + def get_current_scope(cls) -> PotelScope: """ Returns the current scope. """ return cls._get_current_scope() or _INITIAL_CURRENT_SCOPE @classmethod - def _get_current_scope(cls): - # type: () -> Optional[PotelScope] + def _get_current_scope(cls) -> Optional[PotelScope]: """ Returns the current scope without creating a new one. Internal use only. """ @@ -65,16 +57,14 @@ def _get_current_scope(cls): return scopes[0] if scopes else None @classmethod - def get_isolation_scope(cls): - # type: () -> PotelScope + def get_isolation_scope(cls) -> PotelScope: """ Returns the isolation scope. """ return cls._get_isolation_scope() or _INITIAL_ISOLATION_SCOPE @classmethod - def _get_isolation_scope(cls): - # type: () -> Optional[PotelScope] + def _get_isolation_scope(cls) -> Optional[PotelScope]: """ Returns the isolation scope without creating a new one. Internal use only. """ @@ -82,8 +72,9 @@ def _get_isolation_scope(cls): return scopes[1] if scopes else None @contextmanager - def continue_trace(self, environ_or_headers): - # type: (Dict[str, Any]) -> Generator[None, None, None] + def continue_trace( + self, environ_or_headers: Dict[str, Any] + ) -> Generator[None, None, None]: """ Sets the propagation context from environment or headers to continue an incoming trace. Any span started within this context manager will use the same trace_id, parent_span_id @@ -98,8 +89,7 @@ def continue_trace(self, environ_or_headers): with use_span(NonRecordingSpan(span_context)): yield - def _incoming_otel_span_context(self): - # type: () -> Optional[SpanContext] + def _incoming_otel_span_context(self) -> Optional[SpanContext]: if self._propagation_context is None: return None # If sentry-trace extraction didn't have a parent_span_id, we don't have an upstream header @@ -132,8 +122,7 @@ def _incoming_otel_span_context(self): return span_context - def start_transaction(self, **kwargs): - # type: (Any) -> Span + def start_transaction(self, **kwargs: Any) -> Span: """ .. deprecated:: 3.0.0 This function is deprecated and will be removed in a future release. @@ -146,8 +135,7 @@ def start_transaction(self, **kwargs): ) return self.start_span(**kwargs) - def start_span(self, **kwargs): - # type: (Any) -> Span + def start_span(self, **kwargs: Any) -> Span: return Span(**kwargs) @@ -155,8 +143,7 @@ def start_span(self, **kwargs): _INITIAL_ISOLATION_SCOPE = PotelScope(ty=ScopeType.ISOLATION) -def setup_initial_scopes(): - # type: () -> None +def setup_initial_scopes() -> None: global _INITIAL_CURRENT_SCOPE, _INITIAL_ISOLATION_SCOPE _INITIAL_CURRENT_SCOPE = PotelScope(ty=ScopeType.CURRENT) _INITIAL_ISOLATION_SCOPE = PotelScope(ty=ScopeType.ISOLATION) @@ -165,17 +152,18 @@ def setup_initial_scopes(): attach(set_value(SENTRY_SCOPES_KEY, scopes)) -def setup_scope_context_management(): - # type: () -> None +def setup_scope_context_management() -> None: import opentelemetry.context + from sentry_sdk.opentelemetry.contextvars_context import ( + SentryContextVarsRuntimeContext, + ) opentelemetry.context._RUNTIME_CONTEXT = SentryContextVarsRuntimeContext() setup_initial_scopes() @contextmanager -def isolation_scope(): - # type: () -> Generator[PotelScope, None, None] +def isolation_scope() -> Generator[PotelScope, None, None]: context = set_value(SENTRY_FORK_ISOLATION_SCOPE_KEY, True) token = attach(context) try: @@ -185,8 +173,7 @@ def isolation_scope(): @contextmanager -def new_scope(): - # type: () -> Generator[PotelScope, None, None] +def new_scope() -> Generator[PotelScope, None, None]: token = attach(get_current()) try: yield PotelScope.get_current_scope() @@ -195,8 +182,7 @@ def new_scope(): @contextmanager -def use_scope(scope): - # type: (PotelScope) -> Generator[PotelScope, None, None] +def use_scope(scope: PotelScope) -> Generator[PotelScope, None, None]: context = set_value(SENTRY_USE_CURRENT_SCOPE_KEY, scope) token = attach(context) @@ -207,8 +193,9 @@ def use_scope(scope): @contextmanager -def use_isolation_scope(isolation_scope): - # type: (PotelScope) -> Generator[PotelScope, None, None] +def use_isolation_scope( + isolation_scope: PotelScope, +) -> Generator[PotelScope, None, None]: context = set_value(SENTRY_USE_ISOLATION_SCOPE_KEY, isolation_scope) token = attach(context) @@ -216,3 +203,15 @@ def use_isolation_scope(isolation_scope): yield isolation_scope finally: detach(token) + + +def validate_scopes(scopes: Any) -> Optional[Tuple[PotelScope, PotelScope]]: + if ( + isinstance(scopes, tuple) + and len(scopes) == 2 + and isinstance(scopes[0], PotelScope) + and isinstance(scopes[1], PotelScope) + ): + return scopes + else: + return None diff --git a/sentry_sdk/opentelemetry/span_processor.py b/sentry_sdk/opentelemetry/span_processor.py index a148fb0f62..f35af99920 100644 --- a/sentry_sdk/opentelemetry/span_processor.py +++ b/sentry_sdk/opentelemetry/span_processor.py @@ -1,5 +1,5 @@ +from __future__ import annotations from collections import deque, defaultdict -from typing import cast from opentelemetry.trace import ( format_trace_id, @@ -52,30 +52,24 @@ class SentrySpanProcessor(SpanProcessor): Converts OTel spans into Sentry spans so they can be sent to the Sentry backend. """ - def __new__(cls): - # type: () -> SentrySpanProcessor + def __new__(cls) -> SentrySpanProcessor: if not hasattr(cls, "instance"): cls.instance = super().__new__(cls) return cls.instance - def __init__(self): - # type: () -> None - self._children_spans = defaultdict( - list - ) # type: DefaultDict[int, List[ReadableSpan]] - self._dropped_spans = defaultdict(lambda: 0) # type: DefaultDict[int, int] + def __init__(self) -> None: + self._children_spans: DefaultDict[int, List[ReadableSpan]] = defaultdict(list) + self._dropped_spans: DefaultDict[int, int] = defaultdict(lambda: 0) - def on_start(self, span, parent_context=None): - # type: (Span, Optional[Context]) -> None + def on_start(self, span: Span, parent_context: Optional[Context] = None) -> None: if is_sentry_span(span): return self._add_root_span(span, get_current_span(parent_context)) self._start_profile(span) - def on_end(self, span): - # type: (ReadableSpan) -> None + def on_end(self, span: ReadableSpan) -> None: if is_sentry_span(span): return @@ -88,18 +82,15 @@ def on_end(self, span): self._append_child_span(span) # TODO-neel-potel not sure we need a clear like JS - def shutdown(self): - # type: () -> None + def shutdown(self) -> None: pass # TODO-neel-potel change default? this is 30 sec # TODO-neel-potel call this in client.flush - def force_flush(self, timeout_millis=30000): - # type: (int) -> bool + def force_flush(self, timeout_millis: int = 30000) -> bool: return True - def _add_root_span(self, span, parent_span): - # type: (Span, AbstractSpan) -> None + def _add_root_span(self, span: Span, parent_span: AbstractSpan) -> None: """ This is required to make Span.root_span work since we can't traverse back to the root purely with otel efficiently. @@ -112,8 +103,7 @@ def _add_root_span(self, span, parent_span): # root span points to itself set_sentry_meta(span, "root_span", span) - def _start_profile(self, span): - # type: (Span) -> None + def _start_profile(self, span: Span) -> None: try_autostart_continuous_profiler() profiler_id = get_profiler_id() @@ -148,14 +138,12 @@ def _start_profile(self, span): span.set_attribute(SPANDATA.PROFILER_ID, profiler_id) set_sentry_meta(span, "continuous_profile", continuous_profile) - def _stop_profile(self, span): - # type: (ReadableSpan) -> None + def _stop_profile(self, span: ReadableSpan) -> None: continuous_profiler = get_sentry_meta(span, "continuous_profile") if continuous_profiler: continuous_profiler.stop() - def _flush_root_span(self, span): - # type: (ReadableSpan) -> None + def _flush_root_span(self, span: ReadableSpan) -> None: transaction_event = self._root_span_to_transaction_event(span) if not transaction_event: return @@ -176,8 +164,7 @@ def _flush_root_span(self, span): sentry_sdk.capture_event(transaction_event) self._cleanup_references([span] + collected_spans) - def _append_child_span(self, span): - # type: (ReadableSpan) -> None + def _append_child_span(self, span: ReadableSpan) -> None: if not span.parent: return @@ -192,14 +179,13 @@ def _append_child_span(self, span): else: self._dropped_spans[span.parent.span_id] += 1 - def _collect_children(self, span): - # type: (ReadableSpan) -> tuple[List[ReadableSpan], int] + def _collect_children(self, span: ReadableSpan) -> tuple[List[ReadableSpan], int]: if not span.context: return [], 0 children = [] dropped_spans = 0 - bfs_queue = deque() # type: Deque[int] + bfs_queue: Deque[int] = deque() bfs_queue.append(span.context.span_id) while bfs_queue: @@ -215,8 +201,7 @@ def _collect_children(self, span): # we construct the event from scratch here # and not use the current Transaction class for easier refactoring - def _root_span_to_transaction_event(self, span): - # type: (ReadableSpan) -> Optional[Event] + def _root_span_to_transaction_event(self, span: ReadableSpan) -> Optional[Event]: if not span.context: return None @@ -250,23 +235,20 @@ def _root_span_to_transaction_event(self, span): } ) - profile = cast("Optional[Profile]", get_sentry_meta(span, "profile")) - if profile: + profile = get_sentry_meta(span, "profile") + if profile is not None and isinstance(profile, Profile): profile.__exit__(None, None, None) if profile.valid(): event["profile"] = profile return event - def _span_to_json(self, span): - # type: (ReadableSpan) -> Optional[dict[str, Any]] + def _span_to_json(self, span: ReadableSpan) -> Optional[dict[str, Any]]: if not span.context: return None - # This is a safe cast because dict[str, Any] is a superset of Event - span_json = cast( - "dict[str, Any]", self._common_span_transaction_attributes_as_json(span) - ) + # need to ignore the type here due to TypedDict nonsense + span_json: Optional[dict[str, Any]] = self._common_span_transaction_attributes_as_json(span) # type: ignore if span_json is None: return None @@ -299,15 +281,16 @@ def _span_to_json(self, span): return span_json - def _common_span_transaction_attributes_as_json(self, span): - # type: (ReadableSpan) -> Optional[Event] + def _common_span_transaction_attributes_as_json( + self, span: ReadableSpan + ) -> Optional[Event]: if not span.start_time or not span.end_time: return None - common_json = { + common_json: Event = { "start_timestamp": convert_from_otel_timestamp(span.start_time), "timestamp": convert_from_otel_timestamp(span.end_time), - } # type: Event + } tags = extract_span_attributes(span, SentrySpanAttribute.TAG) if tags: @@ -315,13 +298,11 @@ def _common_span_transaction_attributes_as_json(self, span): return common_json - def _cleanup_references(self, spans): - # type: (List[ReadableSpan]) -> None + def _cleanup_references(self, spans: List[ReadableSpan]) -> None: for span in spans: delete_sentry_meta(span) - def _log_debug_info(self): - # type: () -> None + def _log_debug_info(self) -> None: import pprint pprint.pprint( diff --git a/sentry_sdk/opentelemetry/tracing.py b/sentry_sdk/opentelemetry/tracing.py index 5002f71c50..a736a4a477 100644 --- a/sentry_sdk/opentelemetry/tracing.py +++ b/sentry_sdk/opentelemetry/tracing.py @@ -1,3 +1,4 @@ +from __future__ import annotations from opentelemetry import trace from opentelemetry.propagate import set_global_textmap from opentelemetry.sdk.trace import TracerProvider, Span, ReadableSpan @@ -10,16 +11,14 @@ from sentry_sdk.utils import logger -def patch_readable_span(): - # type: () -> None +def patch_readable_span() -> None: """ We need to pass through sentry specific metadata/objects from Span to ReadableSpan to work with them consistently in the SpanProcessor. """ old_readable_span = Span._readable_span - def sentry_patched_readable_span(self): - # type: (Span) -> ReadableSpan + def sentry_patched_readable_span(self: Span) -> ReadableSpan: readable_span = old_readable_span(self) readable_span._sentry_meta = getattr(self, "_sentry_meta", {}) # type: ignore[attr-defined] return readable_span @@ -27,8 +26,7 @@ def sentry_patched_readable_span(self): Span._readable_span = sentry_patched_readable_span # type: ignore[method-assign] -def setup_sentry_tracing(): - # type: () -> None +def setup_sentry_tracing() -> None: # TracerProvider can only be set once. If we're the first ones setting it, # there's no issue. If it already exists, we need to patch it. from opentelemetry.trace import _TRACER_PROVIDER diff --git a/sentry_sdk/opentelemetry/utils.py b/sentry_sdk/opentelemetry/utils.py index abee007a6b..114b1dfd36 100644 --- a/sentry_sdk/opentelemetry/utils.py +++ b/sentry_sdk/opentelemetry/utils.py @@ -1,5 +1,5 @@ +from __future__ import annotations import re -from typing import cast from datetime import datetime, timezone from urllib3.util import parse_url as urlparse @@ -30,9 +30,11 @@ from sentry_sdk._types import TYPE_CHECKING if TYPE_CHECKING: - from typing import Any, Optional, Mapping, Sequence, Union + from typing import Any, Optional, Mapping, Sequence, Union, Type, TypeVar from sentry_sdk._types import OtelExtractedSpanData + T = TypeVar("T") + GRPC_ERROR_MAP = { "1": SPANSTATUS.CANCELLED, @@ -54,8 +56,7 @@ } -def is_sentry_span(span): - # type: (ReadableSpan) -> bool +def is_sentry_span(span: ReadableSpan) -> bool: """ Break infinite loop: HTTP requests to Sentry are caught by OTel and send again to Sentry. @@ -65,10 +66,8 @@ def is_sentry_span(span): if not span.attributes: return False - span_url = span.attributes.get(SpanAttributes.HTTP_URL, None) - span_url = cast("Optional[str]", span_url) - - if not span_url: + span_url = get_typed_attribute(span.attributes, SpanAttributes.HTTP_URL, str) + if span_url is None: return False dsn_url = None @@ -89,32 +88,30 @@ def is_sentry_span(span): return False -def convert_from_otel_timestamp(time): - # type: (int) -> datetime +def convert_from_otel_timestamp(time: int) -> datetime: """Convert an OTel nanosecond-level timestamp to a datetime.""" return datetime.fromtimestamp(time / 1e9, timezone.utc) -def convert_to_otel_timestamp(time): - # type: (Union[datetime, float]) -> int +def convert_to_otel_timestamp(time: Union[datetime, float]) -> int: """Convert a datetime to an OTel timestamp (with nanosecond precision).""" if isinstance(time, datetime): return int(time.timestamp() * 1e9) return int(time * 1e9) -def extract_transaction_name_source(span): - # type: (ReadableSpan) -> tuple[Optional[str], Optional[str]] +def extract_transaction_name_source( + span: ReadableSpan, +) -> tuple[Optional[str], Optional[str]]: if not span.attributes: return (None, None) return ( - cast("Optional[str]", span.attributes.get(SentrySpanAttribute.NAME)), - cast("Optional[str]", span.attributes.get(SentrySpanAttribute.SOURCE)), + get_typed_attribute(span.attributes, SentrySpanAttribute.NAME, str), + get_typed_attribute(span.attributes, SentrySpanAttribute.SOURCE, str), ) -def extract_span_data(span): - # type: (ReadableSpan) -> OtelExtractedSpanData +def extract_span_data(span: ReadableSpan) -> OtelExtractedSpanData: op = span.name description = span.name status, http_status = extract_span_status(span) @@ -122,15 +119,15 @@ def extract_span_data(span): if span.attributes is None: return (op, description, status, http_status, origin) - attribute_op = cast("Optional[str]", span.attributes.get(SentrySpanAttribute.OP)) + attribute_op = get_typed_attribute(span.attributes, SentrySpanAttribute.OP, str) op = attribute_op or op - description = cast( - "str", span.attributes.get(SentrySpanAttribute.DESCRIPTION) or description + description = ( + get_typed_attribute(span.attributes, SentrySpanAttribute.DESCRIPTION, str) + or description ) - origin = cast("Optional[str]", span.attributes.get(SentrySpanAttribute.ORIGIN)) + origin = get_typed_attribute(span.attributes, SentrySpanAttribute.ORIGIN, str) - http_method = span.attributes.get(SpanAttributes.HTTP_METHOD) - http_method = cast("Optional[str]", http_method) + http_method = get_typed_attribute(span.attributes, SpanAttributes.HTTP_METHOD, str) if http_method: return span_data_for_http_method(span) @@ -165,11 +162,10 @@ def extract_span_data(span): return (op, description, status, http_status, origin) -def span_data_for_http_method(span): - # type: (ReadableSpan) -> OtelExtractedSpanData +def span_data_for_http_method(span: ReadableSpan) -> OtelExtractedSpanData: span_attributes = span.attributes or {} - op = cast("Optional[str]", span_attributes.get(SentrySpanAttribute.OP)) + op = get_typed_attribute(span_attributes, SentrySpanAttribute.OP, str) if op is None: op = "http" @@ -184,10 +180,9 @@ def span_data_for_http_method(span): peer_name = span_attributes.get(SpanAttributes.NET_PEER_NAME) # TODO-neel-potel remove description completely - description = span_attributes.get( - SentrySpanAttribute.DESCRIPTION - ) or span_attributes.get(SentrySpanAttribute.NAME) - description = cast("Optional[str]", description) + description = get_typed_attribute( + span_attributes, SentrySpanAttribute.DESCRIPTION, str + ) or get_typed_attribute(span_attributes, SentrySpanAttribute.NAME, str) if description is None: description = f"{http_method}" @@ -199,7 +194,7 @@ def span_data_for_http_method(span): description = f"{http_method} {peer_name}" else: url = span_attributes.get(SpanAttributes.HTTP_URL) - url = cast("Optional[str]", url) + url = get_typed_attribute(span_attributes, SpanAttributes.HTTP_URL, str) if url: parsed_url = urlparse(url) @@ -210,28 +205,24 @@ def span_data_for_http_method(span): status, http_status = extract_span_status(span) - origin = cast("Optional[str]", span_attributes.get(SentrySpanAttribute.ORIGIN)) + origin = get_typed_attribute(span_attributes, SentrySpanAttribute.ORIGIN, str) return (op, description, status, http_status, origin) -def span_data_for_db_query(span): - # type: (ReadableSpan) -> OtelExtractedSpanData +def span_data_for_db_query(span: ReadableSpan) -> OtelExtractedSpanData: span_attributes = span.attributes or {} - op = cast("str", span_attributes.get(SentrySpanAttribute.OP, OP.DB)) - - statement = span_attributes.get(SpanAttributes.DB_STATEMENT, None) - statement = cast("Optional[str]", statement) + op = get_typed_attribute(span_attributes, SentrySpanAttribute.OP, str) or OP.DB + statement = get_typed_attribute(span_attributes, SpanAttributes.DB_STATEMENT, str) description = statement or span.name - origin = cast("Optional[str]", span_attributes.get(SentrySpanAttribute.ORIGIN)) + origin = get_typed_attribute(span_attributes, SentrySpanAttribute.ORIGIN, str) return (op, description, None, None, origin) -def extract_span_status(span): - # type: (ReadableSpan) -> tuple[Optional[str], Optional[int]] +def extract_span_status(span: ReadableSpan) -> tuple[Optional[str], Optional[int]]: span_attributes = span.attributes or {} status = span.status or None @@ -266,8 +257,19 @@ def extract_span_status(span): return (SPANSTATUS.UNKNOWN_ERROR, None) -def infer_status_from_attributes(span_attributes): - # type: (Mapping[str, str | bool | int | float | Sequence[str] | Sequence[bool] | Sequence[int] | Sequence[float]]) -> tuple[Optional[str], Optional[int]] +def infer_status_from_attributes( + span_attributes: Mapping[ + str, + str + | bool + | int + | float + | Sequence[str] + | Sequence[bool] + | Sequence[int] + | Sequence[float], + ], +) -> tuple[Optional[str], Optional[int]]: http_status = get_http_status_code(span_attributes) if http_status: @@ -280,10 +282,23 @@ def infer_status_from_attributes(span_attributes): return (None, None) -def get_http_status_code(span_attributes): - # type: (Mapping[str, str | bool | int | float | Sequence[str] | Sequence[bool] | Sequence[int] | Sequence[float]]) -> Optional[int] +def get_http_status_code( + span_attributes: Mapping[ + str, + str + | bool + | int + | float + | Sequence[str] + | Sequence[bool] + | Sequence[int] + | Sequence[float], + ], +) -> Optional[int]: try: - http_status = span_attributes.get(SpanAttributes.HTTP_RESPONSE_STATUS_CODE) + http_status = get_typed_attribute( + span_attributes, SpanAttributes.HTTP_RESPONSE_STATUS_CODE, int + ) except AttributeError: # HTTP_RESPONSE_STATUS_CODE was added in 1.21, so if we're on an older # OTel version SpanAttributes.HTTP_RESPONSE_STATUS_CODE will throw an @@ -292,19 +307,18 @@ def get_http_status_code(span_attributes): if http_status is None: # Fall back to the deprecated attribute - http_status = span_attributes.get(SpanAttributes.HTTP_STATUS_CODE) - - http_status = cast("Optional[int]", http_status) + http_status = get_typed_attribute( + span_attributes, SpanAttributes.HTTP_STATUS_CODE, int + ) return http_status -def extract_span_attributes(span, namespace): - # type: (ReadableSpan, str) -> dict[str, Any] +def extract_span_attributes(span: ReadableSpan, namespace: str) -> dict[str, Any]: """ Extract Sentry-specific span attributes and make them look the way Sentry expects. """ - extracted_attrs = {} # type: dict[str, Any] + extracted_attrs: dict[str, Any] = {} for attr, value in (span.attributes or {}).items(): if attr.startswith(namespace): @@ -314,8 +328,9 @@ def extract_span_attributes(span, namespace): return extracted_attrs -def get_trace_context(span, span_data=None): - # type: (ReadableSpan, Optional[OtelExtractedSpanData]) -> dict[str, Any] +def get_trace_context( + span: ReadableSpan, span_data: Optional[OtelExtractedSpanData] = None +) -> dict[str, Any]: if not span.context: return {} @@ -328,13 +343,13 @@ def get_trace_context(span, span_data=None): (op, _, status, _, origin) = span_data - trace_context = { + trace_context: dict[str, Any] = { "trace_id": trace_id, "span_id": span_id, "parent_span_id": parent_span_id, "op": op, "origin": origin or DEFAULT_SPAN_ORIGIN, - } # type: dict[str, Any] + } if status: trace_context["status"] = status @@ -350,8 +365,7 @@ def get_trace_context(span, span_data=None): return trace_context -def trace_state_from_baggage(baggage): - # type: (Baggage) -> TraceState +def trace_state_from_baggage(baggage: Baggage) -> TraceState: items = [] for k, v in baggage.sentry_items.items(): key = Baggage.SENTRY_PREFIX + quote(k) @@ -360,13 +374,11 @@ def trace_state_from_baggage(baggage): return TraceState(items) -def baggage_from_trace_state(trace_state): - # type: (TraceState) -> Baggage +def baggage_from_trace_state(trace_state: TraceState) -> Baggage: return Baggage(dsc_from_trace_state(trace_state)) -def serialize_trace_state(trace_state): - # type: (TraceState) -> str +def serialize_trace_state(trace_state: TraceState) -> str: sentry_items = [] for k, v in trace_state.items(): if Baggage.SENTRY_PREFIX_REGEX.match(k): @@ -374,8 +386,7 @@ def serialize_trace_state(trace_state): return ",".join(key + "=" + value for key, value in sentry_items) -def dsc_from_trace_state(trace_state): - # type: (TraceState) -> dict[str, str] +def dsc_from_trace_state(trace_state: TraceState) -> dict[str, str]: dsc = {} for k, v in trace_state.items(): if Baggage.SENTRY_PREFIX_REGEX.match(k): @@ -384,16 +395,14 @@ def dsc_from_trace_state(trace_state): return dsc -def has_incoming_trace(trace_state): - # type: (TraceState) -> bool +def has_incoming_trace(trace_state: TraceState) -> bool: """ The existence of a sentry-trace_id in the baggage implies we continued an upstream trace. """ return (Baggage.SENTRY_PREFIX + "trace_id") in trace_state -def get_trace_state(span): - # type: (Union[AbstractSpan, ReadableSpan]) -> TraceState +def get_trace_state(span: Union[AbstractSpan, ReadableSpan]) -> TraceState: """ Get the existing trace_state with sentry items or populate it if we are the head SDK. @@ -451,34 +460,45 @@ def get_trace_state(span): return trace_state -def get_sentry_meta(span, key): - # type: (Union[AbstractSpan, ReadableSpan], str) -> Any +def get_sentry_meta(span: Union[AbstractSpan, ReadableSpan], key: str) -> Any: sentry_meta = getattr(span, "_sentry_meta", None) return sentry_meta.get(key) if sentry_meta else None -def set_sentry_meta(span, key, value): - # type: (Union[AbstractSpan, ReadableSpan], str, Any) -> None +def set_sentry_meta( + span: Union[AbstractSpan, ReadableSpan], key: str, value: Any +) -> None: sentry_meta = getattr(span, "_sentry_meta", {}) sentry_meta[key] = value span._sentry_meta = sentry_meta # type: ignore[union-attr] -def delete_sentry_meta(span): - # type: (Union[AbstractSpan, ReadableSpan]) -> None +def delete_sentry_meta(span: Union[AbstractSpan, ReadableSpan]) -> None: try: del span._sentry_meta # type: ignore[union-attr] except AttributeError: pass -def get_profile_context(span): - # type: (ReadableSpan) -> Optional[dict[str, str]] +def get_profile_context(span: ReadableSpan) -> Optional[dict[str, str]]: if not span.attributes: return None - profiler_id = cast("Optional[str]", span.attributes.get(SPANDATA.PROFILER_ID)) + profiler_id = get_typed_attribute(span.attributes, SPANDATA.PROFILER_ID, str) if profiler_id is None: return None return {"profiler_id": profiler_id} + + +def get_typed_attribute( + attributes: Mapping[str, Any], key: str, type: Type[T] +) -> Optional[T]: + """ + helper method to coerce types of attribute values + """ + value = attributes.get(key) + if value is not None and isinstance(value, type): + return value + else: + return None diff --git a/sentry_sdk/profiler/continuous_profiler.py b/sentry_sdk/profiler/continuous_profiler.py index 44e54e461c..27e4a42999 100644 --- a/sentry_sdk/profiler/continuous_profiler.py +++ b/sentry_sdk/profiler/continuous_profiler.py @@ -1,3 +1,4 @@ +from __future__ import annotations import atexit import os import random @@ -60,18 +61,19 @@ from gevent.monkey import get_original from gevent.threadpool import ThreadPool as _ThreadPool - ThreadPool = _ThreadPool # type: Optional[Type[_ThreadPool]] + ThreadPool: Optional[Type[_ThreadPool]] = _ThreadPool thread_sleep = get_original("time", "sleep") except ImportError: thread_sleep = time.sleep ThreadPool = None -_scheduler = None # type: Optional[ContinuousScheduler] +_scheduler: Optional[ContinuousScheduler] = None -def setup_continuous_profiler(options, sdk_info, capture_func): - # type: (Dict[str, Any], SDKInfo, Callable[[Envelope], None]) -> bool +def setup_continuous_profiler( + options: Dict[str, Any], sdk_info: SDKInfo, capture_func: Callable[[Envelope], None] +) -> bool: global _scheduler if _scheduler is not None: @@ -115,8 +117,7 @@ def setup_continuous_profiler(options, sdk_info, capture_func): return True -def try_autostart_continuous_profiler(): - # type: () -> None +def try_autostart_continuous_profiler() -> None: # TODO: deprecate this as it'll be replaced by the auto lifecycle option @@ -129,47 +130,43 @@ def try_autostart_continuous_profiler(): _scheduler.manual_start() -def try_profile_lifecycle_trace_start(): - # type: () -> Union[ContinuousProfile, None] +def try_profile_lifecycle_trace_start() -> Union[ContinuousProfile, None]: if _scheduler is None: return None return _scheduler.auto_start() -def start_profiler(): - # type: () -> None +def start_profiler() -> None: if _scheduler is None: return _scheduler.manual_start() -def stop_profiler(): - # type: () -> None +def stop_profiler() -> None: if _scheduler is None: return _scheduler.manual_stop() -def teardown_continuous_profiler(): - # type: () -> None +def teardown_continuous_profiler() -> None: stop_profiler() global _scheduler _scheduler = None -def get_profiler_id(): - # type: () -> Union[str, None] +def get_profiler_id() -> Union[str, None]: if _scheduler is None: return None return _scheduler.profiler_id -def determine_profile_session_sampling_decision(sample_rate): - # type: (Union[float, None]) -> bool +def determine_profile_session_sampling_decision( + sample_rate: Union[float, None], +) -> bool: # `None` is treated as `0.0` if not sample_rate: @@ -181,16 +178,20 @@ def determine_profile_session_sampling_decision(sample_rate): class ContinuousProfile: active: bool = True - def stop(self): - # type: () -> None + def stop(self) -> None: self.active = False class ContinuousScheduler: - mode = "unknown" # type: ContinuousProfilerMode - - def __init__(self, frequency, options, sdk_info, capture_func): - # type: (int, Dict[str, Any], SDKInfo, Callable[[Envelope], None]) -> None + mode: ContinuousProfilerMode = "unknown" + + def __init__( + self, + frequency: int, + options: Dict[str, Any], + sdk_info: SDKInfo, + capture_func: Callable[[Envelope], None], + ) -> None: self.interval = 1.0 / frequency self.options = options self.sdk_info = sdk_info @@ -203,17 +204,16 @@ def __init__(self, frequency, options, sdk_info, capture_func): ) self.sampler = self.make_sampler() - self.buffer = None # type: Optional[ProfileBuffer] - self.pid = None # type: Optional[int] + self.buffer: Optional[ProfileBuffer] = None + self.pid: Optional[int] = None self.running = False self.soft_shutdown = False - self.new_profiles = deque(maxlen=128) # type: Deque[ContinuousProfile] - self.active_profiles = set() # type: Set[ContinuousProfile] + self.new_profiles: Deque[ContinuousProfile] = deque(maxlen=128) + self.active_profiles: Set[ContinuousProfile] = set() - def is_auto_start_enabled(self): - # type: () -> bool + def is_auto_start_enabled(self) -> bool: # Ensure that the scheduler only autostarts once per process. # This is necessary because many web servers use forks to spawn @@ -229,8 +229,7 @@ def is_auto_start_enabled(self): return experiments.get("continuous_profiling_auto_start") - def auto_start(self): - # type: () -> Union[ContinuousProfile, None] + def auto_start(self) -> Union[ContinuousProfile, None]: if not self.sampled: return None @@ -246,8 +245,7 @@ def auto_start(self): return profile - def manual_start(self): - # type: () -> None + def manual_start(self) -> None: if not self.sampled: return @@ -256,48 +254,40 @@ def manual_start(self): self.ensure_running() - def manual_stop(self): - # type: () -> None + def manual_stop(self) -> None: if self.lifecycle != "manual": return self.teardown() - def ensure_running(self): - # type: () -> None + def ensure_running(self) -> None: raise NotImplementedError - def teardown(self): - # type: () -> None + def teardown(self) -> None: raise NotImplementedError - def pause(self): - # type: () -> None + def pause(self) -> None: raise NotImplementedError - def reset_buffer(self): - # type: () -> None + def reset_buffer(self) -> None: self.buffer = ProfileBuffer( self.options, self.sdk_info, PROFILE_BUFFER_SECONDS, self.capture_func ) @property - def profiler_id(self): - # type: () -> Union[str, None] + def profiler_id(self) -> Union[str, None]: if self.buffer is None: return None return self.buffer.profiler_id - def make_sampler(self): - # type: () -> Callable[..., bool] + def make_sampler(self) -> Callable[..., bool]: cwd = os.getcwd() cache = LRUCache(max_size=256) if self.lifecycle == "trace": - def _sample_stack(*args, **kwargs): - # type: (*Any, **Any) -> bool + def _sample_stack(*args: Any, **kwargs: Any) -> bool: """ Take a sample of the stack on all the threads in the process. This should be called at a regular interval to collect samples. @@ -362,8 +352,7 @@ def _sample_stack(*args, **kwargs): else: - def _sample_stack(*args, **kwargs): - # type: (*Any, **Any) -> bool + def _sample_stack(*args: Any, **kwargs: Any) -> bool: """ Take a sample of the stack on all the threads in the process. This should be called at a regular interval to collect samples. @@ -389,8 +378,7 @@ def _sample_stack(*args, **kwargs): return _sample_stack - def run(self): - # type: () -> None + def run(self) -> None: last = time.perf_counter() while self.running: @@ -427,18 +415,22 @@ class ThreadContinuousScheduler(ContinuousScheduler): the sampler at a regular interval. """ - mode = "thread" # type: ContinuousProfilerMode + mode: ContinuousProfilerMode = "thread" name = "sentry.profiler.ThreadContinuousScheduler" - def __init__(self, frequency, options, sdk_info, capture_func): - # type: (int, Dict[str, Any], SDKInfo, Callable[[Envelope], None]) -> None + def __init__( + self, + frequency: int, + options: Dict[str, Any], + sdk_info: SDKInfo, + capture_func: Callable[[Envelope], None], + ) -> None: super().__init__(frequency, options, sdk_info, capture_func) - self.thread = None # type: Optional[threading.Thread] + self.thread: Optional[threading.Thread] = None self.lock = threading.Lock() - def ensure_running(self): - # type: () -> None + def ensure_running(self) -> None: self.soft_shutdown = False @@ -475,8 +467,7 @@ def ensure_running(self): self.running = False self.thread = None - def teardown(self): - # type: () -> None + def teardown(self) -> None: if self.running: self.running = False @@ -501,21 +492,25 @@ class GeventContinuousScheduler(ContinuousScheduler): results in a sample containing only the sampler's code. """ - mode = "gevent" # type: ContinuousProfilerMode + mode: ContinuousProfilerMode = "gevent" - def __init__(self, frequency, options, sdk_info, capture_func): - # type: (int, Dict[str, Any], SDKInfo, Callable[[Envelope], None]) -> None + def __init__( + self, + frequency: int, + options: Dict[str, Any], + sdk_info: SDKInfo, + capture_func: Callable[[Envelope], None], + ) -> None: if ThreadPool is None: raise ValueError("Profiler mode: {} is not available".format(self.mode)) super().__init__(frequency, options, sdk_info, capture_func) - self.thread = None # type: Optional[_ThreadPool] + self.thread: Optional[_ThreadPool] = None self.lock = threading.Lock() - def ensure_running(self): - # type: () -> None + def ensure_running(self) -> None: self.soft_shutdown = False @@ -548,8 +543,7 @@ def ensure_running(self): self.running = False self.thread = None - def teardown(self): - # type: () -> None + def teardown(self) -> None: if self.running: self.running = False @@ -564,8 +558,13 @@ def teardown(self): class ProfileBuffer: - def __init__(self, options, sdk_info, buffer_size, capture_func): - # type: (Dict[str, Any], SDKInfo, int, Callable[[Envelope], None]) -> None + def __init__( + self, + options: Dict[str, Any], + sdk_info: SDKInfo, + buffer_size: int, + capture_func: Callable[[Envelope], None], + ) -> None: self.options = options self.sdk_info = sdk_info self.buffer_size = buffer_size @@ -587,8 +586,7 @@ def __init__(self, options, sdk_info, buffer_size, capture_func): datetime.now(timezone.utc).timestamp() - self.start_monotonic_time ) - def write(self, monotonic_time, sample): - # type: (float, ExtractedSample) -> None + def write(self, monotonic_time: float, sample: ExtractedSample) -> None: if self.should_flush(monotonic_time): self.flush() self.chunk = ProfileChunk() @@ -596,15 +594,13 @@ def write(self, monotonic_time, sample): self.chunk.write(self.start_timestamp + monotonic_time, sample) - def should_flush(self, monotonic_time): - # type: (float) -> bool + def should_flush(self, monotonic_time: float) -> bool: # If the delta between the new monotonic time and the start monotonic time # exceeds the buffer size, it means we should flush the chunk return monotonic_time - self.start_monotonic_time >= self.buffer_size - def flush(self): - # type: () -> None + def flush(self) -> None: chunk = self.chunk.to_json(self.profiler_id, self.options, self.sdk_info) envelope = Envelope() envelope.add_profile_chunk(chunk) @@ -612,18 +608,16 @@ def flush(self): class ProfileChunk: - def __init__(self): - # type: () -> None + def __init__(self) -> None: self.chunk_id = uuid.uuid4().hex - self.indexed_frames = {} # type: Dict[FrameId, int] - self.indexed_stacks = {} # type: Dict[StackId, int] - self.frames = [] # type: List[ProcessedFrame] - self.stacks = [] # type: List[ProcessedStack] - self.samples = [] # type: List[ProcessedSample] + self.indexed_frames: Dict[FrameId, int] = {} + self.indexed_stacks: Dict[StackId, int] = {} + self.frames: List[ProcessedFrame] = [] + self.stacks: List[ProcessedStack] = [] + self.samples: List[ProcessedSample] = [] - def write(self, ts, sample): - # type: (float, ExtractedSample) -> None + def write(self, ts: float, sample: ExtractedSample) -> None: for tid, (stack_id, frame_ids, frames) in sample: try: # Check if the stack is indexed first, this lets us skip @@ -651,8 +645,9 @@ def write(self, ts, sample): # When this happens, we abandon the current sample as it's bad. capture_internal_exception(sys.exc_info()) - def to_json(self, profiler_id, options, sdk_info): - # type: (str, Dict[str, Any], SDKInfo) -> Dict[str, Any] + def to_json( + self, profiler_id: str, options: Dict[str, Any], sdk_info: SDKInfo + ) -> Dict[str, Any]: profile = { "frames": self.frames, "stacks": self.stacks, diff --git a/sentry_sdk/profiler/transaction_profiler.py b/sentry_sdk/profiler/transaction_profiler.py index 095ce2f2f9..f60dd95a87 100644 --- a/sentry_sdk/profiler/transaction_profiler.py +++ b/sentry_sdk/profiler/transaction_profiler.py @@ -25,6 +25,7 @@ SOFTWARE. """ +from __future__ import annotations import atexit import os import platform @@ -99,7 +100,7 @@ from gevent.monkey import get_original from gevent.threadpool import ThreadPool as _ThreadPool - ThreadPool = _ThreadPool # type: Optional[Type[_ThreadPool]] + ThreadPool: Optional[Type[_ThreadPool]] = _ThreadPool thread_sleep = get_original("time", "sleep") except ImportError: thread_sleep = time.sleep @@ -107,7 +108,7 @@ ThreadPool = None -_scheduler = None # type: Optional[Scheduler] +_scheduler: Optional[Scheduler] = None # The minimum number of unique samples that must exist in a profile to be @@ -115,8 +116,7 @@ PROFILE_MINIMUM_SAMPLES = 2 -def has_profiling_enabled(options): - # type: (Dict[str, Any]) -> bool +def has_profiling_enabled(options: Dict[str, Any]) -> bool: profiles_sampler = options["profiles_sampler"] if profiles_sampler is not None: return True @@ -128,8 +128,7 @@ def has_profiling_enabled(options): return False -def setup_profiler(options): - # type: (Dict[str, Any]) -> bool +def setup_profiler(options: Dict[str, Any]) -> bool: global _scheduler if _scheduler is not None: @@ -172,8 +171,7 @@ def setup_profiler(options): return True -def teardown_profiler(): - # type: () -> None +def teardown_profiler() -> None: global _scheduler @@ -189,40 +187,38 @@ def teardown_profiler(): class Profile: def __init__( self, - sampled, # type: Optional[bool] - start_ns, # type: int - scheduler=None, # type: Optional[Scheduler] - ): - # type: (...) -> None + sampled: Optional[bool], + start_ns: int, + scheduler: Optional[Scheduler] = None, + ) -> None: self.scheduler = _scheduler if scheduler is None else scheduler - self.event_id = uuid.uuid4().hex # type: str + self.event_id: str = uuid.uuid4().hex - self.sampled = sampled # type: Optional[bool] + self.sampled: Optional[bool] = sampled # Various framework integrations are capable of overwriting the active thread id. # If it is set to `None` at the end of the profile, we fall back to the default. - self._default_active_thread_id = get_current_thread_meta()[0] or 0 # type: int - self.active_thread_id = None # type: Optional[int] + self._default_active_thread_id: int = get_current_thread_meta()[0] or 0 + self.active_thread_id: Optional[int] = None try: - self.start_ns = start_ns # type: int + self.start_ns: int = start_ns except AttributeError: self.start_ns = 0 - self.stop_ns = 0 # type: int - self.active = False # type: bool + self.stop_ns: int = 0 + self.active: bool = False - self.indexed_frames = {} # type: Dict[FrameId, int] - self.indexed_stacks = {} # type: Dict[StackId, int] - self.frames = [] # type: List[ProcessedFrame] - self.stacks = [] # type: List[ProcessedStack] - self.samples = [] # type: List[ProcessedSample] + self.indexed_frames: Dict[FrameId, int] = {} + self.indexed_stacks: Dict[StackId, int] = {} + self.frames: List[ProcessedFrame] = [] + self.stacks: List[ProcessedStack] = [] + self.samples: List[ProcessedSample] = [] self.unique_samples = 0 - def update_active_thread_id(self): - # type: () -> None + def update_active_thread_id(self) -> None: self.active_thread_id = get_current_thread_meta()[0] logger.debug( "[Profiling] updating active thread id to {tid}".format( @@ -230,8 +226,7 @@ def update_active_thread_id(self): ) ) - def _set_initial_sampling_decision(self, sampling_context): - # type: (SamplingContext) -> None + def _set_initial_sampling_decision(self, sampling_context: SamplingContext) -> None: """ Sets the profile's sampling decision according to the following precedence rules: @@ -281,7 +276,8 @@ def _set_initial_sampling_decision(self, sampling_context): self.sampled = False return - if not is_valid_sample_rate(sample_rate, source="Profiling"): + sample_rate = is_valid_sample_rate(sample_rate, source="Profiling") + if sample_rate is None: logger.warning( "[Profiling] Discarding profile because of invalid sample rate." ) @@ -291,19 +287,18 @@ def _set_initial_sampling_decision(self, sampling_context): # Now we roll the dice. random.random is inclusive of 0, but not of 1, # so strict < is safe here. In case sample_rate is a boolean, cast it # to a float (True becomes 1.0 and False becomes 0.0) - self.sampled = random.random() < float(sample_rate) + self.sampled = random.random() < sample_rate if self.sampled: logger.debug("[Profiling] Initializing profile") else: logger.debug( "[Profiling] Discarding profile because it's not included in the random sample (sample rate = {sample_rate})".format( - sample_rate=float(sample_rate) + sample_rate=sample_rate ) ) - def start(self): - # type: () -> None + def start(self) -> None: if not self.sampled or self.active: return @@ -314,8 +309,7 @@ def start(self): self.start_ns = time.perf_counter_ns() self.scheduler.start_profiling(self) - def stop(self): - # type: () -> None + def stop(self) -> None: if not self.sampled or not self.active: return @@ -324,8 +318,7 @@ def stop(self): self.active = False self.stop_ns = time.perf_counter_ns() - def __enter__(self): - # type: () -> Profile + def __enter__(self) -> Profile: scope = sentry_sdk.get_isolation_scope() old_profile = scope.profile scope.profile = self @@ -336,8 +329,9 @@ def __enter__(self): return self - def __exit__(self, ty, value, tb): - # type: (Optional[Any], Optional[Any], Optional[Any]) -> None + def __exit__( + self, ty: Optional[Any], value: Optional[Any], tb: Optional[Any] + ) -> None: self.stop() scope, old_profile = self._context_manager_state @@ -345,8 +339,7 @@ def __exit__(self, ty, value, tb): scope.profile = old_profile - def write(self, ts, sample): - # type: (int, ExtractedSample) -> None + def write(self, ts: int, sample: ExtractedSample) -> None: if not self.active: return @@ -389,18 +382,17 @@ def write(self, ts, sample): # When this happens, we abandon the current sample as it's bad. capture_internal_exception(sys.exc_info()) - def process(self): - # type: () -> ProcessedProfile + def process(self) -> ProcessedProfile: # This collects the thread metadata at the end of a profile. Doing it # this way means that any threads that terminate before the profile ends # will not have any metadata associated with it. - thread_metadata = { + thread_metadata: Dict[str, ProcessedThreadMetadata] = { str(thread.ident): { "name": str(thread.name), } for thread in threading.enumerate() - } # type: Dict[str, ProcessedThreadMetadata] + } return { "frames": self.frames, @@ -409,8 +401,7 @@ def process(self): "thread_metadata": thread_metadata, } - def to_json(self, event_opt, options): - # type: (Event, Dict[str, Any]) -> Dict[str, Any] + def to_json(self, event_opt: Event, options: Dict[str, Any]) -> Dict[str, Any]: profile = self.process() set_in_app_in_frames( @@ -460,8 +451,7 @@ def to_json(self, event_opt, options): ], } - def valid(self): - # type: () -> bool + def valid(self) -> bool: client = sentry_sdk.get_client() if not client.is_active(): return False @@ -488,39 +478,35 @@ def valid(self): class Scheduler(ABC): - mode = "unknown" # type: ProfilerMode + mode: ProfilerMode = "unknown" - def __init__(self, frequency): - # type: (int) -> None + def __init__(self, frequency: int) -> None: self.interval = 1.0 / frequency self.sampler = self.make_sampler() # cap the number of new profiles at any time so it does not grow infinitely - self.new_profiles = deque(maxlen=128) # type: Deque[Profile] - self.active_profiles = set() # type: Set[Profile] + self.new_profiles: Deque[Profile] = deque(maxlen=128) + self.active_profiles: Set[Profile] = set() - def __enter__(self): - # type: () -> Scheduler + def __enter__(self) -> Scheduler: self.setup() return self - def __exit__(self, ty, value, tb): - # type: (Optional[Any], Optional[Any], Optional[Any]) -> None + def __exit__( + self, ty: Optional[Any], value: Optional[Any], tb: Optional[Any] + ) -> None: self.teardown() @abstractmethod - def setup(self): - # type: () -> None + def setup(self) -> None: pass @abstractmethod - def teardown(self): - # type: () -> None + def teardown(self) -> None: pass - def ensure_running(self): - # type: () -> None + def ensure_running(self) -> None: """ Ensure the scheduler is running. By default, this method is a no-op. The method should be overridden by any implementation for which it is @@ -528,19 +514,16 @@ def ensure_running(self): """ return None - def start_profiling(self, profile): - # type: (Profile) -> None + def start_profiling(self, profile: Profile) -> None: self.ensure_running() self.new_profiles.append(profile) - def make_sampler(self): - # type: () -> Callable[..., None] + def make_sampler(self) -> Callable[..., None]: cwd = os.getcwd() cache = LRUCache(max_size=256) - def _sample_stack(*args, **kwargs): - # type: (*Any, **Any) -> None + def _sample_stack(*args: Any, **kwargs: Any) -> None: """ Take a sample of the stack on all the threads in the process. This should be called at a regular interval to collect samples. @@ -611,32 +594,28 @@ class ThreadScheduler(Scheduler): the sampler at a regular interval. """ - mode = "thread" # type: ProfilerMode + mode: ProfilerMode = "thread" name = "sentry.profiler.ThreadScheduler" - def __init__(self, frequency): - # type: (int) -> None + def __init__(self, frequency: int) -> None: super().__init__(frequency=frequency) # used to signal to the thread that it should stop self.running = False - self.thread = None # type: Optional[threading.Thread] - self.pid = None # type: Optional[int] + self.thread: Optional[threading.Thread] = None + self.pid: Optional[int] = None self.lock = threading.Lock() - def setup(self): - # type: () -> None + def setup(self) -> None: pass - def teardown(self): - # type: () -> None + def teardown(self) -> None: if self.running: self.running = False if self.thread is not None: self.thread.join() - def ensure_running(self): - # type: () -> None + def ensure_running(self) -> None: """ Check that the profiler has an active thread to run in, and start one if that's not the case. @@ -674,8 +653,7 @@ def ensure_running(self): self.thread = None return - def run(self): - # type: () -> None + def run(self) -> None: last = time.perf_counter() while self.running: @@ -707,11 +685,10 @@ class GeventScheduler(Scheduler): results in a sample containing only the sampler's code. """ - mode = "gevent" # type: ProfilerMode + mode: ProfilerMode = "gevent" name = "sentry.profiler.GeventScheduler" - def __init__(self, frequency): - # type: (int) -> None + def __init__(self, frequency: int) -> None: if ThreadPool is None: raise ValueError("Profiler mode: {} is not available".format(self.mode)) @@ -720,27 +697,24 @@ def __init__(self, frequency): # used to signal to the thread that it should stop self.running = False - self.thread = None # type: Optional[_ThreadPool] - self.pid = None # type: Optional[int] + self.thread: Optional[_ThreadPool] = None + self.pid: Optional[int] = None # This intentionally uses the gevent patched threading.Lock. # The lock will be required when first trying to start profiles # as we need to spawn the profiler thread from the greenlets. self.lock = threading.Lock() - def setup(self): - # type: () -> None + def setup(self) -> None: pass - def teardown(self): - # type: () -> None + def teardown(self) -> None: if self.running: self.running = False if self.thread is not None: self.thread.join() - def ensure_running(self): - # type: () -> None + def ensure_running(self) -> None: pid = os.getpid() # is running on the right process @@ -767,8 +741,7 @@ def ensure_running(self): self.thread = None return - def run(self): - # type: () -> None + def run(self) -> None: last = time.perf_counter() while self.running: diff --git a/sentry_sdk/profiler/utils.py b/sentry_sdk/profiler/utils.py index 3554cddb5d..40d667dce2 100644 --- a/sentry_sdk/profiler/utils.py +++ b/sentry_sdk/profiler/utils.py @@ -1,3 +1,4 @@ +from __future__ import annotations import os from collections import deque @@ -63,14 +64,12 @@ if PY311: - def get_frame_name(frame): - # type: (FrameType) -> str + def get_frame_name(frame: FrameType) -> str: return frame.f_code.co_qualname else: - def get_frame_name(frame): - # type: (FrameType) -> str + def get_frame_name(frame: FrameType) -> str: f_code = frame.f_code co_varnames = f_code.co_varnames @@ -117,13 +116,11 @@ def get_frame_name(frame): return name -def frame_id(raw_frame): - # type: (FrameType) -> FrameId +def frame_id(raw_frame: FrameType) -> FrameId: return (raw_frame.f_code.co_filename, raw_frame.f_lineno, get_frame_name(raw_frame)) -def extract_frame(fid, raw_frame, cwd): - # type: (FrameId, FrameType, str) -> ProcessedFrame +def extract_frame(fid: FrameId, raw_frame: FrameType, cwd: str) -> ProcessedFrame: abs_path = raw_frame.f_code.co_filename try: @@ -152,12 +149,11 @@ def extract_frame(fid, raw_frame, cwd): def extract_stack( - raw_frame, # type: Optional[FrameType] - cache, # type: LRUCache - cwd, # type: str - max_stack_depth=MAX_STACK_DEPTH, # type: int -): - # type: (...) -> ExtractedStack + raw_frame: Optional[FrameType], + cache: LRUCache, + cwd: str, + max_stack_depth: int = MAX_STACK_DEPTH, +) -> ExtractedStack: """ Extracts the stack starting the specified frame. The extracted stack assumes the specified frame is the top of the stack, and works back @@ -167,7 +163,7 @@ def extract_stack( only the first `MAX_STACK_DEPTH` frames will be returned. """ - raw_frames = deque(maxlen=max_stack_depth) # type: Deque[FrameType] + raw_frames: Deque[FrameType] = deque(maxlen=max_stack_depth) while raw_frame is not None: f_back = raw_frame.f_back diff --git a/sentry_sdk/scope.py b/sentry_sdk/scope.py index b5e3d2d040..86e7609f40 100644 --- a/sentry_sdk/scope.py +++ b/sentry_sdk/scope.py @@ -1,3 +1,4 @@ +from __future__ import annotations import os import sys import warnings @@ -43,22 +44,25 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from collections.abc import Mapping, MutableMapping + from typing import ( + Any, + Callable, + Deque, + Dict, + Generator, + Iterator, + List, + Optional, + ParamSpec, + Tuple, + TypeVar, + Union, + Self, + ) - from typing import Any - from typing import Callable - from typing import Deque - from typing import Dict - from typing import Generator - from typing import Iterator - from typing import List - from typing import Optional - from typing import ParamSpec - from typing import Tuple - from typing import TypeVar - from typing import Union - from typing import Self + from collections.abc import Mapping, MutableMapping + import sentry_sdk from sentry_sdk._types import ( Breadcrumb, BreadcrumbHint, @@ -71,8 +75,6 @@ Type, ) - import sentry_sdk - P = ParamSpec("P") R = TypeVar("R") @@ -84,7 +86,7 @@ # In case this is a http server (think web framework) with multiple users # the data will be added to events of all users. # Typically this is used for process wide data such as the release. -_global_scope = None # type: Optional[Scope] +_global_scope: Optional[Scope] = None # Holds data for the active request. # This is used to isolate data for different requests or users. @@ -96,7 +98,7 @@ # This can be used to manually add additional data to a span. _current_scope = ContextVar("current_scope", default=None) -global_event_processors = [] # type: List[EventProcessor] +global_event_processors: List[EventProcessor] = [] class ScopeType(Enum): @@ -106,21 +108,17 @@ class ScopeType(Enum): MERGED = "merged" -def add_global_event_processor(processor): - # type: (EventProcessor) -> None +def add_global_event_processor(processor: EventProcessor) -> None: global_event_processors.append(processor) -def _attr_setter(fn): - # type: (Any) -> Any +def _attr_setter(fn: Any) -> Any: return property(fset=fn, doc=fn.__doc__) -def _disable_capture(fn): - # type: (F) -> F +def _disable_capture(fn: F) -> F: @wraps(fn) - def wrapper(self, *args, **kwargs): - # type: (Any, *Dict[str, Any], **Any) -> Any + def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: if not self._should_capture: return try: @@ -172,31 +170,29 @@ class Scope: "_flags", ) - def __init__(self, ty=None): - # type: (Optional[ScopeType]) -> None + def __init__(self, ty: Optional[ScopeType] = None) -> None: self._type = ty - self._event_processors = [] # type: List[EventProcessor] - self._error_processors = [] # type: List[ErrorProcessor] + self._event_processors: List[EventProcessor] = [] + self._error_processors: List[ErrorProcessor] = [] - self._name = None # type: Optional[str] - self._propagation_context = None # type: Optional[PropagationContext] - self._n_breadcrumbs_truncated = 0 # type: int + self._name: Optional[str] = None + self._propagation_context: Optional[PropagationContext] = None + self._n_breadcrumbs_truncated: int = 0 - self.client = NonRecordingClient() # type: sentry_sdk.client.BaseClient + self.client: sentry_sdk.client.BaseClient = NonRecordingClient() self.clear() incoming_trace_information = self._load_trace_data_from_env() self.generate_propagation_context(incoming_data=incoming_trace_information) - def __copy__(self): - # type: () -> Self + def __copy__(self) -> Self: """ Returns a copy of this scope. This also creates a copy of all referenced data structures. """ - rv = object.__new__(self.__class__) # type: Self + rv: Self = object.__new__(self.__class__) rv._type = self._type rv.client = self.client @@ -232,8 +228,7 @@ def __copy__(self): return rv @classmethod - def get_current_scope(cls): - # type: () -> Scope + def get_current_scope(cls) -> Scope: """ .. versionadded:: 2.0.0 @@ -247,16 +242,14 @@ def get_current_scope(cls): return current_scope @classmethod - def _get_current_scope(cls): - # type: () -> Optional[Scope] + def _get_current_scope(cls) -> Optional[Scope]: """ Returns the current scope without creating a new one. Internal use only. """ return _current_scope.get() @classmethod - def set_current_scope(cls, new_current_scope): - # type: (Scope) -> None + def set_current_scope(cls, new_current_scope: Scope) -> None: """ .. versionadded:: 2.0.0 @@ -266,8 +259,7 @@ def set_current_scope(cls, new_current_scope): _current_scope.set(new_current_scope) @classmethod - def get_isolation_scope(cls): - # type: () -> Scope + def get_isolation_scope(cls) -> Scope: """ .. versionadded:: 2.0.0 @@ -281,16 +273,14 @@ def get_isolation_scope(cls): return isolation_scope @classmethod - def _get_isolation_scope(cls): - # type: () -> Optional[Scope] + def _get_isolation_scope(cls) -> Optional[Scope]: """ Returns the isolation scope without creating a new one. Internal use only. """ return _isolation_scope.get() @classmethod - def set_isolation_scope(cls, new_isolation_scope): - # type: (Scope) -> None + def set_isolation_scope(cls, new_isolation_scope: Scope) -> None: """ .. versionadded:: 2.0.0 @@ -300,8 +290,7 @@ def set_isolation_scope(cls, new_isolation_scope): _isolation_scope.set(new_isolation_scope) @classmethod - def get_global_scope(cls): - # type: () -> Scope + def get_global_scope(cls) -> Scope: """ .. versionadded:: 2.0.0 @@ -314,8 +303,7 @@ def get_global_scope(cls): return _global_scope @classmethod - def last_event_id(cls): - # type: () -> Optional[str] + def last_event_id(cls) -> Optional[str]: """ .. versionadded:: 2.2.0 @@ -330,8 +318,11 @@ def last_event_id(cls): """ return cls.get_isolation_scope()._last_event_id - def _merge_scopes(self, additional_scope=None, additional_scope_kwargs=None): - # type: (Optional[Scope], Optional[Dict[str, Any]]) -> Self + def _merge_scopes( + self, + additional_scope: Optional[Scope] = None, + additional_scope_kwargs: Optional[Dict[str, Any]] = None, + ) -> Self: """ Merges global, isolation and current scope into a new scope and adds the given additional scope or additional scope kwargs to it. @@ -366,8 +357,7 @@ def _merge_scopes(self, additional_scope=None, additional_scope_kwargs=None): return final_scope @classmethod - def get_client(cls): - # type: () -> sentry_sdk.client.BaseClient + def get_client(cls) -> sentry_sdk.client.BaseClient: """ .. versionadded:: 2.0.0 @@ -393,18 +383,18 @@ def get_client(cls): if client is not None and client.is_active(): return client - try: - client = _global_scope.client # type: ignore - except AttributeError: - client = None + if _global_scope: + try: + client = _global_scope.client + except AttributeError: + client = None if client is not None and client.is_active(): return client return NonRecordingClient() - def set_client(self, client=None): - # type: (Optional[sentry_sdk.client.BaseClient]) -> None + def set_client(self, client: Optional[sentry_sdk.client.BaseClient] = None) -> None: """ .. versionadded:: 2.0.0 @@ -416,8 +406,7 @@ def set_client(self, client=None): """ self.client = client if client is not None else NonRecordingClient() - def fork(self): - # type: () -> Self + def fork(self) -> Self: """ .. versionadded:: 2.0.0 @@ -426,8 +415,7 @@ def fork(self): forked_scope = copy(self) return forked_scope - def _load_trace_data_from_env(self): - # type: () -> Optional[Dict[str, str]] + def _load_trace_data_from_env(self) -> Optional[Dict[str, str]]: """ Load Sentry trace id and baggage from environment variables. Can be disabled by setting SENTRY_USE_ENVIRONMENT to "false". @@ -453,15 +441,15 @@ def _load_trace_data_from_env(self): return incoming_trace_information or None - def set_new_propagation_context(self): - # type: () -> None + def set_new_propagation_context(self) -> None: """ Creates a new propagation context and sets it as `_propagation_context`. Overwriting existing one. """ self._propagation_context = PropagationContext() - def generate_propagation_context(self, incoming_data=None): - # type: (Optional[Dict[str, str]]) -> None + def generate_propagation_context( + self, incoming_data: Optional[dict[str, str]] = None + ) -> None: """ Makes sure the propagation context is set on the scope. If there is `incoming_data` overwrite existing propagation context. @@ -476,16 +464,14 @@ def generate_propagation_context(self, incoming_data=None): if self._propagation_context is None: self.set_new_propagation_context() - def get_dynamic_sampling_context(self): - # type: () -> Optional[Dict[str, str]] + def get_dynamic_sampling_context(self) -> Optional[Dict[str, str]]: """ Returns the Dynamic Sampling Context from the baggage or populates one. """ baggage = self.get_baggage() return baggage.dynamic_sampling_context() if baggage else None - def get_traceparent(self, *args, **kwargs): - # type: (Any, Any) -> Optional[str] + def get_traceparent(self, *args: Any, **kwargs: Any) -> Optional[str]: """ Returns the Sentry "sentry-trace" header (aka the traceparent) from the currently active span or the scopes Propagation Context. @@ -507,8 +493,7 @@ def get_traceparent(self, *args, **kwargs): # Fall back to isolation scope's traceparent. It always has one return self.get_isolation_scope().get_traceparent() - def get_baggage(self, *args, **kwargs): - # type: (Any, Any) -> Optional[Baggage] + def get_baggage(self, *args: Any, **kwargs: Any) -> Optional[Baggage]: """ Returns the Sentry "baggage" header containing trace information from the currently active span or the scopes Propagation Context. @@ -534,25 +519,23 @@ def get_baggage(self, *args, **kwargs): # Fall back to isolation scope's baggage. It always has one return self.get_isolation_scope().get_baggage() - def get_trace_context(self): - # type: () -> Any + def get_trace_context(self) -> Any: """ Returns the Sentry "trace" context from the Propagation Context. """ if self._propagation_context is None: return None - trace_context = { + trace_context: Dict[str, Any] = { "trace_id": self._propagation_context.trace_id, "span_id": self._propagation_context.span_id, "parent_span_id": self._propagation_context.parent_span_id, "dynamic_sampling_context": self.get_dynamic_sampling_context(), - } # type: Dict[str, Any] + } return trace_context - def trace_propagation_meta(self, *args, **kwargs): - # type: (*Any, **Any) -> str + def trace_propagation_meta(self, *args: Any, **kwargs: Any) -> str: """ Return meta tags which should be injected into HTML templates to allow propagation of trace information. @@ -575,8 +558,7 @@ def trace_propagation_meta(self, *args, **kwargs): return meta - def iter_headers(self): - # type: () -> Iterator[Tuple[str, str]] + def iter_headers(self) -> Iterator[Tuple[str, str]]: """ Creates a generator which returns the `sentry-trace` and `baggage` headers from the Propagation Context. """ @@ -589,8 +571,9 @@ def iter_headers(self): if baggage is not None: yield BAGGAGE_HEADER_NAME, baggage.serialize() - def iter_trace_propagation_headers(self, *args, **kwargs): - # type: (Any, Any) -> Generator[Tuple[str, str], None, None] + def iter_trace_propagation_headers( + self, *args: Any, **kwargs: Any + ) -> Generator[Tuple[str, str], None, None]: """ Return HTTP headers which allow propagation of trace data. @@ -624,8 +607,7 @@ def iter_trace_propagation_headers(self, *args, **kwargs): for header in isolation_scope.iter_headers(): yield header - def get_active_propagation_context(self): - # type: () -> Optional[PropagationContext] + def get_active_propagation_context(self) -> Optional[PropagationContext]: if self._propagation_context is not None: return self._propagation_context @@ -639,37 +621,35 @@ def get_active_propagation_context(self): return None - def clear(self): - # type: () -> None + def clear(self) -> None: """Clears the entire scope.""" - self._level = None # type: Optional[LogLevelStr] - self._fingerprint = None # type: Optional[List[str]] - self._transaction = None # type: Optional[str] - self._transaction_info = {} # type: MutableMapping[str, str] - self._user = None # type: Optional[Dict[str, Any]] + self._level: Optional[LogLevelStr] = None + self._fingerprint: Optional[List[str]] = None + self._transaction: Optional[str] = None + self._transaction_info: MutableMapping[str, str] = {} + self._user: Optional[Dict[str, Any]] = None - self._tags = {} # type: Dict[str, Any] - self._contexts = {} # type: Dict[str, Dict[str, Any]] - self._extras = {} # type: MutableMapping[str, Any] - self._attachments = [] # type: List[Attachment] + self._tags: Dict[str, Any] = {} + self._contexts: Dict[str, Dict[str, Any]] = {} + self._extras: MutableMapping[str, Any] = {} + self._attachments: List[Attachment] = [] self.clear_breadcrumbs() - self._should_capture = True # type: bool + self._should_capture: bool = True - self._span = None # type: Optional[Span] - self._session = None # type: Optional[Session] - self._force_auto_session_tracking = None # type: Optional[bool] + self._span: Optional[Span] = None + self._session: Optional[Session] = None + self._force_auto_session_tracking: Optional[bool] = None - self._profile = None # type: Optional[Profile] + self._profile: Optional[Profile] = None self._propagation_context = None # self._last_event_id is only applicable to isolation scopes - self._last_event_id = None # type: Optional[str] - self._flags = None # type: Optional[FlagBuffer] + self._last_event_id: Optional[str] = None + self._flags: Optional[FlagBuffer] = None - def set_level(self, value): - # type: (LogLevelStr) -> None + def set_level(self, value: LogLevelStr) -> None: """ Sets the level for the scope. @@ -678,22 +658,19 @@ def set_level(self, value): self._level = value @_attr_setter - def fingerprint(self, value): - # type: (Optional[List[str]]) -> None + def fingerprint(self, value: Optional[List[str]]) -> None: """When set this overrides the default fingerprint.""" self._fingerprint = value @property - def root_span(self): - # type: () -> Optional[Span] + def root_span(self) -> Optional[Span]: """Return the root span in the scope, if any.""" if self._span is None: return None return self._span.root_span - def set_transaction_name(self, name, source=None): - # type: (str, Optional[str]) -> None + def set_transaction_name(self, name: str, source: Optional[str] = None) -> None: """Set the transaction name and optionally the transaction source.""" self._transaction = name @@ -706,17 +683,14 @@ def set_transaction_name(self, name, source=None): self._transaction_info["source"] = source @property - def transaction_name(self): - # type: () -> Optional[str] + def transaction_name(self) -> Optional[str]: return self._transaction @property - def transaction_source(self): - # type: () -> Optional[str] + def transaction_source(self) -> Optional[str]: return self._transaction_info.get("source") - def set_user(self, value): - # type: (Optional[Dict[str, Any]]) -> None + def set_user(self, value: Optional[Dict[str, Any]]) -> None: """Sets a user for the scope.""" self._user = value session = self.get_isolation_scope()._session @@ -724,24 +698,20 @@ def set_user(self, value): session.update(user=value) @property - def span(self): - # type: () -> Optional[Span] + def span(self) -> Optional[Span]: """Get current tracing span.""" return self._span @property - def profile(self): - # type: () -> Optional[Profile] + def profile(self) -> Optional[Profile]: return self._profile @profile.setter - def profile(self, profile): - # type: (Optional[Profile]) -> None + def profile(self, profile: Optional[Profile]) -> None: self._profile = profile - def set_tag(self, key, value): - # type: (str, Any) -> None + def set_tag(self, key: str, value: Any) -> None: """ Sets a tag for a key to a specific value. @@ -751,8 +721,7 @@ def set_tag(self, key, value): """ self._tags[key] = value - def set_tags(self, tags): - # type: (Mapping[str, object]) -> None + def set_tags(self, tags: Mapping[str, object]) -> None: """Sets multiple tags at once. This method updates multiple tags at once. The tags are passed as a dictionary @@ -770,8 +739,7 @@ def set_tags(self, tags): """ self._tags.update(tags) - def remove_tag(self, key): - # type: (str) -> None + def remove_tag(self, key: str) -> None: """ Removes a specific tag. @@ -781,53 +749,46 @@ def remove_tag(self, key): def set_context( self, - key, # type: str - value, # type: Dict[str, Any] - ): - # type: (...) -> None + key: str, + value: Dict[str, Any], + ) -> None: """ Binds a context at a certain key to a specific value. """ self._contexts[key] = value def remove_context( - self, key # type: str - ): - # type: (...) -> None + self, + key: str, + ) -> None: """Removes a context.""" self._contexts.pop(key, None) def set_extra( self, - key, # type: str - value, # type: Any - ): - # type: (...) -> None + key: str, + value: Any, + ) -> None: """Sets an extra key to a specific value.""" self._extras[key] = value - def remove_extra( - self, key # type: str - ): - # type: (...) -> None + def remove_extra(self, key: str) -> None: """Removes a specific extra key.""" self._extras.pop(key, None) - def clear_breadcrumbs(self): - # type: () -> None + def clear_breadcrumbs(self) -> None: """Clears breadcrumb buffer.""" - self._breadcrumbs = deque() # type: Deque[Breadcrumb] + self._breadcrumbs: Deque[Breadcrumb] = deque() self._n_breadcrumbs_truncated = 0 def add_attachment( self, - bytes=None, # type: Union[None, bytes, Callable[[], bytes]] - filename=None, # type: Optional[str] - path=None, # type: Optional[str] - content_type=None, # type: Optional[str] - add_to_transactions=False, # type: bool - ): - # type: (...) -> None + bytes: Union[None, bytes, Callable[[], bytes]] = None, + filename: Optional[str] = None, + path: Optional[str] = None, + content_type: Optional[str] = None, + add_to_transactions: bool = False, + ) -> None: """Adds an attachment to future events sent from this scope. The parameters are the same as for the :py:class:`sentry_sdk.attachments.Attachment` constructor. @@ -842,8 +803,12 @@ def add_attachment( ) ) - def add_breadcrumb(self, crumb=None, hint=None, **kwargs): - # type: (Optional[Breadcrumb], Optional[BreadcrumbHint], Any) -> None + def add_breadcrumb( + self, + crumb: Optional[Breadcrumb] = None, + hint: Optional[BreadcrumbHint] = None, + **kwargs: Any, + ) -> None: """ Adds a breadcrumb. @@ -861,12 +826,12 @@ def add_breadcrumb(self, crumb=None, hint=None, **kwargs): before_breadcrumb = client.options.get("before_breadcrumb") max_breadcrumbs = client.options.get("max_breadcrumbs", DEFAULT_MAX_BREADCRUMBS) - crumb = dict(crumb or ()) # type: Breadcrumb + crumb: Breadcrumb = dict(crumb or ()) crumb.update(kwargs) if not crumb: return - hint = dict(hint or ()) # type: Hint + hint: Hint = dict(hint or ()) if crumb.get("timestamp") is None: crumb["timestamp"] = datetime.now(timezone.utc) @@ -887,8 +852,7 @@ def add_breadcrumb(self, crumb=None, hint=None, **kwargs): self._breadcrumbs.popleft() self._n_breadcrumbs_truncated += 1 - def start_transaction(self, **kwargs): - # type: (Any) -> Union[NoOpSpan, Span] + def start_transaction(self, **kwargs: Any) -> Union[NoOpSpan, Span]: """ .. deprecated:: 3.0.0 This function is deprecated and will be removed in a future release. @@ -901,8 +865,7 @@ def start_transaction(self, **kwargs): ) return NoOpSpan(**kwargs) - def start_span(self, **kwargs): - # type: (Any) -> Union[NoOpSpan, Span] + def start_span(self, **kwargs: Any) -> Union[NoOpSpan, Span]: """ Start a span whose parent is the currently active span, if any. @@ -915,16 +878,22 @@ def start_span(self, **kwargs): return NoOpSpan(**kwargs) @contextmanager - def continue_trace(self, environ_or_headers): - # type: (Dict[str, Any]) -> Generator[None, None, None] + def continue_trace( + self, environ_or_headers: Dict[str, Any] + ) -> Generator[None, None, None]: """ Sets the propagation context from environment or headers to continue an incoming trace. """ self.generate_propagation_context(environ_or_headers) yield - def capture_event(self, event, hint=None, scope=None, **scope_kwargs): - # type: (Event, Optional[Hint], Optional[Scope], Any) -> Optional[str] + def capture_event( + self, + event: Event, + hint: Optional[Hint] = None, + scope: Optional[Scope] = None, + **scope_kwargs: Any, + ) -> Optional[str]: """ Captures an event. @@ -955,8 +924,13 @@ def capture_event(self, event, hint=None, scope=None, **scope_kwargs): return event_id - def capture_message(self, message, level=None, scope=None, **scope_kwargs): - # type: (str, Optional[LogLevelStr], Optional[Scope], Any) -> Optional[str] + def capture_message( + self, + message: str, + level: Optional[LogLevelStr] = None, + scope: Optional[Scope] = None, + **scope_kwargs: Any, + ) -> Optional[str]: """ Captures a message. @@ -979,15 +953,19 @@ def capture_message(self, message, level=None, scope=None, **scope_kwargs): if level is None: level = "info" - event = { + event: Event = { "message": message, "level": level, - } # type: Event + } return self.capture_event(event, scope=scope, **scope_kwargs) - def capture_exception(self, error=None, scope=None, **scope_kwargs): - # type: (Optional[Union[BaseException, ExcInfo]], Optional[Scope], Any) -> Optional[str] + def capture_exception( + self, + error: Optional[Union[BaseException, ExcInfo]] = None, + scope: Optional[Scope] = None, + **scope_kwargs: Any, + ) -> Optional[str]: """Captures an exception. :param error: An exception to capture. If `None`, `sys.exc_info()` will be used. @@ -1020,8 +998,7 @@ def capture_exception(self, error=None, scope=None, **scope_kwargs): return None - def start_session(self, *args, **kwargs): - # type: (*Any, **Any) -> None + def start_session(self, *args: Any, **kwargs: Any) -> None: """Starts a new session.""" session_mode = kwargs.pop("session_mode", "application") @@ -1035,8 +1012,7 @@ def start_session(self, *args, **kwargs): session_mode=session_mode, ) - def end_session(self, *args, **kwargs): - # type: (*Any, **Any) -> None + def end_session(self, *args: Any, **kwargs: Any) -> None: """Ends the current session if there is one.""" session = self._session self._session = None @@ -1045,8 +1021,7 @@ def end_session(self, *args, **kwargs): session.close() self.get_client().capture_session(session) - def stop_auto_session_tracking(self, *args, **kwargs): - # type: (*Any, **Any) -> None + def stop_auto_session_tracking(self, *args: Any, **kwargs: Any) -> None: """Stops automatic session tracking. This temporarily session tracking for the current scope when called. @@ -1055,18 +1030,14 @@ def stop_auto_session_tracking(self, *args, **kwargs): self.end_session() self._force_auto_session_tracking = False - def resume_auto_session_tracking(self): - # type: (...) -> None + def resume_auto_session_tracking(self) -> None: """Resumes automatic session tracking for the current scope if disabled earlier. This requires that generally automatic session tracking is enabled. """ self._force_auto_session_tracking = None - def add_event_processor( - self, func # type: EventProcessor - ): - # type: (...) -> None + def add_event_processor(self, func: EventProcessor) -> None: """Register a scope local event processor on the scope. :param func: This function behaves like `before_send.` @@ -1082,10 +1053,9 @@ def add_event_processor( def add_error_processor( self, - func, # type: ErrorProcessor - cls=None, # type: Optional[Type[BaseException]] - ): - # type: (...) -> None + func: ErrorProcessor, + cls: Optional[Type[BaseException]] = None, + ) -> None: """Register a scope local error processor on the scope. :param func: A callback that works similar to an event processor but is invoked with the original exception info triple as second argument. @@ -1096,8 +1066,7 @@ def add_error_processor( cls_ = cls # For mypy. real_func = func - def func(event, exc_info): - # type: (Event, ExcInfo) -> Optional[Event] + def wrapped_func(event: Event, exc_info: ExcInfo) -> Optional[Event]: try: is_inst = isinstance(exc_info[1], cls_) except Exception: @@ -1106,15 +1075,17 @@ def func(event, exc_info): return real_func(event, exc_info) return event - self._error_processors.append(func) + self._error_processors.append(wrapped_func) - def _apply_level_to_event(self, event, hint, options): - # type: (Event, Hint, Optional[Dict[str, Any]]) -> None + def _apply_level_to_event( + self, event: Event, hint: Hint, options: Optional[Dict[str, Any]] + ) -> None: if self._level is not None: event["level"] = self._level - def _apply_breadcrumbs_to_event(self, event, hint, options): - # type: (Event, Hint, Optional[Dict[str, Any]]) -> None + def _apply_breadcrumbs_to_event( + self, event: Event, hint: Hint, options: Optional[Dict[str, Any]] + ) -> None: event.setdefault("breadcrumbs", {}) # This check is just for mypy - @@ -1136,38 +1107,45 @@ def _apply_breadcrumbs_to_event(self, event, hint, options): logger.debug("Error when sorting breadcrumbs", exc_info=err) pass - def _apply_user_to_event(self, event, hint, options): - # type: (Event, Hint, Optional[Dict[str, Any]]) -> None + def _apply_user_to_event( + self, event: Event, hint: Hint, options: Optional[Dict[str, Any]] + ) -> None: if event.get("user") is None and self._user is not None: event["user"] = self._user - def _apply_transaction_name_to_event(self, event, hint, options): - # type: (Event, Hint, Optional[Dict[str, Any]]) -> None + def _apply_transaction_name_to_event( + self, event: Event, hint: Hint, options: Optional[Dict[str, Any]] + ) -> None: if event.get("transaction") is None and self._transaction is not None: event["transaction"] = self._transaction - def _apply_transaction_info_to_event(self, event, hint, options): - # type: (Event, Hint, Optional[Dict[str, Any]]) -> None + def _apply_transaction_info_to_event( + self, event: Event, hint: Hint, options: Optional[Dict[str, Any]] + ) -> None: if event.get("transaction_info") is None and self._transaction_info is not None: event["transaction_info"] = self._transaction_info - def _apply_fingerprint_to_event(self, event, hint, options): - # type: (Event, Hint, Optional[Dict[str, Any]]) -> None + def _apply_fingerprint_to_event( + self, event: Event, hint: Hint, options: Optional[Dict[str, Any]] + ) -> None: if event.get("fingerprint") is None and self._fingerprint is not None: event["fingerprint"] = self._fingerprint - def _apply_extra_to_event(self, event, hint, options): - # type: (Event, Hint, Optional[Dict[str, Any]]) -> None + def _apply_extra_to_event( + self, event: Event, hint: Hint, options: Optional[Dict[str, Any]] + ) -> None: if self._extras: event.setdefault("extra", {}).update(self._extras) - def _apply_tags_to_event(self, event, hint, options): - # type: (Event, Hint, Optional[Dict[str, Any]]) -> None + def _apply_tags_to_event( + self, event: Event, hint: Hint, options: Optional[Dict[str, Any]] + ) -> None: if self._tags: event.setdefault("tags", {}).update(self._tags) - def _apply_contexts_to_event(self, event, hint, options): - # type: (Event, Hint, Optional[Dict[str, Any]]) -> None + def _apply_contexts_to_event( + self, event: Event, hint: Hint, options: Optional[Dict[str, Any]] + ) -> None: if self._contexts: event.setdefault("contexts", {}).update(self._contexts) @@ -1176,7 +1154,8 @@ def _apply_contexts_to_event(self, event, hint, options): # Add "trace" context if contexts.get("trace") is None: if ( - has_tracing_enabled(options) + options is not None + and has_tracing_enabled(options) and self._span is not None and self._span.is_valid ): @@ -1184,21 +1163,20 @@ def _apply_contexts_to_event(self, event, hint, options): else: contexts["trace"] = self.get_trace_context() - def _apply_flags_to_event(self, event, hint, options): - # type: (Event, Hint, Optional[Dict[str, Any]]) -> None + def _apply_flags_to_event( + self, event: Event, hint: Hint, options: Optional[Dict[str, Any]] + ) -> None: flags = self.flags.get() if len(flags) > 0: event.setdefault("contexts", {}).setdefault("flags", {}).update( {"values": flags} ) - def _drop(self, cause, ty): - # type: (Any, str) -> Optional[Any] + def _drop(self, cause: Any, ty: str) -> Optional[Any]: logger.info("%s (%s) dropped event", ty, cause) return None - def run_error_processors(self, event, hint): - # type: (Event, Hint) -> Optional[Event] + def run_error_processors(self, event: Event, hint: Hint) -> Optional[Event]: """ Runs the error processors on the event and returns the modified event. """ @@ -1219,8 +1197,7 @@ def run_error_processors(self, event, hint): return event - def run_event_processors(self, event, hint): - # type: (Event, Hint) -> Optional[Event] + def run_event_processors(self, event: Event, hint: Hint) -> Optional[Event]: """ Runs the event processors on the event and returns the modified event. """ @@ -1240,7 +1217,7 @@ def run_event_processors(self, event, hint): ) for event_processor in event_processors: - new_event = event # type: Optional[Event] + new_event: Optional[Event] = event with capture_internal_exceptions(): new_event = event_processor(event, hint) if new_event is None: @@ -1252,11 +1229,10 @@ def run_event_processors(self, event, hint): @_disable_capture def apply_to_event( self, - event, # type: Event - hint, # type: Hint - options=None, # type: Optional[Dict[str, Any]] - ): - # type: (...) -> Optional[Event] + event: Event, + hint: Hint, + options: Optional[Dict[str, Any]] = None, + ) -> Optional[Event]: """Applies the information contained on the scope to the given event.""" ty = event.get("type") is_transaction = ty == "transaction" @@ -1302,8 +1278,7 @@ def apply_to_event( return event - def update_from_scope(self, scope): - # type: (Scope) -> None + def update_from_scope(self, scope: Scope) -> None: """Update the scope with another scope's data.""" if scope._level is not None: self._level = scope._level @@ -1346,14 +1321,13 @@ def update_from_scope(self, scope): def update_from_kwargs( self, - user=None, # type: Optional[Any] - level=None, # type: Optional[LogLevelStr] - extras=None, # type: Optional[Dict[str, Any]] - contexts=None, # type: Optional[Dict[str, Dict[str, Any]]] - tags=None, # type: Optional[Dict[str, str]] - fingerprint=None, # type: Optional[List[str]] - ): - # type: (...) -> None + user: Optional[Any] = None, + level: Optional[LogLevelStr] = None, + extras: Optional[Dict[str, Any]] = None, + contexts: Optional[Dict[str, Dict[str, Any]]] = None, + tags: Optional[Dict[str, str]] = None, + fingerprint: Optional[List[str]] = None, + ) -> None: """Update the scope's attributes.""" if level is not None: self._level = level @@ -1368,8 +1342,7 @@ def update_from_kwargs( if fingerprint is not None: self._fingerprint = fingerprint - def __repr__(self): - # type: () -> str + def __repr__(self) -> str: return "<%s id=%s name=%s type=%s>" % ( self.__class__.__name__, hex(id(self)), @@ -1378,8 +1351,7 @@ def __repr__(self): ) @property - def flags(self): - # type: () -> FlagBuffer + def flags(self) -> FlagBuffer: if self._flags is None: max_flags = ( self.get_client().options["_experiments"].get("max_flags") @@ -1390,8 +1362,7 @@ def flags(self): @contextmanager -def new_scope(): - # type: () -> Generator[Scope, None, None] +def new_scope() -> Generator[Scope, None, None]: """ .. versionadded:: 2.0.0 @@ -1428,8 +1399,7 @@ def new_scope(): @contextmanager -def use_scope(scope): - # type: (Scope) -> Generator[Scope, None, None] +def use_scope(scope: Scope) -> Generator[Scope, None, None]: """ .. versionadded:: 2.0.0 @@ -1466,8 +1436,7 @@ def use_scope(scope): @contextmanager -def isolation_scope(): - # type: () -> Generator[Scope, None, None] +def isolation_scope() -> Generator[Scope, None, None]: """ .. versionadded:: 2.0.0 @@ -1515,8 +1484,7 @@ def isolation_scope(): @contextmanager -def use_isolation_scope(isolation_scope): - # type: (Scope) -> Generator[Scope, None, None] +def use_isolation_scope(isolation_scope: Scope) -> Generator[Scope, None, None]: """ .. versionadded:: 2.0.0 @@ -1561,14 +1529,10 @@ def use_isolation_scope(isolation_scope): capture_internal_exception(sys.exc_info()) -def should_send_default_pii(): - # type: () -> bool +def should_send_default_pii() -> bool: """Shortcut for `Scope.get_client().should_send_default_pii()`.""" return Scope.get_client().should_send_default_pii() # Circular imports from sentry_sdk.client import NonRecordingClient - -if TYPE_CHECKING: - import sentry_sdk.client diff --git a/sentry_sdk/scrubber.py b/sentry_sdk/scrubber.py index b0576c7e95..a8fcd9b8ba 100644 --- a/sentry_sdk/scrubber.py +++ b/sentry_sdk/scrubber.py @@ -1,14 +1,15 @@ +from __future__ import annotations from sentry_sdk.utils import ( capture_internal_exceptions, AnnotatedValue, iter_event_frames, ) -from typing import TYPE_CHECKING, cast, List, Dict +from typing import TYPE_CHECKING if TYPE_CHECKING: + from typing import List, Optional from sentry_sdk._types import Event - from typing import Optional DEFAULT_DENYLIST = [ @@ -60,9 +61,12 @@ class EventScrubber: def __init__( - self, denylist=None, recursive=False, send_default_pii=False, pii_denylist=None - ): - # type: (Optional[List[str]], bool, bool, Optional[List[str]]) -> None + self, + denylist: Optional[List[str]] = None, + recursive: bool = False, + send_default_pii: bool = False, + pii_denylist: Optional[List[str]] = None, + ) -> None: """ A scrubber that goes through the event payload and removes sensitive data configured through denylists. @@ -82,8 +86,7 @@ def __init__( self.denylist = [x.lower() for x in self.denylist] self.recursive = recursive - def scrub_list(self, lst): - # type: (object) -> None + def scrub_list(self, lst: object) -> None: """ If a list is passed to this method, the method recursively searches the list and any nested lists for any dictionaries. The method calls scrub_dict on all dictionaries @@ -97,8 +100,7 @@ def scrub_list(self, lst): self.scrub_dict(v) # no-op unless v is a dict self.scrub_list(v) # no-op unless v is a list - def scrub_dict(self, d): - # type: (object) -> None + def scrub_dict(self, d: object) -> None: """ If a dictionary is passed to this method, the method scrubs the dictionary of any sensitive data. The method calls itself recursively on any nested dictionaries ( @@ -117,8 +119,7 @@ def scrub_dict(self, d): self.scrub_dict(v) # no-op unless v is a dict self.scrub_list(v) # no-op unless v is a list - def scrub_request(self, event): - # type: (Event) -> None + def scrub_request(self, event: Event) -> None: with capture_internal_exceptions(): if "request" in event: if "headers" in event["request"]: @@ -128,20 +129,17 @@ def scrub_request(self, event): if "data" in event["request"]: self.scrub_dict(event["request"]["data"]) - def scrub_extra(self, event): - # type: (Event) -> None + def scrub_extra(self, event: Event) -> None: with capture_internal_exceptions(): if "extra" in event: self.scrub_dict(event["extra"]) - def scrub_user(self, event): - # type: (Event) -> None + def scrub_user(self, event: Event) -> None: with capture_internal_exceptions(): if "user" in event: self.scrub_dict(event["user"]) - def scrub_breadcrumbs(self, event): - # type: (Event) -> None + def scrub_breadcrumbs(self, event: Event) -> None: with capture_internal_exceptions(): if "breadcrumbs" in event: if ( @@ -152,23 +150,21 @@ def scrub_breadcrumbs(self, event): if "data" in value: self.scrub_dict(value["data"]) - def scrub_frames(self, event): - # type: (Event) -> None + def scrub_frames(self, event: Event) -> None: with capture_internal_exceptions(): for frame in iter_event_frames(event): if "vars" in frame: self.scrub_dict(frame["vars"]) - def scrub_spans(self, event): - # type: (Event) -> None + def scrub_spans(self, event: Event) -> None: with capture_internal_exceptions(): if "spans" in event: - for span in cast(List[Dict[str, object]], event["spans"]): - if "data" in span: - self.scrub_dict(span["data"]) + if not isinstance(event["spans"], AnnotatedValue): + for span in event["spans"]: + if "data" in span: + self.scrub_dict(span["data"]) - def scrub_event(self, event): - # type: (Event) -> None + def scrub_event(self, event: Event) -> None: self.scrub_request(event) self.scrub_extra(event) self.scrub_user(event) diff --git a/sentry_sdk/serializer.py b/sentry_sdk/serializer.py index bc8e38c631..bd629c2927 100644 --- a/sentry_sdk/serializer.py +++ b/sentry_sdk/serializer.py @@ -1,3 +1,4 @@ +from __future__ import annotations import sys import math from collections.abc import Mapping, Sequence, Set @@ -26,7 +27,7 @@ from typing import Type from typing import Union - from sentry_sdk._types import NotImplementedType + from sentry_sdk._types import NotImplementedType, Event Span = Dict[str, Any] @@ -55,29 +56,25 @@ CYCLE_MARKER = "" -global_repr_processors = [] # type: List[ReprProcessor] +global_repr_processors: List[ReprProcessor] = [] -def add_global_repr_processor(processor): - # type: (ReprProcessor) -> None +def add_global_repr_processor(processor: ReprProcessor) -> None: global_repr_processors.append(processor) class Memo: __slots__ = ("_ids", "_objs") - def __init__(self): - # type: () -> None - self._ids = {} # type: Dict[int, Any] - self._objs = [] # type: List[Any] + def __init__(self) -> None: + self._ids: Dict[int, Any] = {} + self._objs: List[Any] = [] - def memoize(self, obj): - # type: (Any) -> ContextManager[bool] + def memoize(self, obj: Any) -> ContextManager[bool]: self._objs.append(obj) return self - def __enter__(self): - # type: () -> bool + def __enter__(self) -> bool: obj = self._objs[-1] if id(obj) in self._ids: return True @@ -87,16 +84,14 @@ def __enter__(self): def __exit__( self, - ty, # type: Optional[Type[BaseException]] - value, # type: Optional[BaseException] - tb, # type: Optional[TracebackType] - ): - # type: (...) -> None + ty: Optional[Type[BaseException]], + value: Optional[BaseException], + tb: Optional[TracebackType], + ) -> None: self._ids.pop(id(self._objs.pop()), None) -def serialize(event, **kwargs): - # type: (Dict[str, Any], **Any) -> Dict[str, Any] +def serialize(event: Union[Dict[str, Any], Event], **kwargs: Any) -> Dict[str, Any]: """ A very smart serializer that takes a dict and emits a json-friendly dict. Currently used for serializing the final Event and also prematurely while fetching the stack @@ -117,18 +112,15 @@ def serialize(event, **kwargs): """ memo = Memo() - path = [] # type: List[Segment] - meta_stack = [] # type: List[Dict[str, Any]] + path: List[Segment] = [] + meta_stack: List[Dict[str, Any]] = [] - keep_request_bodies = ( - kwargs.pop("max_request_body_size", None) == "always" - ) # type: bool - max_value_length = kwargs.pop("max_value_length", None) # type: Optional[int] + keep_request_bodies: bool = kwargs.pop("max_request_body_size", None) == "always" + max_value_length: Optional[int] = kwargs.pop("max_value_length", None) is_vars = kwargs.pop("is_vars", False) - custom_repr = kwargs.pop("custom_repr", None) # type: Callable[..., Optional[str]] + custom_repr: Callable[..., Optional[str]] = kwargs.pop("custom_repr", None) - def _safe_repr_wrapper(value): - # type: (Any) -> str + def _safe_repr_wrapper(value: Any) -> str: try: repr_value = None if custom_repr is not None: @@ -137,8 +129,7 @@ def _safe_repr_wrapper(value): except Exception: return safe_repr(value) - def _annotate(**meta): - # type: (**Any) -> None + def _annotate(**meta: Any) -> None: while len(meta_stack) <= len(path): try: segment = path[len(meta_stack) - 1] @@ -150,8 +141,7 @@ def _annotate(**meta): meta_stack[-1].setdefault("", {}).update(meta) - def _is_databag(): - # type: () -> Optional[bool] + def _is_databag() -> Optional[bool]: """ A databag is any value that we need to trim. True for stuff like vars, request bodies, breadcrumbs and extra. @@ -179,8 +169,7 @@ def _is_databag(): return False - def _is_request_body(): - # type: () -> Optional[bool] + def _is_request_body() -> Optional[bool]: try: if path[0] == "request" and path[1] == "data": return True @@ -190,15 +179,14 @@ def _is_request_body(): return False def _serialize_node( - obj, # type: Any - is_databag=None, # type: Optional[bool] - is_request_body=None, # type: Optional[bool] - should_repr_strings=None, # type: Optional[bool] - segment=None, # type: Optional[Segment] - remaining_breadth=None, # type: Optional[Union[int, float]] - remaining_depth=None, # type: Optional[Union[int, float]] - ): - # type: (...) -> Any + obj: Any, + is_databag: Optional[bool] = None, + is_request_body: Optional[bool] = None, + should_repr_strings: Optional[bool] = None, + segment: Optional[Segment] = None, + remaining_breadth: Optional[Union[int, float]] = None, + remaining_depth: Optional[Union[int, float]] = None, + ) -> Any: if segment is not None: path.append(segment) @@ -227,22 +215,20 @@ def _serialize_node( path.pop() del meta_stack[len(path) + 1 :] - def _flatten_annotated(obj): - # type: (Any) -> Any + def _flatten_annotated(obj: Any) -> Any: if isinstance(obj, AnnotatedValue): _annotate(**obj.metadata) obj = obj.value return obj def _serialize_node_impl( - obj, - is_databag, - is_request_body, - should_repr_strings, - remaining_depth, - remaining_breadth, - ): - # type: (Any, Optional[bool], Optional[bool], Optional[bool], Optional[Union[float, int]], Optional[Union[float, int]]) -> Any + obj: Any, + is_databag: Optional[bool], + is_request_body: Optional[bool], + should_repr_strings: Optional[bool], + remaining_depth: Optional[Union[float, int]], + remaining_breadth: Optional[Union[float, int]], + ) -> Any: if isinstance(obj, AnnotatedValue): should_repr_strings = False if should_repr_strings is None: @@ -306,7 +292,7 @@ def _serialize_node_impl( # might mutate our dictionary while we're still iterating over it. obj = dict(obj.items()) - rv_dict = {} # type: Dict[str, Any] + rv_dict: Dict[str, Any] = {} i = 0 for k, v in obj.items(): diff --git a/sentry_sdk/session.py b/sentry_sdk/session.py index c1d422c115..e392bc354b 100644 --- a/sentry_sdk/session.py +++ b/sentry_sdk/session.py @@ -1,3 +1,4 @@ +from __future__ import annotations import uuid from datetime import datetime, timezone @@ -6,23 +7,15 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Optional - from typing import Union - from typing import Any - from typing import Dict - from sentry_sdk._types import SessionStatus + from typing import Optional, Union, Any, Dict -def _minute_trunc(ts): - # type: (datetime) -> datetime +def _minute_trunc(ts: datetime) -> datetime: return ts.replace(second=0, microsecond=0) -def _make_uuid( - val, # type: Union[str, uuid.UUID] -): - # type: (...) -> uuid.UUID +def _make_uuid(val: Union[str, uuid.UUID]) -> uuid.UUID: if isinstance(val, uuid.UUID): return val return uuid.UUID(val) @@ -31,21 +24,20 @@ def _make_uuid( class Session: def __init__( self, - sid=None, # type: Optional[Union[str, uuid.UUID]] - did=None, # type: Optional[str] - timestamp=None, # type: Optional[datetime] - started=None, # type: Optional[datetime] - duration=None, # type: Optional[float] - status=None, # type: Optional[SessionStatus] - release=None, # type: Optional[str] - environment=None, # type: Optional[str] - user_agent=None, # type: Optional[str] - ip_address=None, # type: Optional[str] - errors=None, # type: Optional[int] - user=None, # type: Optional[Any] - session_mode="application", # type: str - ): - # type: (...) -> None + sid: Optional[Union[str, uuid.UUID]] = None, + did: Optional[str] = None, + timestamp: Optional[datetime] = None, + started: Optional[datetime] = None, + duration: Optional[float] = None, + status: Optional[SessionStatus] = None, + release: Optional[str] = None, + environment: Optional[str] = None, + user_agent: Optional[str] = None, + ip_address: Optional[str] = None, + errors: Optional[int] = None, + user: Optional[Any] = None, + session_mode: str = "application", + ) -> None: if sid is None: sid = uuid.uuid4() if started is None: @@ -53,14 +45,14 @@ def __init__( if status is None: status = "ok" self.status = status - self.did = None # type: Optional[str] + self.did: Optional[str] = None self.started = started - self.release = None # type: Optional[str] - self.environment = None # type: Optional[str] - self.duration = None # type: Optional[float] - self.user_agent = None # type: Optional[str] - self.ip_address = None # type: Optional[str] - self.session_mode = session_mode # type: str + self.release: Optional[str] = None + self.environment: Optional[str] = None + self.duration: Optional[float] = None + self.user_agent: Optional[str] = None + self.ip_address: Optional[str] = None + self.session_mode: str = session_mode self.errors = 0 self.update( @@ -77,26 +69,24 @@ def __init__( ) @property - def truncated_started(self): - # type: (...) -> datetime + def truncated_started(self) -> datetime: return _minute_trunc(self.started) def update( self, - sid=None, # type: Optional[Union[str, uuid.UUID]] - did=None, # type: Optional[str] - timestamp=None, # type: Optional[datetime] - started=None, # type: Optional[datetime] - duration=None, # type: Optional[float] - status=None, # type: Optional[SessionStatus] - release=None, # type: Optional[str] - environment=None, # type: Optional[str] - user_agent=None, # type: Optional[str] - ip_address=None, # type: Optional[str] - errors=None, # type: Optional[int] - user=None, # type: Optional[Any] - ): - # type: (...) -> None + sid: Optional[Union[str, uuid.UUID]] = None, + did: Optional[str] = None, + timestamp: Optional[datetime] = None, + started: Optional[datetime] = None, + duration: Optional[float] = None, + status: Optional[SessionStatus] = None, + release: Optional[str] = None, + environment: Optional[str] = None, + user_agent: Optional[str] = None, + ip_address: Optional[str] = None, + errors: Optional[int] = None, + user: Optional[Any] = None, + ) -> None: # If a user is supplied we pull some data form it if user: if ip_address is None: @@ -129,19 +119,13 @@ def update( if status is not None: self.status = status - def close( - self, status=None # type: Optional[SessionStatus] - ): - # type: (...) -> Any + def close(self, status: Optional[SessionStatus] = None) -> Any: if status is None and self.status == "ok": status = "exited" if status is not None: self.update(status=status) - def get_json_attrs( - self, with_user_info=True # type: Optional[bool] - ): - # type: (...) -> Any + def get_json_attrs(self, with_user_info: bool = True) -> Any: attrs = {} if self.release is not None: attrs["release"] = self.release @@ -154,15 +138,14 @@ def get_json_attrs( attrs["user_agent"] = self.user_agent return attrs - def to_json(self): - # type: (...) -> Any - rv = { + def to_json(self) -> Any: + rv: Dict[str, Any] = { "sid": str(self.sid), "init": True, "started": format_timestamp(self.started), "timestamp": format_timestamp(self.timestamp), "status": self.status, - } # type: Dict[str, Any] + } if self.errors: rv["errors"] = self.errors if self.did is not None: diff --git a/sentry_sdk/sessions.py b/sentry_sdk/sessions.py index 162023a54a..84c046043a 100644 --- a/sentry_sdk/sessions.py +++ b/sentry_sdk/sessions.py @@ -1,3 +1,4 @@ +from __future__ import annotations import os import time from threading import Thread, Lock @@ -11,16 +12,17 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Any - from typing import Callable - from typing import Dict - from typing import Generator - from typing import List - from typing import Optional + from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Generator, + ) -def _is_auto_session_tracking_enabled(scope): - # type: (sentry_sdk.Scope) -> bool +def _is_auto_session_tracking_enabled(scope: sentry_sdk.Scope) -> bool: """ Utility function to find out if session tracking is enabled. """ @@ -34,8 +36,9 @@ def _is_auto_session_tracking_enabled(scope): @contextmanager -def track_session(scope, session_mode="application"): - # type: (sentry_sdk.Scope, str) -> Generator[None, None, None] +def track_session( + scope: sentry_sdk.Scope, session_mode: str = "application" +) -> Generator[None, None, None]: """ Start a new session in the provided scope, assuming session tracking is enabled. This is a no-op context manager if session tracking is not enabled. @@ -55,30 +58,27 @@ def track_session(scope, session_mode="application"): MAX_ENVELOPE_ITEMS = 100 -def make_aggregate_envelope(aggregate_states, attrs): - # type: (Any, Any) -> Any +def make_aggregate_envelope(aggregate_states: Any, attrs: Any) -> Any: return {"attrs": dict(attrs), "aggregates": list(aggregate_states.values())} class SessionFlusher: def __init__( self, - capture_func, # type: Callable[[Envelope], None] - flush_interval=60, # type: int - ): - # type: (...) -> None + capture_func: Callable[[Envelope], None], + flush_interval: int = 60, + ) -> None: self.capture_func = capture_func self.flush_interval = flush_interval - self.pending_sessions = [] # type: List[Any] - self.pending_aggregates = {} # type: Dict[Any, Any] - self._thread = None # type: Optional[Thread] + self.pending_sessions: List[Any] = [] + self.pending_aggregates: Dict[Any, Any] = {} + self._thread: Optional[Thread] = None self._thread_lock = Lock() self._aggregate_lock = Lock() - self._thread_for_pid = None # type: Optional[int] + self._thread_for_pid: Optional[int] = None self._running = True - def flush(self): - # type: (...) -> None + def flush(self) -> None: pending_sessions = self.pending_sessions self.pending_sessions = [] @@ -104,8 +104,7 @@ def flush(self): if len(envelope.items) > 0: self.capture_func(envelope) - def _ensure_running(self): - # type: (...) -> None + def _ensure_running(self) -> None: """ Check that we have an active thread to run in, or create one if not. @@ -119,8 +118,7 @@ def _ensure_running(self): if self._thread_for_pid == os.getpid() and self._thread is not None: return None - def _thread(): - # type: (...) -> None + def _thread() -> None: while self._running: time.sleep(self.flush_interval) if self._running: @@ -141,10 +139,7 @@ def _thread(): return None - def add_aggregate_session( - self, session # type: Session - ): - # type: (...) -> None + def add_aggregate_session(self, session: Session) -> None: # NOTE on `session.did`: # the protocol can deal with buckets that have a distinct-id, however # in practice we expect the python SDK to have an extremely high cardinality @@ -172,20 +167,15 @@ def add_aggregate_session( else: state["exited"] = state.get("exited", 0) + 1 - def add_session( - self, session # type: Session - ): - # type: (...) -> None + def add_session(self, session: Session) -> None: if session.session_mode == "request": self.add_aggregate_session(session) else: self.pending_sessions.append(session.to_json()) self._ensure_running() - def kill(self): - # type: (...) -> None + def kill(self) -> None: self._running = False - def __del__(self): - # type: (...) -> None + def __del__(self) -> None: self.kill() diff --git a/sentry_sdk/spotlight.py b/sentry_sdk/spotlight.py index 4ac427b9c1..976879dc84 100644 --- a/sentry_sdk/spotlight.py +++ b/sentry_sdk/spotlight.py @@ -1,3 +1,4 @@ +from __future__ import annotations import io import logging import os @@ -12,11 +13,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Any - from typing import Callable - from typing import Dict - from typing import Optional - from typing import Self + from typing import Any, Callable, Dict, Optional from sentry_sdk.utils import ( logger as sentry_logger, @@ -34,14 +31,12 @@ class SpotlightClient: - def __init__(self, url): - # type: (str) -> None + def __init__(self, url: str) -> None: self.url = url self.http = urllib3.PoolManager() self.fails = 0 - def capture_envelope(self, envelope): - # type: (Envelope) -> None + def capture_envelope(self, envelope: Envelope) -> None: body = io.BytesIO() envelope.serialize_into(body) try: @@ -90,11 +85,10 @@ def capture_envelope(self, envelope): ) class SpotlightMiddleware(MiddlewareMixin): # type: ignore[misc] - _spotlight_script = None # type: Optional[str] - _spotlight_url = None # type: Optional[str] + _spotlight_script: Optional[str] = None + _spotlight_url: Optional[str] = None - def __init__(self, get_response): - # type: (Self, Callable[..., HttpResponse]) -> None + def __init__(self, get_response: Callable[..., HttpResponse]) -> None: super().__init__(get_response) import sentry_sdk.api @@ -111,8 +105,7 @@ def __init__(self, get_response): self._spotlight_url = urllib.parse.urljoin(spotlight_client.url, "../") @property - def spotlight_script(self): - # type: (Self) -> Optional[str] + def spotlight_script(self) -> Optional[str]: if self._spotlight_url is not None and self._spotlight_script is None: try: spotlight_js_url = urllib.parse.urljoin( @@ -136,8 +129,9 @@ def spotlight_script(self): return self._spotlight_script - def process_response(self, _request, response): - # type: (Self, HttpRequest, HttpResponse) -> Optional[HttpResponse] + def process_response( + self, _request: HttpRequest, response: HttpResponse + ) -> Optional[HttpResponse]: content_type_header = tuple( p.strip() for p in response.headers.get("Content-Type", "").lower().split(";") @@ -181,8 +175,9 @@ def process_response(self, _request, response): return response - def process_exception(self, _request, exception): - # type: (Self, HttpRequest, Exception) -> Optional[HttpResponseServerError] + def process_exception( + self, _request: HttpRequest, exception: Exception + ) -> Optional[HttpResponseServerError]: if not settings.DEBUG or not self._spotlight_url: return None @@ -207,8 +202,7 @@ def process_exception(self, _request, exception): settings = None -def setup_spotlight(options): - # type: (Dict[str, Any]) -> Optional[SpotlightClient] +def setup_spotlight(options: Dict[str, Any]) -> Optional[SpotlightClient]: _handler = logging.StreamHandler(sys.stderr) _handler.setFormatter(logging.Formatter(" [spotlight] %(levelname)s: %(message)s")) logger.addHandler(_handler) diff --git a/sentry_sdk/tracing.py b/sentry_sdk/tracing.py index f15f07065a..6ab3486ef7 100644 --- a/sentry_sdk/tracing.py +++ b/sentry_sdk/tracing.py @@ -1,3 +1,4 @@ +from __future__ import annotations from datetime import datetime import json import warnings @@ -45,29 +46,26 @@ should_be_treated_as_error, ) -from typing import TYPE_CHECKING, cast - +from typing import TYPE_CHECKING, overload if TYPE_CHECKING: - from collections.abc import Callable - from typing import Any - from typing import Dict - from typing import Iterator - from typing import Optional - from typing import overload - from typing import ParamSpec - from typing import Tuple - from typing import Union - from typing import TypeVar + from typing import ( + Callable, + Any, + Dict, + Iterator, + Optional, + ParamSpec, + Tuple, + Union, + TypeVar, + ) + from sentry_sdk._types import SamplingContext + from sentry_sdk.tracing_utils import Baggage P = ParamSpec("P") R = TypeVar("R") - from sentry_sdk._types import ( - SamplingContext, - ) - - from sentry_sdk.tracing_utils import Baggage _FLAGS_CAPACITY = 10 _OTEL_VERSION = parse_version(otel_version) @@ -76,88 +74,65 @@ class NoOpSpan: - def __init__(self, **kwargs): - # type: (Any) -> None + def __init__(self, **kwargs: Any) -> None: pass - def __repr__(self): - # type: () -> str + def __repr__(self) -> str: return "<%s>" % self.__class__.__name__ @property - def root_span(self): - # type: () -> Optional[Span] + def root_span(self) -> Optional[Span]: return None - def start_child(self, **kwargs): - # type: (**Any) -> NoOpSpan + def start_child(self, **kwargs: Any) -> NoOpSpan: return NoOpSpan() - def to_traceparent(self): - # type: () -> str + def to_traceparent(self) -> str: return "" - def to_baggage(self): - # type: () -> Optional[Baggage] + def to_baggage(self) -> Optional[Baggage]: return None - def get_baggage(self): - # type: () -> Optional[Baggage] + def get_baggage(self) -> Optional[Baggage]: return None - def iter_headers(self): - # type: () -> Iterator[Tuple[str, str]] + def iter_headers(self) -> Iterator[Tuple[str, str]]: return iter(()) - def set_tag(self, key, value): - # type: (str, Any) -> None + def set_tag(self, key: str, value: Any) -> None: pass - def set_data(self, key, value): - # type: (str, Any) -> None + def set_data(self, key: str, value: Any) -> None: pass - def set_status(self, value): - # type: (str) -> None + def set_status(self, value: str) -> None: pass - def set_http_status(self, http_status): - # type: (int) -> None + def set_http_status(self, http_status: int) -> None: pass - def is_success(self): - # type: () -> bool + def is_success(self) -> bool: return True - def to_json(self): - # type: () -> Dict[str, Any] + def to_json(self) -> Dict[str, Any]: return {} - def get_trace_context(self): - # type: () -> Any + def get_trace_context(self) -> Any: return {} - def get_profile_context(self): - # type: () -> Any + def get_profile_context(self) -> Any: return {} - def finish( - self, - end_timestamp=None, # type: Optional[Union[float, datetime]] - ): - # type: (...) -> None + def finish(self, end_timestamp: Optional[Union[float, datetime]] = None) -> None: pass - def set_context(self, key, value): - # type: (str, dict[str, Any]) -> None + def set_context(self, key: str, value: dict[str, Any]) -> None: pass - def init_span_recorder(self, maxlen): - # type: (int) -> None + def init_span_recorder(self, maxlen: int) -> None: pass - def _set_initial_sampling_decision(self, sampling_context): - # type: (SamplingContext) -> None + def _set_initial_sampling_decision(self, sampling_context: SamplingContext) -> None: pass @@ -169,21 +144,20 @@ class Span: def __init__( self, *, - op=None, # type: Optional[str] - description=None, # type: Optional[str] - status=None, # type: Optional[str] - sampled=None, # type: Optional[bool] - start_timestamp=None, # type: Optional[Union[datetime, float]] - origin=None, # type: Optional[str] - name=None, # type: Optional[str] - source=TransactionSource.CUSTOM, # type: str - attributes=None, # type: Optional[dict[str, Any]] - only_if_parent=False, # type: bool - parent_span=None, # type: Optional[Span] - otel_span=None, # type: Optional[OtelSpan] - span=None, # type: Optional[Span] - ): - # type: (...) -> None + op: Optional[str] = None, + description: Optional[str] = None, + status: Optional[str] = None, + sampled: Optional[bool] = None, + start_timestamp: Optional[Union[datetime, float]] = None, + origin: Optional[str] = None, + name: Optional[str] = None, + source: str = TransactionSource.CUSTOM, + attributes: Optional[dict[str, Any]] = None, + only_if_parent: bool = False, + parent_span: Optional[Span] = None, + otel_span: Optional[OtelSpan] = None, + span: Optional[Span] = None, + ) -> None: """ If otel_span is passed explicitly, just acts as a proxy. @@ -248,14 +222,12 @@ def __init__( self.update_active_thread() - def __eq__(self, other): - # type: (object) -> bool + def __eq__(self, other: object) -> bool: if not isinstance(other, Span): return False return self._otel_span == other._otel_span - def __repr__(self): - # type: () -> str + def __repr__(self) -> str: return ( "<%s(op=%r, name:%r, trace_id=%r, span_id=%r, parent_span_id=%r, sampled=%r, origin=%r)>" % ( @@ -270,25 +242,23 @@ def __repr__(self): ) ) - def activate(self): - # type: () -> None + def activate(self) -> None: ctx = otel_trace.set_span_in_context(self._otel_span) # set as the implicit current context self._ctx_token = context.attach(ctx) - def deactivate(self): - # type: () -> None + def deactivate(self) -> None: if self._ctx_token: context.detach(self._ctx_token) del self._ctx_token - def __enter__(self): - # type: () -> Span + def __enter__(self) -> Span: self.activate() return self - def __exit__(self, ty, value, tb): - # type: (Optional[Any], Optional[Any], Optional[Any]) -> None + def __exit__( + self, ty: Optional[Any], value: Optional[Any], tb: Optional[Any] + ) -> None: if value is not None and should_be_treated_as_error(ty, value): self.set_status(SPANSTATUS.INTERNAL_ERROR) else: @@ -303,41 +273,34 @@ def __exit__(self, ty, value, tb): self.deactivate() @property - def description(self): - # type: () -> Optional[str] + def description(self) -> Optional[str]: return self.get_attribute(SentrySpanAttribute.DESCRIPTION) @description.setter - def description(self, value): - # type: (Optional[str]) -> None + def description(self, value: Optional[str]) -> None: self.set_attribute(SentrySpanAttribute.DESCRIPTION, value) @property - def origin(self): - # type: () -> Optional[str] + def origin(self) -> Optional[str]: return self.get_attribute(SentrySpanAttribute.ORIGIN) @origin.setter - def origin(self, value): - # type: (Optional[str]) -> None + def origin(self, value: Optional[str]) -> None: self.set_attribute(SentrySpanAttribute.ORIGIN, value) @property - def root_span(self): - # type: () -> Optional[Span] - root_otel_span = cast( - "Optional[OtelSpan]", get_sentry_meta(self._otel_span, "root_span") + def root_span(self) -> Optional[Span]: + root_otel_span: Optional[OtelSpan] = get_sentry_meta( + self._otel_span, "root_span" ) return Span(otel_span=root_otel_span) if root_otel_span else None @property - def is_root_span(self): - # type: () -> bool + def is_root_span(self) -> bool: return self.root_span == self @property - def parent_span_id(self): - # type: () -> Optional[str] + def parent_span_id(self) -> Optional[str]: if ( not isinstance(self._otel_span, ReadableSpan) or self._otel_span.parent is None @@ -346,70 +309,58 @@ def parent_span_id(self): return format_span_id(self._otel_span.parent.span_id) @property - def trace_id(self): - # type: () -> str + def trace_id(self) -> str: return format_trace_id(self._otel_span.get_span_context().trace_id) @property - def span_id(self): - # type: () -> str + def span_id(self) -> str: return format_span_id(self._otel_span.get_span_context().span_id) @property - def is_valid(self): - # type: () -> bool + def is_valid(self) -> bool: return self._otel_span.get_span_context().is_valid and isinstance( self._otel_span, ReadableSpan ) @property - def sampled(self): - # type: () -> Optional[bool] + def sampled(self) -> Optional[bool]: return self._otel_span.get_span_context().trace_flags.sampled @property - def sample_rate(self): - # type: () -> Optional[float] + def sample_rate(self) -> Optional[float]: sample_rate = self._otel_span.get_span_context().trace_state.get( TRACESTATE_SAMPLE_RATE_KEY ) return float(sample_rate) if sample_rate is not None else None @property - def op(self): - # type: () -> Optional[str] + def op(self) -> Optional[str]: return self.get_attribute(SentrySpanAttribute.OP) @op.setter - def op(self, value): - # type: (Optional[str]) -> None + def op(self, value: Optional[str]) -> None: self.set_attribute(SentrySpanAttribute.OP, value) @property - def name(self): - # type: () -> Optional[str] + def name(self) -> Optional[str]: return self.get_attribute(SentrySpanAttribute.NAME) @name.setter - def name(self, value): - # type: (Optional[str]) -> None + def name(self, value: Optional[str]) -> None: self.set_attribute(SentrySpanAttribute.NAME, value) @property - def source(self): - # type: () -> str + def source(self) -> str: return ( self.get_attribute(SentrySpanAttribute.SOURCE) or TransactionSource.CUSTOM ) @source.setter - def source(self, value): - # type: (str) -> None + def source(self, value: str) -> None: self.set_attribute(SentrySpanAttribute.SOURCE, value) @property - def start_timestamp(self): - # type: () -> Optional[datetime] + def start_timestamp(self) -> Optional[datetime]: if not isinstance(self._otel_span, ReadableSpan): return None @@ -420,8 +371,7 @@ def start_timestamp(self): return convert_from_otel_timestamp(start_time) @property - def timestamp(self): - # type: () -> Optional[datetime] + def timestamp(self) -> Optional[datetime]: if not isinstance(self._otel_span, ReadableSpan): return None @@ -431,17 +381,14 @@ def timestamp(self): return convert_from_otel_timestamp(end_time) - def start_child(self, **kwargs): - # type: (**Any) -> Span + def start_child(self, **kwargs: Any) -> Span: return Span(parent_span=self, **kwargs) - def iter_headers(self): - # type: () -> Iterator[Tuple[str, str]] + def iter_headers(self) -> Iterator[Tuple[str, str]]: yield SENTRY_TRACE_HEADER_NAME, self.to_traceparent() yield BAGGAGE_HEADER_NAME, serialize_trace_state(self.trace_state) - def to_traceparent(self): - # type: () -> str + def to_traceparent(self) -> str: if self.sampled is True: sampled = "1" elif self.sampled is False: @@ -456,24 +403,19 @@ def to_traceparent(self): return traceparent @property - def trace_state(self): - # type: () -> TraceState + def trace_state(self) -> TraceState: return get_trace_state(self._otel_span) - def to_baggage(self): - # type: () -> Baggage + def to_baggage(self) -> Baggage: return self.get_baggage() - def get_baggage(self): - # type: () -> Baggage + def get_baggage(self) -> Baggage: return baggage_from_trace_state(self.trace_state) - def set_tag(self, key, value): - # type: (str, Any) -> None + def set_tag(self, key: str, value: Any) -> None: self.set_attribute(f"{SentrySpanAttribute.TAG}.{key}", value) - def set_data(self, key, value): - # type: (str, Any) -> None + def set_data(self, key: str, value: Any) -> None: warnings.warn( "`Span.set_data` is deprecated. Please use `Span.set_attribute` instead.", DeprecationWarning, @@ -483,8 +425,7 @@ def set_data(self, key, value): # TODO-neel-potel we cannot add dicts here self.set_attribute(key, value) - def get_attribute(self, name): - # type: (str) -> Optional[Any] + def get_attribute(self, name: str) -> Optional[Any]: if ( not isinstance(self._otel_span, ReadableSpan) or not self._otel_span.attributes @@ -492,8 +433,7 @@ def get_attribute(self, name): return None return self._otel_span.attributes.get(name) - def set_attribute(self, key, value): - # type: (str, Any) -> None + def set_attribute(self, key: str, value: Any) -> None: # otel doesn't support None as values, preferring to not set the key # at all instead if value is None: @@ -505,8 +445,7 @@ def set_attribute(self, key, value): self._otel_span.set_attribute(key, serialized_value) @property - def status(self): - # type: () -> Optional[str] + def status(self) -> Optional[str]: """ Return the Sentry `SPANSTATUS` corresponding to the underlying OTel status. Because differences in possible values in OTel `StatusCode` and @@ -523,8 +462,7 @@ def status(self): else: return SPANSTATUS.UNKNOWN_ERROR - def set_status(self, status): - # type: (str) -> None + def set_status(self, status: str) -> None: if status == SPANSTATUS.OK: otel_status = StatusCode.OK otel_description = None @@ -537,37 +475,31 @@ def set_status(self, status): else: self._otel_span.set_status(Status(otel_status, otel_description)) - def set_thread(self, thread_id, thread_name): - # type: (Optional[int], Optional[str]) -> None + def set_thread(self, thread_id: Optional[int], thread_name: Optional[str]) -> None: if thread_id is not None: self.set_attribute(SPANDATA.THREAD_ID, str(thread_id)) if thread_name is not None: self.set_attribute(SPANDATA.THREAD_NAME, thread_name) - def update_active_thread(self): - # type: () -> None + def update_active_thread(self) -> None: thread_id, thread_name = get_current_thread_meta() self.set_thread(thread_id, thread_name) - def set_http_status(self, http_status): - # type: (int) -> None + def set_http_status(self, http_status: int) -> None: self.set_attribute(SPANDATA.HTTP_STATUS_CODE, http_status) self.set_status(get_span_status_from_http_code(http_status)) - def is_success(self): - # type: () -> bool + def is_success(self) -> bool: return self.status == SPANSTATUS.OK - def finish(self, end_timestamp=None): - # type: (Optional[Union[float, datetime]]) -> None + def finish(self, end_timestamp: Optional[Union[float, datetime]] = None) -> None: if end_timestamp is not None: self._otel_span.end(convert_to_otel_timestamp(end_timestamp)) else: self._otel_span.end() - def to_json(self): - # type: () -> dict[str, Any] + def to_json(self) -> dict[str, Any]: """ Only meant for testing. Not used internally anymore. """ @@ -575,21 +507,18 @@ def to_json(self): return {} return json.loads(self._otel_span.to_json()) - def get_trace_context(self): - # type: () -> dict[str, Any] + def get_trace_context(self) -> dict[str, Any]: if not isinstance(self._otel_span, ReadableSpan): return {} return get_trace_context(self._otel_span) - def set_context(self, key, value): - # type: (str, Any) -> None + def set_context(self, key: str, value: Any) -> None: # TODO-neel-potel we cannot add dicts here self.set_attribute(f"{SentrySpanAttribute.CONTEXT}.{key}", value) - def set_flag(self, flag, value): - # type: (str, bool) -> None + def set_flag(self, flag: str, value: bool) -> None: flag_count = self.get_attribute("_flag.count") or 0 if flag_count < _FLAGS_CAPACITY: self.set_attribute(f"flag.evaluation.{flag}", value) @@ -603,18 +532,17 @@ def set_flag(self, flag, value): if TYPE_CHECKING: @overload - def trace(func=None): - # type: (None) -> Callable[[Callable[P, R]], Callable[P, R]] + def trace(func: None = None) -> Callable[[Callable[P, R]], Callable[P, R]]: pass @overload - def trace(func): - # type: (Callable[P, R]) -> Callable[P, R] + def trace(func: Callable[P, R]) -> Callable[P, R]: pass -def trace(func=None): - # type: (Optional[Callable[P, R]]) -> Union[Callable[P, R], Callable[[Callable[P, R]], Callable[P, R]]] +def trace( + func: Optional[Callable[P, R]] = None, +) -> Union[Callable[P, R], Callable[[Callable[P, R]], Callable[P, R]]]: """ Decorator to start a child span under the existing current transaction. If there is no current transaction, then nothing will be traced. diff --git a/sentry_sdk/tracing_utils.py b/sentry_sdk/tracing_utils.py index 140ce57139..fecb82e09e 100644 --- a/sentry_sdk/tracing_utils.py +++ b/sentry_sdk/tracing_utils.py @@ -1,5 +1,5 @@ +from __future__ import annotations import contextlib -import decimal import inspect import os import re @@ -37,12 +37,8 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Any - from typing import Dict - from typing import Generator - from typing import Optional - from typing import Union from types import FrameType + from typing import Any, Dict, Generator, Optional, Union SENTRY_TRACE_REGEX = re.compile( @@ -69,23 +65,19 @@ class EnvironHeaders(Mapping): # type: ignore def __init__( self, - environ, # type: Mapping[str, str] - prefix="HTTP_", # type: str - ): - # type: (...) -> None + environ: Mapping[str, str], + prefix: str = "HTTP_", + ) -> None: self.environ = environ self.prefix = prefix - def __getitem__(self, key): - # type: (str) -> Optional[Any] + def __getitem__(self, key: str) -> Optional[Any]: return self.environ[self.prefix + key.replace("-", "_").upper()] - def __len__(self): - # type: () -> int + def __len__(self) -> int: return sum(1 for _ in iter(self)) - def __iter__(self): - # type: () -> Generator[str, None, None] + def __iter__(self) -> Generator[str, None, None]: for k in self.environ: if not isinstance(k, str): continue @@ -97,8 +89,7 @@ def __iter__(self): yield k[len(self.prefix) :] -def has_tracing_enabled(options): - # type: (Optional[Dict[str, Any]]) -> bool +def has_tracing_enabled(options: dict[str, Any]) -> bool: """ Returns True if either traces_sample_rate or traces_sampler is defined. @@ -114,16 +105,14 @@ def has_tracing_enabled(options): @contextlib.contextmanager def record_sql_queries( - cursor, # type: Any - query, # type: Any - params_list, # type: Any - paramstyle, # type: Optional[str] - executemany, # type: bool - record_cursor_repr=False, # type: bool - span_origin=None, # type: Optional[str] -): - # type: (...) -> Generator[sentry_sdk.tracing.Span, None, None] - + cursor: Any, + query: Any, + params_list: Any, + paramstyle: Optional[str], + executemany: bool, + record_cursor_repr: bool = False, + span_origin: Optional[str] = None, +) -> Generator[sentry_sdk.tracing.Span, None, None]: # TODO: Bring back capturing of params by default if sentry_sdk.get_client().options["_experiments"].get("record_sql_params", False): if not params_list or params_list == [None]: @@ -161,8 +150,7 @@ def record_sql_queries( yield span -def _get_frame_module_abs_path(frame): - # type: (FrameType) -> Optional[str] +def _get_frame_module_abs_path(frame: FrameType) -> Optional[str]: try: return frame.f_code.co_filename except Exception: @@ -170,14 +158,13 @@ def _get_frame_module_abs_path(frame): def _should_be_included( - is_sentry_sdk_frame, # type: bool - namespace, # type: Optional[str] - in_app_include, # type: Optional[list[str]] - in_app_exclude, # type: Optional[list[str]] - abs_path, # type: Optional[str] - project_root, # type: Optional[str] -): - # type: (...) -> bool + is_sentry_sdk_frame: bool, + namespace: Optional[str], + in_app_include: Optional[list[str]], + in_app_exclude: Optional[list[str]], + abs_path: Optional[str], + project_root: Optional[str], +) -> bool: # in_app_include takes precedence over in_app_exclude should_be_included = _module_in_list(namespace, in_app_include) should_be_excluded = _is_external_source(abs_path) or _module_in_list( @@ -189,8 +176,7 @@ def _should_be_included( ) -def add_query_source(span): - # type: (sentry_sdk.tracing.Span) -> None +def add_query_source(span: sentry_sdk.tracing.Span) -> None: """ Adds OTel compatible source code information to the span """ @@ -220,12 +206,12 @@ def add_query_source(span): in_app_exclude = client.options.get("in_app_exclude") # Find the correct frame - frame = sys._getframe() # type: Union[FrameType, None] + frame: Optional[FrameType] = sys._getframe() while frame is not None: abs_path = _get_frame_module_abs_path(frame) try: - namespace = frame.f_globals.get("__name__") # type: Optional[str] + namespace: Optional[str] = frame.f_globals.get("__name__") except Exception: namespace = None @@ -283,8 +269,9 @@ def add_query_source(span): span.set_attribute(SPANDATA.CODE_FUNCTION, frame.f_code.co_name) -def extract_sentrytrace_data(header): - # type: (Optional[str]) -> Optional[Dict[str, Union[str, bool, None]]] +def extract_sentrytrace_data( + header: Optional[str], +) -> Optional[Dict[str, Union[str, bool, None]]]: """ Given a `sentry-trace` header string, return a dictionary of data. """ @@ -315,8 +302,7 @@ def extract_sentrytrace_data(header): } -def _format_sql(cursor, sql): - # type: (Any, str) -> Optional[str] +def _format_sql(cursor: Any, sql: str) -> Optional[str]: real_sql = None @@ -350,13 +336,12 @@ class PropagationContext: def __init__( self, - trace_id=None, # type: Optional[str] - span_id=None, # type: Optional[str] - parent_span_id=None, # type: Optional[str] - parent_sampled=None, # type: Optional[bool] - baggage=None, # type: Optional[Baggage] - ): - # type: (...) -> None + trace_id: Optional[str] = None, + span_id: Optional[str] = None, + parent_span_id: Optional[str] = None, + parent_sampled: Optional[bool] = None, + baggage: Optional[Baggage] = None, + ) -> None: self._trace_id = trace_id """The trace id of the Sentry trace.""" @@ -376,13 +361,13 @@ def __init__( """Baggage object used for dynamic sampling decisions.""" @property - def dynamic_sampling_context(self): - # type: () -> Optional[Dict[str, str]] + def dynamic_sampling_context(self) -> Optional[Dict[str, str]]: return self.baggage.dynamic_sampling_context() if self.baggage else None @classmethod - def from_incoming_data(cls, incoming_data): - # type: (Dict[str, Any]) -> Optional[PropagationContext] + def from_incoming_data( + cls, incoming_data: Dict[str, Any] + ) -> Optional[PropagationContext]: propagation_context = None normalized_data = normalize_incoming_data(incoming_data) @@ -405,8 +390,7 @@ def from_incoming_data(cls, incoming_data): return propagation_context @property - def trace_id(self): - # type: () -> str + def trace_id(self) -> str: """The trace id of the Sentry trace.""" if not self._trace_id: self._trace_id = uuid.uuid4().hex @@ -414,13 +398,11 @@ def trace_id(self): return self._trace_id @trace_id.setter - def trace_id(self, value): - # type: (str) -> None + def trace_id(self, value: str) -> None: self._trace_id = value @property - def span_id(self): - # type: () -> str + def span_id(self) -> str: """The span id of the currently executed span.""" if not self._span_id: self._span_id = uuid.uuid4().hex[16:] @@ -428,12 +410,10 @@ def span_id(self): return self._span_id @span_id.setter - def span_id(self, value): - # type: (str) -> None + def span_id(self, value: str) -> None: self._span_id = value - def to_traceparent(self): - # type: () -> str + def to_traceparent(self) -> str: if self.parent_sampled is True: sampled = "1" elif self.parent_sampled is False: @@ -447,8 +427,7 @@ def to_traceparent(self): return traceparent - def update(self, other_dict): - # type: (Dict[str, Any]) -> None + def update(self, other_dict: Dict[str, Any]) -> None: """ Updates the PropagationContext with data from the given dictionary. """ @@ -458,8 +437,7 @@ def update(self, other_dict): except AttributeError: pass - def _fill_sample_rand(self): - # type: () -> None + def _fill_sample_rand(self) -> None: """ Ensure that there is a valid sample_rand value in the baggage. @@ -522,16 +500,14 @@ def _fill_sample_rand(self): self.baggage.sentry_items["sample_rand"] = f"{sample_rand:.6f}" # noqa: E231 - def _sample_rand(self): - # type: () -> Optional[str] + def _sample_rand(self) -> Optional[str]: """Convenience method to get the sample_rand value from the baggage.""" if self.baggage is None: return None return self.baggage.sentry_items.get("sample_rand") - def __repr__(self): - # type: (...) -> str + def __repr__(self) -> str: return "".format( self._trace_id, self._span_id, @@ -558,10 +534,10 @@ class Baggage: def __init__( self, - sentry_items, # type: Dict[str, str] - third_party_items="", # type: str - mutable=True, # type: bool - ): + sentry_items: Dict[str, str], + third_party_items: str = "", + mutable: bool = True, + ) -> None: self.sentry_items = sentry_items self.third_party_items = third_party_items self.mutable = mutable @@ -569,9 +545,8 @@ def __init__( @classmethod def from_incoming_header( cls, - header, # type: Optional[str] - ): - # type: (...) -> Baggage + header: Optional[str], + ) -> Baggage: """ freeze if incoming header already has sentry baggage """ @@ -597,10 +572,8 @@ def from_incoming_header( return Baggage(sentry_items, third_party_items, mutable) @classmethod - def from_options(cls, scope): - # type: (sentry_sdk.scope.Scope) -> Optional[Baggage] - - sentry_items = {} # type: Dict[str, str] + def from_options(cls, scope: sentry_sdk.scope.Scope) -> Optional[Baggage]: + sentry_items: Dict[str, str] = {} third_party_items = "" mutable = False @@ -629,12 +602,10 @@ def from_options(cls, scope): return Baggage(sentry_items, third_party_items, mutable) - def freeze(self): - # type: () -> None + def freeze(self) -> None: self.mutable = False - def dynamic_sampling_context(self): - # type: () -> Dict[str, str] + def dynamic_sampling_context(self) -> Dict[str, str]: header = {} for key, item in self.sentry_items.items(): @@ -642,8 +613,7 @@ def dynamic_sampling_context(self): return header - def serialize(self, include_third_party=False): - # type: (bool) -> str + def serialize(self, include_third_party: bool = False) -> str: items = [] for key, val in self.sentry_items.items(): @@ -657,8 +627,7 @@ def serialize(self, include_third_party=False): return ",".join(items) @staticmethod - def strip_sentry_baggage(header): - # type: (str) -> str + def strip_sentry_baggage(header: str) -> str: """Remove Sentry baggage from the given header. Given a Baggage header, return a new Baggage header with all Sentry baggage items removed. @@ -671,13 +640,11 @@ def strip_sentry_baggage(header): ) ) - def __repr__(self): - # type: () -> str + def __repr__(self) -> str: return f'' -def should_propagate_trace(client, url): - # type: (sentry_sdk.client.BaseClient, str) -> bool +def should_propagate_trace(client: sentry_sdk.client.BaseClient, url: str) -> bool: """ Returns True if url matches trace_propagation_targets configured in the given client. Otherwise, returns False. """ @@ -689,8 +656,7 @@ def should_propagate_trace(client, url): return match_regex_list(url, trace_propagation_targets, substring_matching=True) -def normalize_incoming_data(incoming_data): - # type: (Dict[str, Any]) -> Dict[str, Any] +def normalize_incoming_data(incoming_data: Dict[str, Any]) -> Dict[str, Any]: """ Normalizes incoming data so the keys are all lowercase with dashes instead of underscores and stripped from known prefixes. """ @@ -705,8 +671,7 @@ def normalize_incoming_data(incoming_data): return data -def start_child_span_decorator(func): - # type: (Any) -> Any +def start_child_span_decorator(func: Any) -> Any: """ Decorator to add child spans for functions. @@ -716,9 +681,7 @@ def start_child_span_decorator(func): if inspect.iscoroutinefunction(func): @wraps(func) - async def func_with_tracing(*args, **kwargs): - # type: (*Any, **Any) -> Any - + async def func_with_tracing(*args: Any, **kwargs: Any) -> Any: span = get_current_span() if span is None: @@ -737,7 +700,9 @@ async def func_with_tracing(*args, **kwargs): return await func(*args, **kwargs) try: - func_with_tracing.__signature__ = inspect.signature(func) # type: ignore[attr-defined] + func_with_tracing.__signature__ = inspect.signature( # type: ignore[attr-defined] + func + ) except Exception: pass @@ -745,9 +710,7 @@ async def func_with_tracing(*args, **kwargs): else: @wraps(func) - def func_with_tracing(*args, **kwargs): - # type: (*Any, **Any) -> Any - + def func_with_tracing(*args: Any, **kwargs: Any) -> Any: span = get_current_span() if span is None: @@ -766,15 +729,18 @@ def func_with_tracing(*args, **kwargs): return func(*args, **kwargs) try: - func_with_tracing.__signature__ = inspect.signature(func) # type: ignore[attr-defined] + func_with_tracing.__signature__ = inspect.signature( # type: ignore[attr-defined] + func + ) except Exception: pass return func_with_tracing -def get_current_span(scope=None): - # type: (Optional[sentry_sdk.Scope]) -> Optional[sentry_sdk.tracing.Span] +def get_current_span( + scope: Optional[sentry_sdk.scope.Scope] = None, +) -> Optional[sentry_sdk.tracing.Span]: """ Returns the currently active span if there is one running, otherwise `None` """ @@ -784,10 +750,9 @@ def get_current_span(scope=None): def _generate_sample_rand( - trace_id, # type: Optional[str] - interval=(0.0, 1.0), # type: tuple[float, float] -): - # type: (...) -> Optional[decimal.Decimal] + trace_id: Optional[str], + interval: tuple[float, float] = (0.0, 1.0), +) -> Decimal: """Generate a sample_rand value from a trace ID. The generated value will be pseudorandomly chosen from the provided @@ -817,8 +782,9 @@ def _generate_sample_rand( ) -def _sample_rand_range(parent_sampled, sample_rate): - # type: (Optional[bool], Optional[float]) -> tuple[float, float] +def _sample_rand_range( + parent_sampled: Optional[bool], sample_rate: Optional[float] +) -> tuple[float, float]: """ Compute the lower (inclusive) and upper (exclusive) bounds of the range of values that a generated sample_rand value must fall into, given the parent_sampled and @@ -832,8 +798,7 @@ def _sample_rand_range(parent_sampled, sample_rate): return sample_rate, 1.0 -def get_span_status_from_http_code(http_status_code): - # type: (int) -> str +def get_span_status_from_http_code(http_status_code: int) -> str: """ Returns the Sentry status corresponding to the given HTTP status code. diff --git a/sentry_sdk/transport.py b/sentry_sdk/transport.py index ec48f49be4..fddd04bccc 100644 --- a/sentry_sdk/transport.py +++ b/sentry_sdk/transport.py @@ -1,3 +1,4 @@ +from __future__ import annotations from abc import ABC, abstractmethod import io import os @@ -14,6 +15,14 @@ except ImportError: brotli = None +try: + import httpcore + import h2 # noqa: F401 + + HTTP2_ENABLED = True +except ImportError: + HTTP2_ENABLED = False + import urllib3 import certifi @@ -22,20 +31,23 @@ from sentry_sdk.worker import BackgroundWorker from sentry_sdk.envelope import Envelope, Item, PayloadRef -from typing import TYPE_CHECKING, cast, List, Dict +from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Any - from typing import Callable - from typing import DefaultDict - from typing import Iterable - from typing import Mapping - from typing import Optional - from typing import Self - from typing import Tuple - from typing import Type - from typing import Union - + from typing import ( + List, + Dict, + Any, + Callable, + DefaultDict, + Iterable, + Mapping, + Optional, + Tuple, + Type, + Union, + Self, + ) from urllib3.poolmanager import PoolManager from urllib3.poolmanager import ProxyManager @@ -62,10 +74,9 @@ class Transport(ABC): A transport is used to send an event to sentry. """ - parsed_dsn = None # type: Optional[Dsn] + parsed_dsn: Optional[Dsn] = None - def __init__(self, options=None): - # type: (Self, Optional[Dict[str, Any]]) -> None + def __init__(self: Self, options: Optional[Dict[str, Any]] = None) -> None: self.options = options if options and options["dsn"] is not None and options["dsn"]: self.parsed_dsn = Dsn(options["dsn"]) @@ -73,8 +84,7 @@ def __init__(self, options=None): self.parsed_dsn = None @abstractmethod - def capture_envelope(self, envelope): - # type: (Self, Envelope) -> None + def capture_envelope(self: Self, envelope: Envelope) -> None: """ Send an envelope to Sentry. @@ -85,11 +95,10 @@ def capture_envelope(self, envelope): pass def flush( - self, - timeout, - callback=None, - ): - # type: (Self, float, Optional[Any]) -> None + self: Self, + timeout: float, + callback: Optional[Any] = None, + ) -> None: """ Wait `timeout` seconds for the current events to be sent out. @@ -98,8 +107,7 @@ def flush( """ return None - def kill(self): - # type: (Self) -> None + def kill(self: Self) -> None: """ Forcefully kills the transport. @@ -109,14 +117,13 @@ def kill(self): return None def record_lost_event( - self, - reason, # type: str - data_category=None, # type: Optional[EventDataCategory] - item=None, # type: Optional[Item] + self: Self, + reason: str, + data_category: Optional[EventDataCategory] = None, + item: Optional[Item] = None, *, - quantity=1, # type: int - ): - # type: (...) -> None + quantity: int = 1, + ) -> None: """This increments a counter for event loss by reason and data category by the given positive-int quantity (default 1). @@ -133,20 +140,19 @@ def record_lost_event( """ return None - def is_healthy(self): - # type: (Self) -> bool + def is_healthy(self: Self) -> bool: return True - def __del__(self): - # type: (Self) -> None + def __del__(self: Self) -> None: try: self.kill() except Exception: pass -def _parse_rate_limits(header, now=None): - # type: (str, Optional[datetime]) -> Iterable[Tuple[Optional[EventDataCategory], datetime]] +def _parse_rate_limits( + header: str, now: Optional[datetime] = None +) -> Iterable[Tuple[Optional[str], datetime]]: if now is None: now = datetime.now(timezone.utc) @@ -157,7 +163,6 @@ def _parse_rate_limits(header, now=None): retry_after = now + timedelta(seconds=int(retry_after_val)) for category in categories and categories.split(";") or (None,): - category = cast("Optional[EventDataCategory]", category) yield category, retry_after except (LookupError, ValueError): continue @@ -168,21 +173,20 @@ class BaseHttpTransport(Transport): TIMEOUT = 30 # seconds - def __init__(self, options): - # type: (Self, Dict[str, Any]) -> None + def __init__(self: Self, options: Dict[str, Any]) -> None: from sentry_sdk.consts import VERSION Transport.__init__(self, options) assert self.parsed_dsn is not None - self.options = options # type: Dict[str, Any] + self.options: Dict[str, Any] = options self._worker = BackgroundWorker(queue_size=options["transport_queue_size"]) self._auth = self.parsed_dsn.to_auth("sentry.python/%s" % VERSION) - self._disabled_until = {} # type: Dict[Optional[EventDataCategory], datetime] + self._disabled_until: Dict[Optional[str], datetime] = {} # We only use this Retry() class for the `get_retry_after` method it exposes self._retry = urllib3.util.Retry() - self._discarded_events = defaultdict( - int - ) # type: DefaultDict[Tuple[EventDataCategory, str], int] + self._discarded_events: DefaultDict[Tuple[EventDataCategory, str], int] = ( + defaultdict(int) + ) self._last_client_report_sent = time.time() self._pool = self._make_pool() @@ -227,14 +231,13 @@ def __init__(self, options): self._compression_level = 4 def record_lost_event( - self, - reason, # type: str - data_category=None, # type: Optional[EventDataCategory] - item=None, # type: Optional[Item] + self: Self, + reason: str, + data_category: Optional[EventDataCategory] = None, + item: Optional[Item] = None, *, - quantity=1, # type: int - ): - # type: (...) -> None + quantity: int = 1, + ) -> None: if not self.options["send_client_reports"]: return @@ -247,9 +250,7 @@ def record_lost_event( event = item.get_transaction_event() or {} # +1 for the transaction itself - span_count = ( - len(cast(List[Dict[str, object]], event.get("spans") or [])) + 1 - ) + span_count = len(event.get("spans") or []) + 1 self.record_lost_event(reason, "span", quantity=span_count) elif data_category == "attachment": @@ -262,12 +263,12 @@ def record_lost_event( self._discarded_events[data_category, reason] += quantity - def _get_header_value(self, response, header): - # type: (Self, Any, str) -> Optional[str] + def _get_header_value(self: Self, response: Any, header: str) -> Optional[str]: return response.headers.get(header) - def _update_rate_limits(self, response): - # type: (Self, Union[urllib3.BaseHTTPResponse, httpcore.Response]) -> None + def _update_rate_limits( + self: Self, response: Union[urllib3.BaseHTTPResponse, httpcore.Response] + ) -> None: # new sentries with more rate limit insights. We honor this header # no matter of the status code to update our internal rate limits. @@ -292,16 +293,13 @@ def _update_rate_limits(self, response): ) def _send_request( - self, - body, - headers, - endpoint_type=EndpointType.ENVELOPE, - envelope=None, - ): - # type: (Self, bytes, Dict[str, str], EndpointType, Optional[Envelope]) -> None - - def record_loss(reason): - # type: (str) -> None + self: Self, + body: bytes, + headers: Dict[str, str], + endpoint_type: EndpointType = EndpointType.ENVELOPE, + envelope: Optional[Envelope] = None, + ) -> None: + def record_loss(reason: str) -> None: if envelope is None: self.record_lost_event(reason, data_category="error") else: @@ -348,12 +346,12 @@ def record_loss(reason): finally: response.close() - def on_dropped_event(self, _reason): - # type: (Self, str) -> None + def on_dropped_event(self: Self, _reason: str) -> None: return None - def _fetch_pending_client_report(self, force=False, interval=60): - # type: (Self, bool, int) -> Optional[Item] + def _fetch_pending_client_report( + self: Self, force: bool = False, interval: int = 60 + ) -> Optional[Item]: if not self.options["send_client_reports"]: return None @@ -383,37 +381,30 @@ def _fetch_pending_client_report(self, force=False, interval=60): type="client_report", ) - def _flush_client_reports(self, force=False): - # type: (Self, bool) -> None + def _flush_client_reports(self: Self, force: bool = False) -> None: client_report = self._fetch_pending_client_report(force=force, interval=60) if client_report is not None: self.capture_envelope(Envelope(items=[client_report])) - def _check_disabled(self, category): - # type: (str) -> bool - def _disabled(bucket): - # type: (Any) -> bool + def _check_disabled(self: Self, category: EventDataCategory) -> bool: + def _disabled(bucket: Optional[EventDataCategory]) -> bool: ts = self._disabled_until.get(bucket) return ts is not None and ts > datetime.now(timezone.utc) return _disabled(category) or _disabled(None) - def _is_rate_limited(self): - # type: (Self) -> bool + def _is_rate_limited(self: Self) -> bool: return any( ts > datetime.now(timezone.utc) for ts in self._disabled_until.values() ) - def _is_worker_full(self): - # type: (Self) -> bool + def _is_worker_full(self: Self) -> bool: return self._worker.full() - def is_healthy(self): - # type: (Self) -> bool + def is_healthy(self: Self) -> bool: return not (self._is_worker_full() or self._is_rate_limited()) - def _send_envelope(self, envelope): - # type: (Self, Envelope) -> None + def _send_envelope(self: Self, envelope: Envelope) -> None: # remove all items from the envelope which are over quota new_items = [] @@ -465,8 +456,9 @@ def _send_envelope(self, envelope): ) return None - def _serialize_envelope(self, envelope): - # type: (Self, Envelope) -> tuple[Optional[str], io.BytesIO] + def _serialize_envelope( + self: Self, envelope: Envelope + ) -> tuple[Optional[str], io.BytesIO]: content_encoding = None body = io.BytesIO() if self._compression_level == 0 or self._compression_algo is None: @@ -487,12 +479,10 @@ def _serialize_envelope(self, envelope): return content_encoding, body - def _get_pool_options(self): - # type: (Self) -> Dict[str, Any] + def _get_pool_options(self: Self) -> Dict[str, Any]: raise NotImplementedError() - def _in_no_proxy(self, parsed_dsn): - # type: (Self, Dsn) -> bool + def _in_no_proxy(self: Self, parsed_dsn: Dsn) -> bool: no_proxy = getproxies().get("no") if not no_proxy: return False @@ -502,26 +492,28 @@ def _in_no_proxy(self, parsed_dsn): return True return False - def _make_pool(self): - # type: (Self) -> Union[PoolManager, ProxyManager, httpcore.SOCKSProxy, httpcore.HTTPProxy, httpcore.ConnectionPool] + def _make_pool( + self: Self, + ) -> Union[ + PoolManager, + ProxyManager, + httpcore.SOCKSProxy, + httpcore.HTTPProxy, + httpcore.ConnectionPool, + ]: raise NotImplementedError() def _request( - self, - method, - endpoint_type, - body, - headers, - ): - # type: (Self, str, EndpointType, Any, Mapping[str, str]) -> Union[urllib3.BaseHTTPResponse, httpcore.Response] + self: Self, + method: str, + endpoint_type: EndpointType, + body: Any, + headers: Mapping[str, str], + ) -> Union[urllib3.BaseHTTPResponse, httpcore.Response]: raise NotImplementedError() - def capture_envelope( - self, envelope # type: Envelope - ): - # type: (...) -> None - def send_envelope_wrapper(): - # type: () -> None + def capture_envelope(self: Self, envelope: Envelope) -> None: + def send_envelope_wrapper() -> None: with capture_internal_exceptions(): self._send_envelope(envelope) self._flush_client_reports() @@ -532,19 +524,17 @@ def send_envelope_wrapper(): self.record_lost_event("queue_overflow", item=item) def flush( - self, - timeout, - callback=None, - ): - # type: (Self, float, Optional[Callable[[int, float], None]]) -> None + self: Self, + timeout: float, + callback: Optional[Callable[[int, float], None]] = None, + ) -> None: logger.debug("Flushing HTTP transport") if timeout > 0: self._worker.submit(lambda: self._flush_client_reports(force=True)) self._worker.flush(timeout, callback) - def kill(self): - # type: (Self) -> None + def kill(self: Self) -> None: logger.debug("Killing HTTP transport") self._worker.kill() @@ -553,8 +543,7 @@ class HttpTransport(BaseHttpTransport): if TYPE_CHECKING: _pool: Union[PoolManager, ProxyManager] - def _get_pool_options(self): - # type: (Self) -> Dict[str, Any] + def _get_pool_options(self: Self) -> Dict[str, Any]: num_pools = self.options.get("_experiments", {}).get("transport_num_pools") options = { @@ -563,7 +552,7 @@ def _get_pool_options(self): "timeout": urllib3.Timeout(total=self.TIMEOUT), } - socket_options = None # type: Optional[List[Tuple[int, int, int | bytes]]] + socket_options: Optional[List[Tuple[int, int, int | bytes]]] = None if self.options["socket_options"] is not None: socket_options = self.options["socket_options"] @@ -596,8 +585,7 @@ def _get_pool_options(self): return options - def _make_pool(self): - # type: (Self) -> Union[PoolManager, ProxyManager] + def _make_pool(self: Self) -> Union[PoolManager, ProxyManager]: if self.parsed_dsn is None: raise ValueError("Cannot create HTTP-based transport without valid DSN") @@ -643,13 +631,12 @@ def _make_pool(self): return urllib3.PoolManager(**opts) def _request( - self, - method, - endpoint_type, - body, - headers, - ): - # type: (Self, str, EndpointType, Any, Mapping[str, str]) -> urllib3.BaseHTTPResponse + self: Self, + method: str, + endpoint_type: EndpointType, + body: Any, + headers: Mapping[str, str], + ) -> urllib3.BaseHTTPResponse: return self._pool.request( method, self._auth.get_api_url(endpoint_type), @@ -658,14 +645,10 @@ def _request( ) -try: - import httpcore - import h2 # noqa: F401 -except ImportError: +if not HTTP2_ENABLED: # Sorry, no Http2Transport for you class Http2Transport(HttpTransport): - def __init__(self, options): - # type: (Self, Dict[str, Any]) -> None + def __init__(self: Self, options: Dict[str, Any]) -> None: super().__init__(options) logger.warning( "You tried to use HTTP2Transport but don't have httpcore[http2] installed. Falling back to HTTPTransport." @@ -683,8 +666,7 @@ class Http2Transport(BaseHttpTransport): # type: ignore httpcore.SOCKSProxy, httpcore.HTTPProxy, httpcore.ConnectionPool ] - def _get_header_value(self, response, header): - # type: (Self, httpcore.Response, str) -> Optional[str] + def _get_header_value(self: Self, response: Any, header: str) -> Optional[str]: return next( ( val.decode("ascii") @@ -695,13 +677,12 @@ def _get_header_value(self, response, header): ) def _request( - self, - method, - endpoint_type, - body, - headers, - ): - # type: (Self, str, EndpointType, Any, Mapping[str, str]) -> httpcore.Response + self: Self, + method: str, + endpoint_type: EndpointType, + body: Any, + headers: Mapping[str, str], + ) -> httpcore.Response: response = self._pool.request( method, self._auth.get_api_url(endpoint_type), @@ -718,13 +699,12 @@ def _request( ) return response - def _get_pool_options(self): - # type: (Self) -> Dict[str, Any] - options = { + def _get_pool_options(self: Self) -> Dict[str, Any]: + options: Dict[str, Any] = { "http2": self.parsed_dsn is not None and self.parsed_dsn.scheme == "https", "retries": 3, - } # type: Dict[str, Any] + } socket_options = ( self.options["socket_options"] @@ -755,8 +735,9 @@ def _get_pool_options(self): return options - def _make_pool(self): - # type: (Self) -> Union[httpcore.SOCKSProxy, httpcore.HTTPProxy, httpcore.ConnectionPool] + def _make_pool( + self: Self, + ) -> Union[httpcore.SOCKSProxy, httpcore.HTTPProxy, httpcore.ConnectionPool]: if self.parsed_dsn is None: raise ValueError("Cannot create HTTP-based transport without valid DSN") proxy = None @@ -799,16 +780,15 @@ def _make_pool(self): return httpcore.ConnectionPool(**opts) -def make_transport(options): - # type: (Dict[str, Any]) -> Optional[Transport] +def make_transport(options: Dict[str, Any]) -> Optional[Transport]: ref_transport = options["transport"] use_http2_transport = options.get("_experiments", {}).get("transport_http2", False) # By default, we use the http transport class - transport_cls = ( + transport_cls: Type[Transport] = ( Http2Transport if use_http2_transport else HttpTransport - ) # type: Type[Transport] + ) if isinstance(ref_transport, Transport): return ref_transport diff --git a/sentry_sdk/utils.py b/sentry_sdk/utils.py index 1420f41501..746d1eae54 100644 --- a/sentry_sdk/utils.py +++ b/sentry_sdk/utils.py @@ -1,3 +1,4 @@ +from __future__ import annotations import base64 import json import linecache @@ -34,21 +35,19 @@ ) from sentry_sdk._types import Annotated, AnnotatedValue, SENSITIVE_DATA_SUBSTITUTE -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, overload if TYPE_CHECKING: from types import FrameType, TracebackType from typing import ( Any, Callable, - cast, ContextManager, Dict, Iterator, List, NoReturn, Optional, - overload, ParamSpec, Set, Tuple, @@ -95,8 +94,7 @@ """ -def env_to_bool(value, *, strict=False): - # type: (Any, Optional[bool]) -> bool | None +def env_to_bool(value: Any, *, strict: Optional[bool] = False) -> Optional[bool]: """Casts an ENV variable value to boolean using the constants defined above. In strict mode, it may return None if the value doesn't match any of the predefined values. """ @@ -111,14 +109,12 @@ def env_to_bool(value, *, strict=False): return None if strict else bool(value) -def json_dumps(data): - # type: (Any) -> bytes +def json_dumps(data: Any) -> bytes: """Serialize data into a compact JSON representation encoded as UTF-8.""" return json.dumps(data, allow_nan=False, separators=(",", ":")).encode("utf-8") -def get_git_revision(): - # type: () -> Optional[str] +def get_git_revision() -> Optional[str]: try: with open(os.path.devnull, "w+") as null: # prevent command prompt windows from popping up on windows @@ -145,8 +141,7 @@ def get_git_revision(): return revision -def get_default_release(): - # type: () -> Optional[str] +def get_default_release() -> Optional[str]: """Try to guess a default release.""" release = os.environ.get("SENTRY_RELEASE") if release: @@ -169,8 +164,7 @@ def get_default_release(): return None -def get_sdk_name(installed_integrations): - # type: (List[str]) -> str +def get_sdk_name(installed_integrations: List[str]) -> str: """Return the SDK name including the name of the used web framework.""" # Note: I can not use for example sentry_sdk.integrations.django.DjangoIntegration.identifier @@ -208,12 +202,15 @@ def get_sdk_name(installed_integrations): class CaptureInternalException: __slots__ = () - def __enter__(self): - # type: () -> ContextManager[Any] + def __enter__(self) -> ContextManager[Any]: return self - def __exit__(self, ty, value, tb): - # type: (Optional[Type[BaseException]], Optional[BaseException], Optional[TracebackType]) -> bool + def __exit__( + self, + ty: Optional[Type[BaseException]], + value: Optional[BaseException], + tb: Optional[TracebackType], + ) -> bool: if ty is not None and value is not None: capture_internal_exception((ty, value, tb)) @@ -223,13 +220,11 @@ def __exit__(self, ty, value, tb): _CAPTURE_INTERNAL_EXCEPTION = CaptureInternalException() -def capture_internal_exceptions(): - # type: () -> ContextManager[Any] +def capture_internal_exceptions() -> ContextManager[Any]: return _CAPTURE_INTERNAL_EXCEPTION -def capture_internal_exception(exc_info): - # type: (ExcInfo) -> None +def capture_internal_exception(exc_info: ExcInfo) -> None: """ Capture an exception that is likely caused by a bug in the SDK itself. @@ -240,13 +235,11 @@ def capture_internal_exception(exc_info): logger.error("Internal error in sentry_sdk", exc_info=exc_info) -def to_timestamp(value): - # type: (datetime) -> float +def to_timestamp(value: datetime) -> float: return (value - epoch).total_seconds() -def format_timestamp(value): - # type: (datetime) -> str +def format_timestamp(value: datetime) -> str: """Formats a timestamp in RFC 3339 format. Any datetime objects with a non-UTC timezone are converted to UTC, so that all timestamps are formatted in UTC. @@ -258,8 +251,9 @@ def format_timestamp(value): return utctime.strftime("%Y-%m-%dT%H:%M:%S.%fZ") -def event_hint_with_exc_info(exc_info=None): - # type: (Optional[ExcInfo]) -> Dict[str, Optional[ExcInfo]] +def event_hint_with_exc_info( + exc_info: Optional[ExcInfo] = None, +) -> Dict[str, Optional[ExcInfo]]: """Creates a hint with the exc info filled in.""" if exc_info is None: exc_info = sys.exc_info() @@ -277,8 +271,7 @@ class BadDsn(ValueError): class Dsn: """Represents a DSN.""" - def __init__(self, value): - # type: (Union[Dsn, str]) -> None + def __init__(self, value: Union[Dsn, str]) -> None: if isinstance(value, Dsn): self.__dict__ = dict(value.__dict__) return @@ -294,7 +287,7 @@ def __init__(self, value): self.host = parts.hostname if parts.port is None: - self.port = self.scheme == "https" and 443 or 80 # type: int + self.port: int = self.scheme == "https" and 443 or 80 else: self.port = parts.port @@ -314,16 +307,14 @@ def __init__(self, value): self.path = "/".join(path) + "/" @property - def netloc(self): - # type: () -> str + def netloc(self) -> str: """The netloc part of a DSN.""" rv = self.host if (self.scheme, self.port) not in (("http", 80), ("https", 443)): rv = "%s:%s" % (rv, self.port) return rv - def to_auth(self, client=None): - # type: (Optional[Any]) -> Auth + def to_auth(self, client: Optional[Any] = None) -> Auth: """Returns the auth info object for this dsn.""" return Auth( scheme=self.scheme, @@ -335,8 +326,7 @@ def to_auth(self, client=None): client=client, ) - def __str__(self): - # type: () -> str + def __str__(self) -> str: return "%s://%s%s@%s%s%s" % ( self.scheme, self.public_key, @@ -352,16 +342,15 @@ class Auth: def __init__( self, - scheme, - host, - project_id, - public_key, - secret_key=None, - version=7, - client=None, - path="/", - ): - # type: (str, str, str, str, Optional[str], int, Optional[Any], str) -> None + scheme: str, + host: str, + project_id: str, + public_key: str, + secret_key: Optional[str] = None, + version: int = 7, + client: Optional[Any] = None, + path: str = "/", + ) -> None: self.scheme = scheme self.host = host self.path = path @@ -371,10 +360,7 @@ def __init__( self.version = version self.client = client - def get_api_url( - self, type=EndpointType.ENVELOPE # type: EndpointType - ): - # type: (...) -> str + def get_api_url(self, type: EndpointType = EndpointType.ENVELOPE) -> str: """Returns the API url for storing events.""" return "%s://%s%sapi/%s/%s/" % ( self.scheme, @@ -384,8 +370,7 @@ def get_api_url( type.value, ) - def to_header(self): - # type: () -> str + def to_header(self) -> str: """Returns the auth header a string.""" rv = [("sentry_key", self.public_key), ("sentry_version", self.version)] if self.client is not None: @@ -395,21 +380,18 @@ def to_header(self): return "Sentry " + ", ".join("%s=%s" % (key, value) for key, value in rv) -def get_type_name(cls): - # type: (Optional[type]) -> Optional[str] +def get_type_name(cls: Optional[type]) -> Optional[str]: return getattr(cls, "__qualname__", None) or getattr(cls, "__name__", None) -def get_type_module(cls): - # type: (Optional[type]) -> Optional[str] +def get_type_module(cls: Optional[type]) -> Optional[str]: mod = getattr(cls, "__module__", None) if mod not in (None, "builtins", "__builtins__"): return mod return None -def should_hide_frame(frame): - # type: (FrameType) -> bool +def should_hide_frame(frame: FrameType) -> bool: try: mod = frame.f_globals["__name__"] if mod.startswith("sentry_sdk."): @@ -427,9 +409,8 @@ def should_hide_frame(frame): return False -def iter_stacks(tb): - # type: (Optional[TracebackType]) -> Iterator[TracebackType] - tb_ = tb # type: Optional[TracebackType] +def iter_stacks(tb: Optional[TracebackType]) -> Iterator[TracebackType]: + tb_: Optional[TracebackType] = tb while tb_ is not None: if not should_hide_frame(tb_.tb_frame): yield tb_ @@ -437,18 +418,17 @@ def iter_stacks(tb): def get_lines_from_file( - filename, # type: str - lineno, # type: int - max_length=None, # type: Optional[int] - loader=None, # type: Optional[Any] - module=None, # type: Optional[str] -): - # type: (...) -> Tuple[List[Annotated[str]], Optional[Annotated[str]], List[Annotated[str]]] + filename: str, + lineno: int, + max_length: Optional[int] = None, + loader: Optional[Any] = None, + module: Optional[str] = None, +) -> Tuple[List[Annotated[str]], Optional[Annotated[str]], List[Annotated[str]]]: context_lines = 5 source = None if loader is not None and hasattr(loader, "get_source"): try: - source_str = loader.get_source(module) # type: Optional[str] + source_str: Optional[str] = loader.get_source(module) except (ImportError, IOError): source_str = None if source_str is not None: @@ -483,13 +463,12 @@ def get_lines_from_file( def get_source_context( - frame, # type: FrameType - tb_lineno, # type: Optional[int] - max_value_length=None, # type: Optional[int] -): - # type: (...) -> Tuple[List[Annotated[str]], Optional[Annotated[str]], List[Annotated[str]]] + frame: FrameType, + tb_lineno: Optional[int], + max_value_length: Optional[int] = None, +) -> Tuple[List[Annotated[str]], Optional[Annotated[str]], List[Annotated[str]]]: try: - abs_path = frame.f_code.co_filename # type: Optional[str] + abs_path: Optional[str] = frame.f_code.co_filename except Exception: abs_path = None try: @@ -510,24 +489,23 @@ def get_source_context( return [], None, [] -def safe_str(value): - # type: (Any) -> str +def safe_str(value: Any) -> str: try: return str(value) except Exception: return safe_repr(value) -def safe_repr(value): - # type: (Any) -> str +def safe_repr(value: Any) -> str: try: return repr(value) except Exception: return "" -def filename_for_module(module, abs_path): - # type: (Optional[str], Optional[str]) -> Optional[str] +def filename_for_module( + module: Optional[str], abs_path: Optional[str] +) -> Optional[str]: if not abs_path or not module: return abs_path @@ -551,14 +529,13 @@ def filename_for_module(module, abs_path): def serialize_frame( - frame, - tb_lineno=None, - include_local_variables=True, - include_source_context=True, - max_value_length=None, - custom_repr=None, -): - # type: (FrameType, Optional[int], bool, bool, Optional[int], Optional[Callable[..., Optional[str]]]) -> Dict[str, Any] + frame: FrameType, + tb_lineno: Optional[int] = None, + include_local_variables: bool = True, + include_source_context: bool = True, + max_value_length: Optional[int] = None, + custom_repr: Optional[Callable[..., Optional[str]]] = None, +) -> Dict[str, Any]: f_code = getattr(frame, "f_code", None) if not f_code: abs_path = None @@ -574,13 +551,13 @@ def serialize_frame( if tb_lineno is None: tb_lineno = frame.f_lineno - rv = { + rv: Dict[str, Any] = { "filename": filename_for_module(module, abs_path) or None, "abs_path": os.path.abspath(abs_path) if abs_path else None, "function": function or "", "module": module, "lineno": tb_lineno, - } # type: Dict[str, Any] + } if include_source_context: rv["pre_context"], rv["context_line"], rv["post_context"] = get_source_context( @@ -598,15 +575,14 @@ def serialize_frame( def current_stacktrace( - include_local_variables=True, # type: bool - include_source_context=True, # type: bool - max_value_length=None, # type: Optional[int] -): - # type: (...) -> Dict[str, Any] + include_local_variables: bool = True, + include_source_context: bool = True, + max_value_length: Optional[int] = None, +) -> Dict[str, Any]: __tracebackhide__ = True frames = [] - f = sys._getframe() # type: Optional[FrameType] + f: Optional[FrameType] = sys._getframe() while f is not None: if not should_hide_frame(f): frames.append( @@ -624,24 +600,22 @@ def current_stacktrace( return {"frames": frames} -def get_errno(exc_value): - # type: (BaseException) -> Optional[Any] +def get_errno(exc_value: BaseException) -> Optional[Any]: return getattr(exc_value, "errno", None) -def get_error_message(exc_value): - # type: (Optional[BaseException]) -> str - message = ( +def get_error_message(exc_value: Optional[BaseException]) -> str: + message: str = ( getattr(exc_value, "message", "") or getattr(exc_value, "detail", "") or safe_str(exc_value) - ) # type: str + ) # __notes__ should be a list of strings when notes are added # via add_note, but can be anything else if __notes__ is set # directly. We only support strings in __notes__, since that # is the correct use. - notes = getattr(exc_value, "__notes__", None) # type: object + notes: object = getattr(exc_value, "__notes__", None) if isinstance(notes, list) and len(notes) > 0: message += "\n" + "\n".join(note for note in notes if isinstance(note, str)) @@ -649,24 +623,23 @@ def get_error_message(exc_value): def single_exception_from_error_tuple( - exc_type, # type: Optional[type] - exc_value, # type: Optional[BaseException] - tb, # type: Optional[TracebackType] - client_options=None, # type: Optional[Dict[str, Any]] - mechanism=None, # type: Optional[Dict[str, Any]] - exception_id=None, # type: Optional[int] - parent_id=None, # type: Optional[int] - source=None, # type: Optional[str] - full_stack=None, # type: Optional[list[dict[str, Any]]] -): - # type: (...) -> Dict[str, Any] + exc_type: Optional[type], + exc_value: Optional[BaseException], + tb: Optional[TracebackType], + client_options: Optional[Dict[str, Any]] = None, + mechanism: Optional[Dict[str, Any]] = None, + exception_id: Optional[int] = None, + parent_id: Optional[int] = None, + source: Optional[str] = None, + full_stack: Optional[list[dict[str, Any]]] = None, +) -> Dict[str, Any]: """ Creates a dict that goes into the events `exception.values` list and is ingestible by Sentry. See the Exception Interface documentation for more details: https://develop.sentry.dev/sdk/event-payloads/exception/ """ - exception_value = {} # type: Dict[str, Any] + exception_value: Dict[str, Any] = {} exception_value["mechanism"] = ( mechanism.copy() if mechanism else {"type": "generic", "handled": True} ) @@ -715,7 +688,7 @@ def single_exception_from_error_tuple( max_value_length = client_options["max_value_length"] custom_repr = client_options.get("custom_repr") - frames = [ + frames: List[Dict[str, Any]] = [ serialize_frame( tb.tb_frame, tb_lineno=tb.tb_lineno, @@ -727,7 +700,7 @@ def single_exception_from_error_tuple( # Process at most MAX_STACK_FRAMES + 1 frames, to avoid hanging on # processing a super-long stacktrace. for tb, _ in zip(iter_stacks(tb), range(MAX_STACK_FRAMES + 1)) - ] # type: List[Dict[str, Any]] + ] if len(frames) > MAX_STACK_FRAMES: # If we have more frames than the limit, we remove the stacktrace completely. @@ -755,12 +728,11 @@ def single_exception_from_error_tuple( if HAS_CHAINED_EXCEPTIONS: - def walk_exception_chain(exc_info): - # type: (ExcInfo) -> Iterator[ExcInfo] + def walk_exception_chain(exc_info: ExcInfo) -> Iterator[ExcInfo]: exc_type, exc_value, tb = exc_info seen_exceptions = [] - seen_exception_ids = set() # type: Set[int] + seen_exception_ids: Set[int] = set() while ( exc_type is not None @@ -787,23 +759,21 @@ def walk_exception_chain(exc_info): else: - def walk_exception_chain(exc_info): - # type: (ExcInfo) -> Iterator[ExcInfo] + def walk_exception_chain(exc_info: ExcInfo) -> Iterator[ExcInfo]: yield exc_info def exceptions_from_error( - exc_type, # type: Optional[type] - exc_value, # type: Optional[BaseException] - tb, # type: Optional[TracebackType] - client_options=None, # type: Optional[Dict[str, Any]] - mechanism=None, # type: Optional[Dict[str, Any]] - exception_id=0, # type: int - parent_id=0, # type: int - source=None, # type: Optional[str] - full_stack=None, # type: Optional[list[dict[str, Any]]] -): - # type: (...) -> Tuple[int, List[Dict[str, Any]]] + exc_type: Optional[type], + exc_value: Optional[BaseException], + tb: Optional[TracebackType], + client_options: Optional[Dict[str, Any]] = None, + mechanism: Optional[Dict[str, Any]] = None, + exception_id: int = 0, + parent_id: int = 0, + source: Optional[str] = None, + full_stack: Optional[list[dict[str, Any]]] = None, +) -> Tuple[int, List[Dict[str, Any]]]: """ Converts the given exception information into the Sentry structured "exception" format. This will return a list of exceptions (a flattened tree of exceptions) in the @@ -838,7 +808,9 @@ def exceptions_from_error( exception_source = None # Add any causing exceptions, if present. - should_suppress_context = hasattr(exc_value, "__suppress_context__") and exc_value.__suppress_context__ # type: ignore + should_suppress_context = ( + hasattr(exc_value, "__suppress_context__") and exc_value.__suppress_context__ # type: ignore[union-attr] + ) # Note: __suppress_context__ is True if the exception is raised with the `from` keyword. if should_suppress_context: # Explicitly chained exceptions (Like: raise NewException() from OriginalException()) @@ -862,7 +834,7 @@ def exceptions_from_error( ) if has_implicit_causing_exception: exception_source = "__context__" - causing_exception = exc_value.__context__ # type: ignore + causing_exception = exc_value.__context__ # type: ignore if causing_exception: # Some frameworks (e.g. FastAPI) wrap the causing exception in an @@ -912,12 +884,11 @@ def exceptions_from_error( def exceptions_from_error_tuple( - exc_info, # type: ExcInfo - client_options=None, # type: Optional[Dict[str, Any]] - mechanism=None, # type: Optional[Dict[str, Any]] - full_stack=None, # type: Optional[list[dict[str, Any]]] -): - # type: (...) -> List[Dict[str, Any]] + exc_info: ExcInfo, + client_options: Optional[Dict[str, Any]] = None, + mechanism: Optional[Dict[str, Any]] = None, + full_stack: Optional[list[dict[str, Any]]] = None, +) -> List[Dict[str, Any]]: """ Convert Python's exception information into Sentry's structured "exception" format in the event. See https://develop.sentry.dev/sdk/data-model/event-payloads/exception/ @@ -946,16 +917,14 @@ def exceptions_from_error_tuple( return exceptions -def to_string(value): - # type: (str) -> str +def to_string(value: Any) -> str: try: return str(value) except UnicodeDecodeError: return repr(value)[1:-1] -def iter_event_stacktraces(event): - # type: (Event) -> Iterator[Annotated[Dict[str, Any]]] +def iter_event_stacktraces(event: Event) -> Iterator[Annotated[Dict[str, Any]]]: if "stacktrace" in event: yield event["stacktrace"] if "threads" in event: @@ -968,8 +937,7 @@ def iter_event_stacktraces(event): yield exception["stacktrace"] -def iter_event_frames(event): - # type: (Event) -> Iterator[Dict[str, Any]] +def iter_event_frames(event: Event) -> Iterator[Dict[str, Any]]: for stacktrace in iter_event_stacktraces(event): if isinstance(stacktrace, AnnotatedValue): stacktrace = stacktrace.value or {} @@ -978,8 +946,12 @@ def iter_event_frames(event): yield frame -def handle_in_app(event, in_app_exclude=None, in_app_include=None, project_root=None): - # type: (Event, Optional[List[str]], Optional[List[str]], Optional[str]) -> Event +def handle_in_app( + event: Event, + in_app_exclude: Optional[List[str]] = None, + in_app_include: Optional[List[str]] = None, + project_root: Optional[str] = None, +) -> Event: for stacktrace in iter_event_stacktraces(event): if isinstance(stacktrace, AnnotatedValue): stacktrace = stacktrace.value or {} @@ -994,8 +966,12 @@ def handle_in_app(event, in_app_exclude=None, in_app_include=None, project_root= return event -def set_in_app_in_frames(frames, in_app_exclude, in_app_include, project_root=None): - # type: (Any, Optional[List[str]], Optional[List[str]], Optional[str]) -> Optional[Any] +def set_in_app_in_frames( + frames: Any, + in_app_exclude: Optional[List[str]], + in_app_include: Optional[List[str]], + project_root: Optional[str] = None, +) -> Optional[Any]: if not frames: return None @@ -1033,8 +1009,7 @@ def set_in_app_in_frames(frames, in_app_exclude, in_app_include, project_root=No return frames -def exc_info_from_error(error): - # type: (Union[BaseException, ExcInfo]) -> ExcInfo +def exc_info_from_error(error: Union[BaseException, ExcInfo]) -> ExcInfo: if isinstance(error, tuple) and len(error) == 3: exc_type, exc_value, tb = error elif isinstance(error, BaseException): @@ -1052,18 +1027,17 @@ def exc_info_from_error(error): else: raise ValueError("Expected Exception object to report, got %s!" % type(error)) - exc_info = (exc_type, exc_value, tb) - - if TYPE_CHECKING: - # This cast is safe because exc_type and exc_value are either both - # None or both not None. - exc_info = cast(ExcInfo, exc_info) - - return exc_info + if exc_type is not None and exc_value is not None: + return (exc_type, exc_value, tb) + else: + return (None, None, None) -def merge_stack_frames(frames, full_stack, client_options): - # type: (List[Dict[str, Any]], List[Dict[str, Any]], Optional[Dict[str, Any]]) -> List[Dict[str, Any]] +def merge_stack_frames( + frames: List[Dict[str, Any]], + full_stack: List[Dict[str, Any]], + client_options: Optional[Dict[str, Any]], +) -> List[Dict[str, Any]]: """ Add the missing frames from full_stack to frames and return the merged list. """ @@ -1103,11 +1077,10 @@ def merge_stack_frames(frames, full_stack, client_options): def event_from_exception( - exc_info, # type: Union[BaseException, ExcInfo] - client_options=None, # type: Optional[Dict[str, Any]] - mechanism=None, # type: Optional[Dict[str, Any]] -): - # type: (...) -> Tuple[Event, Dict[str, Any]] + exc_info: Union[BaseException, ExcInfo], + client_options: Optional[Dict[str, Any]] = None, + mechanism: Optional[Dict[str, Any]] = None, +) -> Tuple[Event, Dict[str, Any]]: exc_info = exc_info_from_error(exc_info) hint = event_hint_with_exc_info(exc_info) @@ -1132,8 +1105,7 @@ def event_from_exception( ) -def _module_in_list(name, items): - # type: (Optional[str], Optional[List[str]]) -> bool +def _module_in_list(name: Optional[str], items: Optional[List[str]]) -> bool: if name is None: return False @@ -1147,8 +1119,7 @@ def _module_in_list(name, items): return False -def _is_external_source(abs_path): - # type: (Optional[str]) -> bool +def _is_external_source(abs_path: Optional[str]) -> bool: # check if frame is in 'site-packages' or 'dist-packages' if abs_path is None: return False @@ -1159,8 +1130,7 @@ def _is_external_source(abs_path): return external_source -def _is_in_project_root(abs_path, project_root): - # type: (Optional[str], Optional[str]) -> bool +def _is_in_project_root(abs_path: Optional[str], project_root: Optional[str]) -> bool: if abs_path is None or project_root is None: return False @@ -1171,8 +1141,7 @@ def _is_in_project_root(abs_path, project_root): return False -def _truncate_by_bytes(string, max_bytes): - # type: (str, int) -> str +def _truncate_by_bytes(string: str, max_bytes: int) -> str: """ Truncate a UTF-8-encodable string to the last full codepoint so that it fits in max_bytes. """ @@ -1181,16 +1150,16 @@ def _truncate_by_bytes(string, max_bytes): return truncated + "..." -def _get_size_in_bytes(value): - # type: (str) -> Optional[int] +def _get_size_in_bytes(value: str) -> Optional[int]: try: return len(value.encode("utf-8")) except (UnicodeEncodeError, UnicodeDecodeError): return None -def strip_string(value, max_length=None): - # type: (str, Optional[int]) -> Union[AnnotatedValue, str] +def strip_string( + value: str, max_length: Optional[int] = None +) -> Union[AnnotatedValue, str]: if not value: return value @@ -1218,8 +1187,7 @@ def strip_string(value, max_length=None): ) -def parse_version(version): - # type: (str) -> Optional[Tuple[int, ...]] +def parse_version(version: str) -> Optional[Tuple[int, ...]]: """ Parses a version string into a tuple of integers. This uses the parsing loging from PEP 440: @@ -1263,15 +1231,14 @@ def parse_version(version): try: release = pattern.match(version).groupdict()["release"] # type: ignore - release_tuple = tuple(map(int, release.split(".")[:3])) # type: Tuple[int, ...] + release_tuple: Tuple[int, ...] = tuple(map(int, release.split(".")[:3])) except (TypeError, ValueError, AttributeError): return None return release_tuple -def _is_contextvars_broken(): - # type: () -> bool +def _is_contextvars_broken() -> bool: """ Returns whether gevent/eventlet have patched the stdlib in a way where thread locals are now more "correct" than contextvars. """ @@ -1322,32 +1289,27 @@ def _is_contextvars_broken(): return False -def _make_threadlocal_contextvars(local): - # type: (type) -> type +def _make_threadlocal_contextvars(local: type) -> type: class ContextVar: # Super-limited impl of ContextVar - def __init__(self, name, default=None): - # type: (str, Any) -> None + def __init__(self, name: str, default: Optional[Any] = None) -> None: self._name = name self._default = default self._local = local() self._original_local = local() - def get(self, default=None): - # type: (Any) -> Any + def get(self, default: Optional[Any] = None) -> Any: return getattr(self._local, "value", default or self._default) - def set(self, value): - # type: (Any) -> Any + def set(self, value: Any) -> Any: token = str(random.getrandbits(64)) original_value = self.get() setattr(self._original_local, token, original_value) self._local.value = value return token - def reset(self, token): - # type: (Any) -> None + def reset(self, token: Any) -> None: self._local.value = getattr(self._original_local, token) # delete the original value (this way it works in Python 3.6+) del self._original_local.__dict__[token] @@ -1355,8 +1317,7 @@ def reset(self, token): return ContextVar -def _get_contextvars(): - # type: () -> Tuple[bool, type] +def _get_contextvars() -> Tuple[bool, type]: """ Figure out the "right" contextvars installation to use. Returns a `contextvars.ContextVar`-like class with a limited API. @@ -1391,10 +1352,9 @@ def _get_contextvars(): """ -def qualname_from_function(func): - # type: (Callable[..., Any]) -> Optional[str] +def qualname_from_function(func: Callable[..., Any]) -> Optional[str]: """Return the qualified name of func. Works with regular function, lambda, partial and partialmethod.""" - func_qualname = None # type: Optional[str] + func_qualname: Optional[str] = None # Python 2 try: @@ -1435,8 +1395,7 @@ def qualname_from_function(func): return func_qualname -def transaction_from_function(func): - # type: (Callable[..., Any]) -> Optional[str] +def transaction_from_function(func: Callable[..., Any]) -> Optional[str]: return qualname_from_function(func) @@ -1454,19 +1413,16 @@ class TimeoutThread(threading.Thread): waiting_time and raises a custom ServerlessTimeout exception. """ - def __init__(self, waiting_time, configured_timeout): - # type: (float, int) -> None + def __init__(self, waiting_time: float, configured_timeout: int) -> None: threading.Thread.__init__(self) self.waiting_time = waiting_time self.configured_timeout = configured_timeout self._stop_event = threading.Event() - def stop(self): - # type: () -> None + def stop(self) -> None: self._stop_event.set() - def run(self): - # type: () -> None + def run(self) -> None: self._stop_event.wait(self.waiting_time) @@ -1487,8 +1443,7 @@ def run(self): ) -def to_base64(original): - # type: (str) -> Optional[str] +def to_base64(original: str) -> Optional[str]: """ Convert a string to base64, via UTF-8. Returns None on invalid input. """ @@ -1504,8 +1459,7 @@ def to_base64(original): return base64_string -def from_base64(base64_string): - # type: (str) -> Optional[str] +def from_base64(base64_string: str) -> Optional[str]: """ Convert a string from base64, via UTF-8. Returns None on invalid input. """ @@ -1529,8 +1483,12 @@ def from_base64(base64_string): Components = namedtuple("Components", ["scheme", "netloc", "path", "query", "fragment"]) -def sanitize_url(url, remove_authority=True, remove_query_values=True, split=False): - # type: (str, bool, bool, bool) -> Union[str, Components] +def sanitize_url( + url: str, + remove_authority: bool = True, + remove_query_values: bool = True, + split: bool = False, +) -> Union[str, Components]: """ Removes the authority and query parameter values from a given URL. """ @@ -1576,8 +1534,7 @@ def sanitize_url(url, remove_authority=True, remove_query_values=True, split=Fal ParsedUrl = namedtuple("ParsedUrl", ["url", "query", "fragment"]) -def parse_url(url, sanitize=True): - # type: (str, bool) -> ParsedUrl +def parse_url(url: str, sanitize: bool = True) -> ParsedUrl: """ Splits a URL into a url (including path), query and fragment. If sanitize is True, the query parameters will be sanitized to remove sensitive data. The autority (username and password) @@ -1604,11 +1561,11 @@ def parse_url(url, sanitize=True): ) -def is_valid_sample_rate(rate, source): - # type: (Any, str) -> bool +def is_valid_sample_rate(rate: Any, source: str) -> Optional[float]: """ Checks the given sample rate to make sure it is valid type and value (a boolean or a number between 0 and 1, inclusive). + Returns the final float value to use if valid. """ # both booleans and NaN are instances of Real, so a) checking for Real @@ -1620,7 +1577,7 @@ def is_valid_sample_rate(rate, source): source=source, rate=rate, type=type(rate) ) ) - return False + return None # in case rate is a boolean, it will get cast to 1 if it's True and 0 if it's False rate = float(rate) @@ -1630,13 +1587,14 @@ def is_valid_sample_rate(rate, source): source=source, rate=rate ) ) - return False + return None - return True + return rate -def match_regex_list(item, regex_list=None, substring_matching=False): - # type: (str, Optional[List[str]], bool) -> bool +def match_regex_list( + item: str, regex_list: Optional[List[str]] = None, substring_matching: bool = False +) -> bool: if regex_list is None: return False @@ -1651,8 +1609,7 @@ def match_regex_list(item, regex_list=None, substring_matching=False): return False -def is_sentry_url(client, url): - # type: (sentry_sdk.client.BaseClient, str) -> bool +def is_sentry_url(client: sentry_sdk.client.BaseClient, url: str) -> bool: """ Determines whether the given URL matches the Sentry DSN. """ @@ -1664,8 +1621,7 @@ def is_sentry_url(client, url): ) -def _generate_installed_modules(): - # type: () -> Iterator[Tuple[str, str]] +def _generate_installed_modules() -> Iterator[Tuple[str, str]]: try: from importlib import metadata @@ -1693,21 +1649,18 @@ def _generate_installed_modules(): yield _normalize_module_name(info.key), info.version -def _normalize_module_name(name): - # type: (str) -> str +def _normalize_module_name(name: str) -> str: return name.lower() -def _get_installed_modules(): - # type: () -> Dict[str, str] +def _get_installed_modules() -> Dict[str, str]: global _installed_modules if _installed_modules is None: _installed_modules = dict(_generate_installed_modules()) return _installed_modules -def package_version(package): - # type: (str) -> Optional[Tuple[int, ...]] +def package_version(package: str) -> Optional[Tuple[int, ...]]: installed_packages = _get_installed_modules() version = installed_packages.get(package) if version is None: @@ -1716,43 +1669,35 @@ def package_version(package): return parse_version(version) -def reraise(tp, value, tb=None): - # type: (Optional[Type[BaseException]], Optional[BaseException], Optional[Any]) -> NoReturn +def reraise( + tp: Optional[Type[BaseException]], + value: Optional[BaseException], + tb: Optional[Any] = None, +) -> NoReturn: assert value is not None if value.__traceback__ is not tb: raise value.with_traceback(tb) raise value -def _no_op(*_a, **_k): - # type: (*Any, **Any) -> None - """No-op function for ensure_integration_enabled.""" - pass - - if TYPE_CHECKING: @overload def ensure_integration_enabled( - integration, # type: type[sentry_sdk.integrations.Integration] - original_function, # type: Callable[P, R] - ): - # type: (...) -> Callable[[Callable[P, R]], Callable[P, R]] - ... + integration: type[sentry_sdk.integrations.Integration], + original_function: Callable[P, R], + ) -> Callable[[Callable[P, R]], Callable[P, R]]: ... @overload def ensure_integration_enabled( - integration, # type: type[sentry_sdk.integrations.Integration] - ): - # type: (...) -> Callable[[Callable[P, None]], Callable[P, None]] - ... + integration: type[sentry_sdk.integrations.Integration], + ) -> Callable[[Callable[P, None]], Callable[P, None]]: ... def ensure_integration_enabled( - integration, # type: type[sentry_sdk.integrations.Integration] - original_function=_no_op, # type: Union[Callable[P, R], Callable[P, None]] -): - # type: (...) -> Callable[[Callable[P, R]], Callable[P, R]] + integration: type[sentry_sdk.integrations.Integration], + original_function: Optional[Callable[P, R]] = None, +) -> Callable[[Callable[P, R]], Callable[P, Optional[R]]]: """ Ensures a given integration is enabled prior to calling a Sentry-patched function. @@ -1774,30 +1719,25 @@ def patch_my_function(): return my_function() ``` """ - if TYPE_CHECKING: - # Type hint to ensure the default function has the right typing. The overloads - # ensure the default _no_op function is only used when R is None. - original_function = cast(Callable[P, R], original_function) - - def patcher(sentry_patched_function): - # type: (Callable[P, R]) -> Callable[P, R] - def runner(*args: "P.args", **kwargs: "P.kwargs"): - # type: (...) -> R - if sentry_sdk.get_client().get_integration(integration) is None: - return original_function(*args, **kwargs) - return sentry_patched_function(*args, **kwargs) + def patcher(sentry_patched_function: Callable[P, R]) -> Callable[P, Optional[R]]: + def runner(*args: P.args, **kwargs: P.kwargs) -> Optional[R]: + if sentry_sdk.get_client().get_integration(integration) is not None: + return sentry_patched_function(*args, **kwargs) + elif original_function is not None: + return original_function(*args, **kwargs) + else: + return None - if original_function is _no_op: + if original_function: + return wraps(original_function)(runner) + else: return wraps(sentry_patched_function)(runner) - return wraps(original_function)(runner) - return patcher -def now(): - # type: () -> float +def now() -> float: return time.perf_counter() @@ -1808,23 +1748,21 @@ def now(): # it's not great that the signatures are different, get_hub can't return None # consider adding an if TYPE_CHECKING to change the signature to Optional[GeventHub] - def get_gevent_hub(): # type: ignore[misc] - # type: () -> Optional[GeventHub] + def get_gevent_hub() -> Optional[GeventHub]: # type: ignore[misc] return None - def is_module_patched(mod_name): - # type: (str) -> bool + def is_module_patched(mod_name: str) -> bool: # unable to import from gevent means no modules have been patched return False -def is_gevent(): - # type: () -> bool +def is_gevent() -> bool: return is_module_patched("threading") or is_module_patched("_thread") -def get_current_thread_meta(thread=None): - # type: (Optional[threading.Thread]) -> Tuple[Optional[int], Optional[str]] +def get_current_thread_meta( + thread: Optional[threading.Thread] = None, +) -> Tuple[Optional[int], Optional[str]]: """ Try to get the id of the current thread, with various fall backs. """ @@ -1874,8 +1812,7 @@ def get_current_thread_meta(thread=None): return None, None -def _serialize_span_attribute(value): - # type: (Any) -> Optional[AttributeValue] +def _serialize_span_attribute(value: Any) -> Optional[AttributeValue]: """Serialize an object so that it's OTel-compatible and displays nicely in Sentry.""" # check for allowed primitives if isinstance(value, (int, str, float, bool)): @@ -1902,8 +1839,7 @@ def _serialize_span_attribute(value): ISO_TZ_SEPARATORS = frozenset(("+", "-")) -def datetime_from_isoformat(value): - # type: (str) -> datetime +def datetime_from_isoformat(value: str) -> datetime: try: result = datetime.fromisoformat(value) except (AttributeError, ValueError): @@ -1924,8 +1860,7 @@ def datetime_from_isoformat(value): return result.astimezone(timezone.utc) -def should_be_treated_as_error(ty, value): - # type: (Any, Any) -> bool +def should_be_treated_as_error(ty: Any, value: Any) -> bool: if ty == SystemExit and hasattr(value, "code") and value.code in (0, None): # https://docs.python.org/3/library/exceptions.html#SystemExit return False @@ -1933,8 +1868,7 @@ def should_be_treated_as_error(ty, value): return True -def http_client_status_to_breadcrumb_level(status_code): - # type: (Optional[int]) -> str +def http_client_status_to_breadcrumb_level(status_code: Optional[int]) -> str: if status_code is not None: if 500 <= status_code <= 599: return "error" @@ -1944,8 +1878,9 @@ def http_client_status_to_breadcrumb_level(status_code): return "info" -def set_thread_info_from_span(data, span): - # type: (Dict[str, Any], sentry_sdk.tracing.Span) -> None +def set_thread_info_from_span( + data: Dict[str, Any], span: sentry_sdk.tracing.Span +) -> None: if span.get_attribute(SPANDATA.THREAD_ID) is not None: data[SPANDATA.THREAD_ID] = span.get_attribute(SPANDATA.THREAD_ID) if span.get_attribute(SPANDATA.THREAD_NAME) is not None: diff --git a/sentry_sdk/worker.py b/sentry_sdk/worker.py index b04ea582bc..d911e15623 100644 --- a/sentry_sdk/worker.py +++ b/sentry_sdk/worker.py @@ -1,3 +1,4 @@ +from __future__ import annotations import os import threading @@ -9,38 +10,32 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Any - from typing import Optional - from typing import Callable + from typing import Any, Optional, Callable _TERMINATOR = object() class BackgroundWorker: - def __init__(self, queue_size=DEFAULT_QUEUE_SIZE): - # type: (int) -> None - self._queue = Queue(queue_size) # type: Queue + def __init__(self, queue_size: int = DEFAULT_QUEUE_SIZE) -> None: + self._queue: Queue = Queue(queue_size) self._lock = threading.Lock() - self._thread = None # type: Optional[threading.Thread] - self._thread_for_pid = None # type: Optional[int] + self._thread: Optional[threading.Thread] = None + self._thread_for_pid: Optional[int] = None @property - def is_alive(self): - # type: () -> bool + def is_alive(self) -> bool: if self._thread_for_pid != os.getpid(): return False if not self._thread: return False return self._thread.is_alive() - def _ensure_thread(self): - # type: () -> None + def _ensure_thread(self) -> None: if not self.is_alive: self.start() - def _timed_queue_join(self, timeout): - # type: (float) -> bool + def _timed_queue_join(self, timeout: float) -> bool: deadline = time() + timeout queue = self._queue @@ -57,8 +52,7 @@ def _timed_queue_join(self, timeout): finally: queue.all_tasks_done.release() - def start(self): - # type: () -> None + def start(self) -> None: with self._lock: if not self.is_alive: self._thread = threading.Thread( @@ -74,8 +68,7 @@ def start(self): # send out events. self._thread = None - def kill(self): - # type: () -> None + def kill(self) -> None: """ Kill worker thread. Returns immediately. Not useful for waiting on shutdown for events, use `flush` for that. @@ -91,20 +84,17 @@ def kill(self): self._thread = None self._thread_for_pid = None - def flush(self, timeout, callback=None): - # type: (float, Optional[Any]) -> None + def flush(self, timeout: float, callback: Optional[Any] = None) -> None: logger.debug("background worker got flush request") with self._lock: if self.is_alive and timeout > 0.0: self._wait_flush(timeout, callback) logger.debug("background worker flushed") - def full(self): - # type: () -> bool + def full(self) -> bool: return self._queue.full() - def _wait_flush(self, timeout, callback): - # type: (float, Optional[Any]) -> None + def _wait_flush(self, timeout: float, callback: Optional[Any]) -> None: initial_timeout = min(0.1, timeout) if not self._timed_queue_join(initial_timeout): pending = self._queue.qsize() + 1 @@ -116,8 +106,7 @@ def _wait_flush(self, timeout, callback): pending = self._queue.qsize() + 1 logger.error("flush timed out, dropped %s events", pending) - def submit(self, callback): - # type: (Callable[[], None]) -> bool + def submit(self, callback: Callable[[], None]) -> bool: self._ensure_thread() try: self._queue.put_nowait(callback) @@ -125,8 +114,7 @@ def submit(self, callback): except FullError: return False - def _target(self): - # type: () -> None + def _target(self) -> None: while True: callback = self._queue.get() try: diff --git a/tests/integrations/logging/test_logging.py b/tests/integrations/logging/test_logging.py index d1f5d448b6..931c58d04f 100644 --- a/tests/integrations/logging/test_logging.py +++ b/tests/integrations/logging/test_logging.py @@ -259,8 +259,8 @@ def test_logging_captured_warnings(sentry_init, capture_events, recwarn): assert events[1]["logentry"]["params"] == [] # Using recwarn suppresses the "third" warning in the test output - assert len(recwarn) == 1 - assert str(recwarn[0].message) == "third" + third_warnings = [w for w in recwarn if str(w.message) == "third"] + assert len(third_warnings) == 1 def test_ignore_logger(sentry_init, capture_events): diff --git a/tests/test_utils.py b/tests/test_utils.py index e5bad4fa72..963b937380 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -493,7 +493,7 @@ def test_accepts_valid_sample_rate(rate): with mock.patch.object(logger, "warning", mock.Mock()): result = is_valid_sample_rate(rate, source="Testing") assert logger.warning.called is False - assert result is True + assert result == float(rate) @pytest.mark.parametrize( @@ -514,7 +514,7 @@ def test_warns_on_invalid_sample_rate(rate, StringContaining): # noqa: N803 with mock.patch.object(logger, "warning", mock.Mock()): result = is_valid_sample_rate(rate, source="Testing") logger.warning.assert_any_call(StringContaining("Given sample rate is invalid")) - assert result is False + assert result is None @pytest.mark.parametrize(