Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,7 @@ with-gen-benches-poseidon-top-level = []
[[bench]]
name = "benchmark"
harness = false

[profile.profiling]
inherits = "release"
debug = true
164 changes: 100 additions & 64 deletions src/symmetric/tweak_hash/poseidon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use crate::{F, PackedF};
use super::TweakableHash;

use p3_koala_bear::Poseidon2KoalaBear;
use std::cell::RefCell;

const DOMAIN_PARAMETERS_LENGTH: usize = 4;
/// The state width for compressing a single hash in a chain.
Expand Down Expand Up @@ -201,34 +202,46 @@ where
);
let rate = WIDTH - capacity_value.len();

let extra_elements = (rate - (input.len() % rate)) % rate;
let mut input_vector = input.to_vec();
// We pad the input with zeros to make its length a multiple of the rate.
//
// This is safe because the input's original length is effectively encoded
// in the `capacity_value`, which serves as a domain separator.
input_vector.resize(input.len() + extra_elements, A::ZERO);

// initialize
let mut state = [A::ZERO; WIDTH];
state[rate..].copy_from_slice(capacity_value);

// absorb
for chunk in input_vector.chunks(rate) {
for i in 0..chunk.len() {
state[i] += chunk[i];
let extra_elements = (rate - (input.len() % rate)) % rate;
// Instead of converting the input to a vector, resizing and feeding the data into the
// sponge, we instead fill in the vector from all chunks until we are left with a non
// full chunk. We only add to the state, so padded data does not mutate `state` at all.

// 1. fill in all full chunks and permute
let mut it = input.chunks_exact(rate);
for chunk in &mut it {
// iterate the chunks
for (s, &x) in state.iter_mut().take(rate).zip(chunk) {
*s += x;
}
perm.permute_mut(&mut state);
}
// 2. fill the remainder and extend with zeros
let remainder = rate - extra_elements;
if remainder > 0 {
for (i, x) in it.remainder().iter().enumerate() {
state[i] += *x;
}
// was a remainder, so permute. No need to mutate `state` as we *add* only anyway
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean here by "no need to mutate state"? Since you pass a mutable reference of state to the permutation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The point is that before we zero pad. but because one only adds to the state, we don't have to add zeroes for the previously padded data.

perm.permute_mut(&mut state);
}

// squeeze
let mut out = vec![];
while out.len() < OUT_LEN {
out.extend_from_slice(&state[..rate]);
perm.permute_mut(&mut state);
let mut out = [A::ZERO; OUT_LEN];
let mut out_idx = 0;
while out_idx < OUT_LEN {
let chunk_size = (OUT_LEN - out_idx).min(rate);
out[out_idx..out_idx + chunk_size].copy_from_slice(&state[..chunk_size]);
out_idx += chunk_size;
if out_idx < OUT_LEN {
perm.permute_mut(&mut state);
}
Comment on lines +240 to +242
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here to match exactly the logic we had before I think that we don't need the if statement. Here this piece of code is an important factor for security and so we should not diminish the number of permutations. Due to the if after out_idx += chunk_size I think that the number of permutations was not exactly the same.

A good exercise to do for this kind of sensitive refactoring is to check the outputs of a call before and after the modification, they should be the same since we didn't change any logic here.

Suggested change
if out_idx < OUT_LEN {
perm.permute_mut(&mut state);
}
perm.permute_mut(&mut state);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe I screwed that up. Will check next year, sorry.

}
let slice = &out[0..OUT_LEN];
slice.try_into().expect("Length mismatch")
out
}

/// A tweakable hash function implemented using Poseidon2
Expand Down Expand Up @@ -297,12 +310,14 @@ impl<
[single] => {
// we compress parameter, tweak, message
let perm = poseidon2_16();
let combined_input: Vec<F> = parameter
.iter()
.chain(tweak_fe.iter())
.chain(single.iter())
.copied()
.collect();

// Build input on stack: [parameter | tweak | message]
let mut combined_input = [F::ZERO; CHAIN_COMPRESSION_WIDTH];
combined_input[..PARAMETER_LEN].copy_from_slice(&parameter.0);
combined_input[PARAMETER_LEN..PARAMETER_LEN + TWEAK_LEN].copy_from_slice(&tweak_fe);
combined_input[PARAMETER_LEN + TWEAK_LEN..PARAMETER_LEN + TWEAK_LEN + HASH_LEN]
.copy_from_slice(&single.0);

FieldArray(
poseidon_compress::<F, _, CHAIN_COMPRESSION_WIDTH, HASH_LEN>(
&perm,
Expand All @@ -314,13 +329,17 @@ impl<
[left, right] => {
// we compress parameter, tweak, message (now containing two parts)
let perm = poseidon2_24();
let combined_input: Vec<F> = parameter
.iter()
.chain(tweak_fe.iter())
.chain(left.iter())
.chain(right.iter())
.copied()
.collect();

// Build input on stack: [parameter | tweak | left | right]
let mut combined_input = [F::ZERO; MERGE_COMPRESSION_WIDTH];
combined_input[..PARAMETER_LEN].copy_from_slice(&parameter.0);
combined_input[PARAMETER_LEN..PARAMETER_LEN + TWEAK_LEN].copy_from_slice(&tweak_fe);
combined_input[PARAMETER_LEN + TWEAK_LEN..PARAMETER_LEN + TWEAK_LEN + HASH_LEN]
.copy_from_slice(&left.0);
combined_input[PARAMETER_LEN + TWEAK_LEN + HASH_LEN
..PARAMETER_LEN + TWEAK_LEN + 2 * HASH_LEN]
.copy_from_slice(&right.0);

FieldArray(
poseidon_compress::<F, _, MERGE_COMPRESSION_WIDTH, HASH_LEN>(
&perm,
Expand Down Expand Up @@ -492,6 +511,17 @@ impl<
let capacity_val: [PackedF; CAPACITY] =
poseidon_safe_domain_separator::<CAPACITY>(&sponge_perm, &lengths).map(PackedF::from);

// Compute sponge input length. Required to init packed input vector for each rayon worker
let sponge_tweak_offset = PARAMETER_LEN;
let sponge_chains_offset = PARAMETER_LEN + TWEAK_LEN;
let sponge_input_len = PARAMETER_LEN + TWEAK_LEN + NUM_CHUNKS * HASH_LEN;

// We use thread-local storage to guarantee the `packed_leaf_input` vector is only allocated
// once per thread
thread_local! {
static PACKED_LEAF_INPUT: RefCell<Vec<PackedF>> = const { RefCell::new(Vec::new()) };
}

// PARALLEL SIMD PROCESSING
//
// Process epochs in batches of size `width`.
Expand Down Expand Up @@ -581,38 +611,44 @@ impl<

// Assemble the sponge input.
// Layout: [parameter | tree_tweak | all_chain_ends]
let sponge_tweak_offset = PARAMETER_LEN;
let sponge_chains_offset = PARAMETER_LEN + TWEAK_LEN;
let sponge_input_len = PARAMETER_LEN + TWEAK_LEN + NUM_CHUNKS * HASH_LEN;

let mut packed_leaf_input = vec![PackedF::ZERO; sponge_input_len];
// NOTE: `packed_leaf_input` is preallocated per thread. We overwrite the entire
// vector in each iteration, so no need to `fill(0)`!
let packed_leaves = PACKED_LEAF_INPUT.with_borrow_mut(|packed_leaf_input| {
// Resize on first use for this thread
if packed_leaf_input.len() != sponge_input_len {
packed_leaf_input.resize(sponge_input_len, PackedF::ZERO);
}

// Copy pre-packed parameter
packed_leaf_input[..PARAMETER_LEN].copy_from_slice(&packed_parameter);

// Pack tree tweaks directly (level 0 for bottom-layer leaves)
pack_fn_into::<TWEAK_LEN>(
&mut packed_leaf_input,
sponge_tweak_offset,
|t_idx, lane| {
Self::tree_tweak(0, epoch_chunk[lane]).to_field_elements::<TWEAK_LEN>()
[t_idx]
},
);
// Copy pre-packed parameter
packed_leaf_input[..PARAMETER_LEN].copy_from_slice(&packed_parameter);

// Pack tree tweaks directly (level 0 for bottom-layer leaves)
pack_fn_into::<TWEAK_LEN>(
packed_leaf_input,
sponge_tweak_offset,
|t_idx, lane| {
Self::tree_tweak(0, epoch_chunk[lane]).to_field_elements::<TWEAK_LEN>()
[t_idx]
},
);

// Copy all chain ends (already packed)
let dst = &mut packed_leaf_input[sponge_chains_offset .. sponge_chains_offset + packed_chains.len() * HASH_LEN];
for (dst_chunk, src_chain) in dst.chunks_exact_mut(HASH_LEN).zip(packed_chains.iter()) {
dst_chunk.copy_from_slice(src_chain);
}
// Copy all chain ends (already packed)
let dst = &mut packed_leaf_input[sponge_chains_offset
..sponge_chains_offset + packed_chains.len() * HASH_LEN];
for (dst_chunk, src_chain) in
dst.chunks_exact_mut(HASH_LEN).zip(packed_chains.iter())
{
dst_chunk.copy_from_slice(src_chain);
}

// Apply the sponge hash to produce the leaf.
// This absorbs all chain ends and squeezes out the final hash.
let packed_leaves = poseidon_sponge::<PackedF, _, MERGE_COMPRESSION_WIDTH, HASH_LEN>(
&sponge_perm,
&capacity_val,
&packed_leaf_input,
);
// Apply the sponge hash to produce the leaf.
// This absorbs all chain ends and squeezes out the final hash.
poseidon_sponge::<PackedF, _, MERGE_COMPRESSION_WIDTH, HASH_LEN>(
&sponge_perm,
&capacity_val,
packed_leaf_input,
)
});

// STEP 4: UNPACK RESULTS TO SCALAR REPRESENTATION
//
Expand Down Expand Up @@ -1655,13 +1691,13 @@ mod tests {

let parameter = PoseidonTweak44::rand_parameter(&mut rng);
let children: Vec<_> = (0..num_pairs * 2)
.map(|_| PoseidonTweak44::rand_domain(&mut rng))
.collect();
.map(|_| PoseidonTweak44::rand_domain(&mut rng))
.collect();

let simd_result =
PoseidonTweak44::compute_tree_layer(&parameter, level, parent_start, &children);
PoseidonTweak44::compute_tree_layer(&parameter, level, parent_start, &children);
let scalar_result =
compute_tree_layer_scalar::<PoseidonTweak44>(&parameter, level, parent_start, &children);
compute_tree_layer_scalar::<PoseidonTweak44>(&parameter, level, parent_start, &children);

prop_assert_eq!(simd_result.len(), num_pairs);
prop_assert_eq!(simd_result, scalar_result);
Expand Down