diff --git a/src/stac_auth_proxy/handlers/reverse_proxy.py b/src/stac_auth_proxy/handlers/reverse_proxy.py index 3cae7137..5ace457f 100644 --- a/src/stac_auth_proxy/handlers/reverse_proxy.py +++ b/src/stac_auth_proxy/handlers/reverse_proxy.py @@ -45,6 +45,8 @@ def _prepare_headers(self, request: Request) -> MutableHeaders: proxy_proto = headers.get("X-Forwarded-Proto", request.url.scheme) proxy_host = headers.get("X-Forwarded-Host", request.url.netloc) proxy_path = headers.get("X-Forwarded-Path", request.base_url.path) + proxy_port = headers.get("X-Forwarded-Port") + headers.setdefault( "Forwarded", f"for={proxy_client};host={proxy_host};proto={proxy_proto};path={proxy_path}", @@ -57,6 +59,9 @@ def _prepare_headers(self, request: Request) -> MutableHeaders: headers.setdefault("X-Forwarded-Path", proxy_path) headers.setdefault("X-Forwarded-Proto", proxy_proto) + if proxy_port: + headers["X-Forwarded-Port"] = proxy_port + # Set host to the upstream host if self.override_host: headers["Host"] = self.client.base_url.netloc.decode("utf-8") diff --git a/tests/test_reverse_proxy.py b/tests/test_reverse_proxy.py index bb840268..193a4b19 100644 --- a/tests/test_reverse_proxy.py +++ b/tests/test_reverse_proxy.py @@ -82,11 +82,13 @@ async def test_basic_headers( assert headers["X-Forwarded-Host"] == "localhost:8000" assert headers["X-Forwarded-Proto"] == "http" assert headers["X-Forwarded-Path"] == "/" + assert "X-Forwarded-Port" not in headers else: assert "X-Forwarded-For" not in headers assert "X-Forwarded-Host" not in headers assert "X-Forwarded-Proto" not in headers assert "X-Forwarded-Path" not in headers + assert "X-Forwarded-Port" not in headers @pytest.mark.parametrize("legacy_headers", [False, True]) @@ -113,11 +115,13 @@ async def test_forwarded_headers_with_client(mock_request, legacy_headers): assert headers["X-Forwarded-Host"] == "localhost:8000" assert headers["X-Forwarded-Proto"] == "http" assert headers["X-Forwarded-Path"] == "/" + assert "X-Forwarded-Port" not in headers else: assert "X-Forwarded-For" not in headers assert "X-Forwarded-Host" not in headers assert "X-Forwarded-Proto" not in headers assert "X-Forwarded-Path" not in headers + assert "X-Forwarded-Port" not in headers @pytest.mark.parametrize("legacy_headers", [False, True]) @@ -183,6 +187,7 @@ async def test_nginx_proxy_headers_preserved(legacy_headers): assert headers["X-Forwarded-Host"] == "api.example.com" assert headers["X-Forwarded-Proto"] == "https" assert headers["X-Forwarded-Path"] == "/api/v1" + assert "X-Forwarded-Port" not in headers @pytest.mark.parametrize( @@ -282,3 +287,60 @@ async def test_nginx_headers_behavior(scope_overrides, headers, expected_forward assert f"{key}={expected_value}" in forwarded, ( f"Expected {key}={expected_value} in {forwarded}" ) + + +@pytest.mark.parametrize("legacy_headers", [False, True]) +@pytest.mark.asyncio +async def test_x_forwarded_port_forwarding(legacy_headers): + """Test that X-Forwarded-Port header is always forwarded if present in incoming request.""" + headers = [ + (b"host", b"localhost:8000"), + (b"x-forwarded-port", b"80"), + (b"x-forwarded-host", b"api.example.com"), + (b"x-forwarded-proto", b"http"), + ] + request = create_request(headers=headers) + handler = ReverseProxyHandler( + upstream="http://upstream-api.com", legacy_forwarded_headers=legacy_headers + ) + result_headers = handler._prepare_headers(request) + assert result_headers["X-Forwarded-Port"] == "80" + + if legacy_headers: + assert result_headers["X-Forwarded-Host"] == "api.example.com" + assert result_headers["X-Forwarded-Proto"] == "http" + else: + assert result_headers["X-Forwarded-Host"] == "api.example.com" + assert result_headers["X-Forwarded-Proto"] == "http" + + +@pytest.mark.asyncio +async def test_x_forwarded_port_not_added_when_missing(): + """Test that X-Forwarded-Port is not added when not present in incoming request.""" + headers = [ + (b"host", b"localhost:8000"), + (b"x-forwarded-host", b"api.example.com"), + (b"x-forwarded-proto", b"http"), + ] + request = create_request(headers=headers) + handler = ReverseProxyHandler( + upstream="http://upstream-api.com", legacy_forwarded_headers=True + ) + result_headers = handler._prepare_headers(request) + assert "X-Forwarded-Port" not in result_headers + + +@pytest.mark.parametrize("legacy_headers", [False, True]) +@pytest.mark.asyncio +async def test_x_forwarded_port_always_forwarded(legacy_headers): + """Test that X-Forwarded-Port is forwarded regardless of legacy_forwarded_headers setting.""" + headers = [ + (b"host", b"localhost:8000"), + (b"x-forwarded-port", b"443"), + ] + request = create_request(headers=headers) + handler = ReverseProxyHandler( + upstream="http://upstream-api.com", legacy_forwarded_headers=legacy_headers + ) + result_headers = handler._prepare_headers(request) + assert result_headers["X-Forwarded-Port"] == "443"