From 1821c94c335acbac7f9a483e771faf25fd60d9cb Mon Sep 17 00:00:00 2001 From: Matt Jones Date: Wed, 23 Sep 2015 10:49:52 -0700 Subject: [PATCH 1/2] generate Generate methods for enums (compatible with testing/quick) --- generator/go.go | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/generator/go.go b/generator/go.go index e02cf6a..cb6b38b 100644 --- a/generator/go.go +++ b/generator/go.go @@ -16,6 +16,7 @@ import ( "os" "path/filepath" "runtime" + "sort" "strconv" "strings" @@ -422,6 +423,22 @@ func (e *%s) UnmarshalJSON(b []byte) error { } `, enumName, enumName, enumName, enumName) + valueStrings := make([]string, 0, len(enum.Values)) + for _, val := range enum.Values { + valueStrings = append(valueStrings, strconv.FormatInt(int64(val.Value), 10)) + } + sort.Strings(valueStrings) + valueStringsName := strings.ToLower(enumName) + "Values" + + g.write(out, ` +var %s = []int32{%s} + +func (e *%s) Generate(rand *rand.Rand, size int) reflect.Value { + v := %s(%s[rand.Intn(%d)]) + return reflect.ValueOf(&v) +} +`, valueStringsName, strings.Join(valueStrings, ", "), enumName, enumName, valueStringsName, len(valueNames)) + return nil } @@ -620,7 +637,7 @@ func (g *GoGenerator) generateSingle(out io.Writer, thriftPath string, thrift *p // Imports imports := []string{"fmt"} if len(thrift.Enums) > 0 { - imports = append(imports, "strconv") + imports = append(imports, "strconv", "math/rand", "reflect") } if len(thrift.Includes) > 0 { for _, path := range thrift.Includes { From 2631e8fb4143f8f91ae2b7e8f9dd49ce42b8bf8a Mon Sep 17 00:00:00 2001 From: Matt Jones Date: Tue, 29 Sep 2015 11:17:47 -0700 Subject: [PATCH 2/2] add a flag for Enum Generate methods --- generator/go.go | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/generator/go.go b/generator/go.go index cb6b38b..cfe611c 100644 --- a/generator/go.go +++ b/generator/go.go @@ -24,10 +24,11 @@ import ( ) var ( - flagGoBinarystring = flag.Bool("go.binarystring", false, "Always use string for binary instead of []byte") - flagGoJSONEnumnum = flag.Bool("go.json.enumnum", false, "For JSON marshal enums by number instead of name") - flagGoPointers = flag.Bool("go.pointers", false, "Make all fields pointers") - flagGoImportPrefix = flag.String("go.importprefix", "", "Prefix for thrift-generated go package imports") + flagGoBinarystring = flag.Bool("go.binarystring", false, "Always use string for binary instead of []byte") + flagGoJSONEnumnum = flag.Bool("go.json.enumnum", false, "For JSON marshal enums by number instead of name") + flagGoPointers = flag.Bool("go.pointers", false, "Make all fields pointers") + flagGoImportPrefix = flag.String("go.importprefix", "", "Prefix for thrift-generated go package imports") + flagGoGenerateMethods = flag.Bool("go.generate", false, "Add testing/quick compatible Generate methods to enum types") ) var ( @@ -423,14 +424,15 @@ func (e *%s) UnmarshalJSON(b []byte) error { } `, enumName, enumName, enumName, enumName) - valueStrings := make([]string, 0, len(enum.Values)) - for _, val := range enum.Values { - valueStrings = append(valueStrings, strconv.FormatInt(int64(val.Value), 10)) - } - sort.Strings(valueStrings) - valueStringsName := strings.ToLower(enumName) + "Values" + if *flagGoGenerateMethods { + valueStrings := make([]string, 0, len(enum.Values)) + for _, val := range enum.Values { + valueStrings = append(valueStrings, strconv.FormatInt(int64(val.Value), 10)) + } + sort.Strings(valueStrings) + valueStringsName := strings.ToLower(enumName) + "Values" - g.write(out, ` + g.write(out, ` var %s = []int32{%s} func (e *%s) Generate(rand *rand.Rand, size int) reflect.Value { @@ -438,6 +440,7 @@ func (e *%s) Generate(rand *rand.Rand, size int) reflect.Value { return reflect.ValueOf(&v) } `, valueStringsName, strings.Join(valueStrings, ", "), enumName, enumName, valueStringsName, len(valueNames)) + } return nil } @@ -637,7 +640,11 @@ func (g *GoGenerator) generateSingle(out io.Writer, thriftPath string, thrift *p // Imports imports := []string{"fmt"} if len(thrift.Enums) > 0 { - imports = append(imports, "strconv", "math/rand", "reflect") + imports = append(imports, "strconv") + + if *flagGoGenerateMethods { + imports = append(imports, "math/rand", "reflect") + } } if len(thrift.Includes) > 0 { for _, path := range thrift.Includes {