Skip to content

Commit

Permalink
Allow multiple limits per route and limiter fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
EvieePy committed Apr 26, 2024
1 parent bd1b280 commit 8c9362e
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 11 deletions.
14 changes: 12 additions & 2 deletions starlette_plus/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,22 @@ def limit(
exempt: ExemptCallable | None = None,
) -> T_LimitDecorator:
def decorator(coro: Callable[..., RouteCoro] | _Route) -> LimitDecorator:
limits: RateLimitData = {"rate": rate, "per": per, "bucket": bucket, "priority": priority, "exempt": exempt}
limits: RateLimitData = {
"rate": rate,
"per": per,
"bucket": bucket,
"priority": priority,
"exempt": exempt,
"is_global": False,
}

if isinstance(coro, _Route):
coro._limits.append(limits)
else:
setattr(coro, "__limits__", [limits])
try:
coro.__limits__.append(limits) # type: ignore
except AttributeError:
setattr(coro, "__limits__", [limits])

return coro

Expand Down
21 changes: 16 additions & 5 deletions starlette_plus/middleware/ratelimiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ def __init__(
self.app: ASGIApp = app

self._ignore_local: bool = ignore_localhost

for limit in global_limits:
limit["is_global"] = True

self._global_limits: list[RateLimitData] = global_limits

self._store: Store = Store(redis=redis)
Expand Down Expand Up @@ -82,28 +86,35 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
route = r
break

route_limits: list[RateLimitData] = sorted(getattr(route, "limits", []), key=lambda x: x["priority"])
route_limits: list[RateLimitData] = sorted(getattr(route, "limits", []), key=lambda x: x.get("priority", 0))
for data in route_limits:
# Ensure routes are never treated as global limits...
data["is_global"] = False

for limit in self._global_limits + route_limits:
is_exempt: bool = False
exempt: ExemptCallable | None = limit["exempt"]
exempt: ExemptCallable | None = limit.get("exempt", None)

if exempt is not None:
is_exempt: bool = await exempt(request)

if is_exempt:
continue

bucket: BucketType = limit["bucket"]
bucket: BucketType = limit.get("bucket", "ip")
if bucket == "ip":
if not request.client and not forwarded:
logger.warning("Could not determine the IP address while ratelimiting! Ignoring...")
return await self.app(scope, receive, send)

# forwarded or client.host will exist at this point...
key: str = forwarded.split(",")[0] if forwarded else request.client.host # type: ignore
ip: str = forwarded.split(",")[0] if forwarded else request.client.host # type: ignore
if not limit.get("is_global", False) and route:
key = f"{route.name}@{route.path}::{limit['rate']}.{limit['per']}.ip"
else:
key = ip

if self._ignore_local and key in ("127.0.0.1", "::1", "localhost", "0.0.0.0"):
if self._ignore_local and ip in ("127.0.0.1", "::1", "localhost", "0.0.0.0"):
return await self.app(scope, receive, send)
else:
key: str | None = await bucket(request)
Expand Down
1 change: 1 addition & 0 deletions starlette_plus/types_/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

from .limiter import RateLimitData as RateLimitData
9 changes: 5 additions & 4 deletions starlette_plus/types_/limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""

from collections.abc import Awaitable, Callable
from typing import Literal, TypeAlias, TypedDict
from typing import Literal, NotRequired, TypeAlias, TypedDict

from starlette.requests import Request
from starlette.responses import Response
Expand All @@ -30,6 +30,7 @@
class RateLimitData(TypedDict):
rate: int
per: float
bucket: BucketType
priority: int
exempt: ExemptCallable | None
bucket: NotRequired[BucketType]
priority: NotRequired[int]
exempt: NotRequired[ExemptCallable | None]
is_global: NotRequired[bool]

0 comments on commit 8c9362e

Please sign in to comment.