diff --git a/crates/lean_compiler/src/a_simplify_lang.rs b/crates/lean_compiler/src/a_simplify_lang.rs index bb194fe2..2abedf59 100644 --- a/crates/lean_compiler/src/a_simplify_lang.rs +++ b/crates/lean_compiler/src/a_simplify_lang.rs @@ -94,8 +94,7 @@ pub enum SimpleLine { else_branch: Vec, line_number: SourceLineNumber, }, - TestZero { - // Test that the result of the given operation is zero + AssertZero { operation: HighLevelOperation, arg0: SimpleExpr, arg1: SimpleExpr, @@ -146,6 +145,17 @@ pub enum SimpleLine { DebugAssert(BooleanExpr, SourceLineNumber), } +impl SimpleLine { + pub fn equality(arg0: impl Into, arg1: impl Into) -> Self { + SimpleLine::Assignment { + var: arg0.into(), + operation: HighLevelOperation::Add, + arg0: arg1.into(), + arg1: SimpleExpr::zero(), + } + } +} + pub fn simplify_program(mut program: Program) -> SimpleProgram { check_program_scoping(&program); handle_inlined_functions(&mut program); @@ -180,20 +190,25 @@ pub fn simplify_program(mut program: Program) -> SimpleProgram { let mut new_functions = BTreeMap::new(); let mut counters = Counters::default(); let mut const_malloc = ConstMalloc::default(); - let const_arrays = &program.const_arrays; + let ctx = SimplifyContext { + functions: &program.functions, + const_arrays: &program.const_arrays, + }; for (name, func) in &program.functions { let mut array_manager = ArrayManager::default(); + let mut state = SimplifyState { + counters: &mut counters, + array_manager: &mut array_manager, + }; let simplified_instructions = simplify_lines( - &program.functions, + &ctx, + &mut state, + &mut const_malloc, + &mut new_functions, func.file_id, func.n_returned_vars, &func.body, - &mut counters, - &mut new_functions, false, - &mut array_manager, - &mut const_malloc, - const_arrays, ); let arguments = func .arguments @@ -367,16 +382,23 @@ fn check_block_scoping(block: &[Line], ctx: &mut Context) { ctx.scopes.pop(); } } - Line::Assignment { var, value } => { + Line::Statement { targets, value, .. } => { check_expr_scoping(value, ctx); - if !ctx.defines(var) { - ctx.add_var(var); + // First: add new variables to scope + for target in targets { + if let AssignmentTarget::Var(var) = target + && !ctx.defines(var) + { + ctx.add_var(var); + } + } + // Second pass: check array access targets + for target in targets { + if let AssignmentTarget::ArrayAccess { array, index } = target { + check_simple_expr_scoping(array, ctx); + check_expr_scoping(index, ctx); + } } - } - Line::ArrayAssign { array, index, value } => { - check_simple_expr_scoping(array, ctx); - check_expr_scoping(index, ctx); - check_expr_scoping(value, ctx); } Line::Assert { boolean, .. } => { check_boolean_scoping(boolean, ctx); @@ -411,35 +433,6 @@ fn check_block_scoping(block: &[Line], ctx: &mut Context) { check_block_scoping(body, ctx); ctx.scopes.pop(); } - Line::FunctionCall { - function_name: _, - args, - return_data, - line_number: _, - } => { - for arg in args { - check_expr_scoping(arg, ctx); - } - for target in return_data { - match target { - AssignmentTarget::Var(var) => { - if !ctx.defines(var) { - ctx.add_var(var); - } - } - AssignmentTarget::ArrayAccess { .. } => {} - } - } - for target in return_data { - match target { - AssignmentTarget::Var(_) => {} - AssignmentTarget::ArrayAccess { array, index } => { - check_simple_expr_scoping(array, ctx); - check_expr_scoping(index, ctx); - } - } - } - } Line::FunctionRet { return_data } => { for expr in return_data { check_expr_scoping(expr, ctx); @@ -551,6 +544,24 @@ struct Counters { loops: usize, } +impl Counters { + fn aux_var(&mut self) -> Var { + let var = format!("@aux_var_{}", self.aux_vars); + self.aux_vars += 1; + var + } +} + +struct SimplifyContext<'a> { + functions: &'a BTreeMap, + const_arrays: &'a BTreeMap, +} + +struct SimplifyState<'a> { + counters: &'a mut Counters, + array_manager: &'a mut ArrayManager, +} + #[derive(Debug, Clone, Default)] struct ArrayManager { counter: usize, @@ -578,16 +589,14 @@ impl ArrayManager { #[allow(clippy::too_many_arguments)] fn simplify_lines( - functions: &BTreeMap, + ctx: &SimplifyContext<'_>, + state: &mut SimplifyState<'_>, + const_malloc: &mut ConstMalloc, + new_functions: &mut BTreeMap, file_id: FileId, n_returned_vars: usize, lines: &[Line], - counters: &mut Counters, - new_functions: &mut BTreeMap, in_a_loop: bool, - array_manager: &mut ArrayManager, - const_malloc: &mut ConstMalloc, - const_arrays: &BTreeMap, ) -> Vec { let mut res = Vec::new(); for line in lines { @@ -596,29 +605,19 @@ fn simplify_lines( res.push(SimpleLine::ForwardDeclaration { var: var.clone() }); } Line::Match { value, arms } => { - let simple_value = simplify_expr( - value, - &mut res, - counters, - array_manager, - const_malloc, - const_arrays, - functions, - ); + let simple_value = simplify_expr(ctx, state, const_malloc, value, &mut res); let mut simple_arms = vec![]; for (i, (pattern, statements)) in arms.iter().enumerate() { assert_eq!(*pattern, i, "match patterns should be consecutive, starting from 0"); simple_arms.push(simplify_lines( - functions, + ctx, + state, + const_malloc, + new_functions, file_id, n_returned_vars, statements, - counters, - new_functions, in_a_loop, - array_manager, - const_malloc, - const_arrays, )); } res.push(SimpleLine::Match { @@ -626,125 +625,149 @@ fn simplify_lines( arms: simple_arms, }); } - Line::Assignment { var, value } => match value { - Expression::Value(value) => { - res.push(SimpleLine::Assignment { - var: var.clone().into(), - operation: HighLevelOperation::Add, - arg0: value.clone(), - arg1: SimpleExpr::zero(), - }); - } - Expression::ArrayAccess { array, index } => { - handle_array_assignment( - counters, - &mut res, - array.clone(), - index, - ArrayAccessType::VarIsAssigned(var.clone()), - array_manager, - const_malloc, - const_arrays, - functions, - ); - } - Expression::Binary { left, operation, right } => { - let left = simplify_expr( - left, - &mut res, - counters, - array_manager, - const_malloc, - const_arrays, - functions, - ); - let right = simplify_expr( - right, - &mut res, - counters, - array_manager, - const_malloc, - const_arrays, - functions, - ); - // If both operands are constants, evaluate at compile time and assign the result - if let (SimpleExpr::Constant(left_cst), SimpleExpr::Constant(right_cst)) = (&left, &right) { - let result = ConstExpression::Binary { - left: Box::new(left_cst.clone()), - operation: *operation, - right: Box::new(right_cst.clone()), - } - .try_naive_simplification(); - res.push(SimpleLine::Assignment { - var: var.clone().into(), - operation: HighLevelOperation::Add, - arg0: SimpleExpr::Constant(result), - arg1: SimpleExpr::zero(), + Line::Statement { + targets, + value, + line_number, + } => { + match value { + Expression::FunctionCall { function_name, args } => { + // Function call - may have zero, one, or multiple targets + let function = ctx.functions.get(function_name).unwrap_or_else(|| { + panic!("Function used but not defined: {function_name}, at line {line_number}") }); - } else { - res.push(SimpleLine::Assignment { - var: var.clone().into(), - operation: *operation, - arg0: left, - arg1: right, + if targets.len() != function.n_returned_vars { + panic!( + "Expected {} returned vars (and not {}) in call to {function_name}, at line {line_number}", + function.n_returned_vars, + targets.len() + ); + } + if args.len() != function.arguments.len() { + panic!( + "Expected {} arguments (and not {}) in call to {function_name}, at line {line_number}", + function.arguments.len(), + args.len() + ); + } + + let simplified_args = args + .iter() + .map(|arg| simplify_expr(ctx, state, const_malloc, arg, &mut res)) + .collect::>(); + + let mut temp_vars = Vec::new(); + let mut array_targets: Vec<(usize, SimpleExpr, Box)> = Vec::new(); + + for (i, target) in targets.iter().enumerate() { + match target { + AssignmentTarget::Var(var) => { + temp_vars.push(var.clone()); + } + AssignmentTarget::ArrayAccess { array, index } => { + temp_vars.push(state.counters.aux_var()); + array_targets.push((i, array.clone(), index.clone())); + } + } + } + + res.push(SimpleLine::FunctionCall { + function_name: function_name.clone(), + args: simplified_args, + return_data: temp_vars.clone(), + line_number: *line_number, }); + + // For array access targets, add DEREF instructions to copy temp to array element + for (i, array, index) in array_targets { + handle_array_assignment( + ctx, + state, + const_malloc, + &mut res, + array, + &[*index], + ArrayAccessType::ArrayIsAssigned(Expression::Value(SimpleExpr::Var( + temp_vars[i].clone(), + ))), + ); + } + } + _ => { + // Non-function call - must have exactly one target + assert!(targets.len() == 1, "Non-function call must have exactly one target"); + let target = &targets[0]; + + match target { + AssignmentTarget::Var(var) => { + // Variable assignment + match value { + Expression::Value(val) => { + res.push(SimpleLine::equality(var.clone(), val.clone())); + } + Expression::ArrayAccess { array, index } => { + handle_array_assignment( + ctx, + state, + const_malloc, + &mut res, + array.clone(), + index, + ArrayAccessType::VarIsAssigned(var.clone()), + ); + } + Expression::Binary { left, operation, right } => { + let left = simplify_expr(ctx, state, const_malloc, left, &mut res); + let right = simplify_expr(ctx, state, const_malloc, right, &mut res); + // If both operands are constants, evaluate at compile time and assign the result + if let (SimpleExpr::Constant(left_cst), SimpleExpr::Constant(right_cst)) = + (&left, &right) + { + let result = ConstExpression::Binary { + left: Box::new(left_cst.clone()), + operation: *operation, + right: Box::new(right_cst.clone()), + } + .try_naive_simplification(); + res.push(SimpleLine::equality(var.clone(), SimpleExpr::Constant(result))); + } else { + res.push(SimpleLine::Assignment { + var: var.clone().into(), + operation: *operation, + arg0: left, + arg1: right, + }); + } + } + Expression::MathExpr(_, _) | Expression::Len { .. } => unreachable!(), + Expression::FunctionCall { .. } => { + unreachable!("FunctionCall should be handled above") + } + } + } + AssignmentTarget::ArrayAccess { array, index } => { + // Array element assignment + handle_array_assignment( + ctx, + state, + const_malloc, + &mut res, + array.clone(), + std::slice::from_ref(&**index), + ArrayAccessType::ArrayIsAssigned(value.clone()), + ); + } + } } } - Expression::MathExpr(_, _) | Expression::Len { .. } => unreachable!(), - Expression::FunctionCall { .. } => { - let result = simplify_expr( - value, - &mut res, - counters, - array_manager, - const_malloc, - const_arrays, - functions, - ); - res.push(SimpleLine::Assignment { - var: var.clone().into(), - operation: HighLevelOperation::Add, - arg0: result, - arg1: SimpleExpr::zero(), - }); - } - }, - Line::ArrayAssign { array, index, value } => { - handle_array_assignment( - counters, - &mut res, - array.clone(), - std::slice::from_ref(index), - ArrayAccessType::ArrayIsAssigned(value.clone()), - array_manager, - const_malloc, - const_arrays, - functions, - ); } Line::Assert { boolean, line_number, debug, } => { - let left = simplify_expr( - &boolean.left, - &mut res, - counters, - array_manager, - const_malloc, - const_arrays, - functions, - ); - let right = simplify_expr( - &boolean.right, - &mut res, - counters, - array_manager, - const_malloc, - const_arrays, - functions, - ); + let left = simplify_expr(ctx, state, const_malloc, &boolean.left, &mut res); + let right = simplify_expr(ctx, state, const_malloc, &boolean.right, &mut res); if *debug { res.push(SimpleLine::DebugAssert( @@ -758,8 +781,7 @@ fn simplify_lines( } else { match boolean.kind { Boolean::Different => { - let diff_var = format!("@aux_var_{}", counters.aux_vars); - counters.aux_vars += 1; + let diff_var = state.counters.aux_var(); res.push(SimpleLine::Assignment { var: diff_var.clone().into(), operation: HighLevelOperation::Sub, @@ -774,7 +796,7 @@ fn simplify_lines( }); } Boolean::Equal => { - let (var, other) = if let Ok(left) = left.clone().try_into() { + let (var, other): (VarOrConstMallocAccess, _) = if let Ok(left) = left.clone().try_into() { (left, right) } else if let Ok(right) = right.clone().try_into() { (right, left) @@ -799,12 +821,7 @@ fn simplify_lines( } panic!("Unsupported equality assertion: {left:?}, {right:?}") }; - res.push(SimpleLine::Assignment { - var, - operation: HighLevelOperation::Add, - arg0: other, - arg1: SimpleExpr::zero(), - }); + res.push(SimpleLine::equality(var, other)); } Boolean::LessThan => unreachable!(), } @@ -826,27 +843,10 @@ fn simplify_lines( Boolean::LessThan => unreachable!(), }; - let left_simplified = simplify_expr( - left, - &mut res, - counters, - array_manager, - const_malloc, - const_arrays, - functions, - ); - let right_simplified = simplify_expr( - right, - &mut res, - counters, - array_manager, - const_malloc, - const_arrays, - functions, - ); + let left_simplified = simplify_expr(ctx, state, const_malloc, left, &mut res); + let right_simplified = simplify_expr(ctx, state, const_malloc, right, &mut res); - let diff_var = format!("@diff_{}", counters.aux_vars); - counters.aux_vars += 1; + let diff_var = state.counters.aux_var(); res.push(SimpleLine::Assignment { var: diff_var.clone().into(), operation: HighLevelOperation::Sub, @@ -856,29 +856,20 @@ fn simplify_lines( (diff_var.into(), then_branch, else_branch) } Condition::Expression(condition, assume_boolean) => { - let condition_simplified = simplify_expr( - condition, - &mut res, - counters, - array_manager, - const_malloc, - const_arrays, - functions, - ); + let condition_simplified = simplify_expr(ctx, state, const_malloc, condition, &mut res); match assume_boolean { AssumeBoolean::AssumeBoolean => {} AssumeBoolean::DoNotAssumeBoolean => { // Check condition_simplified is boolean - let one_minus_condition_var = format!("@aux_{}", counters.aux_vars); - counters.aux_vars += 1; + let one_minus_condition_var = state.counters.aux_var(); res.push(SimpleLine::Assignment { var: one_minus_condition_var.clone().into(), operation: HighLevelOperation::Sub, arg0: SimpleExpr::Constant(ConstExpression::Value(ConstantValue::Scalar(1))), arg1: condition_simplified.clone(), }); - res.push(SimpleLine::TestZero { + res.push(SimpleLine::AssertZero { operation: HighLevelOperation::Mul, arg0: condition_simplified.clone(), arg1: one_minus_condition_var.into(), @@ -890,38 +881,43 @@ fn simplify_lines( } }; - let mut array_manager_then = array_manager.clone(); + let mut array_manager_then = state.array_manager.clone(); + let mut state_then = SimplifyState { + counters: state.counters, + array_manager: &mut array_manager_then, + }; let then_branch_simplified = simplify_lines( - functions, + ctx, + &mut state_then, + const_malloc, + new_functions, file_id, n_returned_vars, then_branch, - counters, - new_functions, in_a_loop, - &mut array_manager_then, - const_malloc, - const_arrays, ); let mut array_manager_else = array_manager_then.clone(); - array_manager_else.valid = array_manager.valid.clone(); // Crucial: remove the access added in the IF branch + array_manager_else.valid = state.array_manager.valid.clone(); // Crucial: remove the access added in the IF branch + let mut state_else = SimplifyState { + counters: state.counters, + array_manager: &mut array_manager_else, + }; let else_branch_simplified = simplify_lines( - functions, + ctx, + &mut state_else, + const_malloc, + new_functions, file_id, n_returned_vars, else_branch, - counters, - new_functions, in_a_loop, - &mut array_manager_else, - const_malloc, - const_arrays, ); - *array_manager = array_manager_else.clone(); + *state.array_manager = array_manager_else.clone(); // keep the intersection both branches - array_manager.valid = array_manager + state.array_manager.valid = state + .array_manager .valid .intersection(&array_manager_then.valid) .cloned() @@ -953,32 +949,30 @@ fn simplify_lines( counter: const_malloc.counter, ..ConstMalloc::default() }; - let valid_aux_vars_in_array_manager_before = array_manager.valid.clone(); - array_manager.valid.clear(); + let valid_aux_vars_in_array_manager_before = state.array_manager.valid.clone(); + state.array_manager.valid.clear(); let simplified_body = simplify_lines( - functions, + ctx, + state, + &mut loop_const_malloc, + new_functions, file_id, 0, body, - counters, - new_functions, true, - array_manager, - &mut loop_const_malloc, - const_arrays, ); const_malloc.counter = loop_const_malloc.counter; - array_manager.valid = valid_aux_vars_in_array_manager_before; // restore the valid aux vars + state.array_manager.valid = valid_aux_vars_in_array_manager_before; // restore the valid aux vars - let func_name = format!("@loop_{}_line_{}", counters.loops, line_number); - counters.loops += 1; + let func_name = format!("@loop_{}_line_{}", state.counters.loops, line_number); + state.counters.loops += 1; // Find variables used inside loop but defined outside - let (_, mut external_vars) = find_variable_usage(body, const_arrays); + let (_, mut external_vars) = find_variable_usage(body, ctx.const_arrays); // Include variables in start/end for expr in [start, end] { - for var in vars_in_expression(expr, const_arrays) { + for var in vars_in_expression(expr, ctx.const_arrays) { external_vars.insert(var); } } @@ -986,34 +980,15 @@ fn simplify_lines( let mut external_vars: Vec<_> = external_vars.into_iter().collect(); - let start_simplified = simplify_expr( - start, - &mut res, - counters, - array_manager, - const_malloc, - const_arrays, - functions, - ); - let mut end_simplified = simplify_expr( - end, - &mut res, - counters, - array_manager, - const_malloc, - const_arrays, - functions, - ); + let start_simplified = simplify_expr(ctx, state, const_malloc, start, &mut res); + let mut end_simplified = simplify_expr(ctx, state, const_malloc, end, &mut res); if let SimpleExpr::ConstMallocAccess { malloc_label, offset } = end_simplified.clone() { // we use an auxilary variable to store the end value (const malloc inside non-unrolled loops does not work) - let aux_end_var = format!("@aux_end_{}", counters.aux_vars); - counters.aux_vars += 1; - res.push(SimpleLine::Assignment { - var: aux_end_var.clone().into(), - operation: HighLevelOperation::Add, - arg0: SimpleExpr::ConstMallocAccess { malloc_label, offset }, - arg1: SimpleExpr::zero(), - }); + let aux_end_var = state.counters.aux_var(); + res.push(SimpleLine::equality( + aux_end_var.clone(), + SimpleExpr::ConstMallocAccess { malloc_label, offset }, + )); end_simplified = SimpleExpr::Var(aux_end_var); } @@ -1059,77 +1034,6 @@ fn simplify_lines( line_number: *line_number, }); } - Line::FunctionCall { - function_name, - args, - return_data, - line_number, - } => { - let function = functions - .get(function_name) - .unwrap_or_else(|| panic!("Function used but not defined: {function_name}")); - if return_data.len() != function.n_returned_vars { - panic!( - "Expected {} returned vars in call to {function_name}", - function.n_returned_vars - ); - } - - let simplified_args = args - .iter() - .map(|arg| { - simplify_expr( - arg, - &mut res, - counters, - array_manager, - const_malloc, - const_arrays, - functions, - ) - }) - .collect::>(); - - // Generate temp vars for all return values and track array targets - let mut temp_vars = Vec::new(); - let mut array_targets: Vec<(usize, SimpleExpr, Box)> = Vec::new(); - - for (i, target) in return_data.iter().enumerate() { - match target { - AssignmentTarget::Var(var) => { - temp_vars.push(var.clone()); - } - AssignmentTarget::ArrayAccess { array, index } => { - let temp_var = format!("@ret_temp_{}", counters.aux_vars); - counters.aux_vars += 1; - temp_vars.push(temp_var); - array_targets.push((i, array.clone(), index.clone())); - } - } - } - - res.push(SimpleLine::FunctionCall { - function_name: function_name.clone(), - args: simplified_args, - return_data: temp_vars.clone(), - line_number: *line_number, - }); - - // For array access targets, add DEREF instructions to copy temp to array element - for (i, array, index) in array_targets { - handle_array_assignment( - counters, - &mut res, - array, - &[*index], - ArrayAccessType::ArrayIsAssigned(Expression::Value(SimpleExpr::Var(temp_vars[i].clone()))), - array_manager, - const_malloc, - const_arrays, - functions, - ); - } - } Line::FunctionRet { return_data } => { assert!(!in_a_loop, "Function return inside a loop is not currently supported"); assert!( @@ -1139,17 +1043,7 @@ fn simplify_lines( ); let simplified_return_data = return_data .iter() - .map(|ret| { - simplify_expr( - ret, - &mut res, - counters, - array_manager, - const_malloc, - const_arrays, - functions, - ) - }) + .map(|ret| simplify_expr(ctx, state, const_malloc, ret, &mut res)) .collect::>(); res.push(SimpleLine::FunctionRet { return_data: simplified_return_data, @@ -1158,17 +1052,7 @@ fn simplify_lines( Line::Precompile { table, args } => { let simplified_args = args .iter() - .map(|arg| { - simplify_expr( - arg, - &mut res, - counters, - array_manager, - const_malloc, - const_arrays, - functions, - ) - }) + .map(|arg| simplify_expr(ctx, state, const_malloc, arg, &mut res)) .collect::>(); res.push(SimpleLine::Precompile { table: *table, @@ -1178,17 +1062,7 @@ fn simplify_lines( Line::Print { line_info, content } => { let simplified_content = content .iter() - .map(|var| { - simplify_expr( - var, - &mut res, - counters, - array_manager, - const_malloc, - const_arrays, - functions, - ) - }) + .map(|var| simplify_expr(ctx, state, const_malloc, var, &mut res)) .collect::>(); res.push(SimpleLine::Print { line_info: line_info.clone(), @@ -1205,24 +1079,8 @@ fn simplify_lines( vectorized, vectorized_len, } => { - let simplified_size = simplify_expr( - size, - &mut res, - counters, - array_manager, - const_malloc, - const_arrays, - functions, - ); - let simplified_vectorized_len = simplify_expr( - vectorized_len, - &mut res, - counters, - array_manager, - const_malloc, - const_arrays, - functions, - ); + let simplified_size = simplify_expr(ctx, state, const_malloc, size, &mut res); + let simplified_vectorized_len = simplify_expr(ctx, state, const_malloc, vectorized_len, &mut res); match simplified_size { SimpleExpr::Constant(const_size) if !*vectorized => { let label = const_malloc.counter; @@ -1250,17 +1108,7 @@ fn simplify_lines( Line::CustomHint(hint, args) => { let simplified_args = args .iter() - .map(|expr| { - simplify_expr( - expr, - &mut res, - counters, - array_manager, - const_malloc, - const_arrays, - functions, - ) - }) + .map(|expr| simplify_expr(ctx, state, const_malloc, expr, &mut res)) .collect::>(); res.push(SimpleLine::CustomHint(*hint, simplified_args)); } @@ -1277,38 +1125,28 @@ fn simplify_lines( } fn simplify_expr( + ctx: &SimplifyContext<'_>, + state: &mut SimplifyState<'_>, + const_malloc: &ConstMalloc, expr: &Expression, lines: &mut Vec, - counters: &mut Counters, - array_manager: &mut ArrayManager, - const_malloc: &ConstMalloc, - const_arrays: &BTreeMap, - functions: &BTreeMap, ) -> SimpleExpr { match expr { Expression::Value(value) => value.simplify_if_const(), Expression::ArrayAccess { array, index } => { // Check for const array access if let SimpleExpr::Var(array_var) = array - && let Some(arr) = const_arrays.get(array_var) + && let Some(arr) = ctx.const_arrays.get(array_var) { let simplified_index = index .iter() .map(|idx| { - simplify_expr( - idx, - lines, - counters, - array_manager, - const_malloc, - const_arrays, - functions, - ) - .as_constant() - .expect("Const array access index should be constant") - .naive_eval() - .expect("Const array access index should be constant") - .to_usize() + simplify_expr(ctx, state, const_malloc, idx, lines) + .as_constant() + .expect("Const array access index should be constant") + .naive_eval() + .expect("Const array access index should be constant") + .to_usize() }) .collect::>(); @@ -1334,44 +1172,26 @@ fn simplify_expr( }; } - let aux_arr = array_manager.get_aux_var(array, &index); // auxiliary var to store m[array + index] + let aux_arr = state.array_manager.get_aux_var(array, &index); // auxiliary var to store m[array + index] - if !array_manager.valid.insert(aux_arr.clone()) { + if !state.array_manager.valid.insert(aux_arr.clone()) { return SimpleExpr::Var(aux_arr); } handle_array_assignment( - counters, + ctx, + state, + const_malloc, lines, array.clone(), &[index], ArrayAccessType::VarIsAssigned(aux_arr.clone()), - array_manager, - const_malloc, - const_arrays, - functions, ); SimpleExpr::Var(aux_arr) } Expression::Binary { left, operation, right } => { - let left_var = simplify_expr( - left, - lines, - counters, - array_manager, - const_malloc, - const_arrays, - functions, - ); - let right_var = simplify_expr( - right, - lines, - counters, - array_manager, - const_malloc, - const_arrays, - functions, - ); + let left_var = simplify_expr(ctx, state, const_malloc, left, lines); + let right_var = simplify_expr(ctx, state, const_malloc, right, lines); if let (SimpleExpr::Constant(left_cst), SimpleExpr::Constant(right_cst)) = (&left_var, &right_var) { return SimpleExpr::Constant(ConstExpression::Binary { @@ -1381,8 +1201,7 @@ fn simplify_expr( }); } - let aux_var = format!("@aux_var_{}", counters.aux_vars); - counters.aux_vars += 1; + let aux_var = state.counters.aux_var(); lines.push(SimpleLine::Assignment { var: aux_var.clone().into(), operation: *operation, @@ -1395,23 +1214,16 @@ fn simplify_expr( let simplified_args = args .iter() .map(|arg| { - simplify_expr( - arg, - lines, - counters, - array_manager, - const_malloc, - const_arrays, - functions, - ) - .as_constant() - .unwrap() + simplify_expr(ctx, state, const_malloc, arg, lines) + .as_constant() + .unwrap() }) .collect::>(); SimpleExpr::Constant(ConstExpression::MathExpr(*formula, simplified_args)) } Expression::FunctionCall { function_name, args } => { - let function = functions + let function = ctx + .functions .get(function_name) .unwrap_or_else(|| panic!("Function used but not defined: {function_name}")); assert_eq!( @@ -1422,22 +1234,11 @@ fn simplify_expr( let simplified_args = args .iter() - .map(|arg| { - simplify_expr( - arg, - lines, - counters, - array_manager, - const_malloc, - const_arrays, - functions, - ) - }) + .map(|arg| simplify_expr(ctx, state, const_malloc, arg, lines)) .collect::>(); // Create a temporary variable for the function result - let result_var = format!("@nested_call_{}", counters.aux_vars); - counters.aux_vars += 1; + let result_var = state.counters.aux_var(); lines.push(SimpleLine::FunctionCall { function_name: function_name.clone(), @@ -1492,9 +1293,24 @@ pub fn find_variable_usage( external_vars.extend(stmt_external.into_iter().filter(|v| !internal_vars.contains(v))); } } - Line::Assignment { var, value } => { + Line::Statement { targets, value, .. } => { on_new_expr(value, &internal_vars, &mut external_vars); - internal_vars.insert(var.clone()); + for target in targets { + match target { + AssignmentTarget::Var(var) => { + internal_vars.insert(var.clone()); + } + AssignmentTarget::ArrayAccess { array, index } => { + if let SimpleExpr::Var(var) = array { + assert!(!const_arrays.contains_key(var), "Cannot assign to const array"); + if !internal_vars.contains(var) { + external_vars.insert(var.clone()); + } + } + on_new_expr(index, &internal_vars, &mut external_vars); + } + } + } } Line::IfCondition { condition, @@ -1515,27 +1331,6 @@ pub fn find_variable_usage( .cloned(), ); } - Line::FunctionCall { args, return_data, .. } => { - for arg in args { - on_new_expr(arg, &internal_vars, &mut external_vars); - } - for target in return_data { - match target { - AssignmentTarget::Var(var) => { - internal_vars.insert(var.clone()); - } - AssignmentTarget::ArrayAccess { array, index } => { - if let SimpleExpr::Var(var) = array { - assert!(!const_arrays.contains_key(var), "Cannot assign to const array"); - if !internal_vars.contains(var) { - external_vars.insert(var.clone()); - } - } - on_new_expr(index, &internal_vars, &mut external_vars); - } - } - } - } Line::Assert { boolean, .. } => { on_new_condition( &Condition::Comparison(boolean.clone()), @@ -1586,11 +1381,6 @@ pub fn find_variable_usage( on_new_expr(start, &internal_vars, &mut external_vars); on_new_expr(end, &internal_vars, &mut external_vars); } - Line::ArrayAssign { array, index, value } => { - on_new_expr(&array.clone().into(), &internal_vars, &mut external_vars); - on_new_expr(index, &internal_vars, &mut external_vars); - on_new_expr(value, &internal_vars, &mut external_vars); - } Line::Panic | Line::Break | Line::LocationReport { .. } => {} } } @@ -1677,9 +1467,19 @@ fn inline_lines( inline_lines(statements, args, res, inlining_count); } } - Line::Assignment { var, value } => { + Line::Statement { targets, value, .. } => { inline_expr(value, args, inlining_count); - inline_internal_var(var); + for target in targets { + match target { + AssignmentTarget::Var(var) => { + inline_internal_var(var); + } + AssignmentTarget::ArrayAccess { array, index } => { + inline_simple_expr(array, args, inlining_count); + inline_expr(index, args, inlining_count); + } + } + } } Line::IfCondition { condition, @@ -1692,26 +1492,6 @@ fn inline_lines( inline_lines(then_branch, args, res, inlining_count); inline_lines(else_branch, args, res, inlining_count); } - Line::FunctionCall { - args: func_args, - return_data, - .. - } => { - for arg in func_args { - inline_expr(arg, args, inlining_count); - } - for target in return_data { - match target { - AssignmentTarget::Var(var) => { - inline_internal_var(var); - } - AssignmentTarget::ArrayAccess { array, index } => { - inline_simple_expr(array, args, inlining_count); - inline_expr(index, args, inlining_count); - } - } - } - } Line::Assert { boolean, .. } => { inline_comparison(boolean); } @@ -1725,16 +1505,10 @@ fn inline_lines( i, res.iter() .zip(return_data) - .map(|(target, expr)| match target { - AssignmentTarget::Var(res_var) => Line::Assignment { - var: res_var.clone(), - value: expr.clone(), - }, - AssignmentTarget::ArrayAccess { array, index } => Line::ArrayAssign { - array: array.clone(), - index: (**index).clone(), - value: expr.clone(), - }, + .map(|(target, expr)| Line::Statement { + targets: vec![target.clone()], + value: expr.clone(), + line_number: 0, }) .collect::>(), )); @@ -1778,11 +1552,6 @@ fn inline_lines( inline_expr(start, args, inlining_count); inline_expr(end, args, inlining_count); } - Line::ArrayAssign { array, index, value } => { - inline_simple_expr(array, args, inlining_count); - inline_expr(index, args, inlining_count); - inline_expr(value, args, inlining_count); - } Line::Panic | Line::Break | Line::LocationReport { .. } => {} } } @@ -1838,25 +1607,22 @@ pub enum ArrayAccessType { ArrayIsAssigned(Expression), // array[index] = expr } -#[allow(clippy::too_many_arguments)] fn handle_array_assignment( - counters: &mut Counters, + ctx: &SimplifyContext<'_>, + state: &mut SimplifyState<'_>, + const_malloc: &ConstMalloc, res: &mut Vec, array: SimpleExpr, index: &[Expression], access_type: ArrayAccessType, - array_manager: &mut ArrayManager, - const_malloc: &ConstMalloc, - const_arrays: &BTreeMap, - functions: &BTreeMap, ) { let simplified_index = index .iter() - .map(|idx| simplify_expr(idx, res, counters, array_manager, const_malloc, const_arrays, functions)) + .map(|idx| simplify_expr(ctx, state, const_malloc, idx, res)) .collect::>(); if let (ArrayAccessType::VarIsAssigned(var), SimpleExpr::Var(array_var)) = (&access_type, &array) - && let Some(const_array) = const_arrays.get(array_var) + && let Some(const_array) = ctx.const_arrays.get(array_var) { let idx = simplified_index .iter() @@ -1873,12 +1639,7 @@ fn handle_array_assignment( .expect("Const array access index out of bounds") .as_scalar() .expect("Const array access should return a scalar"); - res.push(SimpleLine::Assignment { - var: var.clone().into(), - operation: HighLevelOperation::Add, - arg0: SimpleExpr::Constant(ConstExpression::from(value)), - arg1: SimpleExpr::zero(), - }); + res.push(SimpleLine::equality(var.clone(), ConstExpression::from(value))); return; } @@ -1888,24 +1649,8 @@ fn handle_array_assignment( && let Some(label) = const_malloc.map.get(array_var) && let ArrayAccessType::ArrayIsAssigned(Expression::Binary { left, operation, right }) = &access_type { - let arg0 = simplify_expr( - left, - res, - counters, - array_manager, - const_malloc, - const_arrays, - functions, - ); - let arg1 = simplify_expr( - right, - res, - counters, - array_manager, - const_malloc, - const_arrays, - functions, - ); + let arg0 = simplify_expr(ctx, state, const_malloc, left, res); + let arg1 = simplify_expr(ctx, state, const_malloc, right, res); res.push(SimpleLine::Assignment { var: VarOrConstMallocAccess::ConstMallocAccess { malloc_label: *label, @@ -1920,15 +1665,7 @@ fn handle_array_assignment( let value_simplified = match access_type { ArrayAccessType::VarIsAssigned(var) => SimpleExpr::Var(var), - ArrayAccessType::ArrayIsAssigned(expr) => simplify_expr( - &expr, - res, - counters, - array_manager, - const_malloc, - const_arrays, - functions, - ), + ArrayAccessType::ArrayIsAssigned(expr) => simplify_expr(ctx, state, const_malloc, &expr, res), }; // TODO opti: in some case we could use ConstMallocAccess @@ -1938,8 +1675,7 @@ fn handle_array_assignment( SimpleExpr::Constant(c) => (array, c), _ => { // Create pointer variable: ptr = array + index - let ptr_var = format!("@aux_var_{}", counters.aux_vars); - counters.aux_vars += 1; + let ptr_var = state.counters.aux_var(); res.push(SimpleLine::Assignment { var: ptr_var.clone().into(), operation: HighLevelOperation::Add, @@ -2082,25 +1818,31 @@ fn replace_vars_for_unroll( Line::ForwardDeclaration { var } => { *var = format!("@unrolled_{unroll_index}_{iterator_value}_{var}"); } - Line::Assignment { var, value } => { - assert!(var != iterator, "Weird"); - *var = format!("@unrolled_{unroll_index}_{iterator_value}_{var}"); + Line::Statement { targets, value, .. } => { replace_vars_for_unroll_in_expr(value, iterator, unroll_index, iterator_value, internal_vars); - } - Line::ArrayAssign { - // array[index] = value - array, - index, - value, - } => { - if let SimpleExpr::Var(array_var) = array { - assert!(array_var != iterator, "Weird"); - if internal_vars.contains(array_var) { - *array_var = format!("@unrolled_{unroll_index}_{iterator_value}_{array_var}"); + for target in targets { + match target { + AssignmentTarget::Var(var) => { + assert!(var != iterator, "Weird"); + *var = format!("@unrolled_{unroll_index}_{iterator_value}_{var}"); + } + AssignmentTarget::ArrayAccess { array, index } => { + if let SimpleExpr::Var(array_var) = array { + assert!(array_var != iterator, "Weird"); + if internal_vars.contains(array_var) { + *array_var = format!("@unrolled_{unroll_index}_{iterator_value}_{array_var}"); + } + } + replace_vars_for_unroll_in_expr( + index, + iterator, + unroll_index, + iterator_value, + internal_vars, + ); + } } } - replace_vars_for_unroll_in_expr(index, iterator, unroll_index, iterator_value, internal_vars); - replace_vars_for_unroll_in_expr(value, iterator, unroll_index, iterator_value, internal_vars); } Line::Assert { boolean, .. } => { replace_vars_for_unroll_in_expr( @@ -2164,37 +1906,6 @@ fn replace_vars_for_unroll( replace_vars_for_unroll_in_expr(end, iterator, unroll_index, iterator_value, internal_vars); replace_vars_for_unroll(body, iterator, unroll_index, iterator_value, internal_vars); } - Line::FunctionCall { - function_name: _, - args, - return_data, - line_number: _, - } => { - for arg in args { - replace_vars_for_unroll_in_expr(arg, iterator, unroll_index, iterator_value, internal_vars); - } - for target in return_data { - match target { - AssignmentTarget::Var(ret) => { - *ret = format!("@unrolled_{unroll_index}_{iterator_value}_{ret}"); - } - AssignmentTarget::ArrayAccess { array, index } => { - if let SimpleExpr::Var(array_var) = array - && internal_vars.contains(array_var) - { - *array_var = format!("@unrolled_{unroll_index}_{iterator_value}_{array_var}"); - } - replace_vars_for_unroll_in_expr( - index, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - } - } - } - } Line::FunctionRet { return_data } => { for ret in return_data { replace_vars_for_unroll_in_expr(ret, iterator, unroll_index, iterator_value, internal_vars); @@ -2378,10 +2089,12 @@ fn extract_inlined_calls_from_expr( if inlined_functions.contains_key(function_name) { let aux_var = format!("@inlined_var_{}", inlined_var_counter.next()); lines.push(Line::ForwardDeclaration { var: aux_var.clone() }); - lines.push(Line::FunctionCall { - function_name: function_name.clone(), - args: args.clone(), - return_data: vec![AssignmentTarget::Var(aux_var.clone())], + lines.push(Line::Statement { + targets: vec![AssignmentTarget::Var(aux_var.clone())], + value: Expression::FunctionCall { + function_name: function_name.clone(), + args: args.clone(), + }, line_number: 0, }); (Expression::Value(SimpleExpr::Var(aux_var)), lines) @@ -2454,17 +2167,16 @@ fn handle_inlined_functions_helper( Line::Break | Line::Panic | Line::LocationReport { .. } => { lines_out.push(line.clone()); } - Line::FunctionCall { - function_name, - args, - return_data, + Line::Statement { + targets, + value: Expression::FunctionCall { function_name, args }, line_number: _, } => { if let Some(func) = inlined_functions.get(function_name) { let mut inlined_lines = vec![]; // Only add forward declarations for variable targets, not array accesses - for target in return_data.iter() { + for target in targets.iter() { if let AssignmentTarget::Var(var) = target && !ctx.defines(var) { @@ -2480,7 +2192,7 @@ fn handle_inlined_functions_helper( } else { let aux_var = format!("@inlined_var_{}", inlined_var_counter.next()); // Check if the argument is a function call to an inlined function - // If so, create a Line::FunctionCall so it gets inlined in subsequent iterations + // If so, create a Line::Statement so it gets inlined in subsequent iterations if let Expression::FunctionCall { function_name: arg_func_name, args: arg_args, @@ -2488,22 +2200,26 @@ fn handle_inlined_functions_helper( { if inlined_functions.contains_key(arg_func_name) { inlined_lines.push(Line::ForwardDeclaration { var: aux_var.clone() }); - inlined_lines.push(Line::FunctionCall { - function_name: arg_func_name.clone(), - args: arg_args.clone(), - return_data: vec![AssignmentTarget::Var(aux_var.clone())], + inlined_lines.push(Line::Statement { + targets: vec![AssignmentTarget::Var(aux_var.clone())], + value: Expression::FunctionCall { + function_name: arg_func_name.clone(), + args: arg_args.clone(), + }, line_number: 0, }); } else { - inlined_lines.push(Line::Assignment { - var: aux_var.clone(), + inlined_lines.push(Line::Statement { + targets: vec![AssignmentTarget::Var(aux_var.clone())], value: arg.clone(), + line_number: 0, }); } } else { - inlined_lines.push(Line::Assignment { - var: aux_var.clone(), + inlined_lines.push(Line::Statement { + targets: vec![AssignmentTarget::Var(aux_var.clone())], value: arg.clone(), + line_number: 0, }); } simplified_args.push(SimpleExpr::Var(aux_var)); @@ -2517,13 +2233,34 @@ fn handle_inlined_functions_helper( .map(|((var, _), expr)| (var.clone(), expr.clone())) .collect::>(); let mut func_body = func.body.clone(); - inline_lines(&mut func_body, &inlined_args, return_data, total_inlined_counter.next()); + inline_lines(&mut func_body, &inlined_args, targets, total_inlined_counter.next()); inlined_lines.extend(func_body); lines_out.extend(inlined_lines); } else { lines_out.push(line.clone()); } } + Line::Statement { + targets, + value, + line_number, + } => { + let (value, value_lines) = + extract_inlined_calls_from_expr(value, inlined_functions, inlined_var_counter); + lines_out.extend(value_lines); + for target in targets { + if let AssignmentTarget::Var(var) = target + && !ctx.defines(var) + { + ctx.add_var(var); + } + } + lines_out.push(Line::Statement { + targets: targets.clone(), + value, + line_number: *line_number, + }); + } Line::IfCondition { condition, then_branch, @@ -2638,31 +2375,6 @@ fn handle_inlined_functions_helper( line_number: *line_number, }); } - Line::Assignment { var, value } => { - let (value, value_lines) = - extract_inlined_calls_from_expr(value, inlined_functions, inlined_var_counter); - lines_out.extend(value_lines); - if !ctx.defines(var) { - ctx.add_var(var); - } - lines_out.push(Line::Assignment { - var: var.clone(), - value, - }); - } - Line::ArrayAssign { array, index, value } => { - let (index, index_lines) = - extract_inlined_calls_from_expr(index, inlined_functions, inlined_var_counter); - lines_out.extend(index_lines); - let (value, value_lines) = - extract_inlined_calls_from_expr(value, inlined_functions, inlined_var_counter); - lines_out.extend(value_lines); - lines_out.push(Line::ArrayAssign { - array: array.clone(), - index, - value, - }); - } Line::Print { line_info, content } => { let mut new_content = vec![]; for expr in content { @@ -2823,13 +2535,12 @@ fn handle_const_arguments_helper( let mut changed = false; 'outer: for line in lines { match line { - Line::FunctionCall { - function_name, - args, - return_data: _, + Line::Statement { + targets: _, + value: Expression::FunctionCall { function_name, args }, line_number: _, } => { - if let Some(func) = constant_functions.get(function_name) { + if let Some(func) = constant_functions.get(function_name.as_str()) { // Check if all const arguments can be evaluated let mut const_evals = Vec::new(); for (arg_expr, (arg_var, is_constant)) in args.iter().zip(&func.arguments) { @@ -2889,6 +2600,7 @@ fn handle_const_arguments_helper( ); } } + Line::Statement { .. } => {} Line::IfCondition { then_branch, else_branch, @@ -2980,25 +2692,12 @@ fn replace_vars_by_const_in_lines(lines: &mut [Line], map: &BTreeMap) { Line::ForwardDeclaration { var } => { assert!(!map.contains_key(var), "Variable {var} is a constant"); } - Line::Assignment { var, value } => { - assert!(!map.contains_key(var), "Variable {var} is a constant"); - replace_vars_by_const_in_expr(value, map); - } - Line::ArrayAssign { array, index, value } => { - if let SimpleExpr::Var(array_var) = array { - assert!(!map.contains_key(array_var), "Array {array_var} is a constant"); - } - replace_vars_by_const_in_expr(index, map); + Line::Statement { targets, value, .. } => { replace_vars_by_const_in_expr(value, map); - } - Line::FunctionCall { args, return_data, .. } => { - for arg in args { - replace_vars_by_const_in_expr(arg, map); - } - for target in return_data { + for target in targets { match target { - AssignmentTarget::Var(ret) => { - assert!(!map.contains_key(ret), "Return variable {ret} is a constant"); + AssignmentTarget::Var(var) => { + assert!(!map.contains_key(var), "Variable {var} is a constant"); } AssignmentTarget::ArrayAccess { array, index } => { if let SimpleExpr::Var(array_var) = array { @@ -3129,7 +2828,7 @@ impl SimpleLine { Self::RawAccess { res, index, shift } => { format!("{res} = memory[{index} + {shift}]") } - Self::TestZero { operation, arg0, arg1 } => { + Self::AssertZero { operation, arg0, arg1 } => { format!("0 = {arg0} {operation} {arg1}") } Self::IfNotZero { diff --git a/crates/lean_compiler/src/b_compile_intermediate.rs b/crates/lean_compiler/src/b_compile_intermediate.rs index 7c581720..f6a4f86f 100644 --- a/crates/lean_compiler/src/b_compile_intermediate.rs +++ b/crates/lean_compiler/src/b_compile_intermediate.rs @@ -198,7 +198,7 @@ fn compile_lines( )); } - SimpleLine::TestZero { operation, arg0, arg1 } => { + SimpleLine::AssertZero { operation, arg0, arg1 } => { instructions.push(IntermediateInstruction::computation( *operation, IntermediateValue::from_simple_expr(arg0, compiler), diff --git a/crates/lean_compiler/src/grammar.pest b/crates/lean_compiler/src/grammar.pest index 9ba1123a..9a1f57f2 100644 --- a/crates/lean_compiler/src/grammar.pest +++ b/crates/lean_compiler/src/grammar.pest @@ -23,9 +23,6 @@ return_count = { "->" ~ number } // Statements statement = { forward_declaration | - function_call | - single_assignment | - array_assign | if_statement | for_statement | match_statement | @@ -36,7 +33,8 @@ statement = { assert_not_eq_statement | debug_assert_eq_statement | debug_assert_not_eq_statement | - debug_assert_lt_statement + debug_assert_lt_statement | + assignment } return_statement = { "return" ~ (tuple_expression)? ~ ";" } @@ -46,9 +44,10 @@ continue_statement = { "continue" ~ ";" } forward_declaration = { "var" ~ identifier ~ ";" } -single_assignment = { identifier ~ "=" ~ expression ~ ";" } - -array_assign = { identifier ~ "[" ~ expression ~ "]" ~ "=" ~ expression ~ ";" } +// General assignment: LHS is optional list of variables/array accesses, RHS is any expression +assignment = { (assignment_target_list ~ "=")? ~ expression ~ ";" } +assignment_target_list = { assignment_target ~ ("," ~ assignment_target)* } +assignment_target = { array_access_expr | identifier } if_statement = { "if" ~ condition ~ "{" ~ statement* ~ "}" ~ else_if_clause* ~ else_clause? } @@ -68,11 +67,6 @@ match_statement = { "match" ~ expression ~ "{" ~ match_arm* ~ "}" } match_arm = { pattern ~ "=>" ~ "{" ~ statement* ~ "}" } pattern = { constant_value } -function_call = { function_res? ~ identifier ~ "(" ~ tuple_expression? ~ ")" ~ ";" } -function_res = { return_target_list ~ "=" } -return_target_list = { return_target ~ ("," ~ return_target)* } -return_target = { array_access_expr | identifier } - assert_eq_statement = { "assert" ~ add_expr ~ "==" ~ add_expr ~ ";" } assert_not_eq_statement = { "assert" ~ add_expr ~ "!=" ~ add_expr ~ ";" } diff --git a/crates/lean_compiler/src/lang.rs b/crates/lean_compiler/src/lang.rs index 47d96fab..3a421853 100644 --- a/crates/lean_compiler/src/lang.rs +++ b/crates/lean_compiler/src/lang.rs @@ -419,15 +419,10 @@ pub enum Line { ForwardDeclaration { var: Var, }, - Assignment { - var: Var, - value: Expression, - }, - ArrayAssign { - // array[index] = value - array: SimpleExpr, - index: Expression, - value: Expression, + Statement { + targets: Vec, // LHS - can be empty for standalone calls + value: Expression, // RHS - any expression + line_number: SourceLineNumber, }, Assert { debug: bool, @@ -449,12 +444,6 @@ pub enum Line { unroll: bool, line_number: SourceLineNumber, }, - FunctionCall { - function_name: String, - args: Vec, - return_data: Vec, // Changed from Vec - line_number: SourceLineNumber, - }, FunctionRet { return_data: Vec, }, @@ -582,11 +571,17 @@ impl Line { Self::ForwardDeclaration { var } => { format!("var {var}") } - Self::Assignment { var, value } => { - format!("{var} = {value}") - } - Self::ArrayAssign { array, index, value } => { - format!("{array}[{index}] = {value}") + Self::Statement { targets, value, .. } => { + if targets.is_empty() { + format!("{value}") + } else { + let targets_str = targets + .iter() + .map(|target| target.to_string()) + .collect::>() + .join(", "); + format!("{targets_str} = {value}") + } } Self::PrivateInputStart { result } => { format!("{result} = private_input_start()") @@ -645,25 +640,6 @@ impl Line { spaces ) } - Self::FunctionCall { - function_name, - args, - return_data, - line_number: _, - } => { - let args_str = args.iter().map(|arg| format!("{arg}")).collect::>().join(", "); - let return_data_str = return_data - .iter() - .map(|target| target.to_string()) - .collect::>() - .join(", "); - - if return_data.is_empty() { - format!("{function_name}({args_str})") - } else { - format!("{return_data_str} = {function_name}({args_str})") - } - } Self::FunctionRet { return_data } => { let return_data_str = return_data .iter() diff --git a/crates/lean_compiler/src/parser/parsers/function.rs b/crates/lean_compiler/src/parser/parsers/function.rs index efda8db0..e351cc8e 100644 --- a/crates/lean_compiler/src/parser/parsers/function.rs +++ b/crates/lean_compiler/src/parser/parsers/function.rs @@ -116,23 +116,12 @@ impl Parse for ReturnCountParser { } } -/// Parser for return target lists (used in function calls). -pub struct ReturnTargetListParser; +/// Parser for individual assignment targets (variable or array access). +pub struct AssignmentTargetParser; -impl Parse> for ReturnTargetListParser { - fn parse(&self, pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult> { - pair.into_inner() - .map(|item| ReturnTargetParser.parse(item, ctx)) - .collect() - } -} - -/// Parser for individual return targets (variable or array access). -pub struct ReturnTargetParser; - -impl Parse for ReturnTargetParser { +impl Parse for AssignmentTargetParser { fn parse(&self, pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult { - let inner = next_inner_pair(&mut pair.into_inner(), "return target")?; + let inner = next_inner_pair(&mut pair.into_inner(), "assignment target")?; match inner.as_rule() { Rule::array_access_expr => { @@ -150,35 +139,31 @@ impl Parse for ReturnTargetParser { } } -/// Parser for function calls with special handling for built-in functions. -pub struct FunctionCallParser; +pub struct AssignmentParser; -impl Parse for FunctionCallParser { +impl Parse for AssignmentParser { fn parse(&self, pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult { - let mut return_data: Vec = Vec::new(); - let mut function_name = String::new(); - let mut args = Vec::new(); let line_number = pair.line_col().0; + let mut inner = pair.into_inner().peekable(); - for item in pair.into_inner() { - match item.as_rule() { - Rule::function_res => { - for res_item in item.into_inner() { - if res_item.as_rule() == Rule::return_target_list { - return_data = ReturnTargetListParser.parse(res_item, ctx)?; - } - } - } - Rule::identifier => function_name = item.as_str().to_string(), - Rule::tuple_expression => { - args = TupleExpressionParser.parse(item, ctx)?; - } - _ => {} - } + // Check if there's an assignment_target_list (LHS) + let mut targets: Vec = Vec::new(); + if let Some(first) = inner.peek() + && first.as_rule() == Rule::assignment_target_list + { + targets = inner + .next() + .unwrap() + .into_inner() + .map(|item| AssignmentTargetParser.parse(item, ctx)) + .collect::>>()?; } - // Replace trash variables with unique names - for target in &mut return_data { + // Parse the expression (RHS) + let expr_pair = next_inner_pair(&mut inner, "expression")?; + let expr = ExpressionParser.parse(expr_pair, ctx)?; + + for target in &mut targets { if let AssignmentTarget::Var(var) = target && var == "_" { @@ -186,13 +171,34 @@ impl Parse for FunctionCallParser { } } - // Handle built-in functions - Self::handle_builtin_function(line_number, function_name, args, return_data) + match &expr { + Expression::FunctionCall { function_name, args } => { + Self::handle_function_call(line_number, function_name.clone(), args.clone(), targets) + } + _ => { + // Non-function-call expression - must have exactly one target + if targets.is_empty() { + return Err(SemanticError::new("Expression statement has no effect").into()); + } + if targets.len() > 1 { + return Err(SemanticError::new( + "Multiple assignment targets require a function call on the right side", + ) + .into()); + } + + Ok(Line::Statement { + targets, + value: expr, + line_number, + }) + } + } } } -impl FunctionCallParser { - fn handle_builtin_function( +impl AssignmentParser { + fn handle_function_call( line_number: SourceLineNumber, function_name: String, args: Vec, @@ -288,10 +294,9 @@ impl FunctionCallParser { return Ok(Line::CustomHint(hint, args)); } // Regular function call - allow array access targets - Ok(Line::FunctionCall { - function_name, - args, - return_data, + Ok(Line::Statement { + targets: return_data, + value: Expression::FunctionCall { function_name, args }, line_number, }) } diff --git a/crates/lean_compiler/src/parser/parsers/statement.rs b/crates/lean_compiler/src/parser/parsers/statement.rs index 62738a00..915b8f27 100644 --- a/crates/lean_compiler/src/parser/parsers/statement.rs +++ b/crates/lean_compiler/src/parser/parsers/statement.rs @@ -1,7 +1,7 @@ use lean_vm::{Boolean, BooleanExpr}; use super::expression::ExpressionParser; -use super::function::{FunctionCallParser, TupleExpressionParser}; +use super::function::{AssignmentParser, TupleExpressionParser}; use super::literal::ConstExprParser; use super::{Parse, ParseContext, next_inner_pair}; use crate::{ @@ -23,13 +23,11 @@ impl Parse for StatementParser { match inner.as_rule() { Rule::forward_declaration => ForwardDeclarationParser.parse(inner, ctx), - Rule::single_assignment => AssignmentParser.parse(inner, ctx), - Rule::array_assign => ArrayAssignParser.parse(inner, ctx), + Rule::assignment => AssignmentParser.parse(inner, ctx), Rule::if_statement => IfStatementParser.parse(inner, ctx), Rule::for_statement => ForStatementParser.parse(inner, ctx), Rule::match_statement => MatchStatementParser.parse(inner, ctx), Rule::return_statement => ReturnStatementParser.parse(inner, ctx), - Rule::function_call => FunctionCallParser.parse(inner, ctx), Rule::assert_eq_statement => AssertEqParser::.parse(inner, ctx), Rule::assert_not_eq_statement => AssertNotEqParser::.parse(inner, ctx), Rule::debug_assert_eq_statement => AssertEqParser::.parse(inner, ctx), @@ -53,38 +51,6 @@ impl Parse for ForwardDeclarationParser { } } -/// Parser for variable assignments. -pub struct AssignmentParser; - -impl Parse for AssignmentParser { - fn parse(&self, pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult { - let mut inner = pair.into_inner(); - let var = next_inner_pair(&mut inner, "variable name")?.as_str().to_string(); - let expr = next_inner_pair(&mut inner, "assignment value")?; - let value = ExpressionParser.parse(expr, ctx)?; - - Ok(Line::Assignment { var, value }) - } -} - -/// Parser for array element assignments. -pub struct ArrayAssignParser; - -impl Parse for ArrayAssignParser { - fn parse(&self, pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult { - let mut inner = pair.into_inner(); - let array = next_inner_pair(&mut inner, "array name")?.as_str().to_string(); - let index = ExpressionParser.parse(next_inner_pair(&mut inner, "array index")?, ctx)?; - let value = ExpressionParser.parse(next_inner_pair(&mut inner, "array value")?, ctx)?; - - Ok(Line::ArrayAssign { - array: array.into(), - index, - value, - }) - } -} - /// Parser for if-else conditional statements. pub struct IfStatementParser;