Skip to content

Commit

Permalink
lang: unification: Improve type unification algorithm
Browse files Browse the repository at this point in the history
The simple type unification algorithm suffered from some serious
performance and memory problems when used with certain code bases. This
adds some crucial optimizations that improve performance drastically.
  • Loading branch information
purpleidea committed Apr 24, 2019
1 parent 97d60ac commit d70bbfb
Show file tree
Hide file tree
Showing 17 changed files with 770 additions and 27 deletions.
1 change: 0 additions & 1 deletion examples/lang/states0.mcl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ import "world"

$ns = "estate"
$exchanged = world.kvlookup($ns)

$state = maplookup($exchanged, $hostname, "default")

if $state == "one" || $state == "default" {
Expand Down
8 changes: 7 additions & 1 deletion lang/gapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,13 @@ func (obj *GAPI) Cli(cliInfo *gapi.CliInfo) (*gapi.Deploy, error) {
}
}
logf("running type unification...")
if err := unification.Unify(interpolated, unification.SimpleInvariantSolverLogger(unificationLogf)); err != nil {
unifier := &unification.Unifier{
AST: interpolated,
Solver: unification.SimpleInvariantSolverLogger(unificationLogf),
Debug: debug,
Logf: unificationLogf,
}
if err := unifier.Unify(); err != nil {
return nil, errwrap.Wrapf(err, "could not unify types")
}

Expand Down
1 change: 1 addition & 0 deletions lang/interfaces/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
// often since we usually know which kind of node we want.
type Node interface {
Apply(fn func(Node) error) error
//Parent() Node // TODO: should we implement this?
}

// Stmt represents a statement node in the language. A stmt could be a resource,
Expand Down
9 changes: 9 additions & 0 deletions lang/interfaces/unification.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package interfaces

import (
"fmt"

"github.com/purpleidea/mgmt/lang/types"
)

// Invariant represents a constraint that is described by the Expr's and Stmt's,
Expand All @@ -27,4 +29,11 @@ import (
type Invariant interface {
// TODO: should we add any other methods to this type?
fmt.Stringer

// ExprList returns the list of valid expressions in this invariant.
ExprList() []Expr

// Matches returns whether an invariant matches the existing solution.
// If it is inconsistent, then it errors.
Matches(solved map[Expr]*types.Type) (bool, error)
}
24 changes: 21 additions & 3 deletions lang/interpret_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,13 @@ func TestAstFunc0(t *testing.T) {
logf := func(format string, v ...interface{}) {
t.Logf(fmt.Sprintf("test #%d", index)+": unification: "+format, v...)
}
err = unification.Unify(iast, unification.SimpleInvariantSolverLogger(logf))
unifier := &unification.Unifier{
AST: iast,
Solver: unification.SimpleInvariantSolverLogger(logf),
Debug: testing.Verbose(),
Logf: logf,
}
err = unifier.Unify()
if !fail && err != nil {
t.Errorf("test #%d: FAIL", index)
t.Errorf("test #%d: could not unify types: %+v", index, err)
Expand Down Expand Up @@ -822,7 +828,13 @@ func TestAstFunc1(t *testing.T) {
xlogf := func(format string, v ...interface{}) {
logf("unification: "+format, v...)
}
err = unification.Unify(iast, unification.SimpleInvariantSolverLogger(xlogf))
unifier := &unification.Unifier{
AST: iast,
Solver: unification.SimpleInvariantSolverLogger(xlogf),
Debug: testing.Verbose(),
Logf: xlogf,
}
err = unifier.Unify()
if !fail && err != nil {
t.Errorf("test #%d: FAIL", index)
t.Errorf("test #%d: could not unify types: %+v", index, err)
Expand Down Expand Up @@ -1216,7 +1228,13 @@ func TestAstFunc2(t *testing.T) {
xlogf := func(format string, v ...interface{}) {
logf("unification: "+format, v...)
}
err = unification.Unify(iast, unification.SimpleInvariantSolverLogger(xlogf))
unifier := &unification.Unifier{
AST: iast,
Solver: unification.SimpleInvariantSolverLogger(xlogf),
Debug: testing.Verbose(),
Logf: xlogf,
}
err = unifier.Unify()
if !fail && err != nil {
t.Errorf("test #%d: FAIL", index)
t.Errorf("test #%d: could not unify types: %+v", index, err)
Expand Down
11 changes: 11 additions & 0 deletions lang/interpret_test/TestAstFunc1/doubleinclude.graph
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
Edge: str("hey") -> var(foo) # foo
Edge: str("hey") -> var(foo) # foo
Edge: str("t1") -> var(a) # a
Edge: str("t2") -> var(a) # a
Vertex: str("hey")
Vertex: str("t1")
Vertex: str("t2")
Vertex: var(a)
Vertex: var(a)
Vertex: var(foo)
Vertex: var(foo)
8 changes: 8 additions & 0 deletions lang/interpret_test/TestAstFunc1/doubleinclude/main.mcl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
include c1("t1")
include c1("t2")
class c1($a) {
test $a {
stringptr => $foo,
}
}
$foo = "hey"
32 changes: 32 additions & 0 deletions lang/interpret_test/TestAstFunc1/polydoubleinclude.graph
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
Edge: call:len(var(b)) -> call:fmt.printf(str("len is: %d"), call:len(var(b))) # b
Edge: call:len(var(b)) -> call:fmt.printf(str("len is: %d"), call:len(var(b))) # b
Edge: int(-37) -> list(int(13), int(42), int(0), int(-37)) # 3
Edge: int(0) -> list(int(13), int(42), int(0), int(-37)) # 2
Edge: int(13) -> list(int(13), int(42), int(0), int(-37)) # 0
Edge: int(42) -> list(int(13), int(42), int(0), int(-37)) # 1
Edge: list(int(13), int(42), int(0), int(-37)) -> var(b) # b
Edge: str("hello") -> var(b) # b
Edge: str("len is: %d") -> call:fmt.printf(str("len is: %d"), call:len(var(b))) # a
Edge: str("len is: %d") -> call:fmt.printf(str("len is: %d"), call:len(var(b))) # a
Edge: str("t1") -> var(a) # a
Edge: str("t2") -> var(a) # a
Edge: var(b) -> call:len(var(b)) # 0
Edge: var(b) -> call:len(var(b)) # 0
Vertex: call:fmt.printf(str("len is: %d"), call:len(var(b)))
Vertex: call:fmt.printf(str("len is: %d"), call:len(var(b)))
Vertex: call:len(var(b))
Vertex: call:len(var(b))
Vertex: int(-37)
Vertex: int(0)
Vertex: int(13)
Vertex: int(42)
Vertex: list(int(13), int(42), int(0), int(-37))
Vertex: str("hello")
Vertex: str("len is: %d")
Vertex: str("len is: %d")
Vertex: str("t1")
Vertex: str("t2")
Vertex: var(a)
Vertex: var(a)
Vertex: var(b)
Vertex: var(b)
10 changes: 10 additions & 0 deletions lang/interpret_test/TestAstFunc1/polydoubleinclude/main.mcl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import "fmt"

# note that the class can have two separate types for $b
include c1("t1", "hello") # len is 5
include c1("t2", [13, 42, 0, -37,]) # len is 4
class c1($a, $b) {
test $a {
anotherstr => fmt.printf("len is: %d", len($b)),
}
}
88 changes: 88 additions & 0 deletions lang/interpret_test/TestAstFunc1/slow_unification0.graph
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
Edge: call:_operator(str("=="), var(state), str("default")) -> call:_operator(str("||"), call:_operator(str("=="), var(state), str("one")), call:_operator(str("=="), var(state), str("default"))) # b
Edge: call:_operator(str("=="), var(state), str("one")) -> call:_operator(str("||"), call:_operator(str("=="), var(state), str("one")), call:_operator(str("=="), var(state), str("default"))) # a
Edge: call:maplookup(var(exchanged), var(hostname), str("default")) -> var(state) # state
Edge: call:maplookup(var(exchanged), var(hostname), str("default")) -> var(state) # state
Edge: call:maplookup(var(exchanged), var(hostname), str("default")) -> var(state) # state
Edge: call:maplookup(var(exchanged), var(hostname), str("default")) -> var(state) # state
Edge: call:world.kvlookup(var(ns)) -> var(exchanged) # exchanged
Edge: str("") -> var(hostname) # hostname
Edge: str("==") -> call:_operator(str("=="), var(state), str("default")) # x
Edge: str("==") -> call:_operator(str("=="), var(state), str("one")) # x
Edge: str("==") -> call:_operator(str("=="), var(state), str("three")) # x
Edge: str("==") -> call:_operator(str("=="), var(state), str("two")) # x
Edge: str("default") -> call:_operator(str("=="), var(state), str("default")) # b
Edge: str("default") -> call:maplookup(var(exchanged), var(hostname), str("default")) # default
Edge: str("estate") -> var(ns) # ns
Edge: str("estate") -> var(ns) # ns
Edge: str("estate") -> var(ns) # ns
Edge: str("estate") -> var(ns) # ns
Edge: str("estate") -> var(ns) # ns
Edge: str("estate") -> var(ns) # ns
Edge: str("estate") -> var(ns) # ns
Edge: str("estate") -> var(ns) # ns
Edge: str("estate") -> var(ns) # ns
Edge: str("estate") -> var(ns) # ns
Edge: str("one") -> call:_operator(str("=="), var(state), str("one")) # b
Edge: str("three") -> call:_operator(str("=="), var(state), str("three")) # b
Edge: str("two") -> call:_operator(str("=="), var(state), str("two")) # b
Edge: str("||") -> call:_operator(str("||"), call:_operator(str("=="), var(state), str("one")), call:_operator(str("=="), var(state), str("default"))) # x
Edge: var(exchanged) -> call:maplookup(var(exchanged), var(hostname), str("default")) # map
Edge: var(hostname) -> call:maplookup(var(exchanged), var(hostname), str("default")) # key
Edge: var(ns) -> call:world.kvlookup(var(ns)) # namespace
Edge: var(state) -> call:_operator(str("=="), var(state), str("default")) # a
Edge: var(state) -> call:_operator(str("=="), var(state), str("one")) # a
Edge: var(state) -> call:_operator(str("=="), var(state), str("three")) # a
Edge: var(state) -> call:_operator(str("=="), var(state), str("two")) # a
Vertex: call:_operator(str("=="), var(state), str("default"))
Vertex: call:_operator(str("=="), var(state), str("one"))
Vertex: call:_operator(str("=="), var(state), str("three"))
Vertex: call:_operator(str("=="), var(state), str("two"))
Vertex: call:_operator(str("||"), call:_operator(str("=="), var(state), str("one")), call:_operator(str("=="), var(state), str("default")))
Vertex: call:maplookup(var(exchanged), var(hostname), str("default"))
Vertex: call:world.kvlookup(var(ns))
Vertex: str("")
Vertex: str("/tmp/mgmt/state")
Vertex: str("/tmp/mgmt/state")
Vertex: str("/tmp/mgmt/state")
Vertex: str("/usr/bin/sleep 1s")
Vertex: str("/usr/bin/sleep 1s")
Vertex: str("/usr/bin/sleep 1s")
Vertex: str("==")
Vertex: str("==")
Vertex: str("==")
Vertex: str("==")
Vertex: str("default")
Vertex: str("default")
Vertex: str("estate")
Vertex: str("one")
Vertex: str("one")
Vertex: str("state: one\n")
Vertex: str("state: three\n")
Vertex: str("state: two\n")
Vertex: str("three")
Vertex: str("three")
Vertex: str("timer")
Vertex: str("timer")
Vertex: str("timer")
Vertex: str("timer")
Vertex: str("timer")
Vertex: str("timer")
Vertex: str("two")
Vertex: str("two")
Vertex: str("||")
Vertex: var(exchanged)
Vertex: var(hostname)
Vertex: var(ns)
Vertex: var(ns)
Vertex: var(ns)
Vertex: var(ns)
Vertex: var(ns)
Vertex: var(ns)
Vertex: var(ns)
Vertex: var(ns)
Vertex: var(ns)
Vertex: var(ns)
Vertex: var(state)
Vertex: var(state)
Vertex: var(state)
Vertex: var(state)
52 changes: 52 additions & 0 deletions lang/interpret_test/TestAstFunc1/slow_unification0/main.mcl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# state machine that previously experienced unusable slow type unification
import "world"

$ns = "estate"
$exchanged = world.kvlookup($ns)
$state = maplookup($exchanged, $hostname, "default")

if $state == "one" || $state == "default" {

file "/tmp/mgmt/state" {
content => "state: one\n",
}

exec "timer" {
cmd => "/usr/bin/sleep 1s",
}
kv "${ns}" {
key => $ns,
value => "two",
}
Exec["timer"] -> Kv["${ns}"]
}
if $state == "two" {

file "/tmp/mgmt/state" {
content => "state: two\n",
}

exec "timer" {
cmd => "/usr/bin/sleep 1s",
}
kv "${ns}" {
key => $ns,
value => "three",
}
Exec["timer"] -> Kv["${ns}"]
}
if $state == "three" {

file "/tmp/mgmt/state" {
content => "state: three\n",
}

exec "timer" {
cmd => "/usr/bin/sleep 1s",
}
kv "${ns}" {
key => $ns,
value => "one",
}
Exec["timer"] -> Kv["${ns}"]
}
8 changes: 7 additions & 1 deletion lang/lang.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,13 @@ func (obj *Lang) Init() error {
}
}
obj.Logf("running type unification...")
if err := unification.Unify(obj.ast, unification.SimpleInvariantSolverLogger(logf)); err != nil {
unifier := &unification.Unifier{
AST: obj.ast,
Solver: unification.SimpleInvariantSolverLogger(logf),
Debug: obj.Debug,
Logf: logf,
}
if err := unifier.Unify(); err != nil {
return errwrap.Wrapf(err, "could not unify types")
}

Expand Down
15 changes: 14 additions & 1 deletion lang/structs.go
Original file line number Diff line number Diff line change
Expand Up @@ -2977,6 +2977,15 @@ type StmtInclude struct {
// Nevertheless, it is a useful facility for operations that might only apply to
// a select number of node types, since they won't need extra noop iterators...
func (obj *StmtInclude) Apply(fn func(interfaces.Node) error) error {
// If the class exists, then descend into it, because at this point, the
// copy of the original class that is stored here, is the effective
// class that we care about for type unification, and everything else...
// It's not clear if this is needed, but it's probably nor harmful atm.
if obj.class != nil {
if err := obj.class.Apply(fn); err != nil {
return err
}
}
if obj.Args != nil {
for _, x := range obj.Args {
if err := x.Apply(fn); err != nil {
Expand Down Expand Up @@ -4890,7 +4899,11 @@ func (obj *ExprFunc) String() string {
if obj.Return != nil {
s += fmt.Sprintf(" %s", obj.Return.String())
}
s += fmt.Sprintf(" { %s }", obj.Body.String())
if obj.Body == nil {
s += fmt.Sprintf(" { ??? }") // TODO: why does this happen?
} else {
s += fmt.Sprintf(" { %s }", obj.Body.String())
}
return s
}

Expand Down
Loading

0 comments on commit d70bbfb

Please sign in to comment.