From cdebd0eb1a088eb2d3e7feaba76393cac8e50802 Mon Sep 17 00:00:00 2001 From: meifakun Date: Mon, 24 Jun 2019 14:31:29 +0800 Subject: [PATCH] add custom context support --- context.go | 12 ++++++++++++ echo.go | 15 ++++++++++++++- echo_test.go | 51 +++++++++++++++++++++++++++++++++++++++++++++++++++ router.go | 21 +++++++++++++++++++-- 4 files changed, 96 insertions(+), 3 deletions(-) diff --git a/context.go b/context.go index 98cf50bc0..cbf8ad4e1 100644 --- a/context.go +++ b/context.go @@ -187,6 +187,12 @@ type ( // with `Echo#AcquireContext()` and `Echo#ReleaseContext()`. // See `Echo#ServeHTTP()` Reset(r *http.Request, w http.ResponseWriter) + + // Underlying returns the underlying context. + Underlying() Context + + // release managed resources + Free() } context struct { @@ -604,3 +610,9 @@ func (c *context) Reset(r *http.Request, w http.ResponseWriter) { // NOTE: Don't reset because it has to have length c.echo.maxParam at all times // c.pvalues = nil } + +func (c *context) Underlying() Context { + return nil +} + +func (c *context) Free() {} diff --git a/echo.go b/echo.go index 032bd0026..aa7c7fdf4 100644 --- a/echo.go +++ b/echo.go @@ -72,6 +72,7 @@ type ( router *Router notFoundHandler HandlerFunc pool sync.Pool + newCtx func(r *http.Request, w http.ResponseWriter) Context Server *http.Server TLSServer *http.Server Listener net.Listener @@ -311,8 +312,19 @@ func New() (e *Echo) { return } +func (e *Echo) SetNewContext(newCtx func(r *http.Request, w http.ResponseWriter) Context) { + e.newCtx = newCtx +} + // NewContext returns a Context instance. func (e *Echo) NewContext(r *http.Request, w http.ResponseWriter) Context { + if e.newCtx != nil { + return e.newCtx(r, w) + } + return e.NewNativeContext(r, w) +} + +func (e *Echo) NewNativeContext(r *http.Request, w http.ResponseWriter) Context { return &context{ request: r, response: NewResponse(w, e), @@ -562,13 +574,14 @@ func (e *Echo) AcquireContext() Context { // ReleaseContext returns the `Context` instance back to the pool. // You must call it after `AcquireContext()`. func (e *Echo) ReleaseContext(c Context) { + c.Free() e.pool.Put(c) } // ServeHTTP implements `http.Handler` interface, which serves HTTP requests. func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Acquire context - c := e.pool.Get().(*context) + c := e.pool.Get().(Context) c.Reset(r, w) h := NotFoundHandler diff --git a/echo_test.go b/echo_test.go index dec713ece..c60e6c169 100644 --- a/echo_test.go +++ b/echo_test.go @@ -578,3 +578,54 @@ func TestEchoShutdown(t *testing.T) { err := <-errCh assert.Equal(t, err.Error(), "http: Server closed") } + +type MyContext struct { + Context + + MyField1 int + MyField2 int + MyField3 int +} + +func (ctx *MyContext) Underlying() Context { + return ctx.Context +} + +func (ctx *MyContext) Free() { + ctx.MyField1 = 0 + ctx.MyField2 = 0 + ctx.MyField3 = 0 +} + +func TestEchoWrapContext(t *testing.T) { + e := New() + e.SetNewContext(func(r *http.Request, w http.ResponseWriter) Context { + return &MyContext{Context: e.NewNativeContext(r, w)} + }) + toEchoFunc := func(h func(ctx *MyContext) error) HandlerFunc { + return func(ctx Context) error { + return h(ctx.(*MyContext)) + } + } + + e.Pre(func(next HandlerFunc) HandlerFunc { + return func(c Context) error { + ctx := c.(*MyContext) + ctx.MyField1 = 1 + ctx.MyField2 = 2 + ctx.MyField3 = 3 + return next(c) + } + }) + + e.GET("/users", toEchoFunc(func(c *MyContext) error { + assert.Equal(t, c.MyField1, 1) + assert.Equal(t, c.MyField2, 2) + assert.Equal(t, c.MyField3, 3) + return c.String(http.StatusOK, "ok") + })) + + c, b := request(http.MethodGet, "/users", e) + assert.Equal(t, http.StatusOK, c) + assert.Equal(t, "ok", b) +} diff --git a/router.go b/router.go index 73f0b68b9..4670047d7 100644 --- a/router.go +++ b/router.go @@ -1,6 +1,9 @@ package echo -import "net/http" +import ( + "errors" + "net/http" +) type ( // Router is the registry of all registered routes for an `Echo` instance for @@ -296,7 +299,21 @@ func (n *node) checkMethodNotAllowed() HandlerFunc { // - Reset it `Context#Reset()` // - Return it `Echo#ReleaseContext()`. func (r *Router) Find(method, path string, c Context) { - ctx := c.(*context) + ctx, isNativeCtx := c.(*context) + if !isNativeCtx { + for { + c = c.Underlying() + if c == nil { + panic(errors.New("must has underlying native context")) + } + + ctx, isNativeCtx = c.(*context) + if isNativeCtx { + break + } + } + } + ctx.path = path cn := r.tree // Current node as root