diff --git a/crates/lean_compiler/src/a_simplify_lang.rs b/crates/lean_compiler/src/a_simplify_lang.rs index 2f30428a..72192c41 100644 --- a/crates/lean_compiler/src/a_simplify_lang.rs +++ b/crates/lean_compiler/src/a_simplify_lang.rs @@ -40,12 +40,7 @@ pub enum VarOrConstMallocAccess { impl From for SimpleExpr { fn from(var_or_const: VarOrConstMallocAccess) -> Self { - match var_or_const { - VarOrConstMallocAccess::Var(var) => Self::Var(var), - VarOrConstMallocAccess::ConstMallocAccess { malloc_label, offset } => { - Self::ConstMallocAccess { malloc_label, offset } - } - } + Self::Memory(var_or_const) } } @@ -54,10 +49,7 @@ impl TryInto for SimpleExpr { fn try_into(self) -> Result { match self { - Self::Var(var) => Ok(VarOrConstMallocAccess::Var(var)), - Self::ConstMallocAccess { malloc_label, offset } => { - Ok(VarOrConstMallocAccess::ConstMallocAccess { malloc_label, offset }) - } + Self::Memory(var_or_const) => Ok(var_or_const), _ => Err(()), } } @@ -395,8 +387,7 @@ fn check_block_scoping(block: &[Line], ctx: &mut Context) { } // Second pass: check array access targets for target in targets { - if let AssignmentTarget::ArrayAccess { array, index } = target { - check_simple_expr_scoping(array, ctx); + if let AssignmentTarget::ArrayAccess { array: _, index } = target { check_expr_scoping(index, ctx); } } @@ -480,8 +471,7 @@ fn check_expr_scoping(expr: &Expression, ctx: &Context) { Expression::Value(simple_expr) => { check_simple_expr_scoping(simple_expr, ctx); } - Expression::ArrayAccess { array, index } => { - check_simple_expr_scoping(array, ctx); + Expression::ArrayAccess { array: _, index } => { for idx in index { check_expr_scoping(idx, ctx); } @@ -507,11 +497,11 @@ fn check_expr_scoping(expr: &Expression, ctx: &Context) { /// Analyzes the simple expression to verify that each variable is defined in the given context. fn check_simple_expr_scoping(expr: &SimpleExpr, ctx: &Context) { match expr { - SimpleExpr::Var(v) => { + SimpleExpr::Memory(VarOrConstMallocAccess::Var(v)) => { assert!(ctx.defines(v), "Variable used but not defined: {v}"); } + SimpleExpr::Memory(VarOrConstMallocAccess::ConstMallocAccess { .. }) => {} SimpleExpr::Constant(_) => {} - SimpleExpr::ConstMallocAccess { .. } => {} } } @@ -558,8 +548,8 @@ struct SimplifyState<'a> { #[derive(Debug, Clone, Default)] struct ArrayManager { counter: usize, - aux_vars: BTreeMap<(SimpleExpr, Expression), Var>, // (array, index) -> aux_var - valid: BTreeSet, // currently valid aux vars + aux_vars: BTreeMap<(Var, Expression), Var>, // (array, index) -> aux_var + valid: BTreeSet, // currently valid aux vars } #[derive(Debug, Clone, Default)] @@ -569,7 +559,7 @@ pub struct ConstMalloc { } impl ArrayManager { - fn get_aux_var(&mut self, array: &SimpleExpr, index: &Expression) -> Var { + fn get_aux_var(&mut self, array: &Var, index: &Expression) -> Var { if let Some(var) = self.aux_vars.get(&(array.clone(), index.clone())) { return var.clone(); } @@ -650,7 +640,7 @@ fn simplify_lines( .collect::>(); let mut temp_vars = Vec::new(); - let mut array_targets: Vec<(usize, SimpleExpr, Box)> = Vec::new(); + let mut array_targets: Vec<(usize, Var, Box)> = Vec::new(); for (i, target) in targets.iter().enumerate() { match target { @@ -678,11 +668,11 @@ fn simplify_lines( state, const_malloc, &mut res, - array, + &array, &[*index], - ArrayAccessType::ArrayIsAssigned(Expression::Value(SimpleExpr::Var( - temp_vars[i].clone(), - ))), + ArrayAccessType::ArrayIsAssigned(Expression::Value( + VarOrConstMallocAccess::Var(temp_vars[i].clone()).into(), + )), ); } } @@ -703,7 +693,7 @@ fn simplify_lines( state, const_malloc, &mut res, - array.clone(), + array, index, ArrayAccessType::VarIsAssigned(var.clone()), ); @@ -740,7 +730,7 @@ fn simplify_lines( state, const_malloc, &mut res, - array.clone(), + array, std::slice::from_ref(&**index), ArrayAccessType::ArrayIsAssigned(value.clone()), ); @@ -950,14 +940,16 @@ fn simplify_lines( let start_simplified = simplify_expr(ctx, state, const_malloc, start, &mut res); let mut end_simplified = simplify_expr(ctx, state, const_malloc, end, &mut res); - if let SimpleExpr::ConstMallocAccess { malloc_label, offset } = end_simplified.clone() { + if let SimpleExpr::Memory(VarOrConstMallocAccess::ConstMallocAccess { malloc_label, offset }) = + end_simplified.clone() + { // we use an auxilary variable to store the end value (const malloc inside non-unrolled loops does not work) let aux_end_var = state.counters.aux_var(); res.push(SimpleLine::equality( aux_end_var.clone(), - SimpleExpr::ConstMallocAccess { malloc_label, offset }, + VarOrConstMallocAccess::ConstMallocAccess { malloc_label, offset }, )); - end_simplified = SimpleExpr::Var(aux_end_var); + end_simplified = VarOrConstMallocAccess::Var(aux_end_var).into(); } for (simplified, original) in [ @@ -966,7 +958,7 @@ fn simplify_lines( ] { if !matches!(original, Expression::Value(_)) { // the simplified var is auxiliary - if let SimpleExpr::Var(var) = simplified { + if let SimpleExpr::Memory(VarOrConstMallocAccess::Var(var)) = simplified { external_vars.push(var); } } @@ -1100,12 +1092,10 @@ fn simplify_expr( lines: &mut Vec, ) -> SimpleExpr { match expr { - Expression::Value(value) => value.simplify_if_const(), + Expression::Value(value) => value.clone(), Expression::ArrayAccess { array, index } => { // Check for const array access - if let SimpleExpr::Var(array_var) = array - && let Some(arr) = ctx.const_arrays.get(array_var) - { + if let Some(arr) = ctx.const_arrays.get(array) { let simplified_index = index .iter() .map(|idx| { @@ -1129,20 +1119,20 @@ fn simplify_expr( assert_eq!(index.len(), 1); let index = index[0].clone(); - if let SimpleExpr::Var(array_var) = array - && let Some(label) = const_malloc.map.get(array_var) + if let Some(label) = const_malloc.map.get(array) && let Ok(offset) = ConstExpression::try_from(index.clone()) { - return SimpleExpr::ConstMallocAccess { + return VarOrConstMallocAccess::ConstMallocAccess { malloc_label: *label, offset, - }; + } + .into(); } let aux_arr = state.array_manager.get_aux_var(array, &index); // auxiliary var to store m[array + index] if !state.array_manager.valid.insert(aux_arr.clone()) { - return SimpleExpr::Var(aux_arr); + return VarOrConstMallocAccess::Var(aux_arr).into(); } handle_array_assignment( @@ -1150,11 +1140,11 @@ fn simplify_expr( state, const_malloc, lines, - array.clone(), + array, &[index], ArrayAccessType::VarIsAssigned(aux_arr.clone()), ); - SimpleExpr::Var(aux_arr) + VarOrConstMallocAccess::Var(aux_arr).into() } Expression::MathExpr(operation, args) => { let simplified_args = args @@ -1172,7 +1162,7 @@ fn simplify_expr( arg0: simplified_args[0].clone(), arg1: simplified_args[1].clone(), }); - SimpleExpr::Var(aux_var) + VarOrConstMallocAccess::Var(aux_var).into() } Expression::FunctionCall { function_name, args } => { let function = ctx @@ -1200,7 +1190,7 @@ fn simplify_expr( line_number: 0, // No source line number for nested calls }); - SimpleExpr::Var(result_var) + VarOrConstMallocAccess::Var(result_var).into() } Expression::Len { .. } => unreachable!(), } @@ -1254,11 +1244,9 @@ pub fn find_variable_usage( internal_vars.insert(var.clone()); } AssignmentTarget::ArrayAccess { array, index } => { - if let SimpleExpr::Var(var) = array { - assert!(!const_arrays.contains_key(var), "Cannot assign to const array"); - if !internal_vars.contains(var) { - external_vars.insert(var.clone()); - } + assert!(!const_arrays.contains_key(array), "Cannot assign to const array"); + if !internal_vars.contains(array) { + external_vars.insert(array.clone()); } on_new_expr(index, &internal_vars, &mut external_vars); } @@ -1342,7 +1330,7 @@ pub fn find_variable_usage( } fn inline_simple_expr(simple_expr: &mut SimpleExpr, args: &BTreeMap, inlining_count: usize) { - if let SimpleExpr::Var(var) = simple_expr { + if let SimpleExpr::Memory(VarOrConstMallocAccess::Var(var)) = simple_expr { if let Some(replacement) = args.get(var) { *simple_expr = replacement.clone(); } else { @@ -1357,7 +1345,14 @@ fn inline_expr(expr: &mut Expression, args: &BTreeMap, inlining inline_simple_expr(value, args, inlining_count); } Expression::ArrayAccess { array, index } => { - inline_simple_expr(array, args, inlining_count); + if let Some(replacement) = args.get(array) { + let SimpleExpr::Memory(VarOrConstMallocAccess::Var(new_array)) = replacement else { + panic!("Cannot inline array access with non-variable array argument"); + }; + *array = new_array.clone(); + } else { + *array = format!("@inlined_var_{inlining_count}_{array}"); + } for idx in index { inline_expr(idx, args, inlining_count); } @@ -1424,7 +1419,16 @@ fn inline_lines( inline_internal_var(var); } AssignmentTarget::ArrayAccess { array, index } => { - inline_simple_expr(array, args, inlining_count); + if let Some(replacement) = args.get(array) { + // Array is a function argument - replace with the argument's var name + let SimpleExpr::Memory(VarOrConstMallocAccess::Var(new_array)) = replacement else { + panic!("Cannot inline array access target with non-variable array argument"); + }; + *array = new_array.clone(); + } else { + // Internal variable - rename with inlining prefix + *array = format!("@inlined_var_{inlining_count}_{array}"); + } inline_expr(index, args, inlining_count); } } @@ -1513,14 +1517,12 @@ fn vars_in_expression(expr: &Expression, const_arrays: &BTreeMap { - if let SimpleExpr::Var(var) = value { + if let SimpleExpr::Memory(VarOrConstMallocAccess::Var(var)) = value { vars.insert(var.clone()); } } Expression::ArrayAccess { array, index } => { - if let SimpleExpr::Var(array) = array - && !const_arrays.contains_key(array) - { + if !const_arrays.contains_key(array) { vars.insert(array.clone()); } for idx in index { @@ -1557,7 +1559,7 @@ fn handle_array_assignment( state: &mut SimplifyState<'_>, const_malloc: &ConstMalloc, res: &mut Vec, - array: SimpleExpr, + array: &Var, index: &[Expression], access_type: ArrayAccessType, ) { @@ -1566,8 +1568,8 @@ fn handle_array_assignment( .map(|idx| simplify_expr(ctx, state, const_malloc, idx, res)) .collect::>(); - if let (ArrayAccessType::VarIsAssigned(var), SimpleExpr::Var(array_var)) = (&access_type, &array) - && let Some(const_array) = ctx.const_arrays.get(array_var) + if let ArrayAccessType::VarIsAssigned(var) = &access_type + && let Some(const_array) = ctx.const_arrays.get(array) { let idx = simplified_index .iter() @@ -1590,8 +1592,7 @@ fn handle_array_assignment( if simplified_index.len() == 1 && let SimpleExpr::Constant(offset) = simplified_index[0].clone() - && let SimpleExpr::Var(array_var) = &array - && let Some(label) = const_malloc.map.get(array_var) + && let Some(label) = const_malloc.map.get(array) && let ArrayAccessType::ArrayIsAssigned(Expression::MathExpr(operation, args)) = &access_type { let var = VarOrConstMallocAccess::ConstMallocAccess { @@ -1617,7 +1618,7 @@ fn handle_array_assignment( } let value_simplified = match access_type { - ArrayAccessType::VarIsAssigned(var) => SimpleExpr::Var(var), + ArrayAccessType::VarIsAssigned(var) => SimpleExpr::Memory(VarOrConstMallocAccess::Var(var)), ArrayAccessType::ArrayIsAssigned(expr) => simplify_expr(ctx, state, const_malloc, &expr, res), }; @@ -1625,17 +1626,20 @@ fn handle_array_assignment( assert_eq!(simplified_index.len(), 1); let simplified_index = simplified_index[0].clone(); let (index_var, shift) = match simplified_index { - SimpleExpr::Constant(c) => (array, c), + SimpleExpr::Constant(c) => (SimpleExpr::Memory(VarOrConstMallocAccess::Var(array.clone())), c), _ => { // Create pointer variable: ptr = array + index let ptr_var = state.counters.aux_var(); res.push(SimpleLine::Assignment { var: ptr_var.clone().into(), operation: MathOperation::Add, - arg0: array, + arg0: SimpleExpr::Memory(VarOrConstMallocAccess::Var(array.clone())), arg1: simplified_index, }); - (SimpleExpr::Var(ptr_var), ConstExpression::zero()) + ( + SimpleExpr::Memory(VarOrConstMallocAccess::Var(ptr_var)), + ConstExpression::zero(), + ) } }; @@ -1711,21 +1715,19 @@ fn replace_vars_for_unroll_in_expr( ) { match expr { Expression::Value(value_expr) => match value_expr { - SimpleExpr::Var(var) => { + SimpleExpr::Memory(VarOrConstMallocAccess::Var(var)) => { if var == iterator { *value_expr = SimpleExpr::Constant(ConstExpression::from(iterator_value)); } else if internal_vars.contains(var) { *var = format!("@unrolled_{unroll_index}_{iterator_value}_{var}"); } } - SimpleExpr::Constant(_) | SimpleExpr::ConstMallocAccess { .. } => {} + SimpleExpr::Constant(_) | SimpleExpr::Memory(VarOrConstMallocAccess::ConstMallocAccess { .. }) => {} }, Expression::ArrayAccess { array, index } => { - if let SimpleExpr::Var(array_var) = array { - assert!(array_var != iterator, "Weird"); - if internal_vars.contains(array_var) { - *array_var = format!("@unrolled_{unroll_index}_{iterator_value}_{array_var}"); - } + assert!(array != iterator, "Weird"); + if internal_vars.contains(array) { + *array = format!("@unrolled_{unroll_index}_{iterator_value}_{array}"); } for index in index { replace_vars_for_unroll_in_expr(index, iterator, unroll_index, iterator_value, internal_vars); @@ -1776,11 +1778,9 @@ fn replace_vars_for_unroll( *var = format!("@unrolled_{unroll_index}_{iterator_value}_{var}"); } AssignmentTarget::ArrayAccess { array, index } => { - if let SimpleExpr::Var(array_var) = array { - assert!(array_var != iterator, "Weird"); - if internal_vars.contains(array_var) { - *array_var = format!("@unrolled_{unroll_index}_{iterator_value}_{array_var}"); - } + assert!(array != iterator, "Weird"); + if internal_vars.contains(array) { + *array = format!("@unrolled_{unroll_index}_{iterator_value}_{array}"); } replace_vars_for_unroll_in_expr( index, @@ -2032,7 +2032,7 @@ fn extract_inlined_calls_from_expr( }, line_number: 0, }); - (Expression::Value(SimpleExpr::Var(aux_var)), lines) + (Expression::Value(VarOrConstMallocAccess::Var(aux_var).into()), lines) } else { (expr.clone(), lines) } @@ -2157,7 +2157,7 @@ fn handle_inlined_functions_helper( line_number: 0, }); } - simplified_args.push(SimpleExpr::Var(aux_var)); + simplified_args.push(VarOrConstMallocAccess::Var(aux_var).into()); } } assert_eq!(simplified_args.len(), func.arguments.len()); @@ -2575,20 +2575,18 @@ fn handle_const_arguments_helper( fn replace_vars_by_const_in_expr(expr: &mut Expression, map: &BTreeMap) { match expr { Expression::Value(value) => match &value { - SimpleExpr::Var(var) => { + SimpleExpr::Memory(VarOrConstMallocAccess::Var(var)) => { if let Some(const_value) = map.get(var) { *value = SimpleExpr::scalar(const_value.to_usize()); } } - SimpleExpr::ConstMallocAccess { .. } => { + SimpleExpr::Memory(VarOrConstMallocAccess::ConstMallocAccess { .. }) => { unreachable!() } SimpleExpr::Constant(_) => {} }, Expression::ArrayAccess { array, index } => { - if let SimpleExpr::Var(array_var) = array { - assert!(!map.contains_key(array_var), "Array {array_var} is a constant"); - } + assert!(!map.contains_key(array), "Array {array} is a constant"); for index in index { replace_vars_by_const_in_expr(index, map); } @@ -2631,9 +2629,7 @@ fn replace_vars_by_const_in_lines(lines: &mut [Line], map: &BTreeMap) { assert!(!map.contains_key(var), "Variable {var} is a constant"); } AssignmentTarget::ArrayAccess { array, index } => { - if let SimpleExpr::Var(array_var) = array { - assert!(!map.contains_key(array_var), "Array {array_var} is a constant"); - } + assert!(!map.contains_key(array), "Array {array} is a constant"); replace_vars_by_const_in_expr(index, map); } } diff --git a/crates/lean_compiler/src/b_compile_intermediate.rs b/crates/lean_compiler/src/b_compile_intermediate.rs index 3cb3762d..06bb2505 100644 --- a/crates/lean_compiler/src/b_compile_intermediate.rs +++ b/crates/lean_compiler/src/b_compile_intermediate.rs @@ -63,16 +63,18 @@ impl Compiler { impl SimpleExpr { fn to_mem_after_fp_or_constant(&self, compiler: &Compiler) -> IntermediateValue { match self { - Self::Var(var) => IntermediateValue::MemoryAfterFp { + Self::Memory(VarOrConstMallocAccess::Var(var)) => IntermediateValue::MemoryAfterFp { offset: compiler.get_offset(&var.clone().into()), }, + Self::Memory(VarOrConstMallocAccess::ConstMallocAccess { malloc_label, offset }) => { + IntermediateValue::MemoryAfterFp { + offset: compiler.get_offset(&VarOrConstMallocAccess::ConstMallocAccess { + malloc_label: *malloc_label, + offset: offset.clone(), + }), + } + } Self::Constant(c) => IntermediateValue::Constant(c.clone()), - Self::ConstMallocAccess { malloc_label, offset } => IntermediateValue::MemoryAfterFp { - offset: compiler.get_offset(&VarOrConstMallocAccess::ConstMallocAccess { - malloc_label: *malloc_label, - offset: offset.clone(), - }), - }, } } } @@ -80,16 +82,18 @@ impl SimpleExpr { impl IntermediateValue { fn from_simple_expr(expr: &SimpleExpr, compiler: &Compiler) -> Self { match expr { - SimpleExpr::Var(var) => Self::MemoryAfterFp { + SimpleExpr::Memory(VarOrConstMallocAccess::Var(var)) => Self::MemoryAfterFp { offset: compiler.get_offset(&var.clone().into()), }, + SimpleExpr::Memory(VarOrConstMallocAccess::ConstMallocAccess { malloc_label, offset }) => { + Self::MemoryAfterFp { + offset: compiler.get_offset(&VarOrConstMallocAccess::ConstMallocAccess { + malloc_label: *malloc_label, + offset: offset.clone(), + }), + } + } SimpleExpr::Constant(c) => Self::Constant(c.clone()), - SimpleExpr::ConstMallocAccess { malloc_label, offset } => Self::MemoryAfterFp { - offset: compiler.get_offset(&VarOrConstMallocAccess::ConstMallocAccess { - malloc_label: *malloc_label, - offset: offset.clone(), - }), - }, } } @@ -373,7 +377,7 @@ fn compile_lines( } SimpleLine::RawAccess { res, index, shift } => { - if let SimpleExpr::Var(var) = res + if let SimpleExpr::Memory(VarOrConstMallocAccess::Var(var)) = res && !compiler.is_in_scope(var) { let current_scope_layout = compiler.stack_frame_layout.scopes.last_mut().unwrap(); diff --git a/crates/lean_compiler/src/lang.rs b/crates/lean_compiler/src/lang.rs index 3ceda7fc..4afe46a0 100644 --- a/crates/lean_compiler/src/lang.rs +++ b/crates/lean_compiler/src/lang.rs @@ -5,6 +5,7 @@ use std::collections::{BTreeMap, BTreeSet}; use std::fmt::{Display, Formatter}; use utils::ToUsize; +use crate::a_simplify_lang::VarOrConstMallocAccess; use crate::{F, parser::ConstArrayValue}; pub use lean_vm::{FileId, FunctionName, SourceLocation}; @@ -39,12 +40,8 @@ pub type ConstMallocLabel = usize; #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum SimpleExpr { - Var(Var), + Memory(VarOrConstMallocAccess), Constant(ConstExpression), - ConstMallocAccess { - malloc_label: ConstMallocLabel, - offset: ConstExpression, - }, } impl SimpleExpr { @@ -63,13 +60,6 @@ impl SimpleExpr { pub const fn is_constant(&self) -> bool { matches!(self, Self::Constant(_)) } - - pub fn simplify_if_const(&self) -> Self { - if let Self::Constant(constant) = self { - return constant.clone().into(); - } - self.clone() - } } impl From for SimpleExpr { @@ -86,16 +76,15 @@ impl From for SimpleExpr { impl From for SimpleExpr { fn from(var: Var) -> Self { - Self::Var(var) + VarOrConstMallocAccess::Var(var).into() } } impl SimpleExpr { pub fn as_constant(&self) -> Option { match self { - Self::Var(_) => None, Self::Constant(constant) => Some(constant.clone()), - Self::ConstMallocAccess { .. } => None, + Self::Memory(_) => None, } } @@ -226,7 +215,7 @@ impl Display for Condition { pub enum Expression { Value(SimpleExpr), ArrayAccess { - array: SimpleExpr, + array: Var, index: Vec, // multi-dimensional array access }, MathExpr(MathOperation, Vec), @@ -356,10 +345,7 @@ impl Expression { self.eval_with( &|value: &SimpleExpr| value.as_constant()?.naive_eval(), &|arr, indexes| { - let SimpleExpr::Var(name) = arr else { - return None; - }; - let array = const_arrays.get(name)?; + let array = const_arrays.get(arr)?; assert_eq!(indexes.len(), array.depth()); let idx = indexes.iter().map(|e| e.to_usize()).collect::>(); array.navigate(&idx)?.as_scalar().map(F::from_usize) @@ -370,7 +356,7 @@ impl Expression { pub fn eval_with(&self, value_fn: &ValueFn, array_fn: &ArrayFn) -> Option where ValueFn: Fn(&SimpleExpr) -> Option, - ArrayFn: Fn(&SimpleExpr, Vec) -> Option, + ArrayFn: Fn(&Var, Vec) -> Option, { match self { Self::Value(value) => value_fn(value), @@ -405,7 +391,7 @@ impl Expression { #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum AssignmentTarget { Var(Var), - ArrayAccess { array: SimpleExpr, index: Box }, + ArrayAccess { array: Var, index: Box }, } impl Display for AssignmentTarget { @@ -713,11 +699,8 @@ impl Display for ConstantValue { impl Display for SimpleExpr { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { - Self::Var(var) => write!(f, "{var}"), Self::Constant(constant) => write!(f, "{constant}"), - Self::ConstMallocAccess { malloc_label, offset } => { - write!(f, "malloc_access({malloc_label}, {offset})") - } + Self::Memory(var_or_const_malloc_access) => write!(f, "{var_or_const_malloc_access}"), } } } diff --git a/crates/lean_compiler/src/parser/parsers/expression.rs b/crates/lean_compiler/src/parser/parsers/expression.rs index bf331a84..28473e5f 100644 --- a/crates/lean_compiler/src/parser/parsers/expression.rs +++ b/crates/lean_compiler/src/parser/parsers/expression.rs @@ -80,16 +80,13 @@ pub struct ArrayAccessParser; impl Parse for ArrayAccessParser { fn parse(&self, pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult { let mut inner = pair.into_inner(); - let array_name = next_inner_pair(&mut inner, "array name")?.as_str().to_string(); + let array = next_inner_pair(&mut inner, "array name")?.as_str().to_string(); let index: Vec = inner .map(|idx_pair| ExpressionParser.parse(idx_pair, ctx)) .collect::, _>>()?; - Ok(Expression::ArrayAccess { - array: SimpleExpr::Var(array_name), - index, - }) + Ok(Expression::ArrayAccess { array, index }) } } diff --git a/crates/lean_compiler/src/parser/parsers/function.rs b/crates/lean_compiler/src/parser/parsers/function.rs index e351cc8e..ebfd7d63 100644 --- a/crates/lean_compiler/src/parser/parsers/function.rs +++ b/crates/lean_compiler/src/parser/parsers/function.rs @@ -3,7 +3,7 @@ use super::statement::StatementParser; use super::{Parse, ParseContext, next_inner_pair}; use crate::{ SourceLineNumber, - lang::{AssignmentTarget, Expression, Function, Line, SimpleExpr, SourceLocation}, + lang::{AssignmentTarget, Expression, Function, Line, SourceLocation}, parser::{ error::{ParseResult, SemanticError}, grammar::{ParsePair, Rule}, @@ -129,7 +129,7 @@ impl Parse for AssignmentTargetParser { let array = next_inner_pair(&mut inner_pairs, "array name")?.as_str().to_string(); let index = ExpressionParser.parse(next_inner_pair(&mut inner_pairs, "array index")?, ctx)?; Ok(AssignmentTarget::ArrayAccess { - array: SimpleExpr::Var(array), + array, index: Box::new(index), }) } diff --git a/crates/lean_compiler/src/parser/parsers/literal.rs b/crates/lean_compiler/src/parser/parsers/literal.rs index 6a245aa9..2c63a6e0 100644 --- a/crates/lean_compiler/src/parser/parsers/literal.rs +++ b/crates/lean_compiler/src/parser/parsers/literal.rs @@ -1,5 +1,6 @@ use super::expression::ExpressionParser; use super::{ConstArrayValue, Parse, ParseContext, ParsedConstant, next_inner_pair}; +use crate::a_simplify_lang::VarOrConstMallocAccess; use crate::{ F, lang::{ConstExpression, ConstantValue, SimpleExpr}, @@ -90,16 +91,13 @@ pub fn evaluate_const_expr(expr: &crate::lang::Expression, ctx: &ParseContext) - expr.eval_with( &|simple_expr| match simple_expr { SimpleExpr::Constant(cst) => cst.naive_eval(), - SimpleExpr::Var(var) => ctx.get_constant(var).map(F::from_usize), - SimpleExpr::ConstMallocAccess { .. } => None, + SimpleExpr::Memory(VarOrConstMallocAccess::Var(var)) => ctx.get_constant(var).map(F::from_usize), + SimpleExpr::Memory(VarOrConstMallocAccess::ConstMallocAccess { .. }) => None, }, &|arr, index| { // Support const array access in expressions - let SimpleExpr::Var(name) = arr else { - return None; - }; let idx = index.iter().map(|e| e.to_usize()).collect::>(); - let array = ctx.get_const_array(name)?; + let array = ctx.get_const_array(arr)?; array.navigate(&idx)?.as_scalar().map(F::from_usize) }, ) @@ -161,7 +159,7 @@ impl VarOrConstantParser { } // Otherwise treat as variable reference else { - Ok(SimpleExpr::Var(text.to_string())) + Ok(VarOrConstMallocAccess::Var(text.to_string()).into()) } } } diff --git a/crates/lean_compiler/tests/test_compiler.rs b/crates/lean_compiler/tests/test_compiler.rs index 5fcc7037..b3c4c3e6 100644 --- a/crates/lean_compiler/tests/test_compiler.rs +++ b/crates/lean_compiler/tests/test_compiler.rs @@ -456,6 +456,36 @@ fn test_mini_program_4() { dbg!(&poseidon24_permute(public_input)[16..]); } +#[test] +fn test_mini_program_5() { + let program = r#" + fn main() { + arr = malloc(10); + arr[6] = 42; + arr[8] = 11; + sum_1 = func_1(arr[6], arr[8]); + assert sum_1 == 53; + return; + } + + fn func_1(i, j) inline -> 1 { + for k in 0..i { + for u in 0..j { + assert k + u != 1000000; + } + } + return i + j; + } + + "#; + compile_and_run( + &ProgramSource::Raw(program.to_string()), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); +} + #[test] fn test_inlined() { let program = r#"