Skip to content

Commit 0e3a528

Browse files
committed
binary tree llvm - consistent block termination and implicit returns
1 parent b9a8067 commit 0e3a528

File tree

3 files changed

+129
-95
lines changed

3 files changed

+129
-95
lines changed

compiler/llvm_backend.py

+30-21
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
class LlvmBackend(Visitor):
2121
locals: List[defaultdict]
2222
externs: Dict[str, ir.Function]
23-
constructors: Dict[str, ir.Function]
2423
methods: Dict[str, Dict[str, ir.Function]]
2524
structs: Dict[str, ir.LiteralStructType]
2625
# (class name, method name) -> idx in vtable, type
@@ -58,14 +57,16 @@ def initializeOffsets(self):
5857
self.methods[cls] = {}
5958
orderedMethods = self.ts.getOrderedMethods(cls)
6059
vtable = []
61-
for idx, (methName, methType, _) in enumerate(orderedMethods):
60+
for idx, (methName, methType, defCls) in enumerate(orderedMethods):
6261
funcType = methType.getLLVMType()
6362
self.methodOffsets[(cls, methName)] = (idx, funcType)
64-
func = ir.Function(self.module, funcType,
65-
cls + "__" + methName)
63+
if defCls == cls:
64+
func = ir.Function(self.module, funcType,
65+
cls + "__" + methName)
66+
self.methods[cls][methName] = func
67+
for methName, _, defCls in orderedMethods:
68+
func = self.methods[defCls][methName]
6669
self.methods[cls][methName] = func
67-
for methName, _, _ in orderedMethods:
68-
func = self.methods[cls][methName]
6970
vtable.append(func)
7071
t = self.getClassVtableType(cls)
7172
self.global_constant('__' + cls + '__vtable',
@@ -220,8 +221,8 @@ def Program(self, node: Program):
220221
self.visitStmtList(node.statements)
221222

222223
self.builder.branch(end_program)
223-
program_block = self.builder.block
224224
self.builder.position_at_start(end_program)
225+
assert not end_program.is_terminated
225226
self.builder.ret_void()
226227
self.exitScope()
227228

@@ -255,13 +256,14 @@ def declareFunc(self, node: FuncDef):
255256
ir.Function(self.module, funcType, funcname)
256257

257258
def FuncDef(self, node: FuncDef):
259+
fname = node.getIdentifier().name
258260
if node.isMethod:
259261
func = self.module.get_global(
260-
self.currentClass + "__" + node.getIdentifier().name)
262+
self.currentClass + "__" + fname)
261263
else:
262-
func = self.module.get_global(node.getIdentifier().name)
264+
func = self.module.get_global(fname)
263265
self.returnType = node.type.returnType
264-
shouldReturnValue = not self.returnType.isNone()
266+
implicitReturn = self.returnType not in {IntType(), BoolType(), StrType(), NoneType()}
265267
self.enterScope()
266268
bb_entry = func.append_basic_block('entry')
267269
self.builder = ir.IRBuilder(bb_entry)
@@ -274,12 +276,14 @@ def FuncDef(self, node: FuncDef):
274276
for d in node.declarations:
275277
self.visit(d)
276278
self.visitStmtList(node.statements)
277-
# implicitly return None if possible
278-
if shouldReturnValue is not None and (
279-
len(node.statements) == 0 or
280-
not isinstance(node.statements[-1], ReturnStmt)
281-
):
282-
self.builder.ret(self.NoneLiteral(None))
279+
# implicitly return None if needed, close all blocks
280+
for block in func.blocks:
281+
self.builder.position_at_end(block)
282+
if not block.is_terminated:
283+
if implicitReturn:
284+
self.builder.ret(self.NoneLiteral(None))
285+
else:
286+
self.builder.unreachable()
283287
self.exitScope()
284288
return func
285289

@@ -536,10 +540,13 @@ def visitArg(self, funcType: FuncType, paramIdx: int, arg: Expr):
536540
# unwrap if necessary, re-wrap
537541
saved_block = self.builder.block
538542
val = self.visit(arg)
543+
# print(val)
539544
addr = self.builder.alloca(
540-
node.var.t.getLLVMType())
545+
arg.inferredType.getLLVMType())
546+
# print(addr)
541547
wrapper = self.builder.alloca(
542-
node.var.t.getLLVMType().as_pointer(), None, "wrapper")
548+
arg.inferredType.getLLVMType().as_pointer(), None, "wrapper")
549+
# print(wrapper)
543550
self.builder.position_at_end(saved_block)
544551
self.builder.store(val, addr)
545552
self.builder.store(addr, wrapper)
@@ -659,6 +666,7 @@ def whileHelper(self, condFn, bodyFn):
659666
self.builder.position_at_start(end_block)
660667

661668
def ReturnStmt(self, node: ReturnStmt):
669+
assert not self.builder.block.is_terminated
662670
if self.returnType.isNone():
663671
self.builder.ret(self.NoneLiteral(None))
664672
else:
@@ -713,14 +721,14 @@ def ifHelper(self, condFn, thenFn, elseFn=None, returnType=None):
713721

714722
self.builder.position_at_start(then_block)
715723
then_val = thenFn()
716-
if not self.builder.block.is_terminated:
724+
if not then_block.is_terminated:
717725
self.builder.branch(merge_block)
718726
then_block = self.builder.block
719727

720728
if elseFn is not None:
721729
self.builder.position_at_start(else_block)
722730
else_val = elseFn()
723-
if not self.builder.block.is_terminated:
731+
if not else_block.is_terminated:
724732
self.builder.branch(merge_block)
725733
else_block = self.builder.block
726734

@@ -749,7 +757,8 @@ def MethodCallExpr(self, node: MethodCallExpr):
749757

750758
call_args = [self.builder.bitcast(obj, voidptr_t)]
751759
for i in range(len(node.args)):
752-
call_args.append(self.visitArg(node.method.inferredType, i, node.args[i]))
760+
call_args.append(self.visitArg(
761+
node.method.inferredType, i, node.args[i]))
753762
return self.builder.call(callee_func, call_args, 'callmethodtmp')
754763

755764
# LITERALS

foobar.py

+87-59
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,87 @@
1-
class A:
2-
y: int = 1
3-
4-
def __init__(self: A):
5-
pass
6-
7-
def t(self: A):
8-
global x
9-
x = 1
10-
11-
12-
class B(A):
13-
z: int = 0
14-
15-
def __init__(self: B):
16-
self.z = 5
17-
self.y = 5
18-
19-
def t(self: B):
20-
global x
21-
x = 2
22-
23-
def setZ(self: B, z: int):
24-
self.z = z
25-
26-
27-
x: int = 0
28-
c1: A = None
29-
c2: B = None
30-
c3: A = None
31-
32-
# constructors, getters, setters
33-
c1 = A()
34-
assert c1.y == 1
35-
c2 = B()
36-
assert c2.y == 5
37-
assert c2.z == 5
38-
c3 = B()
39-
assert c3.y == 5
40-
41-
c2.y = 0
42-
assert c2.y == 0
43-
44-
# methods, dynamic dispatch
45-
46-
c2.setZ(2)
47-
assert c2.z == 2
48-
49-
x = 0
50-
c1.t()
51-
assert x == 1
52-
53-
x = 0
54-
c2.t()
55-
assert x == 2
56-
57-
x = 0
58-
c3.t()
59-
assert x == 2
1+
# Binary-search trees
2+
class TreeNode(object):
3+
value: int = 0
4+
left: "TreeNode" = None
5+
right: "TreeNode" = None
6+
7+
def insert(self: "TreeNode", x: int) -> bool:
8+
if x < self.value:
9+
if self.left is None:
10+
self.left = makeNode(x)
11+
return True
12+
else:
13+
return self.left.insert(x)
14+
elif x > self.value:
15+
if self.right is None:
16+
self.right = makeNode(x)
17+
return True
18+
else:
19+
return self.right.insert(x)
20+
return False
21+
22+
def contains(self: "TreeNode", x: int) -> bool:
23+
if x < self.value:
24+
if self.left is None:
25+
return False
26+
else:
27+
return self.left.contains(x)
28+
elif x > self.value:
29+
if self.right is None:
30+
return False
31+
else:
32+
return self.right.contains(x)
33+
else:
34+
return True
35+
36+
37+
class Tree(object):
38+
root: TreeNode = None
39+
size: int = 0
40+
41+
def insert(self: "Tree", x: int) -> object:
42+
if self.root is None:
43+
self.root = makeNode(x)
44+
self.size = 1
45+
else:
46+
if self.root.insert(x):
47+
self.size = self.size + 1
48+
49+
def contains(self: "Tree", x: int) -> bool:
50+
if self.root is None:
51+
return False
52+
else:
53+
return self.root.contains(x)
54+
55+
56+
def makeNode(x: int) -> TreeNode:
57+
b: TreeNode = None
58+
b = TreeNode()
59+
b.value = x
60+
return b
61+
62+
63+
# Input parameters
64+
n: int = 100
65+
c: int = 4
66+
67+
# Data
68+
t: Tree = None
69+
i: int = 0
70+
k: int = 37813
71+
72+
# Crunch
73+
t = Tree()
74+
while i < n:
75+
t.insert(k)
76+
k = (k * 37813) % 37831
77+
if i % c != 0:
78+
t.insert(i)
79+
i = i + 1
80+
81+
assert t.size == 175
82+
assert t.contains(15)
83+
assert t.contains(23)
84+
assert t.contains(42)
85+
assert not t.contains(4)
86+
assert not t.contains(8)
87+
assert not t.contains(16)

test.py

+12-15
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,7 @@
1717

1818

1919
disabled_llvm_tests = [
20-
"/binary_tree.",
21-
"/doubling_vector.",
2220
"/nonlocal.",
23-
"/exponent."
2421
]
2522

2623
disabled_jvm_tests = []
@@ -41,15 +38,15 @@ def should_skip(disabled_tests: List[str], test: Path) -> bool:
4138

4239

4340
def run_all_tests():
44-
run_parse_tests()
45-
run_typecheck_tests()
46-
run_python_backend_tests()
47-
run_closure_tests()
48-
run_jvm_tests()
49-
run_cil_tests()
50-
run_wasm_tests()
51-
run_llvm_tests()
52-
# test_eval_llvm()
41+
# run_parse_tests()
42+
# run_typecheck_tests()
43+
# run_python_backend_tests()
44+
# run_closure_tests()
45+
# run_jvm_tests()
46+
# run_cil_tests()
47+
# run_wasm_tests()
48+
# run_llvm_tests()
49+
test_eval_llvm()
5350

5451

5552
def run_parse_tests():
@@ -647,7 +644,7 @@ def eval_llvm(module):
647644

648645

649646
def test_eval_llvm():
650-
run_llvm_test("foobar.py", True)
647+
run_llvm_test("foobar.py", "foobar.ll")
651648

652649

653650
def run_llvm_test(test, debug):
@@ -665,8 +662,8 @@ def run_llvm_test(test, debug):
665662
assert len(compiler.typechecker.errors) == 0
666663
module = compiler.emitLLVM(chocopy_ast)
667664
if debug:
668-
print("Module output:")
669-
print(str(module))
665+
with open(debug, "w") as f:
666+
f.write(str(module))
670667
eval_llvm(module)
671668
return True
672669
except Exception as e:

0 commit comments

Comments
 (0)