diff --git a/src/symmetric/tweak_hash/poseidon.rs b/src/symmetric/tweak_hash/poseidon.rs index 1da6923..7ab2d7b 100644 --- a/src/symmetric/tweak_hash/poseidon.rs +++ b/src/symmetric/tweak_hash/poseidon.rs @@ -16,6 +16,8 @@ use crate::{F, PackedF}; use super::TweakableHash; +use p3_koala_bear::Poseidon2KoalaBear; + const DOMAIN_PARAMETERS_LENGTH: usize = 4; /// The state width for compressing a single hash in a chain. const CHAIN_COMPRESSION_WIDTH: usize = 16; @@ -132,26 +134,13 @@ where } /// Computes a Poseidon-based domain separator by compressing an array of `u32` -/// values using a fixed Poseidon instance. +/// values using the Poseidon2 KoalaBear permutation with width 24. /// -/// This function works generically over `A: Algebra`, allowing it to process both: -/// - Scalar fields, -/// - Packed SIMD fields -/// -/// ### Usage constraints -/// - This function is private because it's tailored to a very specific case: -/// the Poseidon2 instance with arity 24 and a fixed 4-word input. -/// - As this function operates on constants, its output can be **precomputed** -/// for significant performance gains, especially within a circuit. -/// - If generalization is ever needed, a more generic and slower version should be used. -fn poseidon_safe_domain_separator( - perm: &P, +/// Returns scalar field elements. For SIMD use, broadcast to `PackedF` at the call site. +fn poseidon_safe_domain_separator( + perm: &Poseidon2KoalaBear, params: &[u32; DOMAIN_PARAMETERS_LENGTH], -) -> [A; OUT_LEN] -where - A: Algebra + Copy, - P: CryptographicPermutation<[A; WIDTH]>, -{ +) -> [F; OUT_LEN] { // Combine params into a single number in base 2^32 // // WARNING: We can use a u128 instead of a BigUint only because `params` @@ -162,16 +151,13 @@ where } // Compute base-p decomposition - // - // We can use 24 as hardcoded because the only time we use this function - // is for the corresponding Poseidon instance. - let input = std::array::from_fn::<_, 24, _>(|_| { + let input: [F; MERGE_COMPRESSION_WIDTH] = std::array::from_fn(|_| { let digit = (acc % F::ORDER_U64 as u128) as u64; acc /= F::ORDER_U64 as u128; - A::from_u64(digit) + F::from_u64(digit) }); - poseidon_compress::(perm, &input) + poseidon_compress::(perm, &input) } /// Poseidon Sponge Hash Function @@ -359,10 +345,7 @@ impl< NUM_CHUNKS as u32, HASH_LEN as u32, ]; - let capacity_value = - poseidon_safe_domain_separator::( - &perm, &lengths, - ); + let capacity_value = poseidon_safe_domain_separator::(&perm, &lengths); FieldArray(poseidon_sponge::( &perm, &capacity_value, @@ -506,11 +489,8 @@ impl< NUM_CHUNKS as u32, HASH_LEN as u32, ]; - let capacity_val = - poseidon_safe_domain_separator::( - &sponge_perm, - &lengths, - ); + let capacity_val: [PackedF; CAPACITY] = + poseidon_safe_domain_separator::(&sponge_perm, &lengths).map(PackedF::from); // PARALLEL SIMD PROCESSING //