Skip to content

Commit 546dfc9

Browse files
committed
Add double type-check when a patch is present
1 parent 8ec0158 commit 546dfc9

File tree

3 files changed

+133
-61
lines changed

3 files changed

+133
-61
lines changed

checker/checker.go

+56-58
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,7 @@ import (
1010
"github.com/antonmedv/expr/parser"
1111
)
1212

13-
func Check(tree *parser.Tree, config *conf.Config) (t reflect.Type, err error) {
14-
defer func() {
15-
if r := recover(); r != nil {
16-
if h, ok := r.(file.Error); ok {
17-
err = fmt.Errorf("%v", h.Format(tree.Source))
18-
} else {
19-
err = fmt.Errorf("%v", r)
20-
}
21-
}
22-
}()
23-
13+
func Check(tree *parser.Tree, config *conf.Config) (reflect.Type, error) {
2414
v := &visitor{
2515
collections: make([]reflect.Type, 0),
2616
}
@@ -32,7 +22,7 @@ func Check(tree *parser.Tree, config *conf.Config) (t reflect.Type, err error) {
3222
v.defaultType = config.DefaultType
3323
}
3424

35-
t = v.visit(tree.Node)
25+
t := v.visit(tree.Node)
3626

3727
if v.expect != reflect.Invalid {
3828
switch v.expect {
@@ -47,7 +37,11 @@ func Check(tree *parser.Tree, config *conf.Config) (t reflect.Type, err error) {
4737
}
4838
}
4939

50-
return
40+
if v.err != nil {
41+
return t, fmt.Errorf("%v", v.err.Format(tree.Source))
42+
}
43+
44+
return t, nil
5145
}
5246

5347
type visitor struct {
@@ -57,6 +51,7 @@ type visitor struct {
5751
collections []reflect.Type
5852
strict bool
5953
defaultType reflect.Type
54+
err *file.Error
6055
}
6156

6257
func (v *visitor) visit(node ast.Node) reflect.Type {
@@ -111,14 +106,17 @@ func (v *visitor) visit(node ast.Node) reflect.Type {
111106
return t
112107
}
113108

114-
func (v *visitor) error(node ast.Node, format string, args ...interface{}) file.Error {
115-
return file.Error{
116-
Location: node.Location(),
117-
Message: fmt.Sprintf(format, args...),
109+
func (v *visitor) error(node ast.Node, format string, args ...interface{}) reflect.Type {
110+
if v.err == nil { // show first error
111+
v.err = &file.Error{
112+
Location: node.Location(),
113+
Message: fmt.Sprintf(format, args...),
114+
}
118115
}
116+
return interfaceType // interface represent undefined type
119117
}
120118

121-
func (v *visitor) NilNode(node *ast.NilNode) reflect.Type {
119+
func (v *visitor) NilNode(*ast.NilNode) reflect.Type {
122120
return nilType
123121
}
124122

@@ -135,22 +133,22 @@ func (v *visitor) IdentifierNode(node *ast.IdentifierNode) reflect.Type {
135133
}
136134
return interfaceType
137135
}
138-
panic(v.error(node, "unknown name %v", node.Value))
136+
return v.error(node, "unknown name %v", node.Value)
139137
}
140138

141-
func (v *visitor) IntegerNode(node *ast.IntegerNode) reflect.Type {
139+
func (v *visitor) IntegerNode(*ast.IntegerNode) reflect.Type {
142140
return integerType
143141
}
144142

145-
func (v *visitor) FloatNode(node *ast.FloatNode) reflect.Type {
143+
func (v *visitor) FloatNode(*ast.FloatNode) reflect.Type {
146144
return floatType
147145
}
148146

149-
func (v *visitor) BoolNode(node *ast.BoolNode) reflect.Type {
147+
func (v *visitor) BoolNode(*ast.BoolNode) reflect.Type {
150148
return boolType
151149
}
152150

153-
func (v *visitor) StringNode(node *ast.StringNode) reflect.Type {
151+
func (v *visitor) StringNode(*ast.StringNode) reflect.Type {
154152
return stringType
155153
}
156154

@@ -170,10 +168,10 @@ func (v *visitor) UnaryNode(node *ast.UnaryNode) reflect.Type {
170168
}
171169

172170
default:
173-
panic(v.error(node, "unknown operator (%v)", node.Operator))
171+
return v.error(node, "unknown operator (%v)", node.Operator)
174172
}
175173

176-
panic(v.error(node, `invalid operation: %v (mismatched type %v)`, node.Operator, t))
174+
return v.error(node, `invalid operation: %v (mismatched type %v)`, node.Operator, t)
177175
}
178176

179177
func (v *visitor) BinaryNode(node *ast.BinaryNode) reflect.Type {
@@ -255,11 +253,11 @@ func (v *visitor) BinaryNode(node *ast.BinaryNode) reflect.Type {
255253
}
256254

257255
default:
258-
panic(v.error(node, "unknown operator (%v)", node.Operator))
256+
return v.error(node, "unknown operator (%v)", node.Operator)
259257

260258
}
261259

262-
panic(v.error(node, `invalid operation: %v (mismatched types %v and %v)`, node.Operator, l, r))
260+
return v.error(node, `invalid operation: %v (mismatched types %v and %v)`, node.Operator, l, r)
263261
}
264262

265263
func (v *visitor) MatchesNode(node *ast.MatchesNode) reflect.Type {
@@ -270,7 +268,7 @@ func (v *visitor) MatchesNode(node *ast.MatchesNode) reflect.Type {
270268
return boolType
271269
}
272270

273-
panic(v.error(node, `invalid operation: matches (mismatched types %v and %v)`, l, r))
271+
return v.error(node, `invalid operation: matches (mismatched types %v and %v)`, l, r)
274272
}
275273

276274
func (v *visitor) PropertyNode(node *ast.PropertyNode) reflect.Type {
@@ -280,7 +278,7 @@ func (v *visitor) PropertyNode(node *ast.PropertyNode) reflect.Type {
280278
return t
281279
}
282280

283-
panic(v.error(node, "type %v has no field %v", t, node.Property))
281+
return v.error(node, "type %v has no field %v", t, node.Property)
284282
}
285283

286284
func (v *visitor) IndexNode(node *ast.IndexNode) reflect.Type {
@@ -289,12 +287,12 @@ func (v *visitor) IndexNode(node *ast.IndexNode) reflect.Type {
289287

290288
if t, ok := indexType(t); ok {
291289
if !isInteger(i) && !isString(i) {
292-
panic(v.error(node, "invalid operation: cannot use %v as index to %v", i, t))
290+
return v.error(node, "invalid operation: cannot use %v as index to %v", i, t)
293291
}
294292
return t
295293
}
296294

297-
panic(v.error(node, "invalid operation: type %v does not support indexing", t))
295+
return v.error(node, "invalid operation: type %v does not support indexing", t)
298296
}
299297

300298
func (v *visitor) SliceNode(node *ast.SliceNode) reflect.Type {
@@ -306,19 +304,19 @@ func (v *visitor) SliceNode(node *ast.SliceNode) reflect.Type {
306304
if node.From != nil {
307305
from := v.visit(node.From)
308306
if !isInteger(from) {
309-
panic(v.error(node.From, "invalid operation: non-integer slice index %v", from))
307+
return v.error(node.From, "invalid operation: non-integer slice index %v", from)
310308
}
311309
}
312310
if node.To != nil {
313311
to := v.visit(node.To)
314312
if !isInteger(to) {
315-
panic(v.error(node.To, "invalid operation: non-integer slice index %v", to))
313+
return v.error(node.To, "invalid operation: non-integer slice index %v", to)
316314
}
317315
}
318316
return t
319317
}
320318

321-
panic(v.error(node, "invalid operation: cannot slice %v", t))
319+
return v.error(node, "invalid operation: cannot slice %v", t)
322320
}
323321

324322
func (v *visitor) FunctionNode(node *ast.FunctionNode) reflect.Type {
@@ -349,7 +347,7 @@ func (v *visitor) FunctionNode(node *ast.FunctionNode) reflect.Type {
349347
}
350348
return interfaceType
351349
}
352-
panic(v.error(node, "unknown func %v", node.Name))
350+
return v.error(node, "unknown func %v", node.Name)
353351
}
354352

355353
func (v *visitor) MethodNode(node *ast.MethodNode) reflect.Type {
@@ -359,7 +357,7 @@ func (v *visitor) MethodNode(node *ast.MethodNode) reflect.Type {
359357
return v.checkFunc(fn, method, node, node.Method, node.Arguments)
360358
}
361359
}
362-
panic(v.error(node, "type %v has no method %v", t, node.Method))
360+
return v.error(node, "type %v has no method %v", t, node.Method)
363361
}
364362

365363
// checkFunc checks func arguments and returns "return type" of func or method.
@@ -369,10 +367,10 @@ func (v *visitor) checkFunc(fn reflect.Type, method bool, node ast.Node, name st
369367
}
370368

371369
if fn.NumOut() == 0 {
372-
panic(v.error(node, "func %v doesn't return value", name))
370+
return v.error(node, "func %v doesn't return value", name)
373371
}
374372
if fn.NumOut() != 1 {
375-
panic(v.error(node, "func %v returns more then one value", name))
373+
return v.error(node, "func %v returns more then one value", name)
376374
}
377375

378376
numIn := fn.NumIn()
@@ -385,14 +383,14 @@ func (v *visitor) checkFunc(fn reflect.Type, method bool, node ast.Node, name st
385383

386384
if fn.IsVariadic() {
387385
if len(arguments) < numIn-1 {
388-
panic(v.error(node, "not enough arguments to call %v", name))
386+
return v.error(node, "not enough arguments to call %v", name)
389387
}
390388
} else {
391389
if len(arguments) > numIn {
392-
panic(v.error(node, "too many arguments to call %v", name))
390+
return v.error(node, "too many arguments to call %v", name)
393391
}
394392
if len(arguments) < numIn {
395-
panic(v.error(node, "not enough arguments to call %v", name))
393+
return v.error(node, "not enough arguments to call %v", name)
396394
}
397395
}
398396

@@ -426,7 +424,7 @@ func (v *visitor) checkFunc(fn reflect.Type, method bool, node ast.Node, name st
426424
}
427425

428426
if !t.AssignableTo(in) {
429-
panic(v.error(arg, "cannot use %v as argument (type %v) to call %v ", t, in, name))
427+
return v.error(arg, "cannot use %v as argument (type %v) to call %v ", t, in, name)
430428
}
431429
}
432430

@@ -441,12 +439,12 @@ func (v *visitor) BuiltinNode(node *ast.BuiltinNode) reflect.Type {
441439
if isArray(param) || isMap(param) || isString(param) {
442440
return integerType
443441
}
444-
panic(v.error(node, "invalid argument for len (type %v)", param))
442+
return v.error(node, "invalid argument for len (type %v)", param)
445443

446444
case "all", "none", "any", "one":
447445
collection := v.visit(node.Arguments[0])
448446
if !isArray(collection) {
449-
panic(v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection))
447+
return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
450448
}
451449

452450
v.collections = append(v.collections, collection)
@@ -458,16 +456,16 @@ func (v *visitor) BuiltinNode(node *ast.BuiltinNode) reflect.Type {
458456
closure.NumIn() == 1 && isInterface(closure.In(0)) {
459457

460458
if !isBool(closure.Out(0)) {
461-
panic(v.error(node.Arguments[1], "closure should return boolean (got %v)", closure.Out(0).String()))
459+
return v.error(node.Arguments[1], "closure should return boolean (got %v)", closure.Out(0).String())
462460
}
463461
return boolType
464462
}
465-
panic(v.error(node.Arguments[1], "closure should has one input and one output param"))
463+
return v.error(node.Arguments[1], "closure should has one input and one output param")
466464

467465
case "filter":
468466
collection := v.visit(node.Arguments[0])
469467
if !isArray(collection) {
470-
panic(v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection))
468+
return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
471469
}
472470

473471
v.collections = append(v.collections, collection)
@@ -479,16 +477,16 @@ func (v *visitor) BuiltinNode(node *ast.BuiltinNode) reflect.Type {
479477
closure.NumIn() == 1 && isInterface(closure.In(0)) {
480478

481479
if !isBool(closure.Out(0)) {
482-
panic(v.error(node.Arguments[1], "closure should return boolean (got %v)", closure.Out(0).String()))
480+
return v.error(node.Arguments[1], "closure should return boolean (got %v)", closure.Out(0).String())
483481
}
484482
return arrayType
485483
}
486-
panic(v.error(node.Arguments[1], "closure should has one input and one output param"))
484+
return v.error(node.Arguments[1], "closure should has one input and one output param")
487485

488486
case "map":
489487
collection := v.visit(node.Arguments[0])
490488
if !isArray(collection) {
491-
panic(v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection))
489+
return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
492490
}
493491

494492
v.collections = append(v.collections, collection)
@@ -501,12 +499,12 @@ func (v *visitor) BuiltinNode(node *ast.BuiltinNode) reflect.Type {
501499

502500
return reflect.SliceOf(closure.Out(0))
503501
}
504-
panic(v.error(node.Arguments[1], "closure should has one input and one output param"))
502+
return v.error(node.Arguments[1], "closure should has one input and one output param")
505503

506504
case "count":
507505
collection := v.visit(node.Arguments[0])
508506
if !isArray(collection) {
509-
panic(v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection))
507+
return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
510508
}
511509

512510
v.collections = append(v.collections, collection)
@@ -517,15 +515,15 @@ func (v *visitor) BuiltinNode(node *ast.BuiltinNode) reflect.Type {
517515
closure.NumOut() == 1 &&
518516
closure.NumIn() == 1 && isInterface(closure.In(0)) {
519517
if !isBool(closure.Out(0)) {
520-
panic(v.error(node.Arguments[1], "closure should return boolean (got %v)", closure.Out(0).String()))
518+
return v.error(node.Arguments[1], "closure should return boolean (got %v)", closure.Out(0).String())
521519
}
522520

523521
return integerType
524522
}
525-
panic(v.error(node.Arguments[1], "closure should has one input and one output param"))
523+
return v.error(node.Arguments[1], "closure should has one input and one output param")
526524

527525
default:
528-
panic(v.error(node, "unknown builtin %v", node.Name))
526+
return v.error(node, "unknown builtin %v", node.Name)
529527
}
530528
}
531529

@@ -536,21 +534,21 @@ func (v *visitor) ClosureNode(node *ast.ClosureNode) reflect.Type {
536534

537535
func (v *visitor) PointerNode(node *ast.PointerNode) reflect.Type {
538536
if len(v.collections) == 0 {
539-
panic(v.error(node, "cannot use pointer accessor outside closure"))
537+
return v.error(node, "cannot use pointer accessor outside closure")
540538
}
541539

542540
collection := v.collections[len(v.collections)-1]
543541

544542
if t, ok := indexType(collection); ok {
545543
return t
546544
}
547-
panic(v.error(node, "cannot use %v as array", collection))
545+
return v.error(node, "cannot use %v as array", collection)
548546
}
549547

550548
func (v *visitor) ConditionalNode(node *ast.ConditionalNode) reflect.Type {
551549
c := v.visit(node.Cond)
552550
if !isBool(c) {
553-
panic(v.error(node.Cond, "non-bool expression (type %v) used as condition", c))
551+
return v.error(node.Cond, "non-bool expression (type %v) used as condition", c)
554552
}
555553

556554
t1 := v.visit(node.Exp1)

expr.go

+12-3
Original file line numberDiff line numberDiff line change
@@ -143,15 +143,24 @@ func Compile(input string, ops ...Option) (*vm.Program, error) {
143143
}
144144

145145
_, err = checker.Check(tree, config)
146-
if err != nil {
146+
147+
// If we have a patch to apply, it may fix out error and
148+
// second type check is needed. Otherwise it is an error.
149+
if err != nil && len(config.Visitors) == 0 {
147150
return nil, err
148151
}
149152

150153
// Patch operators before Optimize, as we may also mark it as ConstExpr.
151154
compiler.PatchOperators(&tree.Node, config)
152155

153-
for _, v := range config.Visitors {
154-
ast.Walk(&tree.Node, v)
156+
if len(config.Visitors) >= 0 {
157+
for _, v := range config.Visitors {
158+
ast.Walk(&tree.Node, v)
159+
}
160+
_, err = checker.Check(tree, config)
161+
if err != nil {
162+
return nil, err
163+
}
155164
}
156165

157166
if config.Optimize {

0 commit comments

Comments
 (0)