Skip to content

Commit 99e488c

Browse files
authored
Merge pull request #7 from Buzzvil/improve_usability
Simplify inspection steps
2 parents f4f3ec8 + 649a906 commit 99e488c

File tree

4 files changed

+25
-149
lines changed

4 files changed

+25
-149
lines changed

README.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ go install github.com/Buzzvil/recovergoroutine
1515
recovergoroutine -recover="" ./...
1616

1717
# -recover string
18-
# Custom recover method name. Currently, it is difficult to determine
19-
# if a CustomRecover function declared in another package is valid,
20-
# so this option can be used to resolve it.
18+
# Custom recovery method name. You can use this option
19+
# when you want to call a method defined in a struct or
20+
# use CustomRecover declared in an external package.
2121
```
2222

2323
Check out the test cases for validation [examples](./test/src/faildata/failcode.go).

recovergoroutine/recovergoroutine.go

+17-96
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,12 @@ package recovergoroutine
22

33
import (
44
"flag"
5-
"fmt"
65
"go/ast"
7-
"go/parser"
8-
"go/types"
9-
"reflect"
10-
116
"golang.org/x/tools/go/analysis"
127
)
138

9+
type message string
10+
1411
var customRecover string
1512

1613
func NewAnalyzer() *analysis.Analyzer {
@@ -25,8 +22,7 @@ func NewAnalyzer() *analysis.Analyzer {
2522
&customRecover,
2623
"recover",
2724
"",
28-
"It is difficult to determine if a CustomRecover function declared in another package is valid,"+
29-
" so this option can be used to resolve it.",
25+
"You can use this option when you want to call a method defined in a struct or use CustomRecover declared in an external package.",
3026
)
3127

3228
return analyzer
@@ -41,12 +37,7 @@ func run(pass *analysis.Pass) (interface{}, error) {
4137
return true
4238
}
4339

44-
ok, err := safeGoStmt(goStmt, pass)
45-
if err != nil {
46-
runErr = err
47-
return false
48-
}
49-
40+
ok, msg := safeGoStmt(goStmt)
5041
if ok {
5142
return true
5243
}
@@ -55,7 +46,7 @@ func run(pass *analysis.Pass) (interface{}, error) {
5546
Pos: goStmt.Pos(),
5647
End: 0,
5748
Category: "goroutine",
58-
Message: "goroutine must have recover",
49+
Message: string(msg),
5950
})
6051

6152
return false
@@ -65,43 +56,28 @@ func run(pass *analysis.Pass) (interface{}, error) {
6556
return nil, runErr
6657
}
6758

68-
func safeGoStmt(goStmt *ast.GoStmt, pass *analysis.Pass) (bool, error) {
59+
func safeGoStmt(goStmt *ast.GoStmt) (bool, message) {
6960
fn := goStmt.Call
7061
switch fun := fn.Fun.(type) {
71-
case *ast.SelectorExpr:
72-
return safeSelectorExpr(fun, pass, safeFunc)
7362
case *ast.FuncLit:
74-
return safeFunc(fun, pass)
75-
case *ast.Ident:
76-
if fun.Obj == nil {
77-
return false, nil
78-
}
79-
80-
funcDecl, ok := fun.Obj.Decl.(*ast.FuncDecl)
81-
if !ok {
82-
return false, nil
63+
if !safeFunc(fun) {
64+
return false, "goroutine must have recover"
8365
}
84-
85-
return safeFunc(funcDecl, pass)
66+
return true, ""
8667
}
8768

88-
return false, fmt.Errorf("unexpected goroutine function type: %v", reflect.TypeOf(fn.Fun).String())
69+
return false, "use function literals when using goroutines"
8970
}
9071

91-
func safeFunc(node ast.Node, pass *analysis.Pass) (bool, error) {
72+
func safeFunc(node ast.Node) bool {
9273
result := false
93-
var err error
9474
ast.Inspect(node, func(node ast.Node) bool {
9575
deferStmt, ok := node.(*ast.DeferStmt)
9676
if !ok {
9777
return true
9878
}
9979

100-
ok, err = hasRecover(deferStmt.Call, pass)
101-
if err != nil {
102-
return false
103-
}
104-
80+
ok = hasRecover(deferStmt.Call)
10581
if ok {
10682
result = true
10783
return false
@@ -110,12 +86,11 @@ func safeFunc(node ast.Node, pass *analysis.Pass) (bool, error) {
11086
return !result
11187
})
11288

113-
return result, err
89+
return result
11490
}
11591

116-
func hasRecover(expr ast.Node, pass *analysis.Pass) (bool, error) {
92+
func hasRecover(expr ast.Node) bool {
11793
var result bool
118-
var err error
11994
ast.Inspect(expr, func(node ast.Node) bool {
12095
switch n := node.(type) {
12196
case *ast.CallExpr:
@@ -128,69 +103,15 @@ func hasRecover(expr ast.Node, pass *analysis.Pass) (bool, error) {
128103
return true
129104
}
130105

131-
var ok bool
132-
ok, err = safeSelectorExpr(n, pass, hasRecover)
133-
if err != nil {
134-
return false
135-
}
136-
137-
if ok || n.Sel.Name == customRecover {
106+
if n.Sel.Name == customRecover {
138107
result = true
139108
return false
140109
}
141110
}
142111
return true
143112
})
144113

145-
return result, err
146-
}
147-
148-
func safeSelectorExpr(
149-
expr *ast.SelectorExpr,
150-
pass *analysis.Pass,
151-
methodChecker func(node ast.Node, pass *analysis.Pass) (bool, error),
152-
) (bool, error) {
153-
ident, ok := expr.X.(*ast.Ident)
154-
if !ok {
155-
return false, nil
156-
}
157-
158-
methodName := expr.Sel.Name
159-
objType := pass.TypesInfo.ObjectOf(ident)
160-
pointerType, ok := objType.Type().(*types.Pointer)
161-
if !ok {
162-
return false, nil
163-
}
164-
165-
named, ok := pointerType.Elem().(*types.Named)
166-
if !ok {
167-
return false, nil
168-
}
169-
170-
result := false
171-
for i := 0; i < named.NumMethods(); i++ {
172-
if named.Method(i).Name() != methodName {
173-
continue
174-
}
175-
176-
fset := pass.Fset
177-
position := fset.Position(named.Method(i).Pos())
178-
file, err := parser.ParseFile(fset, position.Filename, nil, 0)
179-
if err != nil {
180-
return false, fmt.Errorf("parse file: %w", err)
181-
}
182-
183-
for _, decl := range file.Decls {
184-
if funcDecl, ok := decl.(*ast.FuncDecl); ok {
185-
if funcDecl.Name.Name == methodName {
186-
result, err = methodChecker(funcDecl, pass)
187-
break
188-
}
189-
}
190-
}
191-
}
192-
193-
return result, nil
114+
return result
194115
}
195116

196117
func isRecover(callExpr *ast.CallExpr) bool {
@@ -199,7 +120,7 @@ func isRecover(callExpr *ast.CallExpr) bool {
199120
return false
200121
}
201122

202-
return ident.Name == "recover"
123+
return ident.Name == "recover" || ident.Name == customRecover
203124
}
204125

205126
func isCustomRecover(callExpr *ast.CallExpr) bool {

test/src/custom/recover.go

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
package custom

test/src/succdata/succcode.go

+4-50
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
package succdata
22

33
func whenASTFuncLit() {
4+
go func() {
5+
defer recover()
6+
}()
7+
48
go func() {
59
defer func() {
610
if r := recover(); r != nil {
@@ -23,54 +27,4 @@ func whenASTFuncLit() {
2327

2428
defer rec()
2529
}()
26-
27-
go func() {
28-
defer customRecover()
29-
}()
30-
31-
}
32-
33-
func whenIdent() {
34-
go runGoroutine()
35-
go nestedFunc1()
36-
}
37-
38-
func whenCallMethod() {
39-
foo := &Foo{}
40-
go foo.run()
41-
go func() {
42-
defer foo.Recover()
43-
}()
44-
}
45-
46-
func runGoroutine() {
47-
defer func() {
48-
recover()
49-
}()
50-
}
51-
52-
func nestedFunc1() {
53-
// must have recover in parent caller
54-
nestedFunc2()
55-
defer func() {
56-
recover()
57-
}()
58-
}
59-
60-
func nestedFunc2() {}
61-
62-
func customRecover() {
63-
recover()
64-
}
65-
66-
type Foo struct{}
67-
68-
func (a *Foo) run() {
69-
defer func() {
70-
recover()
71-
}()
72-
}
73-
74-
func (a *Foo) Recover() {
75-
recover()
7630
}

0 commit comments

Comments
 (0)