Skip to content

Commit

Permalink
use lua script
Browse files Browse the repository at this point in the history
  • Loading branch information
long2ice committed May 18, 2021
1 parent 6f3c85a commit 1fd34df
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 14 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@

## 0.1

### 0.1.4

- Now use `lua` script.
- **Break change**: You shoud call `FastAPILimiter.init` with `async`.

```python
await FastAPILimiter.init(redis)
```

### 0.1.3

- Support multiple rate strategy for one route. (#3)
Expand Down
27 changes: 25 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

## Introduction

FastAPI-Limiter is a rate limiting tool for [fastapi](https://github.com/tiangolo/fastapi) routes.
FastAPI-Limiter is a rate limiting tool for [fastapi](https://github.com/tiangolo/fastapi) routes with lua script.

## Requirements

Expand Down Expand Up @@ -40,7 +40,7 @@ app = FastAPI()
@app.on_event("startup")
async def startup():
redis = await aioredis.create_redis_pool("redis://localhost")
FastAPILimiter.init(redis)
await FastAPILimiter.init(redis)


@app.get("/", dependencies=[Depends(RateLimiter(times=2, seconds=5))])
Expand Down Expand Up @@ -114,6 +114,29 @@ async def multiple():

Not that you should note the dependencies orders, keep lower of result of `seconds/times` at the first.

## Lua script

The lua script used.

```lua
local key = KEYS[1]
local limit = tonumber(ARGV[1])
local expire_time = ARGV[2]

local current = tonumber(redis.call('get', key) or "0")
if current > 0 then
if current + 1 > limit then
return redis.call("PTTL", key)
else
redis.call("INCR", key)
return 0
end
else
redis.call("SET", key, 1, "px", expire_time)
return 0
end
```

## License

This project is licensed under the
Expand Down
9 changes: 7 additions & 2 deletions examples/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,13 @@

@app.on_event("startup")
async def startup():
redis = await aioredis.create_redis_pool("redis://localhost")
FastAPILimiter.init(redis)
redis = await aioredis.create_redis_pool("redis://localhost", encoding="utf8")
await FastAPILimiter.init(redis)


@app.on_event("shutdown")
async def shutdown():
await FastAPILimiter.close()


@app.get("/", dependencies=[Depends(RateLimiter(times=2, seconds=5))])
Expand Down
26 changes: 24 additions & 2 deletions fastapi_limiter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ async def default_callback(request: Request, response: Response, pexpire: int):
:return:
"""
expire = ceil(pexpire / 1000)

raise HTTPException(
HTTP_429_TOO_MANY_REQUESTS, "Too Many Requests", headers={"Retry-After": str(expire)}
)
Expand All @@ -35,11 +34,28 @@ async def default_callback(request: Request, response: Response, pexpire: int):
class FastAPILimiter:
redis: aioredis.Redis = None
prefix: str = None
lua_sha: str = None
identifier: Callable = None
callback: Callable = None
lua_script = """local key = KEYS[1]
local limit = tonumber(ARGV[1])
local expire_time = ARGV[2]
local current = tonumber(redis.call('get', key) or "0")
if current > 0 then
if current + 1 > limit then
return redis.call("PTTL",key)
else
redis.call("INCR", key)
return 0
end
else
redis.call("SET", key, 1,"px",expire_time)
return 0
end"""

@classmethod
def init(
async def init(
cls,
redis: aioredis.Redis,
prefix: str = "fastapi-limiter",
Expand All @@ -50,3 +66,9 @@ def init(
cls.prefix = prefix
cls.identifier = identifier
cls.callback = callback
cls.lua_sha = await redis.script_load(cls.lua_script)

@classmethod
async def close(cls):
cls.redis.close()
await cls.redis.wait_closed()
11 changes: 4 additions & 7 deletions fastapi_limiter/depends.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,8 @@ async def __call__(self, request: Request, response: Response):
redis = FastAPILimiter.redis
rate_key = await identifier(request)
key = f"{FastAPILimiter.prefix}:{rate_key}:{index}"
tr = redis.multi_exec()
tr.incrby(key, 1)
tr.pttl(key)
num, pexpire = await tr.execute()
if num == 1:
await redis.pexpire(key, self.milliseconds)
if num > self.times:
pexpire = await redis.evalsha(
FastAPILimiter.lua_sha, keys=[key], args=[self.times, self.milliseconds]
)
if pexpire != 0:
return await callback(request, response, pexpire)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ packages = [
]
readme = "README.md"
repository = "https://github.com/long2ice/fastapi-limiter.git"
version = "0.1.3"
version = "0.1.4"

[tool.poetry.dependencies]
aioredis = "*"
Expand Down

0 comments on commit 1fd34df

Please sign in to comment.