Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 64 additions & 8 deletions src/stac_auth_proxy/middleware/ProcessLinksMiddleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,45 @@
logger = logging.getLogger(__name__)


def _extract_hostname(netloc: str) -> str:
"""Extract hostname from netloc."""
if ":" in netloc:
if netloc.startswith("["):
# IPv6 with port: [::1]:8080
end_bracket = netloc.rfind("]")
if end_bracket != -1:
return netloc[: end_bracket + 1]
return netloc.split(":", 1)[0]
return netloc


def _netlocs_match(netloc1: str, scheme1: str, netloc2: str, scheme2: str) -> bool:
"""
Check if two netlocs match. Ports must match exactly, but missing ports
are assumed to be standard ports (80 for http, 443 for https).
"""
if _extract_hostname(netloc1).lower() != _extract_hostname(netloc2).lower():
return False

def _get_port(netloc: str, scheme: str) -> int:
if ":" in netloc:
if netloc.startswith("["):
end_bracket = netloc.rfind("]")
if end_bracket != -1 and end_bracket + 1 < len(netloc):
try:
return int(netloc[end_bracket + 2 :])
except ValueError:
pass
else:
try:
return int(netloc.split(":", 1)[1])
except ValueError:
pass
return 443 if scheme == "https" else 80

return _get_port(netloc1, scheme1) == _get_port(netloc2, scheme2)


@dataclass
class ProcessLinksMiddleware(JsonResponseMiddleware):
"""
Expand Down Expand Up @@ -70,10 +109,20 @@ def _update_link(

parsed_link = urlparse(link["href"])

if parsed_link.netloc not in [
request_url.netloc,
upstream_url.netloc,
]:
if not (
_netlocs_match(
parsed_link.netloc,
parsed_link.scheme,
request_url.netloc,
request_url.scheme,
)
or _netlocs_match(
parsed_link.netloc,
parsed_link.scheme,
upstream_url.netloc,
upstream_url.scheme,
)
):
logger.debug(
"Ignoring link %s because it is not for an endpoint behind this proxy (%s or %s)",
link["href"],
Expand All @@ -94,10 +143,17 @@ def _update_link(
return

# Replace the upstream host with the client's host
if parsed_link.netloc == upstream_url.netloc:
parsed_link = parsed_link._replace(netloc=request_url.netloc)._replace(
scheme=request_url.scheme
)
link_matches_upstream = _netlocs_match(
parsed_link.netloc,
parsed_link.scheme,
upstream_url.netloc,
upstream_url.scheme,
)
parsed_link = parsed_link._replace(netloc=request_url.netloc)
if link_matches_upstream:
# Link hostname matches upstream: also replace scheme with request URL's scheme
parsed_link = parsed_link._replace(scheme=request_url.scheme)
# If link matches request hostname, scheme is preserved (handles https://localhost:443 -> http://localhost)

# Remove the upstream prefix from the link path
if upstream_url.path != "/" and parsed_link.path.startswith(upstream_url.path):
Expand Down
177 changes: 177 additions & 0 deletions tests/test_process_links.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,3 +597,180 @@ def test_transform_with_forwarded_headers(headers, expected_base_url):
# but not include the forwarded path in the response URLs
assert transformed["links"][0]["href"] == f"{expected_base_url}/proxy/collections"
assert transformed["links"][1]["href"] == f"{expected_base_url}/proxy"


@pytest.mark.parametrize(
"upstream_url,root_path,request_host,input_links,expected_links",
[
# Basic localhost:PORT rewriting (common port 8080)
(
"http://eoapi-stac:8080",
"/stac",
"localhost",
[
{"rel": "data", "href": "http://localhost:8080/collections"},
],
[
"http://localhost/stac/collections",
],
),
# Standard HTTP port
(
"http://eoapi-stac:8080",
"/stac",
"localhost",
[
{"rel": "self", "href": "http://localhost:80/collections"},
],
[
"http://localhost/stac/collections",
],
),
# HTTPS port
(
"http://eoapi-stac:8080",
"/stac",
"localhost",
[
{"rel": "self", "href": "https://localhost:443/collections"},
],
[
"https://localhost/stac/collections",
],
),
# Arbitrary port
(
"http://eoapi-stac:8080",
"/stac",
"localhost",
[
{"rel": "self", "href": "http://localhost:3000/collections"},
],
[
"http://localhost/stac/collections",
],
),
# Multiple links with different ports
(
"http://eoapi-stac:8080",
"/stac",
"localhost",
[
{"rel": "self", "href": "http://localhost:8080/collections"},
{"rel": "root", "href": "http://localhost:80/"},
{
"rel": "items",
"href": "https://localhost:443/collections/test/items",
},
],
[
"http://localhost/stac/collections",
"http://localhost/stac/",
"https://localhost/stac/collections/test/items",
],
),
# localhost:PORT with upstream path
(
"http://eoapi-stac:8080/api",
"/stac",
"localhost",
[
{"rel": "self", "href": "http://localhost:8080/api/collections"},
],
[
"http://localhost/stac/collections",
],
),
# Request host with port should still work (port removed in rewrite)
(
"http://eoapi-stac:8080",
"/stac",
"localhost:80",
[
{"rel": "self", "href": "http://localhost:8080/collections"},
],
[
"http://localhost:80/stac/collections",
],
),
],
)
def test_transform_localhost_with_port(
upstream_url, root_path, request_host, input_links, expected_links
):
"""Test transforming links with localhost:PORT (any port number)."""
middleware = ProcessLinksMiddleware(
app=None, upstream_url=upstream_url, root_path=root_path
)
request_scope = {
"type": "http",
"path": "/test",
"headers": [
(b"host", request_host.encode()),
(b"content-type", b"application/json"),
],
}

data = {"links": input_links}
transformed = middleware.transform_json(data, Request(request_scope))

for i, expected in enumerate(expected_links):
assert transformed["links"][i]["href"] == expected


def test_localhost_with_port_preserves_other_hostnames():
"""Test that links with other hostnames are not transformed."""
middleware = ProcessLinksMiddleware(
app=None,
upstream_url="http://eoapi-stac:8080",
root_path="/stac",
)
request_scope = {
"type": "http",
"path": "/test",
"headers": [
(b"host", b"localhost"),
(b"content-type", b"application/json"),
],
}

data = {
"links": [
{"rel": "external", "href": "http://example.com:8080/collections"},
{"rel": "other", "href": "http://other-host:3000/collections"},
]
}

transformed = middleware.transform_json(data, Request(request_scope))

# External hostnames should remain unchanged
assert transformed["links"][0]["href"] == "http://example.com:8080/collections"
assert transformed["links"][1]["href"] == "http://other-host:3000/collections"


def test_localhost_with_port_upstream_service_name_still_works():
"""Test that upstream service name matching still works."""
middleware = ProcessLinksMiddleware(
app=None,
upstream_url="http://eoapi-stac:8080",
root_path="/stac",
)
request_scope = {
"type": "http",
"path": "/test",
"headers": [
(b"host", b"localhost"),
(b"content-type", b"application/json"),
],
}

data = {
"links": [
{"rel": "self", "href": "http://eoapi-stac:8080/collections"},
]
}

transformed = middleware.transform_json(data, Request(request_scope))

# Upstream service name should be rewritten to request hostname
assert transformed["links"][0]["href"] == "http://localhost/stac/collections"
Loading