Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 59 additions & 12 deletions fastmcp_slim/fastmcp/server/middleware/rate_limiting.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Rate limiting middleware for protecting FastMCP servers from abuse."""

import time
from collections import defaultdict, deque
from collections import OrderedDict, deque
from collections.abc import Callable
from typing import Any

Expand Down Expand Up @@ -116,6 +116,7 @@ def __init__(
burst_capacity: int | None = None,
get_client_id: Callable[[MiddlewareContext], str] | None = None,
global_limit: bool = False,
max_clients: int = 10000,
):
"""Initialize rate limiting middleware.

Expand All @@ -124,25 +125,47 @@ def __init__(
burst_capacity: Maximum burst capacity. If None, defaults to 2x max_requests_per_second
get_client_id: Function to extract client ID from context. If None, uses global limiting
global_limit: If True, apply limit globally; if False, per-client
max_clients: Maximum number of per-client limiters to track. Must be >= 1.
When reached, the least recently used client is evicted to make
room for a new client.
"""
if max_clients < 1:
raise ValueError(f"max_clients must be >= 1, got {max_clients}")
self.max_requests_per_second = max_requests_per_second
self.burst_capacity = burst_capacity or int(max_requests_per_second * 2)
self.get_client_id = get_client_id
self.global_limit = global_limit
self._max_clients = max_clients

# Storage for rate limiters per client
self.limiters: dict[str, TokenBucketRateLimiter] = defaultdict(
lambda: TokenBucketRateLimiter(
self.burst_capacity, self.max_requests_per_second
)
)
# Per-client limiters stored in LRU order (oldest at front, newest at back)
self._client_limiters: OrderedDict[str, TokenBucketRateLimiter] = OrderedDict()

# Global rate limiter
if self.global_limit:
self.global_limiter = TokenBucketRateLimiter(
self.burst_capacity, self.max_requests_per_second
)

def _get_limiter(self, client_id: str) -> TokenBucketRateLimiter:
"""Get or create a rate limiter for a client, with LRU eviction.

When the cache is full, the least-recently-used client is evicted.
This is safe because an evicted client was inactive long enough for
max_clients other clients to be more recent — by which point their
token bucket would have refilled to capacity anyway.
"""
if client_id in self._client_limiters:
self._client_limiters.move_to_end(client_id)
return self._client_limiters[client_id]

limiter = TokenBucketRateLimiter(
self.burst_capacity, self.max_requests_per_second
)
self._client_limiters[client_id] = limiter
if len(self._client_limiters) > self._max_clients:
self._client_limiters.popitem(last=False)
Comment thread
strawgate marked this conversation as resolved.
return limiter

def _get_client_identifier(self, context: MiddlewareContext) -> str:
"""Get client identifier for rate limiting."""
if self.get_client_id:
Expand All @@ -159,7 +182,7 @@ async def on_request(self, context: MiddlewareContext, call_next: CallNext) -> A
else:
# Per-client rate limiting
client_id = self._get_client_identifier(context)
limiter = self.limiters[client_id]
limiter = self._get_limiter(client_id)
allowed = await limiter.consume()
if not allowed:
raise RateLimitError(f"Rate limit exceeded for client: {client_id}")
Expand Down Expand Up @@ -193,23 +216,47 @@ def __init__(
max_requests: int,
window_minutes: int = 1,
get_client_id: Callable[[MiddlewareContext], str] | None = None,
max_clients: int = 10000,
):
"""Initialize sliding window rate limiting middleware.

Args:
max_requests: Maximum requests allowed in the time window
window_minutes: Time window in minutes
get_client_id: Function to extract client ID from context
max_clients: Maximum number of per-client limiters to track. Must be >= 1.
When reached, the least recently used client is evicted to make
room for a new client.
"""
if max_clients < 1:
raise ValueError(f"max_clients must be >= 1, got {max_clients}")
self.max_requests = max_requests
self.window_seconds = window_minutes * 60
self.get_client_id = get_client_id
self._max_clients = max_clients

# Storage for rate limiters per client
self.limiters: dict[str, SlidingWindowRateLimiter] = defaultdict(
lambda: SlidingWindowRateLimiter(self.max_requests, self.window_seconds)
self._client_limiters: OrderedDict[str, SlidingWindowRateLimiter] = (
OrderedDict()
)

def _get_limiter(self, client_id: str) -> SlidingWindowRateLimiter:
"""Get or create a rate limiter for a client, with LRU eviction.

When the cache is full, the least-recently-used client is evicted.
This is safe because an evicted client was inactive long enough for
max_clients other clients to be more recent — by which point their
sliding window would have expired anyway.
"""
if client_id in self._client_limiters:
self._client_limiters.move_to_end(client_id)
return self._client_limiters[client_id]

limiter = SlidingWindowRateLimiter(self.max_requests, self.window_seconds)
self._client_limiters[client_id] = limiter
if len(self._client_limiters) > self._max_clients:
self._client_limiters.popitem(last=False)
Comment thread
strawgate marked this conversation as resolved.
return limiter

def _get_client_identifier(self, context: MiddlewareContext) -> str:
"""Get client identifier for rate limiting."""
if self.get_client_id:
Expand All @@ -219,7 +266,7 @@ def _get_client_identifier(self, context: MiddlewareContext) -> str:
async def on_request(self, context: MiddlewareContext, call_next: CallNext) -> Any:
"""Apply sliding window rate limiting to requests."""
client_id = self._get_client_identifier(context)
limiter = self.limiters[client_id]
limiter = self._get_limiter(client_id)

allowed = await limiter.is_allowed()
if not allowed:
Expand Down
67 changes: 67 additions & 0 deletions tests/server/middleware/test_rate_limiting.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,73 @@ async def test_on_request_rate_limited(self, mock_context, mock_call_next):
await middleware.on_request(mock_context, mock_call_next)


class TestRateLimiterLRUEviction:
"""LRU bounding of per-client limiters (issue #4053)."""

def test_token_bucket_evicts_lru_when_full(self):
mw = RateLimitingMiddleware(get_client_id=lambda ctx: ctx, max_clients=2)
mw._get_limiter("a")
mw._get_limiter("b")
# Cache full at 2; adding "c" evicts the LRU ("a"), not the newcomer.
mw._get_limiter("c")
assert list(mw._client_limiters) == ["b", "c"]

def test_token_bucket_move_to_end_on_access_protects_active_client(self):
mw = RateLimitingMiddleware(get_client_id=lambda ctx: ctx, max_clients=2)
mw._get_limiter("a")
mw._get_limiter("b")
# Touch "a" so it is most-recently-used; "b" is now LRU.
mw._get_limiter("a")
mw._get_limiter("c")
assert list(mw._client_limiters) == ["a", "c"]

def test_token_bucket_reuses_same_limiter_instance(self):
mw = RateLimitingMiddleware(get_client_id=lambda ctx: ctx)
first = mw._get_limiter("a")
assert mw._get_limiter("a") is first

def test_token_bucket_bound_never_exceeds_max_clients(self):
mw = RateLimitingMiddleware(get_client_id=lambda ctx: ctx, max_clients=5)
for i in range(100):
mw._get_limiter(f"client-{i}")
assert len(mw._client_limiters) == 5
# Only the 5 most-recent survive.
assert list(mw._client_limiters) == [f"client-{i}" for i in range(95, 100)]

def test_sliding_window_evicts_lru_when_full(self):
mw = SlidingWindowRateLimitingMiddleware(
max_requests=10, get_client_id=lambda ctx: ctx, max_clients=2
)
mw._get_limiter("a")
mw._get_limiter("b")
mw._get_limiter("c")
assert list(mw._client_limiters) == ["b", "c"]

def test_sliding_window_bound_never_exceeds_max_clients(self):
mw = SlidingWindowRateLimitingMiddleware(
max_requests=10, get_client_id=lambda ctx: ctx, max_clients=3
)
for i in range(50):
mw._get_limiter(f"c{i}")
assert len(mw._client_limiters) == 3

@pytest.mark.parametrize("bad", [0, -1, -100])
def test_token_bucket_rejects_non_positive_max_clients(self, bad):
with pytest.raises(ValueError, match="max_clients must be >= 1"):
RateLimitingMiddleware(max_clients=bad)

@pytest.mark.parametrize("bad", [0, -1])
def test_sliding_window_rejects_non_positive_max_clients(self, bad):
with pytest.raises(ValueError, match="max_clients must be >= 1"):
SlidingWindowRateLimitingMiddleware(max_requests=10, max_clients=bad)

def test_max_clients_one_keeps_only_newest(self):
mw = RateLimitingMiddleware(get_client_id=lambda ctx: ctx, max_clients=1)
mw._get_limiter("a")
mw._get_limiter("b")
assert list(mw._client_limiters) == ["b"]


class TestRateLimitError:
"""Test rate limit error."""

Expand Down
Loading