diff --git a/air/src/ir/constraints.rs b/air/src/ir/constraints.rs index bf468793d..f3fd399a6 100644 --- a/air/src/ir/constraints.rs +++ b/air/src/ir/constraints.rs @@ -69,6 +69,15 @@ impl Constraints { &self.boundary_constraints[trace_segment] } + /// Returns the number of integrity constraints applied against the specified trace segment. + pub fn num_integrity_constraints(&self, trace_segment: TraceSegmentId) -> usize { + if self.integrity_constraints.len() <= trace_segment { + return 0; + } + + self.integrity_constraints[trace_segment].len() + } + /// Returns a vector of the degrees of the integrity constraints for the specified trace /// segment. pub fn integrity_constraint_degrees( diff --git a/air/src/ir/mod.rs b/air/src/ir/mod.rs index 0d327a5aa..8cbb0855d 100644 --- a/air/src/ir/mod.rs +++ b/air/src/ir/mod.rs @@ -140,6 +140,11 @@ impl Air { self.constraints.boundary_constraints(trace_segment) } + /// Return the number of integrity constraints + pub fn num_integrity_constraints(&self, trace_segment: TraceSegmentId) -> usize { + self.constraints.num_integrity_constraints(trace_segment) + } + /// Return the set of [ConstraintRoot] corresponding to the integrity constraints pub fn integrity_constraints(&self, trace_segment: TraceSegmentId) -> &[ConstraintRoot] { self.constraints.integrity_constraints(trace_segment) diff --git a/codegen/ace/src/lib.rs b/codegen/ace/src/lib.rs index eb6a734f1..977a758ef 100644 --- a/codegen/ace/src/lib.rs +++ b/codegen/ace/src/lib.rs @@ -4,6 +4,7 @@ mod dot; mod encoded; mod inputs; mod layout; +mod masm; #[cfg(test)] mod tests; @@ -19,6 +20,7 @@ pub use crate::{ encoded::EncodedCircuit as EncodedAceCircuit, inputs::{AceVars, AirInputs}, layout::Layout as AirLayout, + masm::{MasmVerifier, generate_masm_verifier}, }; type QuadFelt = QuadExtension; diff --git a/codegen/ace/src/masm/constraints_eval.rs b/codegen/ace/src/masm/constraints_eval.rs new file mode 100644 index 000000000..fcc23c680 --- /dev/null +++ b/codegen/ace/src/masm/constraints_eval.rs @@ -0,0 +1,173 @@ +use std::ops::Div; + +use crate::{ + EncodedAceCircuit, + masm::{DOUBLE_WORD_SIZE, MasmVerifierParameters, add_section}, +}; + +/// Generates the MASM module of the STARK verifier for checking constraints evaluation. +/// +/// The main logic for generating the module needs to determine the parameters to the `eval_circuit` +/// op, the arithmetic circuit description, and the number of iterations required for hashing it. +pub fn generate_constraints_eval_module( + masm_verifier_parameters: &MasmVerifierParameters, + encoded_circuit: &EncodedAceCircuit, +) -> String { + // generate the constants + let num_inputs_circuit = encoded_circuit.num_vars(); + let num_eval_gates_circuit = encoded_circuit.num_eval_rows(); + let circuit_description = encoded_circuit.instructions(); + let num_iterations_hash_ace_circuit = circuit_description.len().div(DOUBLE_WORD_SIZE); + assert_eq!(circuit_description.len() % DOUBLE_WORD_SIZE, 0); + + // fill the template with the generated constants + let mut file = CONSTRAINTS_EVAL_MASM + .to_string() + .replace("NUM_CONSTRAINTS_VALUE", &masm_verifier_parameters.num_constraints().to_string()) + .replace( + "MAX_CYCLE_LEN_LOG_VALUE", + &masm_verifier_parameters.max_cycle_len_log().to_string(), + ) + .replace("NUM_INPUTS_CIRCUIT_VALUE", &num_inputs_circuit.to_string()) + .replace("NUM_ITERATIONS_HASH_ACE_CIRCUIT", &num_iterations_hash_ace_circuit.to_string()) + .replace("NUM_EVAL_GATES_CIRCUIT_VALUE", &num_eval_gates_circuit.to_string()); + + // generate the section containing the circuit description for loading into the advice map + let advice_map_circuit_description_section = format_circuit_description(encoded_circuit); + + // add the generated section to the template file + add_section( + &mut file, + "ADVICE_MAP_CIRCUIT_DESCRIPTION", + &advice_map_circuit_description_section, + ); + + file +} + +// HELPERS +// ================================================================================================ + +/// Formats the circuit description using the syntax for loading static data through the advice map. +fn format_circuit_description(circuit: &EncodedAceCircuit) -> String { + let circuit_digest = circuit.circuit_hash(); + let mut result: String = + format!("adv_map.CIRCUIT_COMMITMENT(0x{circuit_digest})=[\n").to_string(); + result += &circuit + .instructions() + .iter() + .map(|elem| format!("{}", elem.as_int())) + .collect::>() + .join(",\n"); + result += "\n]\n"; + + result +} + +// TEMPLATES +// ================================================================================================ + +const CONSTRAINTS_EVAL_MASM: &str = r#" +use.std::crypto::hashes::rpo +use.std::crypto::stark::constants +use.std::crypto::stark::utils + +# CONSTANTS +# ================================================================================================= + +# Number of constraints, both boundary and transitional +const.NUM_CONSTRAINTS=NUM_CONSTRAINTS_VALUE + +# Number of inputs to the constraint evaluation circuit +const.NUM_INPUTS_CIRCUIT=NUM_INPUTS_CIRCUIT_VALUE + +# Number of evaluation gates in the constraint evaluation circuit +const.NUM_EVAL_GATES_CIRCUIT=NUM_EVAL_GATES_CIRCUIT_VALUE + +# Max cycle length for periodic columns +const.MAX_CYCLE_LEN_LOG=MAX_CYCLE_LEN_LOG_VALUE + +# --- constant getters -------------------------------------------------------- + +proc.get_num_constraints + push.NUM_CONSTRAINTS +end + +# ERRORS +# ================================================================================================= + +const.ERR_FAILED_TO_LOAD_CIRCUIT_DESCRIPTION="failed to load the circuit description for the constraints evaluation check" + + +# CONSTRAINT EVALUATION CHECKER +# ================================================================================================= + +#! Executes the constraints evaluation check by evaluating an arithmetic circuit using the ACE +#! chiplet. +#! +#! The circuit description is hardcoded into the verifier using its commitment, which is computed as +#! the sequential hash of its description using RPO hasher. The circuit description, containing both +#! constants and evaluation gates description, is stored at the contiguous memory region starting +#! at `ACE_CIRCUIT_PTR`. The variable part of the circuit input is stored at the contiguous memory +#! region starting at `pi_ptr`. The (variable) inputs to the circuit are laid out such that the +#! aforementioned memory regions are together contiguous with the (variable) inputs section. +#! +#! Inputs: [] +#! Outputs: [] +export.execute_constraint_evaluation_check + # Compute and store at the appropriate memory location the auxiliary inputs needed by + # the arithmetic circuit. + push.MAX_CYCLE_LEN_LOG + exec.utils::set_up_auxiliary_inputs_ace + # => [] + + # Load the circuit description from the advice tape and check that it matches the hardcoded digest + exec.load_ace_circuit_description + # => [] + + # Set up the inputs to the "eval_circuit" op. Namely: + # 1. a pointer to the inputs of the circuit in memory, + # 2. the number of inputs to the circuit, + # 3. the number of evaluation gates in the circuit. + push.NUM_EVAL_GATES_CIRCUIT + push.NUM_INPUTS_CIRCUIT + exec.constants::public_inputs_address_ptr mem_load + # => [pi_ptr, n_read, n_eval] + + # Perform the constraint evaluation check by checking that the circuit evaluates to zero, which + # boils down to the "eval_circuit" returning. + eval_circuit + # => [pi_ptr, n_read, n_eval] + + # Clean up the stack. + drop drop drop + # => [] +end + +#! Loads the description of the ACE circuit for the constraints evaluation check. +#! +#! Inputs: [] +#! Outputs: [] +proc.load_ace_circuit_description + push.CIRCUIT_COMMITMENT + adv.push_mapval + exec.constants::get_arithmetic_circuit_ptr + padw padw padw + repeat.NUM_ITERATIONS_HASH_ACE_CIRCUIT + adv_pipe + hperm + end + exec.rpo::squeeze_digest + movup.4 drop + assert_eqw.err=ERR_FAILED_TO_LOAD_CIRCUIT_DESCRIPTION + # => [] +end + + +# CONSTRAINT EVALUATION CIRCUIT DESCRIPTION +# ================================================================================================= + +# BEGIN_SECTION:ADVICE_MAP_CIRCUIT_DESCRIPTION +# END_SECTION:ADVICE_MAP_CIRCUIT_DESCRIPTION + +"#; diff --git a/codegen/ace/src/masm/deep_queries.rs b/codegen/ace/src/masm/deep_queries.rs new file mode 100644 index 000000000..34ea213d5 --- /dev/null +++ b/codegen/ace/src/masm/deep_queries.rs @@ -0,0 +1,312 @@ +use std::collections::HashMap; + +use crate::masm::{ + DOUBLE_WORD_SIZE, FIELD_EXTENSION_DEGREE, MasmVerifierParameters, generate_with_map_sections, +}; + +/// Generates the MASM module of the STARK verifier for computing the queries to the DEEP +/// composition polynomial. +/// +/// The main logic for generating the module needs to determine the need for adding a section +/// to process the part of the query related to the auxiliary trace in addition to the sizes +/// of each of the three sections, namely the main, auxilary and constraints composition polynomials +/// trace sections. Once this is determined, we can assign the values to the loops +/// processing each of the sections involved in the query computation. +pub fn generate_deep_queries_module(masm_verifier_parameters: &MasmVerifierParameters) -> String { + // compute the constants related to the processing of the main segment and constraint + // segment traces + let num_iterations_main_trace_part = + (masm_verifier_parameters.main_trace_width() as usize).div_ceil(DOUBLE_WORD_SIZE); + let num_iterations_cc_trace_part = + (masm_verifier_parameters.constraints_composition_trace_width() * FIELD_EXTENSION_DEGREE) + .div_ceil(DOUBLE_WORD_SIZE); + + // fill the template with the computed constants + let mut file = DEEP_QUERIES_MASM + .to_string() + .replace("NUM_ITERATIONS_MAIN_TRACE_PART", &num_iterations_main_trace_part.to_string()) + .replace( + "NUM_ITERATIONS_CONSTRAINTS_COMPOSITION_TRACE_PART", + &num_iterations_cc_trace_part.to_string(), + ); + + // handle the case when auxiliary segment exists + if let Some(aux_trace_width) = masm_verifier_parameters.aux_trace_width { + // first, build the section for processing the auxiliary segment, which requires computing + // the number of iterations in its main loop + let num_iterations_aux_trace_part = + (aux_trace_width as usize * FIELD_EXTENSION_DEGREE).div_ceil(DOUBLE_WORD_SIZE); + let aux_trace_part_processing = AUX_TRACE_PART_PROCESSING + .to_string() + .replace("NUM_ITERATIONS_AUX_TRACE_PART", &num_iterations_aux_trace_part.to_string()); + + // add a call to the generated procedure and include the generated procedure in the main + // file + let mut sections_map = HashMap::new(); + sections_map.insert("PROCESS_AUX_TRACE_PART_CALL", PROCESS_AUX_TRACE_PART_CALL.to_string()); + sections_map + .insert("IMPL_PROCESS_AUX_TRACE_PART_CALL", aux_trace_part_processing.to_string()); + generate_with_map_sections(&mut file, sections_map); + } + + file +} + +// TEMPLATES +// ================================================================================================ + +const DEEP_QUERIES_MASM: &str = r#" +use.std::crypto::stark::constants +use.std::crypto::stark::deep_queries + +# ERRORS +# ================================================================================================= + +const.LEAF_VALUE_MISMATCH="hash of leaf pre-image does not match leaf value from the Merkle tree" + +# MAIN PROCEDURE +# ================================================================================================= + +#! Computes the DEEP composition polynomial FRI queries. +#! +#! This procedures iterates over all FRI query indices stored in memory at `query_ptr` in +#! a word-aligned and overwrites each word with `[eval0, eval1, index, poe]` where: +#! +#! 1. `index` is the FRI query index, +#! 2. `poe := g^index`, with `g` being the evaluation domain generator, +#! 3. `eval := (eval0, eval1)` is the computed DEEP composition polynomial query. +#! +#! Inputs: [Y, query_ptr, query_end_ptr, W, query_ptr] +#! Outputs: [] +#! +#! where: +#! +#! 1. `Y` is a garbage word, +#! 2. `query_ptr` is a pointer to the memory region from where the query indices will be fetched +#! and to where the computed FRI queries will be stored in a word-aligned manner, +#! 3. `query_end_ptr` is a memory pointer used to indicate the end of the memory region used in +#! storing the computed FRI queries, +#! 4. `W` is the word `[q_z_0, q_z_1, q_gz_0, q_gz_1]` where `q_z = (q_z_0, q_z_1)` and +#! `q_gz = (q_gz_0, q_gz_1)` represent the constant terms across all FRI queries computations. +export.compute_deep_composition_polynomial_queries + + # Iterate over all FRI query indices and compute their correspond DEEP query + # The following assumes the existence of at least one query to compute which is a necessary + # requirement to get any soundness from the protocol. This assumption is validate at + # the start of the verification procedure. + push.1 + while.true + # Load the (main, aux, constraint)-traces rows associated with the current query and get + # the index of the query. + exec.load_query_row + #=> [Y, X, index, query_ptr, query_end_ptr, W, query_ptr] + + # Compute the current query and store the result + # We also re-arrange the stack for the next iteration of the loop + exec.deep_queries::compute_deep_query + # => [has_more_queries, Y, query_ptr+1, query_end_ptr] + end + + # Clean up the stack and return + dropw dropw drop drop drop +end + +# HELPER PROCEDURES +# ================================================================================================= + +#! Loads the next query rows in the main, auxiliary and constraint composition polynomials traces +#! and computes the values of the DEEP code word at the index corresponding to the query. +#! +#! It takes a pointer to the current random query index and returns that index, together with +#! the value +#! +#! Q^x(alpha) = (q_x_at_alpha_0, q_x_at_alpha_1) = \sum_{i=0}^{n+m+l} T_i * alpha^i +#! +#! where: +#! +#! 1. n, m and l are the widths of the main segment, auxiliary segment and constraint composition +#! traces, respectively. +#! 2. T_i are the values of columns in the main segment, auxiliary segment and constraint +#! composition traces, for the query. +#! 3. alpha is the randomness used in order to build the DEEP polynomial. +#! +#! Inputs: [Y, query_ptr] +#! Outputs: [Y, q_x_at_alpha_1, q_x_at_alpha_0, q_x_at_alpha_1, q_x_at_alpha_0, index, query_ptr] +#! +#! where: +#! - Y is a "garbage" word. +proc.load_query_row + # Process the main segment of the execution trace portion of the query + exec.process_main_segment_execution_trace + #=> [Y, ptr_x, ptr_alpha_inv, acc1, acc0, depth, index, query_ptr] + + # BEGIN_SECTION:PROCESS_AUX_TRACE_PART_CALL + # END_SECTION:PROCESS_AUX_TRACE_PART_CALL + + # Process the constraints composition polys trace portion of the query + exec.process_constraints_composition_poly_trace + #=> [Y, q_x_at_alpha_1, q_x_at_alpha_0, q_x_at_alpha_1, q_x_at_alpha_0, index, query_ptr] +end + +# MAIN TRACE SEGMENT PROCESSING +# ================================================================================================= + +#! Handles the logic for processing the main segment of the execution trace. +#! +#! Inputs: [Y, query_ptr] +#! Output: [Y, ptr_x, ptr_alpha_inv, acc1, acc0, depth, index, query_ptr] +proc.process_main_segment_execution_trace + # Load the query index + dup.4 mem_loadw + #=> [index, depth, y, y, query_ptr] where y are "garbage" values here and throughout + + # Get commitment to main segment of the execution trace + movdn.3 movdn.2 + push.0.0 + exec.constants::main_trace_com_ptr mem_loadw + #=>[MAIN_TRACE_TREE_ROOT, depth, index, query_ptr] + + # Use the commitment to get the leaf and save it + dup.5 dup.5 + mtree_get + exec.constants::tmp3 mem_storew + adv.push_mapval + #=>[LEAF_VALUE, MAIN_TRACE_TREE_ROOT, depth, index, query_ptr] + + exec.constants::tmp2 mem_loadw + swapw + #=>[LEAF_VALUE, ptr_x, ptr_alpha_inv, acc1, acc0, depth, index, query_ptr] + + # Load the values of the main segment of the execution trace at the current query. We also + # compute their hashing and the value of their random linear combination using powers of a + # single random value alpha. + padw swapw padw + #=> [Y, Y, 0, 0, 0, 0, ptr, y, y, y] + exec.load_main_segment_execution_trace + #=> [Y, L, C, ptr_x, ptr_alpha_inv, acc1, acc0, depth, index, query_ptr] + + # Load the leaf value we got using `mtree_get` and compare it against the hash we just computed + exec.constants::tmp3 mem_loadw + assert_eqw.err=LEAF_VALUE_MISMATCH +end + +#! Loads the portion of the query associated to the main segment of the execution trace. +#! +#! Inputs: [Y, Y, 0, 0, 0, 0, ptr] +#! Outputs: [Y, D, C, ptr] +proc.load_main_segment_execution_trace + repeat.NUM_ITERATIONS_MAIN_TRACE_PART + adv_pipe + horner_eval_base + hperm + end +end + +# BEGIN_SECTION:IMPL_PROCESS_AUX_TRACE_PART +# END_SECTION:IMPL_PROCESS_AUX_TRACE_PART + +# CONSTRAINTS COMPOSITION POLYNOMIALS TRACE PROCESSING +# ================================================================================================= + +#! Handles the logic for processing the constraints composition polynomials trace. +#! +#! Inputs: [Y, ptr_x, ptr_alpha_inv, acc1, acc0, depth, index, query_ptr] +#! Output: [Y, q_x_at_alpha_1, q_x_at_alpha_0, q_x_at_alpha_1, q_x_at_alpha_0, index, query_ptr] +proc.process_constraints_composition_poly_trace + # Load the commitment to the constraint trace + exec.constants::composition_poly_com_ptr mem_loadw + #=> [R, ptr_x, ptr_alpha_inv, acc1, acc0, depth, index, query_ptr] + + # Get the leaf against the commitment + dup.9 movup.9 + mtree_get + exec.constants::tmp3 mem_storew + adv.push_mapval + #=>[L, R, ptr_x, ptr_alpha_inv, acc1, acc0, index, query_ptr] + + # Load the 8 columns as quadratic extension field elements in batches of 4. + padw + swapw.2 + exec.load_constraints_composition_polys_trace + #=> [Y, L, Y, ptr_x, ptr_alpha_inv, acc1, acc0, index, query_ptr] + + # Load the leaf value we got using `mtree_get` and compare it against the hash we just computed + exec.constants::tmp3 mem_loadw + assert_eqw.err=LEAF_VALUE_MISMATCH + #=> [Y, ptr_x, ptr_alpha_inv, acc1, acc0, index, query_ptr] + + # Re-order the stack + swapw + drop drop + dup.1 dup.1 + swapw +end + +#! Loads the portion of the query associated to the constraints composition polynomials trace. +#! +#! Inputs: [Y, Y, 0, 0, 0, 0, ptr] +#! Outputs: [Y, D, C, ptr] +proc.load_constraints_composition_polys_trace + repeat.NUM_ITERATIONS_CONSTRAINTS_COMPOSITION_TRACE_PART + adv_pipe + horner_eval_ext + hperm + end +end +"#; + +const PROCESS_AUX_TRACE_PART_CALL: &str = r#" +# Process the auxiliary segment of the execution trace portion of the query +exec.process_aux_segment_execution_trace +#=> [Y, ptr_x, ptr_alpha_inv, acc1, acc0, depth, index, query_ptr] +"#; + +const AUX_TRACE_PART_PROCESSING: &str = r#" +# AUX TRACE SEGMENT PROCESSING +# ================================================================================================= + +#! Handles the logic for processing the auxiliary segment of the execution trace, if such a trace +#! exists. +#! +#! Inputs: [Y, ptr_x, ptr_alpha_inv, acc1, acc0, depth, index, query_ptr] +#! Output: [Y, ptr_x, ptr_alpha_inv, acc1, acc0, depth, index, query_ptr] +proc.process_aux_segment_execution_trace + # Load aux trace commitment and get leaf + exec.constants::aux_trace_com_ptr mem_loadw + + # Get the leaf against the auxiliary trace commitment for the current query + dup.9 dup.9 + mtree_get + exec.constants::tmp3 mem_storew + adv.push_mapval + #=> [L, R, ptr_x, ptr_alpha_inv, acc1, acc0, depth, index, query_ptr] + + # Load the values of the auxiliary segment of the execution trace at the current query + + # Set up the stack + exec.constants::zero_word_ptr mem_loadw + swapw padw + #=> [Y, Y, C, ptr_x, ptr_alpha_inv, acc1, acc0, depth, index, query_ptr] + + # Load the first 4 columns as a batch of 4 quadratic extension field elements. + exec.load_aux_segment_execution_trace + #=> [Y, D, C, ptr_x, ptr_alpha_inv, acc1, acc0, depth, index, query_ptr] + + # Load the leaf value we got using `mtree_get` and compare it against the hash we just computed + exec.constants::tmp3 mem_loadw + assert_eqw.err=LEAF_VALUE_MISMATCH + #=> [Y, ptr_x, ptr_alpha_inv, acc1, acc0, depth, index, query_ptr] +end + +#! Loads the portion of the query associated to the auxiliary segment of the execution trace. +#! +#! Inputs: [Y, Y, 0, 0, 0, 0, ptr] +#! Outputs: [Y, D, C, ptr] +proc.load_aux_segment_execution_trace + repeat.NUM_ITERATIONS_AUX_TRACE_PART + adv_pipe + horner_eval_ext + hperm + end +end +"#; diff --git a/codegen/ace/src/masm/mod.rs b/codegen/ace/src/masm/mod.rs new file mode 100644 index 000000000..98a95566f --- /dev/null +++ b/codegen/ace/src/masm/mod.rs @@ -0,0 +1,272 @@ +use std::collections::{BTreeMap, HashMap}; + +use air_ir::{Air, BusType}; +use anyhow::Ok; +use ood_frames::generate_ood_frames_module; +use public_inputs::generate_public_inputs; + +use crate::{ + AceCircuit, + masm::{ + constraints_eval::generate_constraints_eval_module, + deep_queries::generate_deep_queries_module, verifier::generate_verifier_module, + }, +}; + +mod constraints_eval; +mod deep_queries; +mod ood_frames; +mod public_inputs; +mod verifier; + +// CONSTANTS +// ================================================================================================ + +const FIELD_EXTENSION_DEGREE: usize = 2; +const NUM_CONSTRAINTS_COMPOSITION_POLYS: usize = 8; +const DOUBLE_WORD_SIZE: usize = 8; + +// MASM CODE GENERATOR +// ================================================================================================ + +/// Generates the modules of a Miden assembly (MASM) STARK verifier using the core verifier in +/// the Miden standard library `stdlib`. +/// +/// The following assumptions are made: +/// +/// 1. Number of constraints composition polynomials is set to 8, +/// 2. FRI folding factor is set to 4, +/// 3. Extension degree of the cryptographic field is 2. +pub fn generate_masm_verifier(air: &Air, circuit: &AceCircuit) -> anyhow::Result { + // generate the parameters needed during code generation + let masm_verifier_parameters = MasmVerifierParameters::from_air(air); + // get the encoded circuit for the ACE chiplet + let encoded_circuit = circuit.to_ace(); + + // generate the different modules + let deep_queries: String = generate_deep_queries_module(&masm_verifier_parameters); + let ood_frames: String = generate_ood_frames_module(&masm_verifier_parameters); + let public_inputs: String = generate_public_inputs(&masm_verifier_parameters); + let constraints_eval: String = + generate_constraints_eval_module(&masm_verifier_parameters, &encoded_circuit); + let verifier: String = generate_verifier_module(&masm_verifier_parameters); + + Ok(MasmVerifier { + deep_queries, + ood_frames, + public_inputs, + constraints_eval, + verifier, + }) +} + +// HELPER STRUCTS +// ================================================================================================ + +/// Collects the modules making up the MASM STARK verifier. +#[derive(Debug, Default)] +pub struct MasmVerifier { + constraints_eval: String, + deep_queries: String, + ood_frames: String, + public_inputs: String, + verifier: String, +} + +impl MasmVerifier { + pub fn deep_queries(&self) -> &str { + &self.deep_queries + } + + pub fn constants(&self) -> &str { + &self.constraints_eval + } + + pub fn ood_frames(&self) -> &str { + &self.ood_frames + } + + pub fn public_inputs(&self) -> &str { + &self.public_inputs + } + + pub fn verifier(&self) -> &str { + &self.verifier + } + + pub fn constraints_eval(&self) -> &str { + &self.constraints_eval + } +} + +/// Parameters derived from [Air] and used in building the MASM verifier. +struct MasmVerifierParameters { + num_auxiliary_randomness: u16, + max_cycle_len_log: u32, + + main_trace_width: u16, + aux_trace_width: Option, + constraints_composition_trace_width: usize, + + variable_len_pub_inputs_sizes: BTreeMap, + fixed_len_pub_inputs_total_size: usize, + num_constraints: usize, +} + +impl MasmVerifierParameters { + fn from_air(air: &Air) -> Self { + let main_trace_width = air.trace_segment_widths[0].next_multiple_of(8); + let aux_trace_width = + air.trace_segment_widths.get(1).map(|width| width.next_multiple_of(8)); + let num_auxiliary_randomness = air.num_random_values; + + let max_cycle_length = air.periodic_columns().map(|col| col.period()).max(); + let max_cycle_len_log = max_cycle_length.unwrap_or(1).ilog2(); + + // iterate over the public inputs and build a map from the table identifier to + // its width and its bus type + let mut variable_len_pub_inputs_sizes = BTreeMap::new(); + for bus in air.buses.iter() { + for pi in air.public_inputs() { + if let air_ir::BusBoundary::PublicInputTable(public_input_table_access) = + bus.1.first + { + if public_input_table_access.table_name == pi.name() { + let _ = variable_len_pub_inputs_sizes + .insert(pi.name(), (pi.size(), public_input_table_access.bus_type)); + } + } + if let air_ir::BusBoundary::PublicInputTable(public_input_table_access) = bus.1.last + { + if public_input_table_access.table_name == pi.name() { + let _ = variable_len_pub_inputs_sizes + .insert(pi.name(), (pi.size(), public_input_table_access.bus_type)); + } + } + } + } + + // compute the total number of fixed length public inputs + let mut fixed_len_pub_inputs_total_size = 0; + for pi in air.public_inputs() { + if let air_ir::PublicInput::Vector { size, .. } = pi { + fixed_len_pub_inputs_total_size += size + } + } + + // compute the number of constraints + let num_constraints: usize = [0, 1] + .iter() + .map(|trace_id| { + air.num_boundary_constraints(*trace_id) + air.num_integrity_constraints(*trace_id) + }) + .sum(); + + Self { + num_auxiliary_randomness, + max_cycle_len_log, + main_trace_width, + aux_trace_width, + num_constraints, + fixed_len_pub_inputs_total_size, + variable_len_pub_inputs_sizes, + constraints_composition_trace_width: NUM_CONSTRAINTS_COMPOSITION_POLYS, + } + } + + fn num_constraints(&self) -> usize { + self.num_constraints + } + + fn num_auxiliary_randomness(&self) -> u16 { + self.num_auxiliary_randomness + } + + fn max_cycle_len_log(&self) -> u32 { + self.max_cycle_len_log + } + + fn main_trace_width(&self) -> u16 { + self.main_trace_width + } + + fn aux_trace_width(&self) -> Option { + self.aux_trace_width + } + + fn constraints_composition_trace_width(&self) -> usize { + self.constraints_composition_trace_width + } + + fn variable_len_pub_inputs_sizes(&self) -> &BTreeMap { + &self.variable_len_pub_inputs_sizes + } + + fn fixed_len_pub_inputs_total_size(&self) -> usize { + self.fixed_len_pub_inputs_total_size + } +} + +// HELPER FUNCTIONS +// ================================================================================================ + +/// Given a map with keys section labels placeholders and values the `String` to assign to +/// the placholders, returns the resulting filled `String`. +fn generate_with_map_sections(file: &mut String, sections_map: HashMap<&'static str, String>) { + for (section_name, code) in sections_map { + let begin_marker = format!("# BEGIN_SECTION:{section_name}"); + let end_marker = format!("# END_SECTION:{section_name}"); + + if let Some(begin_index) = file.find(&begin_marker) { + if let Some(end_index) = file.find(&end_marker) { + let end_position = end_index + end_marker.len(); + + // find the line start to preserve indentation + let line_start = file[..begin_index].rfind('\n').map(|i| i + 1).unwrap_or(0); + + // extract indentation from the line containing the end marker + let indentation = &file[line_start..begin_index]; + let indent_str = + indentation.chars().take_while(|&c| c == ' ' || c == '\t').collect::(); + + let before = &file[..line_start]; + let after_line_end = file[end_position..] + .find('\n') + .map(|i| end_position + i + 1) + .unwrap_or(file.len()); + let after = &file[after_line_end..]; + + // indent each line of the code section to be inserted + let indented_code = if code.trim().is_empty() { + String::new() + } else { + code.lines() + .map(|line| { + if line.trim().is_empty() { + String::new() + } else { + format!("{indent_str}{line}") + } + }) + .collect::>() + .join("\n") + }; + + // assemble the resulting file + *file = if indented_code.is_empty() { + format!("{before}{after}") + } else { + format!("{before}{indented_code}\n{after}") + }; + } + } + } +} + +/// Given a section placeholder identifier and a `String` value to assign to the placholder, +/// returns the resulting filled `String`. +fn add_section(file: &mut String, section_placeholder_id: &'static str, section: &str) { + let mut sections_map = HashMap::new(); + sections_map.insert(section_placeholder_id, section.to_string()); + generate_with_map_sections(file, sections_map) +} diff --git a/codegen/ace/src/masm/ood_frames.rs b/codegen/ace/src/masm/ood_frames.rs new file mode 100644 index 000000000..22a7b6568 --- /dev/null +++ b/codegen/ace/src/masm/ood_frames.rs @@ -0,0 +1,57 @@ +use std::ops::Div; + +use super::MasmVerifierParameters; +use crate::masm::{DOUBLE_WORD_SIZE, FIELD_EXTENSION_DEGREE}; + +/// Generates the MASM module of the STARK verifier for processing the out-of-domain (OOD) +/// evaluations. +pub fn generate_ood_frames_module(masm_verifier_parameters: &MasmVerifierParameters) -> String { + let main_trace_width = masm_verifier_parameters.main_trace_width(); + let aux_trace_width = masm_verifier_parameters.aux_trace_width().unwrap_or(0); + let num_constraints_composition_polys = + masm_verifier_parameters.constraints_composition_trace_width(); + + // we are loading and hashing extension field elements and hence we need to first compute + // the number of extension field elements to process and thereafter convert this to + // a number over base field elements + let num_extension_field_elements = + main_trace_width + aux_trace_width + num_constraints_composition_polys as u16; + let num_base_field_elements = num_extension_field_elements * FIELD_EXTENSION_DEGREE as u16; + + // since we are loading two words per iteration, we need to divide by `DOUBLE_WORD_SIZE` + let num_iterations = num_base_field_elements.div(DOUBLE_WORD_SIZE as u16); + + // we check double-word alignment + debug_assert_eq!( + num_base_field_elements % DOUBLE_WORD_SIZE as u16, + 0, + "each of trace is expected to be double-word aligned" + ); + + OOD_FRAMES_MASM + .to_string() + .replace("{NUM_ITERATIONS_PROCESS_OOD_EVALS}", &num_iterations.to_string()) +} + +// TEMPLATES +// ================================================================================================ + +const OOD_FRAMES_MASM: &str = r#" +#! Processes the out-of-domain (OOD) evaluations of all committed polynomials. +#! +#! Takes as input an RPO hasher state and a pointer, and loads from the advice provider the OOD +#! evaluations and stores at memory region using pointer `ptr` while absorbing the evaluations +#! into the hasher state and simultaneously computing a random linear combination using Horner +#! evaluation. +#! +#! +#! Inputs: [R2, R1, C, ptr, acc1, acc0] +#! Outputs: [R2, R1, C, ptr, acc1`, acc0`] +export.process_row_ood_evaluations + repeat.{NUM_ITERATIONS_PROCESS_OOD_EVALS} + adv_pipe + horner_eval_ext + hperm + end +end +"#; diff --git a/codegen/ace/src/masm/public_inputs.rs b/codegen/ace/src/masm/public_inputs.rs new file mode 100644 index 000000000..9f7e15871 --- /dev/null +++ b/codegen/ace/src/masm/public_inputs.rs @@ -0,0 +1,599 @@ +use std::collections::HashMap; + +use super::MasmVerifierParameters; +use crate::masm::{DOUBLE_WORD_SIZE, generate_with_map_sections}; + +/// Generates the MASM module of the STARK verifier for handling public inputs. +/// +/// There are two main components to this module: +/// +/// 1. Fixed length public inputs processing: this takes as input the total number of fixed length +/// public inputs (as base field elements) which is used in order to determine the number of +/// iterations of the loop responsible for loading-storing-hashing the fixed length public +/// inputs, +/// 2. Variable length public inputs processing: this takes as input the map from identifiers to +/// (width, bus_type) of the so-called variable length tables, also called messages widths, and +/// the type of the bus corresponding to each table. This is used in order to generate +/// procedures, one per variable length table, in order to reduce each table, using auxiliary +/// randomness, to an element in the extension field. +pub fn generate_public_inputs(masm_verifier_parameters: &MasmVerifierParameters) -> String { + let num_iter_load_fixed_len_pub_inputs = + (masm_verifier_parameters.fixed_len_pub_inputs_total_size()).div_ceil(DOUBLE_WORD_SIZE); + + // procedures to reduce variable length inputs tables, one per table/group + let mut var_len_pi_reduction_procedures = String::new(); + // code section for calling the above procedures + let mut reduce_var_len_pi_groups_call = String::new(); + + // for each variable length public input group, we create a procedure to reduce the variable + // length inputs and add a call to the said procedure + // For each table/group, we associate an label in order to domain separate messages + // TODO: this will probably be the responsibility of the backend in the near term + let mut op_batch_section = String::new(); + let mut op_group_label = 0; + for (identifier, (message_width, bus_type)) in + masm_verifier_parameters.variable_len_pub_inputs_sizes().iter() + { + // create the label for the current group + op_batch_section += &VAR_LEN_PI_GROUP_OP_LABELS + .to_string() + .replace("{GROUP_ID}", &identifier.to_string().to_uppercase()) + .replace("OP_LABEL_VALUE", &op_group_label.to_string().to_uppercase()); + op_group_label += 1; + + // add a call to the procedure for this group + reduce_var_len_pi_groups_call += &REDUCE_VAR_LEN_PI_GROUP_ID + .to_string() + .replace("{group_id}", &identifier.to_string()); + + // depending on the type of bus, generate the appropriate procedure for reducing + // the variable length public inputs + let procedure = match bus_type { + air_ir::BusType::Multiset => REDUCE_VAR_LEN_PI_MULTISET_PROCEDURE + .to_string() + .replace("{group_id}", &identifier.to_string()) + .replace("{GROUP_ID}", &identifier.to_string().to_uppercase()) + .replace( + "{WIDTH_INTERACTION_GROUP_ID_IN_DOUBLE_WORD}", + &((*message_width).div_ceil(DOUBLE_WORD_SIZE)).to_string(), + ), + air_ir::BusType::Logup => REDUCE_VAR_LEN_PI_LOGUP_PROCEDURE + .to_string() + .replace("{group_id}", &identifier.to_string()) + .replace("{GROUP_ID}", &identifier.to_string().to_uppercase()) + .replace( + "{WIDTH_INTERACTION_GROUP_ID_IN_DOUBLE_WORD}", + &((*message_width).div_ceil(DOUBLE_WORD_SIZE)).to_string(), + ), + }; + var_len_pi_reduction_procedures.push_str(&procedure); + } + + // generate the map for filling the sections + let mut sections_map = HashMap::new(); + sections_map.insert("DEFINE_VAR_LEN_PI_GROUP_OP_LABELS", op_batch_section); + sections_map.insert("REDUCE_VAR_LEN_PI_GROUP_ID_CALL", reduce_var_len_pi_groups_call); + sections_map.insert( + "REDUCE_VAR_LEN_PI_GROUP_ID_PROCEDURES_DEFINITIONS", + var_len_pi_reduction_procedures, + ); + + // fill the constants first + let mut file = PUBLIC_INPUTS_MASM + .to_string() + .replace( + "NUM_FIXED_LEN_PUBLIC_INPUTS_VALUE", + &masm_verifier_parameters + .fixed_len_pub_inputs_total_size() + .next_multiple_of(DOUBLE_WORD_SIZE) + .to_string(), + ) + .replace("NUM_VAR_LEN_PI_GROUPS_VALUE", &op_group_label.to_string()) + .replace( + "NUM_ITER_LOAD_FIXED_LEN_PUB_INPUTS", + &num_iter_load_fixed_len_pub_inputs.to_string(), + ); + + // then we fill the sections + generate_with_map_sections(&mut file, sections_map); + + file +} + +// TEMPLATES +// ================================================================================================ + +const PUBLIC_INPUTS_MASM: &str = r#" +use.std::crypto::stark::constants +use.std::crypto::stark::random_coin +use.std::crypto::stark::public_inputs + +use.std::crypto::hashes::rpo + +# CONSTANTS +# ================================================================================================= + +# Number of fixed length public inputs with padding (in field elements) +const.NUM_FIXED_LEN_PUBLIC_INPUTS=NUM_FIXED_LEN_PUBLIC_INPUTS_VALUE + +# Number of variable length public input groups +const.NUM_VAR_LEN_PI_GROUPS=NUM_VAR_LEN_PI_GROUPS_VALUE + +# Op label for variable length public input groups +# BEGIN_SECTION:DEFINE_VAR_LEN_PI_GROUP_OP_LABELS +# END_SECTION:DEFINE_VAR_LEN_PI_GROUP_OP_LABELS + +# CONSTANTS GETTERS +# ================================================================================================= + +export.get_num_fixed_len_public_inputs + push.NUM_FIXED_LEN_PUBLIC_INPUTS +end + +# MAIN PROCEDURE +# ================================================================================================= + +#! Processes the public inputs. +#! +#! This involves: +#! +#! 1. Loading from the advice stack the fixed-length public inputs and storing them in memory +#! starting from the address pointed to by `public_inputs_address_ptr`. +#! 2. Loading from the advice stack the variable-length public inputs, storing them temporarily +#! in memory, and then reducing them to an element in the challenge field using the auxiliary +#! randomness. This reduced value is then used to impose a boundary condition on the relevant +#! auxiliary column. +#! +#! Note that the fixed length public inputs are stored as extension field elements while +#! the variable length ones are stored as base field elements. +#! +#! Note also that, while loading the above, we compute the hash of the public inputs. The hashing +#! starts with capacity registers of the hash function set to `C` that is the result of hashing +#! the proof context. +#! +#! The output D, that is the digest of the above hashing, is then used in order to reseed +#! the random coin. +#! +#! It is worth noting that: +#! +#! 1. Only the fixed-length public inputs are stored for the lifetime of the verification procedure. +#! The variable-length public inputs are stored temporarily, as this simplifies the task of +#! reducing them using the auxiliary randomness. On the other hand, the resulting values from +#! the aforementioned reductions are stored right after the fixed-length public inputs. These +#! are stored in a word-aligned manner and padded with zeros if needed. +#! 2. The public inputs address is computed in such a way so as we end up with the following +#! memory layout: +#! +#! [..., a_0...a_{m-1}, b_0...b_{n-1}, alpha0, alpha1, beta0, beta1, OOD-evaluations-start, ...] +#! +#! where: +#! +#! 1. [a_0...a_{m-1}] are the fixed-length public inputs stored as extension field elements. This +#! section is double-word-aligned. +#! 2. [b_0...b_{n-1}] are the results of reducing the variable length public inputs using +#! auxiliary randomness. This section is word-aligned. +#! 3. [alpha0, alpha1, beta0, beta1] is the auxiliary randomness. +#! 4. `OOD-evaluations-start` is the first field element of the section containing the OOD +#! evaluations. +#! 3. Note that for each bus message in a group in the variable length public inputs, each +#! message is expected to be padded to the next multiple of 8 and provided in reverse order. +#! This has the benefit of making the reduction using the auxiliary randomness more efficient +#! using `horner_eval_base`. +#! +#! +#! Input: [C, ...] +#! Output: [...] +export.process_public_inputs + # 1) Compute the address where the public inputs will be stored and store it. + # This also computes the address where the reduced variable-length public inputs will be stored. + exec.get_num_fixed_len_public_inputs push.NUM_VAR_LEN_PI_GROUPS + exec.public_inputs::compute_and_store_public_inputs_address + # => [C, ...] + + # 2) Load the public inputs. + # This will also hash them so that we can absorb them in the transcript. + exec.load_public_inputs + # => [D, ...] + + # 3) Absorb into the transcript + exec.random_coin::reseed + # => [...] + + # 4) Reduce the variable-length public inputs using randomness. + exec.reduce_variable_length_public_inputs +end + +# HELPER PROCEDURES +# ================================================================================================= + +#! Loads from the advice stack the public inputs and stores them in memory starting from address +#! pointed to by `public_inputs_address_ptr`. +#! +#! Note that the public inputs are stored as extension field elements. +#! +#! In parallel, it computes the hash of the public inputs being loaded. The hashing starts with +#! capacity registers of the hash function set to `C` resulting from hashing the proof context. +#! The output D is the digest of the hashing of the public inputs. +#! +#! Inputs: [C, ...] +#! Outputs: [D, ...] +proc.load_public_inputs + # 1) Load and hash the fixed length public inputs + + exec.constants::public_inputs_address_ptr mem_load + movdn.4 + padw padw + repeat.NUM_ITER_LOAD_FIXED_LEN_PUB_INPUTS + exec.public_inputs::load_base_store_extension_double_word + hperm + end + + # 2) Load and hash the variable length public inputs + + ## a) Compute the number of base field elements in total in the variable length public inputs + exec.constants::num_public_inputs_ptr mem_load + exec.get_num_fixed_len_public_inputs + sub + # => [num_var_len_pi, R2, R1, C, ptr, ...] + + ## b) Compute the number of hash iteration needed to hash the variable length public inputs. + ## We also check the double-word alignment. + u32divmod.8 + # => [rem, num_iter, R2, R1, C, ptr, ...] + push.0 assert_eq + # => [num_iter, R2, R1, C, ptr, ...] + + ## c) Prepare the stack for hashing + movdn.13 + # => [R2, R1, C, ptr, num_iter, ...] + dup.13 sub.1 swap.14 + push.0 neq + # => [(num_iter == 0), R2, R1, C, ptr, num_iter - 1, ...] + + ## d) Hash the variable length public inputs + while.true + adv_pipe + hperm + # => [R2, R1, C, ptr, num_iter, ...] + dup.13 sub.1 swap.14 + push.0 neq + end + # => [R2, R1, C, ptr, num_iter, ...] + + # 3) Return the final digest + exec.rpo::squeeze_digest + # => [D, ptr, num_iter, ...] where D = R1 the digest + movup.4 drop + movup.4 drop + # => [D, ...] +end + +#! Reduces the variable-length public inputs using the auxiliary randomness. +#! +#! The procedure non-deterministically loads the auxiliary randomness from the advice tape and +#! stores it at `aux_rand_nd_ptr` so that it can be later checked for correctness. After this, +#! the procedure uses the auxiliary randomness in order to reduce the variable-length public +#! inputs to a single element in the challenge field. The resulting values are then stored +#! contiguously after the fixed-length public inputs. +#! +#! Input: +#! - Operand stack: [...] +#! - Advice stack: [beta0, beta1, alpha0, alpha1, var_len_pi_1_len, ..., var_len_pi_k_len, ...] +#! Output: [D, ...] +proc.reduce_variable_length_public_inputs + # 1) Load the auxiliary randomness i.e., alpha and beta + # We store them as [beta0, beta1, alpha0, alpha1] since `horner_eval_ext` requires memory + # word-alignment. + adv_push.4 + exec.constants::aux_rand_nd_ptr mem_storew + # => [alpha1, alpha0, beta1, beta0, ...] + dropw + # => [...] + + # 2) Get the pointer to the variable-length public inputs. + # This is also the pointer to the first address at which we will store the results of + # the reductions. + exec.constants::variable_length_public_inputs_address_ptr mem_load + dup + # => [next_var_len_pub_inputs_ptr, var_len_pub_inputs_res_ptr, ...] where + # `next_var_len_pub_inputs_ptr` points to the next chunk of variable public inputs to be reduced, + # and `var_len_pub_inputs_res_ptr` points to the next available memory location where the result + # of the reduction can be stored. + # Note that, as mentioned in the top of this module, the variable-length public inputs are only + # stored temporarily and they will be over-written by, among other data, the result of reducing + # the variable public inputs. + + # BEGIN_SECTION:REDUCE_VAR_LEN_PI_GROUP_ID_CALL + # END_SECTION:REDUCE_VAR_LEN_PI_GROUP_ID_CALL + + # 3) Clean up the stack. + drop drop + # => [...] +end + +# BEGIN_SECTION:REDUCE_VAR_LEN_PI_GROUP_ID_PROCEDURES_DEFINITIONS +# END_SECTION:REDUCE_VAR_LEN_PI_GROUP_ID_PROCEDURES_DEFINITIONS + +"#; + +const VAR_LEN_PI_GROUP_OP_LABELS: &str = r#" +const.VAR_LEN_PI_GROUP_{GROUP_ID}_OP_LABEL=OP_LABEL_VALUE +"#; + +const REDUCE_VAR_LEN_PI_GROUP_ID: &str = r#"adv_push.1 exec.reduce_var_len_pi_group_{group_id} +# => [next_var_len_pub_inputs_ptr, var_len_pub_inputs_res_ptr, ...] +"#; + +const REDUCE_VAR_LEN_PI_MULTISET_PROCEDURE: &str = r#" + +#! Reduces the variable length public inputs for this group using auxiliary randomness. +#! +#! Inputs: [num_interaction, interaction_ptr] +#! Outputs: [next_ptr] +#! +#! where `interaction_ptr` is a pointer to the messages in this group. +proc.reduce_var_len_pi_group_{group_id} + # Assert that the number of interactions is at most 1023 + dup u32lt.1024 assert + + # Store number of interactions + push.0.0 dup.2 + exec.constants::tmp1 mem_storew + # => [num_interaction, 0, 0, num_interaction, interaction_ptr, ...] + + # Load alpha + exec.constants::aux_rand_nd_ptr mem_loadw + # => [alpha1, alpha0, beta1, beta0, interaction_ptr, ...] + + # We will keep [beta0, beta1, alpha0 + op_label, alpha1] on the stack so that we can compute + # the final result, where op_label is a unique label to domain separate the interaction with + # the chiplets` bus. + # The final result is then computed as: + # + # alpha + op_label * beta^0 + beta * (r_0 * beta^0 + r_1 * beta^1 + r_2 * beta^2 + r_3 * beta^3) + swap + push.VAR_LEN_PI_GROUP_{GROUP_ID}_OP_LABEL + add + swap + # => [alpha1, alpha0 + op_label, beta1, beta0, interaction_ptr, ...] + + # Push the `horner_eval_ext` accumulator + push.0.0 + # => [acc1, acc0, alpha1, alpha0 + op_label, beta1, beta0, interaction_ptr, ...] + + # Push the pointer to the evaluation point beta + exec.constants::aux_rand_nd_ptr + # => [beta_ptr, acc1, acc0, alpha1, alpha0 + op_label, beta1, beta0, interaction_ptr, ...] + + # Get the pointer to interactions + movup.7 + # => [interaction_ptr, beta_ptr, acc1, acc0, alpha1, alpha0 + op_label, beta1, beta0, ...] + + # Set up the stack for `mem_stream` + `horner_eval_ext` + swapw + padw padw + # => [Y, Y, alpha1, alpha0 + op_label, beta1, beta0, interaction_ptr, beta_ptr, acc1, acc0, ...] + # where `Y` is a garbage word. + + exec.constants::tmp1 mem_loadw dup + push.0 + neq + + while.true + repeat.{WIDTH_INTERACTION_GROUP_ID_IN_DOUBLE_WORD} + mem_stream + horner_eval_base + end + # => [Y, Y, alpha1, alpha0 + op_label, beta1, beta0, interaction_ptr, beta_ptr, acc1, acc0, ...] + + swapdw + # => [alpha1, alpha0 + op_label, beta1, beta0, interaction_ptr, beta_ptr, acc1, acc0, Y, Y, ...] + + movup.7 movup.7 + # => [acc1, acc0, alpha1, alpha0 + op_label, beta1, beta0, interaction_ptr, beta_ptr, Y, Y, ...] + + dup.5 dup.5 + # => [beta1, beta0, acc1, acc0, alpha1, alpha0 + op_label, beta1, beta0, interaction_ptr, beta_ptr, Y, Y, ...] + ext2mul + # => [tmp1', tmp0', alpha1, alpha0 + op_label, beta1, beta0, interaction_ptr, beta_ptr, Y, Y, ...] + + dup.3 dup.3 + ext2add + # => [term1', term0', alpha1, alpha0 + op_label, beta1, beta0, interaction_ptr, beta_ptr, Y, Y, ...] + + movdn.15 + movdn.15 + # => [alpha1, alpha0 + op_label, beta1, beta0, interaction_ptr, beta_ptr, Y, Y, term1', term0', ...] + + push.0 movdn.6 + push.0 movdn.6 + # => [alpha1, alpha0 + op_label, beta1, beta0, interaction_ptr, beta_ptr, 0, 0, Y, Y, term1', term0', ...] + + swapdw + # => [Y, Y, alpha1, alpha0 + op_label, beta1, beta0, interaction_ptr, beta_ptr, 0, 0, term1', term0', ...] + + exec.constants::tmp1 mem_loadw sub.1 + exec.constants::tmp1 mem_storew + + dup + push.0 + neq + end + # => [Y, Y, alpha1, alpha0 + op_label, beta1, beta0, interaction_ptr, beta_ptr, 0, 0, term1', term0', ...] + + dropw dropw dropw + # => [interaction_ptr, beta_ptr, 0, 0, term1', term0', ...] + dup exec.constants::tmp2 mem_store + exec.constants::tmp1 mem_loadw drop drop drop + + push.1.0 + movup.2 + dup + push.0 + neq + # => [loop, n, acc1, acc0, term1_1, term1_0, ..., termn_1, termn_0, ...] + + while.true + sub.1 movdn.4 + # => [acc1, acc0, term1_1, term1_0, n - 1, ..., termn_1, termn_0, ...] + ext2mul + # => [acc1', acc0', n - 1, ..., termn_1, termn_0, ...] + movup.2 + dup + push.0 + neq + # => [loop, n - 1, acc1', acc0', term1_1, term1_0, ..., termn_1, termn_0, ...] + end + + drop + exec.constants::tmp2 mem_load movdn.2 + # since we are initializing the bus with "requests", we should invert the reduced result + ext2inv + # => [prod_acc1, prod_acc0, interaction_ptr, ...] + + # Store the result + push.0.0 + # => [0, 0, prod_acc1, prod_acc0, interaction_ptr, var_len_pub_inputs_res_ptr, ...] + dup.5 add.4 swap.6 + mem_storew + dropw + # => [interaction_ptr, var_len_pub_inputs_res_ptr, ...] +end +"#; + +const REDUCE_VAR_LEN_PI_LOGUP_PROCEDURE: &str = r#" +#! Reduces the variable length public inputs for this group using auxiliary randomness. +#! +#! Inputs: [num_interaction, interaction_ptr] +#! Outputs: [next_ptr] +#! +#! where `interaction_ptr` is a pointer to the messages in this group. +proc.reduce_var_len_pi_group_{group_id} + # Assert that the number of interactions is at most 1023 + dup u32lt.1024 assert + + # Store number of interactions + push.0.0 dup.2 + exec.constants::tmp1 mem_storew + # => [num_interaction, 0, 0, num_interaction, interaction_ptr, ...] + + # Load alpha + exec.constants::aux_rand_nd_ptr mem_loadw + # => [alpha1, alpha0, beta1, beta0, interaction_ptr, ...] + + # We will keep [beta0, beta1, alpha0 + op_label, alpha1] on the stack so that we can compute + # the final result, where op_label is a unique label to domain separate the interaction with + # the chiplets` bus. + # The final result is then computed as: + # + # alpha + op_label * beta^0 + beta * (r_0 * beta^0 + r_1 * beta^1 + r_2 * beta^2 + r_3 * beta^3) + swap + push.VAR_LEN_PI_GROUP_{GROUP_ID}_OP_LABEL + add + swap + # => [alpha1, alpha0 + op_label, beta1, beta0, interaction_ptr, ...] + + # Push the `horner_eval_ext` accumulator + push.0.0 + # => [acc1, acc0, alpha1, alpha0 + op_label, beta1, beta0, interaction_ptr, ...] + + # Push the pointer to the evaluation point beta + exec.constants::aux_rand_nd_ptr + # => [beta_ptr, acc1, acc0, alpha1, alpha0 + op_label, beta1, beta0, interaction_ptr, ...] + + # Get the pointer to interactions + movup.7 + # => [interaction_ptr, beta_ptr, acc1, acc0, alpha1, alpha0 + op_label, beta1, beta0, ...] + + # Set up the stack for `mem_stream` + `horner_eval_ext` + swapw + padw padw + # => [Y, Y, alpha1, alpha0 + op_label, beta1, beta0, interaction_ptr, beta_ptr, acc1, acc0, ...] + # where `Y` is a garbage word. + + exec.constants::tmp1 mem_loadw dup + push.0 + neq + + while.true + repeat.{WIDTH_INTERACTION_GROUP_ID_IN_DOUBLE_WORD} + mem_stream + horner_eval_base + end + # => [Y, Y, alpha1, alpha0 + op_label, beta1, beta0, interaction_ptr, beta_ptr, acc1, acc0, ...] + + swapdw + # => [alpha1, alpha0 + op_label, beta1, beta0, interaction_ptr, beta_ptr, acc1, acc0, Y, Y, ...] + + movup.7 movup.7 + # => [acc1, acc0, alpha1, alpha0 + op_label, beta1, beta0, interaction_ptr, beta_ptr, Y, Y, ...] + + dup.5 dup.5 + # => [beta1, beta0, acc1, acc0, alpha1, alpha0 + op_label, beta1, beta0, interaction_ptr, beta_ptr, Y, Y, ...] + ext2mul + # => [tmp1', tmp0', alpha1, alpha0 + op_label, beta1, beta0, interaction_ptr, beta_ptr, Y, Y, ...] + + dup.3 dup.3 + ext2add + ext2inv + # => [term1', term0', alpha1, alpha0 + op_label, beta1, beta0, interaction_ptr, beta_ptr, Y, Y, ...] + + movdn.15 + movdn.15 + # => [alpha1, alpha0 + op_label, beta1, beta0, interaction_ptr, beta_ptr, Y, Y, term1', term0', ...] + + push.0 movdn.6 + push.0 movdn.6 + # => [alpha1, alpha0 + op_label, beta1, beta0, interaction_ptr, beta_ptr, 0, 0, Y, Y, term1', term0', ...] + + swapdw + # => [Y, Y, alpha1, alpha0 + op_label, beta1, beta0, interaction_ptr, beta_ptr, 0, 0, term1', term0', ...] + + exec.constants::tmp1 mem_loadw sub.1 + exec.constants::tmp1 mem_storew + + dup + push.0 + neq + end + # => [Y, Y, alpha1, alpha0 + op_label, beta1, beta0, interaction_ptr, beta_ptr, 0, 0, term1', term0', ...] + + dropw dropw dropw + # => [interaction_ptr, beta_ptr, 0, 0, term1', term0', ...] + dup exec.constants::tmp2 mem_store + exec.constants::tmp1 mem_loadw drop drop drop + + push.0.0 + movup.2 + dup + push.0 + neq + # => [loop, n, acc1, acc0, term1_1, term1_0, ..., termn_1, termn_0, ...] + + while.true + sub.1 movdn.4 + # => [acc1, acc0, term1_1, term1_0, n - 1, ..., termn_1, termn_0, ...] + ext2add + # => [acc1', acc0', n - 1, ..., termn_1, termn_0, ...] + movup.2 + dup + push.0 + neq + # => [loop, n - 1, acc1', acc0', term1_1, term1_0, ..., termn_1, termn_0, ...] + end + + drop + exec.constants::tmp2 mem_load movdn.2 + # since we are initializing the bus with "requests", we should negate the reduced result + ext2neg + # => [sum_acc1, sum_acc0, interaction_ptr, ...] + + # Store the result + push.0.0 + # => [0, 0, sum_acc1, sum_acc0, interaction_ptr, var_len_pub_inputs_res_ptr, ...] + dup.5 add.4 swap.6 + mem_storew + dropw + # => [interaction_ptr, var_len_pub_inputs_res_ptr, ...] +end +"#; diff --git a/codegen/ace/src/masm/verifier.rs b/codegen/ace/src/masm/verifier.rs new file mode 100644 index 000000000..efb6e95f1 --- /dev/null +++ b/codegen/ace/src/masm/verifier.rs @@ -0,0 +1,90 @@ +use crate::masm::MasmVerifierParameters; + +/// Generates the main MASM module of the STARK verifier. +pub fn generate_verifier_module(masm_verifier_parameters: &MasmVerifierParameters) -> String { + let is_aux_trace = masm_verifier_parameters.aux_trace_width().is_some() as u8; + let trace_info = build_trace_info(masm_verifier_parameters); + + VERIFIER_MASM + .to_string() + .replace("IS_AUX_TRACE_VALUE", &is_aux_trace.to_string()) + .replace("TRACE_INFO_VALUE", &trace_info) +} + +// HELPERS +// ================================================================================================ + +/// Builds the trace info constant. +fn build_trace_info(masm_verifier_parameters: &MasmVerifierParameters) -> String { + let main_segment_width = masm_verifier_parameters.main_trace_width as u8; + let (num_aux_segments, aux_segment_width): (u8, u8) = + match masm_verifier_parameters.aux_trace_width { + Some(aux_seg_width) => (1, aux_seg_width as u8), + None => (0, 0), + }; + let num_aux_randomness = masm_verifier_parameters.num_auxiliary_randomness() as u8; + + "0x".to_string() + + &format!("{main_segment_width:02x}",).to_string() + + &format!("{num_aux_segments:02x}",).to_string() + + &format!("{aux_segment_width:02x}",).to_string() + + &format!("{num_aux_randomness:02x}",).to_string() +} + +// TEMPLATES +// ================================================================================================ + +const VERIFIER_MASM: &str = r#" +use.std::crypto::hashes::rpo + +use.std::sys::vm::deep_queries +use.std::sys::vm::constraints_eval +use.std::sys::vm::ood_frames +use.std::sys::vm::public_inputs + +use.std::crypto::stark::verifier + +# Indicates the existence of auxiliary trace segment. +const.IS_AUX_TRACE=IS_AUX_TRACE_VALUE + +# A constant encoding the main segment width, the number of auxiliary segments (either 0 or 1), +# width of the auxiliary segment if it exists and the number of auxiliary randomness used by it +const.TRACE_INFO=TRACE_INFO_VALUE + +#! Verifies STARK proof. +#! +#! Inputs: [log(trace_length), num_queries, grinding] +#! Outputs: [] +export.verify_proof + # --- Get constants ------------------------------------------------------- + + # Flag indicating the existence of auxiliary trace + push.IS_AUX_TRACE movdn.3 + # => [log(trace_length), num_queries, grinding, is_aux_trace] + + # Number of fixed length public inputs + exec.public_inputs::get_num_fixed_len_public_inputs movdn.3 + # => [log(trace_length), num_queries, grinding, num_fixed_len_pi, is_aux_trace] + + # Trace info as one field element + push.TRACE_INFO movdn.3 + # => [log(trace_length), num_queries, grinding, trace_info, num_fixed_len_pi, is_aux_trace] + + # Number of constraints + exec.constraints_eval::get_num_constraints movdn.3 + # => [log(trace_length), num_queries, grinding, num_constraints, trace_info, num_fixed_len_pi, is_aux_trace] + + # --- Load the digests of all dynamically invoked procedures -------------- + + procref.deep_queries::compute_deep_composition_polynomial_queries + procref.constraints_eval::execute_constraint_evaluation_check + procref.ood_frames::process_row_ood_evaluations + procref.public_inputs::process_public_inputs + # =>[D3, D2, D1, D0, log(trace_length), num_queries, grinding, num_constraints, trace_info, num_fixed_len_pi, is_aux_trace] + + # --- Call the core verification procedure from `stdlib` ------------------ + + exec.verifier::verify + # => [...] +end +"#;