Skip to content

Commit 06a5270

Browse files
committed
Implement code generation for infix float operations
Nice to see the type information that is passed to the code generation actually being used.
1 parent 3887227 commit 06a5270

File tree

3 files changed

+77
-36
lines changed

3 files changed

+77
-36
lines changed

src/TypeChecker.hs

+10-7
Original file line numberDiff line numberDiff line change
@@ -423,15 +423,18 @@ inferInfixType ::
423423
-> (Text -> CompileError)
424424
-> Either CompileError TypedExpression
425425
inferInfixType state op a b compileError =
426-
let expected =
427-
case op of
428-
StringAdd -> Str
429-
_ -> Num
426+
let validInfix a b =
427+
case (op, b, typeEq a b) of
428+
(StringAdd, Str, True) -> Just Str
429+
(StringAdd, _, _) -> Nothing
430+
(_, Num, True) -> Just Num
431+
(_, Float', True) -> Just Float'
432+
(_, _, _) -> Nothing
430433
types = (,) <$> inferType state a <*> inferType state b
431434
checkInfix (a, b) =
432-
if typeOf a `typeEq` expected && typeOf b `typeEq` expected
433-
then Right (TypeChecker.Infix expected op a b)
434-
else Left $
435+
case validInfix (typeOf a) (typeOf b) of
436+
Just returnType -> Right (TypeChecker.Infix returnType op a b)
437+
Nothing -> Left $
435438
compileError
436439
("No function exists with type " <> printType (typeOf a) <> " " <>
437440
operatorToString op <>

src/Wasm.hs

+60-29
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,15 @@ data Module =
3535
Module [TopLevel]
3636
BytesAllocated
3737

38+
data WasmType
39+
= I32
40+
| F32
41+
deriving (Show, Eq, G.Generic)
42+
3843
data Declaration =
3944
Declaration F.Ident
4045
[F.Ident]
46+
WasmType
4147
Expression
4248
deriving (Show, Eq, G.Generic)
4349

@@ -59,7 +65,7 @@ data Expression
5965
| If Expression
6066
Expression
6167
(Maybe Expression)
62-
| Sequence (NE.NonEmpty Expression)
68+
| Sequence WasmType (NE.NonEmpty Expression)
6369
deriving (Show, Eq)
6470

6571
data Locals =
@@ -163,29 +169,38 @@ allocateBytes (Module topLevel bytes) extraBytes =
163169
Module topLevel (bytes + extraBytes)
164170

165171
compileDeclaration :: Module -> TypedDeclaration -> Module
166-
compileDeclaration m (TypedDeclaration name args _ fexpr) =
172+
compileDeclaration m (TypedDeclaration name args fType fexpr) =
167173
let parameters = concatMap (fst <$> assignments) (fst <$> args)
168174
deconstruction = concatMap (snd <$> assignments) (fst <$> args)
169175
locals = Locals (Set.fromList parameters)
170176
(m', expr') = compileExpression m locals fexpr
177+
wasmType = forestTypeToWasmType fType
171178
func =
172179
Func $
173180
Declaration
174181
name
175182
parameters
176-
(Sequence $ NE.fromList (deconstruction <> [expr']))
183+
wasmType
184+
(Sequence wasmType $ NE.fromList (deconstruction <> [expr']))
177185
in addTopLevel m' [func]
178186

179187
compileInlineDeclaration ::
180188
Module -> Locals -> TypedDeclaration -> (Maybe Expression, Module)
181-
compileInlineDeclaration m (Locals l) (TypedDeclaration name args _ fexpr) =
189+
compileInlineDeclaration m (Locals l) (TypedDeclaration name args forestType fexpr) =
182190
let parameters = concatMap (fst <$> assignments) (fst <$> args)
183191
locals = Locals (Set.union l (Set.fromList parameters))
184192
(m', expr') = compileExpression m locals fexpr
185193
in case args of
186194
[] -> (Just $ SetLocal name expr', m')
187195
_ ->
188-
(Nothing, addTopLevel m' [Func $ Declaration name parameters expr'])
196+
(Nothing, addTopLevel m' [Func $ Declaration name parameters (forestTypeToWasmType forestType) expr'])
197+
198+
forestTypeToWasmType :: T.Type -> WasmType
199+
forestTypeToWasmType fType =
200+
case fType of
201+
Num -> I32
202+
Float' -> F32
203+
_ -> I32
189204

190205
compileExpressions ::
191206
Module -> NonEmpty TypedExpression -> (Module, [Expression])
@@ -221,9 +236,9 @@ compileInfix m locals operator a b =
221236
let (m', aExpr) = compileExpression m locals a
222237
(m'', bExpr) = compileExpression m' locals b
223238
name = (F.Ident $ F.NonEmptyString 's' "tring_add")
224-
in case operator of
225-
F.StringAdd -> (m'', NamedCall name [aExpr, bExpr])
226-
_ -> (m'', Call (funcForOperator operator) [aExpr, bExpr])
239+
in case (operator, T.typeOf b) of
240+
(F.StringAdd, T.Str) -> (m'', NamedCall name [aExpr, bExpr])
241+
(_, t) -> (m'', Call (funcForOperator operator t) [aExpr, bExpr])
227242

228243
compileApply ::
229244
Module
@@ -235,7 +250,7 @@ compileApply m locals left right =
235250
case left of
236251
T.Apply _ (T.Identifier _ name _) r' ->
237252
let (m', exprs) = compileExpressions m [right, r']
238-
in (m', Sequence $ NE.fromList (exprs <> [NamedCall name []]))
253+
in (m', Sequence I32 $ NE.fromList (exprs <> [NamedCall name []]))
239254
T.Identifier _ name _ ->
240255
let (m', r) = compileExpression m locals right
241256
in (m', NamedCall name [r])
@@ -261,7 +276,7 @@ compileLet m locals@(Locals l) declarations fexpr =
261276
NE.toList $ (\(TypedDeclaration name _ _ _) -> name) <$> declarations
262277
locals' = Locals $ Set.union l (Set.fromList names)
263278
(m'', expr') = compileExpression m' locals' fexpr
264-
in (m'', Sequence $ NE.fromList (declarationExpressions <> [expr']))
279+
in (m'', Sequence I32 $ NE.fromList (declarationExpressions <> [expr']))
265280

266281
compileCase ::
267282
Module
@@ -305,7 +320,7 @@ compileCase m locals caseFexpr patterns =
305320
compileADTConstruction ::
306321
(Functor t, Foldable t) => Int -> t (F.Argument, b) -> Expression
307322
compileADTConstruction tag args =
308-
Sequence
323+
Sequence I32
309324
(NE.fromList
310325
([ SetLocal
311326
(ident "address")
@@ -361,7 +376,7 @@ compileDeconstructionAssignment i a n =
361376
(Call
362377
(ident "i32.load")
363378
[Call (ident "i32.add") [GetLocal i, Const $ n * 4]]))
364-
_ -> Sequence []
379+
_ -> Sequence I32 []
365380

366381
compileCaseExpression ::
367382
Module -> Locals -> T.TypedExpression -> (Module, Expression)
@@ -397,7 +412,7 @@ compileArgument m caseFexpr arg =
397412
localName (TAIdentifier _ ident') = Just ident'
398413
localName _ = Nothing
399414
locals = addLocals (mapMaybe localName args) noLocals
400-
in (m, Sequence (NE.fromList (assignments <> [Const tag])), locals)
415+
in (m, Sequence I32 (NE.fromList (assignments <> [Const tag])), locals)
401416
where
402417
caseLocal =
403418
case caseFexpr of
@@ -407,15 +422,24 @@ compileArgument m caseFexpr arg =
407422
eq32 :: F.Ident
408423
eq32 = F.Ident $ F.NonEmptyString 'i' "32.eq"
409424

410-
funcForOperator :: F.OperatorExpr -> F.Ident
411-
funcForOperator operator =
412-
F.Ident . uncurry F.NonEmptyString $
413-
case operator of
414-
F.Add -> ('i', "32.add")
415-
F.Subtract -> ('i', "32.sub")
416-
F.Multiply -> ('i', "32.mul")
417-
F.Divide -> ('i', "32.div_s")
418-
F.StringAdd -> ('s', "tring_add")
425+
funcForOperator :: F.OperatorExpr -> T.Type -> F.Ident
426+
funcForOperator operator t =
427+
let
428+
wasmType =
429+
case t of
430+
Num -> "i32"
431+
Float' -> "f32"
432+
_ -> error $ "tried to get a funcForOperator for a non numeric type: " <> (Text.unpack $ T.printType t)
433+
op =
434+
case (operator, t) of
435+
(F.Add, _) -> "add"
436+
(F.Subtract, _) -> "sub"
437+
(F.Multiply, _) -> "mul"
438+
(F.Divide, Float') -> "div"
439+
(F.Divide, _) -> "div_s"
440+
_ -> error $ "tried to get a funcForOperator for a non numeric type: " <> (Text.unpack $ T.printType t)
441+
in
442+
ident (wasmType <> "." <> op)
419443

420444
printWasm :: Module -> Text
421445
printWasm (Module expressions bytesAllocated) =
@@ -439,10 +463,10 @@ printMemory bytes =
439463
printWasmTopLevel :: TopLevel -> Text
440464
printWasmTopLevel topLevel =
441465
case topLevel of
442-
Func (Declaration name args body) ->
466+
Func (Declaration name args wasmType body) ->
443467
Text.unlines
444468
[ "(export \"" <> F.s name <> "\" (func $" <> F.s name <> "))"
445-
, printDeclaration (Declaration name args body)
469+
, printDeclaration (Declaration name args wasmType body)
446470
]
447471
Data offset str ->
448472
"(data (i32.const " <> showT offset <> ") \"" <>
@@ -459,8 +483,8 @@ printWasmTopLevel topLevel =
459483
printWasmExpr :: Expression -> Text
460484
printWasmExpr expr =
461485
case expr of
462-
Sequence exprs ->
463-
"(block (result i32)\n" <> indent2 (Text.intercalate "\n" $ NE.toList (printWasmExpr <$> exprs)) <> "\n)"
486+
Sequence wasmType exprs ->
487+
"(block (result " <> printWasmType wasmType <> ")\n" <> indent2 (Text.intercalate "\n" $ NE.toList (printWasmExpr <$> exprs)) <> "\n)"
464488
Const n -> "(i32.const " <> showT n <> ")"
465489
FloatConst n -> "(f32.const " <> showT n <> ")"
466490
GetLocal name -> "(get_local $" <> F.s name <> ")"
@@ -482,13 +506,20 @@ printWasmExpr expr =
482506
] <>
483507
[indent2 $ maybe "(i32.const 0)" printWasmExpr b, ")"])
484508

509+
510+
printWasmType :: WasmType -> Text
511+
printWasmType wasmType =
512+
case wasmType of
513+
I32 -> "i32"
514+
F32 -> "f32"
515+
485516
printDeclaration :: Declaration -> Text
486-
printDeclaration (Declaration name args body) =
517+
printDeclaration (Declaration name args wasmType body) =
487518
Text.intercalate
488519
"\n"
489520
[ "(func $" <> F.s name <>
490521
Text.unwords (fmap (\x -> " (param $" <> x <> " i32)") (F.s <$> args)) <>
491-
" (result i32) " <>
522+
" (result " <> printWasmType wasmType <> ") " <>
492523
Text.unwords (printLocal <$> locals body)
493524
, indent2 $ Text.unlines ["(return", indent2 $ printWasmExpr body, ")"]
494525
, ")"
@@ -498,7 +529,7 @@ printDeclaration (Declaration name args body) =
498529
locals expr' =
499530
case expr' of
500531
SetLocal name _ -> [F.s name]
501-
Sequence exprs -> concatMap locals $ NE.toList exprs
532+
Sequence _ exprs -> concatMap locals $ NE.toList exprs
502533
If expr expr' mexpr ->
503534
locals expr <> locals expr' <> maybe [] locals mexpr
504535
Call _ exprs -> concatMap locals exprs

test/integration.rb

+7
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,13 @@ def run_tests
187187

188188
testCode('adt_deconstruction_function', code, 30)
189189

190+
code = <<~FOREST
191+
main :: Float
192+
main = 5.0 / 2.0 * 4.0
193+
FOREST
194+
195+
testCode('float_infix_ops', code, 10)
196+
190197
puts 'Integration tests ran successfully!'
191198
end
192199

0 commit comments

Comments
 (0)