-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathcheck.go
183 lines (171 loc) · 4.98 KB
/
check.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
package main
import (
"fmt"
"go/ast"
"go/token"
"go/types"
"sort"
"strings"
"golang.org/x/tools/go/packages"
)
// inexhaustiveError is returned from check for each occurrence of inexhaustive
// case analysis in a Go type switch statement.
type inexhaustiveError struct {
Pos token.Position
Def sumTypeDef
Missing []types.Object
}
func (e inexhaustiveError) Error() string {
return fmt.Sprintf(
"%s: exhaustiveness check failed for sum type '%s': missing cases for %s",
e.Pos, e.Def.Decl.TypeName, strings.Join(e.Names(), ", "))
}
// Names returns a sorted list of names corresponding to the missing variant
// cases.
func (e inexhaustiveError) Names() []string {
var list []string
for _, o := range e.Missing {
list = append(list, o.Name())
}
sort.Sort(sort.StringSlice(list))
return list
}
// check does exhaustiveness checking for the given sum type definitions in the
// given package. Every instance of inexhaustive case analysis is returned.
func check(pkg *packages.Package, defs []sumTypeDef) []error {
var errs []error
for _, astfile := range pkg.Syntax {
ast.Inspect(astfile, func(n ast.Node) bool {
swtch, ok := n.(*ast.TypeSwitchStmt)
if !ok {
return true
}
if err := checkSwitch(pkg, defs, swtch); err != nil {
errs = append(errs, err)
}
return true
})
}
return errs
}
// checkSwitch performs an exhaustiveness check on the given type switch
// statement. If the type switch is used on a sum type and does not cover
// all variants of that sum type, then an error is returned indicating which
// variants were missed.
//
// Note that if the type switch contains a non-panicing default case, then
// exhaustiveness checks are disabled.
func checkSwitch(
pkg *packages.Package,
defs []sumTypeDef,
swtch *ast.TypeSwitchStmt,
) error {
def, missing := missingVariantsInSwitch(pkg, defs, swtch)
if len(missing) > 0 {
return inexhaustiveError{
Pos: pkg.Fset.Position(swtch.Pos()),
Def: *def,
Missing: missing,
}
}
return nil
}
// missingVariantsInSwitch returns a list of missing variants corresponding to
// the given switch statement. The corresponding sum type definition is also
// returned. (If no sum type definition could be found, then no exhaustiveness
// checks are performed, and therefore, no missing variants are returned.)
func missingVariantsInSwitch(
pkg *packages.Package,
defs []sumTypeDef,
swtch *ast.TypeSwitchStmt,
) (*sumTypeDef, []types.Object) {
asserted := findTypeAssertExpr(swtch)
ty := pkg.TypesInfo.TypeOf(asserted)
def := findDef(defs, ty)
if def == nil {
// We couldn't find a corresponding sum type, so there's
// nothing we can do to check it.
return nil, nil
}
variantExprs, hasDefault := switchVariants(swtch)
if hasDefault && !defaultClauseAlwaysPanics(swtch) {
// A catch-all case defeats all exhaustiveness checks.
return def, nil
}
var variantTypes []types.Type
for _, expr := range variantExprs {
variantTypes = append(variantTypes, pkg.TypesInfo.TypeOf(expr))
}
return def, def.missing(variantTypes)
}
// switchVariants returns all case expressions found in a type switch. This
// includes expressions from cases that have a list of expressions.
func switchVariants(swtch *ast.TypeSwitchStmt) (exprs []ast.Expr, hasDefault bool) {
for _, stmt := range swtch.Body.List {
clause := stmt.(*ast.CaseClause)
if clause.List == nil {
hasDefault = true
} else {
exprs = append(exprs, clause.List...)
}
}
return
}
// defaultClauseAlwaysPanics returns true if the given switch statement has a
// default clause that always panics. Note that this is done on a best-effort
// basis. While there will never be any false positives, there may be false
// negatives.
//
// If the given switch statement has no default clause, then this function
// panics.
func defaultClauseAlwaysPanics(swtch *ast.TypeSwitchStmt) bool {
var clause *ast.CaseClause
for _, stmt := range swtch.Body.List {
c := stmt.(*ast.CaseClause)
if c.List == nil {
clause = c
break
}
}
if clause == nil {
panic("switch statement has no default clause")
}
if len(clause.Body) != 1 {
return false
}
exprStmt, ok := clause.Body[0].(*ast.ExprStmt)
if !ok {
return false
}
callExpr, ok := exprStmt.X.(*ast.CallExpr)
if !ok {
return false
}
fun, ok := callExpr.Fun.(*ast.Ident)
if !ok {
return false
}
return fun.Name == "panic"
}
// findTypeAssertExpr extracts the expression that is being type asserted from a
// type swtich statement.
func findTypeAssertExpr(swtch *ast.TypeSwitchStmt) ast.Expr {
var expr ast.Expr
if assign, ok := swtch.Assign.(*ast.AssignStmt); ok {
expr = assign.Rhs[0]
} else {
expr = swtch.Assign.(*ast.ExprStmt).X
}
return expr.(*ast.TypeAssertExpr).X
}
// findDef returns the sum type definition corresponding to the given type. If
// no such sum type definition exists, then nil is returned.
func findDef(defs []sumTypeDef, needle types.Type) *sumTypeDef {
for i := range defs {
def := &defs[i]
if types.Identical(needle.Underlying(), def.Ty) {
return def
}
}
return nil
}