diff --git a/crates/lean_compiler/src/a_simplify_lang.rs b/crates/lean_compiler/src/a_simplify_lang.rs index 90489cbb..d7792d6b 100644 --- a/crates/lean_compiler/src/a_simplify_lang.rs +++ b/crates/lean_compiler/src/a_simplify_lang.rs @@ -148,6 +148,9 @@ pub enum SimpleLine { LocationReport { location: SourceLineNumber, }, + Goto { + label: String, + }, } pub fn simplify_program(mut program: Program) -> SimpleProgram { @@ -690,6 +693,11 @@ fn simplify_lines( location: *location, }); } + Line::Goto { label } => { + res.push(SimpleLine::Goto { + label: label.clone(), + }); + } } } @@ -898,7 +906,7 @@ pub fn find_variable_usage(lines: &[Line]) -> (BTreeSet, BTreeSet) { on_new_expr(index, &internal_vars, &mut external_vars); on_new_expr(value, &internal_vars, &mut external_vars); } - Line::Panic | Line::Break | Line::LocationReport { .. } => {} + Line::Panic | Line::Break | Line::LocationReport { .. } | Line::Goto { .. } => {} } } @@ -943,6 +951,7 @@ pub fn inline_lines( args: &BTreeMap, res: &[Var], inlining_count: usize, + epilogue_label: Option, ) { let inline_condition = |condition: &mut Boolean| { let (Boolean::Equal { left, right } | Boolean::Different { left, right }) = condition; @@ -964,7 +973,13 @@ pub fn inline_lines( Line::Match { value, arms } => { inline_expr(value, args, inlining_count); for (_, statements) in arms { - inline_lines(statements, args, res, inlining_count); + inline_lines( + statements, + args, + res, + inlining_count, + epilogue_label.clone(), + ); } } Line::Assignment { var, value } => { @@ -978,8 +993,20 @@ pub fn inline_lines( } => { inline_condition(condition); - inline_lines(then_branch, args, res, inlining_count); - inline_lines(else_branch, args, res, inlining_count); + inline_lines( + then_branch, + args, + res, + inlining_count, + epilogue_label.clone(), + ); + inline_lines( + else_branch, + args, + res, + inlining_count, + epilogue_label.clone(), + ); } Line::FunctionCall { args: func_args, @@ -1002,16 +1029,22 @@ pub fn inline_lines( for expr in return_data.iter_mut() { inline_expr(expr, args, inlining_count); } - lines_to_replace.push(( - i, - res.iter() - .zip(return_data) - .map(|(res_var, expr)| Line::Assignment { - var: res_var.clone(), - value: expr.clone(), - }) - .collect::>(), - )); + + let mut replacement_lines = res + .iter() + .zip(return_data) + .map(|(res_var, expr)| Line::Assignment { + var: res_var.clone(), + value: expr.clone(), + }) + .collect::>(); + + if let Some(label) = &epilogue_label { + replacement_lines.push(Line::Goto { + label: label.clone(), + }); + } + lines_to_replace.push((i, replacement_lines)); } Line::MAlloc { var, size, .. } => { inline_expr(size, args, inlining_count); @@ -1048,7 +1081,7 @@ pub fn inline_lines( rev: _, unroll: _, } => { - inline_lines(body, args, res, inlining_count); + inline_lines(body, args, res, inlining_count, epilogue_label.clone()); inline_internal_var(iterator); inline_expr(start, args, inlining_count); inline_expr(end, args, inlining_count); @@ -1062,7 +1095,7 @@ pub fn inline_lines( inline_expr(index, args, inlining_count); inline_expr(value, args, inlining_count); } - Line::Panic | Line::Break | Line::LocationReport { .. } => {} + Line::Panic | Line::Goto { .. } | Line::Break | Line::LocationReport { .. } => {} } } for (i, new_lines) in lines_to_replace.into_iter().rev() { @@ -1531,7 +1564,7 @@ fn replace_vars_for_unroll( Line::CounterHint { var } => { *var = format!("@unrolled_{unroll_index}_{iterator_value}_{var}"); } - Line::Break | Line::Panic | Line::LocationReport { .. } => {} + Line::Break | Line::Panic | Line::LocationReport { .. } | Line::Goto { .. } => {} } } } @@ -1652,12 +1685,14 @@ fn handle_inlined_functions_helper( .zip(&simplified_args) .map(|((var, _), expr)| (var.clone(), expr.clone())) .collect::>(); + let epilogue_label = format!("@inline_epilogue_{}", total_inlined_counter.0); let mut func_body = func.body.clone(); inline_lines( &mut func_body, &inlined_args, return_data, total_inlined_counter.next(), + Some(epilogue_label.clone()), ); inlined_lines.extend(func_body); @@ -1921,7 +1956,8 @@ fn get_function_called(lines: &[Line], function_called: &mut Vec) { | Line::MAlloc { .. } | Line::Panic | Line::Break - | Line::LocationReport { .. } => {} + | Line::LocationReport { .. } + | Line::Goto { .. } => {} } } } @@ -2025,7 +2061,7 @@ fn replace_vars_by_const_in_lines(lines: &mut [Line], map: &BTreeMap) { assert!(!map.contains_key(var), "Variable {var} is a constant"); replace_vars_by_const_in_expr(size, map); } - Line::Panic | Line::Break | Line::LocationReport { .. } => {} + Line::Panic | Line::Break | Line::LocationReport { .. } | Line::Goto { .. } => {} } } } @@ -2213,6 +2249,9 @@ impl SimpleLine { } Self::Panic => "panic".to_string(), Self::LocationReport { .. } => Default::default(), + Self::Goto { label } => { + format!("goto_{}", label) + } }; format!("{spaces}{line_str}") } diff --git a/crates/lean_compiler/src/b_compile_intermediate.rs b/crates/lean_compiler/src/b_compile_intermediate.rs index c040ec78..b760e1bf 100644 --- a/crates/lean_compiler/src/b_compile_intermediate.rs +++ b/crates/lean_compiler/src/b_compile_intermediate.rs @@ -597,6 +597,18 @@ fn compile_lines( location: *location, }); } + SimpleLine::Goto { label } => { + let goto_label = Label::Custom(label.clone()); + instructions.push(IntermediateInstruction::Jump { + dest: IntermediateValue::label(goto_label.clone()), + updated_fp: None, + }); + + let remaining = + compile_lines(&lines[i + 1..], compiler, final_jump, declared_vars)?; + compiler.bytecode.insert(goto_label, remaining); + return Ok(instructions); + } } } @@ -795,7 +807,8 @@ fn find_internal_vars(lines: &[SimpleLine]) -> BTreeSet { | SimpleLine::Print { .. } | SimpleLine::FunctionRet { .. } | SimpleLine::Precompile { .. } - | SimpleLine::LocationReport { .. } => {} + | SimpleLine::LocationReport { .. } + | SimpleLine::Goto { .. } => {} } } internal_vars diff --git a/crates/lean_compiler/src/lang.rs b/crates/lean_compiler/src/lang.rs index 0c1a989e..2d6afeb7 100644 --- a/crates/lean_compiler/src/lang.rs +++ b/crates/lean_compiler/src/lang.rs @@ -366,6 +366,9 @@ pub enum Line { LocationReport { location: SourceLineNumber, }, + Goto { + label: String, + }, } impl Display for Expression { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { @@ -561,6 +564,9 @@ impl Line { } Self::Break => "break".to_string(), Self::Panic => "panic".to_string(), + Self::Goto { label } => { + format!("goto_{}", label) + } }; format!("{spaces}{line_str}") } diff --git a/crates/lean_compiler/tests/test_compiler.rs b/crates/lean_compiler/tests/test_compiler.rs index 3dc73af5..a51023af 100644 --- a/crates/lean_compiler/tests/test_compiler.rs +++ b/crates/lean_compiler/tests/test_compiler.rs @@ -438,23 +438,24 @@ fn test_match() { compile_and_run(program, &[], &[], DEFAULT_NO_VEC_RUNTIME_MEMORY, false); } -// #[test] -// fn inline_bug_mre() { -// let program = r#" -// fn main() { -// boolean(0); -// return; -// } - -// fn boolean(a) inline -> 1 { -// if a == 0 { -// return 0; -// } -// return 1; -// } -// "#; -// compile_and_run(program, &[], &[]); -// } +#[test] +fn inline_bug_mre() { + let program = r#" + fn main() { + x = boolean(0); + return; + } + + fn boolean(a) inline -> 1 { + if a == 0 { + return 0; + } + return 1; + } + "#; + + compile_and_run(program, &[], &[], DEFAULT_NO_VEC_RUNTIME_MEMORY, false); +} #[test] fn test_const_functions_calling_const_functions() {