From ec5282a0d4531fa91d69d9ee06431ae8e551e2d0 Mon Sep 17 00:00:00 2001 From: tdakkota Date: Tue, 13 Sep 2022 14:31:44 +0300 Subject: [PATCH] feat(middleware): add match wrappers --- middleware/match.go | 79 +++++++++++++++++++++++++++++++++++ middleware/middleware.go | 8 ++-- middleware/middleware_test.go | 5 +-- 3 files changed, 85 insertions(+), 7 deletions(-) create mode 100644 middleware/match.go diff --git a/middleware/match.go b/middleware/match.go new file mode 100644 index 000000000..c12960ebd --- /dev/null +++ b/middleware/match.go @@ -0,0 +1,79 @@ +package middleware + +import ( + "regexp" + + "github.com/ogen-go/ogen/internal/xmaps" +) + +// OperationID calls the next middleware if request operation ID matches the given operationID. +func OperationID(m Middleware, operationID ...string) Middleware { + switch len(operationID) { + case 0: + return justCallNext + case 1: + val := operationID[0] + return func(req Request, next Next) (Response, error) { + if req.OperationID == val { + return m(req, next) + } + return next(req) + } + default: + set := xmaps.BuildSet(operationID...) + return func(req Request, next Next) (Response, error) { + if _, ok := set[req.OperationID]; ok { + return m(req, next) + } + return next(req) + } + } +} + +// OperationName calls the next middleware if request operation name matches the given operationName. +func OperationName(m Middleware, operationName ...string) Middleware { + switch len(operationName) { + case 0: + return justCallNext + case 1: + val := operationName[0] + return func(req Request, next Next) (Response, error) { + if req.OperationName == val { + return m(req, next) + } + return next(req) + } + default: + set := xmaps.BuildSet(operationName...) + return func(req Request, next Next) (Response, error) { + if _, ok := set[req.OperationName]; ok { + return m(req, next) + } + return next(req) + } + } +} + +// PathRegex calls the next middleware if request path matches the given regex. +func PathRegex(re *regexp.Regexp, m Middleware) Middleware { + if re == nil { + return justCallNext + } + + return func(req Request, next Next) (Response, error) { + if re.MatchString(req.Raw.URL.Path) { + return m(req, next) + } + return next(req) + } +} + +// BodyType calls the next middleware if request body type matches the given type. +func BodyType[T any](m Middleware) Middleware { + return func(req Request, next Next) (Response, error) { + if _, ok := req.Body.(T); ok { + return m(req, next) + } + return next(req) + } +} diff --git a/middleware/middleware.go b/middleware/middleware.go index ef6f90403..162b936c7 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -35,12 +35,14 @@ type ( Middleware func(req Request, next Next) (Response, error) ) +func justCallNext(req Request, next Next) (Response, error) { + return next(req) +} + // ChainMiddlewares chains middlewares into a single middleware, which will be executed in the order they are passed. func ChainMiddlewares(m ...Middleware) Middleware { if len(m) == 0 { - return func(req Request, next Next) (Response, error) { - return next(req) - } + return justCallNext } tail := ChainMiddlewares(m[1:]...) return func(req Request, next Next) (Response, error) { diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go index e4dee0dda..b3a252bef 100644 --- a/middleware/middleware_test.go +++ b/middleware/middleware_test.go @@ -58,14 +58,11 @@ func TestChainMiddlewares(t *testing.T) { func BenchmarkChainMiddlewares(b *testing.B) { const N = 20 - noop := func(req Request, next Next) (Response, error) { - return next(req) - } var ( chain = ChainMiddlewares(func() (r []Middleware) { for i := 0; i < N; i++ { - r = append(r, noop) + r = append(r, justCallNext) } return r }()...)