Skip to content

Commit

Permalink
Add more type checks
Browse files Browse the repository at this point in the history
  • Loading branch information
antonmedv committed Aug 9, 2018
1 parent aae3b0f commit 653a058
Show file tree
Hide file tree
Showing 4 changed files with 204 additions and 57 deletions.
16 changes: 16 additions & 0 deletions eval_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,22 @@ var evalTests = []evalTest{
}{1, 1},
true,
},
{
`A == B`,
struct {
A float64
B interface{}
}{1, new(interface{})},
false,
},
{
`A == B`,
struct {
A interface{}
B float64
}{new(interface{}), 1},
false,
},
{
`[true][A]`,
&struct{ A int }{0},
Expand Down
129 changes: 91 additions & 38 deletions type.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (n unaryNode) Type(table typesTable) (Type, error) {

switch n.operator {
case "!", "not":
if isBoolType(ntype) {
if isBoolType(ntype) || isInterfaceType(ntype) {
return boolType, nil
}
return nil, fmt.Errorf(`invalid operation: %v (mismatched type %v)`, n, ntype)
Expand All @@ -80,8 +80,15 @@ func (n binaryNode) Type(table typesTable) (Type, error) {
return boolType, nil
}
return nil, fmt.Errorf(`invalid operation: %v (mismatched types %v and %v)`, n, ltype, rtype)

case "or", "||", "and", "&&":
if isBoolType(ltype) && isBoolType(rtype) {
if (isBoolType(ltype) || isInterfaceType(ltype)) && (isBoolType(rtype) || isInterfaceType(rtype)) {
return boolType, nil
}
return nil, fmt.Errorf(`invalid operation: %v (mismatched types %v and %v)`, n, ltype, rtype)

case "|", "^", "&", "<", ">", ">=", "<=", "+", "-", "*", "/", "%", "**", "..":
if (isNumberType(ltype) || isInterfaceType(ltype)) && (isNumberType(rtype) || isInterfaceType(rtype)) {
return boolType, nil
}
return nil, fmt.Errorf(`invalid operation: %v (mismatched types %v and %v)`, n, ltype, rtype)
Expand All @@ -91,7 +98,19 @@ func (n binaryNode) Type(table typesTable) (Type, error) {
}

func (n matchesNode) Type(table typesTable) (Type, error) {
return boolType, nil
var err error
ltype, err := n.left.Type(table)
if err != nil {
return nil, err
}
rtype, err := n.right.Type(table)
if err != nil {
return nil, err
}
if (isStringType(ltype) || isInterfaceType(ltype)) && (isStringType(rtype) || isInterfaceType(rtype)) {
return boolType, nil
}
return nil, fmt.Errorf(`invalid operation: %v (mismatched types %v and %v)`, n, ltype, rtype)
}

func (n propertyNode) Type(table typesTable) (Type, error) {
Expand Down Expand Up @@ -141,8 +160,14 @@ func (n methodNode) Type(table typesTable) (Type, error) {
}

func (n builtinNode) Type(table typesTable) (Type, error) {
for _, node := range n.arguments {
_, err := node.Type(table)
if err != nil {
return nil, err
}
}
if _, ok := builtins[n.name]; ok {
return nil, nil
return interfaceType, nil
}
return nil, fmt.Errorf("%v undefined", n)
}
Expand All @@ -167,7 +192,7 @@ func (n conditionalNode) Type(table typesTable) (Type, error) {
if err != nil {
return nil, err
}
if !isBoolType(ctype) {
if !isBoolType(ctype) && !isInterfaceType(ctype) {
return nil, fmt.Errorf("non-bool %v (type %v) used as condition", n.cond, ctype)
}
_, err = n.exp1.Type(table)
Expand Down Expand Up @@ -216,60 +241,88 @@ func (n pairNode) Type(table typesTable) (Type, error) {

// helper funcs for reflect

func isComparable(ltype Type, rtype Type) bool {
ltype = dereference(ltype)
if ltype == nil {
return true
}
rtype = dereference(rtype)
if rtype == nil {
return true
func isComparable(l Type, r Type) bool {
l = dereference(l)
r = dereference(r)

if l == nil || r == nil {
return true // It is possible to compare with nil.
}

if canBeNumberType(ltype) && canBeNumberType(rtype) {
if isNumberType(l) && isNumberType(r) {
return true
} else if ltype.Kind() == reflect.Interface {
} else if l.Kind() == reflect.Interface {
return true
} else if rtype.Kind() == reflect.Interface {
} else if r.Kind() == reflect.Interface {
return true
} else if ltype == rtype {
} else if l == r {
return true
}
return false
}

func isBoolType(ntype Type) bool {
ntype = dereference(ntype)
if ntype == nil {
return false
func isInterfaceType(t Type) bool {
t = dereference(t)
if t != nil {
switch t.Kind() {
case reflect.Interface:
return true
}
}
return false
}

switch ntype.Kind() {
case reflect.Interface:
return true
case reflect.Bool:
return true
func isNumberType(t Type) bool {
t = dereference(t)
if t != nil {
switch t.Kind() {
case reflect.Float32, reflect.Float64:
fallthrough
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
fallthrough
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return true
}
}
return false
}

func fieldType(ntype Type, name string) (Type, bool) {
ntype = dereference(ntype)
if ntype == nil {
return nil, false
func isBoolType(t Type) bool {
t = dereference(t)
if t != nil {
switch t.Kind() {
case reflect.Bool:
return true
}
}
return false
}

switch ntype.Kind() {
case reflect.Interface:
return interfaceType, true
case reflect.Struct:
if t, ok := ntype.FieldByName(name); ok {
return t.Type, true
func isStringType(t Type) bool {
t = dereference(t)
if t != nil {
switch t.Kind() {
case reflect.String:
return true
}
case reflect.Map:
return ntype.Elem(), true
}
return false
}

func fieldType(ntype Type, name string) (Type, bool) {
ntype = dereference(ntype)
if ntype != nil {
switch ntype.Kind() {
case reflect.Interface:
return interfaceType, true
case reflect.Struct:
if t, ok := ntype.FieldByName(name); ok {
return t.Type, true
}
case reflect.Map:
return ntype.Elem(), true
}
}
return nil, false
}

Expand Down
100 changes: 96 additions & 4 deletions type_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ var typeTests = []typeTest{
"Fn(Any)",
"Foo.Fn()",
"true ? Any : Any",
"Ok && Any",
"Str matches 'ok'",
"Str matches Any",
"Any matches Any",
"len([])",
"true == false",
"nil",
Expand All @@ -39,6 +43,7 @@ var typeTests = []typeTest{
"Num == Abc",
"Abc == Num",
"1 == 2 and true or Ok",
"Int == Any",
"IntPtr == Int",
"!OkPtr == Ok",
"1 == NumPtr",
Expand All @@ -47,6 +52,20 @@ var typeTests = []typeTest{
"nil == nil",
"nil == IntPtr",
"Foo2p.Bar.Baz",
"Int | Num",
"Int ^ Num",
"Int & Num",
"Int < Num",
"Int > Num",
"Int >= Num",
"Int <= Num",
"Int + Num",
"Int - Num",
"Int * Num",
"Int / Num",
"Int % Num",
"Int ** Num",
"Int .. Num",
}

var typeErrorTests = []typeErrorTest{
Expand Down Expand Up @@ -110,6 +129,10 @@ var typeErrorTests = []typeErrorTest{
"Map['str'].Not",
`Map["str"].Not undefined (type *expr_test.foo has no field Not)`,
},
{
"Ok && IntPtr",
"invalid operation: (Ok && IntPtr) (mismatched types bool and *int)",
},
{
"No ? Any.Ok : Any.Not",
"unknown name No",
Expand All @@ -123,8 +146,16 @@ var typeErrorTests = []typeErrorTest{
"unknown name No",
},
{
"Any ? Any : Any",
"non-bool Any (type map[string]interface {}) used as condition",
"Many ? Any : Any",
"non-bool Many (type map[string]interface {}) used as condition",
},
{
"Str matches Int",
"invalid operation: (Str matches Int) (mismatched types string and int)",
},
{
"Int matches Str",
"invalid operation: (Int matches Str) (mismatched types int and string)",
},
{
"!Not",
Expand Down Expand Up @@ -166,6 +197,66 @@ var typeErrorTests = []typeErrorTest{
"not IntPtr",
"invalid operation: not IntPtr (mismatched type *int)",
},
{
"len(Not)",
"unknown name Not",
},
{
"Int | Ok",
"invalid operation: (Int | Ok) (mismatched types int and bool)",
},
{
"Int ^ Ok",
"invalid operation: (Int ^ Ok) (mismatched types int and bool)",
},
{
"Int & Ok",
"invalid operation: (Int & Ok) (mismatched types int and bool)",
},
{
"Int < Ok",
"invalid operation: (Int < Ok) (mismatched types int and bool)",
},
{
"Int > Ok",
"invalid operation: (Int > Ok) (mismatched types int and bool)",
},
{
"Int >= Ok",
"invalid operation: (Int >= Ok) (mismatched types int and bool)",
},
{
"Int <= Ok",
"invalid operation: (Int <= Ok) (mismatched types int and bool)",
},
{
"Int + Ok",
"invalid operation: (Int + Ok) (mismatched types int and bool)",
},
{
"Int - Ok",
"invalid operation: (Int - Ok) (mismatched types int and bool)",
},
{
"Int * Ok",
"invalid operation: (Int * Ok) (mismatched types int and bool)",
},
{
"Int / Ok",
"invalid operation: (Int / Ok) (mismatched types int and bool)",
},
{
"Int % Ok",
"invalid operation: (Int % Ok) (mismatched types int and bool)",
},
{
"Int ** Ok",
"invalid operation: (Int ** Ok) (mismatched types int and bool)",
},
{
"Int .. Ok",
"invalid operation: (Int .. Ok) (mismatched types int and bool)",
},
}

type abc interface {
Expand All @@ -183,9 +274,10 @@ type payload struct {
Abc abc
Foo *foo
Arr []*foo
Irr []interface{}
Map map[string]*foo
Any map[string]interface{}
Any interface{}
Irr []interface{}
Many map[string]interface{}
Fn func()
Ok bool
Num float64
Expand Down
16 changes: 1 addition & 15 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,21 +55,7 @@ func cast(v interface{}) (float64, error) {

func canBeNumber(v interface{}) bool {
if v != nil {
return canBeNumberType(reflect.TypeOf(v))
}
return false
}

func canBeNumberType(t Type) bool {
if t != nil {
switch t.Kind() {
case reflect.Float32, reflect.Float64:
fallthrough
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
fallthrough
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return true
}
return isNumberType(reflect.TypeOf(v))
}
return false
}
Expand Down

0 comments on commit 653a058

Please sign in to comment.