Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/stac_auth_proxy/handlers/reverse_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def _prepare_headers(self, request: Request) -> MutableHeaders:
proxy_proto = headers.get("X-Forwarded-Proto", request.url.scheme)
proxy_host = headers.get("X-Forwarded-Host", request.url.netloc)
proxy_path = headers.get("X-Forwarded-Path", request.base_url.path)
proxy_port = headers.get("X-Forwarded-Port")

headers.setdefault(
"Forwarded",
f"for={proxy_client};host={proxy_host};proto={proxy_proto};path={proxy_path}",
Expand All @@ -57,6 +59,9 @@ def _prepare_headers(self, request: Request) -> MutableHeaders:
headers.setdefault("X-Forwarded-Path", proxy_path)
headers.setdefault("X-Forwarded-Proto", proxy_proto)

if proxy_port:
headers["X-Forwarded-Port"] = proxy_port

# Set host to the upstream host
if self.override_host:
headers["Host"] = self.client.base_url.netloc.decode("utf-8")
Expand Down
62 changes: 62 additions & 0 deletions tests/test_reverse_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,13 @@ async def test_basic_headers(
assert headers["X-Forwarded-Host"] == "localhost:8000"
assert headers["X-Forwarded-Proto"] == "http"
assert headers["X-Forwarded-Path"] == "/"
assert "X-Forwarded-Port" not in headers
else:
assert "X-Forwarded-For" not in headers
assert "X-Forwarded-Host" not in headers
assert "X-Forwarded-Proto" not in headers
assert "X-Forwarded-Path" not in headers
assert "X-Forwarded-Port" not in headers


@pytest.mark.parametrize("legacy_headers", [False, True])
Expand All @@ -113,11 +115,13 @@ async def test_forwarded_headers_with_client(mock_request, legacy_headers):
assert headers["X-Forwarded-Host"] == "localhost:8000"
assert headers["X-Forwarded-Proto"] == "http"
assert headers["X-Forwarded-Path"] == "/"
assert "X-Forwarded-Port" not in headers
else:
assert "X-Forwarded-For" not in headers
assert "X-Forwarded-Host" not in headers
assert "X-Forwarded-Proto" not in headers
assert "X-Forwarded-Path" not in headers
assert "X-Forwarded-Port" not in headers


@pytest.mark.parametrize("legacy_headers", [False, True])
Expand Down Expand Up @@ -183,6 +187,7 @@ async def test_nginx_proxy_headers_preserved(legacy_headers):
assert headers["X-Forwarded-Host"] == "api.example.com"
assert headers["X-Forwarded-Proto"] == "https"
assert headers["X-Forwarded-Path"] == "/api/v1"
assert "X-Forwarded-Port" not in headers


@pytest.mark.parametrize(
Expand Down Expand Up @@ -282,3 +287,60 @@ async def test_nginx_headers_behavior(scope_overrides, headers, expected_forward
assert f"{key}={expected_value}" in forwarded, (
f"Expected {key}={expected_value} in {forwarded}"
)


@pytest.mark.parametrize("legacy_headers", [False, True])
@pytest.mark.asyncio
async def test_x_forwarded_port_forwarding(legacy_headers):
"""Test that X-Forwarded-Port header is always forwarded if present in incoming request."""
headers = [
(b"host", b"localhost:8000"),
(b"x-forwarded-port", b"80"),
(b"x-forwarded-host", b"api.example.com"),
(b"x-forwarded-proto", b"http"),
]
request = create_request(headers=headers)
handler = ReverseProxyHandler(
upstream="http://upstream-api.com", legacy_forwarded_headers=legacy_headers
)
result_headers = handler._prepare_headers(request)
assert result_headers["X-Forwarded-Port"] == "80"

if legacy_headers:
assert result_headers["X-Forwarded-Host"] == "api.example.com"
assert result_headers["X-Forwarded-Proto"] == "http"
else:
assert result_headers["X-Forwarded-Host"] == "api.example.com"
assert result_headers["X-Forwarded-Proto"] == "http"


@pytest.mark.asyncio
async def test_x_forwarded_port_not_added_when_missing():
"""Test that X-Forwarded-Port is not added when not present in incoming request."""
headers = [
(b"host", b"localhost:8000"),
(b"x-forwarded-host", b"api.example.com"),
(b"x-forwarded-proto", b"http"),
]
request = create_request(headers=headers)
handler = ReverseProxyHandler(
upstream="http://upstream-api.com", legacy_forwarded_headers=True
)
result_headers = handler._prepare_headers(request)
assert "X-Forwarded-Port" not in result_headers


@pytest.mark.parametrize("legacy_headers", [False, True])
@pytest.mark.asyncio
async def test_x_forwarded_port_always_forwarded(legacy_headers):
"""Test that X-Forwarded-Port is forwarded regardless of legacy_forwarded_headers setting."""
headers = [
(b"host", b"localhost:8000"),
(b"x-forwarded-port", b"443"),
]
request = create_request(headers=headers)
handler = ReverseProxyHandler(
upstream="http://upstream-api.com", legacy_forwarded_headers=legacy_headers
)
result_headers = handler._prepare_headers(request)
assert result_headers["X-Forwarded-Port"] == "443"
Loading