diff --git a/CHANGELOG.md b/CHANGELOG.md index b18ae9368..86f22e17d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ## 0.5.0 (TBD) - Incremented MSRV to 1.89. +- Add a constant propagation pass after other mir passes (#439). ## 0.4.0 (2025-06-20) diff --git a/air-script/tests/binary/binary.rs b/air-script/tests/binary/binary.rs index cbb2fcb97..8b7185bbd 100644 --- a/air-script/tests/binary/binary.rs +++ b/air-script/tests/binary/binary.rs @@ -82,8 +82,8 @@ impl Air for BinaryAir { fn evaluate_transition>(&self, frame: &EvaluationFrame, periodic_values: &[E], result: &mut [E]) { let main_current = frame.current(); let main_next = frame.next(); - result[0] = main_current[0] * main_current[0] - main_current[0] - E::ZERO; - result[1] = main_current[1] * main_current[1] - main_current[1] - E::ZERO; + result[0] = main_current[0] * main_current[0] - main_current[0]; + result[1] = main_current[1] * main_current[1] - main_current[1]; } fn evaluate_aux_transition(&self, main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], aux_rand_elements: &AuxRandElements, result: &mut [E]) diff --git a/air-script/tests/bitwise/bitwise.rs b/air-script/tests/bitwise/bitwise.rs index c631d2262..84916092d 100644 --- a/air-script/tests/bitwise/bitwise.rs +++ b/air-script/tests/bitwise/bitwise.rs @@ -82,23 +82,23 @@ impl Air for BitwiseAir { fn evaluate_transition>(&self, frame: &EvaluationFrame, periodic_values: &[E], result: &mut [E]) { let main_current = frame.current(); let main_next = frame.next(); - result[0] = main_current[0] * main_current[0] - main_current[0] - E::ZERO; - result[1] = periodic_values[1] * (main_next[0] - main_current[0]) - E::ZERO; - result[2] = main_current[3] * main_current[3] - main_current[3] - E::ZERO; - result[3] = main_current[4] * main_current[4] - main_current[4] - E::ZERO; - result[4] = main_current[5] * main_current[5] - main_current[5] - E::ZERO; - result[5] = main_current[6] * main_current[6] - main_current[6] - E::ZERO; - result[6] = main_current[7] * main_current[7] - main_current[7] - E::ZERO; - result[7] = main_current[8] * main_current[8] - main_current[8] - E::ZERO; - result[8] = main_current[9] * main_current[9] - main_current[9] - E::ZERO; - result[9] = main_current[10] * main_current[10] - main_current[10] - E::ZERO; - result[10] = periodic_values[0] * (main_current[1] - (E::ONE * main_current[3] + E::from(Felt::new(2_u64)) * main_current[4] + E::from(Felt::new(4_u64)) * main_current[5] + E::from(Felt::new(8_u64)) * main_current[6])) - E::ZERO; - result[11] = periodic_values[0] * (main_current[2] - (E::ONE * main_current[7] + E::from(Felt::new(2_u64)) * main_current[8] + E::from(Felt::new(4_u64)) * main_current[9] + E::from(Felt::new(8_u64)) * main_current[10])) - E::ZERO; - result[12] = periodic_values[1] * (main_next[1] - (main_current[1] * E::from(Felt::new(16_u64)) + E::ONE * main_current[3] + E::from(Felt::new(2_u64)) * main_current[4] + E::from(Felt::new(4_u64)) * main_current[5] + E::from(Felt::new(8_u64)) * main_current[6])) - E::ZERO; - result[13] = periodic_values[1] * (main_next[2] - (main_current[2] * E::from(Felt::new(16_u64)) + E::ONE * main_current[7] + E::from(Felt::new(2_u64)) * main_current[8] + E::from(Felt::new(4_u64)) * main_current[9] + E::from(Felt::new(8_u64)) * main_current[10])) - E::ZERO; - result[14] = periodic_values[0] * main_current[11] - E::ZERO; - result[15] = periodic_values[1] * (main_current[12] - main_next[11]) - E::ZERO; - result[16] = (E::ONE - main_current[0]) * (main_current[12] - (main_current[11] * E::from(Felt::new(16_u64)) + E::ONE * main_current[3] * main_current[7] + E::from(Felt::new(2_u64)) * main_current[4] * main_current[8] + E::from(Felt::new(4_u64)) * main_current[5] * main_current[9] + E::from(Felt::new(8_u64)) * main_current[6] * main_current[10])) + main_current[0] * (main_current[12] - (main_current[11] * E::from(Felt::new(16_u64)) + E::ONE * (main_current[3] + main_current[7] - E::from(Felt::new(2_u64)) * main_current[3] * main_current[7]) + E::from(Felt::new(2_u64)) * (main_current[4] + main_current[8] - E::from(Felt::new(2_u64)) * main_current[4] * main_current[8]) + E::from(Felt::new(4_u64)) * (main_current[5] + main_current[9] - E::from(Felt::new(2_u64)) * main_current[5] * main_current[9]) + E::from(Felt::new(8_u64)) * (main_current[6] + main_current[10] - E::from(Felt::new(2_u64)) * main_current[6] * main_current[10]))) - E::ZERO; + result[0] = main_current[0] * main_current[0] - main_current[0]; + result[1] = periodic_values[1] * (main_next[0] - main_current[0]); + result[2] = main_current[3] * main_current[3] - main_current[3]; + result[3] = main_current[4] * main_current[4] - main_current[4]; + result[4] = main_current[5] * main_current[5] - main_current[5]; + result[5] = main_current[6] * main_current[6] - main_current[6]; + result[6] = main_current[7] * main_current[7] - main_current[7]; + result[7] = main_current[8] * main_current[8] - main_current[8]; + result[8] = main_current[9] * main_current[9] - main_current[9]; + result[9] = main_current[10] * main_current[10] - main_current[10]; + result[10] = periodic_values[0] * (main_current[1] - (main_current[3] + E::from(Felt::new(2_u64)) * main_current[4] + E::from(Felt::new(4_u64)) * main_current[5] + E::from(Felt::new(8_u64)) * main_current[6])); + result[11] = periodic_values[0] * (main_current[2] - (main_current[7] + E::from(Felt::new(2_u64)) * main_current[8] + E::from(Felt::new(4_u64)) * main_current[9] + E::from(Felt::new(8_u64)) * main_current[10])); + result[12] = periodic_values[1] * (main_next[1] - (main_current[1] * E::from(Felt::new(16_u64)) + main_current[3] + E::from(Felt::new(2_u64)) * main_current[4] + E::from(Felt::new(4_u64)) * main_current[5] + E::from(Felt::new(8_u64)) * main_current[6])); + result[13] = periodic_values[1] * (main_next[2] - (main_current[2] * E::from(Felt::new(16_u64)) + main_current[7] + E::from(Felt::new(2_u64)) * main_current[8] + E::from(Felt::new(4_u64)) * main_current[9] + E::from(Felt::new(8_u64)) * main_current[10])); + result[14] = periodic_values[0] * main_current[11]; + result[15] = periodic_values[1] * (main_current[12] - main_next[11]); + result[16] = (E::ONE - main_current[0]) * (main_current[12] - (main_current[11] * E::from(Felt::new(16_u64)) + main_current[3] * main_current[7] + E::from(Felt::new(2_u64)) * main_current[4] * main_current[8] + E::from(Felt::new(4_u64)) * main_current[5] * main_current[9] + E::from(Felt::new(8_u64)) * main_current[6] * main_current[10])) + main_current[0] * (main_current[12] - (main_current[11] * E::from(Felt::new(16_u64)) + main_current[3] + main_current[7] - E::from(Felt::new(2_u64)) * main_current[3] * main_current[7] + E::from(Felt::new(2_u64)) * (main_current[4] + main_current[8] - E::from(Felt::new(2_u64)) * main_current[4] * main_current[8]) + E::from(Felt::new(4_u64)) * (main_current[5] + main_current[9] - E::from(Felt::new(2_u64)) * main_current[5] * main_current[9]) + E::from(Felt::new(8_u64)) * (main_current[6] + main_current[10] - E::from(Felt::new(2_u64)) * main_current[6] * main_current[10]))); } fn evaluate_aux_transition(&self, main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], aux_rand_elements: &AuxRandElements, result: &mut [E]) diff --git a/air-script/tests/buses/buses_complex.rs b/air-script/tests/buses/buses_complex.rs index e822dcc45..baf58e6da 100644 --- a/air-script/tests/buses/buses_complex.rs +++ b/air-script/tests/buses/buses_complex.rs @@ -98,7 +98,7 @@ impl Air for BusesAir { let main_next = main_frame.next(); let aux_current = aux_frame.current(); let aux_next = aux_frame.next(); - result[0] = ((aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + (E::ZERO + E::ZERO + E::ONE + E::from(Felt::new(2_u64)) + E::from(main_current[1])) * aux_rand_elements.rand_elements()[2] + E::from(main_current[0]) * aux_rand_elements.rand_elements()[3]) * E::from(main_current[2]) + E::ONE - E::from(main_current[2])) * ((aux_rand_elements.rand_elements()[0] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[1] + E::from(main_current[1]) * aux_rand_elements.rand_elements()[2]) * (E::ONE - E::from(main_current[2])) + E::ONE - (E::ONE - E::from(main_current[2]))) * aux_current[0] - ((aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + (E::ZERO + E::ZERO + E::ONE + E::from(Felt::new(2_u64)) + E::from(main_current[1])) * aux_rand_elements.rand_elements()[2] + E::from(main_current[1]) * aux_rand_elements.rand_elements()[3]) * E::from(main_current[3]) + E::ONE - E::from(main_current[3])) * ((aux_rand_elements.rand_elements()[0] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[1] + E::from(main_current[0]) * aux_rand_elements.rand_elements()[2]) * (E::ONE - E::from(main_current[3])) + E::ONE - (E::ONE - E::from(main_current[3]))) * aux_next[0]; + result[0] = ((aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + (E::from(Felt::new(3_u64)) + E::from(main_current[1])) * aux_rand_elements.rand_elements()[2] + E::from(main_current[0]) * aux_rand_elements.rand_elements()[3]) * E::from(main_current[2]) + E::ONE - E::from(main_current[2])) * ((aux_rand_elements.rand_elements()[0] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[1] + E::from(main_current[1]) * aux_rand_elements.rand_elements()[2]) * (E::ONE - E::from(main_current[2])) + E::ONE - (E::ONE - E::from(main_current[2]))) * aux_current[0] - ((aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + (E::from(Felt::new(3_u64)) + E::from(main_current[1])) * aux_rand_elements.rand_elements()[2] + E::from(main_current[1]) * aux_rand_elements.rand_elements()[3]) * E::from(main_current[3]) + E::ONE - E::from(main_current[3])) * ((aux_rand_elements.rand_elements()[0] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[1] + E::from(main_current[0]) * aux_rand_elements.rand_elements()[2]) * (E::ONE - E::from(main_current[3])) + E::ONE - (E::ONE - E::from(main_current[3]))) * aux_next[0]; result[1] = (aux_rand_elements.rand_elements()[0] + E::from(Felt::new(3_u64)) * aux_rand_elements.rand_elements()[1] + E::from(main_current[0]) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + E::from(Felt::new(3_u64)) * aux_rand_elements.rand_elements()[1] + E::from(main_current[0]) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + E::from(Felt::new(4_u64)) * aux_rand_elements.rand_elements()[1] + E::from(main_current[1]) * aux_rand_elements.rand_elements()[2]) * aux_current[1] + (aux_rand_elements.rand_elements()[0] + E::from(Felt::new(3_u64)) * aux_rand_elements.rand_elements()[1] + E::from(main_current[0]) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + E::from(Felt::new(4_u64)) * aux_rand_elements.rand_elements()[1] + E::from(main_current[1]) * aux_rand_elements.rand_elements()[2]) * E::from(main_current[2]) + (aux_rand_elements.rand_elements()[0] + E::from(Felt::new(3_u64)) * aux_rand_elements.rand_elements()[1] + E::from(main_current[0]) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + E::from(Felt::new(4_u64)) * aux_rand_elements.rand_elements()[1] + E::from(main_current[1]) * aux_rand_elements.rand_elements()[2]) * E::from(main_current[2]) - ((aux_rand_elements.rand_elements()[0] + E::from(Felt::new(3_u64)) * aux_rand_elements.rand_elements()[1] + E::from(main_current[0]) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + E::from(Felt::new(3_u64)) * aux_rand_elements.rand_elements()[1] + E::from(main_current[0]) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + E::from(Felt::new(4_u64)) * aux_rand_elements.rand_elements()[1] + E::from(main_current[1]) * aux_rand_elements.rand_elements()[2]) * aux_next[1] + (aux_rand_elements.rand_elements()[0] + E::from(Felt::new(3_u64)) * aux_rand_elements.rand_elements()[1] + E::from(main_current[0]) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + E::from(Felt::new(3_u64)) * aux_rand_elements.rand_elements()[1] + E::from(main_current[0]) * aux_rand_elements.rand_elements()[2]) * E::from(main_current[4])); } } \ No newline at end of file diff --git a/air-script/tests/constant_in_range/constant_in_range.rs b/air-script/tests/constant_in_range/constant_in_range.rs index aff438e9f..2c53ec6ff 100644 --- a/air-script/tests/constant_in_range/constant_in_range.rs +++ b/air-script/tests/constant_in_range/constant_in_range.rs @@ -82,7 +82,7 @@ impl Air for ConstantInRangeAir { fn evaluate_transition>(&self, frame: &EvaluationFrame, periodic_values: &[E], result: &mut [E]) { let main_current = frame.current(); let main_next = frame.next(); - result[0] = main_current[0] - (E::ZERO + main_current[1] - main_current[4] - main_current[8] + E::ONE + main_current[2] - main_current[5] - main_current[9] + E::from(Felt::new(2_u64)) + main_current[3] - main_current[6] - main_current[10]); + result[0] = main_current[0] - (main_current[1] - main_current[4] - main_current[8] + E::ONE + main_current[2] - main_current[5] - main_current[9] + E::from(Felt::new(2_u64)) + main_current[3] - main_current[6] - main_current[10]); } fn evaluate_aux_transition(&self, main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], aux_rand_elements: &AuxRandElements, result: &mut [E]) diff --git a/air-script/tests/constants/constants.rs b/air-script/tests/constants/constants.rs index 6dbae1db0..9c228099a 100644 --- a/air-script/tests/constants/constants.rs +++ b/air-script/tests/constants/constants.rs @@ -85,7 +85,7 @@ impl Air for ConstantsAir { result.push(Assertion::single(0, 0, Felt::ONE)); result.push(Assertion::single(1, 0, Felt::ONE)); result.push(Assertion::single(2, 0, Felt::ZERO)); - result.push(Assertion::single(3, 0, Felt::ONE - Felt::new(2) + Felt::new(2) - Felt::ZERO)); + result.push(Assertion::single(3, 0, Felt::ONE - Felt::new(2) + Felt::new(2))); result.push(Assertion::single(4, 0, Felt::ONE)); result.push(Assertion::single(6, self.last_step(), Felt::ZERO)); result @@ -100,9 +100,9 @@ impl Air for ConstantsAir { let main_current = frame.current(); let main_next = frame.next(); result[0] = main_next[0] - (main_current[0] + E::ONE); - result[1] = main_next[1] - E::ZERO * main_current[1]; - result[2] = main_next[2] - E::ONE * main_current[2]; - result[3] = main_next[5] - (main_current[5] + E::ONE + E::ZERO); + result[1] = main_next[1]; + result[2] = main_next[2] - main_current[2]; + result[3] = main_next[5] - (main_current[5] + E::ONE); result[4] = main_current[4] - E::ONE; } diff --git a/air-script/tests/functions/functions_complex.rs b/air-script/tests/functions/functions_complex.rs index f9ba361dd..f165c2d0e 100644 --- a/air-script/tests/functions/functions_complex.rs +++ b/air-script/tests/functions/functions_complex.rs @@ -83,7 +83,7 @@ impl Air for FunctionsAir { let main_current = frame.current(); let main_next = frame.next(); result[0] = main_next[16] - main_current[16] * ((main_current[3] * main_current[3] * main_current[3] * main_current[3] * main_current[3] * main_current[3] * main_current[3] * main_current[1] * main_current[2] + main_current[3] * main_current[3] * (E::ONE - main_current[1]) * main_current[2] + main_current[3] * main_current[1] * (E::ONE - main_current[2]) + (E::ONE - main_current[1]) * (E::ONE - main_current[2])) * main_current[0] - main_current[0] + E::ONE); - result[1] = main_next[3] - (E::ZERO + main_current[4] + main_current[5] + main_current[6] + main_current[7] + main_current[8] + main_current[9] + main_current[10] + main_current[11] + main_current[12] + main_current[13] + main_current[14] + main_current[15] + E::ONE) * E::from(Felt::new(2_u64)); + result[1] = main_next[3] - (main_current[4] + main_current[5] + main_current[6] + main_current[7] + main_current[8] + main_current[9] + main_current[10] + main_current[11] + main_current[12] + main_current[13] + main_current[14] + main_current[15] + E::ONE) * E::from(Felt::new(2_u64)); } fn evaluate_aux_transition(&self, main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], aux_rand_elements: &AuxRandElements, result: &mut [E]) diff --git a/air-script/tests/list_comprehension/list_comprehension.rs b/air-script/tests/list_comprehension/list_comprehension.rs index b8e8ecf69..561032260 100644 --- a/air-script/tests/list_comprehension/list_comprehension.rs +++ b/air-script/tests/list_comprehension/list_comprehension.rs @@ -83,10 +83,10 @@ impl Air for ListComprehensionAir { let main_current = frame.current(); let main_next = frame.next(); result[0] = main_current[0] - main_current[2]; - result[1] = main_current[4] - main_current[0] * E::from(Felt::new(2_u64)) * E::from(Felt::new(2_u64)) * E::from(Felt::new(2_u64)) * main_current[11]; + result[1] = main_current[4] - main_current[0] * E::from(Felt::new(8_u64)) * main_current[11]; result[2] = main_current[4] - main_current[0] * (main_next[8] - main_next[12]); result[3] = main_current[6] - main_current[0] * (main_current[9] - main_current[14]); - result[4] = main_current[1] - (E::ZERO + main_current[5] - main_current[8] - main_current[12] + E::ONE + main_current[6] - main_current[9] - main_current[13] + E::from(Felt::new(2_u64)) + main_current[7] - main_current[10] - main_current[14]); + result[4] = main_current[1] - (main_current[5] - main_current[8] - main_current[12] + E::ONE + main_current[6] - main_current[9] - main_current[13] + E::from(Felt::new(2_u64)) + main_current[7] - main_current[10] - main_current[14]); } fn evaluate_aux_transition(&self, main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], aux_rand_elements: &AuxRandElements, result: &mut [E]) diff --git a/air-script/tests/list_comprehension/list_comprehension_nested.rs b/air-script/tests/list_comprehension/list_comprehension_nested.rs index ab9703069..dfdb535f9 100644 --- a/air-script/tests/list_comprehension/list_comprehension_nested.rs +++ b/air-script/tests/list_comprehension/list_comprehension_nested.rs @@ -82,9 +82,9 @@ impl Air for ListComprehensionAir { fn evaluate_transition>(&self, frame: &EvaluationFrame, periodic_values: &[E], result: &mut [E]) { let main_current = frame.current(); let main_next = frame.next(); - result[0] = E::ZERO + main_current[0] * E::ONE + main_current[1] * E::from(Felt::new(2_u64)) - E::from(Felt::new(3_u64)); - result[1] = E::ZERO + main_current[0] * E::from(Felt::new(2_u64)) + main_current[1] * E::from(Felt::new(3_u64)) - E::from(Felt::new(5_u64)); - result[2] = E::ZERO + main_current[0] * E::from(Felt::new(3_u64)) + main_current[1] * E::from(Felt::new(4_u64)) - E::from(Felt::new(7_u64)); + result[0] = main_current[0] + main_current[1] * E::from(Felt::new(2_u64)) - E::from(Felt::new(3_u64)); + result[1] = main_current[0] * E::from(Felt::new(2_u64)) + main_current[1] * E::from(Felt::new(3_u64)) - E::from(Felt::new(5_u64)); + result[2] = main_current[0] * E::from(Felt::new(3_u64)) + main_current[1] * E::from(Felt::new(4_u64)) - E::from(Felt::new(7_u64)); } fn evaluate_aux_transition(&self, main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], aux_rand_elements: &AuxRandElements, result: &mut [E]) diff --git a/air-script/tests/list_folding/list_folding.rs b/air-script/tests/list_folding/list_folding.rs index 3ad7e54ea..d94090cfa 100644 --- a/air-script/tests/list_folding/list_folding.rs +++ b/air-script/tests/list_folding/list_folding.rs @@ -82,10 +82,10 @@ impl Air for ListFoldingAir { fn evaluate_transition>(&self, frame: &EvaluationFrame, periodic_values: &[E], result: &mut [E]) { let main_current = frame.current(); let main_next = frame.next(); - result[0] = main_next[5] - (E::ZERO + main_current[9] + main_current[10] + main_current[11] + main_current[12] + E::ONE * main_current[13] * main_current[14] * main_current[15] * main_current[16]); - result[1] = main_next[6] - (E::ZERO + main_current[9] + main_current[10] + main_current[11] + main_current[12] + E::ONE * main_current[13] * main_current[14] * main_current[15] * main_current[16]); - result[2] = main_next[7] - (E::ZERO + main_current[9] * main_current[13] + main_current[10] * main_current[14] + main_current[11] * main_current[15] + main_current[12] * main_current[16] + E::ONE * (main_current[9] + main_current[13]) * (main_current[10] + main_current[14]) * (main_current[11] + main_current[15]) * (main_current[12] + main_current[16])); - result[3] = main_next[8] - (main_current[1] + E::ZERO + main_current[9] * main_current[13] + main_current[10] * main_current[14] + main_current[11] * main_current[15] + main_current[12] * main_current[16] + E::ZERO + main_current[9] * main_current[13] + main_current[10] * main_current[14] + main_current[11] * main_current[15] + main_current[12] * main_current[16]); + result[0] = main_next[5] - (main_current[9] + main_current[10] + main_current[11] + main_current[12] + main_current[13] * main_current[14] * main_current[15] * main_current[16]); + result[1] = main_next[6] - (main_current[9] + main_current[10] + main_current[11] + main_current[12] + main_current[13] * main_current[14] * main_current[15] * main_current[16]); + result[2] = main_next[7] - (main_current[9] * main_current[13] + main_current[10] * main_current[14] + main_current[11] * main_current[15] + main_current[12] * main_current[16] + (main_current[9] + main_current[13]) * (main_current[10] + main_current[14]) * (main_current[11] + main_current[15]) * (main_current[12] + main_current[16])); + result[3] = main_next[8] - (main_current[1] + main_current[9] * main_current[13] + main_current[10] * main_current[14] + main_current[11] * main_current[15] + main_current[12] * main_current[16] + main_current[9] * main_current[13] + main_current[10] * main_current[14] + main_current[11] * main_current[15] + main_current[12] * main_current[16]); } fn evaluate_aux_transition(&self, main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], aux_rand_elements: &AuxRandElements, result: &mut [E]) diff --git a/air-script/tests/periodic_columns/periodic_columns.rs b/air-script/tests/periodic_columns/periodic_columns.rs index eb1c592da..db71fbe06 100644 --- a/air-script/tests/periodic_columns/periodic_columns.rs +++ b/air-script/tests/periodic_columns/periodic_columns.rs @@ -82,8 +82,8 @@ impl Air for PeriodicColumnsAir { fn evaluate_transition>(&self, frame: &EvaluationFrame, periodic_values: &[E], result: &mut [E]) { let main_current = frame.current(); let main_next = frame.next(); - result[0] = periodic_values[0] * (main_current[1] + main_current[2]) - E::ZERO; - result[1] = periodic_values[1] * (main_next[0] - main_current[0]) - E::ZERO; + result[0] = periodic_values[0] * (main_current[1] + main_current[2]); + result[1] = periodic_values[1] * (main_next[0] - main_current[0]); } fn evaluate_aux_transition(&self, main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], aux_rand_elements: &AuxRandElements, result: &mut [E]) diff --git a/air-script/tests/selectors/selectors.rs b/air-script/tests/selectors/selectors.rs index 796ff2ba0..6bbdee036 100644 --- a/air-script/tests/selectors/selectors.rs +++ b/air-script/tests/selectors/selectors.rs @@ -82,8 +82,8 @@ impl Air for SelectorsAir { fn evaluate_transition>(&self, frame: &EvaluationFrame, periodic_values: &[E], result: &mut [E]) { let main_current = frame.current(); let main_next = frame.next(); - result[0] = main_current[0] * (E::ONE - main_current[1]) * (main_next[3] - E::ZERO) - E::ZERO; - result[1] = main_current[0] * main_current[1] * main_current[2] * (main_next[3] - main_current[3]) + (E::ONE - main_current[1]) * (E::ONE - main_current[2]) * (main_next[3] - E::ONE) - E::ZERO; + result[0] = main_current[0] * (E::ONE - main_current[1]) * main_next[3]; + result[1] = main_current[0] * main_current[1] * main_current[2] * (main_next[3] - main_current[3]) + (E::ONE - main_current[1]) * (E::ONE - main_current[2]) * (main_next[3] - E::ONE); } fn evaluate_aux_transition(&self, main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], aux_rand_elements: &AuxRandElements, result: &mut [E]) diff --git a/air-script/tests/selectors/selectors_combine_complex.rs b/air-script/tests/selectors/selectors_combine_complex.rs index bbc8fdf59..90b1cb4fe 100644 --- a/air-script/tests/selectors/selectors_combine_complex.rs +++ b/air-script/tests/selectors/selectors_combine_complex.rs @@ -84,9 +84,9 @@ impl Air for SelectorsAir { fn evaluate_transition>(&self, frame: &EvaluationFrame, periodic_values: &[E], result: &mut [E]) { let main_current = frame.current(); let main_next = frame.next(); - result[0] = (main_current[0] + (E::ONE - main_current[0]) * main_current[1]) * (main_current[3] - E::from(Felt::new(2_u64)) * E::from(Felt::new(8_u64))) + (E::ONE - main_current[0]) * (E::ONE - main_current[1]) * (main_current[4] - E::from(Felt::new(5_u64))) - E::ZERO; - result[1] = ((E::ONE - main_current[0]) * main_current[1] + (E::ONE - main_current[0]) * (E::ONE - main_current[1])) * (main_current[5] - E::from(Felt::new(5_u64))) + main_current[0] * (main_current[4] - E::from(Felt::new(4_u64))) - E::ZERO; - result[2] = main_current[0] * (main_current[5] - E::from(Felt::new(20_u64))) + (E::ONE - main_current[0]) * main_current[1] * (main_current[4] - E::from(Felt::new(31_u64))) - E::ZERO; + result[0] = (main_current[0] + (E::ONE - main_current[0]) * main_current[1]) * (main_current[3] - E::from(Felt::new(16_u64))) + (E::ONE - main_current[0]) * (E::ONE - main_current[1]) * (main_current[4] - E::from(Felt::new(5_u64))); + result[1] = ((E::ONE - main_current[0]) * main_current[1] + (E::ONE - main_current[0]) * (E::ONE - main_current[1])) * (main_current[5] - E::from(Felt::new(5_u64))) + main_current[0] * (main_current[4] - E::from(Felt::new(4_u64))); + result[2] = main_current[0] * (main_current[5] - E::from(Felt::new(20_u64))) + (E::ONE - main_current[0]) * main_current[1] * (main_current[4] - E::from(Felt::new(31_u64))); } fn evaluate_aux_transition(&self, main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], aux_rand_elements: &AuxRandElements, result: &mut [E]) diff --git a/air-script/tests/selectors/selectors_combine_simple.rs b/air-script/tests/selectors/selectors_combine_simple.rs index 21756f394..5b1a06bcd 100644 --- a/air-script/tests/selectors/selectors_combine_simple.rs +++ b/air-script/tests/selectors/selectors_combine_simple.rs @@ -82,8 +82,8 @@ impl Air for SelectorsAir { fn evaluate_transition>(&self, frame: &EvaluationFrame, periodic_values: &[E], result: &mut [E]) { let main_current = frame.current(); let main_next = frame.next(); - result[0] = (main_current[3] + E::ONE - main_current[3]) * (main_next[1] - main_current[2]) - E::ZERO; - result[1] = main_current[3] * (main_next[0] - (main_current[0] + main_current[1])) + (E::ONE - main_current[3]) * (main_next[0] - main_current[0] * main_current[1]) - E::ZERO; + result[0] = (main_current[3] + E::ONE - main_current[3]) * (main_next[1] - main_current[2]); + result[1] = main_current[3] * (main_next[0] - (main_current[0] + main_current[1])) + (E::ONE - main_current[3]) * (main_next[0] - main_current[0] * main_current[1]); } fn evaluate_aux_transition(&self, main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], aux_rand_elements: &AuxRandElements, result: &mut [E]) diff --git a/air-script/tests/selectors/selectors_with_evaluators.rs b/air-script/tests/selectors/selectors_with_evaluators.rs index d4466503a..4902ac774 100644 --- a/air-script/tests/selectors/selectors_with_evaluators.rs +++ b/air-script/tests/selectors/selectors_with_evaluators.rs @@ -82,8 +82,8 @@ impl Air for SelectorsAir { fn evaluate_transition>(&self, frame: &EvaluationFrame, periodic_values: &[E], result: &mut [E]) { let main_current = frame.current(); let main_next = frame.next(); - result[0] = main_current[0] * (E::ONE - main_current[1]) * (main_next[3] - E::ZERO) - E::ZERO; - result[1] = main_current[1] * main_current[2] * (main_current[0] * (main_next[3] - main_current[3]) - E::ZERO) + (E::ONE - main_current[1]) * (E::ONE - main_current[2]) * (main_next[3] - E::ONE) - E::ZERO; + result[0] = main_current[0] * (E::ONE - main_current[1]) * main_next[3]; + result[1] = main_current[1] * main_current[2] * main_current[0] * (main_next[3] - main_current[3]) + (E::ONE - main_current[1]) * (E::ONE - main_current[2]) * (main_next[3] - E::ONE); } fn evaluate_aux_transition(&self, main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], aux_rand_elements: &AuxRandElements, result: &mut [E]) diff --git a/air-script/tests/variables/variables.rs b/air-script/tests/variables/variables.rs index 22127d0b8..9acea8ee2 100644 --- a/air-script/tests/variables/variables.rs +++ b/air-script/tests/variables/variables.rs @@ -88,9 +88,9 @@ impl Air for VariablesAir { let main_current = frame.current(); let main_next = frame.next(); result[0] = main_current[0] * main_current[0] - main_current[0]; - result[1] = periodic_values[0] * (main_next[0] - main_current[0]) - E::ZERO; + result[1] = periodic_values[0] * (main_next[0] - main_current[0]); result[2] = (E::ONE - main_current[0]) * (main_current[3] - main_current[1] - main_current[2]) - (E::from(Felt::new(6_u64)) - (E::from(Felt::new(7_u64)) - main_current[0])); - result[3] = main_current[0] * (main_current[3] - main_current[1] * main_current[2]) - (E::from(Felt::new(4_u64)) - E::from(Felt::new(3_u64)) - main_next[0]); + result[3] = main_current[0] * (main_current[3] - main_current[1] * main_current[2]) - (E::ONE - main_next[0]); } fn evaluate_aux_transition(&self, main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], aux_rand_elements: &AuxRandElements, result: &mut [E]) diff --git a/air/src/passes/translate_from_mir.rs b/air/src/passes/translate_from_mir.rs index 3477b4623..76735138b 100644 --- a/air/src/passes/translate_from_mir.rs +++ b/air/src/passes/translate_from_mir.rs @@ -6,7 +6,10 @@ use air_parser::{ }; use air_pass::Pass; use miden_diagnostics::{DiagnosticsHandler, Severity, SourceSpan, Span, Spanned}; -use mir::ir::{ConstantValue, Link, Mir, MirValue, Op, Parent, SpannedMirValue}; +use mir::ir::{ + Boundary as MirBoundary, ConstantValue, Link, Mir, MirValue, Op, Parent, SpannedMirValue, + TraceAccess as MirTraceAccess, +}; use crate::{CompileError, graph::NodeIndex, ir::*}; @@ -400,72 +403,9 @@ impl AirBuilder<'_> { let boundary = lhs.as_boundary().unwrap().clone(); - let expected_trace_access_expr = boundary.expr.clone(); - let Op::Value(value) = expected_trace_access_expr.borrow().deref().clone() else { - unreachable!(); // Raise diag - }; - - let (trace_access, _) = match value.value.clone() { - SpannedMirValue { - value: MirValue::TraceAccess(trace_access), - span: lhs_span, - } => (trace_access, lhs_span), - SpannedMirValue { - value: MirValue::TraceAccessBinding(trace_access_binding), - span: lhs_span, - } => { - if trace_access_binding.size != 1 { - self.diagnostics.diagnostic(Severity::Error) - .with_message("invalid boundary constraint") - .with_primary_label(lhs_span, "this has a trace access binding with a size greater than 1") - .with_note("Boundary constraints require both sides of the constraint to be single columns.") - .emit(); - return Err(CompileError::Failed); - } - let trace_access = mir::ir::TraceAccess { - segment: trace_access_binding.segment, - column: trace_access_binding.offset, - row_offset: 0, - }; - (trace_access, lhs_span) - }, - SpannedMirValue { - value: MirValue::BusAccess(bus_access), - span: lhs_span, - } => { - let bus = bus_access.bus; - let name = bus.borrow().deref().name(); - let column = self.bus_bindings_map.get(&name).unwrap(); - let trace_access = mir::ir::TraceAccess::new( - TraceSegmentId::Aux, - *column, - bus_access.row_offset, - ); - (trace_access, lhs_span) - }, - _ => unreachable!( - "Expected TraceAccess or BusAccess, received {:?}", - value.value - ), // Raise diag - }; + let trace_access = self.extract_trace_from_boundary(boundary.clone())?; - if let Some(prev) = self - .trace_columns - .get_mut(&trace_access.segment) - .expect("Boundary constraint on an unknown trace segment") - .mark_constrained(lhs_span, trace_access.column, boundary.kind) - { - self.diagnostics - .diagnostic(Severity::Error) - .with_message("overlapping boundary constraints") - .with_primary_label( - lhs_span, - "this constrains a column and boundary that has already been constrained", - ) - .with_secondary_label(prev, "previous constraint occurs here") - .emit(); - return Err(CompileError::Failed); - } + self.mark_constrained_boundary(trace_access, &boundary)?; let lhs = self.air.constraint_graph_mut().insert_node(Operation::Value( crate::ir::Value::TraceAccess(crate::ir::TraceAccess { @@ -514,6 +454,25 @@ impl AirBuilder<'_> { self.air.constraints.insert_constraint(trace_access.segment, root, domain); Ok(()) }, + Op::Boundary(boundary) => { + let trace_access = self.extract_trace_from_boundary(boundary.clone())?; + + self.mark_constrained_boundary(trace_access, boundary)?; + + let root = self.air.constraint_graph_mut().insert_node(Operation::Value( + crate::ir::Value::TraceAccess(crate::ir::TraceAccess { + segment: trace_access.segment, + column: trace_access.column, + row_offset: trace_access.row_offset, + }), + )); + + let domain = boundary.kind.into(); + + // Store the generated constraint + self.air.constraints.insert_constraint(trace_access.segment, root, domain); + Ok(()) + }, _ => unreachable!(), } } @@ -551,7 +510,14 @@ impl AirBuilder<'_> { bus.borrow_mut().latches.push(latch.clone()); bus.borrow_mut().columns.push(child_op.clone()); }, - _ => unreachable!("Enforced with unexpected operation: {:?}", child_op), + _ => { + let root = self.insert_mir_operation(&child_op)?; + let (trace_segment, domain) = self + .air + .constraint_graph() + .node_details(&root, ConstraintDomain::EveryRow)?; + self.air.constraints.insert_constraint(trace_segment, root, domain); + }, } }, Op::Sub(sub) => { @@ -564,7 +530,7 @@ impl AirBuilder<'_> { self.air.constraint_graph().node_details(&root, ConstraintDomain::EveryRow)?; self.air.constraints.insert_constraint(trace_segment, root, domain); }, - _ => unreachable!(), + _ => unreachable!("Unexpected integrity constraint root: {:?}", ic), } Ok(()) } @@ -604,6 +570,89 @@ impl AirBuilder<'_> { fn insert_op(&mut self, op: Operation) -> NodeIndex { self.air.constraint_graph_mut().insert_node(op) } + + /// Extracts the trace access information from a given [Mir] `Boundary`. + /// Returns a [Mir] `TraceAccess` with the corresponding segment id and column if the boundary + /// wraps a valid trace access column, or raises a diagnostic if the trace access has a size + /// greater than 1. + /// + /// Note: the boundary expression must only reference the constrained trace access, not the + /// whole boundary constraint expression. + /// + /// # Panics + /// Panics if the boundary does not wrap a trace access column, which should have been caught + /// during semantic analysis. + fn extract_trace_from_boundary( + &self, + boundary: MirBoundary, + ) -> Result { + let Op::Value(value) = boundary.expr.borrow().deref().clone() else { + unreachable!(); // Raise diag + }; + + let trace_access = match value.value.clone() { + SpannedMirValue { + value: MirValue::TraceAccess(trace_access), + .. + } => trace_access, + SpannedMirValue { + value: MirValue::TraceAccessBinding(trace_access_binding), + span, + } => { + if trace_access_binding.size != 1 { + self.diagnostics.diagnostic(Severity::Error) + .with_message("invalid boundary constraint") + .with_primary_label(span, "this has a trace access binding with a size greater than 1") + .with_note("Boundary constraints require both sides of the constraint to be single columns.") + .emit(); + return Err(CompileError::Failed); + } + MirTraceAccess { + segment: trace_access_binding.segment, + column: trace_access_binding.offset, + row_offset: 0, + } + }, + SpannedMirValue { + value: MirValue::BusAccess(bus_access), .. + } => { + let bus = bus_access.bus; + let name = bus.borrow().deref().name(); + let column = self.bus_bindings_map.get(&name).unwrap(); + MirTraceAccess::new(TraceSegmentId::Aux, *column, bus_access.row_offset) + }, + _ => unreachable!("Expected TraceAccess or BusAccess, received {:?}", value.value), /* Raise diag */ + }; + + Ok(trace_access) + } + + /// Marks a boundary as constrained by the given trace access information. + /// This is used to ensure that we do not insert duplicate boundary constraints in the graph. + fn mark_constrained_boundary( + &mut self, + trace_access: MirTraceAccess, + boundary: &MirBoundary, + ) -> Result<(), CompileError> { + if let Some(prev) = self + .trace_columns + .get_mut(&trace_access.segment) + .expect("Boundary constraint on an unknown trace segment") + .mark_constrained(boundary.span(), trace_access.column, boundary.kind) + { + self.diagnostics + .diagnostic(Severity::Error) + .with_message("overlapping boundary constraints") + .with_primary_label( + boundary.span(), + "this constrains a column and boundary that has already been constrained", + ) + .with_secondary_label(prev, "previous constraint occurs here") + .emit(); + return Err(CompileError::Failed); + } + Ok(()) + } } // HELPERS FUNCTIONS diff --git a/codegen/ace/tests/regressions/SimpleIntegrityAux.dot b/codegen/ace/tests/regressions/SimpleIntegrityAux.dot index 7408f033f..af7a23064 100644 --- a/codegen/ace/tests/regressions/SimpleIntegrityAux.dot +++ b/codegen/ace/tests/regressions/SimpleIntegrityAux.dot @@ -1,6 +1,6 @@ digraph G { -const0 [label="0"] -const1 [label="1"] +const0 [label="1"] +const1 [label="0"] input0 [label="PI[stack_inputs][0]"] input8 [label="α"] input9 [label="β"] @@ -12,18 +12,18 @@ input36 [label="zⁿ"] input37 [label="g⁻¹"] input38 [label="zᵐᵃˣ"] input39 [label="g⁻²"] -op0 [label="op0\ninput35 - const1"] +op0 [label="op0\ninput35 - const0"] input35 -> op0 -const1 -> op0 +const0 -> op0 op1 [label="op1\ninput35 - input39"] input35 -> op1 input39 -> op1 op2 [label="op2\ninput35 - input37"] input35 -> op2 input37 -> op2 -op3 [label="op3\ninput36 - const1"] +op3 [label="op3\ninput36 - const0"] input36 -> op3 -const1 -> op3 +const0 -> op3 op4 [label="op4\nop0 × op1"] op0 -> op4 op1 -> op4 diff --git a/codegen/winterfell/src/air/boundary_constraints.rs b/codegen/winterfell/src/air/boundary_constraints.rs index d3909249f..fd77f030c 100644 --- a/codegen/winterfell/src/air/boundary_constraints.rs +++ b/codegen/winterfell/src/air/boundary_constraints.rs @@ -55,7 +55,10 @@ fn add_main_trace_assertions(func_body: &mut codegen::Function, ir: &Air) { split_boundary_constraint(ir.constraint_graph(), constraint.node_index()); debug_assert_eq!(trace_access.segment, TraceSegmentId::Main); - let expr_root_string = expr_root.to_string(ir, ElemType::Base, TraceSegmentId::Main); + let expr_root_string = match expr_root { + Some(node_index) => node_index.to_string(ir, ElemType::Base, TraceSegmentId::Main), + None => "Felt::ZERO".to_string(), // If no root, the expression is zero + }; let assertion = format!( "result.push(Assertion::single({}, {}, {}));", @@ -101,7 +104,10 @@ fn add_aux_trace_assertions(func_body: &mut codegen::Function, ir: &Air) { split_boundary_constraint(ir.constraint_graph(), constraint.node_index()); debug_assert_eq!(trace_access.segment, TraceSegmentId::Aux); - let expr_root_string = expr_root.to_string(ir, ElemType::Ext, TraceSegmentId::Aux); + let expr_root_string = match expr_root { + Some(node_index) => node_index.to_string(ir, ElemType::Ext, TraceSegmentId::Aux), + None => "E::ZERO".to_string(), // If no root, the expression is zero + }; let assertion = format!( "result.push(Assertion::single({}, {}, {}));", @@ -132,24 +138,31 @@ fn domain_to_str(domain: ConstraintDomain) -> String { /// boundary constraint expression must hold, as well as the node index that represents the root /// of the constraint expression that must equal zero during evaluation. /// -/// TODO: replace panics with Result and Error +/// Note: If, after the CSE pass, the boundary constraint is a single trace access, +/// we return None for the constraint expression. This expression should then be assumed to be zero +/// during evaluation by the caller. pub fn split_boundary_constraint( graph: &AlgebraicGraph, index: &NodeIndex, -) -> (TraceAccess, NodeIndex) { +) -> (TraceAccess, Option) { let node = graph.node(index); - match node.op() { + match *node.op() { Operation::Sub(lhs, rhs) => { - if let Operation::Value(air_ir::Value::TraceAccess(trace_access)) = graph.node(lhs).op() + if let Operation::Value(air_ir::Value::TraceAccess(trace_access)) = + graph.node(&lhs).op() { debug_assert_eq!(trace_access.row_offset, 0); - (*trace_access, *rhs) + (*trace_access, Some(rhs)) } else { panic!( "InvalidUsage: index {index:?} is not the constraint root of a boundary constraint" ); } }, + Operation::Value(air_ir::Value::TraceAccess(trace_access)) => { + debug_assert_eq!(trace_access.row_offset, 0); + (trace_access, None) + }, _ => panic!("InvalidUsage: index {index:?} is not the root index of a constraint"), } } diff --git a/mir/derive-ir/src/lib.rs b/mir/derive-ir/src/lib.rs index 3bb4642a4..090241f8a 100644 --- a/mir/derive-ir/src/lib.rs +++ b/mir/derive-ir/src/lib.rs @@ -71,8 +71,8 @@ use syn::{DeriveInput, parse_macro_input}; /// ``` /// /// We then generate an implementation for each state, with a method for each field., which -/// transitions to the next state. The `enum_wrapper`` attribute is used to automatically wrap the -/// struct in a `Link`` to reduce boilerplate. The only supported `enum_wrapper`s are +/// transitions to the next state. The `enum_wrapper` attribute is used to automatically wrap the +/// struct in a `Link` to reduce boilerplate. The only supported `enum_wrapper`s are /// `Op` and `Root`. /// /// The following API is generated: diff --git a/mir/src/ir/node.rs b/mir/src/ir/node.rs index cd433ffbb..9bf47a696 100644 --- a/mir/src/ir/node.rs +++ b/mir/src/ir/node.rs @@ -229,7 +229,8 @@ impl Link { Root::None(span) => Node::None(*span), }; } else { - unreachable!(); + // If the [Node] is stale, we set it to None + to_update = Node::None(self.span()); } *self.borrow_mut() = to_update; diff --git a/mir/src/ir/owner.rs b/mir/src/ir/owner.rs index e8f76080b..d1a2c513b 100644 --- a/mir/src/ir/owner.rs +++ b/mir/src/ir/owner.rs @@ -208,7 +208,8 @@ impl Link { Root::None(span) => Owner::None(*span), }; } else { - unreachable!(); + // If the [Owner] is stale, we set it to None + to_update = Owner::None(self.span()); } *self.borrow_mut() = to_update; diff --git a/mir/src/lib.rs b/mir/src/lib.rs index e2a3443d8..450b33a6a 100644 --- a/mir/src/lib.rs +++ b/mir/src/lib.rs @@ -31,7 +31,8 @@ impl Pass for MirPasses<'_> { fn run<'a>(&mut self, input: Self::Input<'a>) -> Result, Self::Error> { let mut passes = passes::AstToMir::new(self.diagnostics) .chain(passes::Inlining::new(self.diagnostics)) - .chain(passes::Unrolling::new(self.diagnostics)); + .chain(passes::Unrolling::new(self.diagnostics)) + .chain(passes::ConstantPropagation::new(self.diagnostics)); passes.run(input) } } diff --git a/mir/src/passes/constant_propagation.rs b/mir/src/passes/constant_propagation.rs index 4744bf4bf..139eeab65 100644 --- a/mir/src/passes/constant_propagation.rs +++ b/mir/src/passes/constant_propagation.rs @@ -1,18 +1,18 @@ +use std::ops::Deref; + +use air_parser::ast::AccessType; use air_pass::Pass; -use miden_diagnostics::DiagnosticsHandler; +use miden_diagnostics::{DiagnosticsHandler, SourceSpan, Spanned}; use super::visitor::Visitor; use crate::{ - ir::{Link, Mir, Node}, CompileError, + ir::{ + BackLink, ConstantValue, Graph, Link, Mir, MirValue, Node, Op, Parent, SpannedMirValue, + Value, + }, }; -/// TODO MIR: -/// If needed, implement constant propagation / folding pass on MIR -/// Run through every operation in the graph -/// If we can deduce the resulting value based on the constants of the operands, -/// replace the operation itself with a constant -/// pub struct ConstantPropagation<'a> { #[allow(unused)] diagnostics: &'a DiagnosticsHandler, @@ -31,11 +31,94 @@ impl Pass for ConstantPropagation<'_> { } impl<'a> ConstantPropagation<'a> { - #[allow(unused)] pub fn new(diagnostics: &'a DiagnosticsHandler) -> Self { - Self { - diagnostics, - work_stack: vec![], + Self { diagnostics, work_stack: vec![] } + } +} + +// For the ConstantPropagation, we use a tweaked version of the Visitor trait, +// each visit_*_bis function returns an Option> instead of Result<(), CompileError>, +// to mutate the nodes (e.g. modifying a Add(lhs, rhs) to Value(lhs + rhs)). +impl ConstantPropagation<'_> { + fn visit_add_bis( + &mut self, + _graph: &mut Graph, + add: Link, + ) -> Result>, CompileError> { + // safe to unwrap because we just dispatched on it + let add_ref = add.as_add().unwrap(); + let lhs = add_ref.lhs.clone(); + let rhs = add_ref.rhs.clone(); + + if let Some(0) = get_inner_const(&lhs) { + Ok(Some(rhs)) + } else if let Some(0) = get_inner_const(&rhs) { + Ok(Some(lhs)) + } else { + try_fold_const_binary_op(lhs, rhs, add.clone(), add_ref.span()) + } + } + + fn visit_sub_bis( + &mut self, + _graph: &mut Graph, + sub: Link, + ) -> Result>, CompileError> { + // safe to unwrap because we just dispatched on it + let sub_ref = sub.as_sub().unwrap(); + let lhs = sub_ref.lhs.clone(); + let rhs = sub_ref.rhs.clone(); + + if let Some(0) = get_inner_const(&rhs) { + Ok(Some(lhs)) + } else { + try_fold_const_binary_op(lhs, rhs, sub.clone(), sub_ref.span()) + } + } + + fn visit_mul_bis( + &mut self, + _graph: &mut Graph, + mul: Link, + ) -> Result>, CompileError> { + // safe to unwrap because we just dispatched on it + let mul_ref = mul.as_mul().unwrap(); + let lhs = mul_ref.lhs.clone(); + let rhs = mul_ref.rhs.clone(); + + match (get_inner_const(&lhs), get_inner_const(&rhs)) { + (Some(0), _) | (_, Some(0)) => Ok(Some(Value::create(SpannedMirValue { + value: MirValue::Constant(ConstantValue::Felt(0)), + span: mul_ref.span, + }))), + (Some(1), _) => Ok(Some(rhs)), + (_, Some(1)) => Ok(Some(lhs)), + _ => try_fold_const_binary_op(lhs, rhs, mul.clone(), mul_ref.span()), + } + } + + fn visit_exp_bis( + &mut self, + _graph: &mut Graph, + exp: Link, + ) -> Result>, CompileError> { + // safe to unwrap because we just dispatched on it + let exp_ref = exp.as_exp().unwrap(); + let lhs = exp_ref.lhs.clone(); + let rhs = exp_ref.rhs.clone(); + + if let Some(0) = get_inner_const(&lhs) { + Ok(Some(Value::create(SpannedMirValue { + value: MirValue::Constant(ConstantValue::Felt(0)), + span: exp_ref.span, + }))) + } else if let Some(0) = get_inner_const(&rhs) { + Ok(Some(Value::create(SpannedMirValue { + value: MirValue::Constant(ConstantValue::Felt(1)), + span: exp_ref.span, + }))) + } else { + try_fold_const_binary_op(lhs, rhs, exp.clone(), exp_ref.span()) } } } @@ -44,22 +127,160 @@ impl Visitor for ConstantPropagation<'_> { fn work_stack(&mut self) -> &mut Vec> { &mut self.work_stack } - fn root_nodes_to_visit( - &self, - graph: &crate::ir::Graph, - ) -> Vec> { + + // We visit all boundary constraints and all integrity constraints + // No need to visit the functions or evaluators, as they should have been inlined before this + // pass + fn root_nodes_to_visit(&self, graph: &Graph) -> Vec> { let boundary_constraints_roots_ref = graph.boundary_constraints_roots.borrow(); let integrity_constraints_roots_ref = graph.integrity_constraints_roots.borrow(); + let bus_roots: Vec<_> = graph + .buses + .values() + .flat_map(|b| b.borrow().clone().columns.into_iter().collect::>()) + .collect(); let combined_roots = boundary_constraints_roots_ref .clone() .into_iter() .map(|bc| bc.as_node()) - .chain( - integrity_constraints_roots_ref - .clone() - .into_iter() - .map(|ic| ic.as_node()), - ); + .chain(integrity_constraints_roots_ref.clone().into_iter().map(|ic| ic.as_node())) + .chain(bus_roots.into_iter().map(|b| b.as_node())); combined_roots.collect() } + + fn visit_node(&mut self, graph: &mut Graph, node: Link) -> Result<(), CompileError> { + if node.is_stale() { + return Ok(()); + } + + // In this pass, we both need to dispatch the visitor depending on the node type, + // and also mutate the node if needed. We implement custom visit_*_bis methods + // that returns a Some(updated_node) if we need to update the node's value. + let updated_op: Result>, CompileError> = match node.borrow().deref() { + Node::Add(a) => to_link_and(a.clone(), graph, |g, el| self.visit_add_bis(g, el)), + Node::Sub(s) => to_link_and(s.clone(), graph, |g, el| self.visit_sub_bis(g, el)), + Node::Mul(m) => to_link_and(m.clone(), graph, |g, el| self.visit_mul_bis(g, el)), + Node::Exp(e) => to_link_and(e.clone(), graph, |g, el| self.visit_exp_bis(g, el)), + // For all the following cases, there is nothing to fold + Node::Vector(_) + | Node::Matrix(_) + | Node::Enf(_) + | Node::Boundary(_) + | Node::BusOp(_) + | Node::Value(_) + | Node::Accessor(_) + | Node::None(_) => Ok(None), + Node::Function(_) | Node::Evaluator(_) | Node::Call(_) => { + unreachable!( + "Unexpected node during Mir's ConstantPropagation: Function, Evaluators and Calls should have been inlined before this pass. Found: {:?}", + node + ); + }, + Node::If(_) | Node::For(_) | Node::Fold(_) | Node::Parameter(_) => { + unreachable!( + "Unexpected node during Mir's ConstantPropagation: If, For, Fold and Parameter should have been unrolled before this pass. Found: {:?}", + node + ); + }, + }; + + // We update the node if needed + if let Some(updated_op) = updated_op? { + node.as_op().unwrap().set(&updated_op); + } + + Ok(()) + } +} + +// HELPERS FUNCTIONS +// ================================================================================================ + +/// Tries to upgrade a BackLink to a Link and apply a given closure to it if it is successful, +/// otherwise returns None. +fn to_link_and( + back: BackLink, + graph: &mut Graph, + f: F, +) -> Result>, CompileError> +where + F: FnOnce(&mut Graph, Link) -> Result>, CompileError>, +{ + if let Some(op) = back.to_link() { + f(graph, op) + } else { + Ok(None) + } +} + +/// Helper function to extract the constant felt value from a Link if it is one. +fn get_inner_const(value: &Link) -> Option { + match value.borrow().deref() { + Op::Value(Value { + value: + SpannedMirValue { + value: MirValue::Constant(ConstantValue::Felt(c)), + .. + }, + .. + }) => Some(*c), + Op::Accessor(accessor) => { + match (accessor.access_type.clone(), accessor.indexable.borrow().deref()) { + (AccessType::Default, _) => get_inner_const(&accessor.indexable), + (AccessType::Index(index), Op::Vector(vector)) => { + let vec_children = vector.children(); + let vec_ref = vec_children.borrow(); + vec_ref.get(index).and_then(get_inner_const) + }, + (AccessType::Matrix(row, col), Op::Matrix(matrix)) => { + let mat_children = matrix.children(); + let mat_ref = mat_children.borrow(); + mat_ref.get(row).and_then(|row| { + if let Op::Vector(row_vector) = row.borrow().deref() { + let row_children = row_vector.children(); + let row_ref = row_children.borrow(); + row_ref.get(col).and_then(get_inner_const) + } else { + None + } + }) + }, + _ => None, + } + }, + _ => None, + } +} + +/// Helper function to fold constant binary operations (Add, Sub, Mul, Exp) +/// into their resulting value if both operands are constant values. +fn try_fold_const_binary_op( + lhs: Link, + rhs: Link, + parent: Link, + span: SourceSpan, +) -> Result>, CompileError> { + let mut updated_binary_op = None; + + if let (Some(lhs_const), Some(rhs_const)) = (get_inner_const(&lhs), get_inner_const(&rhs)) { + let folded = match parent.borrow().deref() { + Op::Add(_) => lhs_const.checked_add(rhs_const), + Op::Sub(_) => lhs_const.checked_sub(rhs_const), + Op::Mul(_) => lhs_const.checked_mul(rhs_const), + Op::Exp(_) => { + let rhs_const = rhs_const.try_into().map_err(|_| CompileError::Failed)?; + lhs_const.checked_pow(rhs_const) + }, + _ => unreachable!("Unexpected parent operation: {:?}", parent), + }; + if let Some(folded) = folded { + let new_value = Value::create(SpannedMirValue { + value: MirValue::Constant(crate::ir::ConstantValue::Felt(folded)), + span, + }); + updated_binary_op = Some(new_value); + } + } + + Ok(updated_binary_op) } diff --git a/mir/src/passes/mod.rs b/mir/src/passes/mod.rs index cfc1c9126..ceeb09a6e 100644 --- a/mir/src/passes/mod.rs +++ b/mir/src/passes/mod.rs @@ -1,14 +1,11 @@ +mod constant_propagation; mod inlining; mod translate; mod unrolling; mod visitor; -// Note: ConstantPropagation and ValueNumbering are not implemented yet in the MIR -//mod constant_propagation; -//mod value_numbering; -//pub use constant_propagation::ConstantPropagation; -//pub use value_numbering::ValueNumbering; use std::{collections::HashMap, ops::Deref}; +pub use constant_propagation::ConstantPropagation; pub use inlining::Inlining; use miden_diagnostics::Spanned; pub use translate::AstToMir; diff --git a/mir/src/passes/unrolling.rs b/mir/src/passes/unrolling.rs index d6db74630..e118413c3 100644 --- a/mir/src/passes/unrolling.rs +++ b/mir/src/passes/unrolling.rs @@ -602,16 +602,8 @@ impl UnrollingFirstPass<'_> { }; } - // Note: This will ensure the resulting constraint is in the form `Sub(x,y)` - // representing `enf x = y` - let zero_node = Value::create(SpannedMirValue { - span: Default::default(), - value: MirValue::Constant(ConstantValue::Felt(0)), - }); // The following unwrap is safe as we always have at least one constraint above - let new_node_with_sub_zero = Sub::create(cur_node.unwrap(), zero_node, if_ref.span()); - - new_vec.push(new_node_with_sub_zero); + new_vec.push(cur_node.unwrap()); } // 4. Add all the constraints that are bus-related @@ -1076,48 +1068,26 @@ impl Visitor for UnrollingSecondPass<'_> { let new_node = self.nodes_to_replace.get(&body.get_ptr()).unwrap().1.clone(); // If there is a selector, we need to enforce it on the body - let new_node_with_selector_if_needed = if let Some(selector) = - self.for_inlining_context.clone().unwrap().selector - { - if let Op::Vector(new_node_vector) = new_node.borrow().deref() { - let new_node_vec = new_node_vector.children().borrow().deref().clone(); - let mut new_vec = vec![]; - for new_node_child in new_node_vec.into_iter() { - let zero_node = Value::create(SpannedMirValue { - span: Default::default(), - value: MirValue::Constant(ConstantValue::Felt(0)), - }); - // FIXME: The Sub here is used to keep the form of Eq(lhs, rhs) -> - // Enf(Sub(lhs, rhs) == 0), but it introduces an - // unnecessary zero node - let new_node_child_with_selector = Sub::create( - Mul::create( + let new_node_with_selector_if_needed = + if let Some(selector) = self.for_inlining_context.clone().unwrap().selector { + if let Op::Vector(new_node_vector) = new_node.borrow().deref() { + let new_node_vec = new_node_vector.children().borrow().deref().clone(); + let mut new_vec = vec![]; + for new_node_child in new_node_vec.into_iter() { + let new_node_child_with_selector = Mul::create( duplicate_node(selector.clone(), &mut HashMap::new()), new_node_child, root.span(), - ), - zero_node, - root.span(), - ); - new_vec.push(new_node_child_with_selector); + ); + new_vec.push(new_node_child_with_selector); + } + Vector::create(new_vec, root.span()) + } else { + Mul::create(selector, new_node, root.span()) } - Vector::create(new_vec, root.span()) } else { - let zero_node = Value::create(SpannedMirValue { - span: Default::default(), - value: MirValue::Constant(ConstantValue::Felt(0)), - }); - // FIXME: The Sub here is used to keep the form of Eq(lhs, rhs) -> Enf(Sub(lhs, - // rhs) == 0), but it introduces an unnecessary zero node - Sub::create( - Mul::create(selector, new_node, root.span()), - zero_node, - root.span(), - ) - } - } else { - new_node - }; + new_node + }; root.as_op().unwrap().set(&new_node_with_selector_if_needed); diff --git a/mir/src/passes/value_numbering.rs b/mir/src/passes/value_numbering.rs deleted file mode 100644 index 7fcb4838d..000000000 --- a/mir/src/passes/value_numbering.rs +++ /dev/null @@ -1,63 +0,0 @@ -use air_pass::Pass; -use miden_diagnostics::DiagnosticsHandler; - -use super::visitor::Visitor; -use crate::{ - ir::{Link, Mir, Node}, - CompileError, -}; - -/// TODO MIR: -/// If needed, implement value numbering pass on MIR -/// See https://en.wikipedia.org/wiki/Value_numbering -/// -pub struct ValueNumbering<'a> { - #[allow(unused)] - diagnostics: &'a DiagnosticsHandler, - work_stack: Vec>, -} - -impl Pass for ValueNumbering<'_> { - type Input<'a> = Mir; - type Output<'a> = Mir; - type Error = CompileError; - - fn run<'a>(&mut self, mut ir: Self::Input<'a>) -> Result, Self::Error> { - Visitor::run(self, ir.constraint_graph_mut())?; - Ok(ir) - } -} - -impl<'a> ValueNumbering<'a> { - #[allow(unused)] - pub fn new(diagnostics: &'a DiagnosticsHandler) -> Self { - Self { - diagnostics, - work_stack: vec![], - } - } -} - -impl Visitor for ValueNumbering<'_> { - fn work_stack(&mut self) -> &mut Vec> { - &mut self.work_stack - } - fn root_nodes_to_visit( - &self, - graph: &crate::ir::Graph, - ) -> Vec> { - let boundary_constraints_roots_ref = graph.boundary_constraints_roots.borrow(); - let integrity_constraints_roots_ref = graph.integrity_constraints_roots.borrow(); - let combined_roots = boundary_constraints_roots_ref - .clone() - .into_iter() - .map(|bc| bc.as_node()) - .chain( - integrity_constraints_roots_ref - .clone() - .into_iter() - .map(|ic| ic.as_node()), - ); - combined_roots.collect() - } -}