Skip to content

Commit b9576fe

Browse files
authored
Merge pull request #790 from LandonTClipp/issue_787
Allow types defined as instantiated generic interfaces to generate mocks
2 parents fb63d00 + 9738b5b commit b9576fe

10 files changed

+211
-32
lines changed

cmd/mockery.go

-5
Original file line numberDiff line numberDiff line change
@@ -244,11 +244,6 @@ func (r *RootApp) Run() error {
244244
log.Error().Err(err).Msg("unable to parse packages")
245245
return err
246246
}
247-
log.Info().Msg("done parsing, loading")
248-
if err := parser.Load(); err != nil {
249-
log.Err(err).Msgf("failed to load parser")
250-
return nil
251-
}
252247
log.Info().Msg("done loading, visiting interface nodes")
253248
for _, iface := range parser.Interfaces() {
254249
ifaceLog := log.

mocks/github.com/vektra/mockery/v2/pkg/fixtures/GenericInterface.go

+78
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

mocks/github.com/vektra/mockery/v2/pkg/fixtures/InstantiatedGenericInterface.go

+78
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package test
2+
3+
type GenericInterface[M any] interface {
4+
Func(arg *M) int
5+
}
6+
7+
type InstantiatedGenericInterface GenericInterface[float32]

pkg/fixtures/variadic.go

-1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,3 @@ type VariadicFunction = func(args1 string, args2 ...interface{}) interface{}
55
type Variadic interface {
66
VariadicFunction(str string, vFunc VariadicFunction) error
77
}
8-

pkg/generator_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ func (s *GeneratorSuite) getInterfaceFromFile(interfacePath, interfaceName strin
4040
)
4141

4242
s.Require().NoError(
43-
s.parser.Load(),
43+
s.parser.Load(context.Background()),
4444
)
4545

4646
iface, err := s.parser.Find(interfaceName)

pkg/outputter_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ packages:
262262
m.config.Config = confPath.String()
263263

264264
require.NoError(t, parser.ParsePackages(ctx, []string{tt.packagePath}))
265-
require.NoError(t, parser.Load())
265+
require.NoError(t, parser.Load(context.Background()))
266266
for _, intf := range parser.Interfaces() {
267267
t.Logf("generating interface: %s %s", intf.QualifiedName, intf.Name)
268268
require.NoError(t, m.Generate(ctx, intf))

pkg/parse.go

+40-18
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,27 @@ import (
1616
"golang.org/x/tools/go/packages"
1717
)
1818

19-
type parserEntry struct {
19+
type fileEntry struct {
2020
fileName string
2121
pkg *packages.Package
2222
syntax *ast.File
2323
interfaces []string
2424
}
2525

26+
func (f *fileEntry) ParseInterfaces(ctx context.Context) {
27+
nv := NewNodeVisitor(ctx)
28+
ast.Walk(nv, f.syntax)
29+
f.interfaces = nv.DeclaredInterfaces()
30+
}
31+
2632
type packageLoadEntry struct {
2733
pkgs []*packages.Package
2834
err error
2935
}
3036

3137
type Parser struct {
32-
entries []*parserEntry
33-
entriesByFileName map[string]*parserEntry
38+
files []*fileEntry
39+
entriesByFileName map[string]*fileEntry
3440
parserPackages []*types.Package
3541
conf packages.Config
3642
packageLoadCache map[string]packageLoadEntry
@@ -52,7 +58,7 @@ func NewParser(buildTags []string) *Parser {
5258
}
5359
return &Parser{
5460
parserPackages: make([]*types.Package, 0),
55-
entriesByFileName: map[string]*parserEntry{},
61+
entriesByFileName: map[string]*fileEntry{},
5662
conf: conf,
5763
packageLoadCache: map[string]packageLoadEntry{},
5864
}
@@ -86,18 +92,21 @@ func (p *Parser) ParsePackages(ctx context.Context, packageNames []string) error
8692
Str("package", pkg.PkgPath).
8793
Str("file", file).
8894
Msgf("found file")
89-
entry := parserEntry{
95+
entry := fileEntry{
9096
fileName: file,
9197
pkg: pkg,
9298
syntax: pkg.Syntax[fileIdx],
9399
}
94-
p.entries = append(p.entries, &entry)
100+
entry.ParseInterfaces(ctx)
101+
p.files = append(p.files, &entry)
95102
p.entriesByFileName[file] = &entry
96103
}
97104
}
98105
return nil
99106
}
100107

108+
// DEPRECATED: Parse is part of the deprecated, legacy mockery behavior. This is not
109+
// used when the packages feature is enabled.
101110
func (p *Parser) Parse(ctx context.Context, path string) error {
102111
// To support relative paths to mock targets w/ vendor deps, we need to provide eventual
103112
// calls to build.Context.Import with an absolute path. It needs to be absolute because
@@ -164,30 +173,28 @@ func (p *Parser) Parse(ctx context.Context, path string) error {
164173
if _, ok := p.entriesByFileName[f]; ok {
165174
continue
166175
}
167-
entry := parserEntry{
176+
entry := fileEntry{
168177
fileName: f,
169178
pkg: pkg,
170179
syntax: pkg.Syntax[idx],
171180
}
172-
p.entries = append(p.entries, &entry)
181+
p.files = append(p.files, &entry)
173182
p.entriesByFileName[f] = &entry
174183
}
175184
}
176185

177186
return nil
178187
}
179188

180-
func (p *Parser) Load() error {
181-
for _, entry := range p.entries {
182-
nv := NewNodeVisitor()
183-
ast.Walk(nv, entry.syntax)
184-
entry.interfaces = nv.DeclaredInterfaces()
189+
func (p *Parser) Load(ctx context.Context) error {
190+
for _, entry := range p.files {
191+
entry.ParseInterfaces(ctx)
185192
}
186193
return nil
187194
}
188195

189196
func (p *Parser) Find(name string) (*Interface, error) {
190-
for _, entry := range p.entries {
197+
for _, entry := range p.files {
191198
for _, iface := range entry.interfaces {
192199
if iface == name {
193200
list := p.packageInterfaces(entry.pkg.Types, entry.fileName, []string{name}, nil)
@@ -202,7 +209,7 @@ func (p *Parser) Find(name string) (*Interface, error) {
202209

203210
func (p *Parser) Interfaces() []*Interface {
204211
ifaces := make(sortableIFaceList, 0)
205-
for _, entry := range p.entries {
212+
for _, entry := range p.files {
206213
declaredIfaces := entry.interfaces
207214
ifaces = p.packageInterfaces(entry.pkg.Types, entry.fileName, declaredIfaces, ifaces)
208215
}
@@ -314,12 +321,15 @@ func (s sortableIFaceList) Less(i, j int) bool {
314321
}
315322

316323
type NodeVisitor struct {
317-
declaredInterfaces []string
324+
declaredInterfaces []string
325+
genericInstantiationInterface map[string]any
326+
ctx context.Context
318327
}
319328

320-
func NewNodeVisitor() *NodeVisitor {
329+
func NewNodeVisitor(ctx context.Context) *NodeVisitor {
321330
return &NodeVisitor{
322331
declaredInterfaces: make([]string, 0),
332+
ctx: ctx,
323333
}
324334
}
325335

@@ -328,11 +338,23 @@ func (nv *NodeVisitor) DeclaredInterfaces() []string {
328338
}
329339

330340
func (nv *NodeVisitor) Visit(node ast.Node) ast.Visitor {
341+
log := zerolog.Ctx(nv.ctx)
342+
331343
switch n := node.(type) {
332344
case *ast.TypeSpec:
345+
log := log.With().
346+
Str("node-name", n.Name.Name).
347+
Str("node-type", fmt.Sprintf("%T", n.Type)).
348+
Logger()
349+
333350
switch n.Type.(type) {
334-
case *ast.InterfaceType, *ast.FuncType:
351+
case *ast.InterfaceType, *ast.FuncType, *ast.IndexExpr:
352+
log.Debug().
353+
Str("node-type", fmt.Sprintf("%T", n.Type)).
354+
Msg("found node with acceptable type for mocking")
335355
nv.declaredInterfaces = append(nv.declaredInterfaces, n.Name.Name)
356+
default:
357+
log.Debug().Msg("Found node with unacceptable type for mocking. Rejecting.")
336358
}
337359
}
338360
return nv

pkg/parse_test.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ func TestFileParse(t *testing.T) {
1616
err := parser.Parse(ctx, testFile)
1717
assert.NoError(t, err)
1818

19-
err = parser.Load()
19+
err = parser.Load(context.Background())
2020
assert.NoError(t, err)
2121

2222
node, err := parser.Find("Requester")
@@ -38,7 +38,7 @@ func TestBuildTagInFilename(t *testing.T) {
3838
err = parser.Parse(ctx, getFixturePath("buildtag", "filename", "iface_freebsd.go"))
3939
assert.NoError(t, err)
4040

41-
err = parser.Load()
41+
err = parser.Load(context.Background())
4242
assert.NoError(t, err) // Expect "redeclared in this block" if tags aren't respected
4343

4444
nodes := parser.Interfaces()
@@ -60,7 +60,7 @@ func TestBuildTagInComment(t *testing.T) {
6060
err = parser.Parse(ctx, getFixturePath("buildtag", "comment", "freebsd_iface.go"))
6161
assert.NoError(t, err)
6262

63-
err = parser.Load()
63+
err = parser.Load(context.Background())
6464
assert.NoError(t, err) // Expect "redeclared in this block" if tags aren't respected
6565

6666
nodes := parser.Interfaces()
@@ -78,7 +78,7 @@ func TestCustomBuildTag(t *testing.T) {
7878
err = parser.Parse(ctx, getFixturePath("buildtag", "comment", "custom2_iface.go"))
7979
assert.NoError(t, err)
8080

81-
err = parser.Load()
81+
err = parser.Load(context.Background())
8282
assert.NoError(t, err) // Expect "redeclared in this block" if tags aren't respected
8383

8484
found := false
@@ -94,6 +94,6 @@ func TestCustomBuildTag(t *testing.T) {
9494
func TestParsePackages(t *testing.T) {
9595
parser := NewParser([]string{})
9696
require.NoError(t, parser.ParsePackages(context.Background(), []string{"github.com/vektra/mockery/v2/pkg/fixtures"}))
97-
assert.NotEqual(t, 0, len(parser.entries))
97+
assert.NotEqual(t, 0, len(parser.files))
9898

9999
}

0 commit comments

Comments
 (0)