Skip to content

Commit

Permalink
feat(generic_http_thrift): fail on nil value for required field
Browse files Browse the repository at this point in the history
  • Loading branch information
wasd96040501 committed Sep 16, 2024
1 parent 4e1dbe9 commit 5b5d873
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 16 deletions.
20 changes: 11 additions & 9 deletions pkg/generic/httpthrift_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,20 @@ type HTTPRequest = descriptor.HTTPRequest
type HTTPResponse = descriptor.HTTPResponse

type httpThriftCodec struct {
svcDsc atomic.Value // *idl
provider DescriptorProvider
binaryWithBase64 bool
convOpts conv.Options // used for dynamicgo conversion
convOptsWithThriftBase conv.Options // used for dynamicgo conversion with EnableThriftBase turned on
dynamicgoEnabled bool
useRawBodyForHTTPResp bool
svcName string
svcDsc atomic.Value // *idl
provider DescriptorProvider
binaryWithBase64 bool
convOpts conv.Options // used for dynamicgo conversion
convOptsWithThriftBase conv.Options // used for dynamicgo conversion with EnableThriftBase turned on
dynamicgoEnabled bool
useRawBodyForHTTPResp bool
failOnNilValueForRequiredField bool
svcName string
}

func newHTTPThriftCodec(p DescriptorProvider, opts *Options) *httpThriftCodec {
svc := <-p.Provide()
c := &httpThriftCodec{provider: p, binaryWithBase64: false, dynamicgoEnabled: false, useRawBodyForHTTPResp: opts.useRawBodyForHTTPResp, svcName: svc.Name}
c := &httpThriftCodec{provider: p, binaryWithBase64: false, dynamicgoEnabled: false, useRawBodyForHTTPResp: opts.useRawBodyForHTTPResp, failOnNilValueForRequiredField: opts.failOnNilValueForRequiredField, svcName: svc.Name}
if dp, ok := p.(GetProviderOption); ok && dp.Option().DynamicGoEnabled {
c.dynamicgoEnabled = true

Expand Down Expand Up @@ -95,6 +96,7 @@ func (c *httpThriftCodec) configureHTTPRequestWriter(writer *thrift.WriteHTTPReq
if c.dynamicgoEnabled {
writer.SetDynamicGo(&c.convOpts, &c.convOptsWithThriftBase)
}
writer.SetFailOnNilValueForRequiredField(c.failOnNilValueForRequiredField)
}

func (c *httpThriftCodec) configureHTTPResponseReader(reader *thrift.ReadHTTPResponse) {
Expand Down
35 changes: 35 additions & 0 deletions pkg/generic/httpthrift_codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,41 @@ func TestHttpThriftCodecWithDynamicGo(t *testing.T) {
test.Assert(t, ok)
}

func TestHttpThriftCodecWithFailOnNilValueForRequired(t *testing.T) {
// without dynamicgo
p, err := NewThriftFileProvider("./http_test/idl/binary_echo.thrift")
test.Assert(t, err == nil)
gOpts := &Options{dynamicgoConvOpts: DefaultHTTPDynamicGoConvOpts, failOnNilValueForRequiredField: true}
htc := newHTTPThriftCodec(p, gOpts)
test.Assert(t, !htc.dynamicgoEnabled)
test.Assert(t, !htc.useRawBodyForHTTPResp)
test.Assert(t, htc.failOnNilValueForRequiredField)
test.DeepEqual(t, htc.convOpts, conv.Options{})
test.DeepEqual(t, htc.convOptsWithThriftBase, conv.Options{})
defer htc.Close()
test.Assert(t, htc.Name() == "HttpThrift")

req := &HTTPRequest{Request: getStdHttpRequest()}
// wrong
method, err := htc.getMethod("test")
test.Assert(t, err.Error() == "req is invalid, need descriptor.HTTPRequest" && method == nil)
// right
method, err = htc.getMethod(req)
test.Assert(t, err == nil && method.Name == "BinaryEcho")
test.Assert(t, method.StreamingMode == serviceinfo.StreamingNone)
test.Assert(t, htc.svcName == "ExampleService")

rw := htc.getMessageReaderWriter()
_, ok := rw.(error)
test.Assert(t, !ok)

rw = htc.getMessageReaderWriter()
_, ok = rw.(thrift.MessageWriter)
test.Assert(t, ok)
_, ok = rw.(thrift.MessageReader)
test.Assert(t, ok)
}

func getStdHttpRequest() *http.Request {
body := map[string]interface{}{
"msg": []byte("hello"),
Expand Down
9 changes: 9 additions & 0 deletions pkg/generic/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ type Options struct {
dynamicgoConvOpts conv.Options
// flag to set whether to store http resp body into HTTPResponse.RawBody
useRawBodyForHTTPResp bool
// will return error when field is required but input value is nil
failOnNilValueForRequiredField bool
}

type Option struct {
Expand Down Expand Up @@ -68,3 +70,10 @@ func UseRawBodyForHTTPResp(enable bool) Option {
opt.useRawBodyForHTTPResp = enable
}}
}

// will return error when field is required but input value is nil
func WithFailOnNilValueForRequiredField(enable bool) Option {
return Option{F: func(opt *Options) {
opt.failOnNilValueForRequiredField = enable
}}
}
17 changes: 11 additions & 6 deletions pkg/generic/thrift/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,12 @@ func NewHTTPReaderWriter(svc *descriptor.ServiceDescriptor) *HTTPReaderWriter {

// WriteHTTPRequest implement of MessageWriter
type WriteHTTPRequest struct {
svc *descriptor.ServiceDescriptor
binaryWithBase64 bool
convOpts conv.Options // used for dynamicgo conversion
convOptsWithThriftBase conv.Options // used for dynamicgo conversion with EnableThriftBase turned on
dynamicgoEnabled bool
svc *descriptor.ServiceDescriptor
binaryWithBase64 bool
convOpts conv.Options // used for dynamicgo conversion
convOptsWithThriftBase conv.Options // used for dynamicgo conversion with EnableThriftBase turned on
dynamicgoEnabled bool
failOnNilValueForRequiredField bool // will return error when field is required but input value is nil
}

var (
Expand All @@ -71,6 +72,10 @@ func (w *WriteHTTPRequest) SetBinaryWithBase64(enable bool) {
w.binaryWithBase64 = enable
}

func (w *WriteHTTPRequest) SetFailOnNilValueForRequiredField(enable bool) {
w.failOnNilValueForRequiredField = enable
}

// SetDynamicGo ...
func (w *WriteHTTPRequest) SetDynamicGo(convOpts, convOptsWithThriftBase *conv.Options) {
w.convOpts = *convOpts
Expand All @@ -94,7 +99,7 @@ func (w *WriteHTTPRequest) originalWrite(ctx context.Context, out bufiox.Writer,
requestBase = nil
}
bw := thrift.NewBufferWriter(out)
err = wrapStructWriter(ctx, req, bw, fn.Request, &writerOption{requestBase: requestBase, binaryWithBase64: w.binaryWithBase64})
err = wrapStructWriter(ctx, req, bw, fn.Request, &writerOption{requestBase: requestBase, binaryWithBase64: w.binaryWithBase64, failOnNilValueForRequiredField: w.failOnNilValueForRequiredField})
bw.Recycle()
return err
}
Expand Down
5 changes: 5 additions & 0 deletions pkg/generic/thrift/write.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ type writerOption struct {
requestBase *base.Base // request base from metahandler
// decoding Base64 to binary
binaryWithBase64 bool
// will return error when field is required but input value is nil
failOnNilValueForRequiredField bool
}

type writer func(ctx context.Context, val interface{}, out *thrift.BufferWriter, t *descriptor.TypeDescriptor, opt *writerOption) error
Expand Down Expand Up @@ -778,6 +780,9 @@ func writeHTTPRequest(ctx context.Context, val interface{}, out *thrift.BufferWr

if v == nil {
if !field.Optional {
if opt != nil && opt.failOnNilValueForRequiredField {
return fmt.Errorf("value of field [%s] is nil", name)
}
if err := out.WriteFieldBegin(field.Type.Type.ToThriftTType(), int16(field.ID)); err != nil {
return err
}
Expand Down
32 changes: 31 additions & 1 deletion pkg/generic/thrift/write_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1358,6 +1358,36 @@ func Test_writeHTTPRequest(t *testing.T) {
},
false,
},
{
"writeStructRequiredFail",
args{
val: &descriptor.HTTPRequest{
Body: map[string]interface{}{"hello": nil},
},

t: &descriptor.TypeDescriptor{
Type: descriptor.STRUCT,
Key: &descriptor.TypeDescriptor{Type: descriptor.STRING},
Elem: &descriptor.TypeDescriptor{Type: descriptor.STRING},
Struct: &descriptor.StructDescriptor{
Name: "Demo",
FieldsByName: map[string]*descriptor.FieldDescriptor{
"hello": {
Name: "hello",
ID: 1,
Required: true,
Type: &descriptor.TypeDescriptor{Type: descriptor.STRING},
HTTPMapping: descriptor.DefaultNewMapping("hello"),
},
},
},
},
opt: &writerOption{
failOnNilValueForRequiredField: true,
},
},
true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down Expand Up @@ -1434,7 +1464,7 @@ func getReqPbBody() (proto.Message, error) {
path := "main.proto"
content := `
package kitex.test.server;
message BizReq {
optional int32 user_id = 1;
optional string user_name = 2;
Expand Down

0 comments on commit 5b5d873

Please sign in to comment.