-
Notifications
You must be signed in to change notification settings - Fork 9
Reduce heap allocations #28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 9 commits
ba510ab
e3e5ceb
ed2f132
595dbe0
41d240e
816fbbe
9453208
a1abd1e
7d7d0aa
d01fa2c
ceab87d
1305e33
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| use std::hint::black_box; | ||
|
|
||
| use leansig::signature::{ | ||
| SignatureScheme, | ||
| generalized_xmss::instantiations_poseidon_top_level::lifetime_2_to_the_8::SIGTopLevelTargetSumLifetime8Dim64Base8, | ||
| }; | ||
|
|
||
| fn main() { | ||
| let mut rng = rand::rng(); | ||
|
|
||
| // 2^8 lifetime, full activation | ||
| let activation_duration = SIGTopLevelTargetSumLifetime8Dim64Base8::LIFETIME as usize; | ||
|
|
||
| eprintln!("Running single key_gen for 2^8 lifetime..."); | ||
| let (pk, sk) = black_box(SIGTopLevelTargetSumLifetime8Dim64Base8::key_gen( | ||
| &mut rng, | ||
| 0, | ||
| activation_duration, | ||
| )); | ||
| eprintln!("Done. pk size: {} bytes", std::mem::size_of_val(&pk)); | ||
|
|
||
| // Prevent optimization from removing the key_gen call | ||
| black_box((pk, sk)); | ||
| } |
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,33 @@ | ||
| use std::hint::black_box; | ||
|
|
||
| use leansig::signature::{ | ||
| SignatureScheme, | ||
| generalized_xmss::instantiations_poseidon_top_level::lifetime_2_to_the_32::size_optimized::SIGTopLevelTargetSumLifetime32Dim32Base26, | ||
| }; | ||
|
|
||
| /// Cap activation duration to 2^18 to keep runtime reasonable (same as benchmark) | ||
| const MAX_LOG_ACTIVATION_DURATION: usize = 18; | ||
|
|
||
| fn main() { | ||
| let mut rng = rand::rng(); | ||
|
|
||
| // 2^32 lifetime, activation capped at 2^18 | ||
| let activation_duration = std::cmp::min( | ||
| 1 << MAX_LOG_ACTIVATION_DURATION, | ||
| SIGTopLevelTargetSumLifetime32Dim32Base26::LIFETIME as usize, | ||
| ); | ||
|
|
||
| eprintln!( | ||
| "Running single key_gen for 2^32 lifetime (activation 2^{})...", | ||
| MAX_LOG_ACTIVATION_DURATION | ||
| ); | ||
| let (pk, sk) = black_box(SIGTopLevelTargetSumLifetime32Dim32Base26::key_gen( | ||
| &mut rng, | ||
| 0, | ||
| activation_duration, | ||
| )); | ||
| eprintln!("Done. pk size: {} bytes", std::mem::size_of_val(&pk)); | ||
|
|
||
| // Prevent optimization from removing the key_gen call | ||
| black_box((pk, sk)); | ||
| } |
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -17,6 +17,8 @@ use crate::{F, PackedF}; | |||||||||
| use super::TweakableHash; | ||||||||||
|
|
||||||||||
| use p3_koala_bear::Poseidon2KoalaBear; | ||||||||||
| use std::cell::RefCell; | ||||||||||
| use thread_local::ThreadLocal; | ||||||||||
|
|
||||||||||
| const DOMAIN_PARAMETERS_LENGTH: usize = 4; | ||||||||||
| /// The state width for compressing a single hash in a chain. | ||||||||||
|
|
@@ -201,34 +203,47 @@ where | |||||||||
| ); | ||||||||||
| let rate = WIDTH - capacity_value.len(); | ||||||||||
|
|
||||||||||
| let extra_elements = (rate - (input.len() % rate)) % rate; | ||||||||||
| let mut input_vector = input.to_vec(); | ||||||||||
| // We pad the input with zeros to make its length a multiple of the rate. | ||||||||||
| // | ||||||||||
| // This is safe because the input's original length is effectively encoded | ||||||||||
| // in the `capacity_value`, which serves as a domain separator. | ||||||||||
| input_vector.resize(input.len() + extra_elements, A::ZERO); | ||||||||||
|
|
||||||||||
| // initialize | ||||||||||
| let mut state = [A::ZERO; WIDTH]; | ||||||||||
| state[rate..].copy_from_slice(capacity_value); | ||||||||||
|
|
||||||||||
| // absorb | ||||||||||
| for chunk in input_vector.chunks(rate) { | ||||||||||
| let extra_elements = (rate - (input.len() % rate)) % rate; | ||||||||||
| // Instead of converting the input to a vector, resizing and feeding the data into the | ||||||||||
| // sponge, we instead fill in the vector from all chunks until we are left with a non | ||||||||||
| // full chunk. We only add to the state, so padded data does not mutate `state` at all. | ||||||||||
|
|
||||||||||
| // 1. fill in all full chunks and permute | ||||||||||
| let mut it = input.chunks_exact(rate); | ||||||||||
| for chunk in &mut it { | ||||||||||
| //input.chunks_exact(rate) { | ||||||||||
| // iterate the chunks | ||||||||||
| for i in 0..chunk.len() { | ||||||||||
| state[i] += chunk[i]; | ||||||||||
| } | ||||||||||
| perm.permute_mut(&mut state); | ||||||||||
| } | ||||||||||
tcoratger marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
| // 2. fill the remainder and extend with zeros | ||||||||||
| let remainder = rate - extra_elements; | ||||||||||
| if remainder > 0 { | ||||||||||
| for (i, x) in it.remainder().iter().enumerate() { | ||||||||||
| state[i] += *x; | ||||||||||
| } | ||||||||||
| // was a remainder, so permute. No need to mutate `state` as we *add* only anyway | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do you mean here by "no need to mutate state"? Since you pass a mutable reference of state to the permutation?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The point is that before we zero pad. but because one only adds to the state, we don't have to add zeroes for the previously padded data. |
||||||||||
| perm.permute_mut(&mut state); | ||||||||||
| } | ||||||||||
|
|
||||||||||
| // squeeze | ||||||||||
| let mut out = vec![]; | ||||||||||
| while out.len() < OUT_LEN { | ||||||||||
| out.extend_from_slice(&state[..rate]); | ||||||||||
| perm.permute_mut(&mut state); | ||||||||||
| let mut out = [A::ZERO; OUT_LEN]; | ||||||||||
| let mut out_idx = 0; | ||||||||||
| while out_idx < OUT_LEN { | ||||||||||
| let chunk_size = (OUT_LEN - out_idx).min(rate); | ||||||||||
| out[out_idx..out_idx + chunk_size].copy_from_slice(&state[..chunk_size]); | ||||||||||
| out_idx += chunk_size; | ||||||||||
| if out_idx < OUT_LEN { | ||||||||||
| perm.permute_mut(&mut state); | ||||||||||
| } | ||||||||||
|
Comment on lines
+240
to
+242
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here to match exactly the logic we had before I think that we don't need the if statement. Here this piece of code is an important factor for security and so we should not diminish the number of permutations. Due to the if after A good exercise to do for this kind of sensitive refactoring is to check the outputs of a call before and after the modification, they should be the same since we didn't change any logic here.
Suggested change
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe I screwed that up. Will check next year, sorry. |
||||||||||
| } | ||||||||||
| let slice = &out[0..OUT_LEN]; | ||||||||||
| slice.try_into().expect("Length mismatch") | ||||||||||
| out | ||||||||||
| } | ||||||||||
|
|
||||||||||
| /// A tweakable hash function implemented using Poseidon2 | ||||||||||
|
|
@@ -297,12 +312,14 @@ impl< | |||||||||
| [single] => { | ||||||||||
| // we compress parameter, tweak, message | ||||||||||
| let perm = poseidon2_16(); | ||||||||||
| let combined_input: Vec<F> = parameter | ||||||||||
| .iter() | ||||||||||
| .chain(tweak_fe.iter()) | ||||||||||
| .chain(single.iter()) | ||||||||||
| .copied() | ||||||||||
| .collect(); | ||||||||||
|
|
||||||||||
| // Build input on stack: [parameter | tweak | message] | ||||||||||
| let mut combined_input = [F::ZERO; CHAIN_COMPRESSION_WIDTH]; | ||||||||||
| combined_input[..PARAMETER_LEN].copy_from_slice(¶meter.0); | ||||||||||
| combined_input[PARAMETER_LEN..PARAMETER_LEN + TWEAK_LEN].copy_from_slice(&tweak_fe); | ||||||||||
| combined_input[PARAMETER_LEN + TWEAK_LEN..PARAMETER_LEN + TWEAK_LEN + HASH_LEN] | ||||||||||
| .copy_from_slice(&single.0); | ||||||||||
|
|
||||||||||
| FieldArray( | ||||||||||
| poseidon_compress::<F, _, CHAIN_COMPRESSION_WIDTH, HASH_LEN>( | ||||||||||
| &perm, | ||||||||||
|
|
@@ -314,13 +331,17 @@ impl< | |||||||||
| [left, right] => { | ||||||||||
| // we compress parameter, tweak, message (now containing two parts) | ||||||||||
| let perm = poseidon2_24(); | ||||||||||
| let combined_input: Vec<F> = parameter | ||||||||||
| .iter() | ||||||||||
| .chain(tweak_fe.iter()) | ||||||||||
| .chain(left.iter()) | ||||||||||
| .chain(right.iter()) | ||||||||||
| .copied() | ||||||||||
| .collect(); | ||||||||||
|
|
||||||||||
| // Build input on stack: [parameter | tweak | left | right] | ||||||||||
| let mut combined_input = [F::ZERO; MERGE_COMPRESSION_WIDTH]; | ||||||||||
| combined_input[..PARAMETER_LEN].copy_from_slice(¶meter.0); | ||||||||||
| combined_input[PARAMETER_LEN..PARAMETER_LEN + TWEAK_LEN].copy_from_slice(&tweak_fe); | ||||||||||
| combined_input[PARAMETER_LEN + TWEAK_LEN..PARAMETER_LEN + TWEAK_LEN + HASH_LEN] | ||||||||||
| .copy_from_slice(&left.0); | ||||||||||
| combined_input[PARAMETER_LEN + TWEAK_LEN + HASH_LEN | ||||||||||
| ..PARAMETER_LEN + TWEAK_LEN + 2 * HASH_LEN] | ||||||||||
| .copy_from_slice(&right.0); | ||||||||||
|
|
||||||||||
| FieldArray( | ||||||||||
| poseidon_compress::<F, _, MERGE_COMPRESSION_WIDTH, HASH_LEN>( | ||||||||||
| &perm, | ||||||||||
|
|
@@ -492,6 +513,15 @@ impl< | |||||||||
| let capacity_val: [PackedF; CAPACITY] = | ||||||||||
| poseidon_safe_domain_separator::<CAPACITY>(&sponge_perm, &lengths).map(PackedF::from); | ||||||||||
|
|
||||||||||
| // Compute sponge input length. Required to init packed input vector for each rayon worker | ||||||||||
| let sponge_tweak_offset = PARAMETER_LEN; | ||||||||||
| let sponge_chains_offset = PARAMETER_LEN + TWEAK_LEN; | ||||||||||
| let sponge_input_len = PARAMETER_LEN + TWEAK_LEN + NUM_CHUNKS * HASH_LEN; | ||||||||||
|
|
||||||||||
| // We use a thread local storage to guarantee the `packed_leaf_input` vector is only allocated | ||||||||||
| // once per thread | ||||||||||
| let tls: ThreadLocal<RefCell<Vec<PackedF>>> = ThreadLocal::new(); | ||||||||||
|
|
||||||||||
| // PARALLEL SIMD PROCESSING | ||||||||||
| // | ||||||||||
| // Process epochs in batches of size `width`. | ||||||||||
|
|
@@ -508,11 +538,18 @@ impl< | |||||||||
| // | ||||||||||
| // This layout enables efficient SIMD operations across epochs. | ||||||||||
|
|
||||||||||
| let cell = tls.get_or(|| { | ||||||||||
| RefCell::new(vec![PackedF::ZERO; sponge_input_len]) | ||||||||||
| }); | ||||||||||
| let mut packed_leaf_input = cell.borrow_mut(); | ||||||||||
| // reset not needed | ||||||||||
|
|
||||||||||
| let mut packed_chains: [[PackedF; HASH_LEN]; NUM_CHUNKS] = | ||||||||||
| array::from_fn(|c_idx| { | ||||||||||
| // Generate starting points for this chain across all epochs. | ||||||||||
| let starts: [_; PackedF::WIDTH] = array::from_fn(|lane| { | ||||||||||
| PRF::get_domain_element(prf_key, epoch_chunk[lane], c_idx as u64).into() | ||||||||||
| PRF::get_domain_element(prf_key, epoch_chunk[lane], c_idx as u64) | ||||||||||
| .into() | ||||||||||
| }); | ||||||||||
|
|
||||||||||
| // Transpose to vertical packing for SIMD efficiency. | ||||||||||
|
|
@@ -565,10 +602,10 @@ impl< | |||||||||
| // Apply the hash function to advance the chain. | ||||||||||
| // This single call processes all epochs in parallel. | ||||||||||
| *packed_chain = | ||||||||||
| poseidon_compress::<PackedF, _, CHAIN_COMPRESSION_WIDTH, HASH_LEN>( | ||||||||||
| &chain_perm, | ||||||||||
| &packed_input, | ||||||||||
| ); | ||||||||||
| poseidon_compress::<PackedF, _, CHAIN_COMPRESSION_WIDTH, HASH_LEN>( | ||||||||||
| &chain_perm, | ||||||||||
| &packed_input, | ||||||||||
| ); | ||||||||||
| } | ||||||||||
| } | ||||||||||
|
|
||||||||||
|
|
@@ -581,11 +618,8 @@ impl< | |||||||||
|
|
||||||||||
| // Assemble the sponge input. | ||||||||||
| // Layout: [parameter | tree_tweak | all_chain_ends] | ||||||||||
| let sponge_tweak_offset = PARAMETER_LEN; | ||||||||||
| let sponge_chains_offset = PARAMETER_LEN + TWEAK_LEN; | ||||||||||
| let sponge_input_len = PARAMETER_LEN + TWEAK_LEN + NUM_CHUNKS * HASH_LEN; | ||||||||||
|
|
||||||||||
| let mut packed_leaf_input = vec![PackedF::ZERO; sponge_input_len]; | ||||||||||
| // NOTE: `packed_leaf_input` is preallocated per thread. We overwrite the entire | ||||||||||
| // vector in each iteration, so no need to `fill(0)`! | ||||||||||
|
|
||||||||||
| // Copy pre-packed parameter | ||||||||||
| packed_leaf_input[..PARAMETER_LEN].copy_from_slice(&packed_parameter); | ||||||||||
|
|
@@ -596,30 +630,35 @@ impl< | |||||||||
| sponge_tweak_offset, | ||||||||||
| |t_idx, lane| { | ||||||||||
| Self::tree_tweak(0, epoch_chunk[lane]).to_field_elements::<TWEAK_LEN>() | ||||||||||
| [t_idx] | ||||||||||
| [t_idx] | ||||||||||
| }, | ||||||||||
| ); | ||||||||||
|
|
||||||||||
| // Copy all chain ends (already packed) | ||||||||||
| let dst = &mut packed_leaf_input[sponge_chains_offset .. sponge_chains_offset + packed_chains.len() * HASH_LEN]; | ||||||||||
| for (dst_chunk, src_chain) in dst.chunks_exact_mut(HASH_LEN).zip(packed_chains.iter()) { | ||||||||||
| dst_chunk.copy_from_slice(src_chain); | ||||||||||
| } | ||||||||||
| let dst = &mut packed_leaf_input[sponge_chains_offset | ||||||||||
| ..sponge_chains_offset + packed_chains.len() * HASH_LEN]; | ||||||||||
| for (dst_chunk, src_chain) in | ||||||||||
| dst.chunks_exact_mut(HASH_LEN).zip(packed_chains.iter()) | ||||||||||
| { | ||||||||||
| dst_chunk.copy_from_slice(src_chain); | ||||||||||
| } | ||||||||||
|
|
||||||||||
| // Apply the sponge hash to produce the leaf. | ||||||||||
| // This absorbs all chain ends and squeezes out the final hash. | ||||||||||
| let packed_leaves = poseidon_sponge::<PackedF, _, MERGE_COMPRESSION_WIDTH, HASH_LEN>( | ||||||||||
| &sponge_perm, | ||||||||||
| &capacity_val, | ||||||||||
| &packed_leaf_input, | ||||||||||
| ); | ||||||||||
| let packed_leaves = | ||||||||||
| poseidon_sponge::<PackedF, _, MERGE_COMPRESSION_WIDTH, HASH_LEN>( | ||||||||||
| &sponge_perm, | ||||||||||
| &capacity_val, | ||||||||||
| &packed_leaf_input, | ||||||||||
| ); | ||||||||||
|
|
||||||||||
| // STEP 4: UNPACK RESULTS TO SCALAR REPRESENTATION | ||||||||||
| // | ||||||||||
| // Convert from vertical packing back to scalar layout. | ||||||||||
| // Each lane becomes one leaf in the output slice. | ||||||||||
| unpack_array(&packed_leaves, leaves_chunk); | ||||||||||
| }); | ||||||||||
| }, | ||||||||||
| ); | ||||||||||
|
|
||||||||||
| // HANDLE REMAINDER EPOCHS | ||||||||||
| // | ||||||||||
|
|
@@ -1655,13 +1694,13 @@ mod tests { | |||||||||
|
|
||||||||||
| let parameter = PoseidonTweak44::rand_parameter(&mut rng); | ||||||||||
| let children: Vec<_> = (0..num_pairs * 2) | ||||||||||
| .map(|_| PoseidonTweak44::rand_domain(&mut rng)) | ||||||||||
| .collect(); | ||||||||||
| .map(|_| PoseidonTweak44::rand_domain(&mut rng)) | ||||||||||
| .collect(); | ||||||||||
|
|
||||||||||
| let simd_result = | ||||||||||
| PoseidonTweak44::compute_tree_layer(¶meter, level, parent_start, &children); | ||||||||||
| PoseidonTweak44::compute_tree_layer(¶meter, level, parent_start, &children); | ||||||||||
| let scalar_result = | ||||||||||
| compute_tree_layer_scalar::<PoseidonTweak44>(¶meter, level, parent_start, &children); | ||||||||||
| compute_tree_layer_scalar::<PoseidonTweak44>(¶meter, level, parent_start, &children); | ||||||||||
|
|
||||||||||
| prop_assert_eq!(simd_result.len(), num_pairs); | ||||||||||
| prop_assert_eq!(simd_result, scalar_result); | ||||||||||
|
|
||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find it very risky to rely on this kind of external crate. We should minimize the number of external deps and this one looks like a personal project so I find it pretty risky, would love to avoid this and I would prefer another alternative such as
for_each_initor something equivalent in std.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that's what I thought and why I kept the
for_each_initapproach. I'll have a look at the thread_localstdpackage https://doc.rust-lang.org/std/macro.thread_local.html.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using the stdlib macro now.