From 346493888e28197c115973723efd5aa02781a98c Mon Sep 17 00:00:00 2001 From: Hchen Date: Sat, 11 Jan 2025 00:34:45 +0800 Subject: [PATCH] fix: oneof type only write/size once --- protoc-gen-fastpb/generator/generator.go | 92 +++++++++++++----------- 1 file changed, 50 insertions(+), 42 deletions(-) diff --git a/protoc-gen-fastpb/generator/generator.go b/protoc-gen-fastpb/generator/generator.go index 365d269..d9b09c4 100644 --- a/protoc-gen-fastpb/generator/generator.go +++ b/protoc-gen-fastpb/generator/generator.go @@ -218,9 +218,16 @@ func (f *fgMessage) GenFastWrite(g *protogen.GeneratedFile) { g.P(fmt.Sprintf("func (x *%s) FastWrite(buf []byte) (offset int) {", f.name())) // switch case g.P("if x == nil { return offset }") - for i := range f.m.Fields { - number := f.m.Fields[i].Desc.Number() - g.P(fmt.Sprintf("offset += x.fastWriteField%d(buf[offset:])", number)) + for _, field := range f.m.Fields { + number := field.Desc.Number() + if field.Oneof != nil && !field.Oneof.Desc.IsSynthetic() { + g.P(fmt.Sprintf("if _, ok := x.GetKind().(*%s); ok {", field.GoIdent.GoName)) + g.P(fmt.Sprintf("offset += x.fastWriteField%d(buf[offset:])", number)) + g.P(`return offset`) + g.P(`}`) + } else { + g.P(fmt.Sprintf("offset += x.fastWriteField%d(buf[offset:])", number)) + } } g.P(`return offset`) g.P(`}`) @@ -231,9 +238,16 @@ func (f *fgMessage) GenFastSize(g *protogen.GeneratedFile) { g.P(fmt.Sprintf("func (x *%s) Size() (n int) {", f.name())) // switch case g.P("if x == nil { return n }") - for i := range f.m.Fields { - number := f.m.Fields[i].Desc.Number() - g.P(fmt.Sprintf("n += x.sizeField%d()", number)) + for _, field := range f.m.Fields { + number := field.Desc.Number() + if field.Oneof != nil && !field.Oneof.Desc.IsSynthetic() { + g.P(fmt.Sprintf("if _, ok := x.GetKind().(*%s); ok {", field.GoIdent.GoName)) + g.P(fmt.Sprintf("n += x.sizeField%d()", number)) + g.P(`return n`) + g.P(`}`) + } else { + g.P(fmt.Sprintf("n += x.sizeField%d()", number)) + } } g.P(`return n`) g.P(`}`) @@ -288,25 +302,22 @@ func (f *fgField) GenFastRead(g *protogen.GeneratedFile) { func (f *fgField) GenFastWrite(g *protogen.GeneratedFile) { g.P(fmt.Sprintf("func (x *%s) fastWriteField%s(buf []byte) (offset int) {", f.parentName(), f.number)) - - setter := fmt.Sprintf("x.%s", f.name()) - getSetter := fmt.Sprintf("x.Get%s()", f.name()) - // oneof need replace setter - if f.oneofType != "" { - setter = fmt.Sprintf("x.Get%s()", f.name()) - } - switch { - case f.f.Desc.Kind() == protoreflect.MessageKind, f.isPointer: - g.P(fmt.Sprintf("if %s == nil { return offset }", setter)) - case f.f.Desc.IsMap() || f.f.Desc.IsList() || f.f.Desc.Kind() == protoreflect.BytesKind: - g.P(fmt.Sprintf("if len(%s) == 0 { return offset }", setter)) - case f.f.Desc.Kind() == protoreflect.BoolKind: - g.P(fmt.Sprintf("if !%s { return offset }", setter)) - case f.f.Desc.Kind() == protoreflect.StringKind: - g.P(fmt.Sprintf(`if %s == "" { return offset }`, setter)) - default: - g.P(fmt.Sprintf("if %s == 0 { return offset }", setter)) + if f.oneofType == "" { + setter := fmt.Sprintf("x.%s", f.name()) + switch { + case f.f.Desc.Kind() == protoreflect.MessageKind, f.isPointer: + g.P(fmt.Sprintf("if %s == nil { return offset }", setter)) + case f.f.Desc.IsMap() || f.f.Desc.IsList() || f.f.Desc.Kind() == protoreflect.BytesKind: + g.P(fmt.Sprintf("if len(%s) == 0 { return offset }", setter)) + case f.f.Desc.Kind() == protoreflect.BoolKind: + g.P(fmt.Sprintf("if !%s { return offset }", setter)) + case f.f.Desc.Kind() == protoreflect.StringKind: + g.P(fmt.Sprintf(`if %s == "" { return offset }`, setter)) + default: + g.P(fmt.Sprintf("if %s == 0 { return offset }", setter)) + } } + getSetter := fmt.Sprintf("x.Get%s()", f.name()) f.body.bodyFastWrite(g, getSetter, f.number) g.P("return offset") g.P("}") @@ -315,25 +326,22 @@ func (f *fgField) GenFastWrite(g *protogen.GeneratedFile) { func (f *fgField) GenFastSize(g *protogen.GeneratedFile) { g.P(fmt.Sprintf("func (x *%s) sizeField%s() (n int) {", f.parentName(), f.number)) - - setter := fmt.Sprintf("x.%s", f.name()) - getSetter := fmt.Sprintf("x.Get%s()", f.name()) - // oneof need replace setter - if f.oneofType != "" { - setter = fmt.Sprintf("x.Get%s()", f.name()) - } - switch { - case f.f.Desc.Kind() == protoreflect.MessageKind, f.isPointer: - g.P(fmt.Sprintf("if %s == nil { return n }", setter)) - case f.f.Desc.IsMap() || f.f.Desc.IsList() || f.f.Desc.Kind() == protoreflect.BytesKind: - g.P(fmt.Sprintf("if len(%s) == 0 { return n }", setter)) - case f.f.Desc.Kind() == protoreflect.BoolKind: - g.P(fmt.Sprintf("if !%s { return n }", setter)) - case f.f.Desc.Kind() == protoreflect.StringKind: - g.P(fmt.Sprintf(`if %s == "" { return n }`, setter)) - default: - g.P(fmt.Sprintf("if %s == 0 { return n }", setter)) + if f.oneofType == "" { + setter := fmt.Sprintf("x.%s", f.name()) + switch { + case f.f.Desc.Kind() == protoreflect.MessageKind, f.isPointer: + g.P(fmt.Sprintf("if %s == nil { return n }", setter)) + case f.f.Desc.IsMap() || f.f.Desc.IsList() || f.f.Desc.Kind() == protoreflect.BytesKind: + g.P(fmt.Sprintf("if len(%s) == 0 { return n }", setter)) + case f.f.Desc.Kind() == protoreflect.BoolKind: + g.P(fmt.Sprintf("if !%s { return n }", setter)) + case f.f.Desc.Kind() == protoreflect.StringKind: + g.P(fmt.Sprintf(`if %s == "" { return n }`, setter)) + default: + g.P(fmt.Sprintf("if %s == 0 { return n }", setter)) + } } + getSetter := fmt.Sprintf("x.Get%s()", f.name()) f.body.bodyFastSize(g, getSetter, f.number) g.P("return n") g.P("}")