Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ p3-baby-bear = { git = "https://github.com/Plonky3/Plonky3.git", rev = "a33a312"
p3-koala-bear = { git = "https://github.com/Plonky3/Plonky3.git", rev = "a33a312" }
p3-symmetric = { git = "https://github.com/Plonky3/Plonky3.git", rev = "a33a312" }

thread_local = "1.1.9"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find it very risky to rely on this kind of external crate. We should minimize the number of external deps and this one looks like a personal project so I find it pretty risky, would love to avoid this and I would prefer another alternative such as for_each_init or something equivalent in std.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's what I thought and why I kept the for_each_init approach. I'll have a look at the thread_local std package https://doc.rust-lang.org/std/macro.thread_local.html.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using the stdlib macro now.


[dev-dependencies]
criterion = "0.7"
proptest = "1.7"
Expand All @@ -60,3 +62,7 @@ with-gen-benches-poseidon-top-level = []
[[bench]]
name = "benchmark"
harness = false

[profile.profiling]
inherits = "release"
debug = true
24 changes: 24 additions & 0 deletions examples/single_keygen.rs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove this file no? Just used for the benchmark I imagine but shouldn't be pushed to main no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, we can do that. I personally find it nice to have files for easy profiling like that, but they are also easy to write again. I'll remove them.

Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
use std::hint::black_box;

use leansig::signature::{
SignatureScheme,
generalized_xmss::instantiations_poseidon_top_level::lifetime_2_to_the_8::SIGTopLevelTargetSumLifetime8Dim64Base8,
};

fn main() {
let mut rng = rand::rng();

// 2^8 lifetime, full activation
let activation_duration = SIGTopLevelTargetSumLifetime8Dim64Base8::LIFETIME as usize;

eprintln!("Running single key_gen for 2^8 lifetime...");
let (pk, sk) = black_box(SIGTopLevelTargetSumLifetime8Dim64Base8::key_gen(
&mut rng,
0,
activation_duration,
));
eprintln!("Done. pk size: {} bytes", std::mem::size_of_val(&pk));

// Prevent optimization from removing the key_gen call
black_box((pk, sk));
}
33 changes: 33 additions & 0 deletions examples/single_keygen_2_32.rs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove this file no? Just used for the benchmark I imagine but shouldn't be pushed to main no?

Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
use std::hint::black_box;

use leansig::signature::{
SignatureScheme,
generalized_xmss::instantiations_poseidon_top_level::lifetime_2_to_the_32::size_optimized::SIGTopLevelTargetSumLifetime32Dim32Base26,
};

/// Cap activation duration to 2^18 to keep runtime reasonable (same as benchmark)
const MAX_LOG_ACTIVATION_DURATION: usize = 18;

fn main() {
let mut rng = rand::rng();

// 2^32 lifetime, activation capped at 2^18
let activation_duration = std::cmp::min(
1 << MAX_LOG_ACTIVATION_DURATION,
SIGTopLevelTargetSumLifetime32Dim32Base26::LIFETIME as usize,
);

eprintln!(
"Running single key_gen for 2^32 lifetime (activation 2^{})...",
MAX_LOG_ACTIVATION_DURATION
);
let (pk, sk) = black_box(SIGTopLevelTargetSumLifetime32Dim32Base26::key_gen(
&mut rng,
0,
activation_duration,
));
eprintln!("Done. pk size: {} bytes", std::mem::size_of_val(&pk));

// Prevent optimization from removing the key_gen call
black_box((pk, sk));
}
147 changes: 93 additions & 54 deletions src/symmetric/tweak_hash/poseidon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ use crate::{F, PackedF};
use super::TweakableHash;

use p3_koala_bear::Poseidon2KoalaBear;
use std::cell::RefCell;
use thread_local::ThreadLocal;

const DOMAIN_PARAMETERS_LENGTH: usize = 4;
/// The state width for compressing a single hash in a chain.
Expand Down Expand Up @@ -201,34 +203,47 @@ where
);
let rate = WIDTH - capacity_value.len();

let extra_elements = (rate - (input.len() % rate)) % rate;
let mut input_vector = input.to_vec();
// We pad the input with zeros to make its length a multiple of the rate.
//
// This is safe because the input's original length is effectively encoded
// in the `capacity_value`, which serves as a domain separator.
input_vector.resize(input.len() + extra_elements, A::ZERO);

// initialize
let mut state = [A::ZERO; WIDTH];
state[rate..].copy_from_slice(capacity_value);

// absorb
for chunk in input_vector.chunks(rate) {
let extra_elements = (rate - (input.len() % rate)) % rate;
// Instead of converting the input to a vector, resizing and feeding the data into the
// sponge, we instead fill in the vector from all chunks until we are left with a non
// full chunk. We only add to the state, so padded data does not mutate `state` at all.

// 1. fill in all full chunks and permute
let mut it = input.chunks_exact(rate);
for chunk in &mut it {
//input.chunks_exact(rate) {
// iterate the chunks
for i in 0..chunk.len() {
state[i] += chunk[i];
}
perm.permute_mut(&mut state);
}
// 2. fill the remainder and extend with zeros
let remainder = rate - extra_elements;
if remainder > 0 {
for (i, x) in it.remainder().iter().enumerate() {
state[i] += *x;
}
// was a remainder, so permute. No need to mutate `state` as we *add* only anyway
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean here by "no need to mutate state"? Since you pass a mutable reference of state to the permutation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The point is that before we zero pad. but because one only adds to the state, we don't have to add zeroes for the previously padded data.

perm.permute_mut(&mut state);
}

// squeeze
let mut out = vec![];
while out.len() < OUT_LEN {
out.extend_from_slice(&state[..rate]);
perm.permute_mut(&mut state);
let mut out = [A::ZERO; OUT_LEN];
let mut out_idx = 0;
while out_idx < OUT_LEN {
let chunk_size = (OUT_LEN - out_idx).min(rate);
out[out_idx..out_idx + chunk_size].copy_from_slice(&state[..chunk_size]);
out_idx += chunk_size;
if out_idx < OUT_LEN {
perm.permute_mut(&mut state);
}
Comment on lines +240 to +242
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here to match exactly the logic we had before I think that we don't need the if statement. Here this piece of code is an important factor for security and so we should not diminish the number of permutations. Due to the if after out_idx += chunk_size I think that the number of permutations was not exactly the same.

A good exercise to do for this kind of sensitive refactoring is to check the outputs of a call before and after the modification, they should be the same since we didn't change any logic here.

Suggested change
if out_idx < OUT_LEN {
perm.permute_mut(&mut state);
}
perm.permute_mut(&mut state);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe I screwed that up. Will check next year, sorry.

}
let slice = &out[0..OUT_LEN];
slice.try_into().expect("Length mismatch")
out
}

/// A tweakable hash function implemented using Poseidon2
Expand Down Expand Up @@ -297,12 +312,14 @@ impl<
[single] => {
// we compress parameter, tweak, message
let perm = poseidon2_16();
let combined_input: Vec<F> = parameter
.iter()
.chain(tweak_fe.iter())
.chain(single.iter())
.copied()
.collect();

// Build input on stack: [parameter | tweak | message]
let mut combined_input = [F::ZERO; CHAIN_COMPRESSION_WIDTH];
combined_input[..PARAMETER_LEN].copy_from_slice(&parameter.0);
combined_input[PARAMETER_LEN..PARAMETER_LEN + TWEAK_LEN].copy_from_slice(&tweak_fe);
combined_input[PARAMETER_LEN + TWEAK_LEN..PARAMETER_LEN + TWEAK_LEN + HASH_LEN]
.copy_from_slice(&single.0);

FieldArray(
poseidon_compress::<F, _, CHAIN_COMPRESSION_WIDTH, HASH_LEN>(
&perm,
Expand All @@ -314,13 +331,17 @@ impl<
[left, right] => {
// we compress parameter, tweak, message (now containing two parts)
let perm = poseidon2_24();
let combined_input: Vec<F> = parameter
.iter()
.chain(tweak_fe.iter())
.chain(left.iter())
.chain(right.iter())
.copied()
.collect();

// Build input on stack: [parameter | tweak | left | right]
let mut combined_input = [F::ZERO; MERGE_COMPRESSION_WIDTH];
combined_input[..PARAMETER_LEN].copy_from_slice(&parameter.0);
combined_input[PARAMETER_LEN..PARAMETER_LEN + TWEAK_LEN].copy_from_slice(&tweak_fe);
combined_input[PARAMETER_LEN + TWEAK_LEN..PARAMETER_LEN + TWEAK_LEN + HASH_LEN]
.copy_from_slice(&left.0);
combined_input[PARAMETER_LEN + TWEAK_LEN + HASH_LEN
..PARAMETER_LEN + TWEAK_LEN + 2 * HASH_LEN]
.copy_from_slice(&right.0);

FieldArray(
poseidon_compress::<F, _, MERGE_COMPRESSION_WIDTH, HASH_LEN>(
&perm,
Expand Down Expand Up @@ -492,6 +513,15 @@ impl<
let capacity_val: [PackedF; CAPACITY] =
poseidon_safe_domain_separator::<CAPACITY>(&sponge_perm, &lengths).map(PackedF::from);

// Compute sponge input length. Required to init packed input vector for each rayon worker
let sponge_tweak_offset = PARAMETER_LEN;
let sponge_chains_offset = PARAMETER_LEN + TWEAK_LEN;
let sponge_input_len = PARAMETER_LEN + TWEAK_LEN + NUM_CHUNKS * HASH_LEN;

// We use a thread local storage to guarantee the `packed_leaf_input` vector is only allocated
// once per thread
let tls: ThreadLocal<RefCell<Vec<PackedF>>> = ThreadLocal::new();

// PARALLEL SIMD PROCESSING
//
// Process epochs in batches of size `width`.
Expand All @@ -508,11 +538,18 @@ impl<
//
// This layout enables efficient SIMD operations across epochs.

let cell = tls.get_or(|| {
RefCell::new(vec![PackedF::ZERO; sponge_input_len])
});
let mut packed_leaf_input = cell.borrow_mut();
// reset not needed

let mut packed_chains: [[PackedF; HASH_LEN]; NUM_CHUNKS] =
array::from_fn(|c_idx| {
// Generate starting points for this chain across all epochs.
let starts: [_; PackedF::WIDTH] = array::from_fn(|lane| {
PRF::get_domain_element(prf_key, epoch_chunk[lane], c_idx as u64).into()
PRF::get_domain_element(prf_key, epoch_chunk[lane], c_idx as u64)
.into()
});

// Transpose to vertical packing for SIMD efficiency.
Expand Down Expand Up @@ -565,10 +602,10 @@ impl<
// Apply the hash function to advance the chain.
// This single call processes all epochs in parallel.
*packed_chain =
poseidon_compress::<PackedF, _, CHAIN_COMPRESSION_WIDTH, HASH_LEN>(
&chain_perm,
&packed_input,
);
poseidon_compress::<PackedF, _, CHAIN_COMPRESSION_WIDTH, HASH_LEN>(
&chain_perm,
&packed_input,
);
}
}

Expand All @@ -581,11 +618,8 @@ impl<

// Assemble the sponge input.
// Layout: [parameter | tree_tweak | all_chain_ends]
let sponge_tweak_offset = PARAMETER_LEN;
let sponge_chains_offset = PARAMETER_LEN + TWEAK_LEN;
let sponge_input_len = PARAMETER_LEN + TWEAK_LEN + NUM_CHUNKS * HASH_LEN;

let mut packed_leaf_input = vec![PackedF::ZERO; sponge_input_len];
// NOTE: `packed_leaf_input` is preallocated per thread. We overwrite the entire
// vector in each iteration, so no need to `fill(0)`!

// Copy pre-packed parameter
packed_leaf_input[..PARAMETER_LEN].copy_from_slice(&packed_parameter);
Expand All @@ -596,30 +630,35 @@ impl<
sponge_tweak_offset,
|t_idx, lane| {
Self::tree_tweak(0, epoch_chunk[lane]).to_field_elements::<TWEAK_LEN>()
[t_idx]
[t_idx]
},
);

// Copy all chain ends (already packed)
let dst = &mut packed_leaf_input[sponge_chains_offset .. sponge_chains_offset + packed_chains.len() * HASH_LEN];
for (dst_chunk, src_chain) in dst.chunks_exact_mut(HASH_LEN).zip(packed_chains.iter()) {
dst_chunk.copy_from_slice(src_chain);
}
let dst = &mut packed_leaf_input[sponge_chains_offset
..sponge_chains_offset + packed_chains.len() * HASH_LEN];
for (dst_chunk, src_chain) in
dst.chunks_exact_mut(HASH_LEN).zip(packed_chains.iter())
{
dst_chunk.copy_from_slice(src_chain);
}

// Apply the sponge hash to produce the leaf.
// This absorbs all chain ends and squeezes out the final hash.
let packed_leaves = poseidon_sponge::<PackedF, _, MERGE_COMPRESSION_WIDTH, HASH_LEN>(
&sponge_perm,
&capacity_val,
&packed_leaf_input,
);
let packed_leaves =
poseidon_sponge::<PackedF, _, MERGE_COMPRESSION_WIDTH, HASH_LEN>(
&sponge_perm,
&capacity_val,
&packed_leaf_input,
);

// STEP 4: UNPACK RESULTS TO SCALAR REPRESENTATION
//
// Convert from vertical packing back to scalar layout.
// Each lane becomes one leaf in the output slice.
unpack_array(&packed_leaves, leaves_chunk);
});
},
);

// HANDLE REMAINDER EPOCHS
//
Expand Down Expand Up @@ -1655,13 +1694,13 @@ mod tests {

let parameter = PoseidonTweak44::rand_parameter(&mut rng);
let children: Vec<_> = (0..num_pairs * 2)
.map(|_| PoseidonTweak44::rand_domain(&mut rng))
.collect();
.map(|_| PoseidonTweak44::rand_domain(&mut rng))
.collect();

let simd_result =
PoseidonTweak44::compute_tree_layer(&parameter, level, parent_start, &children);
PoseidonTweak44::compute_tree_layer(&parameter, level, parent_start, &children);
let scalar_result =
compute_tree_layer_scalar::<PoseidonTweak44>(&parameter, level, parent_start, &children);
compute_tree_layer_scalar::<PoseidonTweak44>(&parameter, level, parent_start, &children);

prop_assert_eq!(simd_result.len(), num_pairs);
prop_assert_eq!(simd_result, scalar_result);
Expand Down