Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 8 additions & 48 deletions benchmark/zig_benchmark/src/cross_lang_zig_tool.zig
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,12 @@ fn signCommand(allocator: Allocator, message: []const u8, epoch: u32, lifetime:

var scheme: *hash_zig.GeneralizedXMSSSignatureScheme = undefined;
const keypair: hash_zig.GeneralizedXMSSSignatureScheme.KeyGenResult = blk: {
// For 2^8 lifetime, always regenerate from seed to avoid epoch configuration issues
// The keygen -> sign flow for 2^8 can have mismatched active epochs in the SSZ file
const skip_ssz_for_2_8 = (lifetime == .lifetime_2_8);

// Try to load SSZ secret key first if use_ssz is true and file exists
if (use_ssz) {
if (use_ssz and !skip_ssz_for_2_8) {
if (std.fs.cwd().readFileAlloc(allocator, "tmp/zig_sk.ssz", std.math.maxInt(usize))) |sk_ssz| {
defer allocator.free(sk_ssz);

Expand Down Expand Up @@ -327,7 +331,6 @@ fn signCommand(allocator: Allocator, message: []const u8, epoch: u32, lifetime:
};
defer seed_file.close();

// Read seed hex string
var buf: [64]u8 = undefined;
const read_len = try seed_file.readAll(&buf);
const hex_slice = buf[0..read_len];
Expand Down Expand Up @@ -365,18 +368,13 @@ fn signCommand(allocator: Allocator, message: []const u8, epoch: u32, lifetime:

const sk_data = try hash_zig.serialization.deserializeSecretKeyData(allocator, sk_json);

// Use the original seed (not PRF key) to ensure RNG state matches original keygen
// The PRF key was generated from the seed, so we need to start from the seed
// and consume RNG state to match where we were after generating parameter and PRF key
const seed_file = std.fs.cwd().openFile("tmp/zig_seed.hex", .{}) catch {
// If seed file is missing, fall back to using PRF key as seed (may not match exactly)
scheme = try hash_zig.GeneralizedXMSSSignatureScheme.initWithSeed(allocator, lifetime, sk_data.prf_key);
const kp = try scheme.keyGenWithParameter(sk_data.activation_epoch, sk_data.num_active_epochs, sk_data.parameter, sk_data.prf_key, false);
break :blk kp;
};
defer seed_file.close();

// Read seed hex string
var seed_buf: [64]u8 = undefined;
const seed_read_len = try seed_file.readAll(&seed_buf);
const seed_hex_slice = seed_buf[0..seed_read_len];
Expand All @@ -387,47 +385,13 @@ fn signCommand(allocator: Allocator, message: []const u8, epoch: u32, lifetime:
}
_ = try std.fmt.hexToBytes(&seed, seed_hex_slice);

// Initialize with original seed to match RNG state from keygen
scheme = try hash_zig.GeneralizedXMSSSignatureScheme.initWithSeed(allocator, lifetime, seed);

// CRITICAL: We need to match the RNG state exactly as it was when keyGenWithParameter
// was called from keyGen(). In keyGen(), the flow is:
// 1. generateRandomParameter() - peeks 20 bytes (doesn't consume)
// 2. generateRandomPRFKey() - consumes 32 bytes
// 3. keyGenWithParameter() - consumes another 32 bytes (to match state after step 2)
//
// But wait - that's wrong! When keyGenWithParameter is called from keyGen(), the RNG
// state is already after consuming 32 bytes. So keyGenWithParameter shouldn't consume
// another 32 bytes when called from keyGen(). But it does, which means it's consuming
// 64 bytes total when called from keyGen().
//
// Actually, I think the issue is that keyGenWithParameter is designed to be called
// directly (not from keyGen()), so it consumes 32 bytes to match the state after
// parameter/PRF key generation. But when called from keyGen(), this causes double
// consumption.
//
// For now, let's NOT consume here, because keyGenWithParameter will consume 32 bytes
// internally. But we need to account for the peek (20 bytes) and PRF key (32 bytes).
// Actually, the peek doesn't consume, so we just need to consume 32 bytes for the PRF key.
// But keyGenWithParameter already does that, so we shouldn't consume here.
//
// Wait, let me re-read the code. keyGenWithParameter consumes 32 bytes to match the
// state AFTER parameter/PRF key generation. So when we call it directly, we need to
// have consumed 32 bytes already. But we're starting fresh, so we need to consume
// 32 bytes to get to the state after PRF key generation.
// CRITICAL: Simulate the exact RNG consumption from keyGen():
// 1. generateRandomParameter() - peeks 20 bytes (doesn't consume RNG offset)
// 2. generateRandomPRFKey() - consumes 32 bytes (advances RNG offset)
//
// Even though peek doesn't consume, we should call the actual function to ensure
// the RNG state is in the exact same condition. The peek reads from the current
// offset without advancing it, but we want to ensure we're reading from the same
// position in the RNG stream.
_ = try scheme.generateRandomParameter(); // Peek at 20 bytes (doesn't consume)
// Simulate RNG consumption from keyGen: peek parameter, consume PRF key
_ = try scheme.generateRandomParameter();
var dummy_prf_key: [32]u8 = undefined;
scheme.rng.fill(&dummy_prf_key); // Consume 32 bytes to match generateRandomPRFKey()
scheme.rng.fill(&dummy_prf_key);

// We've already consumed 32 bytes to match PRF key generation, so pass true
const kp = try scheme.keyGenWithParameter(sk_data.activation_epoch, sk_data.num_active_epochs, sk_data.parameter, sk_data.prf_key, true);
break :blk kp;
};
Expand Down Expand Up @@ -462,7 +426,6 @@ fn signCommand(allocator: Allocator, message: []const u8, epoch: u32, lifetime:
}

if (use_ssz) {
// IMPORTANT: Also update the public key SSZ to match the regenerated keypair.
const pk_bytes = try keypair.public_key.toBytes(allocator);
defer allocator.free(pk_bytes);
var pk_file = try std.fs.cwd().createFile("tmp/zig_pk.ssz", .{});
Expand All @@ -478,9 +441,6 @@ fn signCommand(allocator: Allocator, message: []const u8, epoch: u32, lifetime:
try sig_file.writeAll(sig_bytes);
std.debug.print("✅ Signature saved to tmp/zig_sig.ssz ({} bytes)\n", .{sig_bytes.len});
} else {
// IMPORTANT: Also update the public key JSON to match the regenerated keypair.
// This ensures that verification (in both Zig and Rust) uses a public key that
// is consistent with the trees/roots used during signing.
const pk_json = try hash_zig.serialization.serializePublicKey(allocator, &keypair.public_key);
defer allocator.free(pk_json);
var pk_file = try std.fs.cwd().createFile("tmp/zig_pk.json", .{});
Expand Down
12 changes: 6 additions & 6 deletions benchmark/zig_benchmark/src/remote_hash_tool.zig
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ pub fn writeSignatureBincode(path: []const u8, signature: *const hash_zig.Genera
}

// Write rho (7 u32 values in CANONICAL form, no length prefix for fixed array)
// CRITICAL: Rust's bincode serializes field elements in CANONICAL form (matching path and hashes)
// Rust's bincode serializes field elements in CANONICAL form
// This must match Rust's FieldArray serialization which uses as_canonical_u32()
const rho = signature.getRho();
if (rand_len > rho.len) return BincodeError.InvalidRandLength;
Expand Down Expand Up @@ -217,15 +217,15 @@ pub fn readSignatureBincode(path: []const u8, allocator: std.mem.Allocator, rand

// Read path nodes (each has: HASH_LEN u32 values in CANONICAL form, NO length prefix for fixed arrays)
// Rust's bincode serializes Vec<FieldArray<N>> as: Vec length + elements directly (no per-array length)
// CRITICAL: Rust writes FieldArray<HASH_LEN> which serializes exactly HASH_LEN elements using as_canonical_u32()
// Rust writes FieldArray<HASH_LEN> which serializes exactly HASH_LEN elements
// For lifetime 2^8/2^32: HASH_LEN=8, Rust writes 8 elements
// For lifetime 2^18: HASH_LEN=7, Rust writes 7 elements
// We must read exactly hash_len elements to match Rust's serialization
var path_nodes = try allocator.alloc([8]FieldElement, path_len);
errdefer allocator.free(path_nodes);
for (0..path_len) |i| {
// Read array elements in canonical form (fixed-size array, no length prefix)
// CRITICAL: Read exactly hash_len elements (matching Rust's FieldArray<HASH_LEN>)
// Read exactly hash_len elements (matching Rust's FieldArray<HASH_LEN>)
for (0..hash_len) |j| {
const canonical = try reader.readInt(u32, .little);
path_nodes[i][j] = FieldElement.fromCanonical(canonical);
Expand All @@ -241,7 +241,7 @@ pub fn readSignatureBincode(path: []const u8, allocator: std.mem.Allocator, rand
allocator.free(path_nodes);

// Read rho (rand_len u32 values in CANONICAL form, no length prefix for fixed array)
// CRITICAL: Rust's bincode serializes field elements in CANONICAL form (matching path and hashes)
// Rust's bincode serializes field elements in CANONICAL form
// This must match Rust's FieldArray serialization which uses as_canonical_u32()
// For lifetime 2^8/2^32: rand_len=7, for lifetime 2^18: rand_len=6
if (rand_len > 7) {
Expand All @@ -268,15 +268,15 @@ pub fn readSignatureBincode(path: []const u8, allocator: std.mem.Allocator, rand

// Read hashes (each has: HASH_LEN u32 values in CANONICAL form, NO length prefix for fixed arrays)
// Rust's bincode serializes Vec<FieldArray<N>> as: Vec length + elements directly (no per-array length)
// CRITICAL: Rust writes FieldArray<HASH_LEN> which serializes exactly HASH_LEN elements using as_canonical_u32()
// Rust writes FieldArray<HASH_LEN> which serializes exactly HASH_LEN elements
// For lifetime 2^8/2^32: HASH_LEN=8, Rust writes 8 elements
// For lifetime 2^18: HASH_LEN=7, Rust writes 7 elements
// We must read exactly hash_len elements to match Rust's serialization
var hashes_tmp = try allocator.alloc([8]FieldElement, hashes_len);
errdefer allocator.free(hashes_tmp);
for (0..hashes_len) |i| {
// Read array elements in canonical form (fixed-size array, no length prefix)
// CRITICAL: Read exactly hash_len elements (matching Rust's FieldArray<HASH_LEN>)
// Read exactly hash_len elements (matching Rust's FieldArray<HASH_LEN>)
for (0..hash_len) |j| {
const canonical = try reader.readInt(u32, .little);
hashes_tmp[i][j] = FieldElement.fromCanonical(canonical);
Expand Down
2 changes: 1 addition & 1 deletion investigations/test/rust_compatibility_test.zig
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ const std = @import("std");
const log = @import("hash-zig").utils.log;
const hash_zig = @import("hash-zig");

// CRITICAL: Comprehensive Rust compatibility test
// Comprehensive Rust compatibility test
test "rust compatibility: GeneralizedXMSS validation (CRITICAL)" {
const allocator = std.testing.allocator;

Expand Down
4 changes: 2 additions & 2 deletions src/hash/poseidon2_hash_simd.zig
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ pub const Poseidon2SIMD = struct {
/// Input: packed_input is [element][lane] format (vertical packing)
/// Output: packed_output is [element][lane] format
///
/// CRITICAL OPTIMIZATION: Writes to pre-allocated output buffer instead of allocating.
/// Writes to pre-allocated output buffer instead of allocating.
/// This matches Rust's approach of returning fixed-size stack arrays, eliminating
/// 114,688 allocations in chain walking (64 chains × 7 steps × 256 batches).
pub fn compress16SIMD(
Expand Down Expand Up @@ -187,7 +187,7 @@ pub const Poseidon2SIMD = struct {
/// Input: packed_input is [element][lane] format (vertical packing)
/// Output: packed_output is [element][lane] format
///
/// CRITICAL OPTIMIZATION: Writes to pre-allocated output buffer instead of allocating.
/// Writes to pre-allocated output buffer instead of allocating.
/// This matches Rust's approach of returning fixed-size stack arrays.
pub fn compress24SIMD(
self: *Poseidon2SIMD,
Expand Down
12 changes: 6 additions & 6 deletions src/poseidon2/poseidon2.zig
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ pub fn sbox(x: F) F {
}

// Apply MDS matrix to 4 elements (exact Plonky3 logic from apply_mat4)
// CRITICAL FIX: Rust uses .clone() to preserve original values, so we must store them first
// Rust uses .clone() to preserve original values, so we must store them first
// Matrix: [ 2 3 1 1 ]
// [ 1 2 3 1 ]
// [ 1 1 2 3 ]
Expand Down Expand Up @@ -520,7 +520,7 @@ pub fn poseidon2_24_plonky3(state: []F) void {
// FIX: Rust applies MDS light BEFORE the first external round for 24-width!
// This matches Rust's external_initial_permute_state which calls mds_light_permutation first.
pub fn poseidon2_24_plonky3_with_mds_light(state: []F, apply_mds_light: bool) void {
// CRITICAL FIX: Rust applies MDS light BEFORE the first external round for 24-width!
// Rust applies MDS light BEFORE the first external round for 24-width
// This matches Rust's external_initial_permute_state behavior.
// For 24-width, we should ALWAYS apply MDS light first (matching Rust).
// The apply_mds_light parameter is kept for backward compatibility but should be true for 24-width.
Expand All @@ -529,7 +529,7 @@ pub fn poseidon2_24_plonky3_with_mds_light(state: []F, apply_mds_light: bool) vo
}

// Initial external rounds (4 rounds)
// CRITICAL: Rust's external_terminal_permute_state applies MDS light INSIDE each round
// Rust's external_terminal_permute_state applies MDS light INSIDE each round
// So each external round does: add_rc_and_sbox, then mds_light_permutation
// NOT: add_rc_and_sbox, then full MDS matrix
for (0..4) |i| {
Expand All @@ -541,7 +541,7 @@ pub fn poseidon2_24_plonky3_with_mds_light(state: []F, apply_mds_light: bool) vo
for (state) |*elem| {
elem.* = sbox(elem.*);
}
// CRITICAL FIX: Apply MDS light (not full MDS matrix) - matching Rust's external_terminal_permute_state
// Apply MDS light (not full MDS matrix) - matching Rust's external_terminal_permute_state
mds_light_permutation_24(state);
}

Expand All @@ -551,7 +551,7 @@ pub fn poseidon2_24_plonky3_with_mds_light(state: []F, apply_mds_light: bool) vo
}

// Final external rounds (4 rounds)
// CRITICAL: Rust's external_terminal_permute_state applies MDS light INSIDE each round
// Rust's external_terminal_permute_state applies MDS light INSIDE each round
// So each external round does: add_rc_and_sbox, then mds_light_permutation
// NOT: add_rc_and_sbox, then full MDS matrix
for (0..4) |i| {
Expand All @@ -563,7 +563,7 @@ pub fn poseidon2_24_plonky3_with_mds_light(state: []F, apply_mds_light: bool) vo
for (state) |*elem| {
elem.* = sbox(elem.*);
}
// CRITICAL FIX: Apply MDS light (not full MDS matrix) - matching Rust's external_terminal_permute_state
// Apply MDS light (not full MDS matrix) - matching Rust's external_terminal_permute_state
mds_light_permutation_24(state);
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/prf/shake_prf_to_field.zig
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ const crypto = std.crypto;
const plonky3_field = @import("../poseidon2/plonky3_field.zig");

// Constants matching Rust implementation
// CRITICAL: Rust uses 16 bytes per FE (reads as u128), not 8!
// Rust uses 16 bytes per FE (reads as u128), not 8
const PRF_BYTES_PER_FE: usize = 16;
const KEY_LENGTH: usize = 32; // 32 bytes
const MESSAGE_LENGTH: usize = 32; // From Rust hash-sig
Expand Down
Loading
Loading