diff --git a/provekit/common/Cargo.toml b/provekit/common/Cargo.toml index 3acdbefc3..ad355845d 100644 --- a/provekit/common/Cargo.toml +++ b/provekit/common/Cargo.toml @@ -41,6 +41,7 @@ rayon.workspace = true ruint.workspace = true serde.workspace = true serde_json.workspace = true +sha3.workspace = true tracing.workspace = true xz2.workspace = true zerocopy.workspace = true diff --git a/provekit/common/src/r1cs.rs b/provekit/common/src/r1cs.rs index c971d9f19..2e34bb006 100644 --- a/provekit/common/src/r1cs.rs +++ b/provekit/common/src/r1cs.rs @@ -4,6 +4,7 @@ use { }, ark_ff::Zero, serde::{Deserialize, Serialize}, + sha3::{Digest, Sha3_256}, std::collections::HashMap, }; @@ -61,6 +62,12 @@ impl R1CS { self.a.num_cols } + #[must_use] + pub fn hash(&self) -> [u8; 32] { + let bytes = postcard::to_stdvec(self).expect("R1CS serialization failed"); + Sha3_256::digest(&bytes).into() + } + // Increase the size of the R1CS matrices to the specified dimensions. pub fn grow_matrices(&mut self, num_rows: usize, num_cols: usize) { self.a.grow(num_rows, num_cols); diff --git a/provekit/common/src/whir_r1cs.rs b/provekit/common/src/whir_r1cs.rs index e36713a5c..05c9714ff 100644 --- a/provekit/common/src/whir_r1cs.rs +++ b/provekit/common/src/whir_r1cs.rs @@ -27,6 +27,7 @@ pub struct WhirR1CSScheme { pub num_challenges: usize, pub has_public_inputs: bool, pub whir_witness: WhirZkConfig, + pub r1cs_hash: [u8; 32], } impl WhirR1CSScheme { diff --git a/provekit/common/src/witness/mod.rs b/provekit/common/src/witness/mod.rs index f7cf80db2..858d244ae 100644 --- a/provekit/common/src/witness/mod.rs +++ b/provekit/common/src/witness/mod.rs @@ -10,7 +10,7 @@ use { utils::{serde_ark, serde_ark_vec}, FieldElement, }, - ark_ff::{BigInt, One, PrimeField}, + ark_ff::{BigInt, BigInteger, One, PrimeField}, serde::{Deserialize, Serialize}, }; pub use { @@ -91,6 +91,16 @@ impl PublicInputs { _ => self.0.iter().copied().reduce(compress).unwrap(), } } + + #[must_use] + pub fn hash_bytes(&self) -> [u8; 32] { + let hash = self.hash(); + let bytes = hash.into_bigint().to_bytes_le(); + let mut result = [0u8; 32]; + let len = bytes.len().min(32); + result[..len].copy_from_slice(&bytes[..len]); + result + } } impl Default for PublicInputs { diff --git a/provekit/prover/src/lib.rs b/provekit/prover/src/lib.rs index bbac4de0a..fa80307df 100644 --- a/provekit/prover/src/lib.rs +++ b/provekit/prover/src/lib.rs @@ -5,7 +5,7 @@ use { r1cs::{CompressedLayers, CompressedR1CS}, whir_r1cs::WhirR1CSProver, }, - acir::native_types::WitnessMap, + acir::native_types::{Witness, WitnessMap}, anyhow::{Context, Result}, bn254_blackbox_solver::Bn254BlackBoxSolver, mavros_vm::interpreter as mavros_interpreter, @@ -13,12 +13,12 @@ use { noir_artifact_cli::fs::inputs::read_inputs_from_file, noirc_abi::InputMap, provekit_common::{ - FieldElement, MavrosProver, NoirElement, NoirProof, NoirProver, Prover, PublicInputs, - TranscriptSponge, + utils::noir_to_native, FieldElement, MavrosProver, NoirElement, NoirProof, NoirProver, + Prover, PublicInputs, TranscriptSponge, }, std::{mem::size_of, path::Path}, tracing::{debug, info_span, instrument}, - whir::transcript::{codecs::Empty, ProverState, VerifierMessage}, + whir::transcript::{ProverState, VerifierMessage}, }; pub mod input_utils; @@ -69,7 +69,24 @@ impl Prove for NoirProver { provekit_common::register_ntt(); let acir_witness_idx_to_value_map = generate_noir_witness(&mut self, input_map)?; - let num_public_inputs = self.program.functions[0].public_inputs().indices().len(); + + let mut public_input_indices = self.program.functions[0].public_inputs().indices(); + public_input_indices.sort_unstable(); + let public_inputs = if public_input_indices.is_empty() { + PublicInputs::new() + } else { + let values = public_input_indices + .iter() + .map(|&idx| { + let noir_val = acir_witness_idx_to_value_map + .get(&Witness::from(idx)) + .ok_or_else(|| anyhow::anyhow!("Missing public input at index {idx}"))?; + Ok(noir_to_native(*noir_val)) + }) + .collect::>>()?; + PublicInputs::from_vec(values) + }; + drop(self.program); drop(self.witness_generator); @@ -80,18 +97,17 @@ impl Prove for NoirProver { let num_witnesses = compressed_r1cs.num_witnesses(); let num_constraints = compressed_r1cs.num_constraints(); - // Set up transcript with sponge selected by hash_config. + // Set up transcript with public inputs bound to the instance. + let instance = public_inputs.hash_bytes(); let ds = self .whir_for_witness .create_domain_separator() - .instance(&Empty); + .instance(&instance); let mut merlin = ProverState::new(&ds, TranscriptSponge::from_config(self.hash_config)); let mut witness: Vec> = vec![None; num_witnesses]; // Solve w1 (or all witnesses if no challenges). - // Outer span captures memory AFTER w1_layers parameter is freed - // (parameter drop happens before outer span close). { let _s = info_span!("solve_w1").entered(); crate::r1cs::solve_witness_vec( @@ -178,17 +194,6 @@ impl Prove for NoirProver { r1cs.test_witness_satisfaction(&witness.iter().map(|w| w.unwrap()).collect::>()) .context("While verifying R1CS instance")?; - let public_inputs = if num_public_inputs == 0 { - PublicInputs::new() - } else { - PublicInputs::from_vec( - witness[1..=num_public_inputs] - .iter() - .map(|w| w.ok_or_else(|| anyhow::anyhow!("Missing public input witness"))) - .collect::>>()?, - ) - }; - let full_witness: Vec = witness .into_iter() .enumerate() @@ -227,10 +232,20 @@ impl Prove for MavrosProver { ¶ms, ); + let num_public_inputs = self.num_public_inputs; + let public_inputs = if num_public_inputs == 0 { + PublicInputs::new() + } else { + // TODO : Verify marvos prover's handling of public input + PublicInputs::from_vec(phase1.out_wit_pre_comm[1..=num_public_inputs].to_vec()) + }; + + // Set up transcript with public inputs bound to the instance. + let instance = public_inputs.hash_bytes(); let ds = self .whir_for_witness .create_domain_separator() - .instance(&Empty); + .instance(&instance); let mut merlin = ProverState::new(&ds, TranscriptSponge::from_config(self.hash_config)); let commitment_1 = self @@ -244,7 +259,7 @@ impl Prove for MavrosProver { ) .context("While committing to w1")?; - let (commitments, witgen_result) = if self.whir_for_witness.num_challenges > 0 { + let commitments = if self.whir_for_witness.num_challenges > 0 { let challenges: Vec = (0..self.witness_layout.challenges_size) .map(|_| merlin.verifier_message()) .collect(); @@ -267,22 +282,15 @@ impl Prove for MavrosProver { ) .context("While committing to w2")?; - (vec![commitment_1, commitment_2], witgen_result) + vec![commitment_1, commitment_2] } else { - let witgen_result = mavros_interpreter::run_phase2( + mavros_interpreter::run_phase2( phase1.clone(), &[], self.witness_layout, self.constraints_layout, ); - (vec![commitment_1], witgen_result) - }; - - let num_public_inputs = self.num_public_inputs; - let public_inputs = if num_public_inputs == 0 { - PublicInputs::new() - } else { - PublicInputs::from_vec(witgen_result.out_wit_pre_comm[1..=num_public_inputs].to_vec()) + vec![commitment_1] }; let whir_r1cs_proof = self diff --git a/provekit/r1cs-compiler/src/whir_r1cs.rs b/provekit/r1cs-compiler/src/whir_r1cs.rs index 0b18b19f5..526a1022a 100644 --- a/provekit/r1cs-compiler/src/whir_r1cs.rs +++ b/provekit/r1cs-compiler/src/whir_r1cs.rs @@ -76,6 +76,7 @@ impl WhirR1CSSchemeBuilder for WhirR1CSScheme { num_challenges, whir_witness: Self::new_whir_zk_config_for_size(m_raw, 1, hash_id), has_public_inputs, + r1cs_hash: r1cs.hash(), } } @@ -154,6 +155,7 @@ impl WhirR1CSSchemeBuilder for WhirR1CSScheme { w1_size, num_challenges, has_public_inputs, + r1cs_hash: [0u8; 32], // TODO: Mavros path needs r1cs_hash } } } diff --git a/provekit/verifier/src/whir_r1cs.rs b/provekit/verifier/src/whir_r1cs.rs index 59288172b..daabee7ec 100644 --- a/provekit/verifier/src/whir_r1cs.rs +++ b/provekit/verifier/src/whir_r1cs.rs @@ -14,7 +14,7 @@ use { tracing::instrument, whir::{ algebra::linear_form::LinearForm, - transcript::{codecs::Empty, Proof, VerifierMessage, VerifierState}, + transcript::{Proof, VerifierMessage, VerifierState}, }, }; @@ -44,7 +44,8 @@ impl WhirR1CSVerifier for WhirR1CSScheme { r1cs: &R1CS, hash_config: HashConfig, ) -> Result<()> { - let ds = self.create_domain_separator().instance(&Empty); + let instance = public_inputs.hash_bytes(); + let ds = self.create_domain_separator().instance(&instance); let whir_proof = Proof { narg_string: proof.narg_string.clone(), hints: proof.hints.clone(),