From 3b4b2f935d26ada3f7f45976873089ed5aa07f6b Mon Sep 17 00:00:00 2001
From: Terts Diepraam <terts.diepraam@gmail.com>
Date: Fri, 17 Jan 2025 16:54:20 +0100
Subject: [PATCH 1/3] add some basic testing infrastructure

---
 examples/simple.roto          | 29 +++++++++++++-----
 examples/simple.rs            |  3 ++
 src/ast.rs                    |  7 +++++
 src/codegen/mod.rs            | 50 +++++++++++++++++++++++++++++--
 src/codegen/tests.rs          |  1 -
 src/lower/ir.rs               |  2 +-
 src/lower/mod.rs              | 56 +++++++++++++++++++++++++++++++++--
 src/parser/mod.rs             | 12 +++++++-
 src/parser/token.rs           |  3 ++
 src/pipeline.rs               |  4 +++
 src/typechecker/filter_map.rs | 25 ++++++++++++++++
 src/typechecker/mod.rs        |  6 ++++
 src/typechecker/tests.rs      | 38 ++++++++++++------------
 13 files changed, 202 insertions(+), 34 deletions(-)

diff --git a/examples/simple.roto b/examples/simple.roto
index dfcddcaf..5c44aa99 100644
--- a/examples/simple.roto
+++ b/examples/simple.roto
@@ -1,12 +1,25 @@
+function is_zero(x: IpAddr) -> bool {
+    x == 0.0.0.0
+}
+
 filter-map main(x: IpAddr) {
-    define {
-        y = 0.0.0.0;
+    if is_zero(x) {
+        accept
+    } else {
+        reject
+    }
+}
+
+test is_zero_true {
+    if is_zero(1.1.1.1) {
+        reject
     }
-    apply {
-        if x == y {
-            accept
-        } else {
-            reject
-        }
+    accept
+}
+
+test is_zero_false {
+    if not is_zero(0.0.0.0) {
+        reject
     }
+    accept
 }
diff --git a/examples/simple.rs b/examples/simple.rs
index 7528973f..1d099065 100644
--- a/examples/simple.rs
+++ b/examples/simple.rs
@@ -31,5 +31,8 @@ fn main() -> Result<(), roto::RotoReport> {
     let res = func.call(&mut (), "1.1.1.1".parse().unwrap());
     println!("main(1.1.1.1) = {res:?}");
 
+    println!();
+    let _ = compiled.run_tests(());
+
     Ok(())
 }
diff --git a/src/ast.rs b/src/ast.rs
index 00740920..08b7bc4a 100644
--- a/src/ast.rs
+++ b/src/ast.rs
@@ -23,6 +23,7 @@ pub enum Declaration {
     OutputStream(OutputStream),
     Record(RecordTypeDeclaration),
     Function(FunctionDeclaration),
+    Test(Test),
 }
 
 #[derive(Clone, Debug)]
@@ -58,6 +59,12 @@ pub struct FunctionDeclaration {
     pub body: Meta<Block>,
 }
 
+#[derive(Clone, Debug)]
+pub struct Test {
+    pub ident: Meta<Identifier>,
+    pub body: Meta<Block>,
+}
+
 /// A block of multiple statements
 #[derive(Clone, Debug)]
 pub struct Block {
diff --git a/src/codegen/mod.rs b/src/codegen/mod.rs
index 09577996..cadd4334 100644
--- a/src/codegen/mod.rs
+++ b/src/codegen/mod.rs
@@ -22,7 +22,7 @@ use crate::{
         RuntimeConstant,
     },
     typechecker::{info::TypeInfo, scope::ScopeRef, types},
-    IrValue,
+    IrValue, Verdict,
 };
 use check::{
     check_roto_type_reflect, return_type_by_ref, FunctionRetrievalError,
@@ -555,7 +555,7 @@ impl<'c> FuncGen<'c> {
             args.next().unwrap(),
         );
 
-        if dbg!(return_ptr) {
+        if return_ptr {
             self.def(
                 self.module.variable_map[&Var {
                     scope: self.scope,
@@ -1014,6 +1014,52 @@ impl<'c> FuncGen<'c> {
 }
 
 impl Module {
+    pub fn run_tests<Ctx: 'static>(
+        &mut self,
+        mut ctx: Ctx,
+    ) -> Result<(), ()> {
+        let tests: Vec<_> = self
+            .functions
+            .keys()
+            .filter(|x| x.starts_with("test#"))
+            .map(Clone::clone)
+            .collect();
+
+        let total = tests.len();
+        let total_width = total.to_string().len();
+        let mut successes = 0;
+        let mut failures = 0;
+
+        for (n, test) in tests.into_iter().enumerate() {
+            let n = n + 1;
+            let test_display = test.strip_prefix("test#").unwrap();
+            print!("Test {n:>total_width$} / {total}: {test_display}... ");
+            let test_fn = self
+                .get_function::<Ctx, (), Verdict<(), ()>>(&test)
+                .unwrap();
+
+            match test_fn.call(&mut ctx) {
+                Verdict::Accept(()) => {
+                    successes += 1;
+                    println!("\x1B[92mok\x1B[m");
+                }
+                Verdict::Reject(()) => {
+                    failures += 1;
+                    println!("\x1B[91mfail\x1B[m");
+                }
+            }
+        }
+        println!(
+            "Ran {total} tests, {successes} succeeded, {failures} failed"
+        );
+
+        if failures == 0 {
+            Result::Err(())
+        } else {
+            Result::Ok(())
+        }
+    }
+
     pub fn get_function<Ctx: 'static, Params: RotoParams, Return: Reflect>(
         &mut self,
         name: &str,
diff --git a/src/codegen/tests.rs b/src/codegen/tests.rs
index 801c6d51..4b75c0c5 100644
--- a/src/codegen/tests.rs
+++ b/src/codegen/tests.rs
@@ -53,7 +53,6 @@ fn accept() {
         .expect("No function found (or mismatched types)");
 
     let res = f.call(&mut ());
-    dbg!(std::mem::size_of::<Verdict<(), ()>>());
     assert_eq!(res, Verdict::Accept(()));
 }
 
diff --git a/src/lower/ir.rs b/src/lower/ir.rs
index 04f6921c..802d203d 100644
--- a/src/lower/ir.rs
+++ b/src/lower/ir.rs
@@ -360,8 +360,8 @@ impl<'a> IrPrinter<'a> {
             } => format!(
                 "{}: {ty} = {}({}, {})",
                 self.var(to),
-                self.operand(ctx),
                 self.ident(func),
+                self.operand(ctx),
                 args.iter()
                     .map(|a| format!(
                         "{} = {}",
diff --git a/src/lower/mod.rs b/src/lower/mod.rs
index 2237e4dc..f1270227 100644
--- a/src/lower/mod.rs
+++ b/src/lower/mod.rs
@@ -179,6 +179,21 @@ impl<'r> Lowerer<'r> {
                         .function(ident, params, ret, body),
                     );
                 }
+                ast::Declaration::Test(ast::Test { ident, body }) => {
+                    functions.push(
+                        Lowerer::new(
+                            type_info,
+                            runtime_functions,
+                            &Meta {
+                                node: format!("test#{}", ident).into(),
+                                id: ident.id,
+                            },
+                            label_store,
+                            runtime,
+                        )
+                        .test(ident, body),
+                    );
+                }
                 // Ignore the rest
                 _ => {}
             }
@@ -240,8 +255,6 @@ impl<'r> Lowerer<'r> {
             x => (Some(self.lower_type(&x)), false),
         };
 
-        dbg!(return_ptr);
-
         let ir_signature = ir::Signature {
             parameters: parameter_types
                 .iter()
@@ -335,6 +348,45 @@ impl<'r> Lowerer<'r> {
         }
     }
 
+    fn test(
+        mut self,
+        ident: &Meta<Identifier>,
+        body: &ast::Block,
+    ) -> Function {
+        let label = self.label_store.new_label(self.function_name);
+        self.new_block(label);
+
+        let unit = Box::new(Type::Primitive(Primitive::Unit));
+
+        let return_type = Type::Verdict(unit.clone(), unit);
+
+        let signature = Signature {
+            kind: FunctionKind::Free,
+            parameter_types: Vec::new(),
+            return_type: return_type.clone(),
+        };
+
+        let ir_signature = ir::Signature {
+            parameters: Vec::new(),
+            return_ptr: true, // TODO: check this
+            return_type: None,
+        };
+
+        let last = self.block(body);
+
+        self.add(Instruction::Return(last));
+
+        Function {
+            name: format!("test#{}", ident.node).into(),
+            scope: self.function_scope,
+            entry_block: label,
+            blocks: self.blocks,
+            public: true,
+            signature,
+            ir_signature,
+        }
+    }
+
     /// Lower a block
     ///
     /// Returns either the value of the expression or the place where the
diff --git a/src/parser/mod.rs b/src/parser/mod.rs
index a5c9837f..a952ae5a 100644
--- a/src/parser/mod.rs
+++ b/src/parser/mod.rs
@@ -1,4 +1,6 @@
-use crate::ast::{Declaration, FunctionDeclaration, Identifier, SyntaxTree};
+use crate::ast::{
+    Declaration, FunctionDeclaration, Identifier, SyntaxTree, Test,
+};
 use logos::{Lexer, SpannedIter};
 use std::{fmt::Display, iter::Peekable};
 use token::Token;
@@ -336,6 +338,7 @@ impl<'source, 'spans> Parser<'source, 'spans> {
                 Declaration::Record(self.record_type_assignment()?)
             }
             Token::Function => Declaration::Function(self.function()?),
+            Token::Test => Declaration::Test(self.test()?),
             _ => {
                 let (token, span) = self.next()?;
                 return Err(ParseError::expected(
@@ -372,6 +375,13 @@ impl<'source, 'spans> Parser<'source, 'spans> {
             ret,
         })
     }
+
+    fn test(&mut self) -> ParseResult<Test> {
+        self.take(Token::Test)?;
+        let ident = self.identifier()?;
+        let body = self.block()?;
+        Ok(Test { ident, body })
+    }
 }
 
 /// # Parsing identifiers
diff --git a/src/parser/token.rs b/src/parser/token.rs
index be799a5b..04bf74b9 100644
--- a/src/parser/token.rs
+++ b/src/parser/token.rs
@@ -115,6 +115,8 @@ pub enum Token<'s> {
     Some,
     #[token("table")]
     Table,
+    #[token("test")]
+    Test,
     #[token("through")]
     Through,
     #[token("type")]
@@ -211,6 +213,7 @@ impl Display for Token<'_> {
             Token::Rib => "rib",
             Token::Some => "some",
             Token::Table => "table",
+            Token::Test => "test",
             Token::Through => "through",
             Token::Type => "type",
             Token::UpTo => "up-to",
diff --git a/src/pipeline.rs b/src/pipeline.rs
index ca3851e8..e37675be 100644
--- a/src/pipeline.rs
+++ b/src/pipeline.rs
@@ -447,6 +447,10 @@ impl Lowered {
 }
 
 impl Compiled {
+    pub fn run_tests<Ctx: 'static>(&mut self, ctx: Ctx) -> Result<(), ()> {
+        self.module.run_tests(ctx)
+    }
+
     pub fn get_function<Ctx: 'static, Params: RotoParams, Return: Reflect>(
         &mut self,
         name: &str,
diff --git a/src/typechecker/filter_map.rs b/src/typechecker/filter_map.rs
index dbcb136d..37ab5556 100644
--- a/src/typechecker/filter_map.rs
+++ b/src/typechecker/filter_map.rs
@@ -114,6 +114,31 @@ impl TypeChecker<'_> {
         Ok(())
     }
 
+    pub fn test(
+        &mut self,
+        scope: LocalScopeRef,
+        test: &ast::Test,
+    ) -> TypeResult<()> {
+        let ast::Test { ident, body } = test;
+
+        let scope = self
+            .scope_graph
+            .wrap(scope, ScopeType::Function(ident.node));
+
+        self.type_info
+            .function_scopes
+            .insert(ident.id, scope.into());
+
+        let unit = Box::new(Type::Primitive(Primitive::Unit));
+        let ret = Type::Verdict(unit.clone(), unit);
+        let ctx = Context {
+            expected_type: ret.clone(),
+            function_return_type: Some(ret),
+        };
+        self.block(scope, &ctx, body)?;
+        Ok(())
+    }
+
     pub fn function_type(
         &mut self,
         dec: &ast::FunctionDeclaration,
diff --git a/src/typechecker/mod.rs b/src/typechecker/mod.rs
index 10ca4bc9..61a2fd6f 100644
--- a/src/typechecker/mod.rs
+++ b/src/typechecker/mod.rs
@@ -304,6 +304,12 @@ impl TypeChecker<'_> {
             }
         }
 
+        for expr in &tree.declarations {
+            if let ast::Declaration::Test(t) = expr {
+                checker.test(module_scope, t)?
+            }
+        }
+
         Ok(checker.type_info)
     }
 
diff --git a/src/typechecker/tests.rs b/src/typechecker/tests.rs
index d5a980e6..93bd5388 100644
--- a/src/typechecker/tests.rs
+++ b/src/typechecker/tests.rs
@@ -257,7 +257,7 @@ fn integer_inference() {
         "
         type Foo { x: u8 }
 
-        filter-map test(r: u32) {
+        filter-map my_map(r: u32) {
             let foo = Foo { x: 5 };
             accept
         }
@@ -270,7 +270,7 @@ fn integer_inference() {
         type Foo { x: u8 }
         type Bar { x: u8 }
 
-        filter-map test(r: u32) {
+        filter-map my_map(r: u32) {
             let a = 5;
             let foo = Foo { x: a };
             accept
@@ -284,7 +284,7 @@ fn integer_inference() {
         type Foo { x: u8 }
         type Bar { x: u8 }
 
-        filter-map test(r: u32) {
+        filter-map my_map(r: u32) {
             let a = 5;
             let foo = Foo { x: a };
             let bar = Bar { x: a };
@@ -299,7 +299,7 @@ fn integer_inference() {
         type Foo { x: u8 }
         type Bar { x: u32 }
 
-        filter-map test(r: u32) {
+        filter-map my_map(r: u32) {
             let a = 5;
             let foo = Foo { x: a };
             let bar = Bar { x: a };
@@ -314,7 +314,7 @@ fn integer_inference() {
         type Foo { x: u8 }
         type Bar { x: u32 }
 
-        filter-map test(r: u32) {
+        filter-map my_map(r: u32) {
             let foo = Foo { x: 5 };
             let bar = Bar { x: 5 };
             accept
@@ -331,7 +331,7 @@ fn assign_field_to_other_record() {
         type Foo { x: u8 }
         type Bar { x: u8 }
 
-        filter-map test(r: u32) {
+        filter-map my_map(r: u32) {
             let foo = Foo { x: 5 };
             let bar = Bar { x: foo.x };
             accept
@@ -345,7 +345,7 @@ fn assign_field_to_other_record() {
         type Foo { x: u8 }
         type Bar { x: u8 }
 
-        filter-map test(r: u32) {
+        filter-map my_map(r: u32) {
             let foo = Foo { x: 5 };
             let bar = Bar { x: foo.y };
             accept
@@ -359,7 +359,7 @@ fn assign_field_to_other_record() {
         type Foo { x: u8 }
         type Bar { x: u32 }
 
-        filter-map test(r: u32) {
+        filter-map my_map(r: u32) {
             let foo = Foo { x: 5 };
             let bar = Bar { x: foo.x };
             accept
@@ -373,7 +373,7 @@ fn assign_field_to_other_record() {
 fn ip_addr_method() {
     let s = src!(
         "
-        filter-map test(r: u32) {
+        filter-map my_map(r: u32) {
             let p = 10.10.10.10;
             let is_four = 1 + p.is_ipv4();
             accept
@@ -384,7 +384,7 @@ fn ip_addr_method() {
 
     let s = src!(
         "
-        filter-map test(r: u32) {
+        filter-map my_map(r: u32) {
             let p = 10.10.10.10;
             let is_four = true && p.is_ipv4();
             accept
@@ -399,7 +399,7 @@ fn ip_addr_method() {
 fn ip_addr_method_of_method_return_type() {
     let s = src!(
         "
-        filter-map test(r: u32) {
+        filter-map my_map(r: u32) {
             let p = 10.10.10.10;
             let x = p.to_canonical().is_ipv4();
             accept
@@ -410,7 +410,7 @@ fn ip_addr_method_of_method_return_type() {
 
     let s = src!(
         "
-        filter-map test(r: u32) {
+        filter-map my_map(r: u32) {
             let p = 10.10.10.10;
             let x = p.is_ipv4().to_canonical();
             accept
@@ -425,7 +425,7 @@ fn ip_addr_method_of_method_return_type() {
 fn prefix_method() {
     let s = src!(
         "
-        filter-map test(r: u32) {
+        filter-map my_map(r: u32) {
             let p = 10.10.10.10/20;
             let add = p.address();
             accept
@@ -444,7 +444,7 @@ fn logical_expr() {
             (10 == 10) || (10 == 11)
         }
 
-        filter-map test(r: u32) {
+        filter-map my_map(r: u32) {
             let p = 10.10.10.10/10;
             accept
         }
@@ -458,7 +458,7 @@ fn logical_expr() {
             (10 == 10) || ("hello" == 11)
         }
 
-        filter-map test(r: u32) {
+        filter-map my_map(r: u32) {
             let p = 10.10.10.10/10;
             accept
         }
@@ -482,7 +482,7 @@ fn send_output_stream() {
             });
         }
         
-        filter-map test(r: u32) {
+        filter-map my_map(r: u32) {
             accept
         }
     "#
@@ -503,7 +503,7 @@ fn send_output_stream() {
             });
         }
 
-        filter-map test(r: u32) {
+        filter-map my_map(r: u32) {
             accept
         }
     "#
@@ -529,7 +529,7 @@ fn send_output_stream() {
             });
         }
 
-        filter-map test(r: u32) {
+        filter-map my_map(r: u32) {
             accept
         }
     "#
@@ -555,7 +555,7 @@ fn send_output_stream() {
             });
         }
         
-        filter-map test(r: u32) {
+        filter-map my_map(r: u32) {
             accept
         }
     "#

From 69f9b16a0ebc22237b009740e7798f5156c29b25 Mon Sep 17 00:00:00 2001
From: Terts Diepraam <terts.diepraam@gmail.com>
Date: Fri, 17 Jan 2025 19:11:29 +0100
Subject: [PATCH 2/3] merge some logic between filter-maps, functions and tests

---
 examples/simple.rs            |   7 ++
 src/ast.rs                    |   2 +-
 src/codegen/check.rs          |  12 --
 src/codegen/mod.rs            |  16 +--
 src/codegen/tests.rs          |  63 +++++++++++
 src/lower/mod.rs              | 204 ++++++++++------------------------
 src/parser/filter_map.rs      |   4 +-
 src/typechecker/filter_map.rs |   4 +-
 src/typechecker/info.rs       |  14 +++
 9 files changed, 152 insertions(+), 174 deletions(-)

diff --git a/examples/simple.rs b/examples/simple.rs
index 1d099065..71d42784 100644
--- a/examples/simple.rs
+++ b/examples/simple.rs
@@ -31,6 +31,13 @@ fn main() -> Result<(), roto::RotoReport> {
     let res = func.call(&mut (), "1.1.1.1".parse().unwrap());
     println!("main(1.1.1.1) = {res:?}");
 
+    let is_zero = compiled
+        .get_function::<(), (IpAddr,), bool>("is_zero")
+        .unwrap();
+
+    let res = is_zero.call(&mut (), "0.0.0.0".parse().unwrap());
+    println!("is_zero(0.0.0.0) = {res:?}");
+
     println!();
     let _ = compiled.run_tests(());
 
diff --git a/src/ast.rs b/src/ast.rs
index 08b7bc4a..967837ef 100644
--- a/src/ast.rs
+++ b/src/ast.rs
@@ -47,7 +47,7 @@ pub struct FilterMap {
     pub filter_type: FilterType,
     pub ident: Meta<Identifier>,
     pub params: Meta<Params>,
-    pub block: Meta<Block>,
+    pub body: Meta<Block>,
 }
 
 /// A function declaration, including the [`Block`] forming its definition
diff --git a/src/codegen/check.rs b/src/codegen/check.rs
index 75599321..78e2eace 100644
--- a/src/codegen/check.rs
+++ b/src/codegen/check.rs
@@ -158,18 +158,6 @@ fn check_roto_type(
     }
 }
 
-pub fn return_type_by_ref(registry: &TypeRegistry, rust_ty: TypeId) -> bool {
-    let Some(rust_ty) = registry.get(rust_ty) else {
-        return false;
-    };
-
-    #[allow(clippy::match_like_matches_macro)]
-    match rust_ty.description {
-        TypeDescription::Verdict(_, _) => true,
-        _ => todo!(),
-    }
-}
-
 /// Parameters of a Roto function
 ///
 /// This trait allows for checking the types against Roto types and converting
diff --git a/src/codegen/mod.rs b/src/codegen/mod.rs
index cadd4334..f8f7d220 100644
--- a/src/codegen/mod.rs
+++ b/src/codegen/mod.rs
@@ -16,17 +16,12 @@ use crate::{
         value::IrType,
         IrFunction,
     },
-    runtime::{
-        context::ContextDescription,
-        ty::{Reflect, GLOBAL_TYPE_REGISTRY},
-        RuntimeConstant,
-    },
+    runtime::{context::ContextDescription, ty::Reflect, RuntimeConstant},
     typechecker::{info::TypeInfo, scope::ScopeRef, types},
     IrValue, Verdict,
 };
 use check::{
-    check_roto_type_reflect, return_type_by_ref, FunctionRetrievalError,
-    RotoParams, TypeMismatch,
+    check_roto_type_reflect, FunctionRetrievalError, RotoParams, TypeMismatch,
 };
 use cranelift::{
     codegen::{
@@ -1054,9 +1049,9 @@ impl Module {
         );
 
         if failures == 0 {
-            Result::Err(())
-        } else {
             Result::Ok(())
+        } else {
+            Result::Err(())
         }
     }
 
@@ -1097,9 +1092,8 @@ impl Module {
             )
         })?;
 
-        let registry = GLOBAL_TYPE_REGISTRY.lock().unwrap();
         let return_by_ref =
-            return_type_by_ref(&registry, TypeId::of::<Return>());
+            self.type_info.is_reference_type(&sig.return_type);
 
         let func_ptr = self.inner.0.get_finalized_function(id);
         Ok(TypedFunc {
diff --git a/src/codegen/tests.rs b/src/codegen/tests.rs
index 4b75c0c5..cfe820f5 100644
--- a/src/codegen/tests.rs
+++ b/src/codegen/tests.rs
@@ -901,3 +901,66 @@ fn use_context() {
     let output = f.call(&mut ctx);
     assert_eq!(output, Verdict::Accept(11));
 }
+
+#[test]
+fn use_a_roto_function() {
+    let s = src!(
+        "
+        function double(x: i32) -> i32 {
+            2 * x
+        }"
+    );
+
+    let mut p = compile(s);
+    let f = p.get_function::<(), (i32,), i32>("double").unwrap();
+    let output = f.call(&mut (), 2);
+    assert_eq!(output, 4);
+
+    let output = f.call(&mut (), 16);
+    assert_eq!(output, 32);
+}
+
+#[test]
+fn use_a_test() {
+    let s = src!(
+        "
+        function double(x: i32) -> i32 {
+            x # oops! not correct
+        }
+        
+        test check_double {
+            if double(4) != 8 {
+                reject;
+            }
+            if double(16) != 32 {
+                reject;
+            }
+            accept
+        }
+        "
+    );
+
+    let mut p = compile(s);
+    p.run_tests(()).unwrap_err();
+
+    let s = src!(
+        "
+        function double(x: i32) -> i32 {
+            2 * x
+        }
+        
+        test check_double {
+            if double(4) != 8 {
+                reject;
+            }
+            if double(16) != 32 {
+                reject;
+            }
+            accept
+        }
+        "
+    );
+
+    let mut p = compile(s);
+    p.run_tests(()).unwrap();
+}
diff --git a/src/lower/mod.rs b/src/lower/mod.rs
index f1270227..d8a23983 100644
--- a/src/lower/mod.rs
+++ b/src/lower/mod.rs
@@ -162,36 +162,31 @@ impl<'r> Lowerer<'r> {
                         .filter_map(x),
                     );
                 }
-                ast::Declaration::Function(ast::FunctionDeclaration {
-                    ident,
-                    params,
-                    body,
-                    ret,
-                }) => {
+                ast::Declaration::Function(x) => {
                     functions.push(
                         Lowerer::new(
                             type_info,
                             runtime_functions,
-                            ident,
+                            &x.ident,
                             label_store,
                             runtime,
                         )
-                        .function(ident, params, ret, body),
+                        .function(x),
                     );
                 }
-                ast::Declaration::Test(ast::Test { ident, body }) => {
+                ast::Declaration::Test(x) => {
                     functions.push(
                         Lowerer::new(
                             type_info,
                             runtime_functions,
                             &Meta {
-                                node: format!("test#{}", ident).into(),
-                                id: ident.id,
+                                node: format!("test#{}", x.ident).into(),
+                                id: x.ident.id,
                             },
                             label_store,
                             runtime,
                         )
-                        .test(ident, body),
+                        .test(x),
                     );
                 }
                 // Ignore the rest
@@ -203,95 +198,50 @@ impl<'r> Lowerer<'r> {
     }
 
     /// Lower a filter-map
-    fn filter_map(mut self, fm: &ast::FilterMap) -> Function {
+    fn filter_map(self, fm: &ast::FilterMap) -> Function {
         let ast::FilterMap {
             ident,
-            block,
+            body,
             params,
             ..
         } = fm;
 
-        let label = self.label_store.new_label(self.function_name);
-        self.new_block(label);
-
-        let parameter_types: Vec<_> = params
-            .0
-            .iter()
-            .map(|(x, _)| {
-                let ty = self.type_info.type_of(x);
-                (self.type_info.resolved_name(x), ty)
-            })
-            .collect();
-
-        for (def, ty) in &parameter_types {
-            let (scope, name) = def.to_scope_and_name();
-            let var = Var {
-                scope,
-                kind: VarKind::Explicit(name),
-            };
-            self.stack_slots.push((var, ty.clone()))
-        }
-
-        let return_type = self.type_info.type_of(block);
-        let last = self.block(block);
-
-        self.add(Instruction::Return(last));
-
-        let signature = Signature {
-            kind: FunctionKind::Free,
-            parameter_types: parameter_types
-                .iter()
-                .cloned()
-                .map(|x| x.1)
-                .collect(),
-            return_type: return_type.clone(),
-        };
+        let return_type = self.type_info.type_of(body);
+        self.function_like(ident, params, &return_type, body)
+    }
 
-        let (return_type, return_ptr) = match return_type {
-            x if self.type_info.size_of(&x, self.runtime) == 0 => {
-                (None, false)
-            }
-            x if self.is_reference_type(&x) => (None, true),
-            x => (Some(self.lower_type(&x)), false),
-        };
+    fn function(self, function: &ast::FunctionDeclaration) -> Function {
+        let return_type = function
+            .ret
+            .as_ref()
+            .map(|t| self.type_info.resolve(&Type::Name(**t)))
+            .unwrap_or(Type::Primitive(Primitive::Unit));
+
+        self.function_like(
+            &function.ident,
+            &function.params,
+            &return_type,
+            &function.body,
+        )
+    }
 
-        let ir_signature = ir::Signature {
-            parameters: parameter_types
-                .iter()
-                .map(|(def, ty)| {
-                    (def.to_scope_and_name().1, self.lower_type(ty))
-                })
-                .collect(),
-            return_ptr,
-            return_type,
+    fn test(self, test: &ast::Test) -> Function {
+        let ident = Meta {
+            node: format!("test#{}", *test.ident).into(),
+            id: test.ident.id,
         };
-
-        Function {
-            name: ident.node,
-            scope: self.function_scope,
-            entry_block: label,
-            signature,
-            ir_signature,
-            blocks: self.blocks,
-            public: true,
-        }
+        let unit = Box::new(Type::Primitive(Primitive::Unit));
+        let return_type = Type::Verdict(unit.clone(), unit);
+        let params = ast::Params(Vec::new());
+        self.function_like(&ident, &params, &return_type, &test.body)
     }
 
-    /// Lower a function
-    ///
-    /// We compile functions differently than turing complete languages,
-    /// because we don't have a full stack. Instead, the arguments just
-    /// have fixed locations. The interpreter does have to keep track
-    /// where the function is called and hence where it has to return to
-    ///
-    /// Ultimately, it should be possible to inline all functions, so
-    /// that we don't even need a return instruction. However, this could
-    /// also be done by an optimizing step.
-    fn function(
+    /// Lower a function-like construct (i.e. a function, filter-map or test)
+    fn function_like(
         mut self,
         ident: &Meta<Identifier>,
-        params: &Meta<ast::Params>,
-        return_type: &Option<Meta<Identifier>>,
+        params: &ast::Params,
+        return_type: &Type,
         body: &ast::Block,
     ) -> Function {
         let label = self.label_store.new_label(self.function_name);
@@ -304,9 +254,14 @@ impl<'r> Lowerer<'r> {
             parameter_types.push((self.type_info.resolved_name(x), ty));
         }
 
-        let return_type = return_type
-            .as_ref()
-            .map(|t| self.type_info.resolve(&Type::Name(**t)));
+        for (def, ty) in &parameter_types {
+            let (scope, name) = def.to_scope_and_name();
+            let var = Var {
+                scope,
+                kind: VarKind::Explicit(name),
+            };
+            self.stack_slots.push((var, ty.clone()))
+        }
 
         let signature = Signature {
             kind: FunctionKind::Free,
@@ -315,12 +270,16 @@ impl<'r> Lowerer<'r> {
                 .cloned()
                 .map(|x| x.1)
                 .collect(),
-            return_type: return_type
-                .clone()
-                .unwrap_or(Type::Primitive(Primitive::Unit)),
+            return_type: return_type.clone(),
         };
 
-        let return_type = return_type.map(|t| self.lower_type(&t));
+        let (return_type, return_ptr) = match return_type {
+            x if self.type_info.size_of(x, self.runtime) == 0 => {
+                (None, false)
+            }
+            x if self.is_reference_type(x) => (None, true),
+            x => (Some(self.lower_type(x)), false),
+        };
 
         let ir_signature = ir::Signature {
             parameters: parameter_types
@@ -329,7 +288,7 @@ impl<'r> Lowerer<'r> {
                     (def.to_scope_and_name().1, self.lower_type(ty))
                 })
                 .collect(),
-            return_ptr: false, // TODO: check this
+            return_ptr,
             return_type,
         };
 
@@ -342,45 +301,6 @@ impl<'r> Lowerer<'r> {
             scope: self.function_scope,
             entry_block: label,
             blocks: self.blocks,
-            public: false,
-            signature,
-            ir_signature,
-        }
-    }
-
-    fn test(
-        mut self,
-        ident: &Meta<Identifier>,
-        body: &ast::Block,
-    ) -> Function {
-        let label = self.label_store.new_label(self.function_name);
-        self.new_block(label);
-
-        let unit = Box::new(Type::Primitive(Primitive::Unit));
-
-        let return_type = Type::Verdict(unit.clone(), unit);
-
-        let signature = Signature {
-            kind: FunctionKind::Free,
-            parameter_types: Vec::new(),
-            return_type: return_type.clone(),
-        };
-
-        let ir_signature = ir::Signature {
-            parameters: Vec::new(),
-            return_ptr: true, // TODO: check this
-            return_type: None,
-        };
-
-        let last = self.block(body);
-
-        self.add(Instruction::Return(last));
-
-        Function {
-            name: format!("test#{}", ident.node).into(),
-            scope: self.function_scope,
-            entry_block: label,
-            blocks: self.blocks,
             public: true,
             signature,
             ir_signature,
@@ -1224,19 +1144,11 @@ impl<'r> Lowerer<'r> {
     }
 
     fn is_reference_type(&mut self, ty: &Type) -> bool {
-        let ty = self.type_info.resolve(ty);
-        matches!(
-            ty,
-            Type::Record(..)
-                | Type::RecordVar(..)
-                | Type::NamedRecord(..)
-                | Type::Enum(..)
-                | Type::Verdict(..)
-                | Type::Primitive(Primitive::IpAddr | Primitive::Prefix)
-                | Type::BuiltIn(..)
-        )
+        self.type_info.is_reference_type(ty)
     }
 
+    // TODO: This should return Option<IrType> so that zero-sized
+    // types can be put in here.
     fn lower_type(&mut self, ty: &Type) -> IrType {
         let ty = self.type_info.resolve(ty);
         match ty {
diff --git a/src/parser/filter_map.rs b/src/parser/filter_map.rs
index 935003e3..d519ca66 100644
--- a/src/parser/filter_map.rs
+++ b/src/parser/filter_map.rs
@@ -26,13 +26,13 @@ impl Parser<'_, '_> {
 
         let ident = self.identifier()?;
         let params = self.params()?;
-        let block = self.block()?;
+        let body = self.block()?;
 
         Ok(FilterMap {
             filter_type,
             ident,
             params,
-            block,
+            body,
         })
     }
 
diff --git a/src/typechecker/filter_map.rs b/src/typechecker/filter_map.rs
index 37ab5556..f39922fb 100644
--- a/src/typechecker/filter_map.rs
+++ b/src/typechecker/filter_map.rs
@@ -20,7 +20,7 @@ impl TypeChecker<'_> {
             filter_type,
             ident,
             params,
-            block,
+            body,
         } = filter_map;
 
         let scope = self
@@ -44,7 +44,7 @@ impl TypeChecker<'_> {
             function_return_type: Some(ty.clone()),
         };
 
-        self.block(scope, &ctx, block)?;
+        self.block(scope, &ctx, body)?;
 
         if let Type::Var(x) = self.resolve_type(&a) {
             self.unify(
diff --git a/src/typechecker/info.rs b/src/typechecker/info.rs
index 3e26836f..7481e0e9 100644
--- a/src/typechecker/info.rs
+++ b/src/typechecker/info.rs
@@ -96,6 +96,20 @@ impl TypeInfo {
         self.function_scopes[&x.into()]
     }
 
+    pub fn is_reference_type(&mut self, ty: &Type) -> bool {
+        let ty = self.resolve(ty);
+        matches!(
+            ty,
+            Type::Record(..)
+                | Type::RecordVar(..)
+                | Type::NamedRecord(..)
+                | Type::Enum(..)
+                | Type::Verdict(..)
+                | Type::Primitive(Primitive::IpAddr | Primitive::Prefix)
+                | Type::BuiltIn(..)
+        )
+    }
+
     pub fn offset_of(
         &mut self,
         record: &Type,

From d3c1daff81522db4a14423ea333de66cf0ad2a94 Mon Sep 17 00:00:00 2001
From: Terts Diepraam <terts.diepraam@gmail.com>
Date: Mon, 20 Jan 2025 13:32:39 +0100
Subject: [PATCH 3/3] change lower_type function to return option

---
 src/codegen/mod.rs      |  8 ++---
 src/lower/match_expr.rs | 33 ++++++++++---------
 src/lower/mod.rs        | 72 +++++++++++++++++------------------------
 src/pipeline.rs         |  1 +
 src/typechecker/info.rs |  5 ++-
 5 files changed, 56 insertions(+), 63 deletions(-)

diff --git a/src/codegen/mod.rs b/src/codegen/mod.rs
index f8f7d220..4bf38d72 100644
--- a/src/codegen/mod.rs
+++ b/src/codegen/mod.rs
@@ -172,6 +172,7 @@ call_impl!(A, B, C, D, E, F, G);
 pub struct FunctionInfo {
     id: FuncId,
     signature: types::Signature,
+    return_by_ref: bool,
 }
 
 struct ModuleBuilder {
@@ -354,6 +355,7 @@ impl ModuleBuilder {
 
         let mut sig = self.inner.make_signature();
 
+        // This is the parameter for the context
         sig.params
             .push(AbiParam::new(self.cranelift_type(&IrType::Pointer)));
 
@@ -383,6 +385,7 @@ impl ModuleBuilder {
             name.to_string(),
             FunctionInfo {
                 id: func_id,
+                return_by_ref: ir_signature.return_ptr,
                 signature: signature.clone(),
             },
         );
@@ -1092,13 +1095,10 @@ impl Module {
             )
         })?;
 
-        let return_by_ref =
-            self.type_info.is_reference_type(&sig.return_type);
-
         let func_ptr = self.inner.0.get_finalized_function(id);
         Ok(TypedFunc {
             func: func_ptr,
-            return_by_ref,
+            return_by_ref: function_info.return_by_ref,
             _module: self.inner.clone(),
             _ty: PhantomData,
         })
diff --git a/src/lower/match_expr.rs b/src/lower/match_expr.rs
index ffd34bfe..2d19d1b5 100644
--- a/src/lower/match_expr.rs
+++ b/src/lower/match_expr.rs
@@ -213,12 +213,13 @@ impl Lowerer<'_> {
             let ty = self.type_info.type_of(&arm.body);
             let val = self.block(&arm.body);
             if let Some(val) = val {
-                let ty = self.lower_type(&ty);
-                self.add(Instruction::Assign {
-                    to: out.clone(),
-                    val,
-                    ty,
-                });
+                if let Some(ty) = self.lower_type(&ty) {
+                    self.add(Instruction::Assign {
+                        to: out.clone(),
+                        val,
+                        ty,
+                    });
+                }
                 any_assigned = true;
             }
             self.add(Instruction::Jump(continue_lbl));
@@ -267,15 +268,17 @@ impl Lowerer<'_> {
                     1 + self.type_info.padding_of(&ty, 1, self.runtime);
                 let val =
                     self.read_field(examinee.clone().into(), offset, &ty);
-                let ty = self.lower_type(&ty);
-                self.add(Instruction::Assign {
-                    to: Var {
-                        scope,
-                        kind: VarKind::Explicit(ident),
-                    },
-                    val,
-                    ty,
-                });
+                if let Some(val) = val {
+                    let ty = self.lower_type(&ty).unwrap();
+                    self.add(Instruction::Assign {
+                        to: Var {
+                            scope,
+                            kind: VarKind::Explicit(ident),
+                        },
+                        val,
+                        ty,
+                    });
+                }
             }
 
             let ident = Identifier::from(format!("guard_{}", i + 1));
diff --git a/src/lower/mod.rs b/src/lower/mod.rs
index d8a23983..0fbabefe 100644
--- a/src/lower/mod.rs
+++ b/src/lower/mod.rs
@@ -274,18 +274,16 @@ impl<'r> Lowerer<'r> {
         };
 
         let (return_type, return_ptr) = match return_type {
-            x if self.type_info.size_of(x, self.runtime) == 0 => {
-                (None, false)
-            }
             x if self.is_reference_type(x) => (None, true),
-            x => (Some(self.lower_type(x)), false),
+            x => (self.lower_type(x), false),
         };
 
         let ir_signature = ir::Signature {
             parameters: parameter_types
                 .iter()
-                .map(|(def, ty)| {
-                    (def.to_scope_and_name().1, self.lower_type(ty))
+                .filter_map(|(def, ty)| {
+                    let ty = self.lower_type(ty)?;
+                    Some((def.to_scope_and_name().1, ty))
                 })
                 .collect(),
             return_ptr,
@@ -327,9 +325,8 @@ impl<'r> Lowerer<'r> {
                 let def = self.type_info.resolved_name(ident);
                 let (scope, name) = def.to_scope_and_name();
                 let ty = self.type_info.type_of(ident);
-                if self.type_info.size_of(&ty, self.runtime) > 0 {
+                if let Some(ty) = self.lower_type(&ty) {
                     let val = val.unwrap();
-                    let ty = self.lower_type(&ty);
                     self.add(Instruction::Assign {
                         to: Var {
                             scope,
@@ -477,15 +474,9 @@ impl<'r> Lowerer<'r> {
                             }
                         }
 
-                        let to = if self.type_info.size_of(&ret, self.runtime)
-                            > 0
-                        {
-                            let ty = self.type_info.type_of(id);
-                            let ty = self.lower_type(&ty);
-                            Some((self.new_tmp(), ty))
-                        } else {
-                            None
-                        };
+                        let to = self
+                            .lower_type(&ret)
+                            .map(|ty| (self.new_tmp(), ty));
 
                         let ctx = Var {
                             scope: self.function_scope,
@@ -633,11 +624,7 @@ impl<'r> Lowerer<'r> {
                         **field,
                         self.runtime,
                     );
-                    if self.type_info.size_of(&ty, self.runtime) > 0 {
-                        Some(self.read_field(op, offset, &ty))
-                    } else {
-                        None
-                    }
+                    self.read_field(op, offset, &ty)
                 }
             }
             ast::Expr::Var(x) => {
@@ -650,12 +637,12 @@ impl<'r> Lowerer<'r> {
                             scope: self.function_scope,
                             kind: VarKind::Context,
                         };
-                        Some(self.read_field(var.into(), offset as u32, &ty))
+                        self.read_field(var.into(), offset as u32, &ty)
                     }
                     DefinitionRef::Constant(ident) => {
                         let var = self.new_tmp();
                         let ty = self.type_info.type_of(id);
-                        let ty = self.lower_type(&ty);
+                        let ty = self.lower_type(&ty)?;
                         self.add(Instruction::LoadConstant {
                             to: var.clone(),
                             name: ident,
@@ -933,7 +920,7 @@ impl<'r> Lowerer<'r> {
                         right,
                     }),
                     (ast::BinOp::Div, _, ty) => {
-                        let ty = self.lower_type(&ty);
+                        let ty = self.lower_type(&ty)?;
                         self.add(Instruction::Div {
                             to: place.clone(),
                             ty,
@@ -979,7 +966,7 @@ impl<'r> Lowerer<'r> {
                 self.new_block(lbl_then);
                 if let Some(op) = self.block(if_true) {
                     let ty = self.type_info.type_of(if_true);
-                    let ty = self.lower_type(&ty);
+                    let ty = self.lower_type(&ty)?;
                     self.add(Instruction::Assign {
                         to: res.clone(),
                         val: op,
@@ -996,7 +983,7 @@ impl<'r> Lowerer<'r> {
                     self.new_block(lbl_else);
                     if let Some(op) = self.block(if_false) {
                         let ty = self.type_info.type_of(if_false);
-                        let ty = self.lower_type(&ty);
+                        let ty = self.lower_type(&ty)?;
                         self.add(Instruction::Assign {
                             to: res.clone(),
                             val: op,
@@ -1113,7 +1100,7 @@ impl<'r> Lowerer<'r> {
         from: Operand,
         offset: u32,
         ty: &Type,
-    ) -> Operand {
+    ) -> Option<Operand> {
         let ty = self.type_info.resolve(ty);
         let to = self.new_tmp();
 
@@ -1126,13 +1113,14 @@ impl<'r> Lowerer<'r> {
         } else {
             let tmp = self.new_tmp();
 
+            let ty = self.lower_type(&ty)?;
+
             self.add(Instruction::Offset {
                 to: tmp.clone(),
                 from,
                 offset,
             });
 
-            let ty = self.lower_type(&ty);
             self.add(Instruction::Read {
                 to: to.clone(),
                 from: tmp.into(),
@@ -1140,18 +1128,19 @@ impl<'r> Lowerer<'r> {
             });
         }
 
-        to.into()
+        Some(to.into())
     }
 
     fn is_reference_type(&mut self, ty: &Type) -> bool {
-        self.type_info.is_reference_type(ty)
+        self.type_info.is_reference_type(ty, self.runtime)
     }
 
-    // TODO: This should return Option<IrType> so that zero-sized
-    // types can be put in here.
-    fn lower_type(&mut self, ty: &Type) -> IrType {
+    fn lower_type(&mut self, ty: &Type) -> Option<IrType> {
         let ty = self.type_info.resolve(ty);
-        match ty {
+        if self.type_info.size_of(&ty, self.runtime) == 0 {
+            return None;
+        }
+        Some(match ty {
             Type::Primitive(Primitive::Bool) => IrType::Bool,
             Type::Primitive(Primitive::U8) => IrType::U8,
             Type::Primitive(Primitive::U16) => IrType::U16,
@@ -1162,12 +1151,11 @@ impl<'r> Lowerer<'r> {
             Type::Primitive(Primitive::I32) => IrType::I32,
             Type::Primitive(Primitive::I64) => IrType::I64,
             Type::Primitive(Primitive::Asn) => IrType::Asn,
-            Type::Primitive(Primitive::IpAddr) => IrType::Pointer,
             Type::IntVar(_) => IrType::I32,
             Type::BuiltIn(_, _) => IrType::ExtPointer,
             x if self.is_reference_type(&x) => IrType::Pointer,
             _ => panic!("could not lower: {ty:?}"),
-        }
+        })
     }
 
     fn call_runtime_function(
@@ -1189,9 +1177,9 @@ impl<'r> Lowerer<'r> {
         let mut params = Vec::new();
         params.push(IrType::Pointer);
 
-        for ty in parameter_types {
-            params.push(self.lower_type(ty))
-        }
+        params.extend(
+            parameter_types.iter().filter_map(|ty| self.lower_type(ty)),
+        );
 
         let args = std::iter::once(Operand::Place(out_ptr.clone()))
             .chain(args)
@@ -1214,10 +1202,8 @@ impl<'r> Lowerer<'r> {
 
         if self.is_reference_type(return_type) {
             Some(out_ptr.into())
-        } else if size > 0 {
-            Some(self.read_field(out_ptr.into(), 0, return_type))
         } else {
-            None
+            self.read_field(out_ptr.into(), 0, return_type)
         }
     }
 
diff --git a/src/pipeline.rs b/src/pipeline.rs
index e37675be..40dff39b 100644
--- a/src/pipeline.rs
+++ b/src/pipeline.rs
@@ -447,6 +447,7 @@ impl Lowered {
 }
 
 impl Compiled {
+    #[allow(clippy::result_unit_err)]
     pub fn run_tests<Ctx: 'static>(&mut self, ctx: Ctx) -> Result<(), ()> {
         self.module.run_tests(ctx)
     }
diff --git a/src/typechecker/info.rs b/src/typechecker/info.rs
index 7481e0e9..e2d36a78 100644
--- a/src/typechecker/info.rs
+++ b/src/typechecker/info.rs
@@ -96,8 +96,11 @@ impl TypeInfo {
         self.function_scopes[&x.into()]
     }
 
-    pub fn is_reference_type(&mut self, ty: &Type) -> bool {
+    pub fn is_reference_type(&mut self, ty: &Type, rt: &Runtime) -> bool {
         let ty = self.resolve(ty);
+        if self.size_of(&ty, rt) == 0 {
+            return false;
+        }
         matches!(
             ty,
             Type::Record(..)