|
| 1 | +use alloc::vec::Vec; |
| 2 | +use lambdaworks_math::field::{ |
| 3 | + fields::mersenne31::field::MERSENNE_31_PRIME_FIELD_ORDER, traits::IsField, |
| 4 | +}; |
| 5 | +use sha3::{ |
| 6 | + digest::{ExtendableOutput, Update}, |
| 7 | + Shake128, Shake128Reader, |
| 8 | +}; |
| 9 | + |
| 10 | +mod utils; |
| 11 | +use utils::*; |
| 12 | + |
| 13 | +// Ported from https://github.com/Plonky3/Plonky3/blob/main/monolith |
| 14 | + |
| 15 | +pub const NUM_BARS: usize = 8; |
| 16 | +const MATRIX_CIRC_MDS_16_MERSENNE31_MONOLITH: [u32; 16] = [ |
| 17 | + 61402, 17845, 26798, 59689, 12021, 40901, 41351, 27521, 56951, 12034, 53865, 43244, 7454, |
| 18 | + 33823, 28750, 1108, |
| 19 | +]; |
| 20 | + |
| 21 | +pub struct MonolithMersenne31<const WIDTH: usize, const NUM_FULL_ROUNDS: usize> { |
| 22 | + round_constants: Vec<Vec<u32>>, |
| 23 | + lookup1: Vec<u16>, |
| 24 | + lookup2: Vec<u16>, |
| 25 | +} |
| 26 | + |
| 27 | +impl<const WIDTH: usize, const NUM_FULL_ROUNDS: usize> Default |
| 28 | + for MonolithMersenne31<WIDTH, NUM_FULL_ROUNDS> |
| 29 | +{ |
| 30 | + fn default() -> Self { |
| 31 | + Self::new() |
| 32 | + } |
| 33 | +} |
| 34 | + |
| 35 | +impl<const WIDTH: usize, const NUM_FULL_ROUNDS: usize> MonolithMersenne31<WIDTH, NUM_FULL_ROUNDS> { |
| 36 | + pub fn new() -> Self { |
| 37 | + assert!(WIDTH >= 8); |
| 38 | + assert!(WIDTH <= 24); |
| 39 | + assert!(WIDTH % 4 == 0); |
| 40 | + Self { |
| 41 | + round_constants: Self::instantiate_round_constants(), |
| 42 | + lookup1: Self::instantiate_lookup1(), |
| 43 | + lookup2: Self::instantiate_lookup2(), |
| 44 | + } |
| 45 | + } |
| 46 | + |
| 47 | + fn instantiate_round_constants() -> Vec<Vec<u32>> { |
| 48 | + let mut shake = Shake128::default(); |
| 49 | + shake.update("Monolith".as_bytes()); |
| 50 | + shake.update(&[WIDTH as u8, (NUM_FULL_ROUNDS + 1) as u8]); |
| 51 | + shake.update(&MERSENNE_31_PRIME_FIELD_ORDER.to_le_bytes()); |
| 52 | + shake.update(&[8, 8, 8, 7]); |
| 53 | + let mut shake_finalized = shake.finalize_xof(); |
| 54 | + random_matrix(&mut shake_finalized, NUM_FULL_ROUNDS, WIDTH) |
| 55 | + } |
| 56 | + |
| 57 | + fn instantiate_lookup1() -> Vec<u16> { |
| 58 | + (0..=u16::MAX) |
| 59 | + .map(|i| { |
| 60 | + let hi = (i >> 8) as u8; |
| 61 | + let lo = i as u8; |
| 62 | + ((Self::s_box(hi) as u16) << 8) | Self::s_box(lo) as u16 |
| 63 | + }) |
| 64 | + .collect() |
| 65 | + } |
| 66 | + |
| 67 | + fn instantiate_lookup2() -> Vec<u16> { |
| 68 | + (0..(1 << 15)) |
| 69 | + .map(|i| { |
| 70 | + let hi = (i >> 8) as u8; |
| 71 | + let lo: u8 = i as u8; |
| 72 | + ((Self::final_s_box(hi) as u16) << 8) | Self::s_box(lo) as u16 |
| 73 | + }) |
| 74 | + .collect() |
| 75 | + } |
| 76 | + |
| 77 | + fn s_box(y: u8) -> u8 { |
| 78 | + (y ^ !y.rotate_left(1) & y.rotate_left(2) & y.rotate_left(3)).rotate_left(1) |
| 79 | + } |
| 80 | + |
| 81 | + fn final_s_box(y: u8) -> u8 { |
| 82 | + debug_assert_eq!(y >> 7, 0); |
| 83 | + |
| 84 | + let y_rot_1 = (y >> 6) | (y << 1); |
| 85 | + let y_rot_2 = (y >> 5) | (y << 2); |
| 86 | + |
| 87 | + let tmp = (y ^ !y_rot_1 & y_rot_2) & 0x7F; |
| 88 | + ((tmp >> 6) | (tmp << 1)) & 0x7F |
| 89 | + } |
| 90 | + |
| 91 | + pub fn permutation(&self, state: &mut Vec<u32>) { |
| 92 | + self.concrete(state); |
| 93 | + for round in 0..NUM_FULL_ROUNDS { |
| 94 | + self.bars(state); |
| 95 | + Self::bricks(state); |
| 96 | + self.concrete(state); |
| 97 | + Self::add_round_constants(state, &self.round_constants[round]); |
| 98 | + } |
| 99 | + self.bars(state); |
| 100 | + Self::bricks(state); |
| 101 | + self.concrete(state); |
| 102 | + } |
| 103 | + |
| 104 | + // MDS matrix |
| 105 | + fn concrete(&self, state: &mut Vec<u32>) { |
| 106 | + *state = if WIDTH == 16 { |
| 107 | + Self::apply_circulant(&mut MATRIX_CIRC_MDS_16_MERSENNE31_MONOLITH.clone(), state) |
| 108 | + } else { |
| 109 | + let mut shake = Shake128::default(); |
| 110 | + shake.update("Monolith".as_bytes()); |
| 111 | + shake.update(&[WIDTH as u8, (NUM_FULL_ROUNDS + 1) as u8]); |
| 112 | + shake.update(&MERSENNE_31_PRIME_FIELD_ORDER.to_le_bytes()); |
| 113 | + shake.update(&[16, 15]); |
| 114 | + shake.update("MDS".as_bytes()); |
| 115 | + let mut shake_finalized = shake.finalize_xof(); |
| 116 | + Self::apply_cauchy_mds_matrix(&mut shake_finalized, state) |
| 117 | + }; |
| 118 | + } |
| 119 | + |
| 120 | + // S-box lookups |
| 121 | + fn bars(&self, state: &mut [u32]) { |
| 122 | + for state in state.iter_mut().take(NUM_BARS) { |
| 123 | + *state = (self.lookup2[(*state >> 16) as u16 as usize] as u32) << 16 |
| 124 | + | self.lookup1[*state as u16 as usize] as u32; |
| 125 | + } |
| 126 | + } |
| 127 | + |
| 128 | + // (x_{n+1})² = (x_n)² + x_{n+1} |
| 129 | + fn bricks(state: &mut [u32]) { |
| 130 | + for i in (0..state.len() - 1).rev() { |
| 131 | + state[i + 1] = F::add(&state[i + 1], &F::square(&state[i])); |
| 132 | + } |
| 133 | + } |
| 134 | + |
| 135 | + fn add_round_constants(state: &mut [u32], round_constants: &[u32]) { |
| 136 | + for (x, rc) in state.iter_mut().zip(round_constants) { |
| 137 | + *x = F::add(x, rc); |
| 138 | + } |
| 139 | + } |
| 140 | + |
| 141 | + // O(n²) |
| 142 | + fn apply_circulant(circ_matrix: &mut [u32], input: &[u32]) -> Vec<u32> { |
| 143 | + let mut output = vec![F::zero(); WIDTH]; |
| 144 | + for out_i in output.iter_mut().take(WIDTH - 1) { |
| 145 | + *out_i = dot_product(circ_matrix, input); |
| 146 | + circ_matrix.rotate_right(1); |
| 147 | + } |
| 148 | + output[WIDTH - 1] = dot_product(circ_matrix, input); |
| 149 | + output |
| 150 | + } |
| 151 | + |
| 152 | + fn apply_cauchy_mds_matrix(shake: &mut Shake128Reader, to_multiply: &[u32]) -> Vec<u32> { |
| 153 | + let mut output = vec![F::zero(); WIDTH]; |
| 154 | + |
| 155 | + let bits: u32 = u64::BITS |
| 156 | + - (MERSENNE_31_PRIME_FIELD_ORDER as u64) |
| 157 | + .saturating_sub(1) |
| 158 | + .leading_zeros(); |
| 159 | + |
| 160 | + let x_mask = (1 << (bits - 9)) - 1; |
| 161 | + let y_mask = ((1 << bits) - 1) >> 2; |
| 162 | + |
| 163 | + let y = get_random_y_i(shake, WIDTH, x_mask, y_mask); |
| 164 | + let mut x = y.clone(); |
| 165 | + x.iter_mut().for_each(|x_i| *x_i &= x_mask); |
| 166 | + |
| 167 | + for (i, x_i) in x.iter().enumerate() { |
| 168 | + for (j, yj) in y.iter().enumerate() { |
| 169 | + output[i] = F::add(&output[i], &F::div(&to_multiply[j], &F::add(x_i, yj))); |
| 170 | + } |
| 171 | + } |
| 172 | + |
| 173 | + output |
| 174 | + } |
| 175 | +} |
| 176 | + |
| 177 | +#[cfg(test)] |
| 178 | +mod tests { |
| 179 | + use super::*; |
| 180 | + |
| 181 | + fn get_test_input(width: usize) -> Vec<u32> { |
| 182 | + (0..width).map(|i| F::from_base_type(i as u32)).collect() |
| 183 | + } |
| 184 | + |
| 185 | + #[test] |
| 186 | + fn from_plonky3_concrete_width_16() { |
| 187 | + let mut input = get_test_input(16); |
| 188 | + MonolithMersenne31::<16, 5>::new().concrete(&mut input); |
| 189 | + assert_eq!( |
| 190 | + input, |
| 191 | + [ |
| 192 | + 3470365, 3977394, 4042151, 4025740, 4431233, 4264086, 3927003, 4259216, 3872757, |
| 193 | + 3957178, 3820319, 3690660, 4023081, 3592814, 3688803, 3928040 |
| 194 | + ] |
| 195 | + ); |
| 196 | + } |
| 197 | + |
| 198 | + #[test] |
| 199 | + fn from_plonky3_concrete_width_12() { |
| 200 | + let mut input = get_test_input(12); |
| 201 | + MonolithMersenne31::<12, 5>::new().concrete(&mut input); |
| 202 | + assert_eq!( |
| 203 | + input, |
| 204 | + [ |
| 205 | + 365726249, 1885122147, 379836542, 860204337, 889139350, 1052715727, 151617411, |
| 206 | + 700047874, 925910152, 339398001, 721459023, 464532407 |
| 207 | + ] |
| 208 | + ); |
| 209 | + } |
| 210 | + |
| 211 | + #[test] |
| 212 | + fn from_plonky3_width_16() { |
| 213 | + let mut input = get_test_input(16); |
| 214 | + MonolithMersenne31::<16, 5>::new().permutation(&mut input); |
| 215 | + assert_eq!( |
| 216 | + input, |
| 217 | + [ |
| 218 | + 609156607, 290107110, 1900746598, 1734707571, 2050994835, 1648553244, 1307647296, |
| 219 | + 1941164548, 1707113065, 1477714255, 1170160793, 93800695, 769879348, 375548503, |
| 220 | + 1989726444, 1349325635 |
| 221 | + ] |
| 222 | + ); |
| 223 | + } |
| 224 | +} |
0 commit comments