diff --git a/generator/go.go b/generator/go.go index e02cf6a..cfe611c 100644 --- a/generator/go.go +++ b/generator/go.go @@ -16,6 +16,7 @@ import ( "os" "path/filepath" "runtime" + "sort" "strconv" "strings" @@ -23,10 +24,11 @@ import ( ) var ( - flagGoBinarystring = flag.Bool("go.binarystring", false, "Always use string for binary instead of []byte") - flagGoJSONEnumnum = flag.Bool("go.json.enumnum", false, "For JSON marshal enums by number instead of name") - flagGoPointers = flag.Bool("go.pointers", false, "Make all fields pointers") - flagGoImportPrefix = flag.String("go.importprefix", "", "Prefix for thrift-generated go package imports") + flagGoBinarystring = flag.Bool("go.binarystring", false, "Always use string for binary instead of []byte") + flagGoJSONEnumnum = flag.Bool("go.json.enumnum", false, "For JSON marshal enums by number instead of name") + flagGoPointers = flag.Bool("go.pointers", false, "Make all fields pointers") + flagGoImportPrefix = flag.String("go.importprefix", "", "Prefix for thrift-generated go package imports") + flagGoGenerateMethods = flag.Bool("go.generate", false, "Add testing/quick compatible Generate methods to enum types") ) var ( @@ -422,6 +424,24 @@ func (e *%s) UnmarshalJSON(b []byte) error { } `, enumName, enumName, enumName, enumName) + if *flagGoGenerateMethods { + valueStrings := make([]string, 0, len(enum.Values)) + for _, val := range enum.Values { + valueStrings = append(valueStrings, strconv.FormatInt(int64(val.Value), 10)) + } + sort.Strings(valueStrings) + valueStringsName := strings.ToLower(enumName) + "Values" + + g.write(out, ` +var %s = []int32{%s} + +func (e *%s) Generate(rand *rand.Rand, size int) reflect.Value { + v := %s(%s[rand.Intn(%d)]) + return reflect.ValueOf(&v) +} +`, valueStringsName, strings.Join(valueStrings, ", "), enumName, enumName, valueStringsName, len(valueNames)) + } + return nil } @@ -621,6 +641,10 @@ func (g *GoGenerator) generateSingle(out io.Writer, thriftPath string, thrift *p imports := []string{"fmt"} if len(thrift.Enums) > 0 { imports = append(imports, "strconv") + + if *flagGoGenerateMethods { + imports = append(imports, "math/rand", "reflect") + } } if len(thrift.Includes) > 0 { for _, path := range thrift.Includes {