Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions sentry_sdk/integrations/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from sentry_sdk.integrations.starlette import (
StarletteIntegration,
StarletteRequestExtractor,
_patch_request,
)
except DidNotEnable:
raise DidNotEnable("Starlette is not installed")
Expand Down Expand Up @@ -103,11 +104,16 @@ async def _sentry_app(*args, **kwargs):
return await old_app(*args, **kwargs)

request = args[0]
_patch_request(request)

_set_transaction_name_and_source(
sentry_sdk.get_current_scope(), integration.transaction_style, request
)
sentry_scope = sentry_sdk.get_isolation_scope()
sentry_scope._name = FastApiIntegration.identifier

response = await old_app(*args, **kwargs)

extractor = StarletteRequestExtractor(request)
info = await extractor.extract_request_info()

Expand All @@ -129,12 +135,11 @@ def event_processor(event, hint):

return event_processor

sentry_scope._name = FastApiIntegration.identifier
sentry_scope.add_event_processor(
_make_request_event_processor(request, integration)
)

return await old_app(*args, **kwargs)
return response

return _sentry_app

Expand Down
50 changes: 48 additions & 2 deletions sentry_sdk/integrations/starlette.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,36 @@ def _is_async_callable(obj):
)


def _patch_request(request):
_original_body = request.body
_original_json = request.json
_original_form = request.form

def restore_original_methods():
request.body = _original_body
request.json = _original_json
request.form = _original_form

async def sentry_body():
request.scope.setdefault("state", {})["sentry_sdk.is_body_cached"] = True
restore_original_methods()
return await _original_body()

async def sentry_json():
request.scope.setdefault("state", {})["sentry_sdk.is_body_cached"] = True
restore_original_methods()
return await _original_json()

async def sentry_form():
request.scope.setdefault("state", {})["sentry_sdk.is_body_cached"] = True
restore_original_methods()
return await _original_form()

request.body = sentry_body
request.json = sentry_json
request.form = sentry_form


def patch_request_response():
# type: () -> None
old_request_response = starlette.routing.request_response
Expand All @@ -442,6 +472,7 @@ async def _sentry_async_func(*args, **kwargs):
return await old_func(*args, **kwargs)

request = args[0]
_patch_request(request)

_set_transaction_name_and_source(
sentry_sdk.get_current_scope(),
Expand All @@ -450,6 +481,10 @@ async def _sentry_async_func(*args, **kwargs):
)

sentry_scope = sentry_sdk.get_isolation_scope()
sentry_scope._name = StarletteIntegration.identifier

response = await old_func(*args, **kwargs)

extractor = StarletteRequestExtractor(request)
info = await extractor.extract_request_info()

Expand All @@ -471,12 +506,11 @@ def event_processor(event, hint):

return event_processor

sentry_scope._name = StarletteIntegration.identifier
sentry_scope.add_event_processor(
_make_request_event_processor(request, integration)
)

return await old_func(*args, **kwargs)
return response

func = _sentry_async_func

Expand Down Expand Up @@ -623,6 +657,18 @@ async def extract_request_info(self):
request_info["data"] = AnnotatedValue.removed_because_over_size_limit()
return request_info

# Avoid hangs by not parsing body when ASGI stream is consumed
is_body_cached = (
"state" in self.request.scope
and "sentry_sdk.is_body_cached" in self.request.scope["state"]
and self.request.scope["state"]["sentry_sdk.is_body_cached"]
)
if self.request.is_disconnected() and not is_body_cached:
request_info["data"] = (
AnnotatedValue.removed_because_body_consumed_and_not_cached()
)
return request_info

# Add JSON body, if it is a JSON request
json = await self.json()
if json:
Expand Down
26 changes: 26 additions & 0 deletions tests/integrations/fastapi/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from tests.integrations.conftest import parametrize_test_configurable_status_codes
from tests.integrations.starlette import test_starlette

BODY_JSON = {"some": "json", "for": "testing", "nested": {"numbers": 123}}


def fastapi_app_factory():
app = FastAPI()
Expand Down Expand Up @@ -72,6 +74,29 @@ async def _thread_ids_async():
return app


def test_stream_available_in_handler(sentry_init):
sentry_init(
integrations=[StarletteIntegration(), FastApiIntegration()],
)

app = FastAPI()

@app.post("/consume")
async def _consume_stream_body(request):
# Avoid cache by constructing new request
wrapped_request = Request(request.scope, request.receive)

assert await wrapped_request.json() == BODY_JSON

return {"status": "ok"}

client = TestClient(app)
client.post(
"/consume",
json=BODY_JSON,
)


@pytest.mark.asyncio
async def test_response(sentry_init, capture_events):
# FastAPI is heavily based on Starlette so we also need
Expand Down Expand Up @@ -223,6 +248,7 @@ def test_active_thread_id(sentry_init, capture_envelopes, teardown_profiling, en
@pytest.mark.asyncio
async def test_original_request_not_scrubbed(sentry_init, capture_events):
sentry_init(
default_integrations=False,
integrations=[StarletteIntegration(), FastApiIntegration()],
traces_sample_rate=1.0,
)
Expand Down
66 changes: 65 additions & 1 deletion tests/integrations/starlette/test_starlette.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
AuthenticationError,
SimpleUser,
)
from starlette.requests import Request
from starlette.exceptions import HTTPException
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
Expand Down Expand Up @@ -435,6 +436,7 @@ async def test_starletterequestextractor_extract_request_info(sentry_init):
side_effect = [_mock_receive(msg) for msg in JSON_RECEIVE_MESSAGES]
starlette_request._receive = mock.Mock(side_effect=side_effect)

starlette_request.scope["state"] = {"sentry_sdk.is_body_cached": True}
extractor = StarletteRequestExtractor(starlette_request)

request_info = await extractor.extract_request_info()
Expand All @@ -447,6 +449,37 @@ async def test_starletterequestextractor_extract_request_info(sentry_init):
assert request_info["data"] == BODY_JSON


@pytest.mark.asyncio
async def test_starletterequestextractor_extract_request_info_not_cached(sentry_init):
sentry_init(
send_default_pii=True,
integrations=[StarletteIntegration()],
)
scope = SCOPE.copy()
scope["headers"] = [
[b"content-type", b"application/json"],
[b"content-length", str(len(json.dumps(BODY_JSON))).encode()],
[b"cookie", b"yummy_cookie=choco; tasty_cookie=strawberry"],
]

starlette_request = starlette.requests.Request(scope)

# Mocking async `_receive()` that works in Python 3.7+
side_effect = [_mock_receive(msg) for msg in JSON_RECEIVE_MESSAGES]
starlette_request._receive = mock.Mock(side_effect=side_effect)

extractor = StarletteRequestExtractor(starlette_request)

request_info = await extractor.extract_request_info()

assert request_info
assert request_info["cookies"] == {
"tasty_cookie": "strawberry",
"yummy_cookie": "choco",
}
assert request_info["data"].metadata == {"rem": [["!consumed", "x"]]}


@pytest.mark.asyncio
async def test_starletterequestextractor_extract_request_info_no_pii(sentry_init):
sentry_init(
Expand All @@ -466,6 +499,7 @@ async def test_starletterequestextractor_extract_request_info_no_pii(sentry_init
side_effect = [_mock_receive(msg) for msg in JSON_RECEIVE_MESSAGES]
starlette_request._receive = mock.Mock(side_effect=side_effect)

starlette_request.scope["state"] = {"sentry_sdk.is_body_cached": True}
extractor = StarletteRequestExtractor(starlette_request)

request_info = await extractor.extract_request_info()
Expand All @@ -475,6 +509,32 @@ async def test_starletterequestextractor_extract_request_info_no_pii(sentry_init
assert request_info["data"] == BODY_JSON


def test_stream_available_in_handler(sentry_init):
sentry_init(
integrations=[StarletteIntegration()],
)

async def _consume_stream_body(request):
# Avoid cache by constructing new request
wrapped_request = Request(request.scope, request.receive)

assert await wrapped_request.json() == BODY_JSON

return starlette.responses.JSONResponse({"status": "ok"})

app = starlette.applications.Starlette(
routes=[
starlette.routing.Route("/consume", _consume_stream_body, methods=["POST"]),
],
)

client = TestClient(app)
client.post(
"/consume",
json=BODY_JSON,
)


@pytest.mark.parametrize(
"url,transaction_style,expected_transaction,expected_source",
[
Expand Down Expand Up @@ -942,7 +1002,11 @@ def test_active_thread_id(sentry_init, capture_envelopes, teardown_profiling, en


def test_original_request_not_scrubbed(sentry_init, capture_events):
sentry_init(integrations=[StarletteIntegration()])
sentry_init(
default_integrations=False,
integrations=[StarletteIntegration()],
traces_sample_rate=1.0,
)

events = capture_events()

Expand Down
Loading