Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion fastpb_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (
)

// Impl implements Protocol.
var Impl impl
var Impl Protocol = impl{}

// When encoding length-prefixed fields, we speculatively set aside some number of bytes
// for the length, encode the data, and then encode the length (shifting the data if necessary
Expand All @@ -44,6 +44,11 @@ func SetSpanCache(enable bool) {
spanCacheEnable = enable
}

// SetImpl replaces the specific codec implementation to support function hijacking etc...
func SetImpl(impl Protocol) {
Impl = impl
}

type impl struct{}

// WriteMessage implements TLV(tag, length, value) and V(value).
Expand Down
99 changes: 99 additions & 0 deletions fastpb_impl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package fastpb
import (
"fmt"
"testing"
"unicode/utf8"

"google.golang.org/protobuf/encoding/protowire"
)
Expand Down Expand Up @@ -493,3 +494,101 @@ func AssertConsumeTag(name string, got_n protowire.Number, got_t protowire.Type,
panic(fmt.Errorf("%s ConsumeTag num[%d]type[%d] != except[%d][%d]", name, got_n, got_t, exp_n, exp_t))
}
}

type customizedImpl struct {
Protocol
}

func (impl customizedImpl) WriteString(buf []byte, number int32, value string) (n int) {
if number != SkipTagNumber {
n += AppendTag(buf[n:], protowire.Number(number), protowire.BytesType)
}
if !utf8.ValidString(value) {
n += AppendString(buf[n:], "")
} else {
n += AppendString(buf[n:], value)
}
return n
}

func (impl customizedImpl) SizeString(number int32, value string) (n int) {
if number != SkipTagNumber {
n += protowire.SizeVarint(protowire.EncodeTag(protowire.Number(number), protowire.BytesType))
}
if !utf8.ValidString(value) {
// empty string
n += 1
} else {
n += protowire.SizeBytes(len(value))
}
return n
}

func Test_InjectCustomizedImpl(t *testing.T) {
orig := Impl
defer func() {
SetImpl(orig)
}()
SetImpl(customizedImpl{
Protocol: orig,
})

// utf-8 valid
// write
var num int32 = 255
value := "hello world"
exceptSize := 14
size := Impl.SizeString(num, value)
if size != exceptSize {
panic(fmt.Errorf("SizeString[%d] != except[%d]", size, exceptSize))
}
buf := make([]byte, 64)
exceptWs := "fa0f0b68656c6c6f20776f726c64"
wn := Impl.WriteString(buf, num, value)
ws := fmt.Sprintf("%x", buf[:wn])
if wn != size || ws != exceptWs {
panic(fmt.Errorf("WriteString[%d][%s] != except[%d][%s]", wn, ws, size, exceptWs))
}

// read
_type := protowire.BytesType
gotRn, gotRt, offset := protowire.ConsumeTag(buf)
AssertConsumeTag("ReadString", gotRn, gotRt, num, _type)
rv, rn, err := Impl.ReadString(buf[offset:], int8(_type))
if err != nil {
panic(err)
}
rn += offset
if rn != wn || rv != value {
panic(fmt.Errorf("ReadString[%d][%s] != except[%d][%s]", rn, rv, wn, value))
}

// utf-8 invalid
// write
value = "'\xff'"
exceptSize = 3
size = Impl.SizeString(num, value)
if size != exceptSize {
panic(fmt.Errorf("SizeString[%d] != except[%d]", size, exceptSize))
}
buf = make([]byte, 64)
exceptWs = "fa0f00"
wn = Impl.WriteString(buf, num, value)
ws = fmt.Sprintf("%x", buf[:wn])
if wn != exceptSize || ws != exceptWs {
panic(fmt.Errorf("WriteString[%d][%s] != except[%d][%s]", wn, ws, size, exceptWs))
}

// read
_type = protowire.BytesType
gotRn, gotRt, offset = protowire.ConsumeTag(buf)
AssertConsumeTag("ReadString", gotRn, gotRt, num, _type)
rv, rn, err = Impl.ReadString(buf[offset:], int8(_type))
if err != nil {
panic(err)
}
rn += offset
if rn != wn || rv != "" {
panic(fmt.Errorf("ReadString[%d][%s] != except[%d][%s]", rn, rv, wn, value))
}
}
Loading