Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 50 additions & 3 deletions apisix/plugins/ai-rate-limiting.lua
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ local ipairs = ipairs
local type = type
local core = require("apisix.core")
local limit_count = require("apisix.plugins.limit-count.init")
local policy_to_additional_properties = require("apisix.utils.redis-schema").schema

local plugin_name = "ai-rate-limiting"

Expand Down Expand Up @@ -56,6 +57,12 @@ local schema = {
rejected_msg = {
type = "string", minLength = 1
},
policy = {
type = "string",
enum = {"local", "redis", "redis-cluster"},
default = "local",
},
allow_degradation = {type = "boolean", default = false},
},
dependencies = {
limit = {"time_window"},
Expand All @@ -68,6 +75,24 @@ local schema = {
{
required = {"instances"}
}
},
["if"] = {
properties = {
policy = {
enum = {"redis"},
},
},
},
["then"] = policy_to_additional_properties.redis,
["else"] = {
["if"] = {
properties = {
policy = {
enum = {"redis-cluster"},
},
},
},
["then"] = policy_to_additional_properties["redis-cluster"],
}
}

Expand Down Expand Up @@ -99,7 +124,8 @@ local function transform_limit_conf(plugin_conf, instance_conf, instance_name)
limit = instance_conf.limit
time_window = instance_conf.time_window
end
return {

local limit_conf = {
_vid = key,

key = key,
Expand All @@ -109,15 +135,36 @@ local function transform_limit_conf(plugin_conf, instance_conf, instance_name)
rejected_msg = plugin_conf.rejected_msg,
show_limit_quota_header = plugin_conf.show_limit_quota_header,
-- limit-count need these fields
policy = "local",
policy = plugin_conf.policy or "local",
key_type = "constant",
allow_degradation = false,
allow_degradation = plugin_conf.allow_degradation or false,
sync_interval = -1,

limit_header = "X-AI-RateLimit-Limit-" .. name,
remaining_header = "X-AI-RateLimit-Remaining-" .. name,
reset_header = "X-AI-RateLimit-Reset-" .. name,
}

-- Pass through Redis configuration if policy is redis or redis-cluster
if plugin_conf.policy == "redis" then
limit_conf.redis_host = plugin_conf.redis_host
limit_conf.redis_port = plugin_conf.redis_port
limit_conf.redis_username = plugin_conf.redis_username
limit_conf.redis_password = plugin_conf.redis_password
limit_conf.redis_database = plugin_conf.redis_database
limit_conf.redis_timeout = plugin_conf.redis_timeout
limit_conf.redis_ssl = plugin_conf.redis_ssl
limit_conf.redis_ssl_verify = plugin_conf.redis_ssl_verify
elseif plugin_conf.policy == "redis-cluster" then
limit_conf.redis_cluster_nodes = plugin_conf.redis_cluster_nodes
limit_conf.redis_cluster_name = plugin_conf.redis_cluster_name
limit_conf.redis_password = plugin_conf.redis_password
limit_conf.redis_timeout = plugin_conf.redis_timeout
limit_conf.redis_cluster_ssl = plugin_conf.redis_cluster_ssl
limit_conf.redis_cluster_ssl_verify = plugin_conf.redis_cluster_ssl_verify
end

return limit_conf
end


Expand Down
61 changes: 34 additions & 27 deletions apisix/plugins/limit-count/limit-count-redis-cluster.lua
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
-- See the License for the specific language governing permissions and
-- limitations under the License.
--

local redis_cluster = require("apisix.utils.rediscluster")
local core = require("apisix.core")
local ngx = ngx
local get_phase = ngx.get_phase
local setmetatable = setmetatable
local tostring = tostring
local util = require("apisix.plugins.limit-count.util")
local ngx_timer_at = ngx.timer.at

local _M = {}

Expand All @@ -28,17 +30,6 @@ local mt = {
}


local script = core.string.compress_script([=[
assert(tonumber(ARGV[3]) >= 1, "cost must be at least 1")
local ttl = redis.call('ttl', KEYS[1])
if ttl < 0 then
redis.call('set', KEYS[1], ARGV[1] - ARGV[3], 'EX', ARGV[2])
return {ARGV[1] - ARGV[3], ARGV[2]}
end
return {redis.call('incrby', KEYS[1], 0 - ARGV[3]), ttl}
]=])


function _M.new(plugin_name, limit, window, conf)
local red_cli, err = redis_cluster.new(conf, "plugin-limit-count-redis-cluster-slot-lock")
if not red_cli then
Expand All @@ -57,26 +48,42 @@ function _M.new(plugin_name, limit, window, conf)
end


function _M.incoming(self, key, cost)
local red = self.red_cli
local limit = self.limit
local window = self.window
key = self.plugin_name .. tostring(key)
local function log_phase_incoming_thread(premature, self, key, cost)
return util.redis_log_phase_incoming(self, self.red_cli, key, cost)
end


local ttl = 0
local res, err = red:eval(script, 1, key, limit, window, cost or 1)
local function log_phase_incoming(self, key, cost, dry_run)
if dry_run then
return true
end

if err then
return nil, err, ttl
local ok, err = ngx_timer_at(0, log_phase_incoming_thread, self, key, cost)
if not ok then
core.log.error("failed to create timer: ", err)
return nil, err
end

local remaining = res[1]
ttl = res[2]
return ok
end


function _M.incoming(self, key, cost, dry_run)
if get_phase() == "log" then
local ok, err = log_phase_incoming(self, key, cost, dry_run)
if not ok then
return nil, err, 0
end

if remaining < 0 then
return nil, "rejected", ttl
return 0, self.limit, self.window
end
return 0, remaining, ttl

local commit = true
if dry_run ~= nil then
commit = not dry_run
end

return util.redis_incoming(self, self.red_cli, key, commit, cost)
end


Expand Down
77 changes: 49 additions & 28 deletions apisix/plugins/limit-count/limit-count-redis.lua
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@
--
local redis = require("apisix.utils.redis")
local core = require("apisix.core")
local ngx = ngx
local get_phase = ngx.get_phase
local assert = assert
local setmetatable = setmetatable
local tostring = tostring
local util = require("apisix.plugins.limit-count.util")
local ngx_timer_at = ngx.timer.at


local _M = {version = 0.3}
Expand All @@ -29,17 +32,6 @@ local mt = {
}


local script = core.string.compress_script([=[
assert(tonumber(ARGV[3]) >= 1, "cost must be at least 1")
local ttl = redis.call('ttl', KEYS[1])
if ttl < 0 then
redis.call('set', KEYS[1], ARGV[1] - ARGV[3], 'EX', ARGV[2])
return {ARGV[1] - ARGV[3], ARGV[2]}
end
return {redis.call('incrby', KEYS[1], 0 - ARGV[3]), ttl}
]=])


function _M.new(plugin_name, limit, window, conf)
assert(limit > 0 and window > 0)

Expand All @@ -52,37 +44,66 @@ function _M.new(plugin_name, limit, window, conf)
return setmetatable(self, mt)
end

function _M.incoming(self, key, cost)

local function log_phase_incoming_thread(premature, self, key, cost)
local conf = self.conf
local red, err = redis.new(conf)
if not red then
return red, err, 0
return red, err
end
return util.redis_log_phase_incoming(self, red, key, cost)
end

local limit = self.limit
local window = self.window
local res
key = self.plugin_name .. tostring(key)

local ttl = 0
res, err = red:eval(script, 1, key, limit, window, cost or 1)
local function log_phase_incoming(self, key, cost, dry_run)
if dry_run then
return true
end

if err then
return nil, err, ttl
local ok, err = ngx_timer_at(0, log_phase_incoming_thread, self, key, cost)
if not ok then
core.log.error("failed to create timer: ", err)
return nil, err
end

local remaining = res[1]
ttl = res[2]
return ok
end


function _M.incoming(self, key, cost, dry_run)
if get_phase() == "log" then
local ok, err = log_phase_incoming(self, key, cost, dry_run)
if not ok then
return nil, err, 0
end

-- best-effort result because lua-resty-redis is not allowed in log phase
return 0, self.limit, self.window
end

local conf = self.conf
local red, err = redis.new(conf)
if not red then
return red, err, 0
end

local commit = true
if dry_run ~= nil then
commit = not dry_run
end

local delay, remaining, ttl = util.redis_incoming(self, red, key, commit, cost)
if not delay then
local err = remaining
return nil, err, ttl or 0
end

local ok, err = red:set_keepalive(10000, 100)
if not ok then
return nil, err, ttl
end

if remaining < 0 then
return nil, "rejected", ttl
end
return 0, remaining, ttl
return delay, remaining, ttl
end


Expand Down
79 changes: 79 additions & 0 deletions apisix/plugins/limit-count/util.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local tostring = tostring
local tonumber = tonumber
local _M = {version = 0.1}

local commit_script = core.string.compress_script([=[
assert(tonumber(ARGV[3]) >= 0, "cost must be at least 0")
local ttl = redis.call('ttl', KEYS[1])
if ttl < 0 then
redis.call('set', KEYS[1], ARGV[1] - ARGV[3], 'EX', ARGV[2])
return {ARGV[1] - ARGV[3], ARGV[2]}
end
return {redis.call('incrby', KEYS[1], 0 - ARGV[3]), ttl}
]=])

function _M.redis_incoming(self, red, key, commit, cost)
local limit = self.limit
local window = self.window
key = self.plugin_name .. tostring(key)

local requested_cost = cost or 1
local script_cost = commit and requested_cost or 0
local res, err = red:eval(commit_script, 1, key, limit, window, script_cost)

if err then
return nil, err, 0
end

local stored_remaining = tonumber(res[1])
if stored_remaining == nil then
stored_remaining = limit - script_cost
end
local ttl = tonumber(res[2]) or window

local remaining
if commit then
remaining = stored_remaining
else
remaining = stored_remaining - requested_cost
end

if remaining < 0 then
return nil, "rejected", ttl
end

return 0, remaining, ttl
end

function _M.redis_log_phase_incoming(self, red, key, cost)
local limit = self.limit
local window = self.window
key = self.plugin_name .. tostring(key)

local res, err = red:eval(commit_script, 1, key, limit, window, cost or 1)
if err then
return nil, err
end

return res[1]
end

return _M

Loading