-
Notifications
You must be signed in to change notification settings - Fork 9
Reduce heap allocations #28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
ba510ab
e3e5ceb
ed2f132
595dbe0
41d240e
816fbbe
9453208
a1abd1e
7d7d0aa
d01fa2c
ceab87d
1305e33
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -17,6 +17,7 @@ use crate::{F, PackedF}; | |||||||||
| use super::TweakableHash; | ||||||||||
|
|
||||||||||
| use p3_koala_bear::Poseidon2KoalaBear; | ||||||||||
| use std::cell::RefCell; | ||||||||||
|
|
||||||||||
| const DOMAIN_PARAMETERS_LENGTH: usize = 4; | ||||||||||
| /// The state width for compressing a single hash in a chain. | ||||||||||
|
|
@@ -201,34 +202,46 @@ where | |||||||||
| ); | ||||||||||
| let rate = WIDTH - capacity_value.len(); | ||||||||||
|
|
||||||||||
| let extra_elements = (rate - (input.len() % rate)) % rate; | ||||||||||
| let mut input_vector = input.to_vec(); | ||||||||||
| // We pad the input with zeros to make its length a multiple of the rate. | ||||||||||
| // | ||||||||||
| // This is safe because the input's original length is effectively encoded | ||||||||||
| // in the `capacity_value`, which serves as a domain separator. | ||||||||||
| input_vector.resize(input.len() + extra_elements, A::ZERO); | ||||||||||
|
|
||||||||||
| // initialize | ||||||||||
| let mut state = [A::ZERO; WIDTH]; | ||||||||||
| state[rate..].copy_from_slice(capacity_value); | ||||||||||
|
|
||||||||||
| // absorb | ||||||||||
| for chunk in input_vector.chunks(rate) { | ||||||||||
| for i in 0..chunk.len() { | ||||||||||
| state[i] += chunk[i]; | ||||||||||
| let extra_elements = (rate - (input.len() % rate)) % rate; | ||||||||||
| // Instead of converting the input to a vector, resizing and feeding the data into the | ||||||||||
| // sponge, we instead fill in the vector from all chunks until we are left with a non | ||||||||||
| // full chunk. We only add to the state, so padded data does not mutate `state` at all. | ||||||||||
|
|
||||||||||
| // 1. fill in all full chunks and permute | ||||||||||
| let mut it = input.chunks_exact(rate); | ||||||||||
| for chunk in &mut it { | ||||||||||
| // iterate the chunks | ||||||||||
| for (s, &x) in state.iter_mut().take(rate).zip(chunk) { | ||||||||||
| *s += x; | ||||||||||
| } | ||||||||||
| perm.permute_mut(&mut state); | ||||||||||
| } | ||||||||||
| // 2. fill the remainder and extend with zeros | ||||||||||
| let remainder = rate - extra_elements; | ||||||||||
| if remainder > 0 { | ||||||||||
| for (i, x) in it.remainder().iter().enumerate() { | ||||||||||
| state[i] += *x; | ||||||||||
| } | ||||||||||
| // was a remainder, so permute. No need to mutate `state` as we *add* only anyway | ||||||||||
| perm.permute_mut(&mut state); | ||||||||||
| } | ||||||||||
|
|
||||||||||
| // squeeze | ||||||||||
| let mut out = vec![]; | ||||||||||
| while out.len() < OUT_LEN { | ||||||||||
| out.extend_from_slice(&state[..rate]); | ||||||||||
| perm.permute_mut(&mut state); | ||||||||||
| let mut out = [A::ZERO; OUT_LEN]; | ||||||||||
| let mut out_idx = 0; | ||||||||||
| while out_idx < OUT_LEN { | ||||||||||
| let chunk_size = (OUT_LEN - out_idx).min(rate); | ||||||||||
| out[out_idx..out_idx + chunk_size].copy_from_slice(&state[..chunk_size]); | ||||||||||
| out_idx += chunk_size; | ||||||||||
| if out_idx < OUT_LEN { | ||||||||||
| perm.permute_mut(&mut state); | ||||||||||
| } | ||||||||||
|
Comment on lines
+240
to
+242
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here to match exactly the logic we had before I think that we don't need the if statement. Here this piece of code is an important factor for security and so we should not diminish the number of permutations. Due to the if after A good exercise to do for this kind of sensitive refactoring is to check the outputs of a call before and after the modification, they should be the same since we didn't change any logic here.
Suggested change
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe I screwed that up. Will check next year, sorry. |
||||||||||
| } | ||||||||||
| let slice = &out[0..OUT_LEN]; | ||||||||||
| slice.try_into().expect("Length mismatch") | ||||||||||
| out | ||||||||||
| } | ||||||||||
|
|
||||||||||
| /// A tweakable hash function implemented using Poseidon2 | ||||||||||
|
|
@@ -297,12 +310,14 @@ impl< | |||||||||
| [single] => { | ||||||||||
| // we compress parameter, tweak, message | ||||||||||
| let perm = poseidon2_16(); | ||||||||||
| let combined_input: Vec<F> = parameter | ||||||||||
| .iter() | ||||||||||
| .chain(tweak_fe.iter()) | ||||||||||
| .chain(single.iter()) | ||||||||||
| .copied() | ||||||||||
| .collect(); | ||||||||||
|
|
||||||||||
| // Build input on stack: [parameter | tweak | message] | ||||||||||
| let mut combined_input = [F::ZERO; CHAIN_COMPRESSION_WIDTH]; | ||||||||||
| combined_input[..PARAMETER_LEN].copy_from_slice(¶meter.0); | ||||||||||
| combined_input[PARAMETER_LEN..PARAMETER_LEN + TWEAK_LEN].copy_from_slice(&tweak_fe); | ||||||||||
| combined_input[PARAMETER_LEN + TWEAK_LEN..PARAMETER_LEN + TWEAK_LEN + HASH_LEN] | ||||||||||
| .copy_from_slice(&single.0); | ||||||||||
|
|
||||||||||
| FieldArray( | ||||||||||
| poseidon_compress::<F, _, CHAIN_COMPRESSION_WIDTH, HASH_LEN>( | ||||||||||
| &perm, | ||||||||||
|
|
@@ -314,13 +329,17 @@ impl< | |||||||||
| [left, right] => { | ||||||||||
| // we compress parameter, tweak, message (now containing two parts) | ||||||||||
| let perm = poseidon2_24(); | ||||||||||
| let combined_input: Vec<F> = parameter | ||||||||||
| .iter() | ||||||||||
| .chain(tweak_fe.iter()) | ||||||||||
| .chain(left.iter()) | ||||||||||
| .chain(right.iter()) | ||||||||||
| .copied() | ||||||||||
| .collect(); | ||||||||||
|
|
||||||||||
| // Build input on stack: [parameter | tweak | left | right] | ||||||||||
| let mut combined_input = [F::ZERO; MERGE_COMPRESSION_WIDTH]; | ||||||||||
| combined_input[..PARAMETER_LEN].copy_from_slice(¶meter.0); | ||||||||||
| combined_input[PARAMETER_LEN..PARAMETER_LEN + TWEAK_LEN].copy_from_slice(&tweak_fe); | ||||||||||
| combined_input[PARAMETER_LEN + TWEAK_LEN..PARAMETER_LEN + TWEAK_LEN + HASH_LEN] | ||||||||||
| .copy_from_slice(&left.0); | ||||||||||
| combined_input[PARAMETER_LEN + TWEAK_LEN + HASH_LEN | ||||||||||
| ..PARAMETER_LEN + TWEAK_LEN + 2 * HASH_LEN] | ||||||||||
| .copy_from_slice(&right.0); | ||||||||||
|
|
||||||||||
| FieldArray( | ||||||||||
| poseidon_compress::<F, _, MERGE_COMPRESSION_WIDTH, HASH_LEN>( | ||||||||||
| &perm, | ||||||||||
|
|
@@ -492,6 +511,17 @@ impl< | |||||||||
| let capacity_val: [PackedF; CAPACITY] = | ||||||||||
| poseidon_safe_domain_separator::<CAPACITY>(&sponge_perm, &lengths).map(PackedF::from); | ||||||||||
|
|
||||||||||
| // Compute sponge input length. Required to init packed input vector for each rayon worker | ||||||||||
| let sponge_tweak_offset = PARAMETER_LEN; | ||||||||||
| let sponge_chains_offset = PARAMETER_LEN + TWEAK_LEN; | ||||||||||
| let sponge_input_len = PARAMETER_LEN + TWEAK_LEN + NUM_CHUNKS * HASH_LEN; | ||||||||||
|
|
||||||||||
| // We use thread-local storage to guarantee the `packed_leaf_input` vector is only allocated | ||||||||||
| // once per thread | ||||||||||
| thread_local! { | ||||||||||
| static PACKED_LEAF_INPUT: RefCell<Vec<PackedF>> = const { RefCell::new(Vec::new()) }; | ||||||||||
| } | ||||||||||
|
|
||||||||||
| // PARALLEL SIMD PROCESSING | ||||||||||
| // | ||||||||||
| // Process epochs in batches of size `width`. | ||||||||||
|
|
@@ -581,38 +611,44 @@ impl< | |||||||||
|
|
||||||||||
| // Assemble the sponge input. | ||||||||||
| // Layout: [parameter | tree_tweak | all_chain_ends] | ||||||||||
| let sponge_tweak_offset = PARAMETER_LEN; | ||||||||||
| let sponge_chains_offset = PARAMETER_LEN + TWEAK_LEN; | ||||||||||
| let sponge_input_len = PARAMETER_LEN + TWEAK_LEN + NUM_CHUNKS * HASH_LEN; | ||||||||||
|
|
||||||||||
| let mut packed_leaf_input = vec![PackedF::ZERO; sponge_input_len]; | ||||||||||
| // NOTE: `packed_leaf_input` is preallocated per thread. We overwrite the entire | ||||||||||
| // vector in each iteration, so no need to `fill(0)`! | ||||||||||
| let packed_leaves = PACKED_LEAF_INPUT.with_borrow_mut(|packed_leaf_input| { | ||||||||||
| // Resize on first use for this thread | ||||||||||
| if packed_leaf_input.len() != sponge_input_len { | ||||||||||
| packed_leaf_input.resize(sponge_input_len, PackedF::ZERO); | ||||||||||
| } | ||||||||||
|
|
||||||||||
| // Copy pre-packed parameter | ||||||||||
| packed_leaf_input[..PARAMETER_LEN].copy_from_slice(&packed_parameter); | ||||||||||
|
|
||||||||||
| // Pack tree tweaks directly (level 0 for bottom-layer leaves) | ||||||||||
| pack_fn_into::<TWEAK_LEN>( | ||||||||||
| &mut packed_leaf_input, | ||||||||||
| sponge_tweak_offset, | ||||||||||
| |t_idx, lane| { | ||||||||||
| Self::tree_tweak(0, epoch_chunk[lane]).to_field_elements::<TWEAK_LEN>() | ||||||||||
| [t_idx] | ||||||||||
| }, | ||||||||||
| ); | ||||||||||
| // Copy pre-packed parameter | ||||||||||
| packed_leaf_input[..PARAMETER_LEN].copy_from_slice(&packed_parameter); | ||||||||||
|
|
||||||||||
| // Pack tree tweaks directly (level 0 for bottom-layer leaves) | ||||||||||
| pack_fn_into::<TWEAK_LEN>( | ||||||||||
| packed_leaf_input, | ||||||||||
| sponge_tweak_offset, | ||||||||||
| |t_idx, lane| { | ||||||||||
| Self::tree_tweak(0, epoch_chunk[lane]).to_field_elements::<TWEAK_LEN>() | ||||||||||
| [t_idx] | ||||||||||
| }, | ||||||||||
| ); | ||||||||||
|
|
||||||||||
| // Copy all chain ends (already packed) | ||||||||||
| let dst = &mut packed_leaf_input[sponge_chains_offset .. sponge_chains_offset + packed_chains.len() * HASH_LEN]; | ||||||||||
| for (dst_chunk, src_chain) in dst.chunks_exact_mut(HASH_LEN).zip(packed_chains.iter()) { | ||||||||||
| dst_chunk.copy_from_slice(src_chain); | ||||||||||
| } | ||||||||||
| // Copy all chain ends (already packed) | ||||||||||
| let dst = &mut packed_leaf_input[sponge_chains_offset | ||||||||||
| ..sponge_chains_offset + packed_chains.len() * HASH_LEN]; | ||||||||||
| for (dst_chunk, src_chain) in | ||||||||||
| dst.chunks_exact_mut(HASH_LEN).zip(packed_chains.iter()) | ||||||||||
| { | ||||||||||
| dst_chunk.copy_from_slice(src_chain); | ||||||||||
| } | ||||||||||
|
|
||||||||||
| // Apply the sponge hash to produce the leaf. | ||||||||||
| // This absorbs all chain ends and squeezes out the final hash. | ||||||||||
| let packed_leaves = poseidon_sponge::<PackedF, _, MERGE_COMPRESSION_WIDTH, HASH_LEN>( | ||||||||||
| &sponge_perm, | ||||||||||
| &capacity_val, | ||||||||||
| &packed_leaf_input, | ||||||||||
| ); | ||||||||||
| // Apply the sponge hash to produce the leaf. | ||||||||||
| // This absorbs all chain ends and squeezes out the final hash. | ||||||||||
| poseidon_sponge::<PackedF, _, MERGE_COMPRESSION_WIDTH, HASH_LEN>( | ||||||||||
| &sponge_perm, | ||||||||||
| &capacity_val, | ||||||||||
| packed_leaf_input, | ||||||||||
| ) | ||||||||||
| }); | ||||||||||
|
|
||||||||||
| // STEP 4: UNPACK RESULTS TO SCALAR REPRESENTATION | ||||||||||
| // | ||||||||||
|
|
@@ -1655,13 +1691,13 @@ mod tests { | |||||||||
|
|
||||||||||
| let parameter = PoseidonTweak44::rand_parameter(&mut rng); | ||||||||||
| let children: Vec<_> = (0..num_pairs * 2) | ||||||||||
| .map(|_| PoseidonTweak44::rand_domain(&mut rng)) | ||||||||||
| .collect(); | ||||||||||
| .map(|_| PoseidonTweak44::rand_domain(&mut rng)) | ||||||||||
| .collect(); | ||||||||||
|
|
||||||||||
| let simd_result = | ||||||||||
| PoseidonTweak44::compute_tree_layer(¶meter, level, parent_start, &children); | ||||||||||
| PoseidonTweak44::compute_tree_layer(¶meter, level, parent_start, &children); | ||||||||||
| let scalar_result = | ||||||||||
| compute_tree_layer_scalar::<PoseidonTweak44>(¶meter, level, parent_start, &children); | ||||||||||
| compute_tree_layer_scalar::<PoseidonTweak44>(¶meter, level, parent_start, &children); | ||||||||||
|
|
||||||||||
| prop_assert_eq!(simd_result.len(), num_pairs); | ||||||||||
| prop_assert_eq!(simd_result, scalar_result); | ||||||||||
|
|
||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean here by "no need to mutate state"? Since you pass a mutable reference of state to the permutation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The point is that before we zero pad. but because one only adds to the state, we don't have to add zeroes for the previously padded data.