From 42f9c06959944d9847ccf76c6f493c616159da92 Mon Sep 17 00:00:00 2001 From: Anton Telyshev Date: Sun, 3 Mar 2024 16:57:03 +0300 Subject: [PATCH 1/3] bool-compare: support custom types --- analyzer/analyzer_test.go | 4 + .../bool_compare_custom_types_test.go | 114 ++++++++++++++++++ .../bool_compare_custom_types_test.go.golden | 114 ++++++++++++++++++ .../builtin-override/bool_override_test.go | 19 +++ .../bool-compare-custom-types/types/bool.go | 5 + internal/checkers/bool_compare.go | 105 +++++++++++++--- 6 files changed, 344 insertions(+), 17 deletions(-) create mode 100644 analyzer/testdata/src/bool-compare-custom-types/bool_compare_custom_types_test.go create mode 100644 analyzer/testdata/src/bool-compare-custom-types/bool_compare_custom_types_test.go.golden create mode 100644 analyzer/testdata/src/bool-compare-custom-types/builtin-override/bool_override_test.go create mode 100644 analyzer/testdata/src/bool-compare-custom-types/types/bool.go diff --git a/analyzer/analyzer_test.go b/analyzer/analyzer_test.go index 6cd2d2ce..0efc971c 100644 --- a/analyzer/analyzer_test.go +++ b/analyzer/analyzer_test.go @@ -21,6 +21,10 @@ func TestTestifyLint(t *testing.T) { dir: "base-test", flags: map[string]string{"disable-all": "true", "enable": checkers.NewBoolCompare().Name()}, }, + { + dir: "bool-compare-custom-types", + flags: map[string]string{"disable-all": "true", "enable": checkers.NewBoolCompare().Name()}, + }, { dir: "checkers-priority", flags: map[string]string{"enable-all": "true"}, diff --git a/analyzer/testdata/src/bool-compare-custom-types/bool_compare_custom_types_test.go b/analyzer/testdata/src/bool-compare-custom-types/bool_compare_custom_types_test.go new file mode 100644 index 00000000..f5f50e1b --- /dev/null +++ b/analyzer/testdata/src/bool-compare-custom-types/bool_compare_custom_types_test.go @@ -0,0 +1,114 @@ +package boolcomparecustomtypes_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "bool-compare-custom-types/types" +) + +type MyBool bool + +func TestBoolCompareChecker_CustomTypes(t *testing.T) { + var b MyBool + { + assert.Equal(t, false, b) // want "bool-compare: use assert\\.False" + assert.EqualValues(t, false, b) // want "bool-compare: use assert\\.False" + assert.Exactly(t, false, b) + + assert.Equal(t, true, b) // want "bool-compare: use assert\\.True" + assert.EqualValues(t, true, b) // want "bool-compare: use assert\\.True" + assert.Exactly(t, true, b) + + assert.NotEqual(t, false, b) // want "bool-compare: use assert\\.True" + assert.NotEqualValues(t, false, b) // want "bool-compare: use assert\\.True" + + assert.NotEqual(t, true, b) // want "bool-compare: use assert\\.False" + assert.NotEqualValues(t, true, b) // want "bool-compare: use assert\\.False" + + assert.True(t, b == true) // want "bool-compare: need to simplify the assertion" + assert.True(t, b != false) // want "bool-compare: need to simplify the assertion" + assert.True(t, b == false) // want "bool-compare: use assert\\.False" + assert.True(t, b != true) // want "bool-compare: use assert\\.False" + + assert.False(t, b == true) // want "bool-compare: need to simplify the assertion" + assert.False(t, b != false) // want "bool-compare: need to simplify the assertion" + assert.False(t, b == false) // want "bool-compare: use assert\\.True" + assert.False(t, b != true) // want "bool-compare: use assert\\.True" + } + + var extB types.Bool + { + assert.Equal(t, false, extB) // want "bool-compare: use assert\\.False" + assert.EqualValues(t, false, extB) // want "bool-compare: use assert\\.False" + assert.Exactly(t, false, extB) + + assert.Equal(t, true, extB) // want "bool-compare: use assert\\.True" + assert.EqualValues(t, true, extB) // want "bool-compare: use assert\\.True" + assert.Exactly(t, true, extB) + + assert.NotEqual(t, false, extB) // want "bool-compare: use assert\\.True" + assert.NotEqualValues(t, false, extB) // want "bool-compare: use assert\\.True" + + assert.NotEqual(t, true, extB) // want "bool-compare: use assert\\.False" + assert.NotEqualValues(t, true, extB) // want "bool-compare: use assert\\.False" + + assert.True(t, extB == true) // want "bool-compare: need to simplify the assertion" + assert.True(t, extB != false) // want "bool-compare: need to simplify the assertion" + assert.True(t, extB == false) // want "bool-compare: use assert\\.False" + assert.True(t, extB != true) // want "bool-compare: use assert\\.False" + + assert.False(t, extB == true) // want "bool-compare: need to simplify the assertion" + assert.False(t, extB != false) // want "bool-compare: need to simplify the assertion" + assert.False(t, extB == false) // want "bool-compare: use assert\\.True" + assert.False(t, extB != true) // want "bool-compare: use assert\\.True" + } + + var extSuperB types.SuperBool + { + assert.Equal(t, false, extSuperB) // want "bool-compare: use assert\\.False" + assert.EqualValues(t, false, extSuperB) // want "bool-compare: use assert\\.False" + assert.Exactly(t, false, extSuperB) + + assert.Equal(t, true, extSuperB) // want "bool-compare: use assert\\.True" + assert.EqualValues(t, true, extSuperB) // want "bool-compare: use assert\\.True" + assert.Exactly(t, true, extSuperB) + + assert.NotEqual(t, false, extSuperB) // want "bool-compare: use assert\\.True" + assert.NotEqualValues(t, false, extSuperB) // want "bool-compare: use assert\\.True" + + assert.NotEqual(t, true, extSuperB) // want "bool-compare: use assert\\.False" + assert.NotEqualValues(t, true, extSuperB) // want "bool-compare: use assert\\.False" + + assert.True(t, extSuperB == true) // want "bool-compare: need to simplify the assertion" + assert.True(t, extSuperB != false) // want "bool-compare: need to simplify the assertion" + assert.True(t, extSuperB == false) // want "bool-compare: use assert\\.False" + assert.True(t, extSuperB != true) // want "bool-compare: use assert\\.False" + + assert.False(t, extSuperB == true) // want "bool-compare: need to simplify the assertion" + assert.False(t, extSuperB != false) // want "bool-compare: need to simplify the assertion" + assert.False(t, extSuperB == false) // want "bool-compare: use assert\\.True" + assert.False(t, extSuperB != true) // want "bool-compare: use assert\\.True" + } + + // Crazy cases: + { + assert.Equal(t, true, types.Bool(extSuperB)) // want "bool-compare: use assert\\.True" + assert.Equal(t, true, types.SuperBool(b)) // want "bool-compare: use assert\\.True" + assert.Equal(t, true, bool(types.SuperBool(b))) // want "bool-compare: use assert\\.True" + assert.True(t, !bool(types.SuperBool(b))) // want "bool-compare: use assert\\.False" + assert.False(t, !bool(types.SuperBool(b))) // want "bool-compare: use assert\\.True" + } +} + +func TestBoolCompareChecker_CustomTypes_Format(t *testing.T) { + var predicate MyBool + assert.Equal(t, true, predicate) // want "bool-compare: use assert\\.True" + assert.Equal(t, true, predicate, "msg") // want "bool-compare: use assert\\.True" + assert.Equal(t, true, predicate, "msg with arg %d", 42) // want "bool-compare: use assert\\.True" + assert.Equal(t, true, predicate, "msg with args %d %s", 42, "42") // want "bool-compare: use assert\\.True" + assert.Equalf(t, true, predicate, "msg") // want "bool-compare: use assert\\.Truef" + assert.Equalf(t, true, predicate, "msg with arg %d", 42) // want "bool-compare: use assert\\.Truef" + assert.Equalf(t, true, predicate, "msg with args %d %s", 42, "42") // want "bool-compare: use assert\\.Truef" +} diff --git a/analyzer/testdata/src/bool-compare-custom-types/bool_compare_custom_types_test.go.golden b/analyzer/testdata/src/bool-compare-custom-types/bool_compare_custom_types_test.go.golden new file mode 100644 index 00000000..8cc552e6 --- /dev/null +++ b/analyzer/testdata/src/bool-compare-custom-types/bool_compare_custom_types_test.go.golden @@ -0,0 +1,114 @@ +package boolcomparecustomtypes_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "bool-compare-custom-types/types" +) + +type MyBool bool + +func TestBoolCompareChecker_CustomTypes(t *testing.T) { + var b MyBool + { + assert.False(t, bool(b)) // want "bool-compare: use assert\\.False" + assert.False(t, bool(b)) // want "bool-compare: use assert\\.False" + assert.Exactly(t, false, b) + + assert.True(t, bool(b)) // want "bool-compare: use assert\\.True" + assert.True(t, bool(b)) // want "bool-compare: use assert\\.True" + assert.Exactly(t, true, b) + + assert.True(t, bool(b)) // want "bool-compare: use assert\\.True" + assert.True(t, bool(b)) // want "bool-compare: use assert\\.True" + + assert.False(t, bool(b)) // want "bool-compare: use assert\\.False" + assert.False(t, bool(b)) // want "bool-compare: use assert\\.False" + + assert.True(t, bool(b)) // want "bool-compare: need to simplify the assertion" + assert.True(t, bool(b)) // want "bool-compare: need to simplify the assertion" + assert.False(t, bool(b)) // want "bool-compare: use assert\\.False" + assert.False(t, bool(b)) // want "bool-compare: use assert\\.False" + + assert.False(t, bool(b)) // want "bool-compare: need to simplify the assertion" + assert.False(t, bool(b)) // want "bool-compare: need to simplify the assertion" + assert.True(t, bool(b)) // want "bool-compare: use assert\\.True" + assert.True(t, bool(b)) // want "bool-compare: use assert\\.True" + } + + var extB types.Bool + { + assert.False(t, bool(extB)) // want "bool-compare: use assert\\.False" + assert.False(t, bool(extB)) // want "bool-compare: use assert\\.False" + assert.Exactly(t, false, extB) + + assert.True(t, bool(extB)) // want "bool-compare: use assert\\.True" + assert.True(t, bool(extB)) // want "bool-compare: use assert\\.True" + assert.Exactly(t, true, extB) + + assert.True(t, bool(extB)) // want "bool-compare: use assert\\.True" + assert.True(t, bool(extB)) // want "bool-compare: use assert\\.True" + + assert.False(t, bool(extB)) // want "bool-compare: use assert\\.False" + assert.False(t, bool(extB)) // want "bool-compare: use assert\\.False" + + assert.True(t, bool(extB)) // want "bool-compare: need to simplify the assertion" + assert.True(t, bool(extB)) // want "bool-compare: need to simplify the assertion" + assert.False(t, bool(extB)) // want "bool-compare: use assert\\.False" + assert.False(t, bool(extB)) // want "bool-compare: use assert\\.False" + + assert.False(t, bool(extB)) // want "bool-compare: need to simplify the assertion" + assert.False(t, bool(extB)) // want "bool-compare: need to simplify the assertion" + assert.True(t, bool(extB)) // want "bool-compare: use assert\\.True" + assert.True(t, bool(extB)) // want "bool-compare: use assert\\.True" + } + + var extSuperB types.SuperBool + { + assert.False(t, bool(extSuperB)) // want "bool-compare: use assert\\.False" + assert.False(t, bool(extSuperB)) // want "bool-compare: use assert\\.False" + assert.Exactly(t, false, extSuperB) + + assert.True(t, bool(extSuperB)) // want "bool-compare: use assert\\.True" + assert.True(t, bool(extSuperB)) // want "bool-compare: use assert\\.True" + assert.Exactly(t, true, extSuperB) + + assert.True(t, bool(extSuperB)) // want "bool-compare: use assert\\.True" + assert.True(t, bool(extSuperB)) // want "bool-compare: use assert\\.True" + + assert.False(t, bool(extSuperB)) // want "bool-compare: use assert\\.False" + assert.False(t, bool(extSuperB)) // want "bool-compare: use assert\\.False" + + assert.True(t, bool(extSuperB)) // want "bool-compare: need to simplify the assertion" + assert.True(t, bool(extSuperB)) // want "bool-compare: need to simplify the assertion" + assert.False(t, bool(extSuperB)) // want "bool-compare: use assert\\.False" + assert.False(t, bool(extSuperB)) // want "bool-compare: use assert\\.False" + + assert.False(t, bool(extSuperB)) // want "bool-compare: need to simplify the assertion" + assert.False(t, bool(extSuperB)) // want "bool-compare: need to simplify the assertion" + assert.True(t, bool(extSuperB)) // want "bool-compare: use assert\\.True" + assert.True(t, bool(extSuperB)) // want "bool-compare: use assert\\.True" + } + + // Crazy cases: + { + assert.True(t, bool(types.Bool(extSuperB))) // want "bool-compare: use assert\\.True" + assert.True(t, bool(types.SuperBool(b))) // want "bool-compare: use assert\\.True" + assert.True(t, bool(types.SuperBool(b))) // want "bool-compare: use assert\\.True" + assert.False(t, bool(types.SuperBool(b))) // want "bool-compare: use assert\\.False" + assert.True(t, bool(types.SuperBool(b))) // want "bool-compare: use assert\\.True" + } +} + +func TestBoolCompareChecker_CustomTypes_Format(t *testing.T) { + var predicate MyBool + assert.True(t, bool(predicate)) // want "bool-compare: use assert\\.True" + assert.True(t, bool(predicate), "msg") // want "bool-compare: use assert\\.True" + assert.True(t, bool(predicate), "msg with arg %d", 42) // want "bool-compare: use assert\\.True" + assert.True(t, bool(predicate), "msg with args %d %s", 42, "42") // want "bool-compare: use assert\\.True" + assert.Truef(t, bool(predicate), "msg") // want "bool-compare: use assert\\.Truef" + assert.Truef(t, bool(predicate), "msg with arg %d", 42) // want "bool-compare: use assert\\.Truef" + assert.Truef(t, bool(predicate), "msg with args %d %s", 42, "42") // want "bool-compare: use assert\\.Truef" +} diff --git a/analyzer/testdata/src/bool-compare-custom-types/builtin-override/bool_override_test.go b/analyzer/testdata/src/bool-compare-custom-types/builtin-override/bool_override_test.go new file mode 100644 index 00000000..1b22331d --- /dev/null +++ b/analyzer/testdata/src/bool-compare-custom-types/builtin-override/bool_override_test.go @@ -0,0 +1,19 @@ +package boolcomparecustomtypes_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +type bool int + +func TestBoolCompareChecker_BoolOverride(t *testing.T) { + var mimic bool + assert.Equal(t, false, mimic) + assert.Equal(t, false, mimic) + assert.EqualValues(t, false, mimic) + assert.Exactly(t, false, mimic) + assert.NotEqual(t, false, mimic) + assert.NotEqualValues(t, false, mimic) +} diff --git a/analyzer/testdata/src/bool-compare-custom-types/types/bool.go b/analyzer/testdata/src/bool-compare-custom-types/types/bool.go new file mode 100644 index 00000000..8d733b8a --- /dev/null +++ b/analyzer/testdata/src/bool-compare-custom-types/types/bool.go @@ -0,0 +1,5 @@ +package types + +type SuperBool Bool + +type Bool bool diff --git a/internal/checkers/bool_compare.go b/internal/checkers/bool_compare.go index c8db9420..0cad22a5 100644 --- a/internal/checkers/bool_compare.go +++ b/internal/checkers/bool_compare.go @@ -32,7 +32,14 @@ func NewBoolCompare() BoolCompare { return BoolCompare{} } func (BoolCompare) Name() string { return "bool-compare" } func (checker BoolCompare) Check(pass *analysis.Pass, call *CallMeta) *analysis.Diagnostic { - newUseFnDiagnostic := func(proposed string, survivingArg ast.Node, replaceStart, replaceEnd token.Pos) *analysis.Diagnostic { + newBoolCast := func(e ast.Expr) ast.Expr { + return &ast.CallExpr{Fun: &ast.Ident{Name: "bool"}, Args: []ast.Expr{e}} + } + + newUseFnDiagnostic := func(proposed string, survivingArg ast.Expr, replaceStart, replaceEnd token.Pos) *analysis.Diagnostic { + if !isBuiltinBool(pass, survivingArg) { + survivingArg = newBoolCast(survivingArg) + } return newUseFunctionDiagnostic(checker.Name(), call, proposed, newSuggestedFuncReplacement(call, proposed, analysis.TextEdit{ Pos: replaceStart, @@ -42,15 +49,18 @@ func (checker BoolCompare) Check(pass *analysis.Pass, call *CallMeta) *analysis. ) } - newUseTrueDiagnostic := func(survivingArg ast.Node, replaceStart, replaceEnd token.Pos) *analysis.Diagnostic { + newUseTrueDiagnostic := func(survivingArg ast.Expr, replaceStart, replaceEnd token.Pos) *analysis.Diagnostic { return newUseFnDiagnostic("True", survivingArg, replaceStart, replaceEnd) } - newUseFalseDiagnostic := func(survivingArg ast.Node, replaceStart, replaceEnd token.Pos) *analysis.Diagnostic { + newUseFalseDiagnostic := func(survivingArg ast.Expr, replaceStart, replaceEnd token.Pos) *analysis.Diagnostic { return newUseFnDiagnostic("False", survivingArg, replaceStart, replaceEnd) } - newNeedSimplifyDiagnostic := func(survivingArg ast.Node, replaceStart, replaceEnd token.Pos) *analysis.Diagnostic { + newNeedSimplifyDiagnostic := func(survivingArg ast.Expr, replaceStart, replaceEnd token.Pos) *analysis.Diagnostic { + if !isBuiltinBool(pass, survivingArg) { + survivingArg = newBoolCast(survivingArg) + } return newDiagnostic(checker.Name(), call, "need to simplify the assertion", &analysis.SuggestedFix{ Message: "Simplify the assertion", @@ -70,7 +80,10 @@ func (checker BoolCompare) Check(pass *analysis.Pass, call *CallMeta) *analysis. } arg1, arg2 := call.Args[0], call.Args[1] - if isEmptyInterface(pass, arg1) || isEmptyInterface(pass, arg2) { + if anyCondSatisfaction(pass, isEmptyInterface, arg1, arg2) { + return nil + } + if anyCondSatisfaction(pass, isBoolOverride, arg1, arg2) { return nil } @@ -80,10 +93,18 @@ func (checker BoolCompare) Check(pass *analysis.Pass, call *CallMeta) *analysis. switch { case xor(t1, t2): survivingArg, _ := anyVal([]bool{t1, t2}, arg2, arg1) + if call.Fn.NameFTrimmed == "Exactly" && !isBuiltinBool(pass, survivingArg) { + // NOTE(a.telyshev): `Exactly` assumes no type casting. + return nil + } return newUseTrueDiagnostic(survivingArg, arg1.Pos(), arg2.End()) case xor(f1, f2): survivingArg, _ := anyVal([]bool{f1, f2}, arg2, arg1) + if call.Fn.NameFTrimmed == "Exactly" && !isBuiltinBool(pass, survivingArg) { + // NOTE(a.telyshev): `Exactly` assumes no type casting. + return nil + } return newUseFalseDiagnostic(survivingArg, arg1.Pos(), arg2.End()) } @@ -93,7 +114,10 @@ func (checker BoolCompare) Check(pass *analysis.Pass, call *CallMeta) *analysis. } arg1, arg2 := call.Args[0], call.Args[1] - if isEmptyInterface(pass, arg1) || isEmptyInterface(pass, arg2) { + if anyCondSatisfaction(pass, isEmptyInterface, arg1, arg2) { + return nil + } + if anyCondSatisfaction(pass, isBoolOverride, arg1, arg2) { return nil } @@ -120,8 +144,15 @@ func (checker BoolCompare) Check(pass *analysis.Pass, call *CallMeta) *analysis. arg1, ok1 := isComparisonWithTrue(pass, expr, token.EQL) arg2, ok2 := isComparisonWithFalse(pass, expr, token.NEQ) + if anyCondSatisfaction(pass, isEmptyInterface, arg1, arg2) { + return nil + } + if anyCondSatisfaction(pass, isBoolOverride, arg1, arg2) { + return nil + } + survivingArg, ok := anyVal([]bool{ok1, ok2}, arg1, arg2) - if ok && !isEmptyInterface(pass, survivingArg) { + if ok { return newNeedSimplifyDiagnostic(survivingArg, expr.Pos(), expr.End()) } } @@ -131,8 +162,15 @@ func (checker BoolCompare) Check(pass *analysis.Pass, call *CallMeta) *analysis. arg2, ok2 := isComparisonWithFalse(pass, expr, token.EQL) arg3, ok3 := isNegation(expr) + if anyCondSatisfaction(pass, isEmptyInterface, arg1, arg2, arg3) { + return nil + } + if anyCondSatisfaction(pass, isBoolOverride, arg1, arg2, arg3) { + return nil + } + survivingArg, ok := anyVal([]bool{ok1, ok2, ok3}, arg1, arg2, arg3) - if ok && !isEmptyInterface(pass, survivingArg) { + if ok { return newUseFalseDiagnostic(survivingArg, expr.Pos(), expr.End()) } } @@ -147,8 +185,15 @@ func (checker BoolCompare) Check(pass *analysis.Pass, call *CallMeta) *analysis. arg1, ok1 := isComparisonWithTrue(pass, expr, token.EQL) arg2, ok2 := isComparisonWithFalse(pass, expr, token.NEQ) + if anyCondSatisfaction(pass, isEmptyInterface, arg1, arg2) { + return nil + } + if anyCondSatisfaction(pass, isBoolOverride, arg1, arg2) { + return nil + } + survivingArg, ok := anyVal([]bool{ok1, ok2}, arg1, arg2) - if ok && !isEmptyInterface(pass, survivingArg) { + if ok { return newNeedSimplifyDiagnostic(survivingArg, expr.Pos(), expr.End()) } } @@ -158,8 +203,15 @@ func (checker BoolCompare) Check(pass *analysis.Pass, call *CallMeta) *analysis. arg2, ok2 := isComparisonWithFalse(pass, expr, token.EQL) arg3, ok3 := isNegation(expr) + if anyCondSatisfaction(pass, isEmptyInterface, arg1, arg2, arg3) { + return nil + } + if anyCondSatisfaction(pass, isBoolOverride, arg1, arg2, arg3) { + return nil + } + survivingArg, ok := anyVal([]bool{ok1, ok2, ok3}, arg1, arg2, arg3) - if ok && !isEmptyInterface(pass, survivingArg) { + if ok { return newUseTrueDiagnostic(survivingArg, expr.Pos(), expr.End()) } } @@ -167,6 +219,26 @@ func (checker BoolCompare) Check(pass *analysis.Pass, call *CallMeta) *analysis. return nil } +func isEmptyInterface(pass *analysis.Pass, expr ast.Expr) bool { + t, ok := pass.TypesInfo.Types[expr] + if !ok { + return false + } + + iface, ok := t.Type.Underlying().(*types.Interface) + return ok && iface.NumMethods() == 0 +} + +func isBuiltinBool(pass *analysis.Pass, e ast.Expr) bool { + basicType, ok := pass.TypesInfo.TypeOf(e).(*types.Basic) + return ok && basicType.Kind() == types.Bool +} + +func isBoolOverride(pass *analysis.Pass, e ast.Expr) bool { + namedType, ok := pass.TypesInfo.TypeOf(e).(*types.Named) + return ok && namedType.Obj().Name() == "bool" +} + var ( falseObj = types.Universe.Lookup("false") trueObj = types.Universe.Lookup("true") @@ -237,12 +309,11 @@ func anyVal[T any](bools []bool, vals ...T) (T, bool) { return _default, false } -func isEmptyInterface(pass *analysis.Pass, expr ast.Expr) bool { - t, ok := pass.TypesInfo.Types[expr] - if !ok { - return false +func anyCondSatisfaction(pass *analysis.Pass, p predicate, vals ...ast.Expr) bool { + for _, v := range vals { + if p(pass, v) { + return true + } } - - iface, ok := t.Type.Underlying().(*types.Interface) - return ok && iface.NumMethods() == 0 + return false } From 4ecbfc9edf3c914c6204a41eb077388e5c6588c4 Mon Sep 17 00:00:00 2001 From: Anton Telyshev Date: Sun, 3 Mar 2024 17:20:20 +0300 Subject: [PATCH 2/3] bool-compare: ignore-custom-types flag --- README.md | 16 +++ analyzer/analyzer_test.go | 8 ++ analyzer/checkers_factory.go | 3 + analyzer/checkers_factory_test.go | 14 +++ .../bool_compare_custom_types_test.go | 3 +- .../bool_compare_custom_types_test.go.golden | 3 +- .../bool_compare_ignore_custom_types_test.go | 102 ++++++++++++++++++ ...compare_ignore_custom_types_test.go.golden | 102 ++++++++++++++++++ internal/checkers/bool_compare.go | 19 +++- internal/config/config.go | 8 ++ internal/config/config_test.go | 17 +-- 11 files changed, 282 insertions(+), 13 deletions(-) create mode 100644 analyzer/testdata/src/bool-compare-ignore-custom-types/bool_compare_ignore_custom_types_test.go create mode 100644 analyzer/testdata/src/bool-compare-ignore-custom-types/bool_compare_ignore_custom_types_test.go.golden diff --git a/README.md b/README.md index 6fdc560c..8ca0ff4d 100644 --- a/README.md +++ b/README.md @@ -137,6 +137,22 @@ import ( **Enabled by default**: true.
**Reason**: Code simplification. +Also `bool-compare` supports user defined types like + +```go +type Bool bool +``` + +And fixes assertions via casting variable to builtin `bool`: + +```go +var predicate Bool +❌ assert.Equal(t, false, predicate) +✅ assert.False(t, bool(predicate)) +``` + +To turn off this behavior use the `--bool-compare.ignore-custom-types` flag. + --- ### compares diff --git a/analyzer/analyzer_test.go b/analyzer/analyzer_test.go index 0efc971c..be598774 100644 --- a/analyzer/analyzer_test.go +++ b/analyzer/analyzer_test.go @@ -25,6 +25,14 @@ func TestTestifyLint(t *testing.T) { dir: "bool-compare-custom-types", flags: map[string]string{"disable-all": "true", "enable": checkers.NewBoolCompare().Name()}, }, + { + dir: "bool-compare-ignore-custom-types", + flags: map[string]string{ + "disable-all": "true", + "enable": checkers.NewBoolCompare().Name(), + "bool-compare.ignore-custom-types": "true", + }, + }, { dir: "checkers-priority", flags: map[string]string{"enable-all": "true"}, diff --git a/analyzer/checkers_factory.go b/analyzer/checkers_factory.go index de1d6017..77573e39 100644 --- a/analyzer/checkers_factory.go +++ b/analyzer/checkers_factory.go @@ -49,6 +49,9 @@ func newCheckers(cfg config.Config) ([]checkers.RegularChecker, []checkers.Advan } switch c := ch.(type) { + case *checkers.BoolCompare: + c.SetIgnoreCustomTypes(cfg.BoolCompare.IgnoreCustomTypes) + case *checkers.ExpectedActual: c.SetExpVarPattern(cfg.ExpectedActual.ExpVarPattern.Regexp) diff --git a/analyzer/checkers_factory_test.go b/analyzer/checkers_factory_test.go index 660b89d0..e3c7141f 100644 --- a/analyzer/checkers_factory_test.go +++ b/analyzer/checkers_factory_test.go @@ -136,6 +136,20 @@ func Test_newCheckers(t *testing.T) { checkers.NewRequireError().Name(), }), }, + { + name: "bool-compare ignore custom types", + cfg: config.Config{ + DisableAll: true, + EnabledCheckers: config.KnownCheckersValue{checkers.NewBoolCompare().Name()}, + BoolCompare: config.BoolCompareConfig{ + IgnoreCustomTypes: true, + }, + }, + expRegular: []checkers.RegularChecker{ + checkers.NewBoolCompare().SetIgnoreCustomTypes(true), + }, + expAdvanced: []checkers.AdvancedChecker{}, + }, { name: "expected-actual pattern defined", cfg: config.Config{ diff --git a/analyzer/testdata/src/bool-compare-custom-types/bool_compare_custom_types_test.go b/analyzer/testdata/src/bool-compare-custom-types/bool_compare_custom_types_test.go index f5f50e1b..b0f693bb 100644 --- a/analyzer/testdata/src/bool-compare-custom-types/bool_compare_custom_types_test.go +++ b/analyzer/testdata/src/bool-compare-custom-types/bool_compare_custom_types_test.go @@ -3,9 +3,8 @@ package boolcomparecustomtypes_test import ( "testing" - "github.com/stretchr/testify/assert" - "bool-compare-custom-types/types" + "github.com/stretchr/testify/assert" ) type MyBool bool diff --git a/analyzer/testdata/src/bool-compare-custom-types/bool_compare_custom_types_test.go.golden b/analyzer/testdata/src/bool-compare-custom-types/bool_compare_custom_types_test.go.golden index 8cc552e6..c52ce221 100644 --- a/analyzer/testdata/src/bool-compare-custom-types/bool_compare_custom_types_test.go.golden +++ b/analyzer/testdata/src/bool-compare-custom-types/bool_compare_custom_types_test.go.golden @@ -3,9 +3,8 @@ package boolcomparecustomtypes_test import ( "testing" - "github.com/stretchr/testify/assert" - "bool-compare-custom-types/types" + "github.com/stretchr/testify/assert" ) type MyBool bool diff --git a/analyzer/testdata/src/bool-compare-ignore-custom-types/bool_compare_ignore_custom_types_test.go b/analyzer/testdata/src/bool-compare-ignore-custom-types/bool_compare_ignore_custom_types_test.go new file mode 100644 index 00000000..fac04552 --- /dev/null +++ b/analyzer/testdata/src/bool-compare-ignore-custom-types/bool_compare_ignore_custom_types_test.go @@ -0,0 +1,102 @@ +package boolcomparecustomtypes_test + +import ( + "testing" + + "bool-compare-custom-types/types" + "github.com/stretchr/testify/assert" +) + +type MyBool bool + +func TestBoolCompareChecker_CustomTypes(t *testing.T) { + var b MyBool + { + assert.Equal(t, false, b) + assert.EqualValues(t, false, b) + assert.Exactly(t, false, b) + + assert.Equal(t, true, b) + assert.EqualValues(t, true, b) + assert.Exactly(t, true, b) + + assert.NotEqual(t, false, b) + assert.NotEqualValues(t, false, b) + + assert.NotEqual(t, true, b) + assert.NotEqualValues(t, true, b) + + assert.True(t, b == true) + assert.True(t, b != false) + assert.True(t, b == false) + assert.True(t, b != true) + + assert.False(t, b == true) + assert.False(t, b != false) + assert.False(t, b == false) + assert.False(t, b != true) + } + + var extB types.Bool + { + assert.Equal(t, false, extB) + assert.EqualValues(t, false, extB) + assert.Exactly(t, false, extB) + + assert.Equal(t, true, extB) + assert.EqualValues(t, true, extB) + assert.Exactly(t, true, extB) + + assert.NotEqual(t, false, extB) + assert.NotEqualValues(t, false, extB) + + assert.NotEqual(t, true, extB) + assert.NotEqualValues(t, true, extB) + + assert.True(t, extB == true) + assert.True(t, extB != false) + assert.True(t, extB == false) + assert.True(t, extB != true) + + assert.False(t, extB == true) + assert.False(t, extB != false) + assert.False(t, extB == false) + assert.False(t, extB != true) + } + + var extSuperB types.SuperBool + { + assert.Equal(t, false, extSuperB) + assert.EqualValues(t, false, extSuperB) + assert.Exactly(t, false, extSuperB) + + assert.Equal(t, true, extSuperB) + assert.EqualValues(t, true, extSuperB) + assert.Exactly(t, true, extSuperB) + + assert.NotEqual(t, false, extSuperB) + assert.NotEqualValues(t, false, extSuperB) + + assert.NotEqual(t, true, extSuperB) + assert.NotEqualValues(t, true, extSuperB) + + assert.True(t, extSuperB == true) + assert.True(t, extSuperB != false) + assert.True(t, extSuperB == false) + assert.True(t, extSuperB != true) + + assert.False(t, extSuperB == true) + assert.False(t, extSuperB != false) + assert.False(t, extSuperB == false) + assert.False(t, extSuperB != true) + } + + // Crazy cases: + { + assert.Equal(t, true, types.Bool(extSuperB)) + assert.Equal(t, true, types.SuperBool(b)) + assert.Equal(t, true, bool(types.SuperBool(b))) // want "bool-compare: use assert\\.True" + assert.True(t, !bool(types.SuperBool(b))) // want "bool-compare: use assert\\.False" + assert.False(t, !bool(types.SuperBool(b))) // want "bool-compare: use assert\\.True" + } +} diff --git a/analyzer/testdata/src/bool-compare-ignore-custom-types/bool_compare_ignore_custom_types_test.go.golden b/analyzer/testdata/src/bool-compare-ignore-custom-types/bool_compare_ignore_custom_types_test.go.golden new file mode 100644 index 00000000..5a69d646 --- /dev/null +++ b/analyzer/testdata/src/bool-compare-ignore-custom-types/bool_compare_ignore_custom_types_test.go.golden @@ -0,0 +1,102 @@ +package boolcomparecustomtypes_test + +import ( + "testing" + + "bool-compare-custom-types/types" + "github.com/stretchr/testify/assert" +) + +type MyBool bool + +func TestBoolCompareChecker_CustomTypes(t *testing.T) { + var b MyBool + { + assert.Equal(t, false, b) + assert.EqualValues(t, false, b) + assert.Exactly(t, false, b) + + assert.Equal(t, true, b) + assert.EqualValues(t, true, b) + assert.Exactly(t, true, b) + + assert.NotEqual(t, false, b) + assert.NotEqualValues(t, false, b) + + assert.NotEqual(t, true, b) + assert.NotEqualValues(t, true, b) + + assert.True(t, b == true) + assert.True(t, b != false) + assert.True(t, b == false) + assert.True(t, b != true) + + assert.False(t, b == true) + assert.False(t, b != false) + assert.False(t, b == false) + assert.False(t, b != true) + } + + var extB types.Bool + { + assert.Equal(t, false, extB) + assert.EqualValues(t, false, extB) + assert.Exactly(t, false, extB) + + assert.Equal(t, true, extB) + assert.EqualValues(t, true, extB) + assert.Exactly(t, true, extB) + + assert.NotEqual(t, false, extB) + assert.NotEqualValues(t, false, extB) + + assert.NotEqual(t, true, extB) + assert.NotEqualValues(t, true, extB) + + assert.True(t, extB == true) + assert.True(t, extB != false) + assert.True(t, extB == false) + assert.True(t, extB != true) + + assert.False(t, extB == true) + assert.False(t, extB != false) + assert.False(t, extB == false) + assert.False(t, extB != true) + } + + var extSuperB types.SuperBool + { + assert.Equal(t, false, extSuperB) + assert.EqualValues(t, false, extSuperB) + assert.Exactly(t, false, extSuperB) + + assert.Equal(t, true, extSuperB) + assert.EqualValues(t, true, extSuperB) + assert.Exactly(t, true, extSuperB) + + assert.NotEqual(t, false, extSuperB) + assert.NotEqualValues(t, false, extSuperB) + + assert.NotEqual(t, true, extSuperB) + assert.NotEqualValues(t, true, extSuperB) + + assert.True(t, extSuperB == true) + assert.True(t, extSuperB != false) + assert.True(t, extSuperB == false) + assert.True(t, extSuperB != true) + + assert.False(t, extSuperB == true) + assert.False(t, extSuperB != false) + assert.False(t, extSuperB == false) + assert.False(t, extSuperB != true) + } + + // Crazy cases: + { + assert.Equal(t, true, types.Bool(extSuperB)) + assert.Equal(t, true, types.SuperBool(b)) + assert.True(t, bool(types.SuperBool(b))) // want "bool-compare: use assert\\.True" + assert.False(t, bool(types.SuperBool(b))) // want "bool-compare: use assert\\.False" + assert.True(t, bool(types.SuperBool(b))) // want "bool-compare: use assert\\.True" + } +} diff --git a/internal/checkers/bool_compare.go b/internal/checkers/bool_compare.go index 0cad22a5..6ede3d28 100644 --- a/internal/checkers/bool_compare.go +++ b/internal/checkers/bool_compare.go @@ -25,11 +25,18 @@ import ( // // assert.False(t, result) // assert.True(t, result) -type BoolCompare struct{} // +type BoolCompare struct { + ignoreCustomTypes bool +} // NewBoolCompare constructs BoolCompare checker. -func NewBoolCompare() BoolCompare { return BoolCompare{} } -func (BoolCompare) Name() string { return "bool-compare" } +func NewBoolCompare() *BoolCompare { return new(BoolCompare) } +func (BoolCompare) Name() string { return "bool-compare" } + +func (checker *BoolCompare) SetIgnoreCustomTypes(v bool) *BoolCompare { + checker.ignoreCustomTypes = v + return checker +} func (checker BoolCompare) Check(pass *analysis.Pass, call *CallMeta) *analysis.Diagnostic { newBoolCast := func(e ast.Expr) ast.Expr { @@ -38,6 +45,9 @@ func (checker BoolCompare) Check(pass *analysis.Pass, call *CallMeta) *analysis. newUseFnDiagnostic := func(proposed string, survivingArg ast.Expr, replaceStart, replaceEnd token.Pos) *analysis.Diagnostic { if !isBuiltinBool(pass, survivingArg) { + if checker.ignoreCustomTypes { + return nil + } survivingArg = newBoolCast(survivingArg) } return newUseFunctionDiagnostic(checker.Name(), call, proposed, @@ -59,6 +69,9 @@ func (checker BoolCompare) Check(pass *analysis.Pass, call *CallMeta) *analysis. newNeedSimplifyDiagnostic := func(survivingArg ast.Expr, replaceStart, replaceEnd token.Pos) *analysis.Diagnostic { if !isBuiltinBool(pass, survivingArg) { + if checker.ignoreCustomTypes { + return nil + } survivingArg = newBoolCast(survivingArg) } return newDiagnostic(checker.Name(), call, "need to simplify the assertion", diff --git a/internal/config/config.go b/internal/config/config.go index 6dcfbdb5..9e559e8a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -34,11 +34,17 @@ type Config struct { DisableAll bool EnabledCheckers KnownCheckersValue + BoolCompare BoolCompareConfig ExpectedActual ExpectedActualConfig RequireError RequireErrorConfig SuiteExtraAssertCall SuiteExtraAssertCallConfig } +// BoolCompareConfig implements configuration of checkers.BoolCompare. +type BoolCompareConfig struct { + IgnoreCustomTypes bool +} + // ExpectedActualConfig implements configuration of checkers.ExpectedActual. type ExpectedActualConfig struct { ExpVarPattern RegexpValue @@ -91,6 +97,8 @@ func BindToFlags(cfg *Config, fs *flag.FlagSet) { fs.BoolVar(&cfg.DisableAll, "disable-all", false, "disable all checkers") fs.Var(&cfg.EnabledCheckers, "enable", "comma separated list of enabled checkers (in addition to enabled by default)") + fs.BoolVar(&cfg.BoolCompare.IgnoreCustomTypes, "bool-compare.ignore-custom-types", false, + "ignore user defined types (over builtin `bool`)") fs.Var(&cfg.ExpectedActual.ExpVarPattern, "expected-actual.pattern", "regexp for expected variable name") fs.Var(&cfg.RequireError.FnPattern, "require-error.fn-pattern", "regexp for error assertions that should only be analyzed") fs.Var(NewEnumValue(suiteExtraAssertCallModeAsString, &cfg.SuiteExtraAssertCall.Mode), diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 06322d47..3e63fc9f 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -23,6 +23,9 @@ func TestNewDefault(t *testing.T) { if len(cfg.EnabledCheckers) != 0 { t.Fatal() } + if cfg.BoolCompare.IgnoreCustomTypes != false { + t.Fatal() + } if cfg.ExpectedActual.ExpVarPattern.String() != checkers.DefaultExpectedVarPattern.String() { t.Fatal() } @@ -144,12 +147,14 @@ func TestBindToFlags(t *testing.T) { config.BindToFlags(&cfg, fs) for flagName, defaultVal := range map[string]string{ - "enable-all": "false", - "disable": "", - "disable-all": "false", - "enable": "", - "expected-actual.pattern": cfg.ExpectedActual.ExpVarPattern.String(), - "suite-extra-assert-call.mode": "remove", + "enable-all": "false", + "disable": "", + "disable-all": "false", + "enable": "", + "bool-compare.ignore-custom-types": "false", + "expected-actual.pattern": cfg.ExpectedActual.ExpVarPattern.String(), + "require-error.fn-pattern": cfg.RequireError.FnPattern.String(), + "suite-extra-assert-call.mode": "remove", } { t.Run(flagName, func(t *testing.T) { if v := fs.Lookup(flagName).DefValue; v != defaultVal { From 249d25279a010709bca4515da8914159f86554b5 Mon Sep 17 00:00:00 2001 From: Anton Telyshev Date: Sun, 3 Mar 2024 17:50:52 +0300 Subject: [PATCH 3/3] self-review fixes --- internal/checkers/bool_compare.go | 36 ++++--------------------------- internal/config/config.go | 2 +- 2 files changed, 5 insertions(+), 33 deletions(-) diff --git a/internal/checkers/bool_compare.go b/internal/checkers/bool_compare.go index 6ede3d28..43907123 100644 --- a/internal/checkers/bool_compare.go +++ b/internal/checkers/bool_compare.go @@ -157,15 +157,8 @@ func (checker BoolCompare) Check(pass *analysis.Pass, call *CallMeta) *analysis. arg1, ok1 := isComparisonWithTrue(pass, expr, token.EQL) arg2, ok2 := isComparisonWithFalse(pass, expr, token.NEQ) - if anyCondSatisfaction(pass, isEmptyInterface, arg1, arg2) { - return nil - } - if anyCondSatisfaction(pass, isBoolOverride, arg1, arg2) { - return nil - } - survivingArg, ok := anyVal([]bool{ok1, ok2}, arg1, arg2) - if ok { + if ok && !isEmptyInterface(pass, survivingArg) { return newNeedSimplifyDiagnostic(survivingArg, expr.Pos(), expr.End()) } } @@ -175,15 +168,8 @@ func (checker BoolCompare) Check(pass *analysis.Pass, call *CallMeta) *analysis. arg2, ok2 := isComparisonWithFalse(pass, expr, token.EQL) arg3, ok3 := isNegation(expr) - if anyCondSatisfaction(pass, isEmptyInterface, arg1, arg2, arg3) { - return nil - } - if anyCondSatisfaction(pass, isBoolOverride, arg1, arg2, arg3) { - return nil - } - survivingArg, ok := anyVal([]bool{ok1, ok2, ok3}, arg1, arg2, arg3) - if ok { + if ok && !isEmptyInterface(pass, survivingArg) { return newUseFalseDiagnostic(survivingArg, expr.Pos(), expr.End()) } } @@ -198,15 +184,8 @@ func (checker BoolCompare) Check(pass *analysis.Pass, call *CallMeta) *analysis. arg1, ok1 := isComparisonWithTrue(pass, expr, token.EQL) arg2, ok2 := isComparisonWithFalse(pass, expr, token.NEQ) - if anyCondSatisfaction(pass, isEmptyInterface, arg1, arg2) { - return nil - } - if anyCondSatisfaction(pass, isBoolOverride, arg1, arg2) { - return nil - } - survivingArg, ok := anyVal([]bool{ok1, ok2}, arg1, arg2) - if ok { + if ok && !isEmptyInterface(pass, survivingArg) { return newNeedSimplifyDiagnostic(survivingArg, expr.Pos(), expr.End()) } } @@ -216,15 +195,8 @@ func (checker BoolCompare) Check(pass *analysis.Pass, call *CallMeta) *analysis. arg2, ok2 := isComparisonWithFalse(pass, expr, token.EQL) arg3, ok3 := isNegation(expr) - if anyCondSatisfaction(pass, isEmptyInterface, arg1, arg2, arg3) { - return nil - } - if anyCondSatisfaction(pass, isBoolOverride, arg1, arg2, arg3) { - return nil - } - survivingArg, ok := anyVal([]bool{ok1, ok2, ok3}, arg1, arg2, arg3) - if ok { + if ok && !isEmptyInterface(pass, survivingArg) { return newUseTrueDiagnostic(survivingArg, expr.Pos(), expr.End()) } } diff --git a/internal/config/config.go b/internal/config/config.go index 9e559e8a..7eba0ea3 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -98,7 +98,7 @@ func BindToFlags(cfg *Config, fs *flag.FlagSet) { fs.Var(&cfg.EnabledCheckers, "enable", "comma separated list of enabled checkers (in addition to enabled by default)") fs.BoolVar(&cfg.BoolCompare.IgnoreCustomTypes, "bool-compare.ignore-custom-types", false, - "ignore user defined types (over builtin `bool`)") + "ignore user defined types (over builtin bool)") fs.Var(&cfg.ExpectedActual.ExpVarPattern, "expected-actual.pattern", "regexp for expected variable name") fs.Var(&cfg.RequireError.FnPattern, "require-error.fn-pattern", "regexp for error assertions that should only be analyzed") fs.Var(NewEnumValue(suiteExtraAssertCallModeAsString, &cfg.SuiteExtraAssertCall.Mode),