Skip to content

Commit e52b5a9

Browse files
authored
fix: improve link processing (#95)
We expect that upstream STAC APIs will set the origin of URLs contained in the `href` property of the `links` array to be properly pointed to the STAC Auth Proxy URL based on the [`Forwarded` header](https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Forwarded) we provide. However, in certain configurations, upstream APIs don't perform this change (e.g. the proxy headers are not properly forwarded to the upstream STAC API via a reverse proxy). This PR updates how we rework links to ensure that links that have an origin and path matching the path of the upstream STAC API will be reworked to instead point to the origin and potential path of the STAC Auth Proxy.
1 parent 0f35ff4 commit e52b5a9

File tree

5 files changed

+793
-143
lines changed

5 files changed

+793
-143
lines changed

src/stac_auth_proxy/__main__.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,20 @@
33
import uvicorn
44
from uvicorn.config import LOGGING_CONFIG
55

6-
LOGGING_CONFIG["loggers"][__package__] = {
7-
"level": "DEBUG",
8-
"handlers": ["default"],
9-
}
10-
116
uvicorn.run(
127
f"{__package__}.app:create_app",
138
host="0.0.0.0",
149
port=8000,
15-
log_config=LOGGING_CONFIG,
10+
log_config={
11+
**LOGGING_CONFIG,
12+
"loggers": {
13+
**LOGGING_CONFIG["loggers"],
14+
__package__: {
15+
"level": "DEBUG",
16+
"handlers": ["default"],
17+
},
18+
},
19+
},
1620
reload=True,
1721
factory=True,
1822
)

src/stac_auth_proxy/middleware/ProcessLinksMiddleware.py

Lines changed: 74 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44
import re
55
from dataclasses import dataclass
66
from typing import Any, Optional
7-
from urllib.parse import urlparse, urlunparse
7+
from urllib.parse import ParseResult, urlparse, urlunparse
88

99
from starlette.datastructures import Headers
1010
from starlette.requests import Request
1111
from starlette.types import ASGIApp, Scope
1212

1313
from ..utils.middleware import JsonResponseMiddleware
14+
from ..utils.requests import get_base_url
1415
from ..utils.stac import get_links
1516

1617
logger = logging.getLogger(__name__)
@@ -40,37 +41,81 @@ def should_transform_response(self, request: Request, scope: Scope) -> bool:
4041

4142
def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, Any]:
4243
"""Update links in the response to include root_path."""
43-
for link in get_links(data):
44-
href = link.get("href")
45-
if not href:
46-
continue
44+
# Get the client's actual base URL (accounting for load balancers/proxies)
45+
req_base_url = get_base_url(request)
46+
parsed_req_url = urlparse(req_base_url)
47+
parsed_upstream_url = urlparse(self.upstream_url)
4748

49+
for link in get_links(data):
4850
try:
49-
parsed_link = urlparse(href)
50-
51-
# Ignore links that are not for this proxy
52-
if parsed_link.netloc != request.headers.get("host"):
53-
continue
54-
55-
# Remove the upstream_url path from the link if it exists
56-
parsed_upstream_url = urlparse(self.upstream_url)
57-
if parsed_upstream_url.path != "/" and parsed_link.path.startswith(
58-
parsed_upstream_url.path
59-
):
60-
parsed_link = parsed_link._replace(
61-
path=parsed_link.path[len(parsed_upstream_url.path) :]
62-
)
63-
64-
# Add the root_path to the link if it exists
65-
if self.root_path:
66-
parsed_link = parsed_link._replace(
67-
path=f"{self.root_path}{parsed_link.path}"
68-
)
69-
70-
link["href"] = urlunparse(parsed_link)
51+
self._update_link(link, parsed_req_url, parsed_upstream_url)
7152
except Exception as e:
7253
logger.error(
73-
"Failed to parse link href %r, (ignoring): %s", href, str(e)
54+
"Failed to parse link href %r, (ignoring): %s",
55+
link.get("href"),
56+
str(e),
7457
)
75-
7658
return data
59+
60+
def _update_link(
61+
self, link: dict[str, Any], request_url: ParseResult, upstream_url: ParseResult
62+
) -> None:
63+
"""
64+
Ensure that link hrefs that are local to upstream url are rewritten as local to
65+
the proxy.
66+
"""
67+
if "href" not in link:
68+
logger.warning("Link %r has no href", link)
69+
return
70+
71+
parsed_link = urlparse(link["href"])
72+
73+
if parsed_link.netloc not in [
74+
request_url.netloc,
75+
upstream_url.netloc,
76+
]:
77+
logger.debug(
78+
"Ignoring link %s because it is not for an endpoint behind this proxy (%s or %s)",
79+
link["href"],
80+
request_url.netloc,
81+
upstream_url.netloc,
82+
)
83+
return
84+
85+
# If the link path is not a descendant of the upstream path, don't transform it
86+
if upstream_url.path != "/" and not parsed_link.path.startswith(
87+
upstream_url.path
88+
):
89+
logger.debug(
90+
"Ignoring link %s because it is not descendant of upstream path (%s)",
91+
link["href"],
92+
upstream_url.path,
93+
)
94+
return
95+
96+
# Replace the upstream host with the client's host
97+
if parsed_link.netloc == upstream_url.netloc:
98+
parsed_link = parsed_link._replace(netloc=request_url.netloc)._replace(
99+
scheme=request_url.scheme
100+
)
101+
102+
# Rewrite the link path
103+
if upstream_url.path != "/" and parsed_link.path.startswith(upstream_url.path):
104+
parsed_link = parsed_link._replace(
105+
path=parsed_link.path[len(upstream_url.path) :]
106+
)
107+
108+
# Add the root_path to the link if it exists
109+
if self.root_path:
110+
parsed_link = parsed_link._replace(
111+
path=f"{self.root_path}{parsed_link.path}"
112+
)
113+
114+
logger.debug(
115+
"Rewriting %r link %r to %r",
116+
link.get("rel"),
117+
link["href"],
118+
urlunparse(parsed_link),
119+
)
120+
121+
link["href"] = urlunparse(parsed_link)

src/stac_auth_proxy/utils/requests.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
"""Utility functions for working with HTTP requests."""
22

33
import json
4+
import logging
45
import re
56
from dataclasses import dataclass, field
6-
from typing import Optional, Sequence
7+
from typing import Dict, Optional, Sequence
78
from urllib.parse import urlparse
89

10+
from starlette.requests import Request
11+
912
from ..config import EndpointMethods
1013

14+
logger = logging.getLogger(__name__)
15+
1116

1217
def extract_variables(url: str) -> dict:
1318
"""
@@ -90,3 +95,110 @@ def build_server_timing_header(
9095
if current_value:
9196
return f"{current_value}, {metric}"
9297
return metric
98+
99+
100+
def parse_forwarded_header(forwarded_header: str) -> Dict[str, str]:
101+
"""
102+
Parse the Forwarded header according to RFC 7239.
103+
104+
Args:
105+
forwarded_header: The Forwarded header value
106+
107+
Returns:
108+
Dictionary containing parsed forwarded information (proto, host, for, by, etc.)
109+
110+
Example:
111+
>>> parse_forwarded_header("for=192.0.2.43; by=203.0.113.60; proto=https; host=api.example.com")
112+
{'for': '192.0.2.43', 'by': '203.0.113.60', 'proto': 'https', 'host': 'api.example.com'}
113+
114+
"""
115+
# Forwarded header format: "for=192.0.2.43, for=198.51.100.17; by=203.0.113.60; proto=https; host=example.com"
116+
# The format is: for=value1, for=value2; by=value; proto=value; host=value
117+
# We need to parse all the key=value pairs, taking the first 'for' value
118+
forwarded_info = {}
119+
120+
try:
121+
# Parse all key=value pairs separated by semicolons
122+
for pair in forwarded_header.split(";"):
123+
pair = pair.strip()
124+
if "=" in pair:
125+
key, value = pair.split("=", 1)
126+
key = key.strip()
127+
value = value.strip().strip('"')
128+
129+
# For 'for' field, only take the first value if there are multiple
130+
if key == "for" and key not in forwarded_info:
131+
# Extract the first for value (before comma if present)
132+
first_for_value = value.split(",")[0].strip()
133+
forwarded_info[key] = first_for_value
134+
elif key != "for":
135+
# For other fields, just use the value as-is
136+
forwarded_info[key] = value
137+
except Exception as e:
138+
logger.warning(f"Failed to parse Forwarded header '{forwarded_header}': {e}")
139+
return {}
140+
141+
return forwarded_info
142+
143+
144+
def get_base_url(request: Request) -> str:
145+
"""
146+
Get the request's base URL, accounting for forwarded headers from load balancers/proxies.
147+
148+
This function handles both the standard Forwarded header (RFC 7239) and legacy
149+
X-Forwarded-* headers to reconstruct the original client URL when the service
150+
is deployed behind load balancers or reverse proxies.
151+
152+
Args:
153+
request: The Starlette request object
154+
155+
Returns:
156+
The reconstructed client base URL
157+
158+
Example:
159+
>>> # With Forwarded header
160+
>>> request.headers = {"Forwarded": "for=192.0.2.43; proto=https; host=api.example.com"}
161+
>>> get_base_url(request)
162+
"https://api.example.com/"
163+
164+
>>> # With X-Forwarded-* headers
165+
>>> request.headers = {"X-Forwarded-Host": "api.example.com", "X-Forwarded-Proto": "https"}
166+
>>> get_base_url(request)
167+
"https://api.example.com/"
168+
169+
"""
170+
# Check for standard Forwarded header first (RFC 7239)
171+
forwarded_header = request.headers.get("Forwarded")
172+
if forwarded_header:
173+
try:
174+
forwarded_info = parse_forwarded_header(forwarded_header)
175+
# Only use Forwarded header if we successfully parsed it and got useful info
176+
if forwarded_info and (
177+
"proto" in forwarded_info or "host" in forwarded_info
178+
):
179+
scheme = forwarded_info.get("proto", request.url.scheme)
180+
host = forwarded_info.get("host", request.url.netloc)
181+
# Note: Forwarded header doesn't include path, so we use request.base_url.path
182+
path = request.base_url.path
183+
return f"{scheme}://{host}{path}"
184+
except Exception as e:
185+
logger.warning(f"Failed to parse Forwarded header: {e}")
186+
187+
# Fall back to legacy X-Forwarded-* headers
188+
forwarded_host = request.headers.get("X-Forwarded-Host")
189+
forwarded_proto = request.headers.get("X-Forwarded-Proto")
190+
forwarded_path = request.headers.get("X-Forwarded-Path")
191+
192+
if forwarded_host:
193+
# Use forwarded headers to reconstruct the original client URL
194+
scheme = forwarded_proto or request.url.scheme
195+
netloc = forwarded_host
196+
# Use forwarded path if available, otherwise use request base URL path
197+
path = forwarded_path or request.base_url.path
198+
else:
199+
# Fall back to the request's base URL if no forwarded headers
200+
scheme = request.url.scheme
201+
netloc = request.url.netloc
202+
path = request.base_url.path
203+
204+
return f"{scheme}://{netloc}{path}"

0 commit comments

Comments
 (0)