20
20
class LlvmBackend (Visitor ):
21
21
locals : List [defaultdict ]
22
22
externs : Dict [str , ir .Function ]
23
- constructors : Dict [str , ir .Function ]
24
23
methods : Dict [str , Dict [str , ir .Function ]]
25
24
structs : Dict [str , ir .LiteralStructType ]
26
25
# (class name, method name) -> idx in vtable, type
@@ -58,14 +57,16 @@ def initializeOffsets(self):
58
57
self .methods [cls ] = {}
59
58
orderedMethods = self .ts .getOrderedMethods (cls )
60
59
vtable = []
61
- for idx , (methName , methType , _ ) in enumerate (orderedMethods ):
60
+ for idx , (methName , methType , defCls ) in enumerate (orderedMethods ):
62
61
funcType = methType .getLLVMType ()
63
62
self .methodOffsets [(cls , methName )] = (idx , funcType )
64
- func = ir .Function (self .module , funcType ,
65
- cls + "__" + methName )
63
+ if defCls == cls :
64
+ func = ir .Function (self .module , funcType ,
65
+ cls + "__" + methName )
66
+ self .methods [cls ][methName ] = func
67
+ for methName , _ , defCls in orderedMethods :
68
+ func = self .methods [defCls ][methName ]
66
69
self .methods [cls ][methName ] = func
67
- for methName , _ , _ in orderedMethods :
68
- func = self .methods [cls ][methName ]
69
70
vtable .append (func )
70
71
t = self .getClassVtableType (cls )
71
72
self .global_constant ('__' + cls + '__vtable' ,
@@ -220,8 +221,8 @@ def Program(self, node: Program):
220
221
self .visitStmtList (node .statements )
221
222
222
223
self .builder .branch (end_program )
223
- program_block = self .builder .block
224
224
self .builder .position_at_start (end_program )
225
+ assert not end_program .is_terminated
225
226
self .builder .ret_void ()
226
227
self .exitScope ()
227
228
@@ -255,13 +256,14 @@ def declareFunc(self, node: FuncDef):
255
256
ir .Function (self .module , funcType , funcname )
256
257
257
258
def FuncDef (self , node : FuncDef ):
259
+ fname = node .getIdentifier ().name
258
260
if node .isMethod :
259
261
func = self .module .get_global (
260
- self .currentClass + "__" + node . getIdentifier (). name )
262
+ self .currentClass + "__" + fname )
261
263
else :
262
- func = self .module .get_global (node . getIdentifier (). name )
264
+ func = self .module .get_global (fname )
263
265
self .returnType = node .type .returnType
264
- shouldReturnValue = not self .returnType . isNone ()
266
+ implicitReturn = self .returnType not in { IntType (), BoolType (), StrType (), NoneType ()}
265
267
self .enterScope ()
266
268
bb_entry = func .append_basic_block ('entry' )
267
269
self .builder = ir .IRBuilder (bb_entry )
@@ -274,12 +276,14 @@ def FuncDef(self, node: FuncDef):
274
276
for d in node .declarations :
275
277
self .visit (d )
276
278
self .visitStmtList (node .statements )
277
- # implicitly return None if possible
278
- if shouldReturnValue is not None and (
279
- len (node .statements ) == 0 or
280
- not isinstance (node .statements [- 1 ], ReturnStmt )
281
- ):
282
- self .builder .ret (self .NoneLiteral (None ))
279
+ # implicitly return None if needed, close all blocks
280
+ for block in func .blocks :
281
+ self .builder .position_at_end (block )
282
+ if not block .is_terminated :
283
+ if implicitReturn :
284
+ self .builder .ret (self .NoneLiteral (None ))
285
+ else :
286
+ self .builder .unreachable ()
283
287
self .exitScope ()
284
288
return func
285
289
@@ -536,10 +540,13 @@ def visitArg(self, funcType: FuncType, paramIdx: int, arg: Expr):
536
540
# unwrap if necessary, re-wrap
537
541
saved_block = self .builder .block
538
542
val = self .visit (arg )
543
+ # print(val)
539
544
addr = self .builder .alloca (
540
- node .var .t .getLLVMType ())
545
+ arg .inferredType .getLLVMType ())
546
+ # print(addr)
541
547
wrapper = self .builder .alloca (
542
- node .var .t .getLLVMType ().as_pointer (), None , "wrapper" )
548
+ arg .inferredType .getLLVMType ().as_pointer (), None , "wrapper" )
549
+ # print(wrapper)
543
550
self .builder .position_at_end (saved_block )
544
551
self .builder .store (val , addr )
545
552
self .builder .store (addr , wrapper )
@@ -659,6 +666,7 @@ def whileHelper(self, condFn, bodyFn):
659
666
self .builder .position_at_start (end_block )
660
667
661
668
def ReturnStmt (self , node : ReturnStmt ):
669
+ assert not self .builder .block .is_terminated
662
670
if self .returnType .isNone ():
663
671
self .builder .ret (self .NoneLiteral (None ))
664
672
else :
@@ -713,14 +721,14 @@ def ifHelper(self, condFn, thenFn, elseFn=None, returnType=None):
713
721
714
722
self .builder .position_at_start (then_block )
715
723
then_val = thenFn ()
716
- if not self . builder . block .is_terminated :
724
+ if not then_block .is_terminated :
717
725
self .builder .branch (merge_block )
718
726
then_block = self .builder .block
719
727
720
728
if elseFn is not None :
721
729
self .builder .position_at_start (else_block )
722
730
else_val = elseFn ()
723
- if not self . builder . block .is_terminated :
731
+ if not else_block .is_terminated :
724
732
self .builder .branch (merge_block )
725
733
else_block = self .builder .block
726
734
@@ -749,7 +757,8 @@ def MethodCallExpr(self, node: MethodCallExpr):
749
757
750
758
call_args = [self .builder .bitcast (obj , voidptr_t )]
751
759
for i in range (len (node .args )):
752
- call_args .append (self .visitArg (node .method .inferredType , i , node .args [i ]))
760
+ call_args .append (self .visitArg (
761
+ node .method .inferredType , i , node .args [i ]))
753
762
return self .builder .call (callee_func , call_args , 'callmethodtmp' )
754
763
755
764
# LITERALS
0 commit comments