From 3b4b2f935d26ada3f7f45976873089ed5aa07f6b Mon Sep 17 00:00:00 2001 From: Terts Diepraam <terts.diepraam@gmail.com> Date: Fri, 17 Jan 2025 16:54:20 +0100 Subject: [PATCH 1/3] add some basic testing infrastructure --- examples/simple.roto | 29 +++++++++++++----- examples/simple.rs | 3 ++ src/ast.rs | 7 +++++ src/codegen/mod.rs | 50 +++++++++++++++++++++++++++++-- src/codegen/tests.rs | 1 - src/lower/ir.rs | 2 +- src/lower/mod.rs | 56 +++++++++++++++++++++++++++++++++-- src/parser/mod.rs | 12 +++++++- src/parser/token.rs | 3 ++ src/pipeline.rs | 4 +++ src/typechecker/filter_map.rs | 25 ++++++++++++++++ src/typechecker/mod.rs | 6 ++++ src/typechecker/tests.rs | 38 ++++++++++++------------ 13 files changed, 202 insertions(+), 34 deletions(-) diff --git a/examples/simple.roto b/examples/simple.roto index dfcddcaf..5c44aa99 100644 --- a/examples/simple.roto +++ b/examples/simple.roto @@ -1,12 +1,25 @@ +function is_zero(x: IpAddr) -> bool { + x == 0.0.0.0 +} + filter-map main(x: IpAddr) { - define { - y = 0.0.0.0; + if is_zero(x) { + accept + } else { + reject + } +} + +test is_zero_true { + if is_zero(1.1.1.1) { + reject } - apply { - if x == y { - accept - } else { - reject - } + accept +} + +test is_zero_false { + if not is_zero(0.0.0.0) { + reject } + accept } diff --git a/examples/simple.rs b/examples/simple.rs index 7528973f..1d099065 100644 --- a/examples/simple.rs +++ b/examples/simple.rs @@ -31,5 +31,8 @@ fn main() -> Result<(), roto::RotoReport> { let res = func.call(&mut (), "1.1.1.1".parse().unwrap()); println!("main(1.1.1.1) = {res:?}"); + println!(); + let _ = compiled.run_tests(()); + Ok(()) } diff --git a/src/ast.rs b/src/ast.rs index 00740920..08b7bc4a 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -23,6 +23,7 @@ pub enum Declaration { OutputStream(OutputStream), Record(RecordTypeDeclaration), Function(FunctionDeclaration), + Test(Test), } #[derive(Clone, Debug)] @@ -58,6 +59,12 @@ pub struct FunctionDeclaration { pub body: Meta<Block>, } +#[derive(Clone, Debug)] +pub struct Test { + pub ident: Meta<Identifier>, + pub body: Meta<Block>, +} + /// A block of multiple statements #[derive(Clone, Debug)] pub struct Block { diff --git a/src/codegen/mod.rs b/src/codegen/mod.rs index 09577996..cadd4334 100644 --- a/src/codegen/mod.rs +++ b/src/codegen/mod.rs @@ -22,7 +22,7 @@ use crate::{ RuntimeConstant, }, typechecker::{info::TypeInfo, scope::ScopeRef, types}, - IrValue, + IrValue, Verdict, }; use check::{ check_roto_type_reflect, return_type_by_ref, FunctionRetrievalError, @@ -555,7 +555,7 @@ impl<'c> FuncGen<'c> { args.next().unwrap(), ); - if dbg!(return_ptr) { + if return_ptr { self.def( self.module.variable_map[&Var { scope: self.scope, @@ -1014,6 +1014,52 @@ impl<'c> FuncGen<'c> { } impl Module { + pub fn run_tests<Ctx: 'static>( + &mut self, + mut ctx: Ctx, + ) -> Result<(), ()> { + let tests: Vec<_> = self + .functions + .keys() + .filter(|x| x.starts_with("test#")) + .map(Clone::clone) + .collect(); + + let total = tests.len(); + let total_width = total.to_string().len(); + let mut successes = 0; + let mut failures = 0; + + for (n, test) in tests.into_iter().enumerate() { + let n = n + 1; + let test_display = test.strip_prefix("test#").unwrap(); + print!("Test {n:>total_width$} / {total}: {test_display}... "); + let test_fn = self + .get_function::<Ctx, (), Verdict<(), ()>>(&test) + .unwrap(); + + match test_fn.call(&mut ctx) { + Verdict::Accept(()) => { + successes += 1; + println!("\x1B[92mok\x1B[m"); + } + Verdict::Reject(()) => { + failures += 1; + println!("\x1B[91mfail\x1B[m"); + } + } + } + println!( + "Ran {total} tests, {successes} succeeded, {failures} failed" + ); + + if failures == 0 { + Result::Err(()) + } else { + Result::Ok(()) + } + } + pub fn get_function<Ctx: 'static, Params: RotoParams, Return: Reflect>( &mut self, name: &str, diff --git a/src/codegen/tests.rs b/src/codegen/tests.rs index 801c6d51..4b75c0c5 100644 --- a/src/codegen/tests.rs +++ b/src/codegen/tests.rs @@ -53,7 +53,6 @@ fn accept() { .expect("No function found (or mismatched types)"); let res = f.call(&mut ()); - dbg!(std::mem::size_of::<Verdict<(), ()>>()); assert_eq!(res, Verdict::Accept(())); } diff --git a/src/lower/ir.rs b/src/lower/ir.rs index 04f6921c..802d203d 100644 --- a/src/lower/ir.rs +++ b/src/lower/ir.rs @@ -360,8 +360,8 @@ impl<'a> IrPrinter<'a> { } => format!( "{}: {ty} = {}({}, {})", self.var(to), - self.operand(ctx), self.ident(func), + self.operand(ctx), args.iter() .map(|a| format!( "{} = {}", diff --git a/src/lower/mod.rs b/src/lower/mod.rs index 2237e4dc..f1270227 100644 --- a/src/lower/mod.rs +++ b/src/lower/mod.rs @@ -179,6 +179,21 @@ impl<'r> Lowerer<'r> { .function(ident, params, ret, body), ); } + ast::Declaration::Test(ast::Test { ident, body }) => { + functions.push( + Lowerer::new( + type_info, + runtime_functions, + &Meta { + node: format!("test#{}", ident).into(), + id: ident.id, + }, + label_store, + runtime, + ) + .test(ident, body), + ); + } // Ignore the rest _ => {} } @@ -240,8 +255,6 @@ impl<'r> Lowerer<'r> { x => (Some(self.lower_type(&x)), false), }; - dbg!(return_ptr); - let ir_signature = ir::Signature { parameters: parameter_types .iter() @@ -335,6 +348,45 @@ impl<'r> Lowerer<'r> { } } + fn test( + mut self, + ident: &Meta<Identifier>, + body: &ast::Block, + ) -> Function { + let label = self.label_store.new_label(self.function_name); + self.new_block(label); + + let unit = Box::new(Type::Primitive(Primitive::Unit)); + + let return_type = Type::Verdict(unit.clone(), unit); + + let signature = Signature { + kind: FunctionKind::Free, + parameter_types: Vec::new(), + return_type: return_type.clone(), + }; + + let ir_signature = ir::Signature { + parameters: Vec::new(), + return_ptr: true, // TODO: check this + return_type: None, + }; + + let last = self.block(body); + + self.add(Instruction::Return(last)); + + Function { + name: format!("test#{}", ident.node).into(), + scope: self.function_scope, + entry_block: label, + blocks: self.blocks, + public: true, + signature, + ir_signature, + } + } + /// Lower a block /// /// Returns either the value of the expression or the place where the diff --git a/src/parser/mod.rs b/src/parser/mod.rs index a5c9837f..a952ae5a 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -1,4 +1,6 @@ -use crate::ast::{Declaration, FunctionDeclaration, Identifier, SyntaxTree}; +use crate::ast::{ + Declaration, FunctionDeclaration, Identifier, SyntaxTree, Test, +}; use logos::{Lexer, SpannedIter}; use std::{fmt::Display, iter::Peekable}; use token::Token; @@ -336,6 +338,7 @@ impl<'source, 'spans> Parser<'source, 'spans> { Declaration::Record(self.record_type_assignment()?) } Token::Function => Declaration::Function(self.function()?), + Token::Test => Declaration::Test(self.test()?), _ => { let (token, span) = self.next()?; return Err(ParseError::expected( @@ -372,6 +375,13 @@ impl<'source, 'spans> Parser<'source, 'spans> { ret, }) } + + fn test(&mut self) -> ParseResult<Test> { + self.take(Token::Test)?; + let ident = self.identifier()?; + let body = self.block()?; + Ok(Test { ident, body }) + } } /// # Parsing identifiers diff --git a/src/parser/token.rs b/src/parser/token.rs index be799a5b..04bf74b9 100644 --- a/src/parser/token.rs +++ b/src/parser/token.rs @@ -115,6 +115,8 @@ pub enum Token<'s> { Some, #[token("table")] Table, + #[token("test")] + Test, #[token("through")] Through, #[token("type")] @@ -211,6 +213,7 @@ impl Display for Token<'_> { Token::Rib => "rib", Token::Some => "some", Token::Table => "table", + Token::Test => "test", Token::Through => "through", Token::Type => "type", Token::UpTo => "up-to", diff --git a/src/pipeline.rs b/src/pipeline.rs index ca3851e8..e37675be 100644 --- a/src/pipeline.rs +++ b/src/pipeline.rs @@ -447,6 +447,10 @@ impl Lowered { } impl Compiled { + pub fn run_tests<Ctx: 'static>(&mut self, ctx: Ctx) -> Result<(), ()> { + self.module.run_tests(ctx) + } + pub fn get_function<Ctx: 'static, Params: RotoParams, Return: Reflect>( &mut self, name: &str, diff --git a/src/typechecker/filter_map.rs b/src/typechecker/filter_map.rs index dbcb136d..37ab5556 100644 --- a/src/typechecker/filter_map.rs +++ b/src/typechecker/filter_map.rs @@ -114,6 +114,31 @@ impl TypeChecker<'_> { Ok(()) } + pub fn test( + &mut self, + scope: LocalScopeRef, + test: &ast::Test, + ) -> TypeResult<()> { + let ast::Test { ident, body } = test; + + let scope = self + .scope_graph + .wrap(scope, ScopeType::Function(ident.node)); + + self.type_info + .function_scopes + .insert(ident.id, scope.into()); + + let unit = Box::new(Type::Primitive(Primitive::Unit)); + let ret = Type::Verdict(unit.clone(), unit); + let ctx = Context { + expected_type: ret.clone(), + function_return_type: Some(ret), + }; + self.block(scope, &ctx, body)?; + Ok(()) + } + pub fn function_type( &mut self, dec: &ast::FunctionDeclaration, diff --git a/src/typechecker/mod.rs b/src/typechecker/mod.rs index 10ca4bc9..61a2fd6f 100644 --- a/src/typechecker/mod.rs +++ b/src/typechecker/mod.rs @@ -304,6 +304,12 @@ impl TypeChecker<'_> { } } + for expr in &tree.declarations { + if let ast::Declaration::Test(t) = expr { + checker.test(module_scope, t)? + } + } + Ok(checker.type_info) } diff --git a/src/typechecker/tests.rs b/src/typechecker/tests.rs index d5a980e6..93bd5388 100644 --- a/src/typechecker/tests.rs +++ b/src/typechecker/tests.rs @@ -257,7 +257,7 @@ fn integer_inference() { " type Foo { x: u8 } - filter-map test(r: u32) { + filter-map my_map(r: u32) { let foo = Foo { x: 5 }; accept } @@ -270,7 +270,7 @@ fn integer_inference() { type Foo { x: u8 } type Bar { x: u8 } - filter-map test(r: u32) { + filter-map my_map(r: u32) { let a = 5; let foo = Foo { x: a }; accept @@ -284,7 +284,7 @@ fn integer_inference() { type Foo { x: u8 } type Bar { x: u8 } - filter-map test(r: u32) { + filter-map my_map(r: u32) { let a = 5; let foo = Foo { x: a }; let bar = Bar { x: a }; @@ -299,7 +299,7 @@ fn integer_inference() { type Foo { x: u8 } type Bar { x: u32 } - filter-map test(r: u32) { + filter-map my_map(r: u32) { let a = 5; let foo = Foo { x: a }; let bar = Bar { x: a }; @@ -314,7 +314,7 @@ fn integer_inference() { type Foo { x: u8 } type Bar { x: u32 } - filter-map test(r: u32) { + filter-map my_map(r: u32) { let foo = Foo { x: 5 }; let bar = Bar { x: 5 }; accept @@ -331,7 +331,7 @@ fn assign_field_to_other_record() { type Foo { x: u8 } type Bar { x: u8 } - filter-map test(r: u32) { + filter-map my_map(r: u32) { let foo = Foo { x: 5 }; let bar = Bar { x: foo.x }; accept @@ -345,7 +345,7 @@ fn assign_field_to_other_record() { type Foo { x: u8 } type Bar { x: u8 } - filter-map test(r: u32) { + filter-map my_map(r: u32) { let foo = Foo { x: 5 }; let bar = Bar { x: foo.y }; accept @@ -359,7 +359,7 @@ fn assign_field_to_other_record() { type Foo { x: u8 } type Bar { x: u32 } - filter-map test(r: u32) { + filter-map my_map(r: u32) { let foo = Foo { x: 5 }; let bar = Bar { x: foo.x }; accept @@ -373,7 +373,7 @@ fn assign_field_to_other_record() { fn ip_addr_method() { let s = src!( " - filter-map test(r: u32) { + filter-map my_map(r: u32) { let p = 10.10.10.10; let is_four = 1 + p.is_ipv4(); accept @@ -384,7 +384,7 @@ fn ip_addr_method() { let s = src!( " - filter-map test(r: u32) { + filter-map my_map(r: u32) { let p = 10.10.10.10; let is_four = true && p.is_ipv4(); accept @@ -399,7 +399,7 @@ fn ip_addr_method() { fn ip_addr_method_of_method_return_type() { let s = src!( " - filter-map test(r: u32) { + filter-map my_map(r: u32) { let p = 10.10.10.10; let x = p.to_canonical().is_ipv4(); accept @@ -410,7 +410,7 @@ fn ip_addr_method_of_method_return_type() { let s = src!( " - filter-map test(r: u32) { + filter-map my_map(r: u32) { let p = 10.10.10.10; let x = p.is_ipv4().to_canonical(); accept @@ -425,7 +425,7 @@ fn ip_addr_method_of_method_return_type() { fn prefix_method() { let s = src!( " - filter-map test(r: u32) { + filter-map my_map(r: u32) { let p = 10.10.10.10/20; let add = p.address(); accept @@ -444,7 +444,7 @@ fn logical_expr() { (10 == 10) || (10 == 11) } - filter-map test(r: u32) { + filter-map my_map(r: u32) { let p = 10.10.10.10/10; accept } @@ -458,7 +458,7 @@ fn logical_expr() { (10 == 10) || ("hello" == 11) } - filter-map test(r: u32) { + filter-map my_map(r: u32) { let p = 10.10.10.10/10; accept } @@ -482,7 +482,7 @@ fn send_output_stream() { }); } - filter-map test(r: u32) { + filter-map my_map(r: u32) { accept } "# @@ -503,7 +503,7 @@ fn send_output_stream() { }); } - filter-map test(r: u32) { + filter-map my_map(r: u32) { accept } "# @@ -529,7 +529,7 @@ fn send_output_stream() { }); } - filter-map test(r: u32) { + filter-map my_map(r: u32) { accept } "# @@ -555,7 +555,7 @@ fn send_output_stream() { }); } - filter-map test(r: u32) { + filter-map my_map(r: u32) { accept } "# From 69f9b16a0ebc22237b009740e7798f5156c29b25 Mon Sep 17 00:00:00 2001 From: Terts Diepraam <terts.diepraam@gmail.com> Date: Fri, 17 Jan 2025 19:11:29 +0100 Subject: [PATCH 2/3] merge some logic between filter-maps, functions and tests --- examples/simple.rs | 7 ++ src/ast.rs | 2 +- src/codegen/check.rs | 12 -- src/codegen/mod.rs | 16 +-- src/codegen/tests.rs | 63 +++++++++++ src/lower/mod.rs | 204 ++++++++++------------------------ src/parser/filter_map.rs | 4 +- src/typechecker/filter_map.rs | 4 +- src/typechecker/info.rs | 14 +++ 9 files changed, 152 insertions(+), 174 deletions(-) diff --git a/examples/simple.rs b/examples/simple.rs index 1d099065..71d42784 100644 --- a/examples/simple.rs +++ b/examples/simple.rs @@ -31,6 +31,13 @@ fn main() -> Result<(), roto::RotoReport> { let res = func.call(&mut (), "1.1.1.1".parse().unwrap()); println!("main(1.1.1.1) = {res:?}"); + let is_zero = compiled + .get_function::<(), (IpAddr,), bool>("is_zero") + .unwrap(); + + let res = is_zero.call(&mut (), "0.0.0.0".parse().unwrap()); + println!("is_zero(0.0.0.0) = {res:?}"); + println!(); let _ = compiled.run_tests(()); diff --git a/src/ast.rs b/src/ast.rs index 08b7bc4a..967837ef 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -47,7 +47,7 @@ pub struct FilterMap { pub filter_type: FilterType, pub ident: Meta<Identifier>, pub params: Meta<Params>, - pub block: Meta<Block>, + pub body: Meta<Block>, } /// A function declaration, including the [`Block`] forming its definition diff --git a/src/codegen/check.rs b/src/codegen/check.rs index 75599321..78e2eace 100644 --- a/src/codegen/check.rs +++ b/src/codegen/check.rs @@ -158,18 +158,6 @@ fn check_roto_type( } } -pub fn return_type_by_ref(registry: &TypeRegistry, rust_ty: TypeId) -> bool { - let Some(rust_ty) = registry.get(rust_ty) else { - return false; - }; - - #[allow(clippy::match_like_matches_macro)] - match rust_ty.description { - TypeDescription::Verdict(_, _) => true, - _ => todo!(), - } -} - /// Parameters of a Roto function /// /// This trait allows for checking the types against Roto types and converting diff --git a/src/codegen/mod.rs b/src/codegen/mod.rs index cadd4334..f8f7d220 100644 --- a/src/codegen/mod.rs +++ b/src/codegen/mod.rs @@ -16,17 +16,12 @@ use crate::{ value::IrType, IrFunction, }, - runtime::{ - context::ContextDescription, - ty::{Reflect, GLOBAL_TYPE_REGISTRY}, - RuntimeConstant, - }, + runtime::{context::ContextDescription, ty::Reflect, RuntimeConstant}, typechecker::{info::TypeInfo, scope::ScopeRef, types}, IrValue, Verdict, }; use check::{ - check_roto_type_reflect, return_type_by_ref, FunctionRetrievalError, - RotoParams, TypeMismatch, + check_roto_type_reflect, FunctionRetrievalError, RotoParams, TypeMismatch, }; use cranelift::{ codegen::{ @@ -1054,9 +1049,9 @@ impl Module { ); if failures == 0 { - Result::Err(()) - } else { Result::Ok(()) + } else { + Result::Err(()) } } @@ -1097,9 +1092,8 @@ impl Module { ) })?; - let registry = GLOBAL_TYPE_REGISTRY.lock().unwrap(); let return_by_ref = - return_type_by_ref(®istry, TypeId::of::<Return>()); + self.type_info.is_reference_type(&sig.return_type); let func_ptr = self.inner.0.get_finalized_function(id); Ok(TypedFunc { diff --git a/src/codegen/tests.rs b/src/codegen/tests.rs index 4b75c0c5..cfe820f5 100644 --- a/src/codegen/tests.rs +++ b/src/codegen/tests.rs @@ -901,3 +901,66 @@ fn use_context() { let output = f.call(&mut ctx); assert_eq!(output, Verdict::Accept(11)); } + +#[test] +fn use_a_roto_function() { + let s = src!( + " + function double(x: i32) -> i32 { + 2 * x + }" + ); + + let mut p = compile(s); + let f = p.get_function::<(), (i32,), i32>("double").unwrap(); + let output = f.call(&mut (), 2); + assert_eq!(output, 4); + + let output = f.call(&mut (), 16); + assert_eq!(output, 32); +} + +#[test] +fn use_a_test() { + let s = src!( + " + function double(x: i32) -> i32 { + x # oops! not correct + } + + test check_double { + if double(4) != 8 { + reject; + } + if double(16) != 32 { + reject; + } + accept + } + " + ); + + let mut p = compile(s); + p.run_tests(()).unwrap_err(); + + let s = src!( + " + function double(x: i32) -> i32 { + 2 * x + } + + test check_double { + if double(4) != 8 { + reject; + } + if double(16) != 32 { + reject; + } + accept + } + " + ); + + let mut p = compile(s); + p.run_tests(()).unwrap(); +} diff --git a/src/lower/mod.rs b/src/lower/mod.rs index f1270227..d8a23983 100644 --- a/src/lower/mod.rs +++ b/src/lower/mod.rs @@ -162,36 +162,31 @@ impl<'r> Lowerer<'r> { .filter_map(x), ); } - ast::Declaration::Function(ast::FunctionDeclaration { - ident, - params, - body, - ret, - }) => { + ast::Declaration::Function(x) => { functions.push( Lowerer::new( type_info, runtime_functions, - ident, + &x.ident, label_store, runtime, ) - .function(ident, params, ret, body), + .function(x), ); } - ast::Declaration::Test(ast::Test { ident, body }) => { + ast::Declaration::Test(x) => { functions.push( Lowerer::new( type_info, runtime_functions, &Meta { - node: format!("test#{}", ident).into(), - id: ident.id, + node: format!("test#{}", x.ident).into(), + id: x.ident.id, }, label_store, runtime, ) - .test(ident, body), + .test(x), ); } // Ignore the rest @@ -203,95 +198,50 @@ impl<'r> Lowerer<'r> { } /// Lower a filter-map - fn filter_map(mut self, fm: &ast::FilterMap) -> Function { + fn filter_map(self, fm: &ast::FilterMap) -> Function { let ast::FilterMap { ident, - block, + body, params, .. } = fm; - let label = self.label_store.new_label(self.function_name); - self.new_block(label); - - let parameter_types: Vec<_> = params - .0 - .iter() - .map(|(x, _)| { - let ty = self.type_info.type_of(x); - (self.type_info.resolved_name(x), ty) - }) - .collect(); - - for (def, ty) in ¶meter_types { - let (scope, name) = def.to_scope_and_name(); - let var = Var { - scope, - kind: VarKind::Explicit(name), - }; - self.stack_slots.push((var, ty.clone())) - } - - let return_type = self.type_info.type_of(block); - let last = self.block(block); - - self.add(Instruction::Return(last)); - - let signature = Signature { - kind: FunctionKind::Free, - parameter_types: parameter_types - .iter() - .cloned() - .map(|x| x.1) - .collect(), - return_type: return_type.clone(), - }; + let return_type = self.type_info.type_of(body); + self.function_like(ident, params, &return_type, body) + } - let (return_type, return_ptr) = match return_type { - x if self.type_info.size_of(&x, self.runtime) == 0 => { - (None, false) - } - x if self.is_reference_type(&x) => (None, true), - x => (Some(self.lower_type(&x)), false), - }; + fn function(self, function: &ast::FunctionDeclaration) -> Function { + let return_type = function + .ret + .as_ref() + .map(|t| self.type_info.resolve(&Type::Name(**t))) + .unwrap_or(Type::Primitive(Primitive::Unit)); + + self.function_like( + &function.ident, + &function.params, + &return_type, + &function.body, + ) + } - let ir_signature = ir::Signature { - parameters: parameter_types - .iter() - .map(|(def, ty)| { - (def.to_scope_and_name().1, self.lower_type(ty)) - }) - .collect(), - return_ptr, - return_type, + fn test(self, test: &ast::Test) -> Function { + let ident = Meta { + node: format!("test#{}", *test.ident).into(), + id: test.ident.id, }; - - Function { - name: ident.node, - scope: self.function_scope, - entry_block: label, - signature, - ir_signature, - blocks: self.blocks, - public: true, - } + let unit = Box::new(Type::Primitive(Primitive::Unit)); + let return_type = Type::Verdict(unit.clone(), unit); + let params = ast::Params(Vec::new()); + self.function_like(&ident, ¶ms, &return_type, &test.body) } - /// Lower a function - /// - /// We compile functions differently than turing complete languages, - /// because we don't have a full stack. Instead, the arguments just - /// have fixed locations. The interpreter does have to keep track - /// where the function is called and hence where it has to return to - /// - /// Ultimately, it should be possible to inline all functions, so - /// that we don't even need a return instruction. However, this could - /// also be done by an optimizing step. - fn function( + /// Lower a function-like construct (i.e. a function, filter-map or test) + fn function_like( mut self, ident: &Meta<Identifier>, - params: &Meta<ast::Params>, - return_type: &Option<Meta<Identifier>>, + params: &ast::Params, + return_type: &Type, body: &ast::Block, ) -> Function { let label = self.label_store.new_label(self.function_name); @@ -304,9 +254,14 @@ impl<'r> Lowerer<'r> { parameter_types.push((self.type_info.resolved_name(x), ty)); } - let return_type = return_type - .as_ref() - .map(|t| self.type_info.resolve(&Type::Name(**t))); + for (def, ty) in ¶meter_types { + let (scope, name) = def.to_scope_and_name(); + let var = Var { + scope, + kind: VarKind::Explicit(name), + }; + self.stack_slots.push((var, ty.clone())) + } let signature = Signature { kind: FunctionKind::Free, @@ -315,12 +270,16 @@ impl<'r> Lowerer<'r> { .cloned() .map(|x| x.1) .collect(), - return_type: return_type - .clone() - .unwrap_or(Type::Primitive(Primitive::Unit)), + return_type: return_type.clone(), }; - let return_type = return_type.map(|t| self.lower_type(&t)); + let (return_type, return_ptr) = match return_type { + x if self.type_info.size_of(x, self.runtime) == 0 => { + (None, false) + } + x if self.is_reference_type(x) => (None, true), + x => (Some(self.lower_type(x)), false), + }; let ir_signature = ir::Signature { parameters: parameter_types @@ -329,7 +288,7 @@ impl<'r> Lowerer<'r> { (def.to_scope_and_name().1, self.lower_type(ty)) }) .collect(), - return_ptr: false, // TODO: check this + return_ptr, return_type, }; @@ -342,45 +301,6 @@ impl<'r> Lowerer<'r> { scope: self.function_scope, entry_block: label, blocks: self.blocks, - public: false, - signature, - ir_signature, - } - } - - fn test( - mut self, - ident: &Meta<Identifier>, - body: &ast::Block, - ) -> Function { - let label = self.label_store.new_label(self.function_name); - self.new_block(label); - - let unit = Box::new(Type::Primitive(Primitive::Unit)); - - let return_type = Type::Verdict(unit.clone(), unit); - - let signature = Signature { - kind: FunctionKind::Free, - parameter_types: Vec::new(), - return_type: return_type.clone(), - }; - - let ir_signature = ir::Signature { - parameters: Vec::new(), - return_ptr: true, // TODO: check this - return_type: None, - }; - - let last = self.block(body); - - self.add(Instruction::Return(last)); - - Function { - name: format!("test#{}", ident.node).into(), - scope: self.function_scope, - entry_block: label, - blocks: self.blocks, public: true, signature, ir_signature, @@ -1224,19 +1144,11 @@ impl<'r> Lowerer<'r> { } fn is_reference_type(&mut self, ty: &Type) -> bool { - let ty = self.type_info.resolve(ty); - matches!( - ty, - Type::Record(..) - | Type::RecordVar(..) - | Type::NamedRecord(..) - | Type::Enum(..) - | Type::Verdict(..) - | Type::Primitive(Primitive::IpAddr | Primitive::Prefix) - | Type::BuiltIn(..) - ) + self.type_info.is_reference_type(ty) } + // TODO: This should return Option<IrType> so that zero-sized + // types can be put in here. fn lower_type(&mut self, ty: &Type) -> IrType { let ty = self.type_info.resolve(ty); match ty { diff --git a/src/parser/filter_map.rs b/src/parser/filter_map.rs index 935003e3..d519ca66 100644 --- a/src/parser/filter_map.rs +++ b/src/parser/filter_map.rs @@ -26,13 +26,13 @@ impl Parser<'_, '_> { let ident = self.identifier()?; let params = self.params()?; - let block = self.block()?; + let body = self.block()?; Ok(FilterMap { filter_type, ident, params, - block, + body, }) } diff --git a/src/typechecker/filter_map.rs b/src/typechecker/filter_map.rs index 37ab5556..f39922fb 100644 --- a/src/typechecker/filter_map.rs +++ b/src/typechecker/filter_map.rs @@ -20,7 +20,7 @@ impl TypeChecker<'_> { filter_type, ident, params, - block, + body, } = filter_map; let scope = self @@ -44,7 +44,7 @@ impl TypeChecker<'_> { function_return_type: Some(ty.clone()), }; - self.block(scope, &ctx, block)?; + self.block(scope, &ctx, body)?; if let Type::Var(x) = self.resolve_type(&a) { self.unify( diff --git a/src/typechecker/info.rs b/src/typechecker/info.rs index 3e26836f..7481e0e9 100644 --- a/src/typechecker/info.rs +++ b/src/typechecker/info.rs @@ -96,6 +96,20 @@ impl TypeInfo { self.function_scopes[&x.into()] } + pub fn is_reference_type(&mut self, ty: &Type) -> bool { + let ty = self.resolve(ty); + matches!( + ty, + Type::Record(..) + | Type::RecordVar(..) + | Type::NamedRecord(..) + | Type::Enum(..) + | Type::Verdict(..) + | Type::Primitive(Primitive::IpAddr | Primitive::Prefix) + | Type::BuiltIn(..) + ) + } + pub fn offset_of( &mut self, record: &Type, From d3c1daff81522db4a14423ea333de66cf0ad2a94 Mon Sep 17 00:00:00 2001 From: Terts Diepraam <terts.diepraam@gmail.com> Date: Mon, 20 Jan 2025 13:32:39 +0100 Subject: [PATCH 3/3] change lower_type function to return option --- src/codegen/mod.rs | 8 ++--- src/lower/match_expr.rs | 33 ++++++++++--------- src/lower/mod.rs | 72 +++++++++++++++++------------------------ src/pipeline.rs | 1 + src/typechecker/info.rs | 5 ++- 5 files changed, 56 insertions(+), 63 deletions(-) diff --git a/src/codegen/mod.rs b/src/codegen/mod.rs index f8f7d220..4bf38d72 100644 --- a/src/codegen/mod.rs +++ b/src/codegen/mod.rs @@ -172,6 +172,7 @@ call_impl!(A, B, C, D, E, F, G); pub struct FunctionInfo { id: FuncId, signature: types::Signature, + return_by_ref: bool, } struct ModuleBuilder { @@ -354,6 +355,7 @@ impl ModuleBuilder { let mut sig = self.inner.make_signature(); + // This is the parameter for the context sig.params .push(AbiParam::new(self.cranelift_type(&IrType::Pointer))); @@ -383,6 +385,7 @@ impl ModuleBuilder { name.to_string(), FunctionInfo { id: func_id, + return_by_ref: ir_signature.return_ptr, signature: signature.clone(), }, ); @@ -1092,13 +1095,10 @@ impl Module { ) })?; - let return_by_ref = - self.type_info.is_reference_type(&sig.return_type); - let func_ptr = self.inner.0.get_finalized_function(id); Ok(TypedFunc { func: func_ptr, - return_by_ref, + return_by_ref: function_info.return_by_ref, _module: self.inner.clone(), _ty: PhantomData, }) diff --git a/src/lower/match_expr.rs b/src/lower/match_expr.rs index ffd34bfe..2d19d1b5 100644 --- a/src/lower/match_expr.rs +++ b/src/lower/match_expr.rs @@ -213,12 +213,13 @@ impl Lowerer<'_> { let ty = self.type_info.type_of(&arm.body); let val = self.block(&arm.body); if let Some(val) = val { - let ty = self.lower_type(&ty); - self.add(Instruction::Assign { - to: out.clone(), - val, - ty, - }); + if let Some(ty) = self.lower_type(&ty) { + self.add(Instruction::Assign { + to: out.clone(), + val, + ty, + }); + } any_assigned = true; } self.add(Instruction::Jump(continue_lbl)); @@ -267,15 +268,17 @@ impl Lowerer<'_> { 1 + self.type_info.padding_of(&ty, 1, self.runtime); let val = self.read_field(examinee.clone().into(), offset, &ty); - let ty = self.lower_type(&ty); - self.add(Instruction::Assign { - to: Var { - scope, - kind: VarKind::Explicit(ident), - }, - val, - ty, - }); + if let Some(val) = val { + let ty = self.lower_type(&ty).unwrap(); + self.add(Instruction::Assign { + to: Var { + scope, + kind: VarKind::Explicit(ident), + }, + val, + ty, + }); + } } let ident = Identifier::from(format!("guard_{}", i + 1)); diff --git a/src/lower/mod.rs b/src/lower/mod.rs index d8a23983..0fbabefe 100644 --- a/src/lower/mod.rs +++ b/src/lower/mod.rs @@ -274,18 +274,16 @@ impl<'r> Lowerer<'r> { }; let (return_type, return_ptr) = match return_type { - x if self.type_info.size_of(x, self.runtime) == 0 => { - (None, false) - } x if self.is_reference_type(x) => (None, true), - x => (Some(self.lower_type(x)), false), + x => (self.lower_type(x), false), }; let ir_signature = ir::Signature { parameters: parameter_types .iter() - .map(|(def, ty)| { - (def.to_scope_and_name().1, self.lower_type(ty)) + .filter_map(|(def, ty)| { + let ty = self.lower_type(ty)?; + Some((def.to_scope_and_name().1, ty)) }) .collect(), return_ptr, @@ -327,9 +325,8 @@ impl<'r> Lowerer<'r> { let def = self.type_info.resolved_name(ident); let (scope, name) = def.to_scope_and_name(); let ty = self.type_info.type_of(ident); - if self.type_info.size_of(&ty, self.runtime) > 0 { + if let Some(ty) = self.lower_type(&ty) { let val = val.unwrap(); - let ty = self.lower_type(&ty); self.add(Instruction::Assign { to: Var { scope, @@ -477,15 +474,9 @@ impl<'r> Lowerer<'r> { } } - let to = if self.type_info.size_of(&ret, self.runtime) - > 0 - { - let ty = self.type_info.type_of(id); - let ty = self.lower_type(&ty); - Some((self.new_tmp(), ty)) - } else { - None - }; + let to = self + .lower_type(&ret) + .map(|ty| (self.new_tmp(), ty)); let ctx = Var { scope: self.function_scope, @@ -633,11 +624,7 @@ impl<'r> Lowerer<'r> { **field, self.runtime, ); - if self.type_info.size_of(&ty, self.runtime) > 0 { - Some(self.read_field(op, offset, &ty)) - } else { - None - } + self.read_field(op, offset, &ty) } } ast::Expr::Var(x) => { @@ -650,12 +637,12 @@ impl<'r> Lowerer<'r> { scope: self.function_scope, kind: VarKind::Context, }; - Some(self.read_field(var.into(), offset as u32, &ty)) + self.read_field(var.into(), offset as u32, &ty) } DefinitionRef::Constant(ident) => { let var = self.new_tmp(); let ty = self.type_info.type_of(id); - let ty = self.lower_type(&ty); + let ty = self.lower_type(&ty)?; self.add(Instruction::LoadConstant { to: var.clone(), name: ident, @@ -933,7 +920,7 @@ impl<'r> Lowerer<'r> { right, }), (ast::BinOp::Div, _, ty) => { - let ty = self.lower_type(&ty); + let ty = self.lower_type(&ty)?; self.add(Instruction::Div { to: place.clone(), ty, @@ -979,7 +966,7 @@ impl<'r> Lowerer<'r> { self.new_block(lbl_then); if let Some(op) = self.block(if_true) { let ty = self.type_info.type_of(if_true); - let ty = self.lower_type(&ty); + let ty = self.lower_type(&ty)?; self.add(Instruction::Assign { to: res.clone(), val: op, @@ -996,7 +983,7 @@ impl<'r> Lowerer<'r> { self.new_block(lbl_else); if let Some(op) = self.block(if_false) { let ty = self.type_info.type_of(if_false); - let ty = self.lower_type(&ty); + let ty = self.lower_type(&ty)?; self.add(Instruction::Assign { to: res.clone(), val: op, @@ -1113,7 +1100,7 @@ impl<'r> Lowerer<'r> { from: Operand, offset: u32, ty: &Type, - ) -> Operand { + ) -> Option<Operand> { let ty = self.type_info.resolve(ty); let to = self.new_tmp(); @@ -1126,13 +1113,14 @@ impl<'r> Lowerer<'r> { } else { let tmp = self.new_tmp(); + let ty = self.lower_type(&ty)?; + self.add(Instruction::Offset { to: tmp.clone(), from, offset, }); - let ty = self.lower_type(&ty); self.add(Instruction::Read { to: to.clone(), from: tmp.into(), @@ -1140,18 +1128,19 @@ impl<'r> Lowerer<'r> { }); } - to.into() + Some(to.into()) } fn is_reference_type(&mut self, ty: &Type) -> bool { - self.type_info.is_reference_type(ty) + self.type_info.is_reference_type(ty, self.runtime) } - // TODO: This should return Option<IrType> so that zero-sized - // types can be put in here. - fn lower_type(&mut self, ty: &Type) -> IrType { + fn lower_type(&mut self, ty: &Type) -> Option<IrType> { let ty = self.type_info.resolve(ty); - match ty { + if self.type_info.size_of(&ty, self.runtime) == 0 { + return None; + } + Some(match ty { Type::Primitive(Primitive::Bool) => IrType::Bool, Type::Primitive(Primitive::U8) => IrType::U8, Type::Primitive(Primitive::U16) => IrType::U16, @@ -1162,12 +1151,11 @@ impl<'r> Lowerer<'r> { Type::Primitive(Primitive::I32) => IrType::I32, Type::Primitive(Primitive::I64) => IrType::I64, Type::Primitive(Primitive::Asn) => IrType::Asn, - Type::Primitive(Primitive::IpAddr) => IrType::Pointer, Type::IntVar(_) => IrType::I32, Type::BuiltIn(_, _) => IrType::ExtPointer, x if self.is_reference_type(&x) => IrType::Pointer, _ => panic!("could not lower: {ty:?}"), - } + }) } fn call_runtime_function( @@ -1189,9 +1177,9 @@ impl<'r> Lowerer<'r> { let mut params = Vec::new(); params.push(IrType::Pointer); - for ty in parameter_types { - params.push(self.lower_type(ty)) - } + params.extend( + parameter_types.iter().filter_map(|ty| self.lower_type(ty)), + ); let args = std::iter::once(Operand::Place(out_ptr.clone())) .chain(args) @@ -1214,10 +1202,8 @@ impl<'r> Lowerer<'r> { if self.is_reference_type(return_type) { Some(out_ptr.into()) - } else if size > 0 { - Some(self.read_field(out_ptr.into(), 0, return_type)) } else { - None + self.read_field(out_ptr.into(), 0, return_type) } } diff --git a/src/pipeline.rs b/src/pipeline.rs index e37675be..40dff39b 100644 --- a/src/pipeline.rs +++ b/src/pipeline.rs @@ -447,6 +447,7 @@ impl Lowered { } impl Compiled { + #[allow(clippy::result_unit_err)] pub fn run_tests<Ctx: 'static>(&mut self, ctx: Ctx) -> Result<(), ()> { self.module.run_tests(ctx) } diff --git a/src/typechecker/info.rs b/src/typechecker/info.rs index 7481e0e9..e2d36a78 100644 --- a/src/typechecker/info.rs +++ b/src/typechecker/info.rs @@ -96,8 +96,11 @@ impl TypeInfo { self.function_scopes[&x.into()] } - pub fn is_reference_type(&mut self, ty: &Type) -> bool { + pub fn is_reference_type(&mut self, ty: &Type, rt: &Runtime) -> bool { let ty = self.resolve(ty); + if self.size_of(&ty, rt) == 0 { + return false; + } matches!( ty, Type::Record(..)