Skip to content

Commit

Permalink
feat: use pathrouter instead of httprouter
Browse files Browse the repository at this point in the history
  • Loading branch information
vizee committed Mar 4, 2024
1 parent 7468733 commit c873a75
Show file tree
Hide file tree
Showing 10 changed files with 134 additions and 119 deletions.
4 changes: 2 additions & 2 deletions engine/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ import (
"net/http"
"sync"

"github.com/julienschmidt/httprouter"
"github.com/vizee/gapi/internal/slices"
"github.com/vizee/pathrouter"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
Expand Down Expand Up @@ -41,7 +41,7 @@ func NewBuilder() *Builder {
},
},
}
b.engine.router.Store(httprouter.New())
b.engine.routers.Store(&namedList[pathrouter.Router[*grpcRoute]]{})
return b
}

Expand Down
13 changes: 2 additions & 11 deletions engine/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,11 @@ import (
"net/http"
"net/url"

"github.com/julienschmidt/httprouter"
"github.com/vizee/gapi/internal/ioutil"
"github.com/vizee/pathrouter"
)

type Params httprouter.Params

func (ps Params) Get(name string) (string, bool) {
for i := range ps {
if ps[i].Key == name {
return ps[i].Value, true
}
}
return "", false
}
type Params = pathrouter.Params

type Context struct {
req *http.Request
Expand Down
4 changes: 1 addition & 3 deletions engine/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ import (
"net/http"
"reflect"
"testing"

"github.com/julienschmidt/httprouter"
)

func TestContext_Get(t *testing.T) {
Expand Down Expand Up @@ -44,7 +42,7 @@ func TestContext_reset(t *testing.T) {
ctx := &Context{
req: &http.Request{},
resp: nil,
params: []httprouter.Param{},
params: Params{},
query: map[string][]string{},
values: map[string]string{},
body: []byte{},
Expand Down
75 changes: 38 additions & 37 deletions engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ import (
"sync"
"sync/atomic"

"github.com/julienschmidt/httprouter"
"github.com/vizee/gapi/log"
"github.com/vizee/gapi/metadata"
"github.com/vizee/pathrouter"
"google.golang.org/grpc"
)

Expand All @@ -32,7 +32,7 @@ type Engine struct {
notFound HandleFunc
ctxpool *sync.Pool

router atomic.Pointer[httprouter.Router]
routers atomic.Pointer[namedList[pathrouter.Router[*grpcRoute]]]
clients map[string]*grpc.ClientConn
routeLock sync.Mutex
}
Expand Down Expand Up @@ -63,7 +63,7 @@ func (e *Engine) ClearRouter() {
e.routeLock.Lock()
clients := e.clients
e.clients = nil
e.router.Store(nil)
e.routers.Store(nil)
e.routeLock.Unlock()
for _, cc := range clients {
cc.Close()
Expand All @@ -77,7 +77,9 @@ type routesSliceIter struct {

func (it *routesSliceIter) NextRoute() *metadata.Route {
if it.i < len(it.rs) {
return it.rs[it.i]
r := it.rs[it.i]
it.i++
return r
}
return nil
}
Expand All @@ -86,16 +88,6 @@ func (e *Engine) RebuildRouter(routes []*metadata.Route, ignoreError bool) error
return RebuildEngineRouter(e, &routesSliceIter{rs: routes}, ignoreError)
}

func registerRoute(router *httprouter.Router, method string, path string, handle httprouter.Handle) (err error) {
defer func() {
if e := recover(); e != nil {
err = fmt.Errorf("router.Handle: %v", e)
}
}()
router.Handle(method, path, handle)
return
}

func (e *Engine) Execute(w http.ResponseWriter, req *http.Request, params Params, chain []HandleFunc, handle HandleFunc) {
ctx := e.ctxpool.Get().(*Context)
ctx.req = req
Expand All @@ -118,21 +110,32 @@ func (e *Engine) NotFound(w http.ResponseWriter, req *http.Request) {
e.Execute(w, req, nil, e.uses, e.notFound)
}

var resultsPool sync.Pool

func getMatchResult() *pathrouter.MatchResult[*grpcRoute] {
res := resultsPool.Get()
if res != nil {
return res.(*pathrouter.MatchResult[*grpcRoute])
}
return &pathrouter.MatchResult[*grpcRoute]{}
}

func (e *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) {
log.Debugf("Route %s %s", req.Method, req.URL.Path)

router := e.router.Load()
if router != nil {
path := req.URL.Path
handle, ps, tsr := router.Lookup(req.Method, path)
if handle != nil {
handle(w, req, ps)
return
} else if tsr && path != "/" {
log.Debugf("Trailing slash redirect %s", req.URL.Path)
req.URL.Path = path + "/"
http.Redirect(w, req, req.URL.String(), http.StatusMovedPermanently)
return
routers := e.routers.Load()
if routers != nil {
rotuer, ok := routers.lookup(req.Method)
if ok {
res := getMatchResult()
matched := rotuer.Match(req.URL.Path, res)
if matched {
e.Execute(w, req, res.Params, res.Value.mws, res.Value.handle)
}
resultsPool.Put(res)
if matched {
return
}
}
}

Expand All @@ -159,7 +162,7 @@ func RebuildEngineRouter[R RouteIter](e *Engine, routeIter R, ignoreError bool)

// 在同一次 router 构建中尽可能复用重复的 chain,在大量路由的情况下会带来一些内存节约
chainCache := make(map[string][]HandleFunc)
router := httprouter.New()
routers := &namedList[pathrouter.Router[*grpcRoute]]{}
for {
route := routeIter.NextRoute()
if route == nil {
Expand Down Expand Up @@ -190,22 +193,20 @@ func RebuildEngineRouter[R RouteIter](e *Engine, routeIter R, ignoreError bool)
clients[route.Call.Server] = client
}

middlewares, err := e.generateMiddlewareChain(chainCache, route.Use)
mws, err := e.generateMiddlewareChain(chainCache, route.Use)
if err != nil {
if ignoreError {
continue
}
return fmt.Errorf("middleware of %s: %v", route.Path, err)
}

gr := &grpcRoute{
engine: e,
middlewares: middlewares,
call: route.Call,
ch: ch,
client: client,
}
err = registerRoute(router, route.Method, route.Path, gr.handleRoute)
err = routers.get(route.Method).Add(route.Path, &grpcRoute{
mws: mws,
call: route.Call,
ch: ch,
client: client,
})
if err != nil {
if ignoreError {
log.Warnf("registerRoute(%s %s): %v", route.Method, route.Path, err)
Expand All @@ -215,7 +216,7 @@ func RebuildEngineRouter[R RouteIter](e *Engine, routeIter R, ignoreError bool)
}
}

e.router.Store(router)
e.routers.Store(routers)
e.clients = clients
for server, cc := range old {
if clients[server] == nil {
Expand Down
29 changes: 28 additions & 1 deletion engine/engine_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,36 @@
package engine

import (
"fmt"
"net/http"
"strings"
"testing"

"github.com/vizee/gapi/log"
"github.com/vizee/gapi/metadata"
)

type logger struct {
}

// Debugf implements log.Logger.
func (*logger) Debugf(format string, args ...any) {
fmt.Printf(format+"\n", args...)
}

// Errorf implements log.Logger.
func (*logger) Errorf(format string, args ...any) {
fmt.Printf(format+"\n", args...)
}

// Warnf implements log.Logger.
func (*logger) Warnf(format string, args ...any) {
fmt.Printf(format+"\n", args...)
}

func TestEngine_RebuildRouter(t *testing.T) {
log.SetLogger(&logger{})

builder := NewBuilder()
builder.RegisterHandler("mock-handler", &mockHandler{})
builder.RegisterMiddleware("auth", func(ctx *Context) error {
Expand All @@ -29,9 +51,14 @@ func TestEngine_RebuildRouter(t *testing.T) {
return nil
})
e := builder.Build()
e.RebuildRouter([]*metadata.Route{

err := e.RebuildRouter([]*metadata.Route{
{Method: "POST", Path: "/add", Use: []string{"auth"}, Call: mockAddCall()},
}, true)
if err != nil {
t.Fatal(err)
}

req, err := http.NewRequest("POST", "http://localhost/add?uid=1", strings.NewReader(`{"a":1,"b":2}`))
if err != nil {
t.Fatal(err)
Expand Down
30 changes: 30 additions & 0 deletions engine/named_list.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package engine

type namedItem[T any] struct {
name string
v T
}

type namedList[T any] struct {
items []namedItem[T]
}

func (l *namedList[T]) lookup(name string) (*T, bool) {
for i := range l.items {
if l.items[i].name == name {
return &l.items[i].v, true
}
}
return nil, false
}

func (l *namedList[T]) get(name string) *T {
v, ok := l.lookup(name)
if ok {
return v
}
l.items = append(l.items, namedItem[T]{
name: name,
})
return &l.items[len(l.items)-1].v
}
20 changes: 7 additions & 13 deletions engine/route.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@ package engine

import (
"context"
"net/http"

"github.com/julienschmidt/httprouter"
"github.com/vizee/gapi/metadata"
"google.golang.org/grpc"
)
Expand All @@ -27,12 +25,13 @@ func (*passthroughCodec) Name() string {
return "passthrough"
}

var passthroughCodecOpt = grpc.ForceCodec(&passthroughCodec{})

type grpcRoute struct {
engine *Engine
middlewares []HandleFunc
call *metadata.Call
ch CallHandler
client *grpc.ClientConn
mws []HandleFunc
call *metadata.Call
ch CallHandler
client *grpc.ClientConn
}

func (r *grpcRoute) handle(ctx *Context) error {
Expand All @@ -49,7 +48,7 @@ func (r *grpcRoute) handle(ctx *Context) error {
callctx, cancel = context.WithTimeout(callctx, call.Timeout)
}
var respData []byte
err = r.client.Invoke(callctx, call.Method, reqData, &respData, grpc.ForceCodec(&passthroughCodec{}))
err = r.client.Invoke(callctx, call.Method, reqData, &respData, passthroughCodecOpt)
if cancel != nil {
cancel()
}
Expand All @@ -59,8 +58,3 @@ func (r *grpcRoute) handle(ctx *Context) error {

return r.ch.WriteResponse(call, ctx, respData)
}

func (r *grpcRoute) handleRoute(w http.ResponseWriter, req *http.Request, params httprouter.Params) {
// 封装闭包可能带来一点点内存开销
r.engine.Execute(w, req, Params(params), r.middlewares, r.handle)
}
29 changes: 0 additions & 29 deletions engine/route_test.go

This file was deleted.

Loading

0 comments on commit c873a75

Please sign in to comment.