diff --git a/CHANGELOG.md b/CHANGELOG.md index 34f7309ba..2c5b05a7c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ - Update documentation and tests thereof (#437). - Add a constant propagation pass after other mir passes (#439). - Removed `TraceSegmentId::index()` and replaced segment indexing with `TraceShape`/`FullTraceShape` across ACE codegen, MIR-to-AIR pass, and constraints (#442). +- Allow computed indices (#444). - Fix regressions on MIR and list_comprehensions (#449). ## 0.4.0 (2025-06-20) diff --git a/air-script/tests/codegen/winterfell.rs b/air-script/tests/codegen/winterfell.rs index 7f8bb9ee8..28a993cd6 100644 --- a/air-script/tests/codegen/winterfell.rs +++ b/air-script/tests/codegen/winterfell.rs @@ -2,9 +2,6 @@ use expect_test::expect_file; use super::helpers::{Target, Test}; -// tests_wo_mir -// ================================================================================================ - #[test] fn binary() { let generated_air = Test::new("tests/binary/binary.air".to_string()) @@ -16,21 +13,12 @@ fn binary() { } #[test] -fn buses_simple() { - let generated_air = Test::new("tests/buses/buses_simple.air".to_string()) - .transpile(Target::Winterfell) - .unwrap(); - - let expected = expect_file!["../buses/buses_simple.rs"]; - expected.assert_eq(&generated_air); -} -#[test] -fn buses_simple_with_evaluators() { - let generated_air = Test::new("tests/buses/buses_simple_with_evaluators.air".to_string()) +fn bitwise() { + let generated_air = Test::new("tests/bitwise/bitwise.air".to_string()) .transpile(Target::Winterfell) .unwrap(); - let expected = expect_file!["../buses/buses_simple.rs"]; + let expected = expect_file!["../bitwise/bitwise.rs"]; expected.assert_eq(&generated_air); } @@ -45,22 +33,22 @@ fn buses_complex() { } #[test] -fn buses_varlen_boundary_first() { - let generated_air = Test::new("tests/buses/buses_varlen_boundary_first.air".to_string()) +fn buses_simple() { + let generated_air = Test::new("tests/buses/buses_simple.air".to_string()) .transpile(Target::Winterfell) .unwrap(); - let expected = expect_file!["../buses/buses_varlen_boundary_first.rs"]; + let expected = expect_file!["../buses/buses_simple.rs"]; expected.assert_eq(&generated_air); } #[test] -fn buses_varlen_boundary_last() { - let generated_air = Test::new("tests/buses/buses_varlen_boundary_last.air".to_string()) +fn buses_simple_with_evaluators() { + let generated_air = Test::new("tests/buses/buses_simple_with_evaluators.air".to_string()) .transpile(Target::Winterfell) .unwrap(); - let expected = expect_file!["../buses/buses_varlen_boundary_last.rs"]; + let expected = expect_file!["../buses/buses_simple.rs"]; expected.assert_eq(&generated_air); } @@ -75,42 +63,53 @@ fn buses_varlen_boundary_both() { } #[test] -fn periodic_columns() { - let generated_air = Test::new("tests/periodic_columns/periodic_columns.air".to_string()) +fn buses_varlen_boundary_first() { + let generated_air = Test::new("tests/buses/buses_varlen_boundary_first.air".to_string()) .transpile(Target::Winterfell) .unwrap(); - let expected = expect_file!["../periodic_columns/periodic_columns.rs"]; + let expected = expect_file!["../buses/buses_varlen_boundary_first.rs"]; expected.assert_eq(&generated_air); } #[test] -fn pub_inputs() { - let generated_air = Test::new("tests/pub_inputs/pub_inputs.air".to_string()) +fn buses_varlen_boundary_last() { + let generated_air = Test::new("tests/buses/buses_varlen_boundary_last.air".to_string()) .transpile(Target::Winterfell) .unwrap(); - let expected = expect_file!["../pub_inputs/pub_inputs.rs"]; + let expected = expect_file!["../buses/buses_varlen_boundary_last.rs"]; expected.assert_eq(&generated_air); } #[test] -fn system() { - let generated_air = Test::new("tests/system/system.air".to_string()) +fn computed_indices_complex() { + let generated_air = + Test::new("tests/computed_indices/computed_indices_complex.air".to_string()) + .transpile(Target::Winterfell) + .unwrap(); + + let expected = expect_file!["../computed_indices/computed_indices_complex.rs"]; + expected.assert_eq(&generated_air); +} + +#[test] +fn computed_indices_simple() { + let generated_air = Test::new("tests/computed_indices/computed_indices_simple.air".to_string()) .transpile(Target::Winterfell) .unwrap(); - let expected = expect_file!["../system/system.rs"]; + let expected = expect_file!["../computed_indices/computed_indices_simple.rs"]; expected.assert_eq(&generated_air); } #[test] -fn bitwise() { - let generated_air = Test::new("tests/bitwise/bitwise.air".to_string()) +fn constant_in_range() { + let generated_air = Test::new("tests/constant_in_range/constant_in_range.air".to_string()) .transpile(Target::Winterfell) .unwrap(); - let expected = expect_file!["../bitwise/bitwise.rs"]; + let expected = expect_file!["../constant_in_range/constant_in_range.rs"]; expected.assert_eq(&generated_air); } @@ -125,12 +124,21 @@ fn constants() { } #[test] -fn constant_in_range() { - let generated_air = Test::new("tests/constant_in_range/constant_in_range.air".to_string()) - .transpile(Target::Winterfell) - .unwrap(); +fn constraint_comprehension() { + let generated_air = + Test::new("tests/constraint_comprehension/constraint_comprehension.air".to_string()) + .transpile(Target::Winterfell) + .unwrap(); - let expected = expect_file!["../constant_in_range/constant_in_range.rs"]; + let expected = expect_file!["../constraint_comprehension/constraint_comprehension.rs"]; + expected.assert_eq(&generated_air); + + let generated_air = + Test::new("tests/constraint_comprehension/cc_with_evaluators.air".to_string()) + .transpile(Target::Winterfell) + .unwrap(); + + let expected = expect_file!["../constraint_comprehension/constraint_comprehension.rs"]; expected.assert_eq(&generated_air); } @@ -154,28 +162,6 @@ fn fibonacci() { expected.assert_eq(&generated_air); } -#[test] -fn functions_simple() { - let generated_air = Test::new("tests/functions/functions_simple.air".to_string()) - .transpile(Target::Winterfell) - .unwrap(); - - let expected = expect_file!["../functions/functions_simple.rs"]; - expected.assert_eq(&generated_air); -} - -#[test] -fn functions_simple_inlined() { - // make sure that the constraints generated using inlined functions are the same as the ones - // generated using regular functions - let generated_air = Test::new("tests/functions/inlined_functions_simple.air".to_string()) - .transpile(Target::Winterfell) - .unwrap(); - - let expected = expect_file!["../functions/functions_simple.rs"]; - expected.assert_eq(&generated_air); -} - #[test] fn functions_complex() { let generated_air = Test::new("tests/functions/functions_complex.air".to_string()) @@ -187,22 +173,24 @@ fn functions_complex() { } #[test] -fn variables() { - let generated_air = Test::new("tests/variables/variables.air".to_string()) +fn functions_simple() { + let generated_air = Test::new("tests/functions/functions_simple.air".to_string()) .transpile(Target::Winterfell) .unwrap(); - let expected = expect_file!["../variables/variables.rs"]; + let expected = expect_file!["../functions/functions_simple.rs"]; expected.assert_eq(&generated_air); } #[test] -fn trace_col_groups() { - let generated_air = Test::new("tests/trace_col_groups/trace_col_groups.air".to_string()) +fn functions_simple_inlined() { + // make sure that the constraints generated using inlined functions are the same as the ones + // generated using regular functions + let generated_air = Test::new("tests/functions/inlined_functions_simple.air".to_string()) .transpile(Target::Winterfell) .unwrap(); - let expected = expect_file!["../trace_col_groups/trace_col_groups.rs"]; + let expected = expect_file!["../functions/functions_simple.rs"]; expected.assert_eq(&generated_air); } @@ -248,6 +236,26 @@ fn list_folding() { expected.assert_eq(&generated_air); } +#[test] +fn periodic_columns() { + let generated_air = Test::new("tests/periodic_columns/periodic_columns.air".to_string()) + .transpile(Target::Winterfell) + .unwrap(); + + let expected = expect_file!["../periodic_columns/periodic_columns.rs"]; + expected.assert_eq(&generated_air); +} + +#[test] +fn pub_inputs() { + let generated_air = Test::new("tests/pub_inputs/pub_inputs.air".to_string()) + .transpile(Target::Winterfell) + .unwrap(); + + let expected = expect_file!["../pub_inputs/pub_inputs.rs"]; + expected.assert_eq(&generated_air); +} + #[test] fn selectors() { let generated_air = Test::new("tests/selectors/selectors.air".to_string()) @@ -297,20 +305,31 @@ fn selectors_combine_with_list_comprehensions() { } #[test] -fn constraint_comprehension() { - let generated_air = - Test::new("tests/constraint_comprehension/constraint_comprehension.air".to_string()) - .transpile(Target::Winterfell) - .unwrap(); +fn system() { + let generated_air = Test::new("tests/system/system.air".to_string()) + .transpile(Target::Winterfell) + .unwrap(); - let expected = expect_file!["../constraint_comprehension/constraint_comprehension.rs"]; + let expected = expect_file!["../system/system.rs"]; expected.assert_eq(&generated_air); +} - let generated_air = - Test::new("tests/constraint_comprehension/cc_with_evaluators.air".to_string()) - .transpile(Target::Winterfell) - .unwrap(); +#[test] +fn trace_col_groups() { + let generated_air = Test::new("tests/trace_col_groups/trace_col_groups.air".to_string()) + .transpile(Target::Winterfell) + .unwrap(); - let expected = expect_file!["../constraint_comprehension/constraint_comprehension.rs"]; + let expected = expect_file!["../trace_col_groups/trace_col_groups.rs"]; + expected.assert_eq(&generated_air); +} + +#[test] +fn variables() { + let generated_air = Test::new("tests/variables/variables.air".to_string()) + .transpile(Target::Winterfell) + .unwrap(); + + let expected = expect_file!["../variables/variables.rs"]; expected.assert_eq(&generated_air); } diff --git a/air-script/tests/computed_indices/computed_indices_complex.air b/air-script/tests/computed_indices/computed_indices_complex.air new file mode 100644 index 000000000..1b1dbffbb --- /dev/null +++ b/air-script/tests/computed_indices/computed_indices_complex.air @@ -0,0 +1,61 @@ +def ComputedIndicesAir + +const MDS = [ + [1, 2], + [2, 3], + [3, 4] +]; + +trace_columns { + main: [a[2], s[2]], +} + +public_inputs { + input: [1], +} + +fn double(a: felt) -> felt { + let x = 3 * a; + let y = a; + return x - y; +} + +boundary_constraints { + enf a[0].first = 0; +} + +# Note: +# In this test, we aim to test that computed indices work well even if: +# - The value of the index can only be known late during the compilation process (during MIR's constant propagation) +# - The +integrity_constraints { + + # vec_1 is a list_comprehension that depends on the state + # vec_1 = [ + # 1 * s[0] + 2 * s[1], + # 2 * s[0] + 3 * s[1], + # 3 * s[0] + 4 * s[1] + # ]; + let vec_1 = apply_mds(s); + + let state_2 = [2, 0]; + # vec_2 is a list_comprehension that will not get constant-folded early, but will produce constant values + let vec_2 = apply_mds(state_2); + + # x will get the value 2 * 2 - 4 = 0 + let x = double(2) - vec_2[1]; + + # y will get the value vec_2[0] = 2 + let y = vec_2[x]; + + # z will then be vec_1[2] = 3 * s[0] + 4 * s[1] + let z = vec_1[y]; + + # we enforce 3 * s[0] + 4 * s[1] = 0 + enf z = 0; +} + +# We use apply_mds function to produce a list comprehension that will not get constant-folded during AST +fn apply_mds(state: felt[2]) -> felt[3] { + return [sum([s * m for (s, m) in (state, mds_row)]) for mds_row in MDS]; +} diff --git a/air-script/tests/computed_indices/computed_indices_complex.rs b/air-script/tests/computed_indices/computed_indices_complex.rs new file mode 100644 index 000000000..b39fe04fa --- /dev/null +++ b/air-script/tests/computed_indices/computed_indices_complex.rs @@ -0,0 +1,97 @@ +use winter_air::{Air, AirContext, Assertion, AuxRandElements, EvaluationFrame, ProofOptions as WinterProofOptions, TransitionConstraintDegree, TraceInfo}; +use winter_math::fields::f64::BaseElement as Felt; +use winter_math::{ExtensionOf, FieldElement, ToElements}; +use winter_utils::{ByteWriter, Serializable}; + +pub struct PublicInputs { + input: [Felt; 1], +} + +impl PublicInputs { + pub fn new(input: [Felt; 1]) -> Self { + Self { input } + } +} + +impl Serializable for PublicInputs { + fn write_into(&self, target: &mut W) { + self.input.write_into(target); + } +} + +impl ToElements for PublicInputs { + fn to_elements(&self) -> Vec { + let mut elements = Vec::new(); + elements.extend_from_slice(&self.input); + elements + } +} + +pub struct ComputedIndicesAir { + context: AirContext, + input: [Felt; 1], +} + +impl ComputedIndicesAir { + pub fn last_step(&self) -> usize { + self.trace_length() - self.context().num_transition_exemptions() + } +} + +impl Air for ComputedIndicesAir { + type BaseField = Felt; + type PublicInputs = PublicInputs; + + fn context(&self) -> &AirContext { + &self.context + } + + fn new(trace_info: TraceInfo, public_inputs: PublicInputs, options: WinterProofOptions) -> Self { + let main_degrees = vec![TransitionConstraintDegree::new(1)]; + let aux_degrees = vec![]; + let num_main_assertions = 1; + let num_aux_assertions = 0; + + let context = AirContext::new_multi_segment( + trace_info, + main_degrees, + aux_degrees, + num_main_assertions, + num_aux_assertions, + options, + ) + .set_num_transition_exemptions(2); + Self { context, input: public_inputs.input } + } + + fn get_periodic_column_values(&self) -> Vec> { + vec![] + } + + fn get_assertions(&self) -> Vec> { + let mut result = Vec::new(); + result.push(Assertion::single(0, 0, Felt::ZERO)); + result + } + + fn get_aux_assertions>(&self, aux_rand_elements: &AuxRandElements) -> Vec> { + let mut result = Vec::new(); + result + } + + fn evaluate_transition>(&self, frame: &EvaluationFrame, periodic_values: &[E], result: &mut [E]) { + let main_current = frame.current(); + let main_next = frame.next(); + result[0] = main_current[2] * E::from(Felt::new(3_u64)) + main_current[3] * E::from(Felt::new(4_u64)); + } + + fn evaluate_aux_transition(&self, main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], aux_rand_elements: &AuxRandElements, result: &mut [E]) + where F: FieldElement, + E: FieldElement + ExtensionOf, + { + let main_current = main_frame.current(); + let main_next = main_frame.next(); + let aux_current = aux_frame.current(); + let aux_next = aux_frame.next(); + } +} \ No newline at end of file diff --git a/air-script/tests/computed_indices/computed_indices_simple.air b/air-script/tests/computed_indices/computed_indices_simple.air new file mode 100644 index 000000000..3dbaae038 --- /dev/null +++ b/air-script/tests/computed_indices/computed_indices_simple.air @@ -0,0 +1,34 @@ +def ComputedIndicesAir + +trace_columns { + main: [a, b, c, d, e, f, g, h], +} + +public_inputs { + stack_inputs: [16], +} + +boundary_constraints { + enf a.first = 0; +} + +integrity_constraints { + # vec = [0, 1, 2, 3, 4]; + let vec = [i for i in 0..5]; + + # x = [0 * 2, 1 * 2, 2 * 2, 3 * 2]; + # x = [0, 2, 4, 6]; + let x = [j * vec[1 + 1] for j in 0..4]; + enf a = x[0]; + enf b = x[1]; + enf c = x[2]; + enf d = x[3]; + + # y = [0 * 1, 1 * 2, 2 * 3, 3 * 4]; + # y = [0, 2, 6, 12]; + let y = [j * vec[j + 1] for j in 0..4]; + enf e' = y[0] * e; + enf f' = y[1] * f; + enf g' = y[2] * g; + enf h' = y[3] * h; +} diff --git a/air-script/tests/computed_indices/computed_indices_simple.rs b/air-script/tests/computed_indices/computed_indices_simple.rs new file mode 100644 index 000000000..bb55d3dd5 --- /dev/null +++ b/air-script/tests/computed_indices/computed_indices_simple.rs @@ -0,0 +1,104 @@ +use winter_air::{Air, AirContext, Assertion, AuxRandElements, EvaluationFrame, ProofOptions as WinterProofOptions, TransitionConstraintDegree, TraceInfo}; +use winter_math::fields::f64::BaseElement as Felt; +use winter_math::{ExtensionOf, FieldElement, ToElements}; +use winter_utils::{ByteWriter, Serializable}; + +pub struct PublicInputs { + stack_inputs: [Felt; 16], +} + +impl PublicInputs { + pub fn new(stack_inputs: [Felt; 16]) -> Self { + Self { stack_inputs } + } +} + +impl Serializable for PublicInputs { + fn write_into(&self, target: &mut W) { + self.stack_inputs.write_into(target); + } +} + +impl ToElements for PublicInputs { + fn to_elements(&self) -> Vec { + let mut elements = Vec::new(); + elements.extend_from_slice(&self.stack_inputs); + elements + } +} + +pub struct ComputedIndicesAir { + context: AirContext, + stack_inputs: [Felt; 16], +} + +impl ComputedIndicesAir { + pub fn last_step(&self) -> usize { + self.trace_length() - self.context().num_transition_exemptions() + } +} + +impl Air for ComputedIndicesAir { + type BaseField = Felt; + type PublicInputs = PublicInputs; + + fn context(&self) -> &AirContext { + &self.context + } + + fn new(trace_info: TraceInfo, public_inputs: PublicInputs, options: WinterProofOptions) -> Self { + let main_degrees = vec![TransitionConstraintDegree::new(1), TransitionConstraintDegree::new(1), TransitionConstraintDegree::new(1), TransitionConstraintDegree::new(1), TransitionConstraintDegree::new(1), TransitionConstraintDegree::new(1), TransitionConstraintDegree::new(1), TransitionConstraintDegree::new(1)]; + let aux_degrees = vec![]; + let num_main_assertions = 1; + let num_aux_assertions = 0; + + let context = AirContext::new_multi_segment( + trace_info, + main_degrees, + aux_degrees, + num_main_assertions, + num_aux_assertions, + options, + ) + .set_num_transition_exemptions(2); + Self { context, stack_inputs: public_inputs.stack_inputs } + } + + fn get_periodic_column_values(&self) -> Vec> { + vec![] + } + + fn get_assertions(&self) -> Vec> { + let mut result = Vec::new(); + result.push(Assertion::single(0, 0, Felt::ZERO)); + result + } + + fn get_aux_assertions>(&self, aux_rand_elements: &AuxRandElements) -> Vec> { + let mut result = Vec::new(); + result + } + + fn evaluate_transition>(&self, frame: &EvaluationFrame, periodic_values: &[E], result: &mut [E]) { + let main_current = frame.current(); + let main_next = frame.next(); + result[0] = main_current[0]; + result[1] = main_current[1] - E::from(Felt::new(2_u64)); + result[2] = main_current[2] - E::from(Felt::new(4_u64)); + result[3] = main_current[3] - E::from(Felt::new(6_u64)); + result[4] = main_next[4]; + result[5] = main_next[5] - E::from(Felt::new(2_u64)) * main_current[5]; + result[6] = main_next[6] - E::from(Felt::new(6_u64)) * main_current[6]; + result[7] = main_next[7] - E::from(Felt::new(12_u64)) * main_current[7]; + } + + fn evaluate_aux_transition(&self, main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], aux_rand_elements: &AuxRandElements, result: &mut [E]) + where F: FieldElement, + E: FieldElement + ExtensionOf, + { + let main_current = main_frame.current(); + let main_next = main_frame.next(); + let aux_current = aux_frame.current(); + let aux_next = aux_frame.next(); + } +} \ No newline at end of file diff --git a/air-script/tests/computed_indices/mod.rs b/air-script/tests/computed_indices/mod.rs new file mode 100644 index 000000000..5f582b136 --- /dev/null +++ b/air-script/tests/computed_indices/mod.rs @@ -0,0 +1,7 @@ +#[rustfmt::skip] +#[allow(clippy::all)] +mod computed_indices_complex; +#[rustfmt::skip] +#[allow(clippy::all)] +mod computed_indices_simple; +mod test_air; diff --git a/air-script/tests/computed_indices/test_air.rs b/air-script/tests/computed_indices/test_air.rs new file mode 100644 index 000000000..5a953f6e5 --- /dev/null +++ b/air-script/tests/computed_indices/test_air.rs @@ -0,0 +1,61 @@ +use winter_air::Air; +use winter_math::fields::f64::BaseElement as Felt; +use winterfell::{Trace, TraceTable}; + +use crate::{ + computed_indices::computed_indices_simple::{ComputedIndicesAir, PublicInputs}, + helpers::{AirTester, MyTraceTable}, +}; + +#[derive(Clone)] +struct ComputedIndicesAirTester {} + +impl AirTester for ComputedIndicesAirTester { + type PubInputs = PublicInputs; + + fn build_main_trace(&self, length: usize) -> MyTraceTable { + let trace_width = 8; + let mut trace = TraceTable::new(trace_width, length); + + trace.fill( + |state| { + state[0] = Felt::new(0); + state[1] = Felt::new(2); + state[2] = Felt::new(4); + state[3] = Felt::new(6); + state[4] = Felt::new(0); + state[5] = Felt::new(0); + state[6] = Felt::new(0); + state[7] = Felt::new(0); + }, + |_, state| { + state[4] *= Felt::new(0); + state[5] *= Felt::new(2); + state[6] *= Felt::new(6); + state[7] *= Felt::new(12); + }, + ); + + MyTraceTable::new(trace, 0) + } + + fn public_inputs(&self) -> PublicInputs { + let zero = Felt::new(0); + PublicInputs::new([zero; 16]) + } +} + +#[test] +fn test_computed_indices_air() { + let air_tester = Box::new(ComputedIndicesAirTester {}); + let length = 1024; + + let main_trace = air_tester.build_main_trace(length); + let aux_trace = air_tester.build_aux_trace(length); + let pub_inputs = air_tester.public_inputs(); + let trace_info = air_tester.build_trace_info(length); + let options = air_tester.build_proof_options(); + + let air = ComputedIndicesAir::new(trace_info, pub_inputs, options); + main_trace.validate::(&air, aux_trace.as_ref()); +} diff --git a/air-script/tests/mod.rs b/air-script/tests/mod.rs index eb8adac8c..1f8c0fc09 100644 --- a/air-script/tests/mod.rs +++ b/air-script/tests/mod.rs @@ -9,6 +9,8 @@ mod bitwise; #[allow(unused_variables, dead_code, unused_mut)] mod buses; #[allow(unused_variables, dead_code, unused_mut)] +mod computed_indices; +#[allow(unused_variables, dead_code, unused_mut)] mod constant_in_range; #[allow(unused_variables, dead_code, unused_mut)] mod constants; diff --git a/air/src/passes/translate_from_mir.rs b/air/src/passes/translate_from_mir.rs index 06fe024eb..5fffa9a38 100644 --- a/air/src/passes/translate_from_mir.rs +++ b/air/src/passes/translate_from_mir.rs @@ -6,9 +6,12 @@ use air_parser::{ }; use air_pass::Pass; use miden_diagnostics::{DiagnosticsHandler, Severity, SourceSpan, Span, Spanned}; -use mir::ir::{ - Boundary as MirBoundary, ConstantValue, Link, Mir, MirValue, Op, Parent, SpannedMirValue, - TraceAccess as MirTraceAccess, +use mir::{ + ir::{ + Boundary as MirBoundary, ConstantValue, Link, Mir, MirAccessType, MirValue, Op, Parent, + SpannedMirValue, TraceAccess as MirTraceAccess, + }, + passes::get_inner_const, }; use crate::{CompileError, graph::NodeIndex, ir::*}; @@ -135,10 +138,13 @@ struct AirBuilder<'a> { /// so we need to ensure these cases are properly indexed. fn accessor_to_scalar(mir_node: &Link) -> Link { if let Some(accessor) = mir_node.as_accessor() { - match accessor.access_type { - AccessType::Index(index) => { + match accessor.access_type.clone() { + MirAccessType::Index(index) => { if let Some(vec) = accessor.indexable.as_vector() { let children = vec.elements.borrow().deref().clone(); + let index = get_inner_const(&index) + .expect("Index should be a constant value after constant propagation") + as usize; if index >= children.len() { panic!( "Index out of bounds during indexed accessor translation from MIR to AIR: {index}", @@ -149,7 +155,7 @@ fn accessor_to_scalar(mir_node: &Link) -> Link { mir_node.clone() } }, - AccessType::Default => { + MirAccessType::Default => { add_row_offset_if_trace_access(&accessor.indexable, accessor.offset) }, _ => mir_node.clone(), diff --git a/mir/src/ir/nodes/ops/accessor.rs b/mir/src/ir/nodes/ops/accessor.rs index aea462cf4..5a5d8318f 100644 --- a/mir/src/ir/nodes/ops/accessor.rs +++ b/mir/src/ir/nodes/ops/accessor.rs @@ -1,6 +1,5 @@ use std::hash::Hash; -use air_parser::ast::AccessType; use miden_diagnostics::{SourceSpan, Spanned}; use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; @@ -14,7 +13,7 @@ use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singlet pub struct Accessor { pub parents: Vec>, pub indexable: Link, - pub access_type: AccessType, + pub access_type: MirAccessType, pub offset: usize, pub _node: Singleton, pub _owner: Singleton, @@ -22,10 +21,18 @@ pub struct Accessor { pub span: SourceSpan, } +#[derive(Hash, Clone, PartialEq, Eq, Debug, Default)] +pub enum MirAccessType { + #[default] + Default, + Index(Link), + Matrix(Link, Link), +} + impl Accessor { pub fn create( indexable: Link, - access_type: AccessType, + access_type: MirAccessType, offset: usize, span: SourceSpan, ) -> Link { @@ -43,7 +50,14 @@ impl Accessor { impl Parent for Accessor { type Child = Op; fn children(&self) -> Link>> { - Link::new(vec![self.indexable.clone()]) + let vec = match self.access_type { + MirAccessType::Default => vec![self.indexable.clone()], + MirAccessType::Index(ref idx) => vec![self.indexable.clone(), idx.clone()], + MirAccessType::Matrix(ref row, ref col) => { + vec![self.indexable.clone(), row.clone(), col.clone()] + }, + }; + Link::new(vec) } } diff --git a/mir/src/ir/nodes/ops/mod.rs b/mir/src/ir/nodes/ops/mod.rs index 64d1674a3..c45120d47 100644 --- a/mir/src/ir/nodes/ops/mod.rs +++ b/mir/src/ir/nodes/ops/mod.rs @@ -15,7 +15,7 @@ mod sub; mod value; mod vector; -pub use accessor::Accessor; +pub use accessor::{Accessor, MirAccessType}; pub use add::Add; pub use boundary::Boundary; pub use bus_op::{BusOp, BusOpKind}; diff --git a/mir/src/ir/nodes/ops/value.rs b/mir/src/ir/nodes/ops/value.rs index c102f5777..a515df1bd 100644 --- a/mir/src/ir/nodes/ops/value.rs +++ b/mir/src/ir/nodes/ops/value.rs @@ -21,6 +21,12 @@ impl Value { pub fn create(value: SpannedMirValue) -> Link { Op::Value(Self { value, ..Default::default() }).into() } + pub fn get_inner_const(&self) -> Option { + match &self.value.value { + MirValue::Constant(ConstantValue::Felt(v)) => Some(*v), + _ => None, + } + } } impl From for Value { diff --git a/mir/src/passes/constant_propagation.rs b/mir/src/passes/constant_propagation.rs index 139eeab65..9e066fa00 100644 --- a/mir/src/passes/constant_propagation.rs +++ b/mir/src/passes/constant_propagation.rs @@ -1,6 +1,5 @@ use std::ops::Deref; -use air_parser::ast::AccessType; use air_pass::Pass; use miden_diagnostics::{DiagnosticsHandler, SourceSpan, Spanned}; @@ -8,9 +7,10 @@ use super::visitor::Visitor; use crate::{ CompileError, ir::{ - BackLink, ConstantValue, Graph, Link, Mir, MirValue, Node, Op, Parent, SpannedMirValue, - Value, + ConstantValue, Graph, Link, Mir, MirAccessType, MirValue, Node, Op, Parent, + SpannedMirValue, Value, }, + passes::handle_accessor_visit, }; pub struct ConstantPropagation<'a> { @@ -40,11 +40,7 @@ impl<'a> ConstantPropagation<'a> { // each visit_*_bis function returns an Option> instead of Result<(), CompileError>, // to mutate the nodes (e.g. modifying a Add(lhs, rhs) to Value(lhs + rhs)). impl ConstantPropagation<'_> { - fn visit_add_bis( - &mut self, - _graph: &mut Graph, - add: Link, - ) -> Result>, CompileError> { + fn visit_add_bis(&mut self, add: Link) -> Result>, CompileError> { // safe to unwrap because we just dispatched on it let add_ref = add.as_add().unwrap(); let lhs = add_ref.lhs.clone(); @@ -59,11 +55,7 @@ impl ConstantPropagation<'_> { } } - fn visit_sub_bis( - &mut self, - _graph: &mut Graph, - sub: Link, - ) -> Result>, CompileError> { + fn visit_sub_bis(&mut self, sub: Link) -> Result>, CompileError> { // safe to unwrap because we just dispatched on it let sub_ref = sub.as_sub().unwrap(); let lhs = sub_ref.lhs.clone(); @@ -76,11 +68,7 @@ impl ConstantPropagation<'_> { } } - fn visit_mul_bis( - &mut self, - _graph: &mut Graph, - mul: Link, - ) -> Result>, CompileError> { + fn visit_mul_bis(&mut self, mul: Link) -> Result>, CompileError> { // safe to unwrap because we just dispatched on it let mul_ref = mul.as_mul().unwrap(); let lhs = mul_ref.lhs.clone(); @@ -97,11 +85,7 @@ impl ConstantPropagation<'_> { } } - fn visit_exp_bis( - &mut self, - _graph: &mut Graph, - exp: Link, - ) -> Result>, CompileError> { + fn visit_exp_bis(&mut self, exp: Link) -> Result>, CompileError> { // safe to unwrap because we just dispatched on it let exp_ref = exp.as_exp().unwrap(); let lhs = exp_ref.lhs.clone(); @@ -121,6 +105,10 @@ impl ConstantPropagation<'_> { try_fold_const_binary_op(lhs, rhs, exp.clone(), exp_ref.span()) } } + + fn visit_accessor_bis(&mut self, accessor: Link) -> Result>, CompileError> { + handle_accessor_visit(accessor.clone(), true, self.diagnostics) + } } impl Visitor for ConstantPropagation<'_> { @@ -148,7 +136,7 @@ impl Visitor for ConstantPropagation<'_> { combined_roots.collect() } - fn visit_node(&mut self, graph: &mut Graph, node: Link) -> Result<(), CompileError> { + fn visit_node(&mut self, _graph: &mut Graph, node: Link) -> Result<(), CompileError> { if node.is_stale() { return Ok(()); } @@ -156,36 +144,29 @@ impl Visitor for ConstantPropagation<'_> { // In this pass, we both need to dispatch the visitor depending on the node type, // and also mutate the node if needed. We implement custom visit_*_bis methods // that returns a Some(updated_node) if we need to update the node's value. - let updated_op: Result>, CompileError> = match node.borrow().deref() { - Node::Add(a) => to_link_and(a.clone(), graph, |g, el| self.visit_add_bis(g, el)), - Node::Sub(s) => to_link_and(s.clone(), graph, |g, el| self.visit_sub_bis(g, el)), - Node::Mul(m) => to_link_and(m.clone(), graph, |g, el| self.visit_mul_bis(g, el)), - Node::Exp(e) => to_link_and(e.clone(), graph, |g, el| self.visit_exp_bis(g, el)), - // For all the following cases, there is nothing to fold + let updated_op: Option> = match node.borrow().deref() { + Node::Add(a) => a.to_link().map_or(Ok(None), |el| self.visit_add_bis(el))?, + Node::Sub(s) => s.to_link().map_or(Ok(None), |el| self.visit_sub_bis(el))?, + Node::Mul(m) => m.to_link().map_or(Ok(None), |el| self.visit_mul_bis(el))?, + Node::Exp(e) => e.to_link().map_or(Ok(None), |el| self.visit_exp_bis(el))?, + Node::Accessor(e) => e.to_link().map_or(Ok(None), |el| self.visit_accessor_bis(el))?, Node::Vector(_) | Node::Matrix(_) | Node::Enf(_) | Node::Boundary(_) | Node::BusOp(_) | Node::Value(_) - | Node::Accessor(_) - | Node::None(_) => Ok(None), - Node::Function(_) | Node::Evaluator(_) | Node::Call(_) => { - unreachable!( - "Unexpected node during Mir's ConstantPropagation: Function, Evaluators and Calls should have been inlined before this pass. Found: {:?}", - node - ); - }, - Node::If(_) | Node::For(_) | Node::Fold(_) | Node::Parameter(_) => { + | Node::None(_) => None, + _ => { unreachable!( - "Unexpected node during Mir's ConstantPropagation: If, For, Fold and Parameter should have been unrolled before this pass. Found: {:?}", + "Unexpected node during Mir's ConstantPropagation: Function, Evaluators, Calls, If, For, Fold and Parameter should have been inlined and unrolled before this pass. Found: {:?}", node ); }, }; // We update the node if needed - if let Some(updated_op) = updated_op? { + if let Some(updated_op) = updated_op { node.as_op().unwrap().set(&updated_op); } @@ -196,43 +177,23 @@ impl Visitor for ConstantPropagation<'_> { // HELPERS FUNCTIONS // ================================================================================================ -/// Tries to upgrade a BackLink to a Link and apply a given closure to it if it is successful, -/// otherwise returns None. -fn to_link_and( - back: BackLink, - graph: &mut Graph, - f: F, -) -> Result>, CompileError> -where - F: FnOnce(&mut Graph, Link) -> Result>, CompileError>, -{ - if let Some(op) = back.to_link() { - f(graph, op) - } else { - Ok(None) - } -} - -/// Helper function to extract the constant felt value from a Link if it is one. -fn get_inner_const(value: &Link) -> Option { +pub fn get_inner_const(value: &Link) -> Option { match value.borrow().deref() { - Op::Value(Value { - value: - SpannedMirValue { - value: MirValue::Constant(ConstantValue::Felt(c)), - .. - }, - .. - }) => Some(*c), + Op::Value(v) => v.get_inner_const(), Op::Accessor(accessor) => { match (accessor.access_type.clone(), accessor.indexable.borrow().deref()) { - (AccessType::Default, _) => get_inner_const(&accessor.indexable), - (AccessType::Index(index), Op::Vector(vector)) => { + (MirAccessType::Default, _) => get_inner_const(&accessor.indexable), + (MirAccessType::Index(index), Op::Vector(vector)) => { + let index = get_inner_const(&index).expect("Expected constant index") as usize; + let vec_children = vector.children(); let vec_ref = vec_children.borrow(); vec_ref.get(index).and_then(get_inner_const) }, - (AccessType::Matrix(row, col), Op::Matrix(matrix)) => { + (MirAccessType::Matrix(row, col), Op::Matrix(matrix)) => { + let row = get_inner_const(&row).expect("Expected constant row") as usize; + let col = get_inner_const(&col).expect("Expected constant column") as usize; + let mat_children = matrix.children(); let mat_ref = mat_children.borrow(); mat_ref.get(row).and_then(|row| { diff --git a/mir/src/passes/inlining.rs b/mir/src/passes/inlining.rs index 79c637bfc..1bff0ab39 100644 --- a/mir/src/passes/inlining.rs +++ b/mir/src/passes/inlining.rs @@ -30,6 +30,59 @@ impl<'a> Inlining<'a> { pub fn new(diagnostics: &'a DiagnosticsHandler) -> Self { Self { diagnostics } } + + /// Runs the Inlining pass once (with both InliningFirstPass and InliningSecondPass) + /// + /// Returns true if any calls were inlined, false otherwise, to let the caller know if any + /// changes were made or if we reached a fixed point. + fn run_once(&mut self, ir: &mut Mir) -> Result { + // The first pass only identifies the call graph dependencies and the needed calls to inline + let mut first_pass = InliningFirstPass::new(self.diagnostics); + Visitor::run(&mut first_pass, ir.constraint_graph_mut())?; + + // We then create the inlining order (inlining first the functions and evaluators that do + // not call other functions or evaluators) + let func_eval_inlining_order = + create_inlining_order(self.diagnostics, first_pass.func_eval_dependency_graph.clone())?; + + // The second pass actually inlines the calls + let mut second_pass = InliningSecondPass::new( + self.diagnostics, + func_eval_inlining_order.clone(), + first_pass.func_eval_nodes_where_called.clone(), + ); + Visitor::run(&mut second_pass, ir.constraint_graph_mut())?; + + Ok(second_pass.had_calls) + } +} + +// If we have to run the inlining algorithm `INLINING_LIMIT`, we return an error +const INLINING_LIMIT: usize = 10; + +impl Pass for Inlining<'_> { + type Input<'a> = Mir; + type Output<'a> = Mir; + type Error = CompileError; + + fn run<'a>(&mut self, mut ir: Self::Input<'a>) -> Result, Self::Error> { + let mut had_calls = true; + let mut iterations = 0; + + while had_calls && iterations < INLINING_LIMIT { + had_calls = self.run_once(&mut ir)?; + iterations += 1; + } + + if had_calls { + self.diagnostics.error( + "Inlining call depth limit reached, some calls may not have been inlined. Aborting.".to_string(), + ); + return Err(CompileError::Failed); + } + + Ok(ir) + } } pub struct InliningFirstPass<'a> { @@ -86,6 +139,7 @@ pub struct InliningSecondPass<'a> { // HashMap)> func_eval_nodes_where_called: HashMap, Vec>)>, // Op is a Call here + had_calls: bool, } impl<'a> InliningSecondPass<'a> { @@ -102,37 +156,11 @@ impl<'a> InliningSecondPass<'a> { params_for_ref_node: HashMap::new(), func_eval_nodes_where_called, func_eval_inlining_order, + had_calls: false, } } } -impl Pass for Inlining<'_> { - type Input<'a> = Mir; - type Output<'a> = Mir; - type Error = CompileError; - - fn run<'a>(&mut self, mut ir: Self::Input<'a>) -> Result, Self::Error> { - // The first pass only identifies the call graph dependencies and the needed calls to inline - let mut first_pass = InliningFirstPass::new(self.diagnostics); - Visitor::run(&mut first_pass, ir.constraint_graph_mut())?; - - // We then create the inlining order (inlining first the functions and evaluators that do - // not call other functions or evaluators) - let func_eval_inlining_order = - create_inlining_order(self.diagnostics, first_pass.func_eval_dependency_graph.clone())?; - - // The second pass actually inlines the calls - let mut second_pass = InliningSecondPass::new( - self.diagnostics, - func_eval_inlining_order.clone(), - first_pass.func_eval_nodes_where_called.clone(), - ); - Visitor::run(&mut second_pass, ir.constraint_graph_mut())?; - - Ok(ir) - } -} - /// Helper function to create the inlining order depending on the dependency graph /// /// Raises an error if a circular dependency is detected @@ -283,7 +311,10 @@ impl Visitor for InliningSecondPass<'_> { call_nodes_to_inline_in_order } fn run(&mut self, graph: &mut Graph) -> Result<(), CompileError> { - for root_node in self.root_nodes_to_visit(graph).iter() { + let root_nodes_to_visit = self.root_nodes_to_visit(graph); + self.had_calls = !root_nodes_to_visit.is_empty(); + + for root_node in root_nodes_to_visit { let mut updated_op = None; if let Some(op) = root_node.as_op() { @@ -635,7 +666,7 @@ fn unpack_evaluator_arguments(args: &[Link]) -> Vec> { _ => unreachable!("expected trace access binding, got {:?}", value), }; - args_unpacked.push(indexable.clone()); + args_unpacked.push(arg.clone()); } else if let Some(parameter) = indexable.as_parameter() { let Parameter { ty, .. } = parameter.deref(); let _size = match ty { @@ -644,7 +675,7 @@ fn unpack_evaluator_arguments(args: &[Link]) -> Vec> { _ => unreachable!("expected felt or vector, got {:?}", ty), }; - args_unpacked.push(indexable.clone()); + args_unpacked.push(arg.clone()); } else { unreachable!("expected value or parameter (or accessor on one), got {:?}", arg); } diff --git a/mir/src/passes/mod.rs b/mir/src/passes/mod.rs index 74c72ffd0..789a74017 100644 --- a/mir/src/passes/mod.rs +++ b/mir/src/passes/mod.rs @@ -7,14 +7,18 @@ use std::{collections::HashMap, ops::Deref}; pub use constant_propagation::ConstantPropagation; pub use inlining::Inlining; -use miden_diagnostics::Spanned; +use miden_diagnostics::{DiagnosticsHandler, Spanned}; pub use translate::AstToMir; pub use unrolling::Unrolling; pub use visitor::Visitor; -use crate::ir::{ - Accessor, Add, Boundary, BusOp, Call, Enf, Exp, Fold, For, If, Link, MatchArm, Matrix, Mul, - Node, Op, Owner, Parameter, Parent, Sub, Value, Vector, +use crate::{ + CompileError, + ir::{ + Accessor, Add, Boundary, BusOp, Call, Enf, Exp, Fold, For, If, Link, MatchArm, Matrix, + MirAccessType, MirValue, Mul, Node, Op, Owner, Parameter, Parent, PublicInputAccess, + SpannedMirValue, Sub, TraceAccess, Value, Vector, + }, }; /// Helper to duplicate a MIR node and its children recursively @@ -162,10 +166,19 @@ pub fn duplicate_node( }, Op::Accessor(accessor) => { let indexable = accessor.indexable.clone(); - let access_type = accessor.access_type.clone(); + let new_access_type = match accessor.access_type.clone() { + MirAccessType::Default => MirAccessType::Default, + MirAccessType::Index(index) => { + MirAccessType::Index(duplicate_node(index, current_replace_map)) + }, + MirAccessType::Matrix(row, col) => MirAccessType::Matrix( + duplicate_node(row, current_replace_map), + duplicate_node(col, current_replace_map), + ), + }; let offset = accessor.offset; let new_indexable = duplicate_node(indexable, current_replace_map); - Accessor::create(new_indexable, access_type, offset, accessor.span()) + Accessor::create(new_indexable, new_access_type, offset, accessor.span()) }, Op::BusOp(bus_op) => { let bus = bus_op.bus.clone(); @@ -187,8 +200,9 @@ pub fn duplicate_node( if let Some(_root_ref) = owner_ref.as_root() { new_param.as_parameter_mut().unwrap().set_ref_node(owner_ref); - } else if let Some((_replaced_node, replaced_by)) = - current_replace_map.get(&owner_ref.as_op().unwrap().get_ptr()) + } else if let Some(op_ref) = owner_ref.as_op() + && let Some((_replaced_node, replaced_by)) = + current_replace_map.get(&op_ref.get_ptr()) { new_param .as_parameter_mut() @@ -230,46 +244,46 @@ pub fn duplicate_node_or_replace( match node.borrow().deref() { Op::Enf(enf) => { let expr = enf.expr.clone(); - let new_expr = current_replace_map.get(&expr.get_ptr()).unwrap().1.clone(); + let new_expr = current_replace_map[&expr.get_ptr()].1.clone(); let new_node = Enf::create(new_expr, enf.span()); current_replace_map.insert(node.get_ptr(), (node.clone(), new_node)); }, Op::Boundary(boundary) => { let expr = boundary.expr.clone(); let kind = boundary.kind; - let new_expr = current_replace_map.get(&expr.get_ptr()).unwrap().1.clone(); + let new_expr = current_replace_map[&expr.get_ptr()].1.clone(); let new_node = Boundary::create(new_expr, kind, boundary.span()); current_replace_map.insert(node.get_ptr(), (node.clone(), new_node)); }, Op::Add(add) => { let lhs = add.lhs.clone(); let rhs = add.rhs.clone(); - let new_lhs_node = current_replace_map.get(&lhs.get_ptr()).unwrap().1.clone(); - let new_rhs_node = current_replace_map.get(&rhs.get_ptr()).unwrap().1.clone(); + let new_lhs_node = current_replace_map[&lhs.get_ptr()].1.clone(); + let new_rhs_node = current_replace_map[&rhs.get_ptr()].1.clone(); let new_node = Add::create(new_lhs_node, new_rhs_node, add.span()); current_replace_map.insert(node.get_ptr(), (node.clone(), new_node)); }, Op::Sub(sub) => { let lhs = sub.lhs.clone(); let rhs = sub.rhs.clone(); - let new_lhs_node = current_replace_map.get(&lhs.get_ptr()).unwrap().1.clone(); - let new_rhs_node = current_replace_map.get(&rhs.get_ptr()).unwrap().1.clone(); + let new_lhs_node = current_replace_map[&lhs.get_ptr()].1.clone(); + let new_rhs_node = current_replace_map[&rhs.get_ptr()].1.clone(); let new_node = Sub::create(new_lhs_node, new_rhs_node, sub.span()); current_replace_map.insert(node.get_ptr(), (node.clone(), new_node)); }, Op::Mul(mul) => { let lhs = mul.lhs.clone(); let rhs = mul.rhs.clone(); - let new_lhs_node = current_replace_map.get(&lhs.get_ptr()).unwrap().1.clone(); - let new_rhs_node = current_replace_map.get(&rhs.get_ptr()).unwrap().1.clone(); + let new_lhs_node = current_replace_map[&lhs.get_ptr()].1.clone(); + let new_rhs_node = current_replace_map[&rhs.get_ptr()].1.clone(); let new_node = Mul::create(new_lhs_node, new_rhs_node, mul.span()); current_replace_map.insert(node.get_ptr(), (node.clone(), new_node)); }, Op::Exp(exp) => { let lhs = exp.lhs.clone(); let rhs = exp.rhs.clone(); - let new_lhs_node = current_replace_map.get(&lhs.get_ptr()).unwrap().1.clone(); - let new_rhs_node = current_replace_map.get(&rhs.get_ptr()).unwrap().1.clone(); + let new_lhs_node = current_replace_map[&lhs.get_ptr()].1.clone(); + let new_rhs_node = current_replace_map[&rhs.get_ptr()].1.clone(); let new_node = Exp::create(new_lhs_node, new_rhs_node, exp.span()); current_replace_map.insert(node.get_ptr(), (node.clone(), new_node)); }, @@ -280,9 +294,8 @@ pub fn duplicate_node_or_replace( .iter() .cloned() .map(|arm| { - let new_expr = current_replace_map.get(&arm.expr.get_ptr()).unwrap().1.clone(); - let new_cond = - current_replace_map.get(&arm.condition.get_ptr()).unwrap().1.clone(); + let new_expr = current_replace_map[&arm.expr.get_ptr()].1.clone(); + let new_cond = current_replace_map[&arm.condition.get_ptr()].1.clone(); MatchArm::new(new_expr, new_cond) }) .collect::>(); @@ -297,10 +310,10 @@ pub fn duplicate_node_or_replace( .borrow() .iter() .cloned() - .map(|iterator| current_replace_map.get(&iterator.get_ptr()).unwrap().1.clone()) + .map(|iterator| current_replace_map[&iterator.get_ptr()].1.clone()) .collect::>() .into(); - let new_body = current_replace_map.get(&body.get_ptr()).unwrap().1.clone(); + let new_body = current_replace_map[&body.get_ptr()].1.clone(); let new_selector = current_replace_map .get(&selector.get_ptr()) .map(|selector| selector.1.clone()) @@ -328,7 +341,7 @@ pub fn duplicate_node_or_replace( .borrow() .iter() .cloned() - .map(|argument| current_replace_map.get(&argument.get_ptr()).unwrap().1.clone()) + .map(|argument| current_replace_map[&argument.get_ptr()].1.clone()) .collect::>(); let new_node = Call::create(function, new_arguments, call.span()); current_replace_map.insert(node.get_ptr(), (node.clone(), new_node)); @@ -337,9 +350,8 @@ pub fn duplicate_node_or_replace( let iterator = fold.iterator.clone(); let operator = fold.operator.clone(); let initial_value = fold.initial_value.clone(); - let new_iterator = current_replace_map.get(&iterator.get_ptr()).unwrap().1.clone(); - let new_initial_value = - current_replace_map.get(&initial_value.get_ptr()).unwrap().1.clone(); + let new_iterator = current_replace_map[&iterator.get_ptr()].1.clone(); + let new_initial_value = current_replace_map[&initial_value.get_ptr()].1.clone(); let new_node = Fold::create(new_iterator, operator, new_initial_value, fold.span()); current_replace_map.insert(node.get_ptr(), (node.clone(), new_node)); }, @@ -350,7 +362,7 @@ pub fn duplicate_node_or_replace( let new_children = children .iter() .cloned() - .map(|child| current_replace_map.get(&child.get_ptr()).unwrap().1.clone()) + .map(|child| current_replace_map[&child.get_ptr()].1.clone()) .collect(); let new_node = Vector::create(new_children, vector.span()); current_replace_map.insert(node.get_ptr(), (node.clone(), new_node)); @@ -372,7 +384,7 @@ pub fn duplicate_node_or_replace( let new_row_as_vec = row_children .iter() .cloned() - .map(|child| current_replace_map.get(&child.get_ptr()).unwrap().1.clone()) + .map(|child| current_replace_map[&child.get_ptr()].1.clone()) .collect::>(); let new_row = Vector::create(new_row_as_vec, row.span()); new_matrix.push(new_row); @@ -382,10 +394,20 @@ pub fn duplicate_node_or_replace( }, Op::Accessor(accessor) => { let indexable = accessor.indexable.clone(); - let access_type = accessor.access_type.clone(); + let new_access_type = match accessor.access_type.clone() { + MirAccessType::Default => MirAccessType::Default, + MirAccessType::Index(index) => { + MirAccessType::Index(current_replace_map[&index.get_ptr()].1.clone()) + }, + MirAccessType::Matrix(row, col) => MirAccessType::Matrix( + current_replace_map[&row.get_ptr()].1.clone(), + current_replace_map[&col.get_ptr()].1.clone(), + ), + }; let offset = accessor.offset; - let new_indexable = current_replace_map.get(&indexable.get_ptr()).unwrap().1.clone(); - let new_node = Accessor::create(new_indexable, access_type, offset, accessor.span()); + let new_indexable = current_replace_map[&indexable.get_ptr()].1.clone(); + let new_node = + Accessor::create(new_indexable, new_access_type, offset, accessor.span()); current_replace_map.insert(node.get_ptr(), (node.clone(), new_node)); }, Op::BusOp(bus_op) => { @@ -397,9 +419,9 @@ pub fn duplicate_node_or_replace( let new_args = args .iter() .cloned() - .map(|arg| current_replace_map.get(&arg.get_ptr()).unwrap().1.clone()) + .map(|arg| current_replace_map[&arg.get_ptr()].1.clone()) .collect(); - let new_latch = current_replace_map.get(&latch.get_ptr()).unwrap().1.clone(); + let new_latch = current_replace_map[&latch.get_ptr()].1.clone(); let new_node = BusOp::create(bus.clone(), kind, new_args, bus_op.span()); // Update latch of cloned bus_op @@ -424,7 +446,8 @@ pub fn duplicate_node_or_replace( }; if owner_ref == ref_owner { - let new_node = replace_parameter_list[parameter.position].clone(); + let replace_by_node = replace_parameter_list[parameter.position].clone(); + let new_node = duplicate_node(replace_by_node, &mut Default::default()); current_replace_map.insert(node.get_ptr(), (node.clone(), new_node)); } else { let new_param = @@ -450,3 +473,260 @@ pub fn duplicate_node_or_replace( Op::None(_) => {}, } } + +/// Helper function to extract the constant felt value from a Link if it is one. +pub fn get_inner_const(value: &Link) -> Option { + match value.borrow().deref() { + Op::Value(v) => v.get_inner_const(), + _ => None, + } +} + +/// Handle the visit of an accessor node, used for both Unrolling and ConstantPropagation passes +/// The `expect_constant_indices` bool indicates whether the indices need to be constant at +/// this stage. +pub fn handle_accessor_visit( + accessor: Link, + expect_constant_indices: bool, + diagnostics: &DiagnosticsHandler, +) -> Result>, CompileError> { + let accessor_ref = accessor.as_accessor().unwrap(); + let indexable = accessor_ref.indexable.clone(); + let mir_access_type = accessor_ref.access_type.clone(); + let offset = accessor_ref.offset; + + match mir_access_type { + // If we have a Default accessor, we add the row offset if needed, otherwise we just return + // the indexable + MirAccessType::Default => Ok(Some(add_row_offset_if_trace_access(&indexable, offset))), + // If we have an Index accessor, we compute the index and query the index-th element of the + // indexable. If the index is not a constant and computed_indices is true, we raise + // a diagnostic. If the index is not a constant and computed_indices is false, we + // keep the node as is. If the index is an out-of-bound constant, we raise a + // diagnostic. + MirAccessType::Index(index) => unroll_accessor_index_access_type( + indexable, + index, + offset, + expect_constant_indices, + diagnostics, + ), + // If we have an Matrix accessor, we compute both the corresponding row and column, and + // query the indexable accordingly. If either of row or col is not a constant and + // computed_indices is true, we raise a diagnostic. If either of row or col is not a + // constant and computed_indices is false, we keep the node as is. If either of row + // or col is an out-of-bound constant, we raise a diagnostic. + MirAccessType::Matrix(row, col) => unroll_accessor_matrix_access_type( + indexable, + row, + col, + expect_constant_indices, + diagnostics, + ), + } +} + +/// Helper function to unroll an Index accessor +fn unroll_accessor_index_access_type( + indexable: Link, + index: Link, + accessor_offset: usize, + expect_constant_indices: bool, + diagnostics: &DiagnosticsHandler, +) -> Result>, CompileError> { + let Some(index_usize) = extract_index_value(&index, expect_constant_indices, diagnostics)? + else { + return Ok(None); + }; + if let Op::Vector(indexable_vector) = indexable.borrow().deref() { + let indexable_vec = indexable_vector.children().borrow().deref().clone(); + let child_accessed = match indexable_vec.get(index_usize) { + Some(child_accessed) => child_accessed, + None => { + diagnostics + .diagnostic(miden_diagnostics::Severity::Error) + .with_message("attempted to access an index which is out of bounds") + .with_primary_label(index.span(), "index out of bounds") + .emit(); + return Err(CompileError::Failed); + }, + }; + Ok(Some(add_row_offset_if_trace_access(child_accessed, accessor_offset))) + } else if let Some(value) = indexable.clone().as_value() { + // If the indexable is either a PublicInput or a TraceAccess, we treat the index as + // an offset + let mir_value = value.value.value.clone(); + match mir_value { + MirValue::PublicInput(public_input_access) => { + let new_node = Value::create(SpannedMirValue { + span: value.value.span(), + value: MirValue::PublicInput(PublicInputAccess { + name: public_input_access.name, + index: public_input_access.index + index_usize, + }), + }); + Ok(Some(new_node)) + }, + MirValue::TraceAccess(trace_access) => { + // We also need to account for the row offset + let new_node = Value::create(SpannedMirValue { + span: value.value.span(), + value: MirValue::TraceAccess(TraceAccess { + segment: trace_access.segment, + column: trace_access.column + index_usize, + row_offset: trace_access.row_offset, + }), + }); + Ok(Some(new_node)) + }, + _ => { + unreachable!( + "Unexpected accessor, cannot have MirAccessType::Index with indexable {:?}", + indexable + ); + }, + } + } else { + unreachable!( + "Unexpected accessor, cannot have MirAccessType::Index with indexable {:?}", + indexable + ); + } +} + +/// Helper function to unroll a Matrix accessor +fn unroll_accessor_matrix_access_type( + indexable: Link, + row: Link, + col: Link, + expect_constant_indices: bool, + diagnostics: &DiagnosticsHandler, +) -> Result>, CompileError> { + let Some(row_usize) = extract_index_value(&row, expect_constant_indices, diagnostics)? else { + return Ok(None); + }; + let Some(col_usize) = extract_index_value(&col, expect_constant_indices, diagnostics)? else { + return Ok(None); + }; + // Replace the current node by the index-th element of the vector + // Raise diag if index is out of bounds + if let Op::Vector(indexable_vector) = indexable.borrow().deref() { + let indexable_vec = indexable_vector.children().borrow().deref().clone(); + let row_accessed = match indexable_vec.get(row_usize) { + Some(row_accessed) => row_accessed, + None => { + diagnostics + .diagnostic(miden_diagnostics::Severity::Error) + .with_message("attempted to access a row which is out of bounds") + .with_primary_label(row.span(), "row out of bounds") + .emit(); + return Err(CompileError::Failed); + }, + }; + if let Op::Vector(row_accessed_vector) = row_accessed.borrow().deref() { + let row_accessed_vec = row_accessed_vector.children().borrow().deref().clone(); + let child_accessed = match row_accessed_vec.get(col_usize) { + Some(child_accessed) => child_accessed, + None => { + diagnostics + .diagnostic(miden_diagnostics::Severity::Error) + .with_message("attempted to access a col which is out of bounds") + .with_primary_label(col.span(), "col out of bounds") + .emit(); + return Err(CompileError::Failed); + }, + }; + Ok(Some(child_accessed.clone())) + } else { + unreachable!( + "Unexpected accessor, cannot have MirAccessType::Matrix with indexable {:?}", + indexable + ); + } + } else if let Op::Matrix(indexable_matrix) = indexable.borrow().deref() { + let indexable_vec = indexable_matrix.children().borrow().deref().clone(); + let row_accessed = match indexable_vec.get(row_usize) { + Some(row_accessed) => row_accessed, + None => { + diagnostics + .diagnostic(miden_diagnostics::Severity::Error) + .with_message("attempted to access a row which is out of bounds") + .with_primary_label(row.span(), "row out of bounds") + .emit(); + return Err(CompileError::Failed); + }, + }; + if let Op::Vector(row_accessed_vector) = row_accessed.borrow().deref() { + let row_accessed_vec = row_accessed_vector.children().borrow().deref().clone(); + let child_accessed = match row_accessed_vec.get(col_usize) { + Some(child_accessed) => child_accessed, + None => { + diagnostics + .diagnostic(miden_diagnostics::Severity::Error) + .with_message("attempted to access a col which is out of bounds") + .with_primary_label(col.span(), "col out of bounds") + .emit(); + return Err(CompileError::Failed); + }, + }; + Ok(Some(child_accessed.clone())) + } else { + unreachable!( + "Unexpected accessor, cannot have MirAccessType::Matrix with indexable {:?}", + indexable + ); + } + } else { + unreachable!( + "Unexpected accessor, cannot have MirAccessType::Matrix with indexable {:?}", + indexable + ); + } +} + +/// Helper function to extract a usize value from an index expression with consistent error handling +/// +/// Returns: +/// - `Ok(Some(usize))` - Successfully extracted constant value +/// - `Ok(None)` - Not a constant value but not required (expect_constant_indices=false) +/// - `Err(CompileError)` - Not a constant value when required (expect_constant_indices=true) +fn extract_index_value( + index: &Link, + expect_constant_indices: bool, + diagnostics: &DiagnosticsHandler, +) -> Result, CompileError> { + match (get_inner_const(index), expect_constant_indices) { + (Some(value), _) => Ok(Some(value as usize)), + (None, true) => { + diagnostics + .diagnostic(miden_diagnostics::Severity::Error) + .with_message("the index is not constant during constant propagation") + .with_primary_label(index.span(), "index is not constant") + .emit(); + Err(CompileError::Failed) + }, + (None, false) => Ok(None), + } +} + +/// Helper function to add a row offset to a `TraceAccess` value, and return the node unchanged +/// otherwise. +fn add_row_offset_if_trace_access(node: &Link, offset: usize) -> Link { + if let Some(value) = node.clone().as_value() { + let mir_value = value.value.value.clone(); + if let MirValue::TraceAccess(trace_access) = mir_value { + Value::create(SpannedMirValue { + span: value.value.span(), + value: MirValue::TraceAccess(TraceAccess { + segment: trace_access.segment, + column: trace_access.column, + row_offset: trace_access.row_offset + offset, + }), + }) + } else { + node.clone() + } + } else { + node.clone() + } +} diff --git a/mir/src/passes/translate.rs b/mir/src/passes/translate.rs index be52c24f4..3ff75c570 100644 --- a/mir/src/passes/translate.rs +++ b/mir/src/passes/translate.rs @@ -3,7 +3,7 @@ use std::ops::Deref; use air_parser::{ LexicalScope, - ast::{self, AccessType, TraceSegmentId}, + ast::{self, AccessType, ScalarExpr, TraceSegmentId}, symbols, }; use air_pass::Pass; @@ -14,8 +14,9 @@ use crate::{ ir::{ Accessor, Add, Boundary, Builder, Bus, BusAccess, BusOp, BusOpKind, Call, ConstantValue, Enf, Evaluator, Exp, Fold, FoldOperator, For, Function, If, Link, MatchArm, Matrix, Mir, - MirType, MirValue, Mul, Op, Owner, Parameter, PublicInputAccess, PublicInputTableAccess, - Root, SpannedMirValue, Sub, TraceAccess, TraceAccessBinding, Value, Vector, + MirAccessType, MirType, MirValue, Mul, Op, Owner, Parameter, PublicInputAccess, + PublicInputTableAccess, Root, SpannedMirValue, Sub, TraceAccess, TraceAccessBinding, Value, + Vector, }, passes::duplicate_node, }; @@ -688,7 +689,7 @@ impl<'a> MirBuilder<'a> { fn translate_symbol_access( &mut self, - access: &ast::SymbolAccess, + access: &'a ast::SymbolAccess, ) -> Result, CompileError> { match access.name { // At this point during compilation, fully-qualified identifiers can only possibly refer @@ -999,6 +1000,7 @@ impl<'a> MirBuilder<'a> { .clone_from(&selector_node.borrow()); self.bindings.exit(); + Ok(for_node) } @@ -1040,7 +1042,7 @@ impl<'a> MirBuilder<'a> { fn translate_bounded_symbol_access( &mut self, - access: &ast::BoundedSymbolAccess, + access: &'a ast::BoundedSymbolAccess, ) -> Result, CompileError> { let access_node = self.translate_symbol_access(&access.column)?; let node = Boundary::builder() @@ -1051,6 +1053,28 @@ impl<'a> MirBuilder<'a> { Ok(node) } + fn translate_access_type( + &mut self, + access_type: &'a ast::AccessType, + ) -> Result { + let mir_access_type = match access_type { + AccessType::Default => MirAccessType::Default, + AccessType::Index(index) => { + let index_node = self.translate_scalar_expr(index)?; + MirAccessType::Index(index_node) + }, + AccessType::Matrix(row, col) => { + let row_node = self.translate_scalar_expr(row)?; + let col_node = self.translate_scalar_expr(col)?; + MirAccessType::Matrix(row_node, col_node) + }, + AccessType::Slice(_range_expr) => unreachable!( + "Slices should have been transformed into vector operations during constant propagation" + ), + }; + Ok(mir_access_type) + } + fn translate_bus_operation( &mut self, ast_bus_op: &'a ast::BusOperation, @@ -1091,7 +1115,7 @@ impl<'a> MirBuilder<'a> { let accessor_mut = arg_node.clone(); if let Some(accessor) = accessor_mut.as_accessor_mut() { match accessor.access_type { - AccessType::Default => { + MirAccessType::Default => { arg_node = accessor.indexable.clone(); }, _ => { @@ -1156,28 +1180,18 @@ impl<'a> MirBuilder<'a> { fn translate_symbol_access_global_or_local( &mut self, ident: &ast::Identifier, - access: &ast::SymbolAccess, + access: &'a ast::SymbolAccess, ) -> Result, CompileError> { // Special identifiers are those which are `$`-prefixed, and must refer to the names of // trace segments (e.g. `$main`) if ident.is_special() { // Must be a trace segment name - if let Some(trace_access) = self.trace_access(access) { - return Ok(Value::builder() - .value(SpannedMirValue { - span: access.span(), - value: MirValue::TraceAccess(trace_access), - }) - .build()); + if let Some(trace_access) = self.trace_access(access)? { + return Ok(trace_access); } if let Some(tab) = self.trace_access_binding(access) { - return Ok(Value::builder() - .value(SpannedMirValue { - span: access.span(), - value: MirValue::TraceAccessBinding(tab), - }) - .build()); + return Ok(tab); } // It should never be possible to reach this point - semantic analysis @@ -1188,7 +1202,7 @@ impl<'a> MirBuilder<'a> { ); } - // // If we reach here, this must be a let-bound variable + // If we reach here, this must be a let-bound variable if let Some(let_bound_access_expr) = self.bindings.get(access.name.as_ref()).cloned() { // If the let-bound variable is a parameter, we probably already have the type // @@ -1198,53 +1212,27 @@ impl<'a> MirBuilder<'a> { { param.ty = self.translate_type(access_ty); } + let mir_access_type = self.translate_access_type(&access.access_type)?; let accessor: Link = Accessor::create( duplicate_node(let_bound_access_expr, &mut Default::default()), - access.access_type.clone(), + mir_access_type, access.offset, access.span(), ); - return Ok(accessor); } - if let Some(trace_access) = self.trace_access(access) { - return Ok(Value::builder() - .value(SpannedMirValue { - span: access.span(), - value: MirValue::TraceAccess(trace_access), - }) - .build()); + if let Some(trace_access) = self.trace_access(access)? { + return Ok(trace_access); } // Otherwise, we check bindings, trace bindings, and public inputs, in that order if let Some(tab) = self.trace_access_binding(access) { - return Ok(Value::builder() - .value(SpannedMirValue { - span: access.span(), - value: MirValue::TraceAccessBinding(tab), - }) - .build()); + return Ok(tab); } - match self.public_input_access(access) { - (Some(public_input_access), None) => { - return Ok(Value::builder() - .value(SpannedMirValue { - span: access.span(), - value: MirValue::PublicInput(public_input_access), - }) - .build()); - }, - (None, Some(public_input_table_access)) => { - return Ok(Value::builder() - .value(SpannedMirValue { - span: access.span(), - value: MirValue::PublicInputTable(public_input_table_access), - }) - .build()); - }, - _ => {}, + if let Some(public_input_access) = self.public_input_access(access)? { + return Ok(public_input_access); } self.diagnostics @@ -1258,21 +1246,54 @@ impl<'a> MirBuilder<'a> { Err(CompileError::Failed) } - // Check assumptions, probably this assumed that the inlining pass did some work fn public_input_access( - &self, - access: &ast::SymbolAccess, - ) -> (Option, Option) { + &mut self, + access: &'a ast::SymbolAccess, + ) -> Result>, CompileError> { let Some(public_input) = self.mir.public_inputs.get(access.name.as_ref()) else { - return (None, None); + return Ok(None); }; - match access.access_type { - AccessType::Default => ( - None, - Some(PublicInputTableAccess::new(public_input.name(), public_input.size())), - ), + match access.access_type.clone() { + AccessType::Default => { + let public_input_table = + PublicInputTableAccess::new(public_input.name(), public_input.size()); + Ok(Some( + Value::builder() + .value(SpannedMirValue { + span: access.span(), + value: MirValue::PublicInputTable(public_input_table), + }) + .build(), + )) + }, AccessType::Index(index) => { - (Some(PublicInputAccess::new(public_input.name(), index)), None) + // If the index is a constant, we construct the corresponding PublicInputAccess + if let ScalarExpr::Const(c) = *index { + let public_input_access = + PublicInputAccess::new(public_input.name(), c.item as usize); + let value = Value::builder() + .value(SpannedMirValue { + span: access.span(), + value: MirValue::PublicInput(public_input_access), + }) + .build(); + Ok(Some(value)) + } else { + // Otherwise, we need to wrap a PublicInputAccess with an accessor. In this + // case, the accessor is not used to index into a vector, + // but rather to offset the targeted column + let public_input_access = PublicInputAccess::new(public_input.name(), 0); + let value = Value::builder() + .value(SpannedMirValue { + span: access.span(), + value: MirValue::PublicInput(public_input_access), + }) + .build(); + let mir_access_type = self.translate_access_type(&access.access_type)?; + let accessor = + Accessor::create(value, mir_access_type, access.offset, access.span()); + Ok(Some(accessor)) + } }, _ => { // This should have been caught earlier during compilation @@ -1284,22 +1305,41 @@ impl<'a> MirBuilder<'a> { } } - // Check assumptions, probably this assumed that the inlining pass did some work - fn trace_access_binding(&self, access: &ast::SymbolAccess) -> Option { + fn trace_access_binding(&self, access: &ast::SymbolAccess) -> Option> { let id = access.name.as_ref(); for segment in self.trace_columns.iter() { if let Some(binding) = segment.bindings.iter().find(|tb| tb.name.as_ref() == Some(id)) { return match &access.access_type { - AccessType::Default => Some(TraceAccessBinding { - segment: binding.segment, - offset: binding.offset, - size: binding.size, - }), - AccessType::Slice(range_expr) => Some(TraceAccessBinding { - segment: binding.segment, - offset: binding.offset + range_expr.to_slice_range().start, - size: range_expr.to_slice_range().count(), - }), + AccessType::Default => { + let tab = TraceAccessBinding { + segment: binding.segment, + offset: binding.offset, + size: binding.size, + }; + Some( + Value::builder() + .value(SpannedMirValue { + span: access.span(), + value: MirValue::TraceAccessBinding(tab), + }) + .build(), + ) + }, + AccessType::Slice(range_expr) => { + let tab = TraceAccessBinding { + segment: binding.segment, + offset: binding.offset + range_expr.to_slice_range().start, + size: range_expr.to_slice_range().count(), + }; + Some( + Value::builder() + .value(SpannedMirValue { + span: access.span(), + value: MirValue::TraceAccessBinding(tab), + }) + .build(), + ) + }, _ => None, }; } @@ -1307,8 +1347,10 @@ impl<'a> MirBuilder<'a> { None } - // Check assumptions, probably this assumed that the inlining pass did some work - fn trace_access(&self, access: &ast::SymbolAccess) -> Option { + fn trace_access( + &mut self, + access: &'a ast::SymbolAccess, + ) -> Result>, CompileError> { assert_eq!( self.trace_columns.len(), 1, @@ -1319,8 +1361,15 @@ impl<'a> MirBuilder<'a> { if segment.name == id { // We access $main[i] - if let AccessType::Index(column) = access.access_type { - Some(TraceAccess::new(TraceSegmentId::Main, column, access.offset)) + if let AccessType::Index(column) = access.access_type.clone() { + let node = self.translate_indexed_trace_access( + column, + TraceSegmentId::Main, + 0, + access.offset, + access, + )?; + Ok(Some(node)) } else { // This should have been caught earlier during compilation unreachable!( @@ -1332,20 +1381,70 @@ impl<'a> MirBuilder<'a> { segment.bindings.iter().find(|tb| tb.name.as_ref() == Some(id)) { // We access a trace binding defined in the main trace. - match access.access_type { + match access.access_type.clone() { AccessType::Default if binding.size == 1 => { - Some(TraceAccess::new(binding.segment, binding.offset, access.offset)) + let ta = TraceAccess::new(binding.segment, binding.offset, access.offset); + let value = Value::builder() + .value(SpannedMirValue { + span: access.span(), + value: MirValue::TraceAccess(ta), + }) + .build(); + let mir_binding_access = self.translate_access_type(&binding.access)?; + let accessor = Accessor::create(value, mir_binding_access, 0, access.span()); + Ok(Some(accessor)) }, - AccessType::Index(extra_offset) if binding.size > 1 => Some(TraceAccess::new( - binding.segment, - binding.offset + extra_offset, - access.offset, - )), - _ => None, + AccessType::Index(extra_offset) if binding.size > 1 => { + let node = self.translate_indexed_trace_access( + extra_offset, + binding.segment, + binding.offset, + access.offset, + access, + )?; + Ok(Some(node)) + }, + _ => Ok(None), } } else { // We do not access a trace - None + Ok(None) + } + } + + /// Helper function to translate a trace_access based on an index. If the index is a constant, + /// we build the corresponding TraceAccess. Otherwise, we build an Accessor around the + /// TraceAccess, the index should then be treated as an offset and not a way to index into a + /// collection. + fn translate_indexed_trace_access( + &mut self, + index: Box, + segment: TraceSegmentId, + offset: usize, + row_offset: usize, + access: &'a ast::SymbolAccess, + ) -> Result, CompileError> { + // If the index is a constant, we construct the corresponding TraceAccess + if let ScalarExpr::Const(c) = *index { + let ta = TraceAccess::new(segment, offset + c.item as usize, row_offset); + Ok(Value::builder() + .value(SpannedMirValue { + span: access.span(), + value: MirValue::TraceAccess(ta), + }) + .build()) + } else { + // Otherwise, we need to wrap a TraceAccess with an accessor. In this case, the accessor + // is not used to index into a vector, but rather to offset the targeted column + let ta = TraceAccess::new(segment, offset, row_offset); + let value = Value::builder() + .value(SpannedMirValue { + span: access.span(), + value: MirValue::TraceAccess(ta), + }) + .build(); + let mir_access_type = self.translate_access_type(&access.access_type)?; + Ok(Accessor::create(value, mir_access_type, 0, access.span())) } } } diff --git a/mir/src/passes/unrolling/unrolling_first_pass.rs b/mir/src/passes/unrolling/unrolling_first_pass.rs index 589616b6f..567f62d3d 100644 --- a/mir/src/passes/unrolling/unrolling_first_pass.rs +++ b/mir/src/passes/unrolling/unrolling_first_pass.rs @@ -1,16 +1,15 @@ use std::{collections::HashMap, ops::Deref}; -use air_parser::ast::AccessType; use miden_diagnostics::{DiagnosticsHandler, Spanned}; use crate::{ CompileError, ir::{ - Accessor, Graph, Link, MirType, MirValue, Node, Op, Owner, Parameter, Parent, - SpannedMirValue, TraceAccess, Value, Vector, + Accessor, ConstantValue, Graph, Link, MirAccessType, MirType, MirValue, Node, Op, Owner, + Parameter, Parent, SpannedMirValue, Value, Vector, }, passes::{ - Visitor, + Visitor, handle_accessor_visit, unrolling::{ ForInliningContext, visit_enf_bis, visit_fold_bis, visit_value_bis, visit_vector_bis, }, @@ -45,109 +44,6 @@ impl<'a> UnrollingFirstPass<'a> { } } -/// Unrolls an `Accessor` with `AccessType::Default` access type. -fn unroll_accessor_default_access_type( - indexable: Link, - accessor_offset: usize, -) -> Option> { - if let Some(value) = indexable.clone().as_value() { - let mir_value = value.value.value.clone(); - - if let MirValue::TraceAccess(trace_access) = mir_value { - let new_node = Value::create(SpannedMirValue { - span: value.value.span(), - value: MirValue::TraceAccess(TraceAccess { - segment: trace_access.segment, - column: trace_access.column, - row_offset: trace_access.row_offset + accessor_offset, - }), - }); - return Some(new_node); - } - } - Some(indexable.clone()) -} - -/// Unrolls an `Accessor` with `AccessType::Index` access type. -fn unroll_accessor_index_access_type( - indexable: Link, - index: usize, - accessor_offset: usize, -) -> Option> { - // Check that the child node is a vector, raise diag otherwise - // Replace the current node by the index-th element of the vector - // Raise diag if index is out of bounds - if let Op::Vector(indexable_vector) = indexable.borrow().deref() { - let indexable_vec = indexable_vector.children().borrow().clone(); - let child_accessed = indexable_vec - .get(index) - .unwrap_or_else(|| panic!("Index access out of bounds for indexable: {indexable:?}")); - if let Some(value) = child_accessed.clone().as_value() { - let mir_value = value.value.value.clone(); - match mir_value { - MirValue::TraceAccess(trace_access) => { - let new_node = Value::create(SpannedMirValue { - span: value.value.span(), - value: MirValue::TraceAccess(TraceAccess { - segment: trace_access.segment, - column: trace_access.column, - row_offset: trace_access.row_offset + accessor_offset, - }), - }); - Some(new_node) - }, - _ => Some(child_accessed.clone()), - } - } else { - Some(child_accessed.clone()) - } - } else { - unreachable!("indexable is {:?}", indexable); // raise diag - } -} - -/// Unrolls an `Accessor` with `AccessType::Matrix` access type. -fn unroll_accessor_matrix_access_type( - indexable: Link, - row: usize, - col: usize, -) -> Option> { - // Check that the child node is a matrix, raise diag otherwise - // Replace the current node by the index-th element of the vector - // Raise diag if index is out of bounds - if let Op::Vector(indexable_vector) = indexable.borrow().deref() { - let indexable_vec = indexable_vector.children().borrow().clone(); - let row_accessed = indexable_vec - .get(row) - .unwrap_or_else(|| panic!("Matrix access out of bounds for indexable: {indexable:?}")); - if let Op::Vector(row_accessed_vector) = row_accessed.borrow().deref() { - let row_accessed_vec = row_accessed_vector.children().borrow().clone(); - let child_accessed = row_accessed_vec.get(col).unwrap_or_else(|| { - panic!("Matrix access out of bounds for indexable: {indexable:?}") - }); - Some(child_accessed.clone()) - } else { - unreachable!("unexpected non-vector child of a Matrix: {:?}", row_accessed); - } - } else if let Op::Matrix(indexable_matrix) = indexable.borrow().deref() { - let indexable_vec = indexable_matrix.children().borrow().clone(); - let row_accessed = indexable_vec - .get(row) - .unwrap_or_else(|| panic!("Matrix access out of bounds for indexable: {indexable:?}")); - if let Op::Vector(row_accessed_vector) = row_accessed.borrow().deref() { - let row_accessed_vec = row_accessed_vector.children().borrow().clone(); - let child_accessed = row_accessed_vec.get(col).unwrap_or_else(|| { - panic!("Matrix access out of bounds for indexable: {indexable:?}") - }); - Some(child_accessed.clone()) - } else { - unreachable!("unexpected non-vector child of a Matrix: {:?}", row_accessed); - } - } else { - unreachable!("unexpected matrix access type on indexable {:?}", indexable); - } -} - // For the first pass of Unrolling, we use a tweaked version of the Visitor trait, // each visit_*_bis function returns an `Option>` instead of `Result<(), CompileError>`, // to mutate the nodes (e.g. modifying an `Operation` to `Vector`) @@ -169,6 +65,17 @@ impl UnrollingFirstPass<'_> { Ok(None) } + fn visit_accessor_bis(&mut self, accessor: Link) -> Result>, CompileError> { + let accessor_ref = accessor.as_accessor().unwrap(); + let indexable = accessor_ref.indexable.clone(); + if indexable.clone().as_parameter().is_none() { + handle_accessor_visit(accessor.clone(), false, self.diagnostics) + } else { + // We keep accessors wrapping parameters to allow for nested list comprehensions. + Ok(None) + } + } + fn visit_for_bis(&mut self, for_node: Link) -> Result>, CompileError> { // For each value produced by the iterators, we need to: // - Duplicate the body @@ -262,7 +169,7 @@ impl Visitor for UnrollingFirstPass<'_> { Node::Enf(e) => e.to_link().map_or(Ok(None), visit_enf_bis)?, Node::Fold(f) => f.to_link().map_or(Ok(None), visit_fold_bis)?, Node::Vector(v) => v.to_link().map_or(Ok(None), visit_vector_bis)?, - Node::Accessor(a) => a.to_link().map_or(Ok(None), visit_accessor_bis)?, + Node::Accessor(a) => a.to_link().map_or(Ok(None), |el| self.visit_accessor_bis(el))?, Node::Value(v) => v.to_link().map_or(Ok(None), visit_value_bis)?, Node::For(f) => f.to_link().map_or(Ok(None), |el| self.visit_for_bis(el))?, Node::Parameter(p) => { @@ -297,36 +204,6 @@ impl Visitor for UnrollingFirstPass<'_> { // HELPERS FUNCTIONS // ================================================================================================ -/// Unrolls an `Accessor` depending on its `AccessType`. -/// -/// Note: If the `indexable` is a `Parameter` (referencing a `For` node), we do not unroll the -/// `Accessor`, in order to handle nested `For` nodes. -pub fn visit_accessor_bis(accessor: Link) -> Result>, CompileError> { - let accessor_ref = accessor.as_accessor().unwrap(); - let indexable = accessor_ref.indexable.clone(); - let access_type = accessor_ref.access_type.clone(); - let offset = accessor_ref.offset; - // If the indexable is a parameter, we keep the accessor as is, in - // order to handle nested `For` nodes - if indexable.clone().as_parameter().is_none() { - match access_type { - AccessType::Default => { - return Ok(unroll_accessor_default_access_type(indexable, offset)); - }, - AccessType::Index(index) => { - return Ok(unroll_accessor_index_access_type(indexable, index, offset)); - }, - AccessType::Matrix(row, col) => { - return Ok(unroll_accessor_matrix_access_type(indexable, row, col)); - }, - AccessType::Slice(_range_expr) => { - unreachable!(); // Slices are not scalar, raise diag - }, - } - } - Ok(None) -} - /// Sanity check that all iterators have the same length. /// Note that semantic analysis should have already checked they are valid. fn validate_iterators_and_get_expected_len(iterators: &[Link]) -> usize { @@ -343,15 +220,14 @@ fn validate_iterators_and_get_expected_len(iterators: &[Link]) -> usize { iterator_expected_len } -/// Computes the length of a node that is used as an iterator in a For node. +/// Computes the length of a node that is used as an iterator in a `For` node. fn compute_iterator_len(iterator: Link) -> usize { match iterator.borrow().deref() { Op::Vector(vector) => vector.size, Op::Matrix(matrix) => matrix.size, Op::Accessor(accessor) => match &accessor.access_type { - AccessType::Default => compute_iterator_len(accessor.indexable.clone()), - AccessType::Slice(range_expr) => range_expr.to_slice_range().count(), - AccessType::Index(_) => match accessor.indexable.borrow().deref() { + MirAccessType::Default => compute_iterator_len(accessor.indexable.clone()), + MirAccessType::Index(_) => match accessor.indexable.borrow().deref() { Op::Vector(_) => 1, Op::Matrix(matrix) => { let children = matrix.children().borrow().clone(); @@ -364,7 +240,7 @@ fn compute_iterator_len(iterator: Link) -> usize { }, _ => unreachable!("Unexpected index into non indexable type"), }, - AccessType::Matrix(..) => 1, + MirAccessType::Matrix(..) => 1, }, Op::Parameter(parameter) => match parameter.ty { MirType::Felt => 1, @@ -391,12 +267,18 @@ fn get_iterator_child(op: Link, i: usize) -> Link { // If we access an outer loop parameter in the body of an inner // loop, we need to create // an Accessor for the correct index in this parameter - Op::Parameter(_parameter) => Accessor::create( - accessor.indexable.clone(), - AccessType::Index(i), - 0, - accessor.span(), - ), + Op::Parameter(_parameter) => { + let mir_access_type = MirAccessType::Index(Value::create(SpannedMirValue { + span: accessor.span(), + value: MirValue::Constant(ConstantValue::Felt(i as u64)), + })); + Accessor::create( + accessor.indexable.clone(), + mir_access_type, + 0, + accessor.span(), + ) + }, _ => op.clone(), } }, diff --git a/mir/src/tests/computed_indices.rs b/mir/src/tests/computed_indices.rs new file mode 100644 index 000000000..880963e5e --- /dev/null +++ b/mir/src/tests/computed_indices.rs @@ -0,0 +1,109 @@ +use super::{compile, expect_diagnostic}; + +#[test] +fn basic_computed_indices() { + let source = " + def test + + trace_columns { + main: [a, b, c[4]], + } + + public_inputs { + inputs: [2], + } + + boundary_constraints { + enf a.first = 0; + } + + integrity_constraints { + let x = [0, 1, 2, 3, 4]; + + enf a = x[1 + 1]; + }"; + + assert!(compile(source).is_ok()); +} + +#[test] +fn basic_computed_indices_in_lc() { + let source = " + def test + + trace_columns { + main: [a, b, c[4]], + } + + public_inputs { + inputs: [2], + } + + boundary_constraints { + enf a.first = 0; + } + + integrity_constraints { + let x = [0, 1, 2, 3, 4]; + let y = [i * x[1 + 1] for i in 0..5]; + + enf a = y[1 + 1]; + }"; + + assert!(compile(source).is_ok()); +} + +#[test] +fn computed_indices_in_lc() { + let source = " + def test + + trace_columns { + main: [a, b, c[4]], + } + + public_inputs { + inputs: [2], + } + + boundary_constraints { + enf a.first = 0; + } + + integrity_constraints { + let x = [0, 1, 2, 3, 4]; + let y = [i * x[i + 1] for i in 0..4]; + + enf a = y[1 + 1]; + }"; + + assert!(compile(source).is_ok()); +} + +// Tests that should return errors +#[test] +fn err_computed_indices_in_lc() { + let source = " + def test + + trace_columns { + main: [a, b, c[4]], + } + + public_inputs { + inputs: [2], + } + + boundary_constraints { + enf a.first = 0; + } + + integrity_constraints { + let x = [0, 1, 2, 3, 4]; + let y = [i * x[a + 1] for i in 0..4]; + + enf a = y[1 + 1]; + }"; + + expect_diagnostic(source, "error: the index is not constant during constant propagation"); +} diff --git a/mir/src/tests/mod.rs b/mir/src/tests/mod.rs index 60f49f298..49f7b6c5f 100644 --- a/mir/src/tests/mod.rs +++ b/mir/src/tests/mod.rs @@ -1,6 +1,7 @@ mod access; mod boundary_constraints; mod buses; +mod computed_indices; mod constant; mod evaluators; mod functions; diff --git a/parser/src/ast/expression.rs b/parser/src/ast/expression.rs index 49b0f285a..56c55951c 100644 --- a/parser/src/ast/expression.rs +++ b/parser/src/ast/expression.rs @@ -859,7 +859,7 @@ impl fmt::Display for Boundary { } /// Represents the way an identifier is accessed/referenced in the source. -#[derive(Hash, Debug, Clone, Eq, PartialEq, Default)] +#[derive(Debug, Clone, Eq, PartialEq, Default)] pub enum AccessType { /// Access refers to the entire bound value #[default] @@ -869,9 +869,9 @@ pub enum AccessType { /// Access binds the value at a specific index of an aggregate value (i.e. vector or matrix) /// /// The result type may be either a scalar or a vector, depending on the type of the aggregate - Index(usize), + Index(Box), /// Access binds the value at a specific row and column of a matrix value - Matrix(usize, usize), + Matrix(Box, Box), } impl fmt::Display for AccessType { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { @@ -960,7 +960,7 @@ impl SymbolAccess { AccessType::Slice(base_range) => { self.access_slice(base_range.to_slice_range(), access_type) }, - AccessType::Index(base_idx) => self.access_index(*base_idx, access_type), + AccessType::Index(base_idx) => self.access_index(base_idx.clone(), access_type), AccessType::Matrix(..) => match access_type { AccessType::Default => Ok(self.clone()), _ => Err(InvalidAccessError::IndexIntoScalar), @@ -974,13 +974,11 @@ impl SymbolAccess { AccessType::Default => Ok(self.clone()), AccessType::Index(idx) => match ty { Type::Felt => Err(InvalidAccessError::IndexIntoScalar), - Type::Vector(len) if idx >= len => Err(InvalidAccessError::IndexOutOfBounds), Type::Vector(_) => Ok(Self { access_type: AccessType::Index(idx), ty: Some(Type::Felt), ..self.clone() }), - Type::Matrix(rows, _) if idx >= rows => Err(InvalidAccessError::IndexOutOfBounds), Type::Matrix(_, cols) => Ok(Self { access_type: AccessType::Index(idx), ty: Some(Type::Vector(cols)), @@ -1012,9 +1010,6 @@ impl SymbolAccess { }, AccessType::Matrix(row, col) => match ty { Type::Felt | Type::Vector(_) => Err(InvalidAccessError::IndexIntoScalar), - Type::Matrix(rows, cols) if row >= rows || col >= cols => { - Err(InvalidAccessError::IndexOutOfBounds) - }, Type::Matrix(..) => Ok(Self { access_type: AccessType::Matrix(row, col), ty: Some(Type::Felt), @@ -1034,15 +1029,29 @@ impl SymbolAccess { AccessType::Default => Ok(self.clone()), AccessType::Index(idx) => match ty { Type::Felt => unreachable!(), - Type::Vector(len) if idx >= len => Err(InvalidAccessError::IndexOutOfBounds), Type::Vector(_) => Ok(Self { - access_type: AccessType::Index(base_range.start + idx), + access_type: AccessType::Index(Box::new(ScalarExpr::Binary(BinaryExpr { + span: self.span(), + op: BinaryOp::Add, + lhs: Box::new(ScalarExpr::Const(Span::new( + self.span(), + base_range.start as u64, + ))), + rhs: idx.clone(), + }))), ty: Some(Type::Felt), ..self.clone() }), - Type::Matrix(rows, _) if idx >= rows => Err(InvalidAccessError::IndexOutOfBounds), Type::Matrix(_, cols) => Ok(Self { - access_type: AccessType::Index(base_range.start + idx), + access_type: AccessType::Index(Box::new(ScalarExpr::Binary(BinaryExpr { + span: self.span(), + op: BinaryOp::Add, + lhs: Box::new(ScalarExpr::Const(Span::new( + self.span(), + base_range.start as u64, + ))), + rhs: idx.clone(), + }))), ty: Some(Type::Vector(cols)), ..self.clone() }), @@ -1080,9 +1089,6 @@ impl SymbolAccess { }, AccessType::Matrix(row, col) => match ty { Type::Felt | Type::Vector(_) => Err(InvalidAccessError::IndexIntoScalar), - Type::Matrix(rows, cols) if row >= rows || col >= cols => { - Err(InvalidAccessError::IndexOutOfBounds) - }, Type::Matrix(..) => Ok(Self { access_type: AccessType::Matrix(row, col), ty: Some(Type::Felt), @@ -1094,7 +1100,7 @@ impl SymbolAccess { fn access_index( &self, - base_idx: usize, + base_idx: Box, access_type: AccessType, ) -> Result { let ty = self.ty.unwrap(); @@ -1102,13 +1108,11 @@ impl SymbolAccess { AccessType::Default => Ok(self.clone()), AccessType::Index(idx) => match ty { Type::Felt => Err(InvalidAccessError::IndexIntoScalar), - Type::Vector(len) if idx >= len => Err(InvalidAccessError::IndexOutOfBounds), Type::Vector(_) => Ok(Self { access_type: AccessType::Matrix(base_idx, idx), ty: Some(Type::Felt), ..self.clone() }), - Type::Matrix(rows, _) if idx >= rows => Err(InvalidAccessError::IndexOutOfBounds), Type::Matrix(_, cols) => Ok(Self { access_type: AccessType::Matrix(base_idx, idx), ty: Some(Type::Vector(cols)), diff --git a/parser/src/ast/trace.rs b/parser/src/ast/trace.rs index b093db90a..5b26a05d0 100644 --- a/parser/src/ast/trace.rs +++ b/parser/src/ast/trace.rs @@ -212,7 +212,7 @@ impl fmt::Debug for FormatConstrainedFlags<'_> { /// [TraceBinding] is used to represent one or more columns in the execution trace that are bound to /// a name. For single columns, the size is 1. For groups, the size is the number of columns in the /// group. The offset is the column index in the trace where the first column of the binding starts. -#[derive(Copy, Clone, Spanned)] +#[derive(Clone, Spanned)] pub struct TraceBinding { #[span] pub span: SourceSpan, @@ -226,6 +226,8 @@ pub struct TraceBinding { pub size: usize, /// The effective type of this binding pub ty: Type, + /// The access type associated to this TraceBinding + pub access: AccessType, } impl TraceBinding { /// Creates a new trace binding. @@ -244,48 +246,77 @@ impl TraceBinding { offset, size, ty, + access: AccessType::Default, } } /// Returns a [Type] that describes what type of value this binding represents #[inline] pub fn ty(&self) -> Type { - self.ty + match self.access.clone() { + AccessType::Default => self.ty, + AccessType::Slice(range_expr) => Type::Vector(range_expr.to_slice_range().len()), + AccessType::Index(_) => Type::Felt, + AccessType::Matrix(..) => { + unreachable!("matrix access not supported on trace bindings") + }, + } } #[inline] pub fn is_scalar(&self) -> bool { - self.ty.is_scalar() + self.ty().is_scalar() + } + + /// Returns the size of the trace binding, taking into account how it is accessed + pub fn tb_size(&self) -> usize { + match self.ty() { + Type::Vector(len) => len, + Type::Felt => 1, + _ => self.size, + } } /// Derive a new [TraceBinding] derived from the current one given an [AccessType] pub fn access(&self, access_type: AccessType) -> Result { - match access_type { - AccessType::Default => Ok(*self), - AccessType::Slice(_) if self.is_scalar() => Err(InvalidAccessError::SliceOfScalar), - AccessType::Slice(range) => { - let slice_range = range.to_slice_range(); - if slice_range.end > self.size { - Err(InvalidAccessError::IndexOutOfBounds) - } else { - let offset = self.offset + slice_range.start; - let size = slice_range.len(); - Ok(Self { - offset, - size, - ty: Type::Vector(size), - ..*self - }) - } + let combined_access = match (self.access.clone(), access_type.clone()) { + (AccessType::Default, _) => access_type, + (_, AccessType::Default) => self.access.clone(), + (AccessType::Slice(range_expr), AccessType::Slice(range_expr1)) => { + let range_expr = range_expr.to_slice_range(); + let range_expr1 = range_expr1.to_slice_range(); + let combined_range = + (range_expr.start + range_expr1.start)..(range_expr.end + range_expr1.end); + AccessType::Slice(combined_range.into()) }, - AccessType::Index(_) if self.is_scalar() => Err(InvalidAccessError::IndexIntoScalar), - AccessType::Index(idx) if idx >= self.size => Err(InvalidAccessError::IndexOutOfBounds), - AccessType::Index(idx) => { - let offset = self.offset + idx; - Ok(Self { offset, size: 1, ty: Type::Felt, ..*self }) + (AccessType::Slice(range_expr), AccessType::Index(index_expr)) => { + let range_expr_usize = range_expr.to_slice_range(); + let new_expr = ScalarExpr::Binary(BinaryExpr::new( + self.span(), + BinaryOp::Add, + ScalarExpr::Const(Span::new(range_expr.span(), range_expr_usize.start as u64)), + *index_expr, + )); + AccessType::Index(Box::new(new_expr)) }, - AccessType::Matrix(..) => Err(InvalidAccessError::IndexIntoScalar), + (AccessType::Index(_), AccessType::Index(_)) => { + return Err(InvalidAccessError::IndexIntoScalar); + }, + (AccessType::Matrix(..), _) | (_, AccessType::Matrix(..)) => { + return Err(InvalidAccessError::IndexIntoScalar); + }, + (expression::AccessType::Index(_), expression::AccessType::Slice(_)) => { + return Err(InvalidAccessError::SliceOfScalar); + }, + }; + + if let AccessType::Index(idx) = combined_access.clone() + && let ScalarExpr::Const(value) = *idx + && value.item as usize >= self.size + { + return Err(InvalidAccessError::IndexOutOfBounds); } + Ok(Self { access: combined_access, ..*self }) } } impl Eq for TraceBinding {} diff --git a/parser/src/ast/types.rs b/parser/src/ast/types.rs index 3155180ec..cc2f26931 100644 --- a/parser/src/ast/types.rs +++ b/parser/src/ast/types.rs @@ -52,7 +52,6 @@ impl Type { Ok(Self::Vector(slice_range.len())) } }, - AccessType::Index(idx) if idx >= len => Err(InvalidAccessError::IndexOutOfBounds), AccessType::Index(_) => Ok(Self::Felt), AccessType::Matrix(..) => Err(InvalidAccessError::IndexIntoScalar), _ => unreachable!(), @@ -66,11 +65,7 @@ impl Type { Ok(Self::Matrix(slice_range.len(), cols)) } }, - AccessType::Index(idx) if idx >= rows => Err(InvalidAccessError::IndexOutOfBounds), AccessType::Index(_) => Ok(Self::Vector(cols)), - AccessType::Matrix(row, col) if row >= rows || col >= cols => { - Err(InvalidAccessError::IndexOutOfBounds) - }, AccessType::Matrix(..) => Ok(Self::Felt), _ => unreachable!(), }, diff --git a/parser/src/ast/visit.rs b/parser/src/ast/visit.rs index 834c0d95d..9537c9159 100644 --- a/parser/src/ast/visit.rs +++ b/parser/src/ast/visit.rs @@ -26,7 +26,7 @@ use crate::ast; /// /// use miden_diagnostics::{Span, Spanned}; /// -/// use air_parser::ast::{self, visit}; +/// use air_parser::ast::{self, visit, ScalarExpr}; /// /// /// A simple visitor which replaces accesses to constant values with the values themselves, /// /// evaluates constant expressions (i.e. expressions whose operands are constant), and propagates @@ -59,19 +59,13 @@ use crate::ast; /// core::mem::replace(expr, ast::ScalarExpr::Const(Span::new(span, value))); /// } /// Some((span, ast::ConstantExpr::Vector(value))) => { -/// match sym.access_type { -/// ast::AccessType::Index(idx) => { -/// core::mem::replace(expr, ast::ScalarExpr::Const(Span::new(span, value[idx]))); -/// } -/// _ => panic!("invalid constant reference, expected scalar access"), +/// if let ast::AccessType::Index(idx) = sym.access_type.clone() && let ScalarExpr::Const(idx) = *idx { +/// core::mem::replace(expr, ast::ScalarExpr::Const(Span::new(span, value[idx.item as usize]))); /// } /// } /// Some((span, ast::ConstantExpr::Matrix(value))) => { -/// match sym.access_type { -/// ast::AccessType::Matrix(row, col) => { -/// core::mem::replace(expr, ast::ScalarExpr::Const(Span::new(span, value[row][col]))); -/// } -/// _ => panic!("invalid constant reference, expected scalar access"), +/// if let ast::AccessType::Matrix(row, col) = sym.access_type.clone() && let ScalarExpr::Const(row) = *row && let ScalarExpr::Const(col) = *col { +/// core::mem::replace(expr, ast::ScalarExpr::Const(Span::new(span, value[row.item as usize][col.item as usize]))); /// } /// } /// } @@ -703,8 +697,11 @@ where V: ?Sized + VisitMut, { match expr { - ast::AccessType::Default | ast::AccessType::Index(_) | ast::AccessType::Matrix(..) => { - ControlFlow::Continue(()) + ast::AccessType::Default => ControlFlow::Continue(()), + ast::AccessType::Index(index) => visitor.visit_mut_scalar_expr(index), + ast::AccessType::Matrix(row, col) => { + visitor.visit_mut_scalar_expr(row)?; + visitor.visit_mut_scalar_expr(col) }, ast::AccessType::Slice(range) => { visitor.visit_mut_range_bound(&mut range.start)?; @@ -740,6 +737,7 @@ pub fn visit_mut_symbol_access( where V: ?Sized + VisitMut, { + visitor.visit_mut_access_type(&mut expr.access_type)?; visitor.visit_mut_resolvable_identifier(&mut expr.name) } diff --git a/parser/src/lexer/mod.rs b/parser/src/lexer/mod.rs index 8265300ae..f354cf6f4 100644 --- a/parser/src/lexer/mod.rs +++ b/parser/src/lexer/mod.rs @@ -644,9 +644,10 @@ where match num.parse::() { Ok(i) => Token::Num(i), - Err(err) => { - Token::Error(LexicalError::InvalidInt { span: self.span(), reason: *err.kind() }) - }, + Err(err) => Token::Error(LexicalError::InvalidInt { + span: self.span(), + reason: err.kind().clone(), + }), } } } diff --git a/parser/src/parser/grammar.lalrpop b/parser/src/parser/grammar.lalrpop index c1351b31c..16ece1713 100644 --- a/parser/src/parser/grammar.lalrpop +++ b/parser/src/parser/grammar.lalrpop @@ -614,8 +614,8 @@ Size: u64 = { "[" "]" => <> } -Index: usize = { - "[" "]" => idx as usize +Index: Box = { + "[" "]" => Box::new(idx) } TableSize: u64 = { diff --git a/parser/src/parser/tests/computed_indices.rs b/parser/src/parser/tests/computed_indices.rs new file mode 100644 index 000000000..3238ecf49 --- /dev/null +++ b/parser/src/parser/tests/computed_indices.rs @@ -0,0 +1,143 @@ +use miden_diagnostics::{SourceSpan, Span}; + +use super::ParseTest; +use crate::ast::*; + +#[test] +fn basic_computed_indices() { + let source = " + def test + + trace_columns { + main: [a, b, c[4]], + } + + public_inputs { + inputs: [2], + } + + boundary_constraints { + enf a.first = 0; + } + + integrity_constraints { + let x = [0, 1, 2, 3, 4]; + + enf a = x[1 + 1]; + }"; + + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(test)); + expected.trace_columns.push(trace_segment!( + TraceSegmentId::Main, + "$main", + [(a, 1), (b, 1), (c, 4)] + )); + expected + .public_inputs + .insert(ident!(inputs), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2)); + expected.boundary_constraints = Some(Span::new( + SourceSpan::UNKNOWN, + vec![enforce!(eq!(bounded_access!(a, Boundary::First), int!(0)))], + )); + expected.integrity_constraints = Some(Span::new( + SourceSpan::UNKNOWN, + vec![let_!(x = vector!(int!(0), int!(1), int!(2), int!(3), int!(4)) => + enforce!(eq!(access!(a), access!(x[Box::new(add!(int!(1), int!(1)))]))))], + )); + + ParseTest::new().expect_module_ast(source, expected); +} + +#[test] +fn basic_computed_indices_in_lc() { + let source = " + def test + + trace_columns { + main: [a, b, c[4]], + } + + public_inputs { + inputs: [2], + } + + boundary_constraints { + enf a.first = 0; + } + + integrity_constraints { + let x = [0, 1, 2, 3, 4]; + let y = [i * x[1 + 1] for i in 0..5]; + + enf a = y[1 + 1]; + }"; + + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(test)); + expected.trace_columns.push(trace_segment!( + TraceSegmentId::Main, + "$main", + [(a, 1), (b, 1), (c, 4)] + )); + expected + .public_inputs + .insert(ident!(inputs), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2)); + expected.boundary_constraints = Some(Span::new( + SourceSpan::UNKNOWN, + vec![enforce!(eq!(bounded_access!(a, Boundary::First), int!(0)))], + )); + expected.integrity_constraints = Some(Span::new( + SourceSpan::UNKNOWN, + vec![let_!(x = vector!(int!(0), int!(1), int!(2), int!(3), int!(4)) => + let_!(y = lc!(((i, range!(0usize, 5usize))) => mul!(access!(i), access!(x[Box::new(add!(int!(1), int!(1)))]))).into() => + enforce!(eq!(access!(a), access!(y[Box::new(add!(int!(1), int!(1)))])))))], + )); + + ParseTest::new().expect_module_ast(source, expected); +} + +#[test] +fn computed_indices_in_lc() { + let source = " + def test + + trace_columns { + main: [a, b, c[4]], + } + + public_inputs { + inputs: [2], + } + + boundary_constraints { + enf a.first = 0; + } + + integrity_constraints { + let x = [0, 1, 2, 3, 4]; + let y = [i * x[i + 1] for i in 0..4]; + + enf a = y[1 + 1]; + }"; + + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(test)); + expected.trace_columns.push(trace_segment!( + TraceSegmentId::Main, + "$main", + [(a, 1), (b, 1), (c, 4)] + )); + expected + .public_inputs + .insert(ident!(inputs), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2)); + expected.boundary_constraints = Some(Span::new( + SourceSpan::UNKNOWN, + vec![enforce!(eq!(bounded_access!(a, Boundary::First), int!(0)))], + )); + expected.integrity_constraints = Some(Span::new( + SourceSpan::UNKNOWN, + vec![let_!(x = vector!(int!(0), int!(1), int!(2), int!(3), int!(4)) => + let_!(y = lc!(((i, range!(0usize, 4usize))) => mul!(access!(i), access!(x[Box::new(add!(access!(i), int!(1)))]))).into() => + enforce!(eq!(access!(a), access!(y[Box::new(add!(int!(1), int!(1)))])))))], + )); + + ParseTest::new().expect_module_ast(source, expected); +} diff --git a/parser/src/parser/tests/mod.rs b/parser/src/parser/tests/mod.rs index 80bc49c60..22f1110c6 100644 --- a/parser/src/parser/tests/mod.rs +++ b/parser/src/parser/tests/mod.rs @@ -165,6 +165,15 @@ macro_rules! access { }; ($name:ident [ $idx:literal ]) => { + ScalarExpr::SymbolAccess(SymbolAccess::new( + miden_diagnostics::SourceSpan::UNKNOWN, + ident!($name), + AccessType::Index(Box::new(int!($idx))), + 0, + )) + }; + + ($name:ident [ $idx:expr ]) => { ScalarExpr::SymbolAccess(SymbolAccess::new( miden_diagnostics::SourceSpan::UNKNOWN, ident!($name), @@ -174,6 +183,15 @@ macro_rules! access { }; ($name:literal [ $idx:literal ]) => { + ScalarExpr::SymbolAccess(SymbolAccess::new( + miden_diagnostics::SourceSpan::UNKNOWN, + ident!($name), + AccessType::Index(Box::new(int!($idx))), + 0, + )) + }; + + ($name:literal [ $idx:expr ]) => { ScalarExpr::SymbolAccess(SymbolAccess::new( miden_diagnostics::SourceSpan::UNKNOWN, ident!($name), @@ -183,6 +201,15 @@ macro_rules! access { }; ($name:ident [ $row:literal ] [ $col:literal ]) => { + ScalarExpr::SymbolAccess(SymbolAccess::new( + miden_diagnostics::SourceSpan::UNKNOWN, + ident!($name), + AccessType::Matrix(Box::new(int!($row)), Box::new(int!($col))), + 0, + )) + }; + + ($name:ident [ $row:expr ] [ $col:expr ]) => { ScalarExpr::SymbolAccess(SymbolAccess::new( miden_diagnostics::SourceSpan::UNKNOWN, ident!($name), @@ -192,6 +219,16 @@ macro_rules! access { }; ($name:ident [ $row:literal ] [ $col:literal ], $ty:expr) => { + ScalarExpr::SymbolAccess(SymbolAccess { + span: miden_diagnostics::SourceSpan::UNKNOWN, + name: ResolvableIdentifier::Local(ident!($name)), + access_type: AccessType::Matrix(Box::new(int!($row)), Box::new(int!($col))), + offset: 0, + ty: Some($ty), + }) + }; + + ($name:ident [ $row:expr ] [ $col:expr ], $ty:expr) => { ScalarExpr::SymbolAccess(SymbolAccess { span: miden_diagnostics::SourceSpan::UNKNOWN, name: ResolvableIdentifier::Local(ident!($name)), @@ -202,6 +239,16 @@ macro_rules! access { }; ($module:ident, $name:ident [ $idx:literal ], $ty:expr) => { + ScalarExpr::SymbolAccess(SymbolAccess { + span: miden_diagnostics::SourceSpan::UNKNOWN, + name: ident!($module, $name).into(), + access_type: AccessType::Index(Box::new(int!($idx))), + offset: 0, + ty: Some($ty), + }) + }; + + ($module:ident, $name:ident [ $idx:expr ], $ty:expr) => { ScalarExpr::SymbolAccess(SymbolAccess { span: miden_diagnostics::SourceSpan::UNKNOWN, name: ident!($module, $name).into(), @@ -212,6 +259,16 @@ macro_rules! access { }; ($module:ident, $name:ident [ $row:literal ] [ $col:literal ], $ty:expr) => { + ScalarExpr::SymbolAccess(SymbolAccess { + span: miden_diagnostics::SourceSpan::UNKNOWN, + name: ident!($module, $name).into(), + access_type: AccessType::Matrix(Box::new(int!($row)), Box::new(int!($col))), + offset: 0, + ty: Some($ty), + }) + }; + + ($module:ident, $name:ident [ $row:expr ] [ $col:expr ], $ty:expr) => { ScalarExpr::SymbolAccess(SymbolAccess { span: miden_diagnostics::SourceSpan::UNKNOWN, name: ident!($module, $name).into(), @@ -222,6 +279,15 @@ macro_rules! access { }; ($name:ident [ $idx:literal ], $offset:literal) => { + ScalarExpr::SymbolAccess(SymbolAccess::new( + miden_diagnostics::SourceSpan::UNKNOWN, + ident!($name), + AccessType::Index(Box::new(int!($idx))), + $offset, + )) + }; + + ($name:ident [ $idx:expr ], $offset:literal) => { ScalarExpr::SymbolAccess(SymbolAccess::new( miden_diagnostics::SourceSpan::UNKNOWN, ident!($name), @@ -231,6 +297,16 @@ macro_rules! access { }; ($name:ident [ $idx:literal ], $ty:expr) => { + ScalarExpr::SymbolAccess(SymbolAccess { + span: miden_diagnostics::SourceSpan::UNKNOWN, + name: ResolvableIdentifier::Local(ident!($name)), + access_type: AccessType::Index(Box::new(int!($idx))), + offset: 0, + ty: Some($ty), + }) + }; + + ($name:ident [ $idx:expr ], $ty:expr) => { ScalarExpr::SymbolAccess(SymbolAccess { span: miden_diagnostics::SourceSpan::UNKNOWN, name: ResolvableIdentifier::Local(ident!($name)), @@ -244,13 +320,32 @@ macro_rules! access { ScalarExpr::SymbolAccess(SymbolAccess { span: miden_diagnostics::SourceSpan::UNKNOWN, name: ResolvableIdentifier::Local(ident!($name)), - access_type: AccessType::Index($idx), + access_type: AccessType::Index(Box::new(int!($idx))), + offset: $offset, + ty: Some($ty), + }) + }; + + ($name:ident [ $idx:literal ], $offset:literal, $ty:expr) => { + ScalarExpr::SymbolAccess(SymbolAccess { + span: miden_diagnostics::SourceSpan::UNKNOWN, + name: ResolvableIdentifier::Local(ident!($name)), + access_type: AccessType::Index(Box::new(int!($idx))), offset: $offset, ty: Some($ty), }) }; ($name:literal [ $idx:literal ], $offset:literal) => { + ScalarExpr::SymbolAccess(SymbolAccess::new( + miden_diagnostics::SourceSpan::UNKNOWN, + ident!($name), + AccessType::Index(Box::new(int!($idx))), + $offset, + )) + }; + + ($name:literal [ $idx:expr ], $offset:literal) => { ScalarExpr::SymbolAccess(SymbolAccess::new( miden_diagnostics::SourceSpan::UNKNOWN, ident!($name), @@ -322,7 +417,7 @@ macro_rules! bounded_access { SymbolAccess::new( miden_diagnostics::SourceSpan::UNKNOWN, ident!($name), - AccessType::Index($idx), + AccessType::Index(Box::new(int!($idx))), 0, ), $bound, @@ -335,7 +430,7 @@ macro_rules! bounded_access { SymbolAccess { span: miden_diagnostics::SourceSpan::UNKNOWN, name: ResolvableIdentifier::Local(ident!($name)), - access_type: AccessType::Index($idx), + access_type: AccessType::Index(Box::new(int!($idx))), offset: 0, ty: Some($ty), }, @@ -657,6 +752,7 @@ mod arithmetic_ops; mod boundary_constraints; mod buses; mod calls; +mod computed_indices; mod constant_propagation; mod constants; mod evaluators; diff --git a/parser/src/sema/binding_type.rs b/parser/src/sema/binding_type.rs index 2659634ef..d9bf5eed9 100644 --- a/parser/src/sema/binding_type.rs +++ b/parser/src/sema/binding_type.rs @@ -1,6 +1,8 @@ use std::fmt; -use crate::ast::{AccessType, BusType, FunctionType, InvalidAccessError, TraceBinding, Type}; +use crate::ast::{ + AccessType, BusType, FunctionType, InvalidAccessError, ScalarExpr, TraceBinding, Type, +}; /// This type provides type and contextual information about a binding, /// i.e. not only does it tell us the type of a binding, but what type @@ -59,10 +61,19 @@ impl BindingType { Self::TraceParam(tb) => tb.access(access_type).map(Self::TraceParam), Self::Vector(elems) => match access_type { AccessType::Default => Ok(Self::Vector(elems.clone())), - AccessType::Index(idx) if idx >= elems.len() => { - Err(InvalidAccessError::IndexOutOfBounds) + AccessType::Index(idx) => { + if let ScalarExpr::Const(idx) = *idx { + if idx.item as usize >= elems.len() { + Err(InvalidAccessError::IndexOutOfBounds) + } else { + Ok(elems[idx.item as usize].clone()) + } + } else { + // Items are all of the same type, we can just return the first one for now, + // as we cannot determine its value for now. + Ok(elems[0].clone()) + } }, - AccessType::Index(idx) => Ok(elems[idx].clone()), AccessType::Slice(range) => { let slice_range = range.to_slice_range(); if slice_range.end > elems.len() { @@ -71,10 +82,19 @@ impl BindingType { Ok(Self::Vector(elems[slice_range].to_vec())) } }, - AccessType::Matrix(row, _) if row >= elems.len() => { - Err(InvalidAccessError::IndexOutOfBounds) + AccessType::Matrix(row, col) => { + if let ScalarExpr::Const(row) = *row { + if row.item as usize >= elems.len() { + Err(InvalidAccessError::IndexOutOfBounds) + } else { + elems[row.item as usize].access(AccessType::Index(col)) + } + } else { + // Items are all of the same type, we can just return the first one for now, + // as we cannot determine its value for now. + elems[0].access(AccessType::Index(col)) + } }, - AccessType::Matrix(row, col) => elems[row].access(AccessType::Index(col)), }, Self::PublicInput(ty) => ty.access(access_type).map(Self::PublicInput), Self::PeriodicColumn(period) => match access_type { diff --git a/parser/src/sema/semantic_analysis.rs b/parser/src/sema/semantic_analysis.rs index 07d9d089a..f2093b262 100644 --- a/parser/src/sema/semantic_analysis.rs +++ b/parser/src/sema/semantic_analysis.rs @@ -169,11 +169,12 @@ impl VisitMut for SemanticAnalysis<'_> { offset: 0, size: segment.size, ty: Type::Vector(segment.size), + access: AccessType::Default, }) ), None ); - for binding in segment.bindings.iter().copied() { + for binding in segment.bindings.iter().cloned() { assert_eq!( self.locals.insert( NamespacedIdentifier::Binding(binding.name.unwrap()), @@ -184,6 +185,7 @@ impl VisitMut for SemanticAnalysis<'_> { offset: binding.offset, size: binding.size, ty: binding.ty, + access: binding.access, }) ), None @@ -353,6 +355,7 @@ impl VisitMut for SemanticAnalysis<'_> { offset: trace_binding.offset, size: trace_binding.size, ty: trace_binding.ty, + access: trace_binding.access.clone(), }), ); } @@ -586,7 +589,10 @@ impl VisitMut for SemanticAnalysis<'_> { match self.expr_binding_type(iterable) { Ok(iterable_binding_ty) => { let binding_ty = iterable_binding_ty - .access(AccessType::Index(0)) + .access(AccessType::Index(Box::new(ScalarExpr::Const(Span::new( + iterable.span(), + 0, + ))))) .expect("unexpected scalar iterable"); binding_tys.push((binding, iterable.span(), Some(binding_ty))); }, @@ -1153,7 +1159,7 @@ impl SemanticAnalysis<'_> { Expr::SymbolAccess(access) => { match self.access_binding_type(access) { Ok(BindingType::TraceColumn(tr) | BindingType::TraceParam(tr)) => { - if tr.size == param.size { + if tr.tb_size() == param.size { // Success, the argument and parameter types match up, but // we must make sure the segments also match let same_segment = tr.segment == param.id; @@ -1162,27 +1168,27 @@ impl SemanticAnalysis<'_> { let segment_name = segment_id_to_name(tr.segment); self.has_type_errors = true; self.diagnostics - .diagnostic(Severity::Error) - .with_message("invalid evaluator function argument") - .with_primary_label( - arg.span(), - format!( - "callee expects columns from the {expected_segment} trace"), - ) - .with_secondary_label( - tr.span, - format!( - "but this column is from the {segment_name} trace"), - ) - .emit(); + .diagnostic(Severity::Error) + .with_message("invalid evaluator function argument") + .with_primary_label( + arg.span(), + format!( + "callee expects columns from the {expected_segment} trace"), + ) + .with_secondary_label( + tr.span, + format!( + "but this column is from the {segment_name} trace"), + ) + .emit(); } } else { self.has_type_errors = true; self.diagnostics.diagnostic(Severity::Error) - .with_message("invalid call") - .with_primary_label(span, "type mismatch in function argument") - .with_secondary_label(arg.span(), format!("callee expects {} trace columns here, but this binding provides {}", param.size, tr.size)) - .emit(); + .with_message("invalid call") + .with_primary_label(span, "type mismatch in function argument") + .with_secondary_label(arg.span(), format!("callee expects {} trace columns here, but this binding provides {}", param.size, tr.tb_size())) + .emit(); } }, Ok(BindingType::Vector(ref elems)) => { @@ -1191,25 +1197,25 @@ impl SemanticAnalysis<'_> { match elem { BindingType::TraceColumn(tr) | BindingType::TraceParam(tr) => { if tr.segment == param.id { - size += tr.size; + size += tr.tb_size(); } else { let expected_segment = segment_id_to_name(param.id); let segment_name = segment_id_to_name(tr.segment); self.has_type_errors = true; self.diagnostics - .diagnostic(Severity::Error) - .with_message("invalid evaluator function argument") - .with_primary_label( - arg.span(), - format!( - "callee expects columns from the {expected_segment} trace"), - ) - .with_secondary_label( - tr.span, - format!( - "but this column is from the {segment_name} trace"), - ) - .emit(); + .diagnostic(Severity::Error) + .with_message("invalid evaluator function argument") + .with_primary_label( + arg.span(), + format!( + "callee expects columns from the {expected_segment} trace"), + ) + .with_secondary_label( + tr.span, + format!( + "but this column is from the {segment_name} trace"), + ) + .emit(); return ControlFlow::Continue(()); } }, @@ -1282,7 +1288,7 @@ impl SemanticAnalysis<'_> { match self.expr_binding_type(elem) { Ok(BindingType::TraceColumn(tr) | BindingType::TraceParam(tr)) => { if tr.segment == param.id { - size += tr.size; + size += tr.tb_size(); } else { let expected_segment = segment_id_to_name(param.id); let segment_name = segment_id_to_name(tr.segment); @@ -1367,15 +1373,19 @@ impl SemanticAnalysis<'_> { // Ensure the referenced symbol was a trace column, and that it produces a // scalar value, or a bus - let (found, _segment) = - match self.resolvable_binding_type(&access.column.name) { - Ok(ty) => match ty.item.access(access.column.access_type.clone()) { + let (found, _segment) = match self + .resolvable_binding_type(&access.column.name) + { + Ok(ty) => { + let accessed_ty = ty.item.access(access.column.access_type.clone()); + match accessed_ty.clone() { Ok(BindingType::TraceColumn(tb)) | Ok(BindingType::TraceParam(tb)) => { - if tb.is_scalar() { - (ty, tb.segment) + let tb_type = tb.ty(); + if tb_type.is_scalar() { + (Span::new(ty.span(), accessed_ty.unwrap()), tb.segment) } else { - let inferred = tb.ty(); + let inferred = tb_type; return self.type_mismatch( Some(&inferred), access.span(), @@ -1407,12 +1417,13 @@ impl SemanticAnalysis<'_> { ); }, _ => return ControlFlow::Break(SemanticAnalysisError::Invalid), - }, - Err(_) => { - // We've already raised a diagnostic for the undefined variable - return ControlFlow::Break(SemanticAnalysisError::Invalid); - }, - }; + } + }, + Err(_) => { + // We've already raised a diagnostic for the undefined variable + return ControlFlow::Break(SemanticAnalysisError::Invalid); + }, + }; match (found.clone().item, expr.rhs.as_mut()) { // Buses boundaries can be constrained by null or set to be diff --git a/parser/src/transforms/constant_propagation.rs b/parser/src/transforms/constant_propagation.rs index 098025220..d0eca421f 100644 --- a/parser/src/transforms/constant_propagation.rs +++ b/parser/src/transforms/constant_propagation.rs @@ -190,6 +190,8 @@ impl VisitMut for ConstantPropagation<'_> { // Need to check if this access is to a constant value, and transform to a constant if // so ScalarExpr::SymbolAccess(sym) => { + self.visit_mut_access_type(&mut sym.access_type)?; + let constant_value = match sym.name { // Possibly a reference to a constant declaration ResolvableIdentifier::Resolved(ref qid) => { @@ -208,9 +210,24 @@ impl VisitMut for ConstantPropagation<'_> { assert_eq!(sym.access_type, AccessType::Default); *expr = ScalarExpr::Const(Span::new(span, value)); }, - ConstantExpr::Vector(value) => match sym.access_type { - AccessType::Index(idx) => { - *expr = ScalarExpr::Const(Span::new(span, value[idx])); + ConstantExpr::Vector(value) => match sym.access_type.clone() { + AccessType::Index(idx) => match *idx { + ScalarExpr::Const(idx) => { + if idx.item >= value.len() as u64 { + self.diagnostics.diagnostic(miden_diagnostics::Severity::Error) + .with_message("attempted to access an index which is out of bounds") + .with_primary_label(span, "index out of bounds") + .emit(); + return ControlFlow::Break(SemanticAnalysisError::Invalid); + } + *expr = ScalarExpr::Const(Span::new( + span, + value[idx.item as usize], + )); + }, + _ => { + self.live.insert(*sym.name.as_ref()); + }, }, // This access cannot be resolved here, so we need to record the fact // that there are still live uses of this binding @@ -218,9 +235,26 @@ impl VisitMut for ConstantPropagation<'_> { self.live.insert(*sym.name.as_ref()); }, }, - ConstantExpr::Matrix(value) => match sym.access_type { - AccessType::Matrix(row, col) => { - *expr = ScalarExpr::Const(Span::new(span, value[row][col])); + ConstantExpr::Matrix(value) => match sym.access_type.clone() { + AccessType::Matrix(row, col) => match (*row, *col) { + (ScalarExpr::Const(row), ScalarExpr::Const(col)) => { + if row.item >= value.len() as u64 + || col.item >= value[row.item as usize].len() as u64 + { + self.diagnostics.diagnostic(miden_diagnostics::Severity::Error) + .with_message("attempted to access an index which is out of bounds") + .with_primary_label(span, "index out of bounds") + .emit(); + return ControlFlow::Break(SemanticAnalysisError::Invalid); + } + *expr = ScalarExpr::Const(Span::new( + span, + value[row.item as usize][col.item as usize], + )); + }, + _ => { + self.live.insert(*sym.name.as_ref()); + }, }, // This access cannot be resolved here, so we need to record the fact // that there are still live uses of this binding @@ -315,6 +349,8 @@ impl VisitMut for ConstantPropagation<'_> { // // We deal with symbol accesses directly, as they may evaluate to an aggregate constant Expr::SymbolAccess(access) => { + self.visit_mut_access_type(&mut access.access_type)?; + let constant_value = match access.name { // Possibly a reference to a constant declaration ResolvableIdentifier::Resolved(ref qid) => { @@ -342,9 +378,16 @@ impl VisitMut for ConstantPropagation<'_> { let vector = value[range].to_vec(); *expr = Expr::Const(Span::new(span, ConstantExpr::Vector(vector))); }, - AccessType::Index(idx) => { - *expr = - Expr::Const(Span::new(span, ConstantExpr::Scalar(value[idx]))); + AccessType::Index(idx) => match *idx { + ScalarExpr::Const(idx) => { + *expr = Expr::Const(Span::new( + span, + ConstantExpr::Scalar(value[idx.item as usize]), + )); + }, + _ => { + self.live.insert(*access.name.as_ref()); + }, }, ref ty => panic!( "invalid constant reference, expected scalar access, got {ty:?}", @@ -359,17 +402,29 @@ impl VisitMut for ConstantPropagation<'_> { let matrix = value[range].to_vec(); *expr = Expr::Const(Span::new(span, ConstantExpr::Matrix(matrix))); }, - AccessType::Index(idx) => { - *expr = Expr::Const(Span::new( - span, - ConstantExpr::Vector(value[idx].clone()), - )); + AccessType::Index(idx) => match *idx { + ScalarExpr::Const(idx) => { + *expr = Expr::Const(Span::new( + span, + ConstantExpr::Vector(value[idx.item as usize].clone()), + )); + }, + _ => { + self.live.insert(*access.name.as_ref()); + }, }, - AccessType::Matrix(row, col) => { - *expr = Expr::Const(Span::new( - span, - ConstantExpr::Scalar(value[row][col]), - )); + AccessType::Matrix(row, col) => match (*row, *col) { + (ScalarExpr::Const(row), ScalarExpr::Const(col)) => { + *expr = Expr::Const(Span::new( + span, + ConstantExpr::Scalar( + value[row.item as usize][col.item as usize], + ), + )); + }, + _ => { + self.live.insert(*access.name.as_ref()); + }, }, }, }