forked from trpc-group/trpc-go
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfasthttp.go
154 lines (136 loc) · 4.8 KB
/
fasthttp.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
//
//
// Tencent is pleased to support the open source community by making tRPC available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company.
// All rights reserved.
//
// If you have downloaded a copy of the tRPC source code from Tencent,
// please note that tRPC source code is licensed under the Apache 2.0 License,
// A copy of the Apache 2.0 License is included in this file.
//
//
package restful
import (
"bytes"
"context"
"unsafe"
"github.com/valyala/fasthttp"
"google.golang.org/protobuf/proto"
"trpc.group/trpc-go/trpc-go/errs"
)
// FastHTTPHeaderMatcher matches fasthttp request header to tRPC Stub Context.
type FastHTTPHeaderMatcher func(
ctx context.Context,
requestCtx *fasthttp.RequestCtx,
serviceName, methodName string,
) (context.Context, error)
// DefaultFastHTTPHeaderMatcher is the default FastHTTPHeaderMatcher.
var DefaultFastHTTPHeaderMatcher = func(
ctx context.Context,
requestCtx *fasthttp.RequestCtx,
serviceName, methodName string,
) (context.Context, error) {
return withNewMessage(ctx, serviceName, methodName), nil
}
// FastHTTPRespHandler is the custom response handler when fasthttp is used.
type FastHTTPRespHandler func(
ctx context.Context,
requestCtx *fasthttp.RequestCtx,
resp proto.Message,
body []byte,
) error
// DefaultFastHTTPRespHandler is the default FastHTTPRespHandler.
func DefaultFastHTTPRespHandler(stubCtx context.Context, requestCtx *fasthttp.RequestCtx,
protoResp proto.Message, body []byte) error {
// compress
writer := requestCtx.Response.BodyWriter()
// fasthttp doesn't support getting multiple values of one key from http headers.
// ctx.Request.Header.Peek is equivalent to req.Header.Get from Go net/http.
_, c := compressorForTranscoding(
[]string{bytes2str(requestCtx.Request.Header.Peek(headerContentEncoding))},
[]string{bytes2str(requestCtx.Request.Header.Peek(headerAcceptEncoding))},
)
if c != nil {
writeCloser, err := c.Compress(writer)
if err != nil {
return err
}
defer writeCloser.Close()
requestCtx.Response.Header.Set(headerContentEncoding, c.ContentEncoding())
writer = writeCloser
}
// set response content-type
_, s := serializerForTranscoding(
[]string{bytes2str(requestCtx.Request.Header.Peek(headerContentType))},
[]string{bytes2str(requestCtx.Request.Header.Peek(headerAccept))},
)
requestCtx.Response.Header.Set(headerContentType, s.ContentType())
// set status code
statusCode := GetStatusCodeOnSucceed(stubCtx)
requestCtx.SetStatusCode(statusCode)
// write body
if statusCode != fasthttp.StatusNoContent && statusCode != fasthttp.StatusNotModified {
writer.Write(body)
}
return nil
}
// bytes2str is the high-performance way of converting []byte to string.
func bytes2str(b []byte) string {
return *(*string)(unsafe.Pointer(&b))
}
// HandleRequestCtx fasthttp handler
func (r *Router) HandleRequestCtx(ctx *fasthttp.RequestCtx) {
newCtx := context.Background()
for _, tr := range r.transcoders[bytes2str(ctx.Method())] {
fieldValues, err := tr.pat.Match(bytes2str(ctx.Path()))
if err == nil {
// header matching
stubCtx, err := r.opts.FastHTTPHeaderMatcher(newCtx, ctx,
r.opts.ServiceName, tr.name)
if err != nil {
r.opts.FastHTTPErrHandler(stubCtx, ctx, errs.New(errs.RetServerDecodeFail, err.Error()))
return
}
// get inbound/outbound Compressor & Serializer
reqCompressor, respCompressor := compressorForTranscoding(
[]string{bytes2str(ctx.Request.Header.Peek(headerContentEncoding))},
[]string{bytes2str(ctx.Request.Header.Peek(headerAcceptEncoding))},
)
reqSerializer, respSerializer := serializerForTranscoding(
[]string{bytes2str(ctx.Request.Header.Peek(headerContentType))},
[]string{bytes2str(ctx.Request.Header.Peek(headerAccept))},
)
// get query params
form := make(map[string][]string)
ctx.QueryArgs().VisitAll(func(key []byte, value []byte) {
form[bytes2str(key)] = append(form[bytes2str(key)], bytes2str(value))
})
// set transcoding params
params := paramsPool.Get().(*transcodeParams)
params.reqCompressor = reqCompressor
params.respCompressor = respCompressor
params.reqSerializer = reqSerializer
params.respSerializer = respSerializer
params.body = bytes.NewBuffer(ctx.PostBody())
params.fieldValues = fieldValues
params.form = form
// transcode
resp, body, err := tr.transcode(stubCtx, params)
if err != nil {
r.opts.FastHTTPErrHandler(stubCtx, ctx, err)
putBackCtxMessage(stubCtx)
putBackParams(params)
return
}
// response
if err := r.opts.FastHTTPRespHandler(stubCtx, ctx, resp, body); err != nil {
r.opts.FastHTTPErrHandler(stubCtx, ctx, errs.New(errs.RetServerEncodeFail, err.Error()))
}
putBackCtxMessage(stubCtx)
putBackParams(params)
return
}
}
r.opts.FastHTTPErrHandler(newCtx, ctx, errs.New(errs.RetServerNoFunc, "failed to match any pattern"))
}