|
1 | 1 | """Utility functions for working with HTTP requests."""
|
2 | 2 |
|
3 | 3 | import json
|
| 4 | +import logging |
4 | 5 | import re
|
5 | 6 | from dataclasses import dataclass, field
|
6 |
| -from typing import Optional, Sequence |
| 7 | +from typing import Dict, Optional, Sequence |
7 | 8 | from urllib.parse import urlparse
|
8 | 9 |
|
| 10 | +from starlette.requests import Request |
| 11 | + |
9 | 12 | from ..config import EndpointMethods
|
10 | 13 |
|
| 14 | +logger = logging.getLogger(__name__) |
| 15 | + |
11 | 16 |
|
12 | 17 | def extract_variables(url: str) -> dict:
|
13 | 18 | """
|
@@ -90,3 +95,110 @@ def build_server_timing_header(
|
90 | 95 | if current_value:
|
91 | 96 | return f"{current_value}, {metric}"
|
92 | 97 | return metric
|
| 98 | + |
| 99 | + |
| 100 | +def parse_forwarded_header(forwarded_header: str) -> Dict[str, str]: |
| 101 | + """ |
| 102 | + Parse the Forwarded header according to RFC 7239. |
| 103 | +
|
| 104 | + Args: |
| 105 | + forwarded_header: The Forwarded header value |
| 106 | +
|
| 107 | + Returns: |
| 108 | + Dictionary containing parsed forwarded information (proto, host, for, by, etc.) |
| 109 | +
|
| 110 | + Example: |
| 111 | + >>> parse_forwarded_header("for=192.0.2.43; by=203.0.113.60; proto=https; host=api.example.com") |
| 112 | + {'for': '192.0.2.43', 'by': '203.0.113.60', 'proto': 'https', 'host': 'api.example.com'} |
| 113 | +
|
| 114 | + """ |
| 115 | + # Forwarded header format: "for=192.0.2.43, for=198.51.100.17; by=203.0.113.60; proto=https; host=example.com" |
| 116 | + # The format is: for=value1, for=value2; by=value; proto=value; host=value |
| 117 | + # We need to parse all the key=value pairs, taking the first 'for' value |
| 118 | + forwarded_info = {} |
| 119 | + |
| 120 | + try: |
| 121 | + # Parse all key=value pairs separated by semicolons |
| 122 | + for pair in forwarded_header.split(";"): |
| 123 | + pair = pair.strip() |
| 124 | + if "=" in pair: |
| 125 | + key, value = pair.split("=", 1) |
| 126 | + key = key.strip() |
| 127 | + value = value.strip().strip('"') |
| 128 | + |
| 129 | + # For 'for' field, only take the first value if there are multiple |
| 130 | + if key == "for" and key not in forwarded_info: |
| 131 | + # Extract the first for value (before comma if present) |
| 132 | + first_for_value = value.split(",")[0].strip() |
| 133 | + forwarded_info[key] = first_for_value |
| 134 | + elif key != "for": |
| 135 | + # For other fields, just use the value as-is |
| 136 | + forwarded_info[key] = value |
| 137 | + except Exception as e: |
| 138 | + logger.warning(f"Failed to parse Forwarded header '{forwarded_header}': {e}") |
| 139 | + return {} |
| 140 | + |
| 141 | + return forwarded_info |
| 142 | + |
| 143 | + |
| 144 | +def get_base_url(request: Request) -> str: |
| 145 | + """ |
| 146 | + Get the request's base URL, accounting for forwarded headers from load balancers/proxies. |
| 147 | +
|
| 148 | + This function handles both the standard Forwarded header (RFC 7239) and legacy |
| 149 | + X-Forwarded-* headers to reconstruct the original client URL when the service |
| 150 | + is deployed behind load balancers or reverse proxies. |
| 151 | +
|
| 152 | + Args: |
| 153 | + request: The Starlette request object |
| 154 | +
|
| 155 | + Returns: |
| 156 | + The reconstructed client base URL |
| 157 | +
|
| 158 | + Example: |
| 159 | + >>> # With Forwarded header |
| 160 | + >>> request.headers = {"Forwarded": "for=192.0.2.43; proto=https; host=api.example.com"} |
| 161 | + >>> get_base_url(request) |
| 162 | + "https://api.example.com/" |
| 163 | +
|
| 164 | + >>> # With X-Forwarded-* headers |
| 165 | + >>> request.headers = {"X-Forwarded-Host": "api.example.com", "X-Forwarded-Proto": "https"} |
| 166 | + >>> get_base_url(request) |
| 167 | + "https://api.example.com/" |
| 168 | +
|
| 169 | + """ |
| 170 | + # Check for standard Forwarded header first (RFC 7239) |
| 171 | + forwarded_header = request.headers.get("Forwarded") |
| 172 | + if forwarded_header: |
| 173 | + try: |
| 174 | + forwarded_info = parse_forwarded_header(forwarded_header) |
| 175 | + # Only use Forwarded header if we successfully parsed it and got useful info |
| 176 | + if forwarded_info and ( |
| 177 | + "proto" in forwarded_info or "host" in forwarded_info |
| 178 | + ): |
| 179 | + scheme = forwarded_info.get("proto", request.url.scheme) |
| 180 | + host = forwarded_info.get("host", request.url.netloc) |
| 181 | + # Note: Forwarded header doesn't include path, so we use request.base_url.path |
| 182 | + path = request.base_url.path |
| 183 | + return f"{scheme}://{host}{path}" |
| 184 | + except Exception as e: |
| 185 | + logger.warning(f"Failed to parse Forwarded header: {e}") |
| 186 | + |
| 187 | + # Fall back to legacy X-Forwarded-* headers |
| 188 | + forwarded_host = request.headers.get("X-Forwarded-Host") |
| 189 | + forwarded_proto = request.headers.get("X-Forwarded-Proto") |
| 190 | + forwarded_path = request.headers.get("X-Forwarded-Path") |
| 191 | + |
| 192 | + if forwarded_host: |
| 193 | + # Use forwarded headers to reconstruct the original client URL |
| 194 | + scheme = forwarded_proto or request.url.scheme |
| 195 | + netloc = forwarded_host |
| 196 | + # Use forwarded path if available, otherwise use request base URL path |
| 197 | + path = forwarded_path or request.base_url.path |
| 198 | + else: |
| 199 | + # Fall back to the request's base URL if no forwarded headers |
| 200 | + scheme = request.url.scheme |
| 201 | + netloc = request.url.netloc |
| 202 | + path = request.base_url.path |
| 203 | + |
| 204 | + return f"{scheme}://{netloc}{path}" |
0 commit comments