diff --git a/Cargo.toml b/Cargo.toml index 5962f18..41637e8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -60,3 +60,7 @@ with-gen-benches-poseidon-top-level = [] [[bench]] name = "benchmark" harness = false + +[profile.profiling] +inherits = "release" +debug = true \ No newline at end of file diff --git a/src/symmetric/tweak_hash/poseidon.rs b/src/symmetric/tweak_hash/poseidon.rs index 7ab2d7b..e3e5395 100644 --- a/src/symmetric/tweak_hash/poseidon.rs +++ b/src/symmetric/tweak_hash/poseidon.rs @@ -17,6 +17,7 @@ use crate::{F, PackedF}; use super::TweakableHash; use p3_koala_bear::Poseidon2KoalaBear; +use std::cell::RefCell; const DOMAIN_PARAMETERS_LENGTH: usize = 4; /// The state width for compressing a single hash in a chain. @@ -201,34 +202,46 @@ where ); let rate = WIDTH - capacity_value.len(); - let extra_elements = (rate - (input.len() % rate)) % rate; - let mut input_vector = input.to_vec(); - // We pad the input with zeros to make its length a multiple of the rate. - // - // This is safe because the input's original length is effectively encoded - // in the `capacity_value`, which serves as a domain separator. - input_vector.resize(input.len() + extra_elements, A::ZERO); - // initialize let mut state = [A::ZERO; WIDTH]; state[rate..].copy_from_slice(capacity_value); - // absorb - for chunk in input_vector.chunks(rate) { - for i in 0..chunk.len() { - state[i] += chunk[i]; + let extra_elements = (rate - (input.len() % rate)) % rate; + // Instead of converting the input to a vector, resizing and feeding the data into the + // sponge, we instead fill in the vector from all chunks until we are left with a non + // full chunk. We only add to the state, so padded data does not mutate `state` at all. + + // 1. fill in all full chunks and permute + let mut it = input.chunks_exact(rate); + for chunk in &mut it { + // iterate the chunks + for (s, &x) in state.iter_mut().take(rate).zip(chunk) { + *s += x; } perm.permute_mut(&mut state); } + // 2. fill the remainder and extend with zeros + let remainder = rate - extra_elements; + if remainder > 0 { + for (i, x) in it.remainder().iter().enumerate() { + state[i] += *x; + } + // was a remainder, so permute. No need to mutate `state` as we *add* only anyway + perm.permute_mut(&mut state); + } // squeeze - let mut out = vec![]; - while out.len() < OUT_LEN { - out.extend_from_slice(&state[..rate]); - perm.permute_mut(&mut state); + let mut out = [A::ZERO; OUT_LEN]; + let mut out_idx = 0; + while out_idx < OUT_LEN { + let chunk_size = (OUT_LEN - out_idx).min(rate); + out[out_idx..out_idx + chunk_size].copy_from_slice(&state[..chunk_size]); + out_idx += chunk_size; + if out_idx < OUT_LEN { + perm.permute_mut(&mut state); + } } - let slice = &out[0..OUT_LEN]; - slice.try_into().expect("Length mismatch") + out } /// A tweakable hash function implemented using Poseidon2 @@ -297,12 +310,14 @@ impl< [single] => { // we compress parameter, tweak, message let perm = poseidon2_16(); - let combined_input: Vec = parameter - .iter() - .chain(tweak_fe.iter()) - .chain(single.iter()) - .copied() - .collect(); + + // Build input on stack: [parameter | tweak | message] + let mut combined_input = [F::ZERO; CHAIN_COMPRESSION_WIDTH]; + combined_input[..PARAMETER_LEN].copy_from_slice(¶meter.0); + combined_input[PARAMETER_LEN..PARAMETER_LEN + TWEAK_LEN].copy_from_slice(&tweak_fe); + combined_input[PARAMETER_LEN + TWEAK_LEN..PARAMETER_LEN + TWEAK_LEN + HASH_LEN] + .copy_from_slice(&single.0); + FieldArray( poseidon_compress::( &perm, @@ -314,13 +329,17 @@ impl< [left, right] => { // we compress parameter, tweak, message (now containing two parts) let perm = poseidon2_24(); - let combined_input: Vec = parameter - .iter() - .chain(tweak_fe.iter()) - .chain(left.iter()) - .chain(right.iter()) - .copied() - .collect(); + + // Build input on stack: [parameter | tweak | left | right] + let mut combined_input = [F::ZERO; MERGE_COMPRESSION_WIDTH]; + combined_input[..PARAMETER_LEN].copy_from_slice(¶meter.0); + combined_input[PARAMETER_LEN..PARAMETER_LEN + TWEAK_LEN].copy_from_slice(&tweak_fe); + combined_input[PARAMETER_LEN + TWEAK_LEN..PARAMETER_LEN + TWEAK_LEN + HASH_LEN] + .copy_from_slice(&left.0); + combined_input[PARAMETER_LEN + TWEAK_LEN + HASH_LEN + ..PARAMETER_LEN + TWEAK_LEN + 2 * HASH_LEN] + .copy_from_slice(&right.0); + FieldArray( poseidon_compress::( &perm, @@ -492,6 +511,17 @@ impl< let capacity_val: [PackedF; CAPACITY] = poseidon_safe_domain_separator::(&sponge_perm, &lengths).map(PackedF::from); + // Compute sponge input length. Required to init packed input vector for each rayon worker + let sponge_tweak_offset = PARAMETER_LEN; + let sponge_chains_offset = PARAMETER_LEN + TWEAK_LEN; + let sponge_input_len = PARAMETER_LEN + TWEAK_LEN + NUM_CHUNKS * HASH_LEN; + + // We use thread-local storage to guarantee the `packed_leaf_input` vector is only allocated + // once per thread + thread_local! { + static PACKED_LEAF_INPUT: RefCell> = const { RefCell::new(Vec::new()) }; + } + // PARALLEL SIMD PROCESSING // // Process epochs in batches of size `width`. @@ -581,38 +611,44 @@ impl< // Assemble the sponge input. // Layout: [parameter | tree_tweak | all_chain_ends] - let sponge_tweak_offset = PARAMETER_LEN; - let sponge_chains_offset = PARAMETER_LEN + TWEAK_LEN; - let sponge_input_len = PARAMETER_LEN + TWEAK_LEN + NUM_CHUNKS * HASH_LEN; - - let mut packed_leaf_input = vec![PackedF::ZERO; sponge_input_len]; + // NOTE: `packed_leaf_input` is preallocated per thread. We overwrite the entire + // vector in each iteration, so no need to `fill(0)`! + let packed_leaves = PACKED_LEAF_INPUT.with_borrow_mut(|packed_leaf_input| { + // Resize on first use for this thread + if packed_leaf_input.len() != sponge_input_len { + packed_leaf_input.resize(sponge_input_len, PackedF::ZERO); + } - // Copy pre-packed parameter - packed_leaf_input[..PARAMETER_LEN].copy_from_slice(&packed_parameter); - - // Pack tree tweaks directly (level 0 for bottom-layer leaves) - pack_fn_into::( - &mut packed_leaf_input, - sponge_tweak_offset, - |t_idx, lane| { - Self::tree_tweak(0, epoch_chunk[lane]).to_field_elements::() - [t_idx] - }, - ); + // Copy pre-packed parameter + packed_leaf_input[..PARAMETER_LEN].copy_from_slice(&packed_parameter); + + // Pack tree tweaks directly (level 0 for bottom-layer leaves) + pack_fn_into::( + packed_leaf_input, + sponge_tweak_offset, + |t_idx, lane| { + Self::tree_tweak(0, epoch_chunk[lane]).to_field_elements::() + [t_idx] + }, + ); - // Copy all chain ends (already packed) - let dst = &mut packed_leaf_input[sponge_chains_offset .. sponge_chains_offset + packed_chains.len() * HASH_LEN]; - for (dst_chunk, src_chain) in dst.chunks_exact_mut(HASH_LEN).zip(packed_chains.iter()) { - dst_chunk.copy_from_slice(src_chain); - } + // Copy all chain ends (already packed) + let dst = &mut packed_leaf_input[sponge_chains_offset + ..sponge_chains_offset + packed_chains.len() * HASH_LEN]; + for (dst_chunk, src_chain) in + dst.chunks_exact_mut(HASH_LEN).zip(packed_chains.iter()) + { + dst_chunk.copy_from_slice(src_chain); + } - // Apply the sponge hash to produce the leaf. - // This absorbs all chain ends and squeezes out the final hash. - let packed_leaves = poseidon_sponge::( - &sponge_perm, - &capacity_val, - &packed_leaf_input, - ); + // Apply the sponge hash to produce the leaf. + // This absorbs all chain ends and squeezes out the final hash. + poseidon_sponge::( + &sponge_perm, + &capacity_val, + packed_leaf_input, + ) + }); // STEP 4: UNPACK RESULTS TO SCALAR REPRESENTATION // @@ -1655,13 +1691,13 @@ mod tests { let parameter = PoseidonTweak44::rand_parameter(&mut rng); let children: Vec<_> = (0..num_pairs * 2) - .map(|_| PoseidonTweak44::rand_domain(&mut rng)) - .collect(); + .map(|_| PoseidonTweak44::rand_domain(&mut rng)) + .collect(); let simd_result = - PoseidonTweak44::compute_tree_layer(¶meter, level, parent_start, &children); + PoseidonTweak44::compute_tree_layer(¶meter, level, parent_start, &children); let scalar_result = - compute_tree_layer_scalar::(¶meter, level, parent_start, &children); + compute_tree_layer_scalar::(¶meter, level, parent_start, &children); prop_assert_eq!(simd_result.len(), num_pairs); prop_assert_eq!(simd_result, scalar_result);