diff --git a/fastmcp_slim/fastmcp/server/middleware/rate_limiting.py b/fastmcp_slim/fastmcp/server/middleware/rate_limiting.py index 703fd393a..9ae959a83 100644 --- a/fastmcp_slim/fastmcp/server/middleware/rate_limiting.py +++ b/fastmcp_slim/fastmcp/server/middleware/rate_limiting.py @@ -1,7 +1,7 @@ """Rate limiting middleware for protecting FastMCP servers from abuse.""" import time -from collections import defaultdict, deque +from collections import OrderedDict, deque from collections.abc import Callable from typing import Any @@ -116,6 +116,7 @@ def __init__( burst_capacity: int | None = None, get_client_id: Callable[[MiddlewareContext], str] | None = None, global_limit: bool = False, + max_clients: int = 10000, ): """Initialize rate limiting middleware. @@ -124,18 +125,20 @@ def __init__( burst_capacity: Maximum burst capacity. If None, defaults to 2x max_requests_per_second get_client_id: Function to extract client ID from context. If None, uses global limiting global_limit: If True, apply limit globally; if False, per-client + max_clients: Maximum number of per-client limiters to track. Must be >= 1. + When reached, the least recently used client is evicted to make + room for a new client. """ + if max_clients < 1: + raise ValueError(f"max_clients must be >= 1, got {max_clients}") self.max_requests_per_second = max_requests_per_second self.burst_capacity = burst_capacity or int(max_requests_per_second * 2) self.get_client_id = get_client_id self.global_limit = global_limit + self._max_clients = max_clients - # Storage for rate limiters per client - self.limiters: dict[str, TokenBucketRateLimiter] = defaultdict( - lambda: TokenBucketRateLimiter( - self.burst_capacity, self.max_requests_per_second - ) - ) + # Per-client limiters stored in LRU order (oldest at front, newest at back) + self._client_limiters: OrderedDict[str, TokenBucketRateLimiter] = OrderedDict() # Global rate limiter if self.global_limit: @@ -143,6 +146,26 @@ def __init__( self.burst_capacity, self.max_requests_per_second ) + def _get_limiter(self, client_id: str) -> TokenBucketRateLimiter: + """Get or create a rate limiter for a client, with LRU eviction. + + When the cache is full, the least-recently-used client is evicted. + This is safe because an evicted client was inactive long enough for + max_clients other clients to be more recent — by which point their + token bucket would have refilled to capacity anyway. + """ + if client_id in self._client_limiters: + self._client_limiters.move_to_end(client_id) + return self._client_limiters[client_id] + + limiter = TokenBucketRateLimiter( + self.burst_capacity, self.max_requests_per_second + ) + self._client_limiters[client_id] = limiter + if len(self._client_limiters) > self._max_clients: + self._client_limiters.popitem(last=False) + return limiter + def _get_client_identifier(self, context: MiddlewareContext) -> str: """Get client identifier for rate limiting.""" if self.get_client_id: @@ -159,7 +182,7 @@ async def on_request(self, context: MiddlewareContext, call_next: CallNext) -> A else: # Per-client rate limiting client_id = self._get_client_identifier(context) - limiter = self.limiters[client_id] + limiter = self._get_limiter(client_id) allowed = await limiter.consume() if not allowed: raise RateLimitError(f"Rate limit exceeded for client: {client_id}") @@ -193,6 +216,7 @@ def __init__( max_requests: int, window_minutes: int = 1, get_client_id: Callable[[MiddlewareContext], str] | None = None, + max_clients: int = 10000, ): """Initialize sliding window rate limiting middleware. @@ -200,16 +224,39 @@ def __init__( max_requests: Maximum requests allowed in the time window window_minutes: Time window in minutes get_client_id: Function to extract client ID from context + max_clients: Maximum number of per-client limiters to track. Must be >= 1. + When reached, the least recently used client is evicted to make + room for a new client. """ + if max_clients < 1: + raise ValueError(f"max_clients must be >= 1, got {max_clients}") self.max_requests = max_requests self.window_seconds = window_minutes * 60 self.get_client_id = get_client_id + self._max_clients = max_clients - # Storage for rate limiters per client - self.limiters: dict[str, SlidingWindowRateLimiter] = defaultdict( - lambda: SlidingWindowRateLimiter(self.max_requests, self.window_seconds) + self._client_limiters: OrderedDict[str, SlidingWindowRateLimiter] = ( + OrderedDict() ) + def _get_limiter(self, client_id: str) -> SlidingWindowRateLimiter: + """Get or create a rate limiter for a client, with LRU eviction. + + When the cache is full, the least-recently-used client is evicted. + This is safe because an evicted client was inactive long enough for + max_clients other clients to be more recent — by which point their + sliding window would have expired anyway. + """ + if client_id in self._client_limiters: + self._client_limiters.move_to_end(client_id) + return self._client_limiters[client_id] + + limiter = SlidingWindowRateLimiter(self.max_requests, self.window_seconds) + self._client_limiters[client_id] = limiter + if len(self._client_limiters) > self._max_clients: + self._client_limiters.popitem(last=False) + return limiter + def _get_client_identifier(self, context: MiddlewareContext) -> str: """Get client identifier for rate limiting.""" if self.get_client_id: @@ -219,7 +266,7 @@ def _get_client_identifier(self, context: MiddlewareContext) -> str: async def on_request(self, context: MiddlewareContext, call_next: CallNext) -> Any: """Apply sliding window rate limiting to requests.""" client_id = self._get_client_identifier(context) - limiter = self.limiters[client_id] + limiter = self._get_limiter(client_id) allowed = await limiter.is_allowed() if not allowed: diff --git a/tests/server/middleware/test_rate_limiting.py b/tests/server/middleware/test_rate_limiting.py index 59fb074f1..153fde269 100644 --- a/tests/server/middleware/test_rate_limiting.py +++ b/tests/server/middleware/test_rate_limiting.py @@ -270,6 +270,73 @@ async def test_on_request_rate_limited(self, mock_context, mock_call_next): await middleware.on_request(mock_context, mock_call_next) +class TestRateLimiterLRUEviction: + """LRU bounding of per-client limiters (issue #4053).""" + + def test_token_bucket_evicts_lru_when_full(self): + mw = RateLimitingMiddleware(get_client_id=lambda ctx: ctx, max_clients=2) + mw._get_limiter("a") + mw._get_limiter("b") + # Cache full at 2; adding "c" evicts the LRU ("a"), not the newcomer. + mw._get_limiter("c") + assert list(mw._client_limiters) == ["b", "c"] + + def test_token_bucket_move_to_end_on_access_protects_active_client(self): + mw = RateLimitingMiddleware(get_client_id=lambda ctx: ctx, max_clients=2) + mw._get_limiter("a") + mw._get_limiter("b") + # Touch "a" so it is most-recently-used; "b" is now LRU. + mw._get_limiter("a") + mw._get_limiter("c") + assert list(mw._client_limiters) == ["a", "c"] + + def test_token_bucket_reuses_same_limiter_instance(self): + mw = RateLimitingMiddleware(get_client_id=lambda ctx: ctx) + first = mw._get_limiter("a") + assert mw._get_limiter("a") is first + + def test_token_bucket_bound_never_exceeds_max_clients(self): + mw = RateLimitingMiddleware(get_client_id=lambda ctx: ctx, max_clients=5) + for i in range(100): + mw._get_limiter(f"client-{i}") + assert len(mw._client_limiters) == 5 + # Only the 5 most-recent survive. + assert list(mw._client_limiters) == [f"client-{i}" for i in range(95, 100)] + + def test_sliding_window_evicts_lru_when_full(self): + mw = SlidingWindowRateLimitingMiddleware( + max_requests=10, get_client_id=lambda ctx: ctx, max_clients=2 + ) + mw._get_limiter("a") + mw._get_limiter("b") + mw._get_limiter("c") + assert list(mw._client_limiters) == ["b", "c"] + + def test_sliding_window_bound_never_exceeds_max_clients(self): + mw = SlidingWindowRateLimitingMiddleware( + max_requests=10, get_client_id=lambda ctx: ctx, max_clients=3 + ) + for i in range(50): + mw._get_limiter(f"c{i}") + assert len(mw._client_limiters) == 3 + + @pytest.mark.parametrize("bad", [0, -1, -100]) + def test_token_bucket_rejects_non_positive_max_clients(self, bad): + with pytest.raises(ValueError, match="max_clients must be >= 1"): + RateLimitingMiddleware(max_clients=bad) + + @pytest.mark.parametrize("bad", [0, -1]) + def test_sliding_window_rejects_non_positive_max_clients(self, bad): + with pytest.raises(ValueError, match="max_clients must be >= 1"): + SlidingWindowRateLimitingMiddleware(max_requests=10, max_clients=bad) + + def test_max_clients_one_keeps_only_newest(self): + mw = RateLimitingMiddleware(get_client_id=lambda ctx: ctx, max_clients=1) + mw._get_limiter("a") + mw._get_limiter("b") + assert list(mw._client_limiters) == ["b"] + + class TestRateLimitError: """Test rate limit error."""