Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions crates/lean_compiler/src/b_compile_intermediate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ struct Compiler {
match_blocks: Vec<MatchBlock>,
if_counter: usize,
call_counter: usize,
match_counter: usize,
func_name: String,
stack_frame_layout: StackFrameLayout,
args_count: usize,
Expand Down Expand Up @@ -210,8 +211,9 @@ fn compile_lines(
SimpleLine::Match { value, arms } => {
compiler.stack_frame_layout.scopes.push(ScopeLayout::default());

let match_index = compiler.match_blocks.len();
let end_label = Label::match_end(match_index);
let label_id = compiler.match_counter;
compiler.match_counter += 1;
let end_label = Label::match_end(label_id);

let value_simplified = IntermediateValue::from_simple_expr(value, compiler);

Expand All @@ -232,6 +234,8 @@ fn compile_lines(
function_name: function_name.clone(),
match_cases: compiled_arms,
});
// Get the actual index AFTER pushing (nested matches may have pushed their blocks first)
let match_index = compiler.match_blocks.len() - 1;

let value_scaled_offset = IntermediateValue::MemoryAfterFp {
offset: compiler.stack_pos.into(),
Expand Down Expand Up @@ -263,9 +267,9 @@ fn compile_lines(
compiler.bytecode.insert(end_label, remaining);

compiler.stack_frame_layout.scopes.pop();
compiler.stack_pos = saved_stack_pos;
// It is not necessary to update compiler.stack_size here because the preceding call to
// compile lines should have done so.
// Don't reset stack_pos here - we need to preserve space for the temps we allocated.
// Nested matches would otherwise reuse the same temp positions, causing conflicts.
// This is consistent with IfNotZero which also doesn't reset stack_pos.

return Ok(instructions);
}
Expand Down
156 changes: 156 additions & 0 deletions crates/lean_compiler/tests/test_compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1455,3 +1455,159 @@ fn test_len_2d_array() {
false,
);
}

#[test]
fn test_nested_matches() {
let program = r#"
fn main() {
assert test_func(0, 0) == 6;
assert test_func(1, 0) == 3;
return;
}

fn test_func(a, b) -> 1 {
x = 1;

var mut_x_2;
match a {
0 => {
var mut_x_1;
mut_x_1 = x + 2;
match b {
0 => {
mut_x_2 = mut_x_1 + 3;
}
}
}
1 => {
mut_x_2 = x + 2;
}
}

return mut_x_2;
}
"#;
compile_and_run(
&ProgramSource::Raw(program.to_string()),
(&[], &[]),
DEFAULT_NO_VEC_RUNTIME_MEMORY,
false,
);
}

#[test]
fn test_deeply_nested_match() {
// Test with 3 levels of nesting, multiple arms, and variables at each level
let program = r#"
fn main() {
// Test each combination with expected values
// (0,0,0): base=1000, local_a=5, local_b=8, inner_val=1008
assert compute(0, 0, 0) == 1008;
// (0,0,1): base=1000, local_a=5, local_b=8, inner_val=1009
assert compute(0, 0, 1) == 1009;
// (0,1,0): base=1000, local_a=5, local_b=12, inner_val=1012
assert compute(0, 1, 0) == 1012;
// (0,1,1): base=1000, local_a=5, local_b=12, inner_val=1013
assert compute(0, 1, 1) == 1013;
// (1,0,0): base=1000, local_a=16, local_b=36, inner_val=1036
assert compute(1, 0, 0) == 1036;
// (1,0,1): base=1000, local_a=16, local_b=36, inner_val=1037
assert compute(1, 0, 1) == 1037;
// (1,1,0): base=1000, local_a=16, local_b=46, inner_val=1046
assert compute(1, 1, 0) == 1046;
// (1,1,1): base=1000, local_a=16, local_b=46, inner_val=1047
assert compute(1, 1, 1) == 1047;
return;
}

fn compute(a, b, c) -> 1 {
base = 1000;
var outer_val;
var mid_val;
var inner_val;

match a {
0 => {
outer_val = 5;
var local_a;
local_a = a + outer_val; // local_a = 5

match b {
0 => {
mid_val = 3;
var local_b;
local_b = local_a + mid_val; // local_b = 8

match c {
0 => {
inner_val = base + local_b + c; // 1000 + 8 + 0 = 1008
}
1 => {
inner_val = base + local_b + c; // 1000 + 8 + 1 = 1009
}
}
}
1 => {
mid_val = 7;
var local_b;
local_b = local_a + mid_val; // local_b = 12

match c {
0 => {
inner_val = base + local_b + c; // 1000 + 12 + 0 = 1012
}
1 => {
inner_val = base + local_b + c; // 1000 + 12 + 1 = 1013
}
}
}
}
}
1 => {
outer_val = 15;
var local_a;
local_a = a + outer_val; // local_a = 16

match b {
0 => {
mid_val = 20;
var local_b;
local_b = local_a + mid_val; // local_b = 36

match c {
0 => {
inner_val = base + local_b + c; // 1000 + 36 + 0 = 1036
}
1 => {
inner_val = base + local_b + c; // 1000 + 36 + 1 = 1037
}
}
}
1 => {
mid_val = 30;
var local_b;
local_b = local_a + mid_val; // local_b = 46

match c {
0 => {
inner_val = base + local_b + c; // 1000 + 46 + 0 = 1046
}
1 => {
inner_val = base + local_b + c; // 1000 + 46 + 1 = 1047
}
}
}
}
}
}

return inner_val;
}
"#;
compile_and_run(
&ProgramSource::Raw(program.to_string()),
(&[], &[]),
DEFAULT_NO_VEC_RUNTIME_MEMORY,
false,
);
}