From 12e411f97507fdcf5375c922e516ff2a04e9e741 Mon Sep 17 00:00:00 2001 From: Bradley Gore Date: Fri, 8 Jul 2022 16:09:21 -0500 Subject: [PATCH 01/14] start on work for generics in the AST tree --- mockgen/parse.go | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/mockgen/parse.go b/mockgen/parse.go index 21c0d70a..a4d5e48d 100644 --- a/mockgen/parse.go +++ b/mockgen/parse.go @@ -340,12 +340,25 @@ func (p *fileParser) parseInterface(name, pkg string, it *namedInterface) (*mode for _, m := range embeddedIface.Methods { iface.AddMethod(m) } - case *ast.SelectorExpr: + case *ast.SelectorExpr, *ast.IndexExpr: + + var ( + ident *ast.Ident + selIdent *ast.Ident + ) + + if se, ok := v.(*ast.SelectorExpr); ok { + ident, selIdent = se.X.(*ast.Ident), se.Sel + } else { + ie := v.(*ast.IndexExpr) + se := ie.X.(*ast.SelectorExpr) + ident, selIdent = se.X.(*ast.Ident), se.Sel + } // Embedded interface in another package. - filePkg, sel := v.X.(*ast.Ident).String(), v.Sel.String() + filePkg, sel := ident.String(), selIdent.String() embeddedPkg, ok := p.imports[filePkg] if !ok { - return nil, p.errorf(v.X.Pos(), "unknown package %s", filePkg) + return nil, p.errorf(ident.Pos(), "unknown package %s", filePkg) } var embeddedIface *model.Interface @@ -383,6 +396,16 @@ func (p *fileParser) parseInterface(name, pkg string, it *namedInterface) (*mode for _, m := range embeddedIface.Methods { iface.AddMethod(m) } + // case *ast.IndexExpr: + // // Embedded generic interface in another package + // idxExpr := v.X + // ident := idxExpr.(*ast.Ident) + // filePkg, sel := ident.String(), ident.Name + // fmt.Printf("filePkg=%s, sel=%s", filePkg, sel) + // _, ok := p.imports[filePkg] + // if !ok { + // return nil, p.errorf(v.X.Pos(), "unknown package %s", filePkg) + // } default: return nil, fmt.Errorf("don't know how to mock method of type %T", field.Type) } From 98733d82c75527cec10315989444ce79a56b40d7 Mon Sep 17 00:00:00 2001 From: Bradley Gore Date: Mon, 11 Jul 2022 13:12:23 -0500 Subject: [PATCH 02/14] work out interface with generic params and swapping out typenames per usage of generic icace --- mockgen/model/model.go | 118 +++++++++++++++++++++++++++++++++++++++++ mockgen/parse.go | 88 +++++++++++++++++++++++------- 2 files changed, 187 insertions(+), 19 deletions(-) diff --git a/mockgen/model/model.go b/mockgen/model/model.go index 94d7f4ba..6a4a060a 100644 --- a/mockgen/model/model.go +++ b/mockgen/model/model.go @@ -61,6 +61,36 @@ type Interface struct { TypeParams []*Parameter } +// TypeParamIndexByName returns the index of the type parameter matching on name. If none matching, returns -1, nil. +// +// This is especially useful for generics where interface is something like this: +// Doer[T any, K any]{ +// Start(T) +// Add(K) error +// Stop() []K +// } +// +// But it is used like this: +// type MyDoer = Doer[types.SomeType, otherPkg.SomeOtherThing] +// or as an embedded interface: +// type MyDoer interface { +// Doer[types.SomeType, otherPkg.SomeOtherThing] +// } +// +// If parsing the Add method for an implementation of this interface, +// we need to be able to swap out K for whatever the actual type is. +// K will be at index 1 of the interface's TypeParams, +// but it will also be at index 1 of the actual type params in either the +// definition of the interface being mocked or the generic interface it embeds. +func (intf *Interface) TypeParamIndexByName(name string) int { + for i, p := range intf.TypeParams { + if p.Name == name { + return i + } + } + return -1 +} + // Print writes the interface name and its methods. func (intf *Interface) Print(w io.Writer) { _, _ = fmt.Fprintf(w, "interface %s\n", intf.Name) @@ -92,6 +122,22 @@ type Method struct { Variadic *Parameter // may be nil } +func (m *Method) Clone() *Method { + mm := &Method{ + Name: m.Name, + In: make([]*Parameter, 0), + Out: make([]*Parameter, 0), + Variadic: m.Variadic.clone(), + } + for _, in := range m.In { + mm.In = append(mm.In, in.clone()) + } + for _, out := range m.Out { + mm.Out = append(mm.Out, out.clone()) + } + return mm +} + // Print writes the method name and its signature. func (m *Method) Print(w io.Writer) { _, _ = fmt.Fprintf(w, " - method %s\n", m.Name) @@ -131,6 +177,16 @@ type Parameter struct { Type Type } +func (p *Parameter) clone() *Parameter { + if p == nil { + return nil + } + return &Parameter{ + Name: p.Name, + Type: p.Type.clone(), + } +} + // Print writes a method parameter. func (p *Parameter) Print(w io.Writer) { n := p.Name @@ -144,6 +200,7 @@ func (p *Parameter) Print(w io.Writer) { type Type interface { String(pm map[string]string, pkgOverride string) string addImports(im map[string]bool) + clone() Type } func init() { @@ -180,6 +237,13 @@ func (at *ArrayType) String(pm map[string]string, pkgOverride string) string { func (at *ArrayType) addImports(im map[string]bool) { at.Type.addImports(im) } +func (at *ArrayType) clone() Type { + return &ArrayType{ + Len: at.Len, + Type: at.Type.clone(), + } +} + // ChanType is a channel type. type ChanType struct { Dir ChanDir // 0, 1 or 2 @@ -199,6 +263,13 @@ func (ct *ChanType) String(pm map[string]string, pkgOverride string) string { func (ct *ChanType) addImports(im map[string]bool) { ct.Type.addImports(im) } +func (ct *ChanType) clone() Type { + return &ChanType{ + Dir: ct.Dir, + Type: ct.Type.clone(), + } +} + // ChanDir is a channel direction. type ChanDir int @@ -247,6 +318,21 @@ func (ft *FuncType) addImports(im map[string]bool) { } } +func (ft *FuncType) clone() Type { + ftt := &FuncType{ + In: make([]*Parameter, 0), + Out: make([]*Parameter, 0), + Variadic: ft.Variadic.clone(), + } + for _, in := range ft.In { + ftt.In = append(ftt.In, in.clone()) + } + for _, out := range ft.Out { + ftt.Out = append(ftt.Out, out.clone()) + } + return ftt +} + // MapType is a map type. type MapType struct { Key, Value Type @@ -261,6 +347,13 @@ func (mt *MapType) addImports(im map[string]bool) { mt.Value.addImports(im) } +func (mt *MapType) clone() Type { + return &MapType{ + Key: mt.Key, + Value: mt.Value.clone(), + } +} + // NamedType is an exported type in a package. type NamedType struct { Package string // may be empty @@ -287,6 +380,15 @@ func (nt *NamedType) addImports(im map[string]bool) { nt.TypeParams.addImports(im) } +func (nt *NamedType) clone() Type { + ntt := &NamedType{ + Package: nt.Package, + Type: nt.Type, + TypeParams: nt.TypeParams.clone().(*TypeParametersType), + } + return ntt +} + // PointerType is a pointer to another type. type PointerType struct { Type Type @@ -297,11 +399,16 @@ func (pt *PointerType) String(pm map[string]string, pkgOverride string) string { } func (pt *PointerType) addImports(im map[string]bool) { pt.Type.addImports(im) } +func (pt *PointerType) clone() Type { + return &PointerType{Type: pt.Type.clone()} +} + // PredeclaredType is a predeclared type such as "int". type PredeclaredType string func (pt PredeclaredType) String(map[string]string, string) string { return string(pt) } func (pt PredeclaredType) addImports(map[string]bool) {} +func (pt PredeclaredType) clone() Type { return PredeclaredType(pt) } // TypeParametersType contains type paramters for a NamedType. type TypeParametersType struct { @@ -333,6 +440,17 @@ func (tp *TypeParametersType) addImports(im map[string]bool) { } } +func (tp *TypeParametersType) clone() Type { + if tp == nil { + return nil + } + tpt := &TypeParametersType{} + for _, t := range tp.TypeParameters { + tpt.TypeParameters = append(tpt.TypeParameters, t.clone()) + } + return tpt +} + // The following code is intended to be called by the program generated by ../reflect.go. // InterfaceFromInterfaceType returns a pointer to an interface for the diff --git a/mockgen/parse.go b/mockgen/parse.go index a4d5e48d..3b0ea66c 100644 --- a/mockgen/parse.go +++ b/mockgen/parse.go @@ -340,26 +340,46 @@ func (p *fileParser) parseInterface(name, pkg string, it *namedInterface) (*mode for _, m := range embeddedIface.Methods { iface.AddMethod(m) } - case *ast.SelectorExpr, *ast.IndexExpr: - + case *ast.SelectorExpr, *ast.IndexExpr, *ast.IndexListExpr: + // Embedded interface in another package. + // *ast.SelectorExpr for embedded legacy iface + // *ast.IndexExpr for embedded generic iface with single index e.g. DoSomething[T] + // *ast.IndexListExpr for embedded generic iface with multiple indexes e.g. DoSomething[T, K] var ( - ident *ast.Ident - selIdent *ast.Ident + ident *ast.Ident + selIdent *ast.Ident + path string + typeParams []model.Type //unsure ) if se, ok := v.(*ast.SelectorExpr); ok { ident, selIdent = se.X.(*ast.Ident), se.Sel - } else { - ie := v.(*ast.IndexExpr) + } else if ie, ok := v.(*ast.IndexExpr); ok { se := ie.X.(*ast.SelectorExpr) ident, selIdent = se.X.(*ast.Ident), se.Sel + typParam, err := p.parseType(pkg, ie.Index, tps) + if err != nil { + return nil, err + } + typeParams = append(typeParams, typParam) + } else { + ile := v.(*ast.IndexListExpr) + se := ile.X.(*ast.SelectorExpr) + ident, selIdent = se.X.(*ast.Ident), se.Sel + for i := range ile.Indices { + typParam, err := p.parseType(pkg, ile.Indices[i], tps) + if err != nil { + return nil, err + } + typeParams = append(typeParams, typParam) + } } - // Embedded interface in another package. filePkg, sel := ident.String(), selIdent.String() embeddedPkg, ok := p.imports[filePkg] if !ok { return nil, p.errorf(ident.Pos(), "unknown package %s", filePkg) } + path = embeddedPkg.Path() var embeddedIface *model.Interface var err error @@ -370,7 +390,6 @@ func (p *fileParser) parseInterface(name, pkg string, it *namedInterface) (*mode return nil, err } } else { - path := embeddedPkg.Path() parser := embeddedPkg.Parser() if parser == nil { ip, err := p.parsePackage(path) @@ -393,19 +412,50 @@ func (p *fileParser) parseInterface(name, pkg string, it *namedInterface) (*mode } // Copy the methods. // TODO: apply shadowing rules. + isGeneric := len(typeParams) > 0 for _, m := range embeddedIface.Methods { - iface.AddMethod(m) + if !isGeneric { + // trivial part - no generic params to consider + iface.AddMethod(m) + continue + } + // non-trivial part - we have to match up the as-used type params with the as-defined + // defined as DoSomething[T any, K any] + // used as DoSomething[somPkg.SomeType, int64] + // meaning methods may be like in definition: + // Do(T) (K, error) + // but need to be like this in implementation: + // Do(somePkg.SomeType) (int64, error) + gm := m.Clone() // clone so we can change without changing source def + + // overwrite all typed params for incoming/outgoing params + // to get the implementor-specified typing over the definition-specified typing + + for _, pim := range gm.In { + if nt, ok := pim.Type.(*model.NamedType); ok && nt.TypeParams != nil { + for i, tp := range nt.TypeParams.TypeParameters { + if srcParamIdx := embeddedIface.TypeParamIndexByName(tp.String(nil, "")); srcParamIdx > -1 && srcParamIdx < len(typeParams) { + dstParamTyp := typeParams[srcParamIdx] + nt.TypeParams.TypeParameters[i] = dstParamTyp + } + } + } + } + for _, out := range gm.Out { + if nt, ok := out.Type.(*model.NamedType); ok && nt.TypeParams != nil { + for i, tp := range nt.TypeParams.TypeParameters { + if srcParamIdx := embeddedIface.TypeParamIndexByName(tp.String(nil, "")); srcParamIdx > -1 && srcParamIdx < len(typeParams) { + dstParamTyp := typeParams[srcParamIdx] + nt.TypeParams.TypeParameters[i] = dstParamTyp + } + } + } + } + + //TODO, anything else we need to do here to support generics? + iface.AddMethod(gm) + } - // case *ast.IndexExpr: - // // Embedded generic interface in another package - // idxExpr := v.X - // ident := idxExpr.(*ast.Ident) - // filePkg, sel := ident.String(), ident.Name - // fmt.Printf("filePkg=%s, sel=%s", filePkg, sel) - // _, ok := p.imports[filePkg] - // if !ok { - // return nil, p.errorf(v.X.Pos(), "unknown package %s", filePkg) - // } default: return nil, fmt.Errorf("don't know how to mock method of type %T", field.Type) } From 5c6c39b6a38c7618af8e93071078c6eacc990aad Mon Sep 17 00:00:00 2001 From: Bradley Gore Date: Mon, 11 Jul 2022 14:57:28 -0500 Subject: [PATCH 03/14] handle embed of external or local generic iface --- mockgen/internal/tests/generics/generics.go | 6 + .../internal/tests/generics/other/other.go | 6 + mockgen/model/model.go | 12 +- mockgen/parse.go | 142 +++++++++++++----- 4 files changed, 122 insertions(+), 44 deletions(-) diff --git a/mockgen/internal/tests/generics/generics.go b/mockgen/internal/tests/generics/generics.go index 0b389622..6f0cc815 100644 --- a/mockgen/internal/tests/generics/generics.go +++ b/mockgen/internal/tests/generics/generics.go @@ -38,3 +38,9 @@ type StructType struct{} type StructType2 struct{} type AliasType Baz[other.Three] + +type EmbeddingIface interface { + Bar[other.Three, error] + other.Otherer[StructType, other.Five] + LocalFunc() error +} diff --git a/mockgen/internal/tests/generics/other/other.go b/mockgen/internal/tests/generics/other/other.go index 9265422b..374aac19 100644 --- a/mockgen/internal/tests/generics/other/other.go +++ b/mockgen/internal/tests/generics/other/other.go @@ -9,3 +9,9 @@ type Three struct{} type Four struct{} type Five interface{} + +type Otherer[T any, R any] interface { + DoT(T) error + DoR(R) error + MakeThem() (T, R, error) +} diff --git a/mockgen/model/model.go b/mockgen/model/model.go index 6a4a060a..fc8cda85 100644 --- a/mockgen/model/model.go +++ b/mockgen/model/model.go @@ -381,10 +381,16 @@ func (nt *NamedType) addImports(im map[string]bool) { } func (nt *NamedType) clone() Type { + if nt == nil { + return nil + } + ntt := &NamedType{ - Package: nt.Package, - Type: nt.Type, - TypeParams: nt.TypeParams.clone().(*TypeParametersType), + Package: nt.Package, + Type: nt.Type, + } + if nt.TypeParams != nil { + ntt.TypeParams = nt.TypeParams.clone().(*TypeParametersType) } return ntt } diff --git a/mockgen/parse.go b/mockgen/parse.go index 3b0ea66c..e5ff06b1 100644 --- a/mockgen/parse.go +++ b/mockgen/parse.go @@ -340,46 +340,13 @@ func (p *fileParser) parseInterface(name, pkg string, it *namedInterface) (*mode for _, m := range embeddedIface.Methods { iface.AddMethod(m) } - case *ast.SelectorExpr, *ast.IndexExpr, *ast.IndexListExpr: + case *ast.SelectorExpr: // Embedded interface in another package. - // *ast.SelectorExpr for embedded legacy iface - // *ast.IndexExpr for embedded generic iface with single index e.g. DoSomething[T] - // *ast.IndexListExpr for embedded generic iface with multiple indexes e.g. DoSomething[T, K] - var ( - ident *ast.Ident - selIdent *ast.Ident - path string - typeParams []model.Type //unsure - ) - - if se, ok := v.(*ast.SelectorExpr); ok { - ident, selIdent = se.X.(*ast.Ident), se.Sel - } else if ie, ok := v.(*ast.IndexExpr); ok { - se := ie.X.(*ast.SelectorExpr) - ident, selIdent = se.X.(*ast.Ident), se.Sel - typParam, err := p.parseType(pkg, ie.Index, tps) - if err != nil { - return nil, err - } - typeParams = append(typeParams, typParam) - } else { - ile := v.(*ast.IndexListExpr) - se := ile.X.(*ast.SelectorExpr) - ident, selIdent = se.X.(*ast.Ident), se.Sel - for i := range ile.Indices { - typParam, err := p.parseType(pkg, ile.Indices[i], tps) - if err != nil { - return nil, err - } - typeParams = append(typeParams, typParam) - } - } - filePkg, sel := ident.String(), selIdent.String() + filePkg, sel := v.X.(*ast.Ident).String(), v.Sel.String() embeddedPkg, ok := p.imports[filePkg] if !ok { - return nil, p.errorf(ident.Pos(), "unknown package %s", filePkg) + return nil, p.errorf(v.X.Pos(), "unknown package %s", filePkg) } - path = embeddedPkg.Path() var embeddedIface *model.Interface var err error @@ -390,6 +357,7 @@ func (p *fileParser) parseInterface(name, pkg string, it *namedInterface) (*mode return nil, err } } else { + path := embeddedPkg.Path() parser := embeddedPkg.Parser() if parser == nil { ip, err := p.parsePackage(path) @@ -412,13 +380,105 @@ func (p *fileParser) parseInterface(name, pkg string, it *namedInterface) (*mode } // Copy the methods. // TODO: apply shadowing rules. - isGeneric := len(typeParams) > 0 for _, m := range embeddedIface.Methods { - if !isGeneric { - // trivial part - no generic params to consider - iface.AddMethod(m) - continue + iface.AddMethod(m) + } + case *ast.IndexExpr, *ast.IndexListExpr: + // generic embedded interface + // may or may not be external pkg + // *ast.IndexExpr for embedded generic iface with single index e.g. DoSomething[T] + // *ast.IndexListExpr for embedded generic iface with multiple indexes e.g. DoSomething[T, K] + var ( + ident *ast.Ident + selIdent *ast.Ident // selector identity only used in external import + path string + typeParams []model.Type //unsure + ) + if ie, ok := v.(*ast.IndexExpr); ok { + // singular type param + if se, ok := ie.X.(*ast.SelectorExpr); ok { + ident, selIdent = se.X.(*ast.Ident), se.Sel + } else { + ident = ie.X.(*ast.Ident) + } + typParam, err := p.parseType(pkg, ie.Index, tps) + if err != nil { + return nil, err + } + typeParams = append(typeParams, typParam) + } else { + // multiple type params + ile := v.(*ast.IndexListExpr) + if se, ok := ile.X.(*ast.SelectorExpr); ok { + ident, selIdent = se.X.(*ast.Ident), se.Sel + } else { + ident = ile.X.(*ast.Ident) + } + for i := range ile.Indices { + typParam, err := p.parseType(pkg, ile.Indices[i], tps) + if err != nil { + return nil, err + } + typeParams = append(typeParams, typParam) + } + } + + var ( + embeddedIface *model.Interface + err error + ) + + if selIdent == nil { + // trivial part: defined in this pkg + embeddedIfaceType := p.auxInterfaces.Get(pkg, ident.Name) + if embeddedIfaceType == nil { + embeddedIfaceType = p.importedInterfaces.Get(pkg, ident.Name) + } + embeddedIface, err = p.parseInterface(ident.Name, pkg, embeddedIfaceType) + if err != nil { + return nil, err + } + } else { + // non-trivial part: defined in external pkg + filePkg, sel := ident.String(), selIdent.String() + embeddedPkg, ok := p.imports[filePkg] + if !ok { + return nil, p.errorf(ident.Pos(), "unknown package %s", filePkg) + } + path = embeddedPkg.Path() + + embeddedIfaceType := p.auxInterfaces.Get(filePkg, sel) + if embeddedIfaceType != nil { + embeddedIface, err = p.parseInterface(sel, filePkg, embeddedIfaceType) + if err != nil { + return nil, err + } + } else { + parser := embeddedPkg.Parser() + if parser == nil { + ip, err := p.parsePackage(path) + if err != nil { + return nil, p.errorf(v.Pos(), "could not parse package %s: %v", path, err) + } + parser = ip + p.imports[filePkg] = importedPkg{ + path: embeddedPkg.Path(), + parser: parser, + } + } + if embeddedIfaceType = parser.importedInterfaces.Get(path, sel); embeddedIfaceType == nil { + return nil, p.errorf(v.Pos(), "unknown embedded interface %s.%s", path, sel) + } + embeddedIface, err = parser.parseInterface(sel, path, embeddedIfaceType) + if err != nil { + return nil, err + } } + } + + // Copy the methods. + // TODO: apply shadowing rules. + for _, m := range embeddedIface.Methods { // non-trivial part - we have to match up the as-used type params with the as-defined // defined as DoSomething[T any, K any] // used as DoSomething[somPkg.SomeType, int64] From 1f03af7e746adec670ff3682335c2ee81bfa0eb9 Mon Sep 17 00:00:00 2001 From: Bradley Gore Date: Mon, 11 Jul 2022 15:50:50 -0500 Subject: [PATCH 04/14] comment tweaks/cleanup --- mockgen/model/model.go | 10 +++++++++- mockgen/parse.go | 7 ++----- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/mockgen/model/model.go b/mockgen/model/model.go index fc8cda85..5ed9462c 100644 --- a/mockgen/model/model.go +++ b/mockgen/model/model.go @@ -61,7 +61,7 @@ type Interface struct { TypeParams []*Parameter } -// TypeParamIndexByName returns the index of the type parameter matching on name. If none matching, returns -1, nil. +// TypeParamIndexByName returns the index of the type parameter matching on name. If none matching, returns -1. // // This is especially useful for generics where interface is something like this: // Doer[T any, K any]{ @@ -71,9 +71,11 @@ type Interface struct { // } // // But it is used like this: +// [ T , K ] // type MyDoer = Doer[types.SomeType, otherPkg.SomeOtherThing] // or as an embedded interface: // type MyDoer interface { +// [ T , K ] // Doer[types.SomeType, otherPkg.SomeOtherThing] // } // @@ -122,6 +124,12 @@ type Method struct { Variadic *Parameter // may be nil } +// Clone makes a deep clone of a Method. +// +// This is useful specifically for generics so that generic parameters +// from source interface methods (e.g. Iface[T any, R any]) +// can be swapped out with actualized types from a referencing entity +// (e.g. type OtherIface = Iface[external.Foo, Baz]). func (m *Method) Clone() *Method { mm := &Method{ Name: m.Name, diff --git a/mockgen/parse.go b/mockgen/parse.go index e5ff06b1..cd039881 100644 --- a/mockgen/parse.go +++ b/mockgen/parse.go @@ -392,10 +392,9 @@ func (p *fileParser) parseInterface(name, pkg string, it *namedInterface) (*mode ident *ast.Ident selIdent *ast.Ident // selector identity only used in external import path string - typeParams []model.Type //unsure + typeParams []model.Type // normalize to slice whether IndexExpr or IndexListExpr to make it consistent to work with ) if ie, ok := v.(*ast.IndexExpr); ok { - // singular type param if se, ok := ie.X.(*ast.SelectorExpr); ok { ident, selIdent = se.X.(*ast.Ident), se.Sel } else { @@ -407,7 +406,6 @@ func (p *fileParser) parseInterface(name, pkg string, it *namedInterface) (*mode } typeParams = append(typeParams, typParam) } else { - // multiple type params ile := v.(*ast.IndexListExpr) if se, ok := ile.X.(*ast.SelectorExpr); ok { ident, selIdent = se.X.(*ast.Ident), se.Sel @@ -429,7 +427,7 @@ func (p *fileParser) parseInterface(name, pkg string, it *namedInterface) (*mode ) if selIdent == nil { - // trivial part: defined in this pkg + // trivial part: defined in source pkg embeddedIfaceType := p.auxInterfaces.Get(pkg, ident.Name) if embeddedIfaceType == nil { embeddedIfaceType = p.importedInterfaces.Get(pkg, ident.Name) @@ -512,7 +510,6 @@ func (p *fileParser) parseInterface(name, pkg string, it *namedInterface) (*mode } } - //TODO, anything else we need to do here to support generics? iface.AddMethod(gm) } From ac74355d2d11f438c3e699151e998d0fa1a60da1 Mon Sep 17 00:00:00 2001 From: Bradley Gore Date: Wed, 13 Jul 2022 14:16:46 -0500 Subject: [PATCH 05/14] consolidate logic to parser interface to model - imported or same pkg --- mockgen/parse.go | 188 ++++++++++++++++++++--------------------------- 1 file changed, 80 insertions(+), 108 deletions(-) diff --git a/mockgen/parse.go b/mockgen/parse.go index cd039881..8458a7ae 100644 --- a/mockgen/parse.go +++ b/mockgen/parse.go @@ -304,37 +304,9 @@ func (p *fileParser) parseInterface(name, pkg string, it *namedInterface) (*mode iface.AddMethod(m) case *ast.Ident: // Embedded interface in this package. - embeddedIfaceType := p.auxInterfaces.Get(pkg, v.String()) - if embeddedIfaceType == nil { - embeddedIfaceType = p.importedInterfaces.Get(pkg, v.String()) - } - - var embeddedIface *model.Interface - if embeddedIfaceType != nil { - var err error - embeddedIface, err = p.parseInterface(v.String(), pkg, embeddedIfaceType) - if err != nil { - return nil, err - } - } else { - // This is built-in error interface. - if v.String() == model.ErrorInterface.Name { - embeddedIface = &model.ErrorInterface - } else { - ip, err := p.parsePackage(pkg) - if err != nil { - return nil, p.errorf(v.Pos(), "could not parse package %s: %v", pkg, err) - } - - if embeddedIfaceType = ip.importedInterfaces.Get(pkg, v.String()); embeddedIfaceType == nil { - return nil, p.errorf(v.Pos(), "unknown embedded interface %s.%s", pkg, v.String()) - } - - embeddedIface, err = ip.parseInterface(v.String(), pkg, embeddedIfaceType) - if err != nil { - return nil, err - } - } + embeddedIface, err := p.retrieveEmbeddedIfaceModel(pkg, v.String(), v.Pos(), false) + if err != nil { + return nil, err } // Copy the methods. for _, m := range embeddedIface.Methods { @@ -343,40 +315,9 @@ func (p *fileParser) parseInterface(name, pkg string, it *namedInterface) (*mode case *ast.SelectorExpr: // Embedded interface in another package. filePkg, sel := v.X.(*ast.Ident).String(), v.Sel.String() - embeddedPkg, ok := p.imports[filePkg] - if !ok { - return nil, p.errorf(v.X.Pos(), "unknown package %s", filePkg) - } - - var embeddedIface *model.Interface - var err error - embeddedIfaceType := p.auxInterfaces.Get(filePkg, sel) - if embeddedIfaceType != nil { - embeddedIface, err = p.parseInterface(sel, filePkg, embeddedIfaceType) - if err != nil { - return nil, err - } - } else { - path := embeddedPkg.Path() - parser := embeddedPkg.Parser() - if parser == nil { - ip, err := p.parsePackage(path) - if err != nil { - return nil, p.errorf(v.Pos(), "could not parse package %s: %v", path, err) - } - parser = ip - p.imports[filePkg] = importedPkg{ - path: embeddedPkg.Path(), - parser: parser, - } - } - if embeddedIfaceType = parser.importedInterfaces.Get(path, sel); embeddedIfaceType == nil { - return nil, p.errorf(v.Pos(), "unknown embedded interface %s.%s", path, sel) - } - embeddedIface, err = parser.parseInterface(sel, path, embeddedIfaceType) - if err != nil { - return nil, err - } + embeddedIface, err := p.retrieveEmbeddedIfaceModel(filePkg, sel, v.X.Pos(), true) + if err != nil { + return nil, err } // Copy the methods. // TODO: apply shadowing rules. @@ -389,9 +330,9 @@ func (p *fileParser) parseInterface(name, pkg string, it *namedInterface) (*mode // *ast.IndexExpr for embedded generic iface with single index e.g. DoSomething[T] // *ast.IndexListExpr for embedded generic iface with multiple indexes e.g. DoSomething[T, K] var ( - ident *ast.Ident - selIdent *ast.Ident // selector identity only used in external import - path string + ident *ast.Ident + selIdent *ast.Ident // selector identity only used in external import + // path string typeParams []model.Type // normalize to slice whether IndexExpr or IndexListExpr to make it consistent to work with ) if ie, ok := v.(*ast.IndexExpr); ok { @@ -427,50 +368,13 @@ func (p *fileParser) parseInterface(name, pkg string, it *namedInterface) (*mode ) if selIdent == nil { - // trivial part: defined in source pkg - embeddedIfaceType := p.auxInterfaces.Get(pkg, ident.Name) - if embeddedIfaceType == nil { - embeddedIfaceType = p.importedInterfaces.Get(pkg, ident.Name) - } - embeddedIface, err = p.parseInterface(ident.Name, pkg, embeddedIfaceType) - if err != nil { + if embeddedIface, err = p.retrieveEmbeddedIfaceModel(pkg, ident.Name, ident.Pos(), false); err != nil { return nil, err } } else { - // non-trivial part: defined in external pkg filePkg, sel := ident.String(), selIdent.String() - embeddedPkg, ok := p.imports[filePkg] - if !ok { - return nil, p.errorf(ident.Pos(), "unknown package %s", filePkg) - } - path = embeddedPkg.Path() - - embeddedIfaceType := p.auxInterfaces.Get(filePkg, sel) - if embeddedIfaceType != nil { - embeddedIface, err = p.parseInterface(sel, filePkg, embeddedIfaceType) - if err != nil { - return nil, err - } - } else { - parser := embeddedPkg.Parser() - if parser == nil { - ip, err := p.parsePackage(path) - if err != nil { - return nil, p.errorf(v.Pos(), "could not parse package %s: %v", path, err) - } - parser = ip - p.imports[filePkg] = importedPkg{ - path: embeddedPkg.Path(), - parser: parser, - } - } - if embeddedIfaceType = parser.importedInterfaces.Get(path, sel); embeddedIfaceType == nil { - return nil, p.errorf(v.Pos(), "unknown embedded interface %s.%s", path, sel) - } - embeddedIface, err = parser.parseInterface(sel, path, embeddedIfaceType) - if err != nil { - return nil, err - } + if embeddedIface, err = p.retrieveEmbeddedIfaceModel(filePkg, sel, ident.Pos(), true); err != nil { + return nil, err } } @@ -520,6 +424,74 @@ func (p *fileParser) parseInterface(name, pkg string, it *namedInterface) (*mode return iface, nil } +func (p *fileParser) retrieveEmbeddedIfaceModel(pkg, ifaceName string, pos token.Pos, isImport bool) (m *model.Interface, err error) { + var ( + typ *namedInterface + importPkg importedPackage + ) + + if isImport { + var ok bool + if importPkg, ok = p.imports[pkg]; !ok { + err = p.errorf(pos, "unknown package %s", pkg) + return + } + } + + typ = p.auxInterfaces.Get(pkg, ifaceName) + if typ == nil { + typ = p.importedInterfaces.Get(pkg, ifaceName) + } + if typ != nil { + m, err = p.parseInterface(ifaceName, pkg, typ) + return + } + if ifaceName == model.ErrorInterface.Name { + // built-in error interface + m = &model.ErrorInterface + return + } + // parse from pkg (may be current pkg, may be imported pkg) + // so need to get the proper parser for the pkg + var ifaceParser *fileParser + + if importPkg != nil { + // imported pkg + if ifaceParser = importPkg.Parser(); ifaceParser == nil { + path := importPkg.Path() + if ifaceParser, err = p.parsePackage(path); err != nil { + err = p.errorf(pos, "could not parse package %s: %v", path, err) + return + } + p.imports[pkg] = importedPkg{ + path: importPkg.Path(), + parser: ifaceParser, + } + } + typ = ifaceParser.importedInterfaces.Get(importPkg.Path(), ifaceName) + } + + if ifaceParser == nil { + // this pkg + if ifaceParser, err = p.parsePackage(pkg); err != nil { + err = p.errorf(pos, "could not parse package %s: %v", pkg, err) + return + } + typ = ifaceParser.importedInterfaces.Get(pkg, ifaceName) + } + + if typ == nil { + err = p.errorf(pos, "unknown embedded interface %s.%s", pkg, ifaceName) + return + } + + // at this point, whether iface is of imported pkg or same pkg, + // the ifaceParser is appropriate and knows how to parse the iface + m, err = ifaceParser.parseInterface(ifaceName, pkg, typ) + + return +} + func (p *fileParser) parseFunc(pkg string, f *ast.FuncType, tps map[string]bool) (inParam []*model.Parameter, variadic *model.Parameter, outParam []*model.Parameter, err error) { if f.Params != nil { regParams := f.Params.List From 69c4064122e71de12265d936210ec20fe790e0dd Mon Sep 17 00:00:00 2001 From: Bradley Gore Date: Thu, 11 Aug 2022 21:12:30 -0500 Subject: [PATCH 06/14] try to split out generics-specific to generic_go118.go --- mockgen/generic_go118.go | 99 +++++++++++++++++++++++++++++++++++++ mockgen/generic_notgo118.go | 4 ++ mockgen/parse.go | 93 ++-------------------------------- 3 files changed, 106 insertions(+), 90 deletions(-) diff --git a/mockgen/generic_go118.go b/mockgen/generic_go118.go index b29db9a8..69b582db 100644 --- a/mockgen/generic_go118.go +++ b/mockgen/generic_go118.go @@ -86,3 +86,102 @@ func getIdentTypeParams(decl interface{}) string { sb.WriteString("]") return sb.String() } + +func (p *fileParser) parseEmbeddedGenericIface(iface *model.Interface, field *ast.Field, pkg string, tps map[string]bool) (wasGeneric bool, err error) { + switch v := field.Type.(type) { + case *ast.IndexExpr, *ast.IndexListExpr: + wasGeneric = true + // generic embedded interface + // may or may not be external pkg + // *ast.IndexExpr for embedded generic iface with single index e.g. DoSomething[T] + // *ast.IndexListExpr for embedded generic iface with multiple indexes e.g. DoSomething[T, K] + var ( + ident *ast.Ident + selIdent *ast.Ident // selector identity only used in external import + // path string + typeParams []model.Type // normalize to slice whether IndexExpr or IndexListExpr to make it consistent to work with + ) + if ie, ok := v.(*ast.IndexExpr); ok { + if se, ok := ie.X.(*ast.SelectorExpr); ok { + ident, selIdent = se.X.(*ast.Ident), se.Sel + } else { + ident = ie.X.(*ast.Ident) + } + var typParam model.Type + if typParam, err = p.parseType(pkg, ie.Index, tps); err != nil { + return + } + typeParams = append(typeParams, typParam) + } else { + ile := v.(*ast.IndexListExpr) + if se, ok := ile.X.(*ast.SelectorExpr); ok { + ident, selIdent = se.X.(*ast.Ident), se.Sel + } else { + ident = ile.X.(*ast.Ident) + } + var typParam model.Type + for i := range ile.Indices { + if typParam, err = p.parseType(pkg, ile.Indices[i], tps); err != nil { + return + } + typeParams = append(typeParams, typParam) + } + } + + var ( + embeddedIface *model.Interface + ) + + if selIdent == nil { + if embeddedIface, err = p.retrieveEmbeddedIfaceModel(pkg, ident.Name, ident.Pos(), false); err != nil { + return + } + } else { + filePkg, sel := ident.String(), selIdent.String() + if embeddedIface, err = p.retrieveEmbeddedIfaceModel(filePkg, sel, ident.Pos(), true); err != nil { + return + } + } + + // Copy the methods. + // TODO: apply shadowing rules. + for _, m := range embeddedIface.Methods { + // non-trivial part - we have to match up the as-used type params with the as-defined + // defined as DoSomething[T any, K any] + // used as DoSomething[somPkg.SomeType, int64] + // meaning methods may be like in definition: + // Do(T) (K, error) + // but need to be like this in implementation: + // Do(somePkg.SomeType) (int64, error) + gm := m.Clone() // clone so we can change without changing source def + + // overwrite all typed params for incoming/outgoing params + // to get the implementor-specified typing over the definition-specified typing + + for _, pim := range gm.In { + if nt, ok := pim.Type.(*model.NamedType); ok && nt.TypeParams != nil { + for i, tp := range nt.TypeParams.TypeParameters { + if srcParamIdx := embeddedIface.TypeParamIndexByName(tp.String(nil, "")); srcParamIdx > -1 && srcParamIdx < len(typeParams) { + dstParamTyp := typeParams[srcParamIdx] + nt.TypeParams.TypeParameters[i] = dstParamTyp + } + } + } + } + for _, out := range gm.Out { + if nt, ok := out.Type.(*model.NamedType); ok && nt.TypeParams != nil { + for i, tp := range nt.TypeParams.TypeParameters { + if srcParamIdx := embeddedIface.TypeParamIndexByName(tp.String(nil, "")); srcParamIdx > -1 && srcParamIdx < len(typeParams) { + dstParamTyp := typeParams[srcParamIdx] + nt.TypeParams.TypeParameters[i] = dstParamTyp + } + } + } + } + + iface.AddMethod(gm) + } + } + + return +} diff --git a/mockgen/generic_notgo118.go b/mockgen/generic_notgo118.go index 8fe48c17..87c8d715 100644 --- a/mockgen/generic_notgo118.go +++ b/mockgen/generic_notgo118.go @@ -34,3 +34,7 @@ func (p *fileParser) parseGenericType(pkg string, typ ast.Expr, tps map[string]b func getIdentTypeParams(decl interface{}) string { return "" } + +func (p *fileParser) parseEmbeddedGenericIface(iface *model.Interface, field *ast.Field, pkg string, tps map[string]bool) (wasGeneric bool, err error) { + return false, nil +} diff --git a/mockgen/parse.go b/mockgen/parse.go index 8458a7ae..1192cd01 100644 --- a/mockgen/parse.go +++ b/mockgen/parse.go @@ -324,100 +324,13 @@ func (p *fileParser) parseInterface(name, pkg string, it *namedInterface) (*mode for _, m := range embeddedIface.Methods { iface.AddMethod(m) } - case *ast.IndexExpr, *ast.IndexListExpr: - // generic embedded interface - // may or may not be external pkg - // *ast.IndexExpr for embedded generic iface with single index e.g. DoSomething[T] - // *ast.IndexListExpr for embedded generic iface with multiple indexes e.g. DoSomething[T, K] - var ( - ident *ast.Ident - selIdent *ast.Ident // selector identity only used in external import - // path string - typeParams []model.Type // normalize to slice whether IndexExpr or IndexListExpr to make it consistent to work with - ) - if ie, ok := v.(*ast.IndexExpr); ok { - if se, ok := ie.X.(*ast.SelectorExpr); ok { - ident, selIdent = se.X.(*ast.Ident), se.Sel - } else { - ident = ie.X.(*ast.Ident) - } - typParam, err := p.parseType(pkg, ie.Index, tps) + default: + if wasEmbeddedGeneric, err := p.parseEmbeddedGenericIface(iface, field, pkg, tps); wasEmbeddedGeneric { if err != nil { return nil, err } - typeParams = append(typeParams, typParam) - } else { - ile := v.(*ast.IndexListExpr) - if se, ok := ile.X.(*ast.SelectorExpr); ok { - ident, selIdent = se.X.(*ast.Ident), se.Sel - } else { - ident = ile.X.(*ast.Ident) - } - for i := range ile.Indices { - typParam, err := p.parseType(pkg, ile.Indices[i], tps) - if err != nil { - return nil, err - } - typeParams = append(typeParams, typParam) - } - } - - var ( - embeddedIface *model.Interface - err error - ) - - if selIdent == nil { - if embeddedIface, err = p.retrieveEmbeddedIfaceModel(pkg, ident.Name, ident.Pos(), false); err != nil { - return nil, err - } - } else { - filePkg, sel := ident.String(), selIdent.String() - if embeddedIface, err = p.retrieveEmbeddedIfaceModel(filePkg, sel, ident.Pos(), true); err != nil { - return nil, err - } + return iface, nil } - - // Copy the methods. - // TODO: apply shadowing rules. - for _, m := range embeddedIface.Methods { - // non-trivial part - we have to match up the as-used type params with the as-defined - // defined as DoSomething[T any, K any] - // used as DoSomething[somPkg.SomeType, int64] - // meaning methods may be like in definition: - // Do(T) (K, error) - // but need to be like this in implementation: - // Do(somePkg.SomeType) (int64, error) - gm := m.Clone() // clone so we can change without changing source def - - // overwrite all typed params for incoming/outgoing params - // to get the implementor-specified typing over the definition-specified typing - - for _, pim := range gm.In { - if nt, ok := pim.Type.(*model.NamedType); ok && nt.TypeParams != nil { - for i, tp := range nt.TypeParams.TypeParameters { - if srcParamIdx := embeddedIface.TypeParamIndexByName(tp.String(nil, "")); srcParamIdx > -1 && srcParamIdx < len(typeParams) { - dstParamTyp := typeParams[srcParamIdx] - nt.TypeParams.TypeParameters[i] = dstParamTyp - } - } - } - } - for _, out := range gm.Out { - if nt, ok := out.Type.(*model.NamedType); ok && nt.TypeParams != nil { - for i, tp := range nt.TypeParams.TypeParameters { - if srcParamIdx := embeddedIface.TypeParamIndexByName(tp.String(nil, "")); srcParamIdx > -1 && srcParamIdx < len(typeParams) { - dstParamTyp := typeParams[srcParamIdx] - nt.TypeParams.TypeParameters[i] = dstParamTyp - } - } - } - } - - iface.AddMethod(gm) - - } - default: return nil, fmt.Errorf("don't know how to mock method of type %T", field.Type) } } From ebb9abd0a15cd609d98e609cbbd5683e8a7a0ed8 Mon Sep 17 00:00:00 2001 From: Bradley Gore Date: Thu, 11 Aug 2022 23:14:25 -0500 Subject: [PATCH 07/14] further solidify parsing generic types on inputs/outputs and add new scenarios --- mockgen/generic_go118.go | 76 ++++++++++++++++----- mockgen/internal/tests/generics/generics.go | 2 + 2 files changed, 62 insertions(+), 16 deletions(-) diff --git a/mockgen/generic_go118.go b/mockgen/generic_go118.go index 69b582db..0f3691ba 100644 --- a/mockgen/generic_go118.go +++ b/mockgen/generic_go118.go @@ -11,6 +11,7 @@ package main import ( + "fmt" "go/ast" "strings" @@ -158,24 +159,24 @@ func (p *fileParser) parseEmbeddedGenericIface(iface *model.Interface, field *as // overwrite all typed params for incoming/outgoing params // to get the implementor-specified typing over the definition-specified typing - for _, pim := range gm.In { - if nt, ok := pim.Type.(*model.NamedType); ok && nt.TypeParams != nil { - for i, tp := range nt.TypeParams.TypeParameters { - if srcParamIdx := embeddedIface.TypeParamIndexByName(tp.String(nil, "")); srcParamIdx > -1 && srcParamIdx < len(typeParams) { - dstParamTyp := typeParams[srcParamIdx] - nt.TypeParams.TypeParameters[i] = dstParamTyp - } - } + for pinIdx, pin := range gm.In { + switch t := pin.Type.(type) { + case *model.NamedType: + p.populateTypedParamsFromNamedType(t, embeddedIface, typeParams) + case model.PredeclaredType: + p.populateTypedParamsFromPredeclaredType(t, pinIdx, gm.In, embeddedIface, typeParams) + case *model.PointerType: + p.populateTypedParamsFromPointerType(t, embeddedIface, typeParams) } } - for _, out := range gm.Out { - if nt, ok := out.Type.(*model.NamedType); ok && nt.TypeParams != nil { - for i, tp := range nt.TypeParams.TypeParameters { - if srcParamIdx := embeddedIface.TypeParamIndexByName(tp.String(nil, "")); srcParamIdx > -1 && srcParamIdx < len(typeParams) { - dstParamTyp := typeParams[srcParamIdx] - nt.TypeParams.TypeParameters[i] = dstParamTyp - } - } + for outIdx, out := range gm.Out { + switch t := out.Type.(type) { + case *model.NamedType: + p.populateTypedParamsFromNamedType(t, embeddedIface, typeParams) + case model.PredeclaredType: + p.populateTypedParamsFromPredeclaredType(t, outIdx, gm.Out, embeddedIface, typeParams) + case *model.PointerType: + p.populateTypedParamsFromPointerType(t, embeddedIface, typeParams) } } @@ -185,3 +186,46 @@ func (p *fileParser) parseEmbeddedGenericIface(iface *model.Interface, field *as return } + +func (p *fileParser) populateTypedParamsFromNamedType(nt *model.NamedType, iface *model.Interface, knownTypeParams []model.Type) { + if nt.TypeParams == nil { + return + } + + for i, tp := range nt.TypeParams.TypeParameters { + switch tpt := tp.(type) { + case *model.PointerType: + p.populateTypedParamsFromPointerType(tpt, iface, knownTypeParams) + default: + if srcParamIdx := iface.TypeParamIndexByName(tp.String(nil, "")); srcParamIdx > -1 && srcParamIdx < len(knownTypeParams) { + dstParamTyp := knownTypeParams[srcParamIdx] + nt.TypeParams.TypeParameters[i] = dstParamTyp + } + } + } +} + +func (p *fileParser) populateTypedParamsFromPredeclaredType(pt model.PredeclaredType, paramIdx int, inOrOutParams []*model.Parameter, iface *model.Interface, knownTypParams []model.Type) { + if srcParamIdx := iface.TypeParamIndexByName(pt.String(nil, "")); srcParamIdx > -1 { + dstParamTyp := knownTypParams[srcParamIdx] + inOrOutParams[paramIdx] = &model.Parameter{ + Name: "", + Type: dstParamTyp, + } + } +} + +func (p *fileParser) populateTypedParamsFromPointerType(pt *model.PointerType, iface *model.Interface, knownTypeParams []model.Type) { + switch t := pt.Type.(type) { + case model.PredeclaredType: + parms := make([]*model.Parameter, 1) + p.populateTypedParamsFromPredeclaredType(t, 0, parms, iface, knownTypeParams) + if parms[0] != nil { + pt.Type = parms[0].Type + } + case *model.NamedType: + p.populateTypedParamsFromNamedType(t, iface, knownTypeParams) + default: + fmt.Println("unhandled model PointerType") + } +} diff --git a/mockgen/internal/tests/generics/generics.go b/mockgen/internal/tests/generics/generics.go index 6f0cc815..93df22d5 100644 --- a/mockgen/internal/tests/generics/generics.go +++ b/mockgen/internal/tests/generics/generics.go @@ -25,6 +25,8 @@ type Bar[T any, R any] interface { Seventeen() (*Foo[other.Three, other.Four], error) Eighteen() (Iface[*other.Five], error) Nineteen() AliasType + Twenty(*other.One[T]) *other.Two[T, R] + TwentyOne(*string) *other.Two[*T, *R] } type Foo[T any, R any] struct{} From beda6b92842995efe90e431a19f72af5c266143c Mon Sep 17 00:00:00 2001 From: Bradley Gore Date: Thu, 11 Aug 2022 23:44:35 -0500 Subject: [PATCH 08/14] more tweaks --- .../generics/source/mock_generics_test.go | 411 +++++++++++++++++- mockgen/parse.go | 2 +- 2 files changed, 410 insertions(+), 3 deletions(-) diff --git a/mockgen/internal/tests/generics/source/mock_generics_test.go b/mockgen/internal/tests/generics/source/mock_generics_test.go index 0223e311..d7c16509 100644 --- a/mockgen/internal/tests/generics/source/mock_generics_test.go +++ b/mockgen/internal/tests/generics/source/mock_generics_test.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: generics.go +// Source: /Users/bgore/git/bdg-gomock/mockgen/internal/tests/generics/generics.go // Package source is a generated GoMock package. package source @@ -7,9 +7,9 @@ package source import ( reflect "reflect" + "github.com/golang/mock/mockgen/internal/tests/generics/other" gomock "github.com/golang/mock/gomock" generics "github.com/golang/mock/mockgen/internal/tests/generics" - other "github.com/golang/mock/mockgen/internal/tests/generics/other" ) // MockBar is a mock of Bar interface. @@ -291,6 +291,34 @@ func (mr *MockBarMockRecorder[T, R]) Twelve() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Twelve", reflect.TypeOf((*MockBar[T, R])(nil).Twelve)) } +// Twenty mocks base method. +func (m *MockBar[T, R]) Twenty(arg0 *other.One[T]) *other.Two[T, R] { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Twenty", arg0) + ret0, _ := ret[0].(*other.Two[T, R]) + return ret0 +} + +// Twenty indicates an expected call of Twenty. +func (mr *MockBarMockRecorder[T, R]) Twenty(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Twenty", reflect.TypeOf((*MockBar[T, R])(nil).Twenty), arg0) +} + +// TwentyOne mocks base method. +func (m *MockBar[T, R]) TwentyOne(arg0 *string) *other.Two[*T, *R] { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TwentyOne", arg0) + ret0, _ := ret[0].(*other.Two[*T, *R]) + return ret0 +} + +// TwentyOne indicates an expected call of TwentyOne. +func (mr *MockBarMockRecorder[T, R]) TwentyOne(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TwentyOne", reflect.TypeOf((*MockBar[T, R])(nil).TwentyOne), arg0) +} + // Two mocks base method. func (m *MockBar[T, R]) Two(arg0 T) string { m.ctrl.T.Helper() @@ -327,3 +355,382 @@ func NewMockIface[T any](ctrl *gomock.Controller) *MockIface[T] { func (m *MockIface[T]) EXPECT() *MockIfaceMockRecorder[T] { return m.recorder } + +// MockEmbeddingIface is a mock of EmbeddingIface interface. +type MockEmbeddingIface struct { + ctrl *gomock.Controller + recorder *MockEmbeddingIfaceMockRecorder +} + +// MockEmbeddingIfaceMockRecorder is the mock recorder for MockEmbeddingIface. +type MockEmbeddingIfaceMockRecorder struct { + mock *MockEmbeddingIface +} + +// NewMockEmbeddingIface creates a new mock instance. +func NewMockEmbeddingIface(ctrl *gomock.Controller) *MockEmbeddingIface { + mock := &MockEmbeddingIface{ctrl: ctrl} + mock.recorder = &MockEmbeddingIfaceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockEmbeddingIface) EXPECT() *MockEmbeddingIfaceMockRecorder { + return m.recorder +} + +// DoR mocks base method. +func (m *MockEmbeddingIface) DoR(arg0 other.Five) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DoR", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// DoR indicates an expected call of DoR. +func (mr *MockEmbeddingIfaceMockRecorder) DoR(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DoR", reflect.TypeOf((*MockEmbeddingIface)(nil).DoR), arg0) +} + +// DoT mocks base method. +func (m *MockEmbeddingIface) DoT(arg0 generics.StructType) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DoT", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// DoT indicates an expected call of DoT. +func (mr *MockEmbeddingIfaceMockRecorder) DoT(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DoT", reflect.TypeOf((*MockEmbeddingIface)(nil).DoT), arg0) +} + +// Eight mocks base method. +func (m *MockEmbeddingIface) Eight(arg0 other.Three) other.Two[other.Three, error] { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Eight", arg0) + ret0, _ := ret[0].(other.Two[other.Three, error]) + return ret0 +} + +// Eight indicates an expected call of Eight. +func (mr *MockEmbeddingIfaceMockRecorder) Eight(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Eight", reflect.TypeOf((*MockEmbeddingIface)(nil).Eight), arg0) +} + +// Eighteen mocks base method. +func (m *MockEmbeddingIface) Eighteen() (generics.Iface[*other.Five], error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Eighteen") + ret0, _ := ret[0].(generics.Iface[*other.Five]) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Eighteen indicates an expected call of Eighteen. +func (mr *MockEmbeddingIfaceMockRecorder) Eighteen() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Eighteen", reflect.TypeOf((*MockEmbeddingIface)(nil).Eighteen)) +} + +// Eleven mocks base method. +func (m *MockEmbeddingIface) Eleven() (*other.One[other.Three], error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Eleven") + ret0, _ := ret[0].(*other.One[other.Three]) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Eleven indicates an expected call of Eleven. +func (mr *MockEmbeddingIfaceMockRecorder) Eleven() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Eleven", reflect.TypeOf((*MockEmbeddingIface)(nil).Eleven)) +} + +// Fifteen mocks base method. +func (m *MockEmbeddingIface) Fifteen() (generics.Iface[generics.StructType], error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Fifteen") + ret0, _ := ret[0].(generics.Iface[generics.StructType]) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Fifteen indicates an expected call of Fifteen. +func (mr *MockEmbeddingIfaceMockRecorder) Fifteen() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Fifteen", reflect.TypeOf((*MockEmbeddingIface)(nil).Fifteen)) +} + +// Five mocks base method. +func (m *MockEmbeddingIface) Five(arg0 other.Three) generics.Baz[other.Three] { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Five", arg0) + ret0, _ := ret[0].(generics.Baz[other.Three]) + return ret0 +} + +// Five indicates an expected call of Five. +func (mr *MockEmbeddingIfaceMockRecorder) Five(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Five", reflect.TypeOf((*MockEmbeddingIface)(nil).Five), arg0) +} + +// Four mocks base method. +func (m *MockEmbeddingIface) Four(arg0 other.Three) generics.Foo[other.Three, error] { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Four", arg0) + ret0, _ := ret[0].(generics.Foo[other.Three, error]) + return ret0 +} + +// Four indicates an expected call of Four. +func (mr *MockEmbeddingIfaceMockRecorder) Four(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Four", reflect.TypeOf((*MockEmbeddingIface)(nil).Four), arg0) +} + +// Fourteen mocks base method. +func (m *MockEmbeddingIface) Fourteen() (*generics.Foo[generics.StructType, generics.StructType2], error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Fourteen") + ret0, _ := ret[0].(*generics.Foo[generics.StructType, generics.StructType2]) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Fourteen indicates an expected call of Fourteen. +func (mr *MockEmbeddingIfaceMockRecorder) Fourteen() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Fourteen", reflect.TypeOf((*MockEmbeddingIface)(nil).Fourteen)) +} + +// LocalFunc mocks base method. +func (m *MockEmbeddingIface) LocalFunc() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LocalFunc") + ret0, _ := ret[0].(error) + return ret0 +} + +// LocalFunc indicates an expected call of LocalFunc. +func (mr *MockEmbeddingIfaceMockRecorder) LocalFunc() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalFunc", reflect.TypeOf((*MockEmbeddingIface)(nil).LocalFunc)) +} + +// MakeThem mocks base method. +func (m *MockEmbeddingIface) MakeThem() (generics.StructType, other.Five, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MakeThem") + ret0, _ := ret[0].(generics.StructType) + ret1, _ := ret[1].(other.Five) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// MakeThem indicates an expected call of MakeThem. +func (mr *MockEmbeddingIfaceMockRecorder) MakeThem() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MakeThem", reflect.TypeOf((*MockEmbeddingIface)(nil).MakeThem)) +} + +// Nine mocks base method. +func (m *MockEmbeddingIface) Nine(arg0 generics.Iface[other.Three]) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Nine", arg0) +} + +// Nine indicates an expected call of Nine. +func (mr *MockEmbeddingIfaceMockRecorder) Nine(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Nine", reflect.TypeOf((*MockEmbeddingIface)(nil).Nine), arg0) +} + +// Nineteen mocks base method. +func (m *MockEmbeddingIface) Nineteen() generics.AliasType { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Nineteen") + ret0, _ := ret[0].(generics.AliasType) + return ret0 +} + +// Nineteen indicates an expected call of Nineteen. +func (mr *MockEmbeddingIfaceMockRecorder) Nineteen() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Nineteen", reflect.TypeOf((*MockEmbeddingIface)(nil).Nineteen)) +} + +// One mocks base method. +func (m *MockEmbeddingIface) One(arg0 string) string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "One", arg0) + ret0, _ := ret[0].(string) + return ret0 +} + +// One indicates an expected call of One. +func (mr *MockEmbeddingIfaceMockRecorder) One(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "One", reflect.TypeOf((*MockEmbeddingIface)(nil).One), arg0) +} + +// Seven mocks base method. +func (m *MockEmbeddingIface) Seven(arg0 other.Three) other.One[other.Three] { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Seven", arg0) + ret0, _ := ret[0].(other.One[other.Three]) + return ret0 +} + +// Seven indicates an expected call of Seven. +func (mr *MockEmbeddingIfaceMockRecorder) Seven(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Seven", reflect.TypeOf((*MockEmbeddingIface)(nil).Seven), arg0) +} + +// Seventeen mocks base method. +func (m *MockEmbeddingIface) Seventeen() (*generics.Foo[other.Three, other.Four], error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Seventeen") + ret0, _ := ret[0].(*generics.Foo[other.Three, other.Four]) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Seventeen indicates an expected call of Seventeen. +func (mr *MockEmbeddingIfaceMockRecorder) Seventeen() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Seventeen", reflect.TypeOf((*MockEmbeddingIface)(nil).Seventeen)) +} + +// Six mocks base method. +func (m *MockEmbeddingIface) Six(arg0 other.Three) *generics.Baz[other.Three] { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Six", arg0) + ret0, _ := ret[0].(*generics.Baz[other.Three]) + return ret0 +} + +// Six indicates an expected call of Six. +func (mr *MockEmbeddingIfaceMockRecorder) Six(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Six", reflect.TypeOf((*MockEmbeddingIface)(nil).Six), arg0) +} + +// Sixteen mocks base method. +func (m *MockEmbeddingIface) Sixteen() (generics.Baz[other.Three], error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Sixteen") + ret0, _ := ret[0].(generics.Baz[other.Three]) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Sixteen indicates an expected call of Sixteen. +func (mr *MockEmbeddingIfaceMockRecorder) Sixteen() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Sixteen", reflect.TypeOf((*MockEmbeddingIface)(nil).Sixteen)) +} + +// Ten mocks base method. +func (m *MockEmbeddingIface) Ten(arg0 *other.Three) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Ten", arg0) +} + +// Ten indicates an expected call of Ten. +func (mr *MockEmbeddingIfaceMockRecorder) Ten(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ten", reflect.TypeOf((*MockEmbeddingIface)(nil).Ten), arg0) +} + +// Thirteen mocks base method. +func (m *MockEmbeddingIface) Thirteen() (generics.Baz[generics.StructType], error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Thirteen") + ret0, _ := ret[0].(generics.Baz[generics.StructType]) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Thirteen indicates an expected call of Thirteen. +func (mr *MockEmbeddingIfaceMockRecorder) Thirteen() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Thirteen", reflect.TypeOf((*MockEmbeddingIface)(nil).Thirteen)) +} + +// Three mocks base method. +func (m *MockEmbeddingIface) Three(arg0 other.Three) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Three", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Three indicates an expected call of Three. +func (mr *MockEmbeddingIfaceMockRecorder) Three(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Three", reflect.TypeOf((*MockEmbeddingIface)(nil).Three), arg0) +} + +// Twelve mocks base method. +func (m *MockEmbeddingIface) Twelve() (*other.Two[other.Three, error], error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Twelve") + ret0, _ := ret[0].(*other.Two[other.Three, error]) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Twelve indicates an expected call of Twelve. +func (mr *MockEmbeddingIfaceMockRecorder) Twelve() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Twelve", reflect.TypeOf((*MockEmbeddingIface)(nil).Twelve)) +} + +// Twenty mocks base method. +func (m *MockEmbeddingIface) Twenty(arg0 *other.One[other.Three]) *other.Two[other.Three, error] { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Twenty", arg0) + ret0, _ := ret[0].(*other.Two[other.Three, error]) + return ret0 +} + +// Twenty indicates an expected call of Twenty. +func (mr *MockEmbeddingIfaceMockRecorder) Twenty(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Twenty", reflect.TypeOf((*MockEmbeddingIface)(nil).Twenty), arg0) +} + +// TwentyOne mocks base method. +func (m *MockEmbeddingIface) TwentyOne(arg0 *string) *other.Two[*other.Three, *error] { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TwentyOne", arg0) + ret0, _ := ret[0].(*other.Two[*other.Three, *error]) + return ret0 +} + +// TwentyOne indicates an expected call of TwentyOne. +func (mr *MockEmbeddingIfaceMockRecorder) TwentyOne(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TwentyOne", reflect.TypeOf((*MockEmbeddingIface)(nil).TwentyOne), arg0) +} + +// Two mocks base method. +func (m *MockEmbeddingIface) Two(arg0 other.Three) string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Two", arg0) + ret0, _ := ret[0].(string) + return ret0 +} + +// Two indicates an expected call of Two. +func (mr *MockEmbeddingIfaceMockRecorder) Two(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Two", reflect.TypeOf((*MockEmbeddingIface)(nil).Two), arg0) +} diff --git a/mockgen/parse.go b/mockgen/parse.go index 1192cd01..62ee6a68 100644 --- a/mockgen/parse.go +++ b/mockgen/parse.go @@ -329,7 +329,7 @@ func (p *fileParser) parseInterface(name, pkg string, it *namedInterface) (*mode if err != nil { return nil, err } - return iface, nil + continue } return nil, fmt.Errorf("don't know how to mock method of type %T", field.Type) } From 26cf1829ccf4d38a96018c976c07e37d441b1fd5 Mon Sep 17 00:00:00 2001 From: Bradley Gore Date: Fri, 12 Aug 2022 18:30:37 -0500 Subject: [PATCH 09/14] regenerate source/mock_generics_test.go --- mockgen/internal/tests/generics/source/mock_generics_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mockgen/internal/tests/generics/source/mock_generics_test.go b/mockgen/internal/tests/generics/source/mock_generics_test.go index d7c16509..785784bc 100644 --- a/mockgen/internal/tests/generics/source/mock_generics_test.go +++ b/mockgen/internal/tests/generics/source/mock_generics_test.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: /Users/bgore/git/bdg-gomock/mockgen/internal/tests/generics/generics.go +// Source: generics.go // Package source is a generated GoMock package. package source @@ -7,9 +7,9 @@ package source import ( reflect "reflect" - "github.com/golang/mock/mockgen/internal/tests/generics/other" gomock "github.com/golang/mock/gomock" generics "github.com/golang/mock/mockgen/internal/tests/generics" + other "github.com/golang/mock/mockgen/internal/tests/generics/other" ) // MockBar is a mock of Bar interface. From aa62cf4979eb64982cbec7d034cb1d0fadff3720 Mon Sep 17 00:00:00 2001 From: Bradley Gore Date: Fri, 12 Aug 2022 21:25:52 -0500 Subject: [PATCH 10/14] fix embedded iface external pkg parser --- mockgen/parse.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mockgen/parse.go b/mockgen/parse.go index 62ee6a68..3469029a 100644 --- a/mockgen/parse.go +++ b/mockgen/parse.go @@ -398,6 +398,10 @@ func (p *fileParser) retrieveEmbeddedIfaceModel(pkg, ifaceName string, pos token return } + if importPkg != nil { + pkg = importPkg.Path() + } + // at this point, whether iface is of imported pkg or same pkg, // the ifaceParser is appropriate and knows how to parse the iface m, err = ifaceParser.parseInterface(ifaceName, pkg, typ) From e11b27601f1f115ca952ddfcc9dcf89adb11323d Mon Sep 17 00:00:00 2001 From: Bradley Gore Date: Fri, 14 Oct 2022 22:21:47 -0500 Subject: [PATCH 11/14] handle some additional cases for parameters dealing with []*T, map[T] etc... --- mockgen/generic_go118.go | 127 +++++++++++------- .../internal/tests/generics/other/other.go | 5 + 2 files changed, 84 insertions(+), 48 deletions(-) diff --git a/mockgen/generic_go118.go b/mockgen/generic_go118.go index 0f3691ba..8c86e096 100644 --- a/mockgen/generic_go118.go +++ b/mockgen/generic_go118.go @@ -11,7 +11,6 @@ package main import ( - "fmt" "go/ast" "strings" @@ -160,23 +159,13 @@ func (p *fileParser) parseEmbeddedGenericIface(iface *model.Interface, field *as // to get the implementor-specified typing over the definition-specified typing for pinIdx, pin := range gm.In { - switch t := pin.Type.(type) { - case *model.NamedType: - p.populateTypedParamsFromNamedType(t, embeddedIface, typeParams) - case model.PredeclaredType: - p.populateTypedParamsFromPredeclaredType(t, pinIdx, gm.In, embeddedIface, typeParams) - case *model.PointerType: - p.populateTypedParamsFromPointerType(t, embeddedIface, typeParams) + if genType, hasGeneric := p.getTypedParamForGeneric(pin.Type, embeddedIface, typeParams); hasGeneric { + gm.In[pinIdx].Type = genType } } for outIdx, out := range gm.Out { - switch t := out.Type.(type) { - case *model.NamedType: - p.populateTypedParamsFromNamedType(t, embeddedIface, typeParams) - case model.PredeclaredType: - p.populateTypedParamsFromPredeclaredType(t, outIdx, gm.Out, embeddedIface, typeParams) - case *model.PointerType: - p.populateTypedParamsFromPointerType(t, embeddedIface, typeParams) + if genType, hasGeneric := p.getTypedParamForGeneric(out.Type, embeddedIface, typeParams); hasGeneric { + gm.Out[outIdx].Type = genType } } @@ -187,45 +176,87 @@ func (p *fileParser) parseEmbeddedGenericIface(iface *model.Interface, field *as return } -func (p *fileParser) populateTypedParamsFromNamedType(nt *model.NamedType, iface *model.Interface, knownTypeParams []model.Type) { - if nt.TypeParams == nil { - return - } - - for i, tp := range nt.TypeParams.TypeParameters { - switch tpt := tp.(type) { - case *model.PointerType: - p.populateTypedParamsFromPointerType(tpt, iface, knownTypeParams) - default: +// getTypedParamForGeneric is recursive func to hydrate all generic types within a model.Type +// so they get populated instead with the actual desired target types +func (p *fileParser) getTypedParamForGeneric(t model.Type, iface *model.Interface, knownTypeParams []model.Type) (model.Type, bool) { + switch typ := t.(type) { + case *model.ArrayType: + if gType, wasGeneric := p.getTypedParamForGeneric(typ.Type, iface, knownTypeParams); wasGeneric { + typ.Type = gType + return typ, true + } + case *model.ChanType: + if gType, wasGeneric := p.getTypedParamForGeneric(typ.Type, iface, knownTypeParams); wasGeneric { + typ.Type = gType + return typ, true + } + case *model.FuncType: + hasGeneric := false + for inIdx, inParam := range typ.In { + if genType, ok := p.getTypedParamForGeneric(inParam.Type, iface, knownTypeParams); ok { + hasGeneric = true + typ.In[inIdx].Type = genType + } + } + for outIdx, outParam := range typ.Out { + if genType, ok := p.getTypedParamForGeneric(outParam.Type, iface, knownTypeParams); ok { + hasGeneric = true + typ.Out[outIdx].Type = genType + } + } + if typ.Variadic != nil { + if genType, ok := p.getTypedParamForGeneric(typ.Variadic.Type, iface, knownTypeParams); ok { + hasGeneric = true + typ.Variadic.Type = genType + } + } + if hasGeneric { + return typ, true + } + case *model.MapType: + var ( + keyTyp, valTyp model.Type + wasKeyGeneric, wasValGeneric bool + ) + if keyTyp, wasKeyGeneric = p.getTypedParamForGeneric(typ.Key, iface, knownTypeParams); wasKeyGeneric { + typ.Key = keyTyp + } + if valTyp, wasValGeneric = p.getTypedParamForGeneric(typ.Value, iface, knownTypeParams); wasValGeneric { + typ.Value = valTyp + } + if wasKeyGeneric || wasValGeneric { + return typ, true + } + case *model.NamedType: + if typ.TypeParams == nil { + return nil, false + } + hasGeneric := false + for i, tp := range typ.TypeParams.TypeParameters { + // it will either be a type with name matching a generic parameter + // or it will be something like ptr or slice etc... if srcParamIdx := iface.TypeParamIndexByName(tp.String(nil, "")); srcParamIdx > -1 && srcParamIdx < len(knownTypeParams) { + hasGeneric = true dstParamTyp := knownTypeParams[srcParamIdx] - nt.TypeParams.TypeParameters[i] = dstParamTyp + typ.TypeParams.TypeParameters[i] = dstParamTyp + } else if _, ok := p.getTypedParamForGeneric(tp, iface, knownTypeParams); ok { + hasGeneric = true } } - } -} - -func (p *fileParser) populateTypedParamsFromPredeclaredType(pt model.PredeclaredType, paramIdx int, inOrOutParams []*model.Parameter, iface *model.Interface, knownTypParams []model.Type) { - if srcParamIdx := iface.TypeParamIndexByName(pt.String(nil, "")); srcParamIdx > -1 { - dstParamTyp := knownTypParams[srcParamIdx] - inOrOutParams[paramIdx] = &model.Parameter{ - Name: "", - Type: dstParamTyp, + if hasGeneric { + return typ, true } - } -} - -func (p *fileParser) populateTypedParamsFromPointerType(pt *model.PointerType, iface *model.Interface, knownTypeParams []model.Type) { - switch t := pt.Type.(type) { case model.PredeclaredType: - parms := make([]*model.Parameter, 1) - p.populateTypedParamsFromPredeclaredType(t, 0, parms, iface, knownTypeParams) - if parms[0] != nil { - pt.Type = parms[0].Type + if srcParamIdx := iface.TypeParamIndexByName(typ.String(nil, "")); srcParamIdx > -1 { + dstParamTyp := knownTypeParams[srcParamIdx] + return dstParamTyp, true + } + case *model.PointerType: + if gType, hasGeneric := p.getTypedParamForGeneric(typ.Type, iface, knownTypeParams); hasGeneric { + typ.Type = gType + return typ, true } - case *model.NamedType: - p.populateTypedParamsFromNamedType(t, iface, knownTypeParams) - default: - fmt.Println("unhandled model PointerType") } + + return nil, false } diff --git a/mockgen/internal/tests/generics/other/other.go b/mockgen/internal/tests/generics/other/other.go index 374aac19..a90fcc6b 100644 --- a/mockgen/internal/tests/generics/other/other.go +++ b/mockgen/internal/tests/generics/other/other.go @@ -14,4 +14,9 @@ type Otherer[T any, R any] interface { DoT(T) error DoR(R) error MakeThem() (T, R, error) + GetThem() ([]T, error) + GetThemPtr() ([]*T, error) + GetThemMapped() ([]map[int64]*T, error) + GetMap() (map[bool]T, error) + AddChan(chan T) error } From 362cdf8af55ec6add907cef43227f6d16ac00cc5 Mon Sep 17 00:00:00 2001 From: Bradley Gore Date: Mon, 17 Oct 2022 08:50:37 -0500 Subject: [PATCH 12/14] handle variadic generic --- mockgen/generic_go118.go | 5 +++++ mockgen/internal/tests/generics/other/other.go | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/mockgen/generic_go118.go b/mockgen/generic_go118.go index 8c86e096..22f66c4d 100644 --- a/mockgen/generic_go118.go +++ b/mockgen/generic_go118.go @@ -168,6 +168,11 @@ func (p *fileParser) parseEmbeddedGenericIface(iface *model.Interface, field *as gm.Out[outIdx].Type = genType } } + if gm.Variadic != nil { + if vGenType, hasGeneric := p.getTypedParamForGeneric(gm.Variadic.Type, embeddedIface, typeParams); hasGeneric { + gm.Variadic.Type = vGenType + } + } iface.AddMethod(gm) } diff --git a/mockgen/internal/tests/generics/other/other.go b/mockgen/internal/tests/generics/other/other.go index a90fcc6b..39db6c23 100644 --- a/mockgen/internal/tests/generics/other/other.go +++ b/mockgen/internal/tests/generics/other/other.go @@ -13,7 +13,7 @@ type Five interface{} type Otherer[T any, R any] interface { DoT(T) error DoR(R) error - MakeThem() (T, R, error) + MakeThem(...T) (R, error) GetThem() ([]T, error) GetThemPtr() ([]*T, error) GetThemMapped() ([]map[int64]*T, error) From 6ad3ba6430d5ab4e403b129a53a989c6ee9ca734 Mon Sep 17 00:00:00 2001 From: Bradley Gore Date: Mon, 17 Oct 2022 10:58:23 -0500 Subject: [PATCH 13/14] regenerate test mock --- .../generics/source/mock_generics_test.go | 97 +++++++++++++++++-- 1 file changed, 87 insertions(+), 10 deletions(-) diff --git a/mockgen/internal/tests/generics/source/mock_generics_test.go b/mockgen/internal/tests/generics/source/mock_generics_test.go index 785784bc..2c1fefe1 100644 --- a/mockgen/internal/tests/generics/source/mock_generics_test.go +++ b/mockgen/internal/tests/generics/source/mock_generics_test.go @@ -7,9 +7,9 @@ package source import ( reflect "reflect" + other "github.com/golang/mock/mockgen/internal/tests/generics/other" gomock "github.com/golang/mock/gomock" generics "github.com/golang/mock/mockgen/internal/tests/generics" - other "github.com/golang/mock/mockgen/internal/tests/generics/other" ) // MockBar is a mock of Bar interface. @@ -379,6 +379,20 @@ func (m *MockEmbeddingIface) EXPECT() *MockEmbeddingIfaceMockRecorder { return m.recorder } +// AddChan mocks base method. +func (m *MockEmbeddingIface) AddChan(arg0 chan generics.StructType) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddChan", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// AddChan indicates an expected call of AddChan. +func (mr *MockEmbeddingIfaceMockRecorder) AddChan(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddChan", reflect.TypeOf((*MockEmbeddingIface)(nil).AddChan), arg0) +} + // DoR mocks base method. func (m *MockEmbeddingIface) DoR(arg0 other.Five) error { m.ctrl.T.Helper() @@ -509,6 +523,66 @@ func (mr *MockEmbeddingIfaceMockRecorder) Fourteen() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Fourteen", reflect.TypeOf((*MockEmbeddingIface)(nil).Fourteen)) } +// GetMap mocks base method. +func (m *MockEmbeddingIface) GetMap() (map[bool]generics.StructType, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetMap") + ret0, _ := ret[0].(map[bool]generics.StructType) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetMap indicates an expected call of GetMap. +func (mr *MockEmbeddingIfaceMockRecorder) GetMap() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMap", reflect.TypeOf((*MockEmbeddingIface)(nil).GetMap)) +} + +// GetThem mocks base method. +func (m *MockEmbeddingIface) GetThem() ([]generics.StructType, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetThem") + ret0, _ := ret[0].([]generics.StructType) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetThem indicates an expected call of GetThem. +func (mr *MockEmbeddingIfaceMockRecorder) GetThem() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetThem", reflect.TypeOf((*MockEmbeddingIface)(nil).GetThem)) +} + +// GetThemMapped mocks base method. +func (m *MockEmbeddingIface) GetThemMapped() ([]map[int64]*generics.StructType, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetThemMapped") + ret0, _ := ret[0].([]map[int64]*generics.StructType) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetThemMapped indicates an expected call of GetThemMapped. +func (mr *MockEmbeddingIfaceMockRecorder) GetThemMapped() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetThemMapped", reflect.TypeOf((*MockEmbeddingIface)(nil).GetThemMapped)) +} + +// GetThemPtr mocks base method. +func (m *MockEmbeddingIface) GetThemPtr() ([]*generics.StructType, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetThemPtr") + ret0, _ := ret[0].([]*generics.StructType) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetThemPtr indicates an expected call of GetThemPtr. +func (mr *MockEmbeddingIfaceMockRecorder) GetThemPtr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetThemPtr", reflect.TypeOf((*MockEmbeddingIface)(nil).GetThemPtr)) +} + // LocalFunc mocks base method. func (m *MockEmbeddingIface) LocalFunc() error { m.ctrl.T.Helper() @@ -524,19 +598,22 @@ func (mr *MockEmbeddingIfaceMockRecorder) LocalFunc() *gomock.Call { } // MakeThem mocks base method. -func (m *MockEmbeddingIface) MakeThem() (generics.StructType, other.Five, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "MakeThem") - ret0, _ := ret[0].(generics.StructType) - ret1, _ := ret[1].(other.Five) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 +func (m *MockEmbeddingIface) MakeThem(arg0 ...generics.StructType) (other.Five, error) { + m.ctrl.T.Helper() + varargs := []interface{}{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "MakeThem", varargs...) + ret0, _ := ret[0].(other.Five) + ret1, _ := ret[1].(error) + return ret0, ret1 } // MakeThem indicates an expected call of MakeThem. -func (mr *MockEmbeddingIfaceMockRecorder) MakeThem() *gomock.Call { +func (mr *MockEmbeddingIfaceMockRecorder) MakeThem(arg0 ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MakeThem", reflect.TypeOf((*MockEmbeddingIface)(nil).MakeThem)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MakeThem", reflect.TypeOf((*MockEmbeddingIface)(nil).MakeThem), arg0...) } // Nine mocks base method. From 53072b5db1e2ede3d47c4b2f41961a895f69cb03 Mon Sep 17 00:00:00 2001 From: Bradley Gore Date: Mon, 17 Oct 2022 11:00:00 -0500 Subject: [PATCH 14/14] sort import --- mockgen/internal/tests/generics/source/mock_generics_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mockgen/internal/tests/generics/source/mock_generics_test.go b/mockgen/internal/tests/generics/source/mock_generics_test.go index 2c1fefe1..de6e1779 100644 --- a/mockgen/internal/tests/generics/source/mock_generics_test.go +++ b/mockgen/internal/tests/generics/source/mock_generics_test.go @@ -7,9 +7,9 @@ package source import ( reflect "reflect" - other "github.com/golang/mock/mockgen/internal/tests/generics/other" gomock "github.com/golang/mock/gomock" generics "github.com/golang/mock/mockgen/internal/tests/generics" + other "github.com/golang/mock/mockgen/internal/tests/generics/other" ) // MockBar is a mock of Bar interface.