diff --git a/check_test.go b/check_test.go index 1261998..a9c7712 100644 --- a/check_test.go +++ b/check_test.go @@ -167,6 +167,50 @@ func main() { assert.Len(t, errs, 0) } +// TestNoMissingInterface tests that we correctly detect exhaustive case when +// there is an interface which implements the interface we are going to check. +func TestNoMissingInterface(t *testing.T) { + code := ` +package main + +//go-sumtype:decl T + +type T interface { + sealedT() +} + +type A struct {} +func (a *A) sealedT() {} + +type U interface { + T + sealedU() +} + +type B struct {} +func (b *B) sealedT() {} +func (b *B) sealedU() {} + +type C struct {} +func (c *C) sealedT() {} +func (c *C) sealedU() {} + +func main() { + switch T(nil).(type) { + case *A, *B, *C: + } + switch T(nil).(type) { + case *A, U: + } +} +` + tmpdir, pkgs := setupPackages(t, code) + defer teardownPackage(t, tmpdir) + + errs := run(pkgs) + assert.Len(t, errs, 0) +} + // TestNotSealed tests that we report an error if one tries to declare a sum // type with an unsealed interface. func TestNotSealed(t *testing.T) { diff --git a/def.go b/def.go index 01067f5..62ff7ce 100644 --- a/def.go +++ b/def.go @@ -107,6 +107,10 @@ func newSumTypeDef(pkg *types.Package, decl sumTypeDecl) (*sumTypeDef, error) { if types.Identical(ty.Underlying(), iface) { continue } + _, ok = ty.Underlying().(*types.Interface) + if ok { + continue + } if types.Implements(ty, iface) || types.Implements(types.NewPointer(ty), iface) { def.Variants = append(def.Variants, obj) } @@ -131,6 +135,11 @@ func (def *sumTypeDef) missing(tys []types.Type) []types.Object { if types.Identical(varty, ty) { found = true } + if iface, ok := ty.Underlying().(*types.Interface); ok { + if types.Implements(varty, iface) || types.Implements(types.NewPointer(varty), iface) { + found = true + } + } } if !found { missing = append(missing, v)