Skip to content

Commit cb2d46a

Browse files
kenneivesclaude
andcommitted
Add rate limit response headers to all rate-limited endpoints
Responses from rate-limited endpoints now include: - X-RateLimit-Limit: max requests per window - X-RateLimit-Remaining: requests left in current window - X-RateLimit-Reset: seconds until window resets Implemented via HTTP middleware that reads rate limit state. 429 responses also include headers. 4 new tests. 428 tests passing. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent d800ce8 commit cb2d46a

3 files changed

Lines changed: 155 additions & 4 deletions

File tree

src/api/rate_limit.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,19 @@ def check(self, key: str, limit: int, window_seconds: int = 60) -> bool:
3434
self._windows[key].append(time.time())
3535
return True
3636

37+
def get_remaining(self, key: str, limit: int, window_seconds: int = 60) -> int:
38+
"""Return how many requests remain in the current window."""
39+
self._clean_window(key, window_seconds)
40+
used = len(self._windows.get(key, []))
41+
return max(0, limit - used)
42+
43+
def get_reset_time(self, key: str, window_seconds: int = 60) -> int:
44+
"""Return seconds until the oldest request in the window expires."""
45+
if key not in self._windows or not self._windows[key]:
46+
return 0
47+
oldest = min(self._windows[key])
48+
return max(0, int(window_seconds - (time.time() - oldest)))
49+
3750

3851
# Global instance
3952
_limiter = InMemoryRateLimiter()
@@ -62,14 +75,28 @@ def _rate_limit_response(
6275
}
6376

6477

78+
def _set_rate_limit_headers(
79+
request: Request, key: str, limit: int, window: int = 60,
80+
) -> None:
81+
"""Store rate limit info on request state for middleware to add to response."""
82+
remaining = _limiter.get_remaining(key, limit, window)
83+
reset = _limiter.get_reset_time(key, window)
84+
request.state.rate_limit_limit = limit
85+
request.state.rate_limit_remaining = remaining
86+
request.state.rate_limit_reset = reset
87+
88+
6589
async def rate_limit_reads(request: Request) -> None:
6690
ip = _get_client_ip(request)
6791
limit = settings.rate_limit_reads_per_minute
68-
if not _limiter.check(f"read:{ip}", limit):
92+
key = f"read:{ip}"
93+
if not _limiter.check(key, limit):
6994
raise HTTPException(
7095
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
7196
detail="Rate limit exceeded",
97+
headers=_rate_limit_response(0, limit),
7298
)
99+
_set_rate_limit_headers(request, key, limit)
73100
# Per-entity limit (2x IP limit to be generous)
74101
entity_id = _get_entity_id(request)
75102
if entity_id:
@@ -78,32 +105,41 @@ async def rate_limit_reads(request: Request) -> None:
78105
raise HTTPException(
79106
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
80107
detail="Rate limit exceeded",
108+
headers=_rate_limit_response(0, entity_limit),
81109
)
82110

83111

84112
async def rate_limit_writes(request: Request) -> None:
85113
ip = _get_client_ip(request)
86114
limit = settings.rate_limit_writes_per_minute
87-
if not _limiter.check(f"write:{ip}", limit):
115+
key = f"write:{ip}"
116+
if not _limiter.check(key, limit):
88117
raise HTTPException(
89118
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
90119
detail="Rate limit exceeded",
120+
headers=_rate_limit_response(0, limit),
91121
)
122+
_set_rate_limit_headers(request, key, limit)
92123
entity_id = _get_entity_id(request)
93124
if entity_id:
94125
entity_limit = limit * 2
95126
if not _limiter.check(f"write:entity:{entity_id}", entity_limit):
96127
raise HTTPException(
97128
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
98129
detail="Rate limit exceeded",
130+
headers=_rate_limit_response(0, entity_limit),
99131
)
100132

101133

102134
async def rate_limit_auth(request: Request) -> None:
103135
"""Stricter limit for auth endpoints."""
104136
ip = _get_client_ip(request)
105-
if not _limiter.check(f"auth:{ip}", settings.rate_limit_auth_per_minute):
137+
limit = settings.rate_limit_auth_per_minute
138+
key = f"auth:{ip}"
139+
if not _limiter.check(key, limit):
106140
raise HTTPException(
107141
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
108142
detail="Too many attempts. Try again later.",
143+
headers=_rate_limit_response(0, limit),
109144
)
145+
_set_rate_limit_headers(request, key, limit)

src/main.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from fastapi import FastAPI
3+
from fastapi import FastAPI, Request, Response
44
from fastapi.middleware.cors import CORSMiddleware
55

66
from src.api.account_router import router as account_router
@@ -77,6 +77,19 @@
7777
allow_headers=["*"],
7878
)
7979

80+
@app.middleware("http")
81+
async def rate_limit_headers_middleware(request: Request, call_next) -> Response:
82+
"""Add rate limit headers to responses when available."""
83+
response: Response = await call_next(request)
84+
if hasattr(request.state, "rate_limit_limit"):
85+
response.headers["X-RateLimit-Limit"] = str(request.state.rate_limit_limit)
86+
response.headers["X-RateLimit-Remaining"] = str(
87+
request.state.rate_limit_remaining
88+
)
89+
response.headers["X-RateLimit-Reset"] = str(request.state.rate_limit_reset)
90+
return response
91+
92+
8093
app.include_router(account_router, prefix=settings.api_v1_prefix)
8194
app.include_router(activity_router, prefix=settings.api_v1_prefix)
8295
app.include_router(admin_router, prefix=settings.api_v1_prefix)

tests/test_rate_limit_headers.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
from __future__ import annotations
2+
3+
import pytest
4+
import pytest_asyncio
5+
from httpx import ASGITransport, AsyncClient
6+
7+
from src.database import get_db
8+
from src.main import app
9+
10+
11+
@pytest_asyncio.fixture
12+
async def client(db):
13+
async def override_get_db():
14+
yield db
15+
16+
app.dependency_overrides[get_db] = override_get_db
17+
transport = ASGITransport(app=app)
18+
async with AsyncClient(transport=transport, base_url="http://test") as ac:
19+
yield ac
20+
app.dependency_overrides.clear()
21+
22+
23+
REGISTER_URL = "/api/v1/auth/register"
24+
LOGIN_URL = "/api/v1/auth/login"
25+
26+
USER = {
27+
"email": "ratelimit@example.com",
28+
"password": "Str0ngP@ss",
29+
"display_name": "RateLimitUser",
30+
}
31+
32+
33+
def _auth(token: str) -> dict:
34+
return {"Authorization": f"Bearer {token}"}
35+
36+
37+
async def _setup_user(client: AsyncClient) -> str:
38+
await client.post(REGISTER_URL, json=USER)
39+
resp = await client.post(
40+
LOGIN_URL, json={"email": USER["email"], "password": USER["password"]}
41+
)
42+
return resp.json()["access_token"]
43+
44+
45+
@pytest.mark.asyncio
46+
async def test_rate_limit_headers_on_read(client: AsyncClient, db):
47+
"""Rate-limited read endpoints include rate limit headers."""
48+
token = await _setup_user(client)
49+
50+
resp = await client.get(
51+
"/api/v1/feed/posts",
52+
headers=_auth(token),
53+
)
54+
assert resp.status_code == 200
55+
assert "x-ratelimit-limit" in resp.headers
56+
assert "x-ratelimit-remaining" in resp.headers
57+
assert "x-ratelimit-reset" in resp.headers
58+
assert int(resp.headers["x-ratelimit-limit"]) > 0
59+
assert int(resp.headers["x-ratelimit-remaining"]) >= 0
60+
61+
62+
@pytest.mark.asyncio
63+
async def test_rate_limit_headers_on_write(client: AsyncClient, db):
64+
"""Rate-limited write endpoints include rate limit headers."""
65+
token = await _setup_user(client)
66+
67+
resp = await client.post(
68+
"/api/v1/feed/posts",
69+
json={"content": "Testing rate limit headers"},
70+
headers=_auth(token),
71+
)
72+
assert resp.status_code == 201
73+
assert "x-ratelimit-limit" in resp.headers
74+
assert int(resp.headers["x-ratelimit-remaining"]) >= 0
75+
76+
77+
@pytest.mark.asyncio
78+
async def test_rate_limit_remaining_decreases(client: AsyncClient, db):
79+
"""Remaining count decreases with each request."""
80+
token = await _setup_user(client)
81+
82+
resp1 = await client.get(
83+
"/api/v1/feed/posts",
84+
headers=_auth(token),
85+
)
86+
remaining1 = int(resp1.headers["x-ratelimit-remaining"])
87+
88+
resp2 = await client.get(
89+
"/api/v1/feed/posts",
90+
headers=_auth(token),
91+
)
92+
remaining2 = int(resp2.headers["x-ratelimit-remaining"])
93+
94+
assert remaining2 < remaining1
95+
96+
97+
@pytest.mark.asyncio
98+
async def test_no_rate_limit_headers_on_unrated(client: AsyncClient, db):
99+
"""Endpoints without rate limiting don't have rate limit headers."""
100+
resp = await client.get("/health")
101+
assert resp.status_code == 200
102+
assert "x-ratelimit-limit" not in resp.headers

0 commit comments

Comments
 (0)