diff --git a/starlette_plus/core.py b/starlette_plus/core.py index d01b9d9..5e60545 100644 --- a/starlette_plus/core.py +++ b/starlette_plus/core.py @@ -112,12 +112,22 @@ def limit( exempt: ExemptCallable | None = None, ) -> T_LimitDecorator: def decorator(coro: Callable[..., RouteCoro] | _Route) -> LimitDecorator: - limits: RateLimitData = {"rate": rate, "per": per, "bucket": bucket, "priority": priority, "exempt": exempt} + limits: RateLimitData = { + "rate": rate, + "per": per, + "bucket": bucket, + "priority": priority, + "exempt": exempt, + "is_global": False, + } if isinstance(coro, _Route): coro._limits.append(limits) else: - setattr(coro, "__limits__", [limits]) + try: + coro.__limits__.append(limits) # type: ignore + except AttributeError: + setattr(coro, "__limits__", [limits]) return coro diff --git a/starlette_plus/middleware/ratelimiter.py b/starlette_plus/middleware/ratelimiter.py index 43e18a2..7dbd74b 100644 --- a/starlette_plus/middleware/ratelimiter.py +++ b/starlette_plus/middleware/ratelimiter.py @@ -51,6 +51,10 @@ def __init__( self.app: ASGIApp = app self._ignore_local: bool = ignore_localhost + + for limit in global_limits: + limit["is_global"] = True + self._global_limits: list[RateLimitData] = global_limits self._store: Store = Store(redis=redis) @@ -82,11 +86,14 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: route = r break - route_limits: list[RateLimitData] = sorted(getattr(route, "limits", []), key=lambda x: x["priority"]) + route_limits: list[RateLimitData] = sorted(getattr(route, "limits", []), key=lambda x: x.get("priority", 0)) + for data in route_limits: + # Ensure routes are never treated as global limits... + data["is_global"] = False for limit in self._global_limits + route_limits: is_exempt: bool = False - exempt: ExemptCallable | None = limit["exempt"] + exempt: ExemptCallable | None = limit.get("exempt", None) if exempt is not None: is_exempt: bool = await exempt(request) @@ -94,16 +101,20 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if is_exempt: continue - bucket: BucketType = limit["bucket"] + bucket: BucketType = limit.get("bucket", "ip") if bucket == "ip": if not request.client and not forwarded: logger.warning("Could not determine the IP address while ratelimiting! Ignoring...") return await self.app(scope, receive, send) # forwarded or client.host will exist at this point... - key: str = forwarded.split(",")[0] if forwarded else request.client.host # type: ignore + ip: str = forwarded.split(",")[0] if forwarded else request.client.host # type: ignore + if not limit.get("is_global", False) and route: + key = f"{route.name}@{route.path}::{limit['rate']}.{limit['per']}.ip" + else: + key = ip - if self._ignore_local and key in ("127.0.0.1", "::1", "localhost", "0.0.0.0"): + if self._ignore_local and ip in ("127.0.0.1", "::1", "localhost", "0.0.0.0"): return await self.app(scope, receive, send) else: key: str | None = await bucket(request) diff --git a/starlette_plus/types_/__init__.py b/starlette_plus/types_/__init__.py index d320675..8d8d801 100644 --- a/starlette_plus/types_/__init__.py +++ b/starlette_plus/types_/__init__.py @@ -12,4 +12,5 @@ See the License for the specific language governing permissions and limitations under the License. """ + from .limiter import RateLimitData as RateLimitData diff --git a/starlette_plus/types_/limiter.py b/starlette_plus/types_/limiter.py index 7c9bb9e..51b847d 100644 --- a/starlette_plus/types_/limiter.py +++ b/starlette_plus/types_/limiter.py @@ -14,7 +14,7 @@ """ from collections.abc import Awaitable, Callable -from typing import Literal, TypeAlias, TypedDict +from typing import Literal, NotRequired, TypeAlias, TypedDict from starlette.requests import Request from starlette.responses import Response @@ -30,6 +30,7 @@ class RateLimitData(TypedDict): rate: int per: float - bucket: BucketType - priority: int - exempt: ExemptCallable | None + bucket: NotRequired[BucketType] + priority: NotRequired[int] + exempt: NotRequired[ExemptCallable | None] + is_global: NotRequired[bool]