Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use go-cmp instead of reflect.DeepEqual #99

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions gotests.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ import (
"github.com/cweill/gotests/internal/output"
)

var (
cmpImport = &models.Import{
Name: "",
Path: `"github.com/google/go-cmp/cmp"`,
}
)

// Options provides custom filters and parameters for generating tests.
type Options struct {
Only *regexp.Regexp // Includes only functions that match.
Expand Down Expand Up @@ -142,8 +149,20 @@ func parseTestFile(p *goparser.Parser, testPath string, h *models.Header) (*mode
return nil, nil, fmt.Errorf("Parser.Parse test file: %v", err)
}
var testFuncs []string
cmpImportNeeded := false
for _, fun := range tr.Funcs {
testFuncs = append(testFuncs, fun.Name)
if cmpImportNeeded {
continue
}
for _, field := range fun.Parameters {
if !(field.IsWriter() || field.IsBasicType()) {
cmpImportNeeded = true
}
}
}
if cmpImportNeeded {
tr.Header.Imports = append(tr.Header.Imports, cmpImport)
}
tr.Header.Imports = append(tr.Header.Imports, h.Imports...)
h = tr.Header
Expand Down
3 changes: 2 additions & 1 deletion gotests_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"testing"
"unicode"

"github.com/google/go-cmp/cmp"
"golang.org/x/tools/imports"
)

Expand Down Expand Up @@ -643,7 +644,7 @@ func TestGenerateTests(t *testing.T) {
continue
}
if got := string(gts[0].Output); got != tt.want {
t.Errorf("%q. GenerateTests(%v) = \n%v, want \n%v", tt.name, tt.args.srcPath, got, tt.want)
t.Errorf("%q. GenerateTests(%v) = diff=%v", tt.name, tt.args.srcPath, cmp.Diff(got, tt.want))
outputResult(t, tmp, tt.name, gts[0].Output)
}
}
Expand Down
43 changes: 22 additions & 21 deletions internal/render/bindata/esc.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions internal/render/templates/function.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,10 @@ func {{.TestName}}(t *testing.T) {
{{- else if .IsBasicType}}
if {{if $f.OnlyReturnsOneValue}}{{Got .}} := {{template "inline" $f}}; {{end}} {{Got .}} != tt.{{Want .}} {
{{- else}}
if {{if $f.OnlyReturnsOneValue}}{{Got .}} := {{template "inline" $f}}; {{end}} !reflect.DeepEqual({{Got .}}, tt.{{Want .}}) {
if {{if $f.OnlyReturnsOneValue}}{{Got .}} := {{template "inline" $f}}; {{end}} !cmp.Equal({{Got .}}, tt.{{Want .}}) {
{{- end}}
t.Errorf("{{template "message" $f}} {{if $f.ReturnsMultiple}}{{Got .}} {{end}}= %v, want %v", {{template "inputs" $f}} {{Got .}}, tt.{{Want .}})
t.Errorf("{{template "message" $f}} {{if $f.ReturnsMultiple}}{{Got .}} {{end}}= %v, want %v{{ if (not ( or .IsWriter .IsBasicType))}}\ndiff=%v{{ end }}", {{template "inputs" $f}} {{Got .}}, tt.{{Want .}}
{{- if (not ( or .IsWriter .IsBasicType))}}, cmp.Diff({{Got .}}, tt.{{Want .}}){{ end }})
}
{{- end}}
{{- if .Subtests }} }) {{- end -}}
Expand Down
13 changes: 7 additions & 6 deletions testdata/goldens/custom_importer_fails.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package testdata

import (
"reflect"
"testing"

"github.com/google/go-cmp/cmp"
)

func TestFooFilter(t *testing.T) {
Expand All @@ -15,16 +16,16 @@ func TestFooFilter(t *testing.T) {
want []*Bar
wantErr bool
}{
// TODO: Add test cases.
// TODO: Add test cases.
}
for _, tt := range tests {
got, err := FooFilter(tt.args.strs)
if (err != nil) != tt.wantErr {
t.Errorf("%q. FooFilter() error = %v, wantErr %v", tt.name, err, tt.wantErr)
continue
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("%q. FooFilter() = %v, want %v", tt.name, got, tt.want)
if !cmp.Equal(got, tt.want) {
t.Errorf("%q. FooFilter() = %v, want %v\ndiff=%v", tt.name, got, tt.want, cmp.Diff(got, tt.want))
}
}
}
Expand All @@ -39,7 +40,7 @@ func TestBar_BarFilter(t *testing.T) {
args args
wantErr bool
}{
// TODO: Add test cases.
// TODO: Add test cases.
}
for _, tt := range tests {
b := &Bar{}
Expand All @@ -58,7 +59,7 @@ func Test_bazFilter(t *testing.T) {
args args
want float64
}{
// TODO: Add test cases.
// TODO: Add test cases.
}
for _, tt := range tests {
if got := bazFilter(tt.args.f); got != tt.want {
Expand Down
13 changes: 7 additions & 6 deletions testdata/goldens/existing_test_file.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package testdata

import (
"reflect"
"testing"

"github.com/google/go-cmp/cmp"
)

func TestBarBar100(t *testing.T) {
Expand Down Expand Up @@ -55,16 +56,16 @@ func TestFoo100(t *testing.T) {
want []*Bar
wantErr bool
}{
// TODO: Add test cases.
// TODO: Add test cases.
}
for _, tt := range tests {
got, err := Foo100(tt.args.strs)
if (err != nil) != tt.wantErr {
t.Errorf("%q. Foo100() error = %v, wantErr %v", tt.name, err, tt.wantErr)
continue
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("%q. Foo100() = %v, want %v", tt.name, got, tt.want)
if !cmp.Equal(got, tt.want) {
t.Errorf("%q. Foo100() = %v, want %v\ndiff=%v", tt.name, got, tt.want, cmp.Diff(got, tt.want))
}
}
}
Expand All @@ -79,7 +80,7 @@ func TestBar_Bar100(t *testing.T) {
args args
wantErr bool
}{
// TODO: Add test cases.
// TODO: Add test cases.
}
for _, tt := range tests {
b := &Bar{}
Expand All @@ -98,7 +99,7 @@ func Test_baz100(t *testing.T) {
args args
want float64
}{
// TODO: Add test cases.
// TODO: Add test cases.
}
for _, tt := range tests {
if got := baz100(tt.args.f); got != tt.want {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package testdata

import (
"reflect"
"testing"

"github.com/google/go-cmp/cmp"
)

func TestFoo25(t *testing.T) {
Expand All @@ -16,7 +17,7 @@ func TestFoo25(t *testing.T) {
want1 []byte
wantErr bool
}{
// TODO: Add test cases.
// TODO: Add test cases.
}
for _, tt := range tests {
got, got1, err := Foo25(tt.args.in0)
Expand All @@ -27,8 +28,8 @@ func TestFoo25(t *testing.T) {
if got != tt.want {
t.Errorf("%q. Foo25() got = %v, want %v", tt.name, got, tt.want)
}
if !reflect.DeepEqual(got1, tt.want1) {
t.Errorf("%q. Foo25() got1 = %v, want %v", tt.name, got1, tt.want1)
if !cmp.Equal(got1, tt.want1) {
t.Errorf("%q. Foo25() got1 = %v, want %v\ndiff=%v", tt.name, got1, tt.want1, cmp.Diff(got1, tt.want1))
}
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package testdata

import (
"reflect"
"testing"

"github.com/google/go-cmp/cmp"
)

func TestFoo23(t *testing.T) {
Expand All @@ -14,11 +15,11 @@ func TestFoo23(t *testing.T) {
args args
want chan string
}{
// TODO: Add test cases.
// TODO: Add test cases.
}
for _, tt := range tests {
if got := Foo23(tt.args.ch); !reflect.DeepEqual(got, tt.want) {
t.Errorf("%q. Foo23() = %v, want %v", tt.name, got, tt.want)
if got := Foo23(tt.args.ch); !cmp.Equal(got, tt.want) {
t.Errorf("%q. Foo23() = %v, want %v\ndiff=%v", tt.name, got, tt.want, cmp.Diff(got, tt.want))
}
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package testdata

import (
"reflect"
"testing"

"github.com/google/go-cmp/cmp"
)

func TestFoo16(t *testing.T) {
Expand All @@ -14,11 +15,11 @@ func TestFoo16(t *testing.T) {
args args
want Bazzar
}{
// TODO: Add test cases.
// TODO: Add test cases.
}
for _, tt := range tests {
if got := Foo16(tt.args.in); !reflect.DeepEqual(got, tt.want) {
t.Errorf("%q. Foo16() = %v, want %v", tt.name, got, tt.want)
if got := Foo16(tt.args.in); !cmp.Equal(got, tt.want) {
t.Errorf("%q. Foo16() = %v, want %v\ndiff=%v", tt.name, got, tt.want, cmp.Diff(got, tt.want))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ package testdata

import (
"io"
"reflect"
"testing"

"github.com/google/go-cmp/cmp"
)

func TestFoo17(t *testing.T) {
Expand All @@ -15,11 +16,11 @@ func TestFoo17(t *testing.T) {
args args
want io.Reader
}{
// TODO: Add test cases.
// TODO: Add test cases.
}
for _, tt := range tests {
if got := Foo17(tt.args.r); !reflect.DeepEqual(got, tt.want) {
t.Errorf("%q. Foo17() = %v, want %v", tt.name, got, tt.want)
if got := Foo17(tt.args.r); !cmp.Equal(got, tt.want) {
t.Errorf("%q. Foo17() = %v, want %v\ndiff=%v", tt.name, got, tt.want, cmp.Diff(got, tt.want))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ package testdata

import (
"os"
"reflect"
"testing"

"github.com/google/go-cmp/cmp"
)

func TestFoo18(t *testing.T) {
Expand All @@ -15,11 +16,11 @@ func TestFoo18(t *testing.T) {
args args
want *os.File
}{
// TODO: Add test cases.
// TODO: Add test cases.
}
for _, tt := range tests {
if got := Foo18(tt.args.t); !reflect.DeepEqual(got, tt.want) {
t.Errorf("%q. Foo18() = %v, want %v", tt.name, got, tt.want)
if got := Foo18(tt.args.t); !cmp.Equal(got, tt.want) {
t.Errorf("%q. Foo18() = %v, want %v\ndiff=%v", tt.name, got, tt.want, cmp.Diff(got, tt.want))
}
}
}
Loading