diff --git a/core/src/stack/inputs.rs b/core/src/stack/inputs.rs index f35c771ad1..70316375f6 100644 --- a/core/src/stack/inputs.rs +++ b/core/src/stack/inputs.rs @@ -16,7 +16,7 @@ use crate::{ /// /// The values in the struct are stored in the "stack order" - i.e., the last input is at the top /// of the stack (in position 0). -#[derive(Clone, Debug, Default)] +#[derive(Clone, Copy, Debug, Default)] pub struct StackInputs { elements: [Felt; MIN_STACK_DEPTH], } diff --git a/core/src/stack/outputs.rs b/core/src/stack/outputs.rs index 603dd2a261..886acc9123 100644 --- a/core/src/stack/outputs.rs +++ b/core/src/stack/outputs.rs @@ -19,7 +19,7 @@ use crate::utils::{ByteReader, Deserializable, DeserializationError, range}; /// `stack` is expected to be ordered as if the elements were popped off the stack one by one. /// Thus, the value at the top of the stack is expected to be in the first position, and the order /// of the rest of the output elements will also match the order on the stack. -#[derive(Debug, Clone, Default, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] pub struct StackOutputs { elements: [Felt; MIN_STACK_DEPTH], } diff --git a/crates/lib/core/tests/stark/mod.rs b/crates/lib/core/tests/stark/mod.rs index 3b693da7f6..2db2304dc2 100644 --- a/crates/lib/core/tests/stark/mod.rs +++ b/crates/lib/core/tests/stark/mod.rs @@ -88,7 +88,7 @@ pub fn generate_recursive_verifier_data( ProvingOptions::new(27, 8, 0, FieldExtension::Quadratic, 4, 127, HashFunction::Rpo256); let (stack_outputs, proof) = - prove(&program, stack_inputs.clone(), advice_inputs, &mut host, options).unwrap(); + prove(&program, stack_inputs, advice_inputs, &mut host, options).unwrap(); let program_info = ProgramInfo::from(program); diff --git a/crates/test-utils/src/lib.rs b/crates/test-utils/src/lib.rs index 3fa3396f1d..e467f85079 100644 --- a/crates/test-utils/src/lib.rs +++ b/crates/test-utils/src/lib.rs @@ -276,7 +276,7 @@ impl Test { let mut host = host.with_source_manager(self.source_manager.clone()); // execute the test - let stack_inputs: Vec = self.stack_inputs.clone().into_iter().rev().collect(); + let stack_inputs: Vec = self.stack_inputs.into_iter().rev().collect(); let processor = if self.in_debug_mode { FastProcessor::new_debug(&stack_inputs, self.advice_inputs.clone()) } else { @@ -385,7 +385,7 @@ impl Test { let mut host = host.with_source_manager(self.source_manager.clone()); let fast_stack_result = { - let stack_inputs: Vec = self.stack_inputs.clone().into_iter().rev().collect(); + let stack_inputs: Vec = self.stack_inputs.into_iter().rev().collect(); let advice_inputs: AdviceInputs = self.advice_inputs.clone(); let fast_processor = FastProcessor::new_with_advice_inputs(&stack_inputs, advice_inputs); @@ -417,7 +417,7 @@ impl Test { let mut host = host.with_source_manager(self.source_manager.clone()); let processor = FastProcessor::new_debug( - &self.stack_inputs.clone().into_iter().rev().collect::>(), + &self.stack_inputs.into_iter().rev().collect::>(), self.advice_inputs.clone(), ); @@ -438,7 +438,7 @@ impl Test { .with_debug_handler(debug_handler); let processor = FastProcessor::new_debug( - &self.stack_inputs.clone().into_iter().rev().collect::>(), + &self.stack_inputs.into_iter().rev().collect::>(), self.advice_inputs.clone(), ); @@ -466,7 +466,7 @@ impl Test { let stack_inputs = StackInputs::try_from_ints(pub_inputs).unwrap(); let (mut stack_outputs, proof) = miden_prover::prove_sync( &program, - stack_inputs.clone(), + stack_inputs, self.advice_inputs.clone(), &mut host, ProvingOptions::default(), @@ -573,14 +573,14 @@ impl Test { let mut host = host.with_source_manager(self.source_manager.clone()); let fast_result_by_step = { - let stack_inputs: Vec = self.stack_inputs.clone().into_iter().rev().collect(); + let stack_inputs: Vec = self.stack_inputs.into_iter().rev().collect(); let advice_inputs: AdviceInputs = self.advice_inputs.clone(); let fast_process = FastProcessor::new_with_advice_inputs(&stack_inputs, advice_inputs); fast_process.execute_by_step_sync(&program, &mut host) }; compare_results( - fast_result.as_ref().map(|(output, _)| output.stack.clone()), + fast_result.as_ref().map(|(output, _)| output.stack), &fast_result_by_step, "fast processor", "fast processor by step", diff --git a/prover/src/lib.rs b/prover/src/lib.rs index 8b81ed77fe..065abae735 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -1,34 +1,49 @@ #![no_std] +#[cfg_attr(all(feature = "metal", target_arch = "aarch64", target_os = "macos"), macro_use)] extern crate alloc; #[cfg(feature = "std")] extern crate std; -use alloc::string::ToString; +use core::marker::PhantomData; -use miden_processor::{Program, fast::FastProcessor, math::Felt, parallel::build_trace}; +use miden_air::{AuxRandElements, PartitionOptions, ProcessorAir, PublicInputs}; +#[cfg(all(feature = "metal", target_arch = "aarch64", target_os = "macos"))] +use miden_gpu::HashFn; +use miden_processor::{ + ExecutionTrace, Program, + crypto::{ + Blake3_192, Blake3_256, ElementHasher, Poseidon2, RandomCoin, Rpo256, RpoRandomCoin, + Rpx256, RpxRandomCoin, WinterRandomCoin, + }, + fast::{DEFAULT_CORE_TRACE_FRAGMENT_SIZE, FastProcessor}, + math::{Felt, FieldElement}, + parallel::build_trace, +}; use tracing::instrument; - -// Trace conversion utilities -mod trace_adapter; +use winter_maybe_async::{maybe_async, maybe_await}; +use winter_prover::{ + CompositionPoly, CompositionPolyTrace, ConstraintCompositionCoefficients, + DefaultConstraintCommitment, DefaultConstraintEvaluator, DefaultTraceLde, + ProofOptions as WinterProofOptions, Prover, StarkDomain, TraceInfo, TracePolyTable, + matrix::ColMatrix, +}; +#[cfg(feature = "std")] +use {std::time::Instant, winter_prover::Trace}; +mod gpu; // EXPORTS // ================================================================================================ pub use miden_air::{ - DEFAULT_CORE_TRACE_FRAGMENT_SIZE, DeserializationError, ExecutionProof, HashFunction, - ProcessorAir, ProvingOptions, config, -}; -pub use miden_crypto::{ - stark, - stark::{Commitments, OpenedValues, Proof}, + DeserializationError, ExecutionProof, FieldExtension, HashFunction, ProvingOptions, }; pub use miden_processor::{ - AdviceInputs, AsyncHost, BaseHost, ExecutionError, InputError, StackInputs, StackOutputs, - SyncHost, Word, crypto, math, utils, + AdviceInputs, AsyncHost, BaseHost, ExecutionError, InputError, PrecompileRequest, StackInputs, + StackOutputs, SyncHost, Word, crypto, math, utils, }; -pub use trace_adapter::{aux_trace_to_row_major, execution_trace_to_row_major}; +pub use winter_prover::{Proof, crypto::MerkleTree as MerkleTreeVC}; // PROVER // ================================================================================================ @@ -36,8 +51,6 @@ pub use trace_adapter::{aux_trace_to_row_major, execution_trace_to_row_major}; /// Executes and proves the specified `program` and returns the result together with a STARK-based /// proof of the program's execution. /// -/// This is an async function that works on all platforms including wasm32. -/// /// - `stack_inputs` specifies the initial state of the stack for the VM. /// - `host` specifies the host environment which contain non-deterministic (secret) inputs for the /// prover @@ -46,7 +59,8 @@ pub use trace_adapter::{aux_trace_to_row_major, execution_trace_to_row_major}; /// # Errors /// Returns an error if program execution or STARK proof generation fails for any reason. #[instrument("prove_program", skip_all)] -pub async fn prove( +#[maybe_async] +pub fn prove( program: &Program, stack_inputs: StackInputs, advice_inputs: AdviceInputs, @@ -54,6 +68,8 @@ pub async fn prove( options: ProvingOptions, ) -> Result<(StackOutputs, ExecutionProof), ExecutionError> { // execute the program to create an execution trace using FastProcessor + #[cfg(feature = "std")] + let now = Instant::now(); // Reverse stack inputs since FastProcessor expects them in reverse order // (first element = bottom of stack, last element = top) @@ -65,114 +81,218 @@ pub async fn prove( FastProcessor::new_with_advice_inputs(&stack_inputs_reversed, advice_inputs) }; - let (execution_output, trace_generation_context) = processor - .execute_for_trace(program, host, options.execution_options().core_trace_fragment_size()) - .await?; + let (execution_output, trace_generation_context) = + processor.execute_for_trace_sync(program, host, DEFAULT_CORE_TRACE_FRAGMENT_SIZE)?; - let trace = build_trace( + let mut trace = build_trace( execution_output, trace_generation_context, program.hash(), program.kernel().clone(), ); + #[cfg(feature = "std")] tracing::event!( tracing::Level::INFO, - "Generated execution trace of {} columns and {} steps (padded from {})", - miden_air::trace::TRACE_WIDTH, + "Generated execution trace of {} columns and {} steps ({}% padded) in {} ms", + trace.info().main_trace_width(), trace.trace_len_summary().padded_trace_len(), - trace.trace_len_summary().main_trace_len() + trace.trace_len_summary().padding_percentage(), + now.elapsed().as_millis() ); let stack_outputs = trace.stack_outputs().clone(); - let precompile_requests = trace.precompile_requests().to_vec(); let hash_fn = options.hash_fn(); - // Convert trace to row-major format - let trace_matrix = { - let _span = tracing::info_span!("execution_trace_to_row_major").entered(); - execution_trace_to_row_major(&trace) - }; + // extract precompile requests from the trace to include in the proof + let pc_requests = trace.precompile_requests().to_vec(); - // Build public values - let public_values = trace.to_public_values(); - - // Create AIR with aux trace builders - let air = ProcessorAir::with_aux_builder(trace.aux_trace_builders().clone()); - - // Generate STARK proof using unified miden-prover - let proof_bytes = match hash_fn { - HashFunction::Blake3_256 => { - let config = miden_air::config::create_blake3_256_config(); - let proof = stark::prove(&config, &air, &trace_matrix, &public_values); - serialize_proof(&proof)? + // generate STARK proof + let proof = match hash_fn { + HashFunction::Blake3_192 => { + let prover = ExecutionProver::>::new( + options, + stack_inputs, + stack_outputs, + ); + maybe_await!(prover.prove(trace)) }, - HashFunction::Keccak => { - let config = miden_air::config::create_keccak_config(); - let proof = stark::prove(&config, &air, &trace_matrix, &public_values); - serialize_proof(&proof)? + HashFunction::Blake3_256 => { + let prover = ExecutionProver::>::new( + options, + stack_inputs, + stack_outputs, + ); + maybe_await!(prover.prove(trace)) }, HashFunction::Rpo256 => { - let config = miden_air::config::create_rpo_config(); - let proof = stark::prove(&config, &air, &trace_matrix, &public_values); - serialize_proof(&proof)? - }, - HashFunction::Poseidon2 => { - let config = miden_air::config::create_poseidon2_config(); - let proof = stark::prove(&config, &air, &trace_matrix, &public_values); - serialize_proof(&proof)? + let prover = + ExecutionProver::::new(options, stack_inputs, stack_outputs); + #[cfg(all(feature = "metal", target_arch = "aarch64", target_os = "macos"))] + let prover = gpu::metal::MetalExecutionProver::new(prover, HashFn::Rpo256); + maybe_await!(prover.prove(trace)) }, HashFunction::Rpx256 => { - let config = miden_air::config::create_rpx_config(); - let proof = stark::prove(&config, &air, &trace_matrix, &public_values); - serialize_proof(&proof)? + let prover = + ExecutionProver::::new(options, stack_inputs, stack_outputs); + #[cfg(all(feature = "metal", target_arch = "aarch64", target_os = "macos"))] + let prover = gpu::metal::MetalExecutionProver::new(prover, HashFn::Rpx256); + maybe_await!(prover.prove(trace)) }, - }; + HashFunction::Poseidon2 => { + let prover = ExecutionProver::>::new( + options, + stack_inputs, + stack_outputs, + ); + maybe_await!(prover.prove(trace)) + }, + } + .map_err(ExecutionError::ProverError)?; - let proof = miden_air::ExecutionProof::new(proof_bytes, hash_fn, precompile_requests); + let proof = ExecutionProof::new(proof, hash_fn, pc_requests); Ok((stack_outputs, proof)) } -/// Synchronous wrapper for the async `prove()` function. -/// -/// This method is only available on non-wasm32 targets. On wasm32, use the -/// async `prove()` method directly since wasm32 runs in the browser's event loop. -/// -/// # Panics -/// Panics if called from within an existing Tokio runtime. Use the async `prove()` -/// method instead in async contexts. -#[cfg(not(target_arch = "wasm32"))] -#[instrument("prove_program_sync", skip_all)] -pub fn prove_sync( - program: &Program, +// PROVER +// ================================================================================================ + +struct ExecutionProver +where + H: ElementHasher, + R: RandomCoin, +{ + random_coin: PhantomData, + options: WinterProofOptions, stack_inputs: StackInputs, - advice_inputs: AdviceInputs, - host: &mut impl AsyncHost, - options: ProvingOptions, -) -> Result<(StackOutputs, ExecutionProof), ExecutionError> { - match tokio::runtime::Handle::try_current() { - Ok(_handle) => { - // We're already inside a Tokio runtime - this is not supported - // because we cannot safely create a nested runtime or move the - // non-Send host reference to another thread - panic!( - "Cannot call prove_sync from within a Tokio runtime. \ - Use the async prove() method instead." - ) - }, - Err(_) => { - // No runtime exists - create one and use it - let rt = tokio::runtime::Builder::new_current_thread().build().unwrap(); - rt.block_on(prove(program, stack_inputs, advice_inputs, host, options)) - }, + stack_outputs: StackOutputs, +} + +impl ExecutionProver +where + H: ElementHasher, + R: RandomCoin, +{ + pub fn new( + options: ProvingOptions, + stack_inputs: StackInputs, + stack_outputs: StackOutputs, + ) -> Self { + Self { + random_coin: PhantomData, + options: options.into(), + stack_inputs, + stack_outputs, + } + } + + // HELPER FUNCTIONS + // -------------------------------------------------------------------------------------------- + + /// Validates the stack inputs against the provided execution trace and returns true if valid. + fn are_inputs_valid(&self, trace: &ExecutionTrace) -> bool { + self.stack_inputs + .iter() + .zip(trace.init_stack_state().iter()) + .all(|(l, r)| l == r) + } + + /// Validates the stack outputs against the provided execution trace and returns true if valid. + fn are_outputs_valid(&self, trace: &ExecutionTrace) -> bool { + self.stack_outputs + .iter() + .zip(trace.last_stack_state().iter()) + .all(|(l, r)| l == r) } } -// HELPER FUNCTIONS -// ================================================================================================ +impl Prover for ExecutionProver +where + H: ElementHasher + Sync, + R: RandomCoin + Send, +{ + type BaseField = Felt; + type Air = ProcessorAir; + type Trace = ExecutionTrace; + type HashFn = H; + type VC = MerkleTreeVC; + type RandomCoin = R; + type TraceLde> = DefaultTraceLde; + type ConstraintEvaluator<'a, E: FieldElement> = + DefaultConstraintEvaluator<'a, ProcessorAir, E>; + type ConstraintCommitment> = + DefaultConstraintCommitment; -/// Serializes a proof to bytes, converting serialization errors to ExecutionError. -fn serialize_proof(proof: &T) -> Result, ExecutionError> { - bincode::serialize(proof).map_err(|e| ExecutionError::ProofSerializationError(e.to_string())) + fn options(&self) -> &WinterProofOptions { + &self.options + } + + fn get_pub_inputs(&self, trace: &ExecutionTrace) -> PublicInputs { + // ensure inputs and outputs are consistent with the execution trace. + debug_assert!( + self.are_inputs_valid(trace), + "provided inputs do not match the execution trace" + ); + debug_assert!( + self.are_outputs_valid(trace), + "provided outputs do not match the execution trace" + ); + + let program_info = trace.program_info().clone(); + let final_pc_transcript = trace.final_precompile_transcript(); + PublicInputs::new( + program_info, + self.stack_inputs, + self.stack_outputs, + final_pc_transcript.state(), + ) + } + + #[maybe_async] + fn new_trace_lde>( + &self, + trace_info: &TraceInfo, + main_trace: &ColMatrix, + domain: &StarkDomain, + partition_options: PartitionOptions, + ) -> (Self::TraceLde, TracePolyTable) { + DefaultTraceLde::new(trace_info, main_trace, domain, partition_options) + } + + #[maybe_async] + fn new_evaluator<'a, E: FieldElement>( + &self, + air: &'a ProcessorAir, + aux_rand_elements: Option>, + composition_coefficients: ConstraintCompositionCoefficients, + ) -> Self::ConstraintEvaluator<'a, E> { + DefaultConstraintEvaluator::new(air, aux_rand_elements, composition_coefficients) + } + + #[instrument(skip_all)] + #[maybe_async] + fn build_aux_trace>( + &self, + trace: &Self::Trace, + aux_rand_elements: &AuxRandElements, + ) -> ColMatrix { + trace.build_aux_trace(aux_rand_elements.rand_elements()).unwrap() + } + + #[maybe_async] + fn build_constraint_commitment>( + &self, + composition_poly_trace: CompositionPolyTrace, + num_constraint_composition_columns: usize, + domain: &StarkDomain, + partition_options: PartitionOptions, + ) -> (Self::ConstraintCommitment, CompositionPoly) { + DefaultConstraintCommitment::new( + composition_poly_trace, + num_constraint_composition_columns, + domain, + partition_options, + ) + } }