diff --git a/processor/src/decoder/aux_trace/mod.rs b/processor/src/decoder/aux_trace/mod.rs index f98c4ba5ae..ce5ef70e68 100644 --- a/processor/src/decoder/aux_trace/mod.rs +++ b/processor/src/decoder/aux_trace/mod.rs @@ -40,17 +40,17 @@ impl AuxTraceBuilder { let p2 = block_hash_column_builder.build_aux_column(main_trace, rand_elements); let p3 = op_group_table_column_builder.build_aux_column(main_trace, rand_elements); - debug_assert_eq!( + assert_eq!( *p1.last().unwrap(), E::ONE, "block stack table is not empty at the end of program execution" ); - debug_assert_eq!( + assert_eq!( *p2.last().unwrap(), E::ONE, "block hash table is not empty at the end of program execution" ); - debug_assert_eq!( + assert_eq!( *p3.last().unwrap(), E::ONE, "op group table is not empty at the end of program execution" diff --git a/processor/src/fast/trace_state.rs b/processor/src/fast/trace_state.rs index 719c94c6d4..4e425c830d 100644 --- a/processor/src/fast/trace_state.rs +++ b/processor/src/fast/trace_state.rs @@ -1,9 +1,10 @@ -use alloc::{collections::VecDeque, sync::Arc, vec::Vec}; - -use miden_air::trace::{ - RowIndex, - chiplets::hasher::{HasherState, STATE_WIDTH}, +use alloc::{ + collections::{BTreeMap, VecDeque, btree_map::Entry}, + sync::Arc, + vec::Vec, }; + +use miden_air::trace::chiplets::hasher::{HasherState, STATE_WIDTH}; use miden_core::{ Felt, ONE, Word, ZERO, crypto::merkle::MerklePath, @@ -13,7 +14,7 @@ use miden_core::{ }; use crate::{ - AdviceError, ContextId, ErrorContext, ExecutionError, + AdviceError, ContextId, ErrorContext, ExecutionError, RowIndex, chiplets::CircuitEvaluation, continuation_stack::ContinuationStack, fast::FastProcessor, @@ -978,7 +979,7 @@ impl HasherInterface for HasherResponseReplay { pub enum HasherOp { Permute([Felt; STATE_WIDTH]), HashControlBlock((Word, Word, Felt, Word)), - HashBasicBlock((Vec, Word)), + HashBasicBlock(Word), // Only stores the digest; op_batches are looked up from op_batches_map BuildMerkleRoot((Word, MerklePath, Felt)), UpdateMerkleRoot((Word, Word, MerklePath, Felt)), } @@ -988,9 +989,22 @@ pub enum HasherOp { /// /// The hasher requests are recorded during fast processor execution and then replayed during hasher /// chiplet trace generation. -#[derive(Debug, Default)] +#[derive(Default)] pub struct HasherRequestReplay { hasher_ops: VecDeque, + /// Deduplication map for basic block operation batches. + /// Maps from basic block digest to its operation batches, avoiding duplication when the same + /// basic block is entered multiple times. + op_batches_map: BTreeMap>, +} + +impl core::fmt::Debug for HasherRequestReplay { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("HasherRequestReplay") + .field("hasher_ops", &self.hasher_ops) + // Exclude op_batches_map from Debug output to maintain snapshot compatibility + .finish() + } } impl HasherRequestReplay { @@ -1012,8 +1026,40 @@ impl HasherRequestReplay { } /// Records a `Hasher::hash_basic_block()` request. + /// + /// Deduplicates operation batches by storing them in a map keyed by the basic block digest. + /// If the same basic block is entered multiple times, only one copy of the operation batches + /// is stored. pub fn record_hash_basic_block(&mut self, op_batches: Vec, expected_hash: Word) { - self.hasher_ops.push_back(HasherOp::HashBasicBlock((op_batches, expected_hash))); + // Only store the op_batches if we haven't seen this digest before + // If the digest already exists, we verify that the op_batches match (they should, since + // the digest is computed from the op_batches) + match self.op_batches_map.entry(expected_hash) { + Entry::Vacant(entry) => { + entry.insert(op_batches); + }, + Entry::Occupied(entry) => { + // Digest already exists, skip storing (deduplication) + debug_assert_eq!( + entry.get(), + &op_batches, + "Same digest should always map to same op_batches" + ); + }, + } + // Store only the digest in the operation record + self.hasher_ops.push_back(HasherOp::HashBasicBlock(expected_hash)); + } + + /// Returns a reference to the operation batches map for looking up batches during replay. + pub fn op_batches_map(&self) -> &BTreeMap> { + &self.op_batches_map + } + + /// Consumes `HasherRequestReplay` and returns both the hasher operations and the operation + /// batches map. This allows accessing the map during iteration without cloning. + pub fn into_parts(self) -> (VecDeque, BTreeMap>) { + (self.hasher_ops, self.op_batches_map) } /// Records a `Hasher::build_merkle_root()` request. @@ -1034,15 +1080,6 @@ impl HasherRequestReplay { } } -impl IntoIterator for HasherRequestReplay { - type Item = HasherOp; - type IntoIter = as IntoIterator>::IntoIter; - - fn into_iter(self) -> Self::IntoIter { - self.hasher_ops.into_iter() - } -} - // STACK OVERFLOW REPLAY // ================================================================================================ @@ -1175,3 +1212,40 @@ pub enum NodeExecutionState { /// This is used when completing execution of a control flow construct. End(MastNodeId), } + +#[cfg(test)] +mod tests { + use miden_core::{Operation, mast::BasicBlockNodeBuilder}; + + use super::*; + + #[test] + fn test_hasher_request_replay_deduplicates_basic_blocks() { + let mut replay = HasherRequestReplay::default(); + let digest = Word::new([Felt::new(1), Felt::new(2), Felt::new(3), Felt::new(4)]); + + // Create a simple basic block with one operation to get op_batches + let basic_block = + BasicBlockNodeBuilder::new(vec![Operation::Add], Vec::new()).build().unwrap(); + let op_batches = basic_block.op_batches().to_vec(); + + // Record the same digest three times + replay.record_hash_basic_block(op_batches.clone(), digest); + replay.record_hash_basic_block(op_batches.clone(), digest); + replay.record_hash_basic_block(op_batches.clone(), digest); + + // Verify that the map has only one entry (deduplication worked) + assert_eq!(replay.op_batches_map().len(), 1); + + // Verify that hasher_ops has three entries (one for each record call) + let (hasher_ops, _) = replay.into_parts(); + assert_eq!(hasher_ops.len(), 3); + + // Verify all three entries are HashBasicBlock with the same digest + let mut iter = hasher_ops.into_iter(); + assert!(matches!(iter.next(), Some(HasherOp::HashBasicBlock(h)) if h == digest)); + assert!(matches!(iter.next(), Some(HasherOp::HashBasicBlock(h)) if h == digest)); + assert!(matches!(iter.next(), Some(HasherOp::HashBasicBlock(h)) if h == digest)); + assert!(iter.next().is_none()); + } +} diff --git a/processor/src/parallel/mod.rs b/processor/src/parallel/mod.rs index 16b3d1c1f8..4978833014 100644 --- a/processor/src/parallel/mod.rs +++ b/processor/src/parallel/mod.rs @@ -5,27 +5,27 @@ use miden_air::{ Felt, trace::{ CLK_COL_IDX, CTX_COL_IDX, DECODER_TRACE_OFFSET, DECODER_TRACE_WIDTH, FN_HASH_RANGE, - MIN_TRACE_LEN, MainTrace, PADDED_TRACE_WIDTH, RowIndex, STACK_TRACE_OFFSET, - STACK_TRACE_WIDTH, SYS_TRACE_WIDTH, TRACE_WIDTH, + MIN_TRACE_LEN, PADDED_TRACE_WIDTH, STACK_TRACE_OFFSET, STACK_TRACE_WIDTH, SYS_TRACE_WIDTH, + TRACE_WIDTH, decoder::{ ADDR_COL_IDX, GROUP_COUNT_COL_IDX, HASHER_STATE_OFFSET, IN_SPAN_COL_IDX, NUM_HASHER_COLUMNS, NUM_OP_BATCH_FLAGS, NUM_OP_BITS, OP_BATCH_FLAGS_OFFSET, OP_BITS_EXTRA_COLS_OFFSET, OP_BITS_OFFSET, OP_INDEX_COL_IDX, }, + MainTrace, stack::{B0_COL_IDX, B1_COL_IDX, H0_COL_IDX, STACK_TOP_OFFSET}, }, }; use miden_core::{ - Kernel, ONE, Operation, Word, ZERO, - stack::MIN_STACK_DEPTH, - utils::{ColMatrix, uninit_vector}, + Kernel, ONE, Operation, Word, ZERO, stack::MIN_STACK_DEPTH, utils::uninit_vector, }; use rayon::prelude::*; -use tracing::instrument; use crate::{ - ChipletsLengths, ContextId, ExecutionTrace, TraceLenSummary, + ChipletsLengths, ContextId, ExecutionTrace, RowIndex, + TraceLenSummary, chiplets::Chiplets, + crypto::RpoRandomCoin, decoder::AuxTraceBuilder as DecoderAuxTraceBuilder, fast::{ ExecutionOutput, @@ -39,11 +39,16 @@ use crate::{ range::RangeChecker, stack::AuxTraceBuilder as StackAuxTraceBuilder, trace::AuxTraceBuilders, - utils::invert_column_allow_zeros, + utils::{ColMatrix, invert_column_allow_zeros}, }; pub const CORE_TRACE_WIDTH: usize = SYS_TRACE_WIDTH + DECODER_TRACE_WIDTH + STACK_TRACE_WIDTH; +/// Number of random rows to inject at the end of the trace. +/// This matches the RPO permutation rate (8 elements). +const NUM_RAND_ROWS: usize = 8; + + mod core_trace_fragment; #[cfg(test)] @@ -53,7 +58,6 @@ mod tests; // ================================================================================================ /// Builds the main trace from the provided trace states in parallel. -#[instrument(name = "build_trace", skip_all)] pub fn build_trace( execution_output: ExecutionOutput, trace_generation_context: TraceGenerationContext, @@ -104,7 +108,7 @@ pub fn build_trace( let trace_len_summary = TraceLenSummary::new(core_trace_len, range_table_len, ChipletsLengths::new(&chiplets)); - // Compute the final main trace length + // Compute the final main trace length, after accounting for random rows let main_trace_len = compute_main_trace_length(core_trace_len, range_table_len, chiplets.trace_len()); @@ -112,7 +116,12 @@ pub fn build_trace( || pad_trace_columns(&mut core_trace_columns, main_trace_len), || { rayon::join( - || range_checker.into_trace_with_table(range_table_len, main_trace_len), + || { + range_checker.into_trace_with_table( + range_table_len, + main_trace_len, + ) + }, || chiplets.into_trace(main_trace_len, final_pc_transcript.state()), ) }, @@ -122,13 +131,23 @@ pub fn build_trace( let padding_columns = vec![vec![ZERO; main_trace_len]; PADDED_TRACE_WIDTH - TRACE_WIDTH]; // Chain all trace columns together - let trace_columns: Vec> = core_trace_columns + let mut trace_columns: Vec> = core_trace_columns .into_iter() .chain(range_checker_trace.trace) .chain(chiplets_trace.trace) .chain(padding_columns) .collect(); + // Initialize random element generator using program hash + let mut rng = RpoRandomCoin::new(program_hash); + + // Inject random values into the last NUM_RAND_ROWS rows for all columns + for i in main_trace_len - NUM_RAND_ROWS..main_trace_len { + for column in trace_columns.iter_mut() { + column[i] = rng.draw(); + } + } + // Create the MainTrace let main_trace = { let last_program_row = RowIndex::from((core_trace_len as u32).saturating_sub(1)); @@ -165,8 +184,9 @@ fn compute_main_trace_length( // Get the trace length required to hold all execution trace steps let max_len = range_table_len.max(core_trace_len).max(chiplets_trace_len); - // Pad the trace length to the next power of two - let trace_len = max_len.next_power_of_two(); + // Pad the trace length to the next power of two and ensure that there is space for random + // rows + let trace_len = (max_len + NUM_RAND_ROWS).next_power_of_two(); core::cmp::max(trace_len, MIN_TRACE_LEN) } @@ -222,9 +242,13 @@ fn generate_core_trace_columns( // Run batch inversion on stack's H0 helper column, processing each fragment in parallel. // This must be done after fixup_stack_and_system_rows since that function overwrites the first // row of each fragment with non-inverted values. + // Note: We need to handle zeros properly, so we use a helper function that processes chunks + // and handles zeros by leaving them unchanged. { let h0_column = &mut core_trace_columns[STACK_TRACE_OFFSET + H0_COL_IDX]; - h0_column.par_chunks_mut(fragment_size).for_each(invert_column_allow_zeros); + h0_column.par_chunks_mut(fragment_size).for_each(|chunk| { + invert_column_allow_zeros(chunk); + }); } // Truncate the core trace columns. After this point, there is no more uninitialized memory. @@ -422,7 +446,7 @@ fn initialize_range_checker( // Add all u32 range checks recorded during execution for (clk, values) in range_checker_replay.into_iter() { - range_checker.add_range_checks(clk, &values); + range_checker.add_range_checks(clk, &values[..]); } // Add all memory-related range checks @@ -442,8 +466,11 @@ fn initialize_chiplets( ) -> Chiplets { let mut chiplets = Chiplets::new(kernel); + // Extract both the hasher operations and the op_batches_map to avoid cloning + let (hasher_ops, op_batches_map) = hasher_for_chiplet.into_parts(); + // populate hasher chiplet - for hasher_op in hasher_for_chiplet.into_iter() { + for hasher_op in hasher_ops.into_iter() { match hasher_op { HasherOp::Permute(input_state) => { let _ = chiplets.hasher.permute(input_state); @@ -451,14 +478,24 @@ fn initialize_chiplets( HasherOp::HashControlBlock((h1, h2, domain, expected_hash)) => { let _ = chiplets.hasher.hash_control_block(h1, h2, domain, expected_hash); }, - HasherOp::HashBasicBlock((op_batches, expected_hash)) => { - let _ = chiplets.hasher.hash_basic_block(&op_batches, expected_hash); + HasherOp::HashBasicBlock(expected_hash) => { + // Look up the operation batches from the deduplication map + let op_batches = op_batches_map.get(&expected_hash).unwrap_or_else(|| { + panic!( + "op_batches should exist in map for recorded digest {:?}. Map contains {} entries with keys: {:?}", + expected_hash, + op_batches_map.len(), + op_batches_map.keys().collect::>() + ); + }); + // Convert &Vec to &[OpBatch] for hash_basic_block + let _ = chiplets.hasher.hash_basic_block(op_batches.as_slice(), expected_hash); }, HasherOp::BuildMerkleRoot((value, path, index)) => { let _ = chiplets.hasher.build_merkle_root(value, &path, index); }, HasherOp::UpdateMerkleRoot((old_value, new_value, path, index)) => { - chiplets.hasher.update_merkle_root(old_value, new_value, &path, index); + let _ = chiplets.hasher.update_merkle_root(old_value, new_value, &path, index); }, } } @@ -522,19 +559,19 @@ fn initialize_chiplets( .expect("memory read element failed when populating chiplet"); }, MemoryAccess::WriteElement(addr, element, ctx, clk) => { - chiplets + let _ = chiplets .memory .write(ctx, addr, clk, element, &()) .expect("memory write element failed when populating chiplet"); }, MemoryAccess::ReadWord(addr, ctx, clk) => { - chiplets + let _ = chiplets .memory .read_word(ctx, addr, clk, &()) .expect("memory read word failed when populating chiplet"); }, MemoryAccess::WriteWord(addr, word, ctx, clk) => { - chiplets + let _ = chiplets .memory .write_word(ctx, addr, clk, word, &()) .expect("memory write word failed when populating chiplet"); @@ -567,7 +604,7 @@ fn initialize_chiplets( // populate kernel ROM for proc_hash in kernel_replay.into_iter() { - chiplets + let _ = chiplets .kernel_rom .access_proc(proc_hash, &()) .expect("kernel proc access failed when populating chiplet"); @@ -578,7 +615,7 @@ fn initialize_chiplets( fn pad_trace_columns(trace_columns: &mut [Vec], main_trace_len: usize) { let total_program_rows = trace_columns[0].len(); - assert!(total_program_rows <= main_trace_len); + assert!(total_program_rows + NUM_RAND_ROWS - 1 <= main_trace_len); let num_padding_rows = main_trace_len - total_program_rows;