From e69257360306befa28d0096f55f56d12bdbb63dd Mon Sep 17 00:00:00 2001 From: markosg04 Date: Tue, 25 Nov 2025 19:51:38 -0500 Subject: [PATCH 1/4] feat: recursion --- Cargo.lock | 91 +++ Cargo.toml | 6 + README.md | 60 +- examples/recursion.rs | 152 +++++ src/backends/arkworks/ark_witness.rs | 326 +++++++++++ src/backends/arkworks/mod.rs | 10 + src/evaluation_proof.rs | 296 ++++++++++ src/lib.rs | 94 ++++ src/recursion/collection.rs | 139 +++++ src/recursion/collector.rs | 271 +++++++++ src/recursion/context.rs | 254 +++++++++ src/recursion/hint_map.rs | 324 +++++++++++ src/recursion/mod.rs | 61 ++ src/recursion/trace.rs | 797 +++++++++++++++++++++++++++ src/recursion/witness.rs | 105 ++++ tests/arkworks/mod.rs | 4 + tests/arkworks/recursion.rs | 315 +++++++++++ tests/arkworks/witness.rs | 47 ++ 18 files changed, 3350 insertions(+), 2 deletions(-) create mode 100644 examples/recursion.rs create mode 100644 src/backends/arkworks/ark_witness.rs create mode 100644 src/recursion/collection.rs create mode 100644 src/recursion/collector.rs create mode 100644 src/recursion/context.rs create mode 100644 src/recursion/hint_map.rs create mode 100644 src/recursion/mod.rs create mode 100644 src/recursion/trace.rs create mode 100644 src/recursion/witness.rs create mode 100644 tests/arkworks/recursion.rs create mode 100644 tests/arkworks/witness.rs diff --git a/Cargo.lock b/Cargo.lock index 0c2ce1c..871a32c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -424,6 +424,7 @@ dependencies = [ "serde", "thiserror 2.0.17", "tracing", + "tracing-subscriber", ] [[package]] @@ -562,6 +563,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + [[package]] name = "libc" version = "0.2.177" @@ -578,12 +585,36 @@ dependencies = [ "libc", ] +[[package]] +name = "log" +version = "0.4.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34080505efa8e45a4b816c349525ebe327ceaa8559756f0356cba97ef3bf7432" + +[[package]] +name = "matchers" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9" +dependencies = [ + "regex-automata", +] + [[package]] name = "memchr" version = "2.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" +[[package]] +name = "nu-ansi-term" +version = "0.50.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" +dependencies = [ + "windows-sys 0.61.2", +] + [[package]] name = "num-bigint" version = "0.4.6" @@ -851,6 +882,21 @@ dependencies = [ "serde_core", ] +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + [[package]] name = "subtle" version = "2.6.1" @@ -908,6 +954,15 @@ dependencies = [ "syn", ] +[[package]] +name = "thread_local" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185" +dependencies = [ + "cfg-if", +] + [[package]] name = "tinytemplate" version = "1.2.1" @@ -947,6 +1002,36 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b9d12581f227e93f094d3af2ae690a574abb8a2b9b7a96e7cfe9647b2b617678" dependencies = [ "once_cell", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2054a14f5307d601f88daf0553e1cbf472acc4f2c51afab632431cdcd72124d5" +dependencies = [ + "matchers", + "nu-ansi-term", + "once_cell", + "regex-automata", + "sharded-slab", + "smallvec", + "thread_local", + "tracing", + "tracing-core", + "tracing-log", ] [[package]] @@ -961,6 +1046,12 @@ version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "462eeb75aeb73aea900253ce739c8e18a67423fadf006037cd3ff27e82748a06" +[[package]] +name = "valuable" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" + [[package]] name = "version_check" version = "0.9.5" diff --git a/Cargo.toml b/Cargo.toml index 19195a2..d86c717 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,6 +48,7 @@ arkworks = [ parallel = ["dep:rayon", "ark-ec?/parallel", "ark-ff?/parallel"] cache = ["arkworks", "dep:once_cell", "parallel"] disk-persistence = ["dep:dirs"] +recursion = ["arkworks"] [dependencies] thiserror = "2.0" @@ -72,6 +73,7 @@ rayon = { version = "1.10", optional = true } [dev-dependencies] rand = "0.8" criterion = { version = "0.5", features = ["html_reports"] } +tracing-subscriber = { version = "0.3", features = ["env-filter"] } [[example]] name = "basic_e2e" @@ -85,6 +87,10 @@ required-features = ["backends"] name = "non_square" required-features = ["backends"] +[[example]] +name = "recursion" +required-features = ["recursion"] + [[bench]] name = "arkworks_proof" harness = false diff --git a/README.md b/README.md index 89f9959..fac3b7d 100644 --- a/README.md +++ b/README.md @@ -93,6 +93,45 @@ Com(r₁·P₁ + r₂·P₂ + ... + rₙ·Pₙ) = r₁·Com(P₁) + r₂·Com(P This property enables efficient proof aggregation and batch verification. See `examples/homomorphic.rs` for a demonstration. +### Recursive Proof Composition + +The `recursion` feature enables traced verification for building recursive SNARKs that compose Dory: + +1. **Witness Generation**: Run verification while capturing traces of all arithmetic operations (GT exponentiations, scalar multiplications, pairings, etc.) + +2. **Hint-Based Verification**: Re-run verification using pre-computed hints instead of performing expensive ops + +```rust +use std::rc::Rc; +use dory_pcs::{verify_recursive, setup, prove}; +use dory_pcs::backends::arkworks::{ + SimpleWitnessBackend, SimpleWitnessGenerator, BN254, G1Routines, G2Routines, +}; +use dory_pcs::recursion::TraceContext; + +type Ctx = TraceContext; + +// Phase 1: Witness generation - captures operation traces +let ctx = Rc::new(Ctx::for_witness_gen()); +verify_recursive::<_, BN254, G1Routines, G2Routines, _, _, _>( + commitment, evaluation, &point, &proof, setup.clone(), &mut transcript, ctx.clone(), +)?; + +let collection = Rc::try_unwrap(ctx).ok().unwrap().finalize().unwrap(); +// collection contains detailed witnesses for each operation + +// Convert to hints +let hints = collection.to_hints::(); + +// Phase 2: Hint-based verification +let ctx = Rc::new(Ctx::for_hints(hints)); +verify_recursive::<_, BN254, G1Routines, G2Routines, _, _, _>( + commitment, evaluation, &point, &proof, setup, &mut transcript, ctx, +)?; +``` + +See `examples/recursion.rs` for a complete demonstration. + ## Usage ```rust @@ -170,6 +209,11 @@ The repository includes three comprehensive examples demonstrating different asp cargo run --example non_square --features backends ``` +4. **`recursion`** - Trace generation and hint-based verification for recursive proof composition + ```bash + cargo run --example recursion --features recursion + ``` + ## Development Setup After cloning the repository, install Git hooks to ensure code quality: @@ -238,6 +282,7 @@ cargo bench --features backends,cache,parallel - `cache` - Enable prepared point caching for ~20-30% pairing speedup. Requires `arkworks` and `parallel`. - `parallel` - Enable parallelization using Rayon for MSMs and pairings. Works with both `arkworks` backend and enables parallel features in `ark-ec` and `ark-ff`. - `disk-persistence` - Enable automatic setup caching to disk. When enabled, `setup()` will load from OS-specific cache directories if available, avoiding regeneration. +- `recursion` - Enable traced verification for recursive proof composition. Provides witness generation and hint-based verification modes. ## Project Structure @@ -263,7 +308,15 @@ src/ ├── reduce_and_fold.rs # Inner product protocol ├── messages.rs # Protocol messages ├── proof.rs # Proof structure -└── error.rs # Error types +├── error.rs # Error types +└── recursion/ # Recursive verification support + ├── mod.rs # Module exports + ├── witness.rs # WitnessBackend, OpId, OpType traits/types + ├── context.rs # TraceContext for execution modes + ├── trace.rs # TraceG1, TraceG2, TraceGT wrappers + ├── collection.rs # WitnessCollection storage + ├── collector.rs # WitnessCollector and generator traits + └── hint_map.rs # Lightweight HintMap storage tests/arkworks/ ├── mod.rs # Test utilities @@ -271,7 +324,9 @@ tests/arkworks/ ├── commitment.rs # Commitment tests ├── evaluation.rs # Evaluation tests ├── integration.rs # End-to-end tests -└── soundness.rs # Soundness tests +├── soundness.rs # Soundness tests +├── recursion.rs # Trace and hint verification tests +└── witness.rs # Witness generation tests ``` ## Test Coverage @@ -285,6 +340,7 @@ The implementation includes comprehensive tests covering: - Non-square matrix support (nu < sigma, nu = sigma - 1, and very rectangular cases) - Soundness (tampering resistance for all proof components across 20+ attack vectors) - Prepared point caching correctness +- Recursive verification (witness generation and hint-based verification) ## Acknowledgments diff --git a/examples/recursion.rs b/examples/recursion.rs new file mode 100644 index 0000000..f6a353f --- /dev/null +++ b/examples/recursion.rs @@ -0,0 +1,152 @@ +//! Recursion example: trace generation and hint-based verification +//! +//! This example demonstrates the recursion API workflow: +//! 1. Standard proof generation +//! 2. Witness-generating verification (captures operation traces) +//! 3. Converting witnesses to hints +//! 4. Hint-based verification +//! +//! The hint-based verification enables efficient recursive proof composition. +//! +//! Run with: `cargo run --features recursion --example recursion` + +use std::rc::Rc; + +use dory_pcs::backends::arkworks::{ + ArkFr, ArkworksPolynomial, Blake2bTranscript, G1Routines, G2Routines, SimpleWitnessBackend, + SimpleWitnessGenerator, BN254, +}; +use dory_pcs::primitives::arithmetic::Field; +use dory_pcs::primitives::poly::Polynomial; +use dory_pcs::recursion::TraceContext; +use dory_pcs::{prove, setup, verify, verify_recursive}; +use rand::thread_rng; +use tracing::info; +use tracing_subscriber::EnvFilter; + +type Ctx = TraceContext; + +fn main() -> Result<(), Box> { + tracing_subscriber::fmt() + .with_env_filter( + EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")), + ) + .init(); + + info!("Dory PCS - Recursion API Example"); + info!("=================================\n"); + + let mut rng = thread_rng(); + + // Step 1: Setup + let max_log_n = 8; + info!("1. Generating setup (max_log_n = {})...", max_log_n); + let (prover_setup, verifier_setup) = setup::(&mut rng, max_log_n); + info!(" Setup complete\n"); + + // Step 2: Create polynomial + let nu = 3; + let sigma = 3; + let poly_size = 1 << (nu + sigma); // 64 coefficients + let num_vars = nu + sigma; + + info!("2. Creating random polynomial..."); + info!(" Matrix layout: {}x{}", 1 << nu, 1 << sigma); + info!(" Total coefficients: {}", poly_size); + + let coefficients: Vec = (0..poly_size).map(|_| ArkFr::random(&mut rng)).collect(); + let poly = ArkworksPolynomial::new(coefficients); + + // Step 3: Commit + info!("\n3. Computing commitment..."); + let (tier_2, tier_1) = poly.commit::(nu, sigma, &prover_setup)?; + + // Step 4: Create evaluation proof + let point: Vec = (0..num_vars).map(|_| ArkFr::random(&mut rng)).collect(); + let evaluation = poly.evaluate(&point); + + info!("4. Generating proof..."); + let mut prover_transcript = Blake2bTranscript::new(b"dory-recursion-example"); + let proof = prove::<_, BN254, G1Routines, G2Routines, _, _>( + &poly, + &point, + tier_1, + nu, + sigma, + &prover_setup, + &mut prover_transcript, + )?; + + // Step 5: Standard verification) + info!("\n5. Standard verification..."); + let mut std_transcript = Blake2bTranscript::new(b"dory-recursion-example"); + verify::<_, BN254, G1Routines, G2Routines, _>( + tier_2, + evaluation, + &point, + &proof, + verifier_setup.clone(), + &mut std_transcript, + )?; + info!(" Standard verification passed\n"); + + // Step 6: Witness-generating verification + info!("6. Witness-generating verification..."); + info!(" This captures traces of all arithmetic operations"); + + let ctx = Rc::new(Ctx::for_witness_gen()); + let mut witness_transcript = Blake2bTranscript::new(b"dory-recursion-example"); + + verify_recursive::<_, BN254, G1Routines, G2Routines, _, _, _>( + tier_2, + evaluation, + &point, + &proof, + verifier_setup.clone(), + &mut witness_transcript, + ctx.clone(), + )?; + + // Finalize and get witness collection + let collection = Rc::try_unwrap(ctx) + .ok() + .expect("should have sole ownership") + .finalize() + .expect("should have witnesses"); + + info!(" Witness collection stats:"); + info!(" - GT exponentiation: {}", collection.gt_exp.len()); + info!(" - G1 scalar mul: {}", collection.g1_scalar_mul.len()); + info!(" - G2 scalar mul: {}", collection.g2_scalar_mul.len()); + info!(" - GT multiplication: {}", collection.gt_mul.len()); + info!(" - Single pairing: {}", collection.pairing.len()); + info!(" - Multi-pairing: {}", collection.multi_pairing.len()); + info!(" - G1 MSM: {}", collection.msm_g1.len()); + info!(" - G2 MSM: {}", collection.msm_g2.len()); + info!(" - Total operations: {}", collection.total_witnesses()); + info!(" - Reduce-fold rounds: {}\n", collection.num_rounds); + + // Step 7: Convert to hints + info!("7. Converting witnesses to hints..."); + let hints = collection.to_hints::(); + info!(" HintMap entries: {} (one per operation)", hints.len()); + + // Step 8: Hint-based verification + info!("8. Hint-based verification..."); + + let ctx = Rc::new(Ctx::for_hints(hints)); + let mut hint_transcript = Blake2bTranscript::new(b"dory-recursion-example"); + + verify_recursive::<_, BN254, G1Routines, G2Routines, _, _, _>( + tier_2, + evaluation, + &point, + &proof, + verifier_setup, + &mut hint_transcript, + ctx, + )?; + info!(" Hint-based verification passed\n"); + + Ok(()) +} diff --git a/src/backends/arkworks/ark_witness.rs b/src/backends/arkworks/ark_witness.rs new file mode 100644 index 0000000..3654a89 --- /dev/null +++ b/src/backends/arkworks/ark_witness.rs @@ -0,0 +1,326 @@ +//! Simple/testing witness types for recursive proof composition. +//! +//! This module provides basic witness structures that capture inputs and outputs +//! of arithmetic operations without detailed intermediate computation steps. +//! +//! For Jolt or other proof systems, we would provide a more involved witness gen and backend + +use super::{ArkFr, ArkG1, ArkG2, ArkGT, BN254}; +use crate::primitives::arithmetic::Group; +use crate::recursion::{WitnessBackend, WitnessGenerator, WitnessResult}; +use ark_ff::{BigInteger, PrimeField}; + +/// BN254 scalar field bit length +const SCALAR_BITS: usize = 254; + +/// Simplified witness backend for BN254 curve. +/// +/// This backend defines witness types that store inputs, outputs, and basic +/// scalar bit decompositions. Intermediate computation steps are mostly empty. +pub struct SimpleWitnessBackend; + +impl WitnessBackend for SimpleWitnessBackend { + type GtExpWitness = GtExpWitness; + type G1ScalarMulWitness = G1ScalarMulWitness; + type G2ScalarMulWitness = G2ScalarMulWitness; + type GtMulWitness = GtMulWitness; + type PairingWitness = PairingWitness; + type MultiPairingWitness = MultiPairingWitness; + type MsmG1Witness = MsmG1Witness; + type MsmG2Witness = MsmG2Witness; +} + +/// Witness for GT exponentiation using square-and-multiply. +/// +/// Captures the intermediate values during exponentiation: base^scalar. +/// In GT (multiplicative group), this is computed as repeated squaring and multiplication. +#[derive(Clone, Debug)] +pub struct GtExpWitness { + /// The base element being exponentiated + pub base: ArkGT, + /// Scalar decomposed into bits (LSB first) + pub scalar_bits: Vec, + /// Intermediate squaring results: base, base^2, base^4, ... + pub squares: Vec, + /// Running accumulator after processing each bit + pub accumulators: Vec, + /// Final result: base^scalar + pub result: ArkGT, +} + +impl WitnessResult for GtExpWitness { + fn result(&self) -> &ArkGT { + &self.result + } +} + +/// Witness for G1 scalar multiplication using double-and-add. +#[derive(Clone, Debug)] +pub struct G1ScalarMulWitness { + /// The point being scaled + pub point: ArkG1, + /// Scalar decomposed into bits (LSB first) + pub scalar_bits: Vec, + /// Intermediate doubling results: P, 2P, 4P, ... + pub doubles: Vec, + /// Running accumulator after processing each bit + pub accumulators: Vec, + /// Final result: point * scalar + pub result: ArkG1, +} + +impl WitnessResult for G1ScalarMulWitness { + fn result(&self) -> &ArkG1 { + &self.result + } +} + +/// Witness for G2 scalar multiplication using double-and-add. +#[derive(Clone, Debug)] +pub struct G2ScalarMulWitness { + /// The point being scaled + pub point: ArkG2, + /// Scalar decomposed into bits (LSB first) + pub scalar_bits: Vec, + /// Intermediate doubling results: P, 2P, 4P, ... + pub doubles: Vec, + /// Running accumulator after processing each bit + pub accumulators: Vec, + /// Final result: point * scalar + pub result: ArkG2, +} + +impl WitnessResult for G2ScalarMulWitness { + fn result(&self) -> &ArkG2 { + &self.result + } +} + +/// Witness for GT multiplication (Fq12 multiplication). +/// +/// Since GT is a multiplicative group, "group addition" is field multiplication. +#[derive(Clone, Debug)] +pub struct GtMulWitness { + /// Left operand + pub lhs: ArkGT, + /// Right operand + pub rhs: ArkGT, + /// Intermediate values during Fq12 multiplication (Karatsuba steps) + pub intermediates: Vec, + /// Final result: lhs * rhs + pub result: ArkGT, +} + +impl WitnessResult for GtMulWitness { + fn result(&self) -> &ArkGT { + &self.result + } +} + +/// Single step in the Miller loop computation. +#[derive(Clone, Debug)] +pub struct MillerStep { + /// Line evaluation at this step + pub line_eval: ArkGT, + /// Accumulated value after this step + pub accumulator: ArkGT, +} + +/// Witness for single pairing e(G1, G2) -> GT. +/// +/// Captures the Miller loop iterations and final exponentiation. +#[derive(Clone, Debug)] +pub struct PairingWitness { + /// G1 input point + pub g1: ArkG1, + /// G2 input point + pub g2: ArkG2, + /// Miller loop step-by-step trace + pub miller_steps: Vec, + /// Final exponentiation intermediate values + pub final_exp_steps: Vec, + /// Final pairing result + pub result: ArkGT, +} + +impl WitnessResult for PairingWitness { + fn result(&self) -> &ArkGT { + &self.result + } +} + +/// Witness for multi-pairing: `∏ e(g1s[i], g2s[i])`. +#[derive(Clone, Debug)] +pub struct MultiPairingWitness { + /// G1 input points + pub g1s: Vec, + /// G2 input points + pub g2s: Vec, + /// Miller loop traces for each pair + pub individual_millers: Vec>, + /// Combined Miller loop result before final exponentiation + pub combined_miller: ArkGT, + /// Final exponentiation steps + pub final_exp_steps: Vec, + /// Final multi-pairing result + pub result: ArkGT, +} + +impl WitnessResult for MultiPairingWitness { + fn result(&self) -> &ArkGT { + &self.result + } +} + +/// Witness for G1 multi-scalar multiplication. +/// +/// For detailed Pippenger algorithm traces, stores bucket states. +#[derive(Clone, Debug)] +pub struct MsmG1Witness { + /// Base points + pub bases: Vec, + /// Scalar values + pub scalars: Vec, + /// Bucket sums (simplified - actual Pippenger has more structure) + pub bucket_sums: Vec, + /// Running sum intermediates + pub running_sums: Vec, + /// Final MSM result + pub result: ArkG1, +} + +impl WitnessResult for MsmG1Witness { + fn result(&self) -> &ArkG1 { + &self.result + } +} + +/// Witness for G2 multi-scalar multiplication. +#[derive(Clone, Debug)] +pub struct MsmG2Witness { + /// Base points + pub bases: Vec, + /// Scalar values + pub scalars: Vec, + /// Bucket sums + pub bucket_sums: Vec, + /// Running sum intermediates + pub running_sums: Vec, + /// Final MSM result + pub result: ArkG2, +} + +impl WitnessResult for MsmG2Witness { + fn result(&self) -> &ArkG2 { + &self.result + } +} + +/// Simplified witness generator for the Arkworks backend. +/// +/// This generator creates basic witnesses with inputs, outputs, and scalar +/// bit decompositions. Most intermediate traces are empty. +pub struct SimpleWitnessGenerator; + +impl WitnessGenerator for SimpleWitnessGenerator { + fn generate_gt_exp(base: &ArkGT, scalar: &ArkFr, result: &ArkGT) -> GtExpWitness { + // Get scalar bits (LSB first) + let bigint = scalar.0.into_bigint(); + let scalar_bits: Vec = (0..SCALAR_BITS).map(|i| bigint.get_bit(i)).collect(); + + // Doesn't record intermediate results + let squares = vec![*base]; + let accumulators = vec![*result]; + + GtExpWitness { + base: *base, + scalar_bits, + squares, + accumulators, + result: *result, + } + } + + fn generate_g1_scalar_mul(point: &ArkG1, scalar: &ArkFr, result: &ArkG1) -> G1ScalarMulWitness { + let bigint = scalar.0.into_bigint(); + let scalar_bits: Vec = (0..SCALAR_BITS).map(|i| bigint.get_bit(i)).collect(); + + // Doesn't record intermediate results + let doubles = vec![*point]; + let accumulators = vec![*result]; + + G1ScalarMulWitness { + point: *point, + scalar_bits, + doubles, + accumulators, + result: *result, + } + } + + fn generate_g2_scalar_mul(point: &ArkG2, scalar: &ArkFr, result: &ArkG2) -> G2ScalarMulWitness { + let bigint = scalar.0.into_bigint(); + let scalar_bits: Vec = (0..SCALAR_BITS).map(|i| bigint.get_bit(i)).collect(); + + let doubles = vec![*point]; + let accumulators = vec![*result]; + + G2ScalarMulWitness { + point: *point, + scalar_bits, + doubles, + accumulators, + result: *result, + } + } + + fn generate_gt_mul(lhs: &ArkGT, rhs: &ArkGT, result: &ArkGT) -> GtMulWitness { + GtMulWitness { + lhs: *lhs, + rhs: *rhs, + intermediates: vec![], + result: *result, + } + } + + fn generate_pairing(g1: &ArkG1, g2: &ArkG2, result: &ArkGT) -> PairingWitness { + PairingWitness { + g1: *g1, + g2: *g2, + miller_steps: vec![], + final_exp_steps: vec![], + result: *result, + } + } + + fn generate_multi_pairing(g1s: &[ArkG1], g2s: &[ArkG2], result: &ArkGT) -> MultiPairingWitness { + MultiPairingWitness { + g1s: g1s.to_vec(), + g2s: g2s.to_vec(), + individual_millers: vec![], + combined_miller: ArkGT::identity(), + final_exp_steps: vec![], + result: *result, + } + } + + fn generate_msm_g1(bases: &[ArkG1], scalars: &[ArkFr], result: &ArkG1) -> MsmG1Witness { + MsmG1Witness { + bases: bases.to_vec(), + scalars: scalars.to_vec(), + bucket_sums: vec![], + running_sums: vec![], + result: *result, + } + } + + fn generate_msm_g2(bases: &[ArkG2], scalars: &[ArkFr], result: &ArkG2) -> MsmG2Witness { + MsmG2Witness { + bases: bases.to_vec(), + scalars: scalars.to_vec(), + bucket_sums: vec![], + running_sums: vec![], + result: *result, + } + } +} diff --git a/src/backends/arkworks/mod.rs b/src/backends/arkworks/mod.rs index 63372a4..a716183 100644 --- a/src/backends/arkworks/mod.rs +++ b/src/backends/arkworks/mod.rs @@ -12,6 +12,9 @@ mod blake2b_transcript; #[cfg(feature = "cache")] pub mod ark_cache; +#[cfg(feature = "recursion")] +mod ark_witness; + pub use ark_field::ArkFr; pub use ark_group::{ArkG1, ArkG2, ArkGT, G1Routines, G2Routines}; pub use ark_pairing::BN254; @@ -22,3 +25,10 @@ pub use blake2b_transcript::Blake2bTranscript; #[cfg(feature = "cache")] pub use ark_cache::{get_prepared_g1, get_prepared_g2, init_cache, is_cached}; + +#[cfg(feature = "recursion")] +pub use ark_witness::{ + G1ScalarMulWitness, G2ScalarMulWitness, GtExpWitness, GtMulWitness, MillerStep, MsmG1Witness, + MsmG2Witness, MultiPairingWitness, PairingWitness, SimpleWitnessBackend, + SimpleWitnessGenerator, +}; diff --git a/src/evaluation_proof.rs b/src/evaluation_proof.rs index d1bbbcf..270717e 100644 --- a/src/evaluation_proof.rs +++ b/src/evaluation_proof.rs @@ -34,6 +34,9 @@ use crate::proof::DoryProof; use crate::reduce_and_fold::{DoryProverState, DoryVerifierState}; use crate::setup::{ProverSetup, VerifierSetup}; +#[cfg(feature = "recursion")] +use crate::recursion::{WitnessBackend, WitnessGenerator}; + /// Create evaluation proof for a polynomial at a point /// /// Implements Eval-VMV-RE protocol from Dory Section 5. @@ -366,3 +369,296 @@ where verifier_state.verify_final(&proof.final_message, &gamma, &d) } + +/// Verify an evaluation proof with automatic operation tracing. +/// +/// This function verifies a Dory evaluation proof while automatically tracing +/// all expensive arithmetic operations through the provided +/// [`TraceContext`](crate::recursion::TraceContext). The context determines the behavior: +/// +/// - **Witness Generation Mode**: All operations are computed and their witnesses +/// are recorded in the context's collector. +/// - **Hint-Based Mode**: Operations use pre-computed hints when available, +/// falling back to computation with a warning when hints are missing. +/// +/// # Parameters +/// - `commitment`: Polynomial commitment (in GT) +/// - `evaluation`: Claimed evaluation result +/// - `point`: Evaluation point (length must equal proof.nu + proof.sigma) +/// - `proof`: Evaluation proof to verify +/// - `setup`: Verifier setup +/// - `transcript`: Fiat-Shamir transcript for challenge generation +/// - `ctx`: Trace context (from `TraceContext::for_witness_gen()` or `TraceContext::for_hints()`) +/// +/// # Returns +/// `Ok(())` if proof is valid, `Err(DoryError)` otherwise. +/// +/// After verification, call `ctx.finalize()` to get the collected witnesses +/// (in witness generation mode) or check `ctx.had_missing_hints()` to see +/// if any hints were missing (in hint-based mode). +/// +/// # Errors +/// Returns `DoryError::InvalidProof` if verification fails, or +/// `DoryError::InvalidPointDimension` if point length doesn't match proof dimensions. +/// +/// # Panics +/// Panics if transcript challenge scalars (alpha, beta, gamma, d) are zero +/// (if this happens, go buy a lottery ticket) +/// +/// # Example +/// +/// ```ignore +/// use std::rc::Rc; +/// use dory_pcs::recursion::TraceContext; +/// +/// // Witness generation mode +/// let ctx = Rc::new(TraceContext::for_witness_gen()); +/// verify_recursive(commitment, evaluation, &point, &proof, setup.clone(), &mut transcript, ctx.clone())?; +/// let witnesses = Rc::try_unwrap(ctx).ok().unwrap().finalize(); +/// +/// // Hint-based mode +/// let hints = witnesses.to_hints::(); +/// let ctx = Rc::new(TraceContext::for_hints(hints)); +/// verify_recursive(commitment, evaluation, &point, &proof, setup, &mut transcript, ctx)?; +/// +/// TODO(markosg04) this unrolls all the reduce_and_fold fns. We could make it more ergonomic by not unrolling. +/// ``` +#[cfg(feature = "recursion")] +#[tracing::instrument(skip_all, name = "verify_recursive")] +#[allow(clippy::too_many_arguments)] +pub fn verify_recursive( + commitment: E::GT, + evaluation: F, + point: &[F], + proof: &DoryProof, + setup: VerifierSetup, + transcript: &mut T, + ctx: crate::recursion::CtxHandle, +) -> Result<(), DoryError> +where + F: Field, + E: PairingCurve, + E::G1: Group, + E::G2: Group, + E::GT: Group, + M1: DoryRoutines, + M2: DoryRoutines, + T: Transcript, + W: WitnessBackend, + Gen: WitnessGenerator, +{ + use crate::recursion::{TraceG1, TraceG2, TraceGT, TracePairing}; + use std::rc::Rc; + + let nu = proof.nu; + let sigma = proof.sigma; + + if point.len() != nu + sigma { + return Err(DoryError::InvalidPointDimension { + expected: nu + sigma, + actual: point.len(), + }); + } + + let vmv_message = &proof.vmv_message; + transcript.append_serde(b"vmv_c", &vmv_message.c); + transcript.append_serde(b"vmv_d2", &vmv_message.d2); + transcript.append_serde(b"vmv_e1", &vmv_message.e1); + + // Create trace operators + let pairing = TracePairing::new(Rc::clone(&ctx)); + + // VMV check pairing: d2 == e(e1, h2) + let e1_trace = TraceG1::new(vmv_message.e1, Rc::clone(&ctx)); + let h2_trace = TraceG2::new(setup.h2, Rc::clone(&ctx)); + let pairing_check = pairing.pair(&e1_trace, &h2_trace); + + if vmv_message.d2 != *pairing_check.inner() { + return Err(DoryError::InvalidProof); + } + + // e2 = h2 * evaluation (traced G2 scalar mul) + let e2 = h2_trace.scale(&evaluation); + + let num_rounds = sigma; + let col_coords = &point[..sigma]; + let s1_coords: Vec = col_coords.to_vec(); + let mut s2_coords: Vec = vec![F::zero(); sigma]; + let row_coords = &point[sigma..sigma + nu]; + s2_coords[..nu].copy_from_slice(&row_coords[..nu]); + + // Initialize traced verifier state + let mut c = TraceGT::new(vmv_message.c, Rc::clone(&ctx)); + let mut d1 = TraceGT::new(commitment, Rc::clone(&ctx)); + let mut d2 = TraceGT::new(vmv_message.d2, Rc::clone(&ctx)); + let mut e1 = TraceG1::new(vmv_message.e1, Rc::clone(&ctx)); + let mut e2_state = e2; + let mut s1_acc = F::one(); + let mut s2_acc = F::one(); + let mut remaining_rounds = num_rounds; + + ctx.set_num_rounds(num_rounds); + + // Process each round with automatic tracing + for round in 0..num_rounds { + ctx.advance_round(); + let first_msg = &proof.first_messages[round]; + let second_msg = &proof.second_messages[round]; + + transcript.append_serde(b"d1_left", &first_msg.d1_left); + transcript.append_serde(b"d1_right", &first_msg.d1_right); + transcript.append_serde(b"d2_left", &first_msg.d2_left); + transcript.append_serde(b"d2_right", &first_msg.d2_right); + transcript.append_serde(b"e1_beta", &first_msg.e1_beta); + transcript.append_serde(b"e2_beta", &first_msg.e2_beta); + let beta = transcript.challenge_scalar(b"beta"); + + transcript.append_serde(b"c_plus", &second_msg.c_plus); + transcript.append_serde(b"c_minus", &second_msg.c_minus); + transcript.append_serde(b"e1_plus", &second_msg.e1_plus); + transcript.append_serde(b"e1_minus", &second_msg.e1_minus); + transcript.append_serde(b"e2_plus", &second_msg.e2_plus); + transcript.append_serde(b"e2_minus", &second_msg.e2_minus); + let alpha = transcript.challenge_scalar(b"alpha"); + + let alpha_inv = alpha.inv().expect("alpha must be invertible"); + let beta_inv = beta.inv().expect("beta must be invertible"); + + // Update C with traced operations + let chi = &setup.chi[remaining_rounds]; + c = c + TraceGT::new(*chi, Rc::clone(&ctx)); + + // d2.scale(beta) - traced GT exp + let d2_scaled = d2.scale(&beta); + // c + d2_scaled - traced GT mul (via Add impl) + c = c + d2_scaled; + + // d1.scale(beta_inv) - traced GT exp + let d1_scaled = d1.scale(&beta_inv); + c = c + d1_scaled; + + // c_plus.scale(alpha) - traced GT exp + let c_plus_trace = TraceGT::new(second_msg.c_plus, Rc::clone(&ctx)); + let c_plus_scaled = c_plus_trace.scale(&alpha); + c = c + c_plus_scaled; + + // c_minus.scale(alpha_inv) - traced GT exp + let c_minus_trace = TraceGT::new(second_msg.c_minus, Rc::clone(&ctx)); + let c_minus_scaled = c_minus_trace.scale(&alpha_inv); + c = c + c_minus_scaled; + + // Update D1 (GT operations - traced via scale and add) + let delta_1l = &setup.delta_1l[remaining_rounds]; + let delta_1r = &setup.delta_1r[remaining_rounds]; + let alpha_beta = alpha * beta; + let d1_left_trace = TraceGT::new(first_msg.d1_left, Rc::clone(&ctx)); + d1 = d1_left_trace.scale(&alpha); + d1 = d1 + TraceGT::new(first_msg.d1_right, Rc::clone(&ctx)); + let delta_1l_trace = TraceGT::new(*delta_1l, Rc::clone(&ctx)); + d1 = d1 + delta_1l_trace.scale(&alpha_beta); + let delta_1r_trace = TraceGT::new(*delta_1r, Rc::clone(&ctx)); + d1 = d1 + delta_1r_trace.scale(&beta); + + // Update D2 (GT operations - traced via scale and add) + let delta_2l = &setup.delta_2l[remaining_rounds]; + let delta_2r = &setup.delta_2r[remaining_rounds]; + let alpha_inv_beta_inv = alpha_inv * beta_inv; + let d2_left_trace = TraceGT::new(first_msg.d2_left, Rc::clone(&ctx)); + d2 = d2_left_trace.scale(&alpha_inv); + d2 = d2 + TraceGT::new(first_msg.d2_right, Rc::clone(&ctx)); + let delta_2l_trace = TraceGT::new(*delta_2l, Rc::clone(&ctx)); + d2 = d2 + delta_2l_trace.scale(&alpha_inv_beta_inv); + let delta_2r_trace = TraceGT::new(*delta_2r, Rc::clone(&ctx)); + d2 = d2 + delta_2r_trace.scale(&beta_inv); + + // Update E1 (G1 operations - traced via scale) + let e1_beta_trace = TraceG1::new(first_msg.e1_beta, Rc::clone(&ctx)); + let e1_beta_scaled = e1_beta_trace.scale(&beta); + e1 = e1 + e1_beta_scaled; + let e1_plus_trace = TraceG1::new(second_msg.e1_plus, Rc::clone(&ctx)); + e1 = e1 + e1_plus_trace.scale(&alpha); + let e1_minus_trace = TraceG1::new(second_msg.e1_minus, Rc::clone(&ctx)); + e1 = e1 + e1_minus_trace.scale(&alpha_inv); + + // Update E2 (G2 operations - traced via scale) + let e2_beta_trace = TraceG2::new(first_msg.e2_beta, Rc::clone(&ctx)); + let e2_beta_scaled = e2_beta_trace.scale(&beta_inv); + e2_state = e2_state + e2_beta_scaled; + let e2_plus_trace = TraceG2::new(second_msg.e2_plus, Rc::clone(&ctx)); + e2_state = e2_state + e2_plus_trace.scale(&alpha); + let e2_minus_trace = TraceG2::new(second_msg.e2_minus, Rc::clone(&ctx)); + e2_state = e2_state + e2_minus_trace.scale(&alpha_inv); + + // Update scalar accumulators (field ops, not traced) + let idx = remaining_rounds - 1; + let y_t = s1_coords[idx]; + let x_t = s2_coords[idx]; + let one = F::one(); + let s1_term = alpha * (one - y_t) + y_t; + let s2_term = alpha_inv * (one - x_t) + x_t; + s1_acc = s1_acc * s1_term; + s2_acc = s2_acc * s2_term; + + remaining_rounds -= 1; + } + + ctx.enter_final(); + + let gamma = transcript.challenge_scalar(b"gamma"); + let d_challenge = transcript.challenge_scalar(b"d"); + + let gamma_inv = gamma.inv().expect("gamma must be invertible"); + let d_inv = d_challenge.inv().expect("d must be invertible"); + + // Final verification with tracing + let s_product = s1_acc * s2_acc; + let ht_trace = TraceGT::new(setup.ht, Rc::clone(&ctx)); + let ht_scaled = ht_trace.scale(&s_product); + c = c + ht_scaled; + + // Traced pairings + let h1_trace = TraceG1::new(setup.h1, Rc::clone(&ctx)); + let pairing_h1_e2 = pairing.pair(&h1_trace, &e2_state); + let pairing_e1_h2 = pairing.pair(&e1, &h2_trace); + + c = c + pairing_h1_e2.scale(&gamma); + c = c + pairing_e1_h2.scale(&gamma_inv); + + // D1 update with traced operations + let scalar_for_g2_in_d1 = s1_acc * gamma; + let g2_0_trace = TraceG2::new(setup.g2_0, Rc::clone(&ctx)); + let g2_0_scaled = g2_0_trace.scale(&scalar_for_g2_in_d1); + + let pairing_h1_g2 = pairing.pair(&h1_trace, &g2_0_scaled); + d1 = d1 + pairing_h1_g2; + + // D2 update with traced operations + let scalar_for_g1_in_d2 = s2_acc * gamma_inv; + let g1_0_trace = TraceG1::new(setup.g1_0, Rc::clone(&ctx)); + let g1_0_scaled = g1_0_trace.scale(&scalar_for_g1_in_d2); + + let pairing_g1_h2 = pairing.pair(&g1_0_scaled, &h2_trace); + d2 = d2 + pairing_g1_h2; + + // Final pairing check + let e1_final = TraceG1::new(proof.final_message.e1, Rc::clone(&ctx)); + let g1_0_d_scaled = g1_0_trace.scale(&d_challenge); + let e1_modified = e1_final + g1_0_d_scaled; + + let e2_final = TraceG2::new(proof.final_message.e2, Rc::clone(&ctx)); + let g2_0_d_inv_scaled = g2_0_trace.scale(&d_inv); + let e2_modified = e2_final + g2_0_d_inv_scaled; + + let lhs = pairing.pair(&e1_modified, &e2_modified); + + let mut rhs = c; + rhs = rhs + TraceGT::new(setup.chi[0], Rc::clone(&ctx)); + rhs = rhs + d2.scale(&d_challenge); + rhs = rhs + d1.scale(&d_inv); + + if *lhs.inner() == *rhs.inner() { + Ok(()) + } else { + Err(DoryError::InvalidProof) + } +} diff --git a/src/lib.rs b/src/lib.rs index 37940e5..c014f4b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -100,6 +100,9 @@ pub mod setup; #[cfg(feature = "arkworks")] pub mod backends; +#[cfg(feature = "recursion")] +pub mod recursion; + pub use error::DoryError; pub use evaluation_proof::create_evaluation_proof; pub use messages::{FirstReduceMessage, ScalarProductMessage, SecondReduceMessage, VMVMessage}; @@ -107,6 +110,8 @@ use primitives::arithmetic::{DoryRoutines, Field, Group, PairingCurve}; pub use primitives::poly::{MultilinearLagrange, Polynomial}; use primitives::serialization::{DoryDeserialize, DorySerialize}; pub use proof::DoryProof; +#[cfg(feature = "recursion")] +use recursion::WitnessBackend; pub use reduce_and_fold::{DoryProverState, DoryVerifierState}; pub use setup::{ProverSetup, VerifierSetup}; @@ -338,3 +343,92 @@ where commitment, evaluation, point, proof, setup, transcript, ) } + +/// Verifies an evaluation proof with automatic operation tracing. +/// +/// This function verifies a Dory evaluation proof while automatically tracing +/// all expensive arithmetic operations through the provided +/// [`TraceContext`](recursion::TraceContext). The context determines the behavior: +/// +/// - **Witness Generation Mode**: Create context with +/// [`TraceContext::for_witness_gen()`](recursion::TraceContext::for_witness_gen). +/// All operations are computed and their witnesses are recorded. +/// +/// - **Hint-Based Mode**: Create context with +/// [`TraceContext::for_hints(hints)`](recursion::TraceContext::for_hints). +/// Operations use pre-computed hints when available, falling back to computation +/// with a warning when hints are missing. +/// +/// # Arguments +/// +/// - `commitment`: The polynomial commitment (tier-2/GT element) +/// - `evaluation`: The claimed evaluation value +/// - `point`: The evaluation point +/// - `proof`: The Dory proof +/// - `setup`: Verifier setup parameters +/// - `transcript`: Fiat-Shamir transcript +/// - `ctx`: Trace context handle (use `Rc::new(TraceContext::for_witness_gen())` or +/// `Rc::new(TraceContext::for_hints(hints))`) +/// +/// # Returns +/// +/// `Ok(())` if the proof is valid. +/// +/// After verification: +/// - In witness generation mode: Call `Rc::try_unwrap(ctx).ok().unwrap().finalize()` +/// to get the collected witnesses. +/// - In hint-based mode: Check `ctx.had_missing_hints()` to see if any hints were missing. +/// +/// # Example +/// +/// ```ignore +/// use std::rc::Rc; +/// use dory_pcs::recursion::TraceContext; +/// +/// // Witness generation +/// let ctx = Rc::new(TraceContext::for_witness_gen()); +/// verify_recursive::<_, E, M1, M2, _, W, Gen>( +/// commitment, evaluation, &point, &proof, setup.clone(), &mut transcript, ctx.clone() +/// )?; +/// let witnesses = Rc::try_unwrap(ctx).ok().unwrap().finalize(); +/// +/// // Convert to lightweight hints +/// let hints = witnesses.unwrap().to_hints::(); +/// +/// // Hint-based verification +/// let ctx = Rc::new(TraceContext::for_hints(hints)); +/// verify_recursive::<_, E, M1, M2, _, W, Gen>( +/// commitment, evaluation, &point, &proof, setup, &mut transcript, ctx +/// )?; +/// ``` +/// +/// # Errors +/// +/// Returns `DoryError::InvalidProof` if verification fails. +#[cfg(feature = "recursion")] +#[allow(clippy::too_many_arguments)] +pub fn verify_recursive( + commitment: E::GT, + evaluation: F, + point: &[F], + proof: &DoryProof, + setup: VerifierSetup, + transcript: &mut T, + ctx: recursion::CtxHandle, +) -> Result<(), DoryError> +where + F: Field, + E: PairingCurve + Clone, + E::G1: Group, + E::G2: Group, + E::GT: Group, + M1: DoryRoutines, + M2: DoryRoutines, + T: primitives::transcript::Transcript, + W: WitnessBackend, + Gen: recursion::WitnessGenerator, +{ + evaluation_proof::verify_recursive::( + commitment, evaluation, point, proof, setup, transcript, ctx, + ) +} diff --git a/src/recursion/collection.rs b/src/recursion/collection.rs new file mode 100644 index 0000000..bc724df --- /dev/null +++ b/src/recursion/collection.rs @@ -0,0 +1,139 @@ +//! Witness collection storage for recursive proof composition. + +use std::collections::HashMap; + +use super::hint_map::HintMap; +use super::witness::{OpId, WitnessBackend, WitnessResult}; +use crate::primitives::arithmetic::PairingCurve; + +/// Storage for all witnesses collected during a verification run. +/// +/// This struct holds witnesses for each type of arithmetic operation, indexed +/// by their [`OpId`]. It is produced internally during witness generation and can +/// be converted to a [`HintMap`](crate::recursion::HintMap) for hint-based verification. +/// +/// # Type Parameters +/// +/// - `W`: The witness backend defining concrete witness types +pub struct WitnessCollection { + /// Number of reduce-and-fold rounds in the verification + pub num_rounds: usize, + + /// GT exponentiation witnesses (base^scalar) + pub gt_exp: HashMap, + + /// G1 scalar multiplication witnesses + pub g1_scalar_mul: HashMap, + + /// G2 scalar multiplication witnesses + pub g2_scalar_mul: HashMap, + + /// GT multiplication witnesses + pub gt_mul: HashMap, + + /// Single pairing witnesses + pub pairing: HashMap, + + /// Multi-pairing witnesses + pub multi_pairing: HashMap, + + /// G1 MSM witnesses + pub msm_g1: HashMap, + + /// G2 MSM witnesses + pub msm_g2: HashMap, +} + +impl WitnessCollection { + /// Create an empty witness collection. + pub fn new() -> Self { + Self { + num_rounds: 0, + gt_exp: HashMap::new(), + g1_scalar_mul: HashMap::new(), + g2_scalar_mul: HashMap::new(), + gt_mul: HashMap::new(), + pairing: HashMap::new(), + multi_pairing: HashMap::new(), + msm_g1: HashMap::new(), + msm_g2: HashMap::new(), + } + } + + /// Total number of witnesses across all operation types. + pub fn total_witnesses(&self) -> usize { + self.gt_exp.len() + + self.g1_scalar_mul.len() + + self.g2_scalar_mul.len() + + self.gt_mul.len() + + self.pairing.len() + + self.multi_pairing.len() + + self.msm_g1.len() + + self.msm_g2.len() + } + + /// Check if the collection is empty. + pub fn is_empty(&self) -> bool { + self.total_witnesses() == 0 + } +} + +impl Default for WitnessCollection { + fn default() -> Self { + Self::new() + } +} + +impl WitnessCollection { + /// Convert full witness collection to hints (outputs only). + /// + /// # Type Parameters + /// + /// - `E`: The pairing curve whose group elements are stored in the witnesses + pub fn to_hints(&self) -> HintMap + where + E: PairingCurve, + W::GtExpWitness: WitnessResult, + W::G1ScalarMulWitness: WitnessResult, + W::G2ScalarMulWitness: WitnessResult, + W::GtMulWitness: WitnessResult, + W::PairingWitness: WitnessResult, + W::MultiPairingWitness: WitnessResult, + W::MsmG1Witness: WitnessResult, + W::MsmG2Witness: WitnessResult, + { + let mut hints = HintMap::new(self.num_rounds); + + // Extract GT results + for (id, w) in &self.gt_exp { + hints.insert_gt(*id, *w.result()); + } + for (id, w) in &self.gt_mul { + hints.insert_gt(*id, *w.result()); + } + for (id, w) in &self.pairing { + hints.insert_gt(*id, *w.result()); + } + for (id, w) in &self.multi_pairing { + hints.insert_gt(*id, *w.result()); + } + + // Extract G1 results + for (id, w) in &self.g1_scalar_mul { + hints.insert_g1(*id, *w.result()); + } + for (id, w) in &self.msm_g1 { + hints.insert_g1(*id, *w.result()); + } + + // Extract G2 results + for (id, w) in &self.g2_scalar_mul { + hints.insert_g2(*id, *w.result()); + } + for (id, w) in &self.msm_g2 { + hints.insert_g2(*id, *w.result()); + } + + hints + } +} diff --git a/src/recursion/collector.rs b/src/recursion/collector.rs new file mode 100644 index 0000000..39a23b0 --- /dev/null +++ b/src/recursion/collector.rs @@ -0,0 +1,271 @@ +//! Witness collection for recursive proof composition. + +use std::collections::HashMap; +use std::marker::PhantomData; + +use super::witness::{OpId, OpType, WitnessBackend}; +use crate::primitives::arithmetic::{Group, PairingCurve}; + +use super::WitnessCollection; + +/// Builder for tracking operation IDs during witness collection. +/// +/// Maintains counters for each operation type within a round, +/// providing deterministic operation IDs. +#[derive(Debug, Clone)] +pub(crate) struct OpIdBuilder { + current_round: u16, + counters: HashMap, +} + +impl OpIdBuilder { + /// Create a new builder starting at round 0 (VMV phase). + pub(crate) fn new() -> Self { + Self { + current_round: 0, + counters: HashMap::new(), + } + } + + /// Advance to the next round. + pub(crate) fn advance_round(&mut self) { + self.current_round += 1; + self.counters.clear(); + } + + /// Enter the final verification phase (base case of Dory reduce) + pub(crate) fn enter_final(&mut self) { + self.current_round = u16::MAX; + self.counters.clear(); + } + + /// Get the current round number. + pub(crate) fn round(&self) -> u16 { + self.current_round + } + + /// Generate the next operation ID for the given type. + pub(crate) fn next(&mut self, op_type: OpType) -> OpId { + let index = self.counters.entry(op_type).or_insert(0); + let id = OpId::new(self.current_round, op_type, *index); + *index += 1; + id + } +} + +impl Default for OpIdBuilder { + fn default() -> Self { + Self::new() + } +} + +/// Trait for generating detailed witness traces from operation inputs/outputs. +/// +/// Backend implementations provide this to create witnesses with intermediate +/// computation steps (e.g., Miller loop iterations, square-and-multiply steps). +pub trait WitnessGenerator { + /// Generate a GT exponentiation witness with intermediate steps. + fn generate_gt_exp( + base: &E::GT, + scalar: &::Scalar, + result: &E::GT, + ) -> W::GtExpWitness; + + /// Generate a G1 scalar multiplication witness with intermediate steps. + fn generate_g1_scalar_mul( + point: &E::G1, + scalar: &::Scalar, + result: &E::G1, + ) -> W::G1ScalarMulWitness; + + /// Generate a G2 scalar multiplication witness with intermediate steps. + fn generate_g2_scalar_mul( + point: &E::G2, + scalar: &::Scalar, + result: &E::G2, + ) -> W::G2ScalarMulWitness; + + /// Generate a GT multiplication witness with intermediate steps. + fn generate_gt_mul(lhs: &E::GT, rhs: &E::GT, result: &E::GT) -> W::GtMulWitness; + + /// Generate a single pairing witness with Miller loop steps. + fn generate_pairing(g1: &E::G1, g2: &E::G2, result: &E::GT) -> W::PairingWitness; + + /// Generate a multi-pairing witness with all Miller loop steps. + fn generate_multi_pairing( + g1s: &[E::G1], + g2s: &[E::G2], + result: &E::GT, + ) -> W::MultiPairingWitness; + + /// Generate a G1 MSM witness with bucket and accumulator states. + fn generate_msm_g1( + bases: &[E::G1], + scalars: &[::Scalar], + result: &E::G1, + ) -> W::MsmG1Witness; + + /// Generate a G2 MSM witness with bucket and accumulator states. + fn generate_msm_g2( + bases: &[E::G2], + scalars: &[::Scalar], + result: &E::G2, + ) -> W::MsmG2Witness; +} + +/// Witness collector that generates and stores witnesses during verification. +/// +/// # Type Parameters +/// +/// - `W`: The witness backend defining witness types +/// - `E`: The pairing curve providing group element types +/// - `Gen`: A witness generator that creates detailed traces +pub(crate) struct WitnessCollector +where + W: WitnessBackend, + E: PairingCurve, + Gen: WitnessGenerator, +{ + collection: WitnessCollection, + _phantom: PhantomData<(E, Gen)>, +} + +impl WitnessCollector +where + W: WitnessBackend, + E: PairingCurve, + Gen: WitnessGenerator, +{ + /// Create a new witness collector. + pub(crate) fn new() -> Self { + Self { + collection: WitnessCollection::new(), + _phantom: PhantomData, + } + } + + /// Set the number of rounds for the verification. + pub(crate) fn set_num_rounds(&mut self, num_rounds: usize) { + self.collection.num_rounds = num_rounds; + } + + /// Finalize collection and return all accumulated witnesses. + pub(crate) fn finalize(self) -> WitnessCollection { + self.collection + } + + /// Collect a GT exponentiation witness. + pub(crate) fn collect_gt_exp( + &mut self, + id: OpId, + base: &E::GT, + scalar: &::Scalar, + result: &E::GT, + ) -> W::GtExpWitness { + let witness = Gen::generate_gt_exp(base, scalar, result); + self.collection.gt_exp.insert(id, witness.clone()); + witness + } + + /// Collect a G1 scalar multiplication witness. + pub(crate) fn collect_g1_scalar_mul( + &mut self, + id: OpId, + point: &E::G1, + scalar: &::Scalar, + result: &E::G1, + ) -> W::G1ScalarMulWitness { + let witness = Gen::generate_g1_scalar_mul(point, scalar, result); + self.collection.g1_scalar_mul.insert(id, witness.clone()); + witness + } + + /// Collect a G2 scalar multiplication witness. + pub(crate) fn collect_g2_scalar_mul( + &mut self, + id: OpId, + point: &E::G2, + scalar: &::Scalar, + result: &E::G2, + ) -> W::G2ScalarMulWitness { + let witness = Gen::generate_g2_scalar_mul(point, scalar, result); + self.collection.g2_scalar_mul.insert(id, witness.clone()); + witness + } + + /// Collect a GT multiplication witness. + pub(crate) fn collect_gt_mul( + &mut self, + id: OpId, + lhs: &E::GT, + rhs: &E::GT, + result: &E::GT, + ) -> W::GtMulWitness { + let witness = Gen::generate_gt_mul(lhs, rhs, result); + self.collection.gt_mul.insert(id, witness.clone()); + witness + } + + /// Collect a single pairing witness. + pub(crate) fn collect_pairing( + &mut self, + id: OpId, + g1: &E::G1, + g2: &E::G2, + result: &E::GT, + ) -> W::PairingWitness { + let witness = Gen::generate_pairing(g1, g2, result); + self.collection.pairing.insert(id, witness.clone()); + witness + } + + /// Collect a multi-pairing witness. + pub(crate) fn collect_multi_pairing( + &mut self, + id: OpId, + g1s: &[E::G1], + g2s: &[E::G2], + result: &E::GT, + ) -> W::MultiPairingWitness { + let witness = Gen::generate_multi_pairing(g1s, g2s, result); + self.collection.multi_pairing.insert(id, witness.clone()); + witness + } + + /// Collect a G1 MSM witness. + pub(crate) fn collect_msm_g1( + &mut self, + id: OpId, + bases: &[E::G1], + scalars: &[::Scalar], + result: &E::G1, + ) -> W::MsmG1Witness { + let witness = Gen::generate_msm_g1(bases, scalars, result); + self.collection.msm_g1.insert(id, witness.clone()); + witness + } + + /// Collect a G2 MSM witness. + pub(crate) fn collect_msm_g2( + &mut self, + id: OpId, + bases: &[E::G2], + scalars: &[::Scalar], + result: &E::G2, + ) -> W::MsmG2Witness { + let witness = Gen::generate_msm_g2(bases, scalars, result); + self.collection.msm_g2.insert(id, witness.clone()); + witness + } +} + +impl Default for WitnessCollector +where + W: WitnessBackend, + E: PairingCurve, + Gen: WitnessGenerator, +{ + fn default() -> Self { + Self::new() + } +} diff --git a/src/recursion/context.rs b/src/recursion/context.rs new file mode 100644 index 0000000..19f8696 --- /dev/null +++ b/src/recursion/context.rs @@ -0,0 +1,254 @@ +//! Trace context for automatic operation tracing during verification. +//! +//! This module provides [`TraceContext`], a unified context that manages both +//! witness generation and hint-based verification modes. Operations executed +//! through trace types automatically record witnesses or use hints based on +//! the context's mode. + +use std::cell::RefCell; +use std::marker::PhantomData; +use std::rc::Rc; + +use super::witness::{OpId, OpType, WitnessBackend}; +use crate::primitives::arithmetic::{Group, PairingCurve}; + +use super::{HintMap, OpIdBuilder, WitnessCollection, WitnessCollector, WitnessGenerator}; + +/// Execution mode for traced verification operations. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum ExecutionMode { + /// Always compute operations and record witnesses. + /// Used during initial witness generation phase. + #[default] + WitnessGeneration, + + /// Try hints first, fall back to compute with warning. + /// Used during recursive verification when hints should be available. + HintBased, +} + +/// Handle to a trace context +pub type CtxHandle = Rc>; + +/// Context for executing arithmetic operations with automatic tracing. +/// +/// In **witness generation** mode, all traced operations are computed and +/// their witnesses are recorded. +/// +/// In **hint-based** mode, traced operations first check for pre-computed hints. +/// If a hint is missing, the operation is computed with a warning logged via +/// `tracing::warn!`. +/// +/// # Interior Mutability +/// +/// This context uses [`RefCell`] for interior mutability because arithmetic +/// operators (`Add`, `Sub`, `Mul`) take `&self`, not `&mut self`. Since +/// verification is single-threaded, `RefCell` provides the necessary mutability. +pub struct TraceContext +where + W: WitnessBackend, + E: PairingCurve, + Gen: WitnessGenerator, +{ + mode: ExecutionMode, + id_builder: RefCell, + collector: RefCell>>, + hints: Option>, + missing_hints: RefCell>, + _phantom: PhantomData<(W, E, Gen)>, +} + +impl TraceContext +where + W: WitnessBackend, + E: PairingCurve, + Gen: WitnessGenerator, +{ + /// Create a context for witness generation mode. + /// + /// All traced operations will be computed and their witnesses recorded. + pub fn for_witness_gen() -> Self { + Self { + mode: ExecutionMode::WitnessGeneration, + id_builder: RefCell::new(OpIdBuilder::new()), + collector: RefCell::new(Some(WitnessCollector::new())), + hints: None, + missing_hints: RefCell::new(Vec::new()), + _phantom: PhantomData, + } + } + + /// Create a context for hint-based verification. + /// + /// Traced operations will use pre-computed hints when available, + /// falling back to computation with a warning when hints are missing. + pub fn for_hints(hints: HintMap) -> Self { + Self { + mode: ExecutionMode::HintBased, + id_builder: RefCell::new(OpIdBuilder::new()), + collector: RefCell::new(None), + hints: Some(hints), + missing_hints: RefCell::new(Vec::new()), + _phantom: PhantomData, + } + } + + /// Get the current execution mode. + #[inline] + pub fn mode(&self) -> ExecutionMode { + self.mode + } + + /// Advance to the next round. + pub fn advance_round(&self) { + self.id_builder.borrow_mut().advance_round(); + } + + /// Enter the final verification phase. + pub fn enter_final(&self) { + self.id_builder.borrow_mut().enter_final(); + } + + /// Get the current round number. + pub fn round(&self) -> u16 { + self.id_builder.borrow().round() + } + + /// Set the number of rounds for witness collection. + pub fn set_num_rounds(&self, num_rounds: usize) { + if let Some(ref mut collector) = *self.collector.borrow_mut() { + collector.set_num_rounds(num_rounds); + } + } + + /// Generate the next operation ID for the given type. + pub fn next_id(&self, op_type: OpType) -> OpId { + self.id_builder.borrow_mut().next(op_type) + } + + /// Get all missing hints encountered during hint-based verification. + pub fn missing_hints(&self) -> Vec { + self.missing_hints.borrow().clone() + } + + /// Check if any hints were missing during verification. + pub fn had_missing_hints(&self) -> bool { + !self.missing_hints.borrow().is_empty() + } + + /// Record that a hint was missing for the given operation. + pub fn record_missing_hint(&self, id: OpId) { + self.missing_hints.borrow_mut().push(id); + } + + /// Finalize and return the collected witnesses (if in witness generation mode). + /// + /// Returns `None` if no collector was active (pure hint mode without recording). + pub fn finalize(self) -> Option> { + self.collector.into_inner().map(|c| c.finalize()) + } + + /// Get a G1 hint for the given operation. + #[inline] + pub fn get_hint_g1(&self, id: OpId) -> Option { + self.hints.as_ref().and_then(|h| h.get_g1(id).copied()) + } + + /// Get a G2 hint for the given operation. + #[inline] + pub fn get_hint_g2(&self, id: OpId) -> Option { + self.hints.as_ref().and_then(|h| h.get_g2(id).copied()) + } + + /// Get a GT hint for the given operation. + #[inline] + pub fn get_hint_gt(&self, id: OpId) -> Option { + self.hints.as_ref().and_then(|h| h.get_gt(id).copied()) + } + + /// Record a GT exponentiation witness. + pub fn record_gt_exp( + &self, + id: OpId, + base: &E::GT, + scalar: &::Scalar, + result: &E::GT, + ) { + if let Some(ref mut collector) = *self.collector.borrow_mut() { + collector.collect_gt_exp(id, base, scalar, result); + } + } + + /// Record a G1 scalar multiplication witness. + pub fn record_g1_scalar_mul( + &self, + id: OpId, + point: &E::G1, + scalar: &::Scalar, + result: &E::G1, + ) { + if let Some(ref mut collector) = *self.collector.borrow_mut() { + collector.collect_g1_scalar_mul(id, point, scalar, result); + } + } + + /// Record a G2 scalar multiplication witness. + pub fn record_g2_scalar_mul( + &self, + id: OpId, + point: &E::G2, + scalar: &::Scalar, + result: &E::G2, + ) { + if let Some(ref mut collector) = *self.collector.borrow_mut() { + collector.collect_g2_scalar_mul(id, point, scalar, result); + } + } + + /// Record a GT multiplication witness. + pub fn record_gt_mul(&self, id: OpId, lhs: &E::GT, rhs: &E::GT, result: &E::GT) { + if let Some(ref mut collector) = *self.collector.borrow_mut() { + collector.collect_gt_mul(id, lhs, rhs, result); + } + } + + /// Record a pairing witness. + pub fn record_pairing(&self, id: OpId, g1: &E::G1, g2: &E::G2, result: &E::GT) { + if let Some(ref mut collector) = *self.collector.borrow_mut() { + collector.collect_pairing(id, g1, g2, result); + } + } + + /// Record a multi-pairing witness. + pub fn record_multi_pairing(&self, id: OpId, g1s: &[E::G1], g2s: &[E::G2], result: &E::GT) { + if let Some(ref mut collector) = *self.collector.borrow_mut() { + collector.collect_multi_pairing(id, g1s, g2s, result); + } + } + + /// Record a G1 MSM witness. + pub fn record_msm_g1( + &self, + id: OpId, + bases: &[E::G1], + scalars: &[::Scalar], + result: &E::G1, + ) { + if let Some(ref mut collector) = *self.collector.borrow_mut() { + collector.collect_msm_g1(id, bases, scalars, result); + } + } + + /// Record a G2 MSM witness. + pub fn record_msm_g2( + &self, + id: OpId, + bases: &[E::G2], + scalars: &[::Scalar], + result: &E::G2, + ) { + if let Some(ref mut collector) = *self.collector.borrow_mut() { + collector.collect_msm_g2(id, bases, scalars, result); + } + } +} diff --git a/src/recursion/hint_map.rs b/src/recursion/hint_map.rs new file mode 100644 index 0000000..ea7c183 --- /dev/null +++ b/src/recursion/hint_map.rs @@ -0,0 +1,324 @@ +//! Lightweight hint storage for recursive verification. +//! +//! This module provides [`HintMap`], a simplified storage structure that holds +//! only operation results (not full witnesses with intermediate computation steps). +//! This results in ~30-50x smaller storage compared to full witness collections. + +use std::collections::HashMap; +use std::io::{Read, Write}; + +use super::witness::{OpId, OpType}; +use crate::primitives::arithmetic::PairingCurve; +use crate::primitives::serialization::{ + Compress, DoryDeserialize, DorySerialize, SerializationError, Valid, Validate, +}; + +/// Tag bytes for HintResult discriminant during serialization. +const TAG_G1: u8 = 0; +const TAG_G2: u8 = 1; +const TAG_GT: u8 = 2; + +/// Result value storing only the computed output of an operation. +/// +/// Unlike full witness types which store intermediate computation steps, +/// this stores only the final result, suitable for hint-based verification. +#[derive(Clone)] +pub enum HintResult { + /// G1 point result (from G1ScalarMul, MsmG1) + G1(E::G1), + /// G2 point result (from G2ScalarMul, MsmG2) + G2(E::G2), + /// GT element result (from GtExp, GtMul, Pairing, MultiPairing) + GT(E::GT), +} + +impl HintResult { + /// Returns true if this is a G1 result. + #[inline] + pub fn is_g1(&self) -> bool { + matches!(self, HintResult::G1(_)) + } + + /// Returns true if this is a G2 result. + #[inline] + pub fn is_g2(&self) -> bool { + matches!(self, HintResult::G2(_)) + } + + /// Returns true if this is a GT result. + #[inline] + pub fn is_gt(&self) -> bool { + matches!(self, HintResult::GT(_)) + } + + /// Try to get as G1, returns None if wrong variant. + #[inline] + pub fn as_g1(&self) -> Option<&E::G1> { + match self { + HintResult::G1(g1) => Some(g1), + _ => None, + } + } + + /// Try to get as G2, returns None if wrong variant. + #[inline] + pub fn as_g2(&self) -> Option<&E::G2> { + match self { + HintResult::G2(g2) => Some(g2), + _ => None, + } + } + + /// Try to get as GT, returns None if wrong variant. + #[inline] + pub fn as_gt(&self) -> Option<&E::GT> { + match self { + HintResult::GT(gt) => Some(gt), + _ => None, + } + } +} + +impl Valid for HintResult { + fn check(&self) -> Result<(), SerializationError> { + // Curve points are validated during deserialization + Ok(()) + } +} + +impl DorySerialize for HintResult { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + match self { + HintResult::G1(g1) => { + TAG_G1.serialize_with_mode(&mut writer, compress)?; + g1.serialize_with_mode(writer, compress) + } + HintResult::G2(g2) => { + TAG_G2.serialize_with_mode(&mut writer, compress)?; + g2.serialize_with_mode(writer, compress) + } + HintResult::GT(gt) => { + TAG_GT.serialize_with_mode(&mut writer, compress)?; + gt.serialize_with_mode(writer, compress) + } + } + } + + fn serialized_size(&self, compress: Compress) -> usize { + 1 + match self { + HintResult::G1(g1) => g1.serialized_size(compress), + HintResult::G2(g2) => g2.serialized_size(compress), + HintResult::GT(gt) => gt.serialized_size(compress), + } + } +} + +impl DoryDeserialize for HintResult { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let tag = u8::deserialize_with_mode(&mut reader, compress, validate)?; + match tag { + TAG_G1 => Ok(HintResult::G1(E::G1::deserialize_with_mode( + reader, compress, validate, + )?)), + TAG_G2 => Ok(HintResult::G2(E::G2::deserialize_with_mode( + reader, compress, validate, + )?)), + TAG_GT => Ok(HintResult::GT(E::GT::deserialize_with_mode( + reader, compress, validate, + )?)), + _ => Err(SerializationError::InvalidData(format!( + "Invalid HintResult tag: {tag}" + ))), + } + } +} + +/// Hint storage +/// +/// Unlike [`WitnessCollection`](crate::recursion::WitnessCollection) which stores +/// full computation traces, this stores only the final results for each operation, +/// indexed by [`OpId`]. +#[derive(Clone)] +pub struct HintMap { + /// Number of reduce-and-fold rounds in the verification + pub num_rounds: usize, + /// All operation results indexed by OpId + results: HashMap>, +} + +impl HintMap { + /// Create a new empty hint map. + pub fn new(num_rounds: usize) -> Self { + Self { + num_rounds, + results: HashMap::new(), + } + } + + /// Get G1 result for an operation. + /// + /// Returns None if the operation is not found or is not a G1 result. + #[inline] + pub fn get_g1(&self, id: OpId) -> Option<&E::G1> { + self.results.get(&id).and_then(|r| r.as_g1()) + } + + /// Get G2 result for an operation. + /// + /// Returns None if the operation is not found or is not a G2 result. + #[inline] + pub fn get_g2(&self, id: OpId) -> Option<&E::G2> { + self.results.get(&id).and_then(|r| r.as_g2()) + } + + /// Get GT result for an operation. + /// + /// Returns None if the operation is not found or is not a GT result. + #[inline] + pub fn get_gt(&self, id: OpId) -> Option<&E::GT> { + self.results.get(&id).and_then(|r| r.as_gt()) + } + + /// Get raw result enum for an operation. + #[inline] + pub fn get(&self, id: OpId) -> Option<&HintResult> { + self.results.get(&id) + } + + /// Insert a G1 result. + #[inline] + pub fn insert_g1(&mut self, id: OpId, value: E::G1) { + self.results.insert(id, HintResult::G1(value)); + } + + /// Insert a G2 result. + #[inline] + pub fn insert_g2(&mut self, id: OpId, value: E::G2) { + self.results.insert(id, HintResult::G2(value)); + } + + /// Insert a GT result. + #[inline] + pub fn insert_gt(&mut self, id: OpId, value: E::GT) { + self.results.insert(id, HintResult::GT(value)); + } + + /// Total number of hints stored. + #[inline] + pub fn len(&self) -> usize { + self.results.len() + } + + /// Check if the hint map is empty. + #[inline] + pub fn is_empty(&self) -> bool { + self.results.is_empty() + } + + /// Iterate over all (OpId, HintResult) pairs. + pub fn iter(&self) -> impl Iterator)> { + self.results.iter() + } + + /// Check if a hint exists for the given operation. + #[inline] + pub fn contains(&self, id: OpId) -> bool { + self.results.contains_key(&id) + } +} + +impl Default for HintMap { + fn default() -> Self { + Self::new(0) + } +} + +impl Valid for HintMap { + fn check(&self) -> Result<(), SerializationError> { + for result in self.results.values() { + result.check()?; + } + Ok(()) + } +} + +impl DorySerialize for HintMap { + fn serialize_with_mode( + &self, + mut writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + (self.num_rounds as u64).serialize_with_mode(&mut writer, compress)?; + (self.results.len() as u64).serialize_with_mode(&mut writer, compress)?; + + for (id, result) in &self.results { + // Serialize OpId as (round: u16, op_type: u8, index: u16) + id.round.serialize_with_mode(&mut writer, compress)?; + (id.op_type as u8).serialize_with_mode(&mut writer, compress)?; + id.index.serialize_with_mode(&mut writer, compress)?; + result.serialize_with_mode(&mut writer, compress)?; + } + Ok(()) + } + + fn serialized_size(&self, compress: Compress) -> usize { + let header = 8 + 8; // num_rounds + len + let entries: usize = self + .results + .values() + .map(|r| 2 + 1 + 2 + r.serialized_size(compress)) + .sum(); + header + entries + } +} + +impl DoryDeserialize for HintMap { + fn deserialize_with_mode( + mut reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + let num_rounds = u64::deserialize_with_mode(&mut reader, compress, validate)? as usize; + let len = u64::deserialize_with_mode(&mut reader, compress, validate)? as usize; + + let mut results = HashMap::with_capacity(len); + for _ in 0..len { + let round = u16::deserialize_with_mode(&mut reader, compress, validate)?; + let op_type_byte = u8::deserialize_with_mode(&mut reader, compress, validate)?; + let index = u16::deserialize_with_mode(&mut reader, compress, validate)?; + + let op_type = match op_type_byte { + 0 => OpType::GtExp, + 1 => OpType::G1ScalarMul, + 2 => OpType::G2ScalarMul, + 3 => OpType::GtMul, + 4 => OpType::Pairing, + 5 => OpType::MultiPairing, + 6 => OpType::MsmG1, + 7 => OpType::MsmG2, + _ => { + return Err(SerializationError::InvalidData(format!( + "Invalid OpType: {op_type_byte}" + ))) + } + }; + + let id = OpId::new(round, op_type, index); + let result = HintResult::deserialize_with_mode(&mut reader, compress, validate)?; + results.insert(id, result); + } + + Ok(Self { + num_rounds, + results, + }) + } +} diff --git a/src/recursion/mod.rs b/src/recursion/mod.rs new file mode 100644 index 0000000..f95fec7 --- /dev/null +++ b/src/recursion/mod.rs @@ -0,0 +1,61 @@ +//! Recursion support for Dory polynomial commitment verification. +//! +//! This module provides infrastructure for recursive proof composition by enabling: +//! +//! 1. **Witness Generation**: Capture detailed traces of all arithmetic operations +//! during verification, suitable for proving in a bespoke SNARK. +//! +//! 2. **Hint-Based Verification**: Run verification using pre-computed hints instead +//! of performing expensive operations, enabling faster verification. +//! +//! # Architecture +//! +//! The recursion system is built around these core abstractions: +//! +//! - [`TraceContext`]: Unified context managing witness generation or hint-based modes +//! - Internal trace wrappers (`TraceG1`, `TraceG2`, `TraceGT`): Auto-trace operations +//! - Internal operators (`TracePairing`): Traced pairing operations +//! - [`HintMap`]: Hint storage for operation results +//! - [`WitnessBackend`]: Backend-defined witness types +//! +//! # Usage +//! +//! ```ignore +//! use std::rc::Rc; +//! use dory_pcs::recursion::TraceContext; +//! use dory_pcs::verify_recursive; +//! +//! // Witness generation mode +//! let ctx = Rc::new(TraceContext::for_witness_gen()); +//! verify_recursive::<_, E, M1, M2, _, W, Gen>( +//! commitment, evaluation, &point, &proof, setup.clone(), &mut transcript, ctx.clone() +//! )?; +//! let witnesses = Rc::try_unwrap(ctx).ok().unwrap().finalize(); +//! +//! // Convert to lightweight hints +//! let hints = witnesses.unwrap().to_hints::(); +//! +//! // Hint-based verification (with fallback on missing hints) +//! let ctx = Rc::new(TraceContext::for_hints(hints)); +//! verify_recursive::<_, E, M1, M2, _, W, Gen>( +//! commitment, evaluation, &point, &proof, setup, &mut transcript, ctx +//! )?; +//! ``` + +mod collection; +mod collector; +mod context; +mod hint_map; +mod trace; +mod witness; + +pub use collection::WitnessCollection; +pub use collector::WitnessGenerator; +pub use context::{CtxHandle, TraceContext}; +pub use hint_map::HintMap; +pub use witness::{OpId, OpType, WitnessBackend}; + +pub(crate) use collector::{OpIdBuilder, WitnessCollector}; +pub(crate) use context::ExecutionMode; +pub(crate) use trace::{TraceG1, TraceG2, TraceGT, TracePairing}; +pub(crate) use witness::WitnessResult; diff --git a/src/recursion/trace.rs b/src/recursion/trace.rs new file mode 100644 index 0000000..5f16d65 --- /dev/null +++ b/src/recursion/trace.rs @@ -0,0 +1,797 @@ +//! Trace wrapper types for automatic operation tracing. +//! +//! This module provides wrapper types (`TraceG1`, `TraceG2`, `TraceGT`) that +//! automatically trace arithmetic operations during verification. Operations +//! are recorded (in witness generation mode) or use hints (in hint-based mode) + +// Some methods/types are kept for API completeness but not currently used +#![allow(dead_code)] + +use std::ops::{Add, Neg, Sub}; +use std::rc::Rc; + +use super::witness::{OpType, WitnessBackend}; +use crate::primitives::arithmetic::{Group, PairingCurve}; + +use super::{CtxHandle, ExecutionMode, WitnessGenerator}; + +/// G1 element with automatic operation tracing. +#[derive(Clone)] +pub(crate) struct TraceG1 +where + W: WitnessBackend, + E: PairingCurve, + Gen: WitnessGenerator, +{ + inner: E::G1, + ctx: CtxHandle, +} + +impl TraceG1 +where + W: WitnessBackend, + E: PairingCurve, + Gen: WitnessGenerator, +{ + /// Wrap a G1 element with a trace context. + #[inline] + pub(crate) fn new(inner: E::G1, ctx: CtxHandle) -> Self { + Self { inner, ctx } + } + + /// Get a reference to the underlying G1 element. + #[inline] + pub(crate) fn inner(&self) -> &E::G1 { + &self.inner + } + + /// Unwrap to get the raw G1 value. + #[inline] + pub(crate) fn into_inner(self) -> E::G1 { + self.inner + } + + /// Get a clone of the context handle. + #[inline] + pub(crate) fn ctx(&self) -> CtxHandle { + Rc::clone(&self.ctx) + } + + /// Traced scalar multiplication. + pub(crate) fn scale(&self, scalar: &::Scalar) -> Self { + let id = self.ctx.next_id(OpType::G1ScalarMul); + + match self.ctx.mode() { + ExecutionMode::WitnessGeneration => { + let result = self.inner.scale(scalar); + self.ctx + .record_g1_scalar_mul(id, &self.inner, scalar, &result); + Self::new(result, Rc::clone(&self.ctx)) + } + ExecutionMode::HintBased => { + if let Some(result) = self.ctx.get_hint_g1(id) { + Self::new(result, Rc::clone(&self.ctx)) + } else { + tracing::warn!( + op_id = ?id, + op_type = "G1ScalarMul", + round = id.round, + index = id.index, + "Missing hint, computing fallback" + ); + self.ctx.record_missing_hint(id); + let result = self.inner.scale(scalar); + Self::new(result, Rc::clone(&self.ctx)) + } + } + } + } + + /// Get the identity element for G1. + pub(crate) fn identity(ctx: CtxHandle) -> Self { + Self::new(E::G1::identity(), ctx) + } +} + +// G1 + G1 +impl Add for TraceG1 +where + W: WitnessBackend, + E: PairingCurve, + Gen: WitnessGenerator, +{ + type Output = Self; + + fn add(self, rhs: Self) -> Self { + Self::new(self.inner + rhs.inner, self.ctx) + } +} + +impl Add<&Self> for TraceG1 +where + W: WitnessBackend, + E: PairingCurve, + Gen: WitnessGenerator, +{ + type Output = Self; + + fn add(self, rhs: &Self) -> Self { + Self::new(self.inner + rhs.inner, self.ctx) + } +} + +// G1 - G1 +impl Sub for TraceG1 +where + W: WitnessBackend, + E: PairingCurve, + Gen: WitnessGenerator, +{ + type Output = Self; + + fn sub(self, rhs: Self) -> Self { + Self::new(self.inner - rhs.inner, self.ctx) + } +} + +impl Sub<&Self> for TraceG1 +where + W: WitnessBackend, + E: PairingCurve, + Gen: WitnessGenerator, +{ + type Output = Self; + + fn sub(self, rhs: &Self) -> Self { + Self::new(self.inner - rhs.inner, self.ctx) + } +} + +// -G1 +impl Neg for TraceG1 +where + W: WitnessBackend, + E: PairingCurve, + Gen: WitnessGenerator, +{ + type Output = Self; + + fn neg(self) -> Self { + Self::new(-self.inner, self.ctx) + } +} + +/// G2 element with automatic operation tracing. +#[derive(Clone)] +pub(crate) struct TraceG2 +where + W: WitnessBackend, + E: PairingCurve, + Gen: WitnessGenerator, +{ + inner: E::G2, + ctx: CtxHandle, +} + +impl TraceG2 +where + W: WitnessBackend, + E: PairingCurve, + Gen: WitnessGenerator, +{ + /// Wrap a G2 element with a trace context. + #[inline] + pub(crate) fn new(inner: E::G2, ctx: CtxHandle) -> Self { + Self { inner, ctx } + } + + /// Get a reference to the underlying G2 element. + #[inline] + pub(crate) fn inner(&self) -> &E::G2 { + &self.inner + } + + /// Unwrap to get the raw G2 value. + #[inline] + pub(crate) fn into_inner(self) -> E::G2 { + self.inner + } + + /// Get a clone of the context handle. + #[inline] + pub(crate) fn ctx(&self) -> CtxHandle { + Rc::clone(&self.ctx) + } + + /// Traced scalar multiplication. + pub(crate) fn scale(&self, scalar: &::Scalar) -> Self + where + E::G2: Group::Scalar>, + { + let id = self.ctx.next_id(OpType::G2ScalarMul); + + match self.ctx.mode() { + ExecutionMode::WitnessGeneration => { + let result = self.inner.scale(scalar); + self.ctx + .record_g2_scalar_mul(id, &self.inner, scalar, &result); + Self::new(result, Rc::clone(&self.ctx)) + } + ExecutionMode::HintBased => { + if let Some(result) = self.ctx.get_hint_g2(id) { + Self::new(result, Rc::clone(&self.ctx)) + } else { + tracing::warn!( + op_id = ?id, + op_type = "G2ScalarMul", + round = id.round, + index = id.index, + "Missing hint, computing fallback" + ); + self.ctx.record_missing_hint(id); + let result = self.inner.scale(scalar); + Self::new(result, Rc::clone(&self.ctx)) + } + } + } + } + + /// Get the identity element for G2. + pub(crate) fn identity(ctx: CtxHandle) -> Self { + Self::new(E::G2::identity(), ctx) + } +} + +// G2 + G2 +impl Add for TraceG2 +where + W: WitnessBackend, + E: PairingCurve, + Gen: WitnessGenerator, +{ + type Output = Self; + + fn add(self, rhs: Self) -> Self { + Self::new(self.inner + rhs.inner, self.ctx) + } +} + +impl Add<&Self> for TraceG2 +where + W: WitnessBackend, + E: PairingCurve, + Gen: WitnessGenerator, +{ + type Output = Self; + + fn add(self, rhs: &Self) -> Self { + Self::new(self.inner + rhs.inner, self.ctx) + } +} + +// G2 - G2 +impl Sub for TraceG2 +where + W: WitnessBackend, + E: PairingCurve, + Gen: WitnessGenerator, +{ + type Output = Self; + + fn sub(self, rhs: Self) -> Self { + Self::new(self.inner - rhs.inner, self.ctx) + } +} + +impl Sub<&Self> for TraceG2 +where + W: WitnessBackend, + E: PairingCurve, + Gen: WitnessGenerator, +{ + type Output = Self; + + fn sub(self, rhs: &Self) -> Self { + Self::new(self.inner - rhs.inner, self.ctx) + } +} + +// -G2 +impl Neg for TraceG2 +where + W: WitnessBackend, + E: PairingCurve, + Gen: WitnessGenerator, +{ + type Output = Self; + + fn neg(self) -> Self { + Self::new(-self.inner, self.ctx) + } +} + +/// GT element with automatic operation tracing. +/// +/// Note: GT is a multiplicative group, so "addition" in the Group trait +/// corresponds to field multiplication in Fq12 +#[derive(Clone)] +pub(crate) struct TraceGT +where + W: WitnessBackend, + E: PairingCurve, + Gen: WitnessGenerator, +{ + inner: E::GT, + ctx: CtxHandle, +} + +impl TraceGT +where + W: WitnessBackend, + E: PairingCurve, + Gen: WitnessGenerator, +{ + /// Wrap a GT element with a trace context. + #[inline] + pub(crate) fn new(inner: E::GT, ctx: CtxHandle) -> Self { + Self { inner, ctx } + } + + /// Get a reference to the underlying GT element. + #[inline] + pub(crate) fn inner(&self) -> &E::GT { + &self.inner + } + + /// Unwrap to get the raw GT value. + #[inline] + pub(crate) fn into_inner(self) -> E::GT { + self.inner + } + + /// Get a clone of the context handle. + #[inline] + pub(crate) fn ctx(&self) -> CtxHandle { + Rc::clone(&self.ctx) + } + + /// Traced GT exponentiation (scalar multiplication in multiplicative group). + pub(crate) fn scale(&self, scalar: &::Scalar) -> Self + where + E::GT: Group::Scalar>, + { + let id = self.ctx.next_id(OpType::GtExp); + + match self.ctx.mode() { + ExecutionMode::WitnessGeneration => { + let result = self.inner.scale(scalar); + self.ctx.record_gt_exp(id, &self.inner, scalar, &result); + Self::new(result, Rc::clone(&self.ctx)) + } + ExecutionMode::HintBased => { + if let Some(result) = self.ctx.get_hint_gt(id) { + Self::new(result, Rc::clone(&self.ctx)) + } else { + tracing::warn!( + op_id = ?id, + op_type = "GtExp", + round = id.round, + index = id.index, + "Missing hint, computing fallback" + ); + self.ctx.record_missing_hint(id); + let result = self.inner.scale(scalar); + Self::new(result, Rc::clone(&self.ctx)) + } + } + } + } + + /// Traced GT multiplication. + pub(crate) fn mul_traced(&self, rhs: &Self) -> Self { + let id = self.ctx.next_id(OpType::GtMul); + + match self.ctx.mode() { + ExecutionMode::WitnessGeneration => { + let result = self.inner + rhs.inner; + self.ctx.record_gt_mul(id, &self.inner, &rhs.inner, &result); + Self::new(result, Rc::clone(&self.ctx)) + } + ExecutionMode::HintBased => { + if let Some(result) = self.ctx.get_hint_gt(id) { + Self::new(result, Rc::clone(&self.ctx)) + } else { + tracing::warn!( + op_id = ?id, + op_type = "GtMul", + round = id.round, + index = id.index, + "Missing hint, computing fallback" + ); + self.ctx.record_missing_hint(id); + let result = self.inner + rhs.inner; + Self::new(result, Rc::clone(&self.ctx)) + } + } + } + } + + /// Get the identity element for GT (the multiplicative identity). + pub(crate) fn identity(ctx: CtxHandle) -> Self { + Self::new(E::GT::identity(), ctx) + } +} + +// GT * GT +impl Add for TraceGT +where + W: WitnessBackend, + E: PairingCurve, + Gen: WitnessGenerator, +{ + type Output = Self; + + fn add(self, rhs: Self) -> Self { + self.mul_traced(&rhs) + } +} + +impl Add<&Self> for TraceGT +where + W: WitnessBackend, + E: PairingCurve, + Gen: WitnessGenerator, +{ + type Output = Self; + + fn add(self, rhs: &Self) -> Self { + self.mul_traced(rhs) + } +} + +// GT^(-1) (NOT traced) +impl Neg for TraceGT +where + W: WitnessBackend, + E: PairingCurve, + Gen: WitnessGenerator, +{ + type Output = Self; + + fn neg(self) -> Self { + Self::new(-self.inner, self.ctx) + } +} + +/// Traced pairing operations. +/// +/// Provides `pair` and `multi_pair` methods that automatically trace +/// the pairing computation. +pub(crate) struct TracePairing +where + W: WitnessBackend, + E: PairingCurve, + Gen: WitnessGenerator, +{ + ctx: CtxHandle, +} + +impl TracePairing +where + W: WitnessBackend, + E: PairingCurve, + Gen: WitnessGenerator, +{ + /// Create a new traced pairing operator with the given context. + pub(crate) fn new(ctx: CtxHandle) -> Self { + Self { ctx } + } + + /// Traced single pairing e(G1, G2) -> GT. + pub(crate) fn pair( + &self, + g1: &TraceG1, + g2: &TraceG2, + ) -> TraceGT { + let id = self.ctx.next_id(OpType::Pairing); + + match self.ctx.mode() { + ExecutionMode::WitnessGeneration => { + let result = E::pair(&g1.inner, &g2.inner); + self.ctx.record_pairing(id, &g1.inner, &g2.inner, &result); + TraceGT::new(result, Rc::clone(&self.ctx)) + } + ExecutionMode::HintBased => { + if let Some(result) = self.ctx.get_hint_gt(id) { + TraceGT::new(result, Rc::clone(&self.ctx)) + } else { + tracing::warn!( + op_id = ?id, + op_type = "Pairing", + round = id.round, + index = id.index, + "Missing hint, computing fallback" + ); + self.ctx.record_missing_hint(id); + let result = E::pair(&g1.inner, &g2.inner); + TraceGT::new(result, Rc::clone(&self.ctx)) + } + } + } + } + + /// Traced single pairing from raw G1/G2 elements. + pub(crate) fn pair_raw(&self, g1: &E::G1, g2: &E::G2) -> TraceGT { + let id = self.ctx.next_id(OpType::Pairing); + + match self.ctx.mode() { + ExecutionMode::WitnessGeneration => { + let result = E::pair(g1, g2); + self.ctx.record_pairing(id, g1, g2, &result); + TraceGT::new(result, Rc::clone(&self.ctx)) + } + ExecutionMode::HintBased => { + if let Some(result) = self.ctx.get_hint_gt(id) { + TraceGT::new(result, Rc::clone(&self.ctx)) + } else { + tracing::warn!( + op_id = ?id, + op_type = "Pairing", + round = id.round, + index = id.index, + "Missing hint, computing fallback" + ); + self.ctx.record_missing_hint(id); + let result = E::pair(g1, g2); + TraceGT::new(result, Rc::clone(&self.ctx)) + } + } + } + } + + /// Traced multi-pairing: product of e(g1s[i], g2s[i]). + pub(crate) fn multi_pair( + &self, + g1s: &[TraceG1], + g2s: &[TraceG2], + ) -> TraceGT { + let id = self.ctx.next_id(OpType::MultiPairing); + + let g1_inners: Vec = g1s.iter().map(|g| g.inner).collect(); + let g2_inners: Vec = g2s.iter().map(|g| g.inner).collect(); + + match self.ctx.mode() { + ExecutionMode::WitnessGeneration => { + let result = E::multi_pair(&g1_inners, &g2_inners); + self.ctx + .record_multi_pairing(id, &g1_inners, &g2_inners, &result); + TraceGT::new(result, Rc::clone(&self.ctx)) + } + ExecutionMode::HintBased => { + if let Some(result) = self.ctx.get_hint_gt(id) { + TraceGT::new(result, Rc::clone(&self.ctx)) + } else { + tracing::warn!( + op_id = ?id, + op_type = "MultiPairing", + round = id.round, + index = id.index, + num_pairs = g1s.len(), + "Missing hint, computing fallback" + ); + self.ctx.record_missing_hint(id); + let result = E::multi_pair(&g1_inners, &g2_inners); + TraceGT::new(result, Rc::clone(&self.ctx)) + } + } + } + } + + /// Traced multi-pairing from raw slices. + pub(crate) fn multi_pair_raw(&self, g1s: &[E::G1], g2s: &[E::G2]) -> TraceGT { + let id = self.ctx.next_id(OpType::MultiPairing); + + match self.ctx.mode() { + ExecutionMode::WitnessGeneration => { + let result = E::multi_pair(g1s, g2s); + self.ctx.record_multi_pairing(id, g1s, g2s, &result); + TraceGT::new(result, Rc::clone(&self.ctx)) + } + ExecutionMode::HintBased => { + if let Some(result) = self.ctx.get_hint_gt(id) { + TraceGT::new(result, Rc::clone(&self.ctx)) + } else { + tracing::warn!( + op_id = ?id, + op_type = "MultiPairing", + round = id.round, + index = id.index, + num_pairs = g1s.len(), + "Missing hint, computing fallback" + ); + self.ctx.record_missing_hint(id); + let result = E::multi_pair(g1s, g2s); + TraceGT::new(result, Rc::clone(&self.ctx)) + } + } + } + } +} + +/// Traced MSM operations. +pub(crate) struct TraceMsm +where + W: WitnessBackend, + E: PairingCurve, + Gen: WitnessGenerator, +{ + ctx: CtxHandle, +} + +impl TraceMsm +where + W: WitnessBackend, + E: PairingCurve, + Gen: WitnessGenerator, +{ + /// Create a new traced MSM operator with the given context. + pub(crate) fn new(ctx: CtxHandle) -> Self { + Self { ctx } + } + + /// Traced G1 MSM using the provided MSM implementation. + pub(crate) fn msm_g1( + &self, + bases: &[TraceG1], + scalars: &[::Scalar], + msm_fn: F, + ) -> TraceG1 + where + F: FnOnce(&[E::G1], &[::Scalar]) -> E::G1, + { + let id = self.ctx.next_id(OpType::MsmG1); + let base_inners: Vec = bases.iter().map(|b| b.inner).collect(); + + match self.ctx.mode() { + ExecutionMode::WitnessGeneration => { + let result = msm_fn(&base_inners, scalars); + self.ctx.record_msm_g1(id, &base_inners, scalars, &result); + TraceG1::new(result, Rc::clone(&self.ctx)) + } + ExecutionMode::HintBased => { + if let Some(result) = self.ctx.get_hint_g1(id) { + TraceG1::new(result, Rc::clone(&self.ctx)) + } else { + tracing::warn!( + op_id = ?id, + op_type = "MsmG1", + round = id.round, + index = id.index, + size = bases.len(), + "Missing hint, computing fallback" + ); + self.ctx.record_missing_hint(id); + let result = msm_fn(&base_inners, scalars); + TraceG1::new(result, Rc::clone(&self.ctx)) + } + } + } + } + + /// Traced G1 MSM from raw bases. + pub(crate) fn msm_g1_raw( + &self, + bases: &[E::G1], + scalars: &[::Scalar], + msm_fn: F, + ) -> TraceG1 + where + F: FnOnce(&[E::G1], &[::Scalar]) -> E::G1, + { + let id = self.ctx.next_id(OpType::MsmG1); + + match self.ctx.mode() { + ExecutionMode::WitnessGeneration => { + let result = msm_fn(bases, scalars); + self.ctx.record_msm_g1(id, bases, scalars, &result); + TraceG1::new(result, Rc::clone(&self.ctx)) + } + ExecutionMode::HintBased => { + if let Some(result) = self.ctx.get_hint_g1(id) { + TraceG1::new(result, Rc::clone(&self.ctx)) + } else { + tracing::warn!( + op_id = ?id, + op_type = "MsmG1", + round = id.round, + index = id.index, + size = bases.len(), + "Missing hint, computing fallback" + ); + self.ctx.record_missing_hint(id); + let result = msm_fn(bases, scalars); + TraceG1::new(result, Rc::clone(&self.ctx)) + } + } + } + } + + /// Traced G2 MSM using the provided MSM implementation. + pub(crate) fn msm_g2( + &self, + bases: &[TraceG2], + scalars: &[::Scalar], + msm_fn: F, + ) -> TraceG2 + where + F: FnOnce(&[E::G2], &[::Scalar]) -> E::G2, + E::G2: Group::Scalar>, + { + let id = self.ctx.next_id(OpType::MsmG2); + let base_inners: Vec = bases.iter().map(|b| b.inner).collect(); + + match self.ctx.mode() { + ExecutionMode::WitnessGeneration => { + let result = msm_fn(&base_inners, scalars); + self.ctx.record_msm_g2(id, &base_inners, scalars, &result); + TraceG2::new(result, Rc::clone(&self.ctx)) + } + ExecutionMode::HintBased => { + if let Some(result) = self.ctx.get_hint_g2(id) { + TraceG2::new(result, Rc::clone(&self.ctx)) + } else { + tracing::warn!( + op_id = ?id, + op_type = "MsmG2", + round = id.round, + index = id.index, + size = bases.len(), + "Missing hint, computing fallback" + ); + self.ctx.record_missing_hint(id); + let result = msm_fn(&base_inners, scalars); + TraceG2::new(result, Rc::clone(&self.ctx)) + } + } + } + } + + /// Traced G2 MSM from raw bases. + pub(crate) fn msm_g2_raw( + &self, + bases: &[E::G2], + scalars: &[::Scalar], + msm_fn: F, + ) -> TraceG2 + where + F: FnOnce(&[E::G2], &[::Scalar]) -> E::G2, + E::G2: Group::Scalar>, + { + let id = self.ctx.next_id(OpType::MsmG2); + + match self.ctx.mode() { + ExecutionMode::WitnessGeneration => { + let result = msm_fn(bases, scalars); + self.ctx.record_msm_g2(id, bases, scalars, &result); + TraceG2::new(result, Rc::clone(&self.ctx)) + } + ExecutionMode::HintBased => { + if let Some(result) = self.ctx.get_hint_g2(id) { + TraceG2::new(result, Rc::clone(&self.ctx)) + } else { + tracing::warn!( + op_id = ?id, + op_type = "MsmG2", + round = id.round, + index = id.index, + size = bases.len(), + "Missing hint, computing fallback" + ); + self.ctx.record_missing_hint(id); + let result = msm_fn(bases, scalars); + TraceG2::new(result, Rc::clone(&self.ctx)) + } + } + } + } +} diff --git a/src/recursion/witness.rs b/src/recursion/witness.rs new file mode 100644 index 0000000..02e66a3 --- /dev/null +++ b/src/recursion/witness.rs @@ -0,0 +1,105 @@ +//! Witness generation types and traits for recursive proof composition. + +/// Operation type identifier for witness indexing. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(u8)] +pub enum OpType { + /// GT exponentiation: base^scalar in the target group + GtExp = 0, + /// G1 scalar multiplication: scalar * point + G1ScalarMul = 1, + /// G2 scalar multiplication: scalar * point + G2ScalarMul = 2, + /// GT multiplication: lhs * rhs in the target group + GtMul = 3, + /// Single pairing: e(G1, G2) -> GT + Pairing = 4, + /// Multi-pairing: product of pairings + MultiPairing = 5, + /// Multi-scalar multiplication in G1 + MsmG1 = 6, + /// Multi-scalar multiplication in G2 + MsmG2 = 7, +} + +/// Unique identifier for an arithmetic operation in the verification protocol. +/// +/// Operations are indexed by (round, op_type, index) to enable deterministic +/// mapping between witness generation and hint consumption. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct OpId { + /// Protocol round number (0 for initial checks, 1..=num_rounds for reduce rounds) + pub round: u16, + /// Type of arithmetic operation + pub op_type: OpType, + /// Index within the round for operations of the same type + pub index: u16, +} + +impl OpId { + /// Create a new operation identifier. + #[inline] + pub const fn new(round: u16, op_type: OpType, index: u16) -> Self { + Self { + round, + op_type, + index, + } + } + + /// Create an operation ID for the initial VMV check phase (round 0). + #[inline] + pub const fn vmv(op_type: OpType, index: u16) -> Self { + Self::new(0, op_type, index) + } + + /// Create an operation ID for a reduce-and-fold round. + #[inline] + pub const fn reduce(round: u16, op_type: OpType, index: u16) -> Self { + Self::new(round, op_type, index) + } + + /// Create an operation ID for the final verification phase. + /// Uses round = u16::MAX to distinguish from reduce rounds. + #[inline] + pub const fn final_verify(op_type: OpType, index: u16) -> Self { + Self::new(u16::MAX, op_type, index) + } +} + +/// Backend-defined witness types for arithmetic operations. +/// +/// Each proof system backend implements this trait to define +/// the structure of witness data for each operation type. This allows different +/// proof systems to capture the level of detail they need. +pub trait WitnessBackend: Sized + Send + Sync + 'static { + /// Witness type for GT exponentiation (base^scalar). + type GtExpWitness: Clone + Send + Sync; + + /// Witness type for G1 scalar multiplication. + type G1ScalarMulWitness: Clone + Send + Sync; + + /// Witness type for G2 scalar multiplication. + type G2ScalarMulWitness: Clone + Send + Sync; + + /// Witness type for GT multiplication (Fq12 multiplication). + type GtMulWitness: Clone + Send + Sync; + + /// Witness type for single pairing e(G1, G2) -> GT. + type PairingWitness: Clone + Send + Sync; + + /// Witness type for multi-pairing (product of pairings). + type MultiPairingWitness: Clone + Send + Sync; + + /// Witness type for G1 multi-scalar multiplication. + type MsmG1Witness: Clone + Send + Sync; + + /// Witness type for G2 multi-scalar multiplication. + type MsmG2Witness: Clone + Send + Sync; +} + +/// Trait for extracting the result from a witness. +pub trait WitnessResult { + /// Get the result of the operation. + fn result(&self) -> &T; +} diff --git a/tests/arkworks/mod.rs b/tests/arkworks/mod.rs index e235c47..2e27416 100644 --- a/tests/arkworks/mod.rs +++ b/tests/arkworks/mod.rs @@ -16,8 +16,12 @@ pub mod evaluation; pub mod homomorphic; pub mod integration; pub mod non_square; +#[cfg(feature = "recursion")] +pub mod recursion; pub mod setup; pub mod soundness; +#[cfg(feature = "recursion")] +pub mod witness; pub fn random_polynomial(size: usize) -> ArkworksPolynomial { let mut rng = thread_rng(); diff --git a/tests/arkworks/recursion.rs b/tests/arkworks/recursion.rs new file mode 100644 index 0000000..31fcef0 --- /dev/null +++ b/tests/arkworks/recursion.rs @@ -0,0 +1,315 @@ +//! Integration tests for recursion feature (witness generation and hint-based verification) + +use std::rc::Rc; + +use super::*; +use dory_pcs::backends::arkworks::{SimpleWitnessBackend, SimpleWitnessGenerator}; +use dory_pcs::primitives::poly::Polynomial; +use dory_pcs::recursion::TraceContext; +use dory_pcs::{prove, setup, verify_recursive}; + +type TestCtx = TraceContext; + +#[test] +fn test_witness_gen_roundtrip() { + let mut rng = rand::thread_rng(); + let max_log_n = 10; + + let (prover_setup, verifier_setup) = setup::(&mut rng, max_log_n); + + let poly = random_polynomial(256); + let nu = 4; + let sigma = 4; + + let (tier_2, tier_1) = poly + .commit::(nu, sigma, &prover_setup) + .unwrap(); + + let point = random_point(8); + + let mut prover_transcript = fresh_transcript(); + let proof = prove::<_, BN254, TestG1Routines, TestG2Routines, _, _>( + &poly, + &point, + tier_1, + nu, + sigma, + &prover_setup, + &mut prover_transcript, + ) + .unwrap(); + let evaluation = poly.evaluate(&point); + + // Phase 1: Witness generation + let ctx = Rc::new(TestCtx::for_witness_gen()); + let mut witness_transcript = fresh_transcript(); + + verify_recursive::<_, BN254, TestG1Routines, TestG2Routines, _, _, _>( + tier_2, + evaluation, + &point, + &proof, + verifier_setup.clone(), + &mut witness_transcript, + ctx.clone(), + ) + .expect("Witness-generating verification should succeed"); + + let collection = Rc::try_unwrap(ctx) + .ok() + .expect("Should have sole ownership") + .finalize() + .expect("Should have witnesses"); + + // Phase 2: Hint-based verification + let hints = collection.to_hints::(); + let ctx = Rc::new(TestCtx::for_hints(hints)); + let mut hint_transcript = fresh_transcript(); + + let result = verify_recursive::<_, BN254, TestG1Routines, TestG2Routines, _, _, _>( + tier_2, + evaluation, + &point, + &proof, + verifier_setup, + &mut hint_transcript, + ctx, + ); + + assert!(result.is_ok(), "Hint-based verification should succeed"); +} + +#[test] +fn test_witness_collection_contents() { + let mut rng = rand::thread_rng(); + let max_log_n = 6; + + let (prover_setup, verifier_setup) = setup::(&mut rng, max_log_n); + + let poly = random_polynomial(16); + let nu = 2; + let sigma = 2; + + let (tier_2, tier_1) = poly + .commit::(nu, sigma, &prover_setup) + .unwrap(); + + let point = random_point(4); + + let mut prover_transcript = fresh_transcript(); + let proof = prove::<_, BN254, TestG1Routines, TestG2Routines, _, _>( + &poly, + &point, + tier_1, + nu, + sigma, + &prover_setup, + &mut prover_transcript, + ) + .unwrap(); + let evaluation = poly.evaluate(&point); + + let ctx = Rc::new(TestCtx::for_witness_gen()); + let mut witness_transcript = fresh_transcript(); + + verify_recursive::<_, BN254, TestG1Routines, TestG2Routines, _, _, _>( + tier_2, + evaluation, + &point, + &proof, + verifier_setup, + &mut witness_transcript, + ctx.clone(), + ) + .expect("Witness-generating verification should succeed"); + + let collection = Rc::try_unwrap(ctx) + .ok() + .expect("Should have sole ownership") + .finalize() + .expect("Should have witnesses"); + + // Verify the collection contains expected operation types + assert!( + !collection.gt_exp.is_empty(), + "Should have GT exponentiation witnesses" + ); + assert!( + !collection.pairing.is_empty() || !collection.multi_pairing.is_empty(), + "Should have pairing witnesses" + ); + + tracing::info!( + gt_exp = collection.gt_exp.len(), + g1_scalar_mul = collection.g1_scalar_mul.len(), + g2_scalar_mul = collection.g2_scalar_mul.len(), + gt_mul = collection.gt_mul.len(), + pairing = collection.pairing.len(), + multi_pairing = collection.multi_pairing.len(), + msm_g1 = collection.msm_g1.len(), + msm_g2 = collection.msm_g2.len(), + total = collection.total_witnesses(), + rounds = collection.num_rounds, + "Witness collection stats" + ); +} + +#[test] +fn test_hint_verification_with_missing_hints() { + let mut rng = rand::thread_rng(); + let max_log_n = 6; + + let (prover_setup, verifier_setup) = setup::(&mut rng, max_log_n); + + // Create two different polynomials + let poly1 = random_polynomial(16); + let poly2 = random_polynomial(16); + let nu = 2; + let sigma = 2; + + let (tier_2_1, tier_1_1) = poly1 + .commit::(nu, sigma, &prover_setup) + .unwrap(); + + let (tier_2_2, tier_1_2) = poly2 + .commit::(nu, sigma, &prover_setup) + .unwrap(); + + let point = random_point(4); + + // Create proof for poly1 + let mut prover_transcript1 = fresh_transcript(); + let proof1 = prove::<_, BN254, TestG1Routines, TestG2Routines, _, _>( + &poly1, + &point, + tier_1_1, + nu, + sigma, + &prover_setup, + &mut prover_transcript1, + ) + .unwrap(); + let evaluation1 = poly1.evaluate(&point); + + // Create proof for poly2 + let mut prover_transcript2 = fresh_transcript(); + let proof2 = prove::<_, BN254, TestG1Routines, TestG2Routines, _, _>( + &poly2, + &point, + tier_1_2, + nu, + sigma, + &prover_setup, + &mut prover_transcript2, + ) + .unwrap(); + let evaluation2 = poly2.evaluate(&point); + + // Generate hints for poly1's verification + let ctx = Rc::new(TestCtx::for_witness_gen()); + let mut witness_transcript = fresh_transcript(); + + verify_recursive::<_, BN254, TestG1Routines, TestG2Routines, _, _, _>( + tier_2_1, + evaluation1, + &point, + &proof1, + verifier_setup.clone(), + &mut witness_transcript, + ctx.clone(), + ) + .expect("Witness-generating verification should succeed"); + + let collection = Rc::try_unwrap(ctx) + .ok() + .expect("Should have sole ownership") + .finalize() + .expect("Should have witnesses"); + + let hints = collection.to_hints::(); + + // Try to use poly1's hints for poly2's verification + let ctx = Rc::new(TestCtx::for_hints(hints)); + let mut hint_transcript = fresh_transcript(); + + let result = verify_recursive::<_, BN254, TestG1Routines, TestG2Routines, _, _, _>( + tier_2_2, + evaluation2, + &point, + &proof2, + verifier_setup, + &mut hint_transcript, + ctx.clone(), + ); + + // The verification should fail because the hints don't match the proof + assert!(result.is_err(), "Verification with wrong hints should fail"); +} + +#[test] +fn test_hint_map_size_reduction() { + let mut rng = rand::thread_rng(); + let max_log_n = 8; + + let (prover_setup, verifier_setup) = setup::(&mut rng, max_log_n); + + let poly = random_polynomial(64); + let nu = 3; + let sigma = 3; + + let (tier_2, tier_1) = poly + .commit::(nu, sigma, &prover_setup) + .unwrap(); + + let point = random_point(6); + + let mut prover_transcript = fresh_transcript(); + let proof = prove::<_, BN254, TestG1Routines, TestG2Routines, _, _>( + &poly, + &point, + tier_1, + nu, + sigma, + &prover_setup, + &mut prover_transcript, + ) + .unwrap(); + let evaluation = poly.evaluate(&point); + + let ctx = Rc::new(TestCtx::for_witness_gen()); + let mut witness_transcript = fresh_transcript(); + + verify_recursive::<_, BN254, TestG1Routines, TestG2Routines, _, _, _>( + tier_2, + evaluation, + &point, + &proof, + verifier_setup, + &mut witness_transcript, + ctx.clone(), + ) + .expect("Verification should succeed"); + + let collection = Rc::try_unwrap(ctx) + .ok() + .expect("Should have sole ownership") + .finalize() + .expect("Should have witnesses"); + + let hints = collection.to_hints::(); + + // Verify hint count matches total operations + let total_ops = collection.total_witnesses(); + tracing::info!( + total_ops, + hint_map_size = hints.len(), + "Hint map conversion stats" + ); + + // HintMap should have same number of entries as total witnesses + assert_eq!( + hints.len(), + total_ops, + "HintMap should have one entry per operation" + ); +} diff --git a/tests/arkworks/witness.rs b/tests/arkworks/witness.rs new file mode 100644 index 0000000..97dfb9a --- /dev/null +++ b/tests/arkworks/witness.rs @@ -0,0 +1,47 @@ +//! Tests for Arkworks witness generation + +use dory_pcs::backends::arkworks::{ArkFr, ArkG1, ArkG2, ArkGT, SimpleWitnessGenerator, BN254}; +use dory_pcs::primitives::arithmetic::{Field, Group, PairingCurve}; +use dory_pcs::recursion::WitnessGenerator; +use rand::thread_rng; + +#[test] +fn test_gt_exp_witness_generation() { + let mut rng = thread_rng(); + let base = ArkGT::random(&mut rng); + let scalar = ArkFr::random(&mut rng); + let result = base.scale(&scalar); + + let witness = SimpleWitnessGenerator::generate_gt_exp(&base, &scalar, &result); + + assert_eq!(witness.base, base); + assert_eq!(witness.result, result); + assert_eq!(witness.scalar_bits.len(), 254); +} + +#[test] +fn test_g1_scalar_mul_witness_generation() { + let mut rng = thread_rng(); + let point = ArkG1::random(&mut rng); + let scalar = ArkFr::random(&mut rng); + let result = point.scale(&scalar); + + let witness = SimpleWitnessGenerator::generate_g1_scalar_mul(&point, &scalar, &result); + + assert_eq!(witness.point, point); + assert_eq!(witness.result, result); +} + +#[test] +fn test_pairing_witness_generation() { + let mut rng = thread_rng(); + let g1 = ArkG1::random(&mut rng); + let g2 = ArkG2::random(&mut rng); + let result = BN254::pair(&g1, &g2); + + let witness = SimpleWitnessGenerator::generate_pairing(&g1, &g2, &result); + + assert_eq!(witness.g1, g1); + assert_eq!(witness.g2, g2); + assert_eq!(witness.result, result); +} From a735c99016b1f8b301c1866f3181545db8fe6566 Mon Sep 17 00:00:00 2001 From: markosg04 Date: Thu, 4 Dec 2025 12:49:05 -0500 Subject: [PATCH 2/4] refactor: some more recursion improvements --- src/backends/arkworks/ark_witness.rs | 32 ++++++++++++++-------------- src/recursion/collection.rs | 32 +++++++++++++++++++++------- src/recursion/mod.rs | 3 +-- src/recursion/witness.rs | 5 +++-- 4 files changed, 44 insertions(+), 28 deletions(-) diff --git a/src/backends/arkworks/ark_witness.rs b/src/backends/arkworks/ark_witness.rs index 3654a89..03ab629 100644 --- a/src/backends/arkworks/ark_witness.rs +++ b/src/backends/arkworks/ark_witness.rs @@ -49,8 +49,8 @@ pub struct GtExpWitness { } impl WitnessResult for GtExpWitness { - fn result(&self) -> &ArkGT { - &self.result + fn result(&self) -> Option<&ArkGT> { + Some(&self.result) } } @@ -70,8 +70,8 @@ pub struct G1ScalarMulWitness { } impl WitnessResult for G1ScalarMulWitness { - fn result(&self) -> &ArkG1 { - &self.result + fn result(&self) -> Option<&ArkG1> { + Some(&self.result) } } @@ -91,8 +91,8 @@ pub struct G2ScalarMulWitness { } impl WitnessResult for G2ScalarMulWitness { - fn result(&self) -> &ArkG2 { - &self.result + fn result(&self) -> Option<&ArkG2> { + Some(&self.result) } } @@ -112,8 +112,8 @@ pub struct GtMulWitness { } impl WitnessResult for GtMulWitness { - fn result(&self) -> &ArkGT { - &self.result + fn result(&self) -> Option<&ArkGT> { + Some(&self.result) } } @@ -144,8 +144,8 @@ pub struct PairingWitness { } impl WitnessResult for PairingWitness { - fn result(&self) -> &ArkGT { - &self.result + fn result(&self) -> Option<&ArkGT> { + Some(&self.result) } } @@ -167,8 +167,8 @@ pub struct MultiPairingWitness { } impl WitnessResult for MultiPairingWitness { - fn result(&self) -> &ArkGT { - &self.result + fn result(&self) -> Option<&ArkGT> { + Some(&self.result) } } @@ -190,8 +190,8 @@ pub struct MsmG1Witness { } impl WitnessResult for MsmG1Witness { - fn result(&self) -> &ArkG1 { - &self.result + fn result(&self) -> Option<&ArkG1> { + Some(&self.result) } } @@ -211,8 +211,8 @@ pub struct MsmG2Witness { } impl WitnessResult for MsmG2Witness { - fn result(&self) -> &ArkG2 { - &self.result + fn result(&self) -> Option<&ArkG2> { + Some(&self.result) } } diff --git a/src/recursion/collection.rs b/src/recursion/collection.rs index bc724df..0da6d35 100644 --- a/src/recursion/collection.rs +++ b/src/recursion/collection.rs @@ -106,32 +106,48 @@ impl WitnessCollection { // Extract GT results for (id, w) in &self.gt_exp { - hints.insert_gt(*id, *w.result()); + if let Some(result) = w.result() { + hints.insert_gt(*id, *result); + } } for (id, w) in &self.gt_mul { - hints.insert_gt(*id, *w.result()); + if let Some(result) = w.result() { + hints.insert_gt(*id, *result); + } } for (id, w) in &self.pairing { - hints.insert_gt(*id, *w.result()); + if let Some(result) = w.result() { + hints.insert_gt(*id, *result); + } } for (id, w) in &self.multi_pairing { - hints.insert_gt(*id, *w.result()); + if let Some(result) = w.result() { + hints.insert_gt(*id, *result); + } } // Extract G1 results for (id, w) in &self.g1_scalar_mul { - hints.insert_g1(*id, *w.result()); + if let Some(result) = w.result() { + hints.insert_g1(*id, *result); + } } for (id, w) in &self.msm_g1 { - hints.insert_g1(*id, *w.result()); + if let Some(result) = w.result() { + hints.insert_g1(*id, *result); + } } // Extract G2 results for (id, w) in &self.g2_scalar_mul { - hints.insert_g2(*id, *w.result()); + if let Some(result) = w.result() { + hints.insert_g2(*id, *result); + } } for (id, w) in &self.msm_g2 { - hints.insert_g2(*id, *w.result()); + if let Some(result) = w.result() { + hints.insert_g2(*id, *result); + } } hints diff --git a/src/recursion/mod.rs b/src/recursion/mod.rs index f95fec7..42f65f5 100644 --- a/src/recursion/mod.rs +++ b/src/recursion/mod.rs @@ -53,9 +53,8 @@ pub use collection::WitnessCollection; pub use collector::WitnessGenerator; pub use context::{CtxHandle, TraceContext}; pub use hint_map::HintMap; -pub use witness::{OpId, OpType, WitnessBackend}; +pub use witness::{OpId, OpType, WitnessBackend, WitnessResult}; pub(crate) use collector::{OpIdBuilder, WitnessCollector}; pub(crate) use context::ExecutionMode; pub(crate) use trace::{TraceG1, TraceG2, TraceGT, TracePairing}; -pub(crate) use witness::WitnessResult; diff --git a/src/recursion/witness.rs b/src/recursion/witness.rs index 02e66a3..9691a30 100644 --- a/src/recursion/witness.rs +++ b/src/recursion/witness.rs @@ -100,6 +100,7 @@ pub trait WitnessBackend: Sized + Send + Sync + 'static { /// Trait for extracting the result from a witness. pub trait WitnessResult { - /// Get the result of the operation. - fn result(&self) -> &T; + /// Get the result of the operation if implemented. + /// Returns None for unimplemented operations. + fn result(&self) -> Option<&T>; } From 08a7fef962e06d4b3b81fce7cc8e37598b24bc50 Mon Sep 17 00:00:00 2001 From: markosg04 Date: Mon, 12 Jan 2026 15:59:27 -0500 Subject: [PATCH 3/4] fix: recursion flag --- Cargo.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Cargo.toml b/Cargo.toml index c66484a..98c96d5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,6 +48,7 @@ arkworks = [ parallel = ["dep:rayon", "ark-ec?/parallel", "ark-ff?/parallel"] cache = ["arkworks", "dep:once_cell", "parallel"] disk-persistence = [] +recursion = ["arkworks"] [dependencies] thiserror = "2.0" From 3fefea95efdfd86d2389cecbc0d3c7bbbb8ae4bb Mon Sep 17 00:00:00 2001 From: markosg04 Date: Thu, 15 Jan 2026 22:51:30 -0500 Subject: [PATCH 4/4] refactor: HashMap -> BtreeMap for HintMap --- src/evaluation_proof.rs | 75 ++++----- src/recursion/context.rs | 6 +- src/recursion/hint_map.rs | 336 ++++++++++++++++++++++++-------------- src/recursion/witness.rs | 4 +- 4 files changed, 253 insertions(+), 168 deletions(-) diff --git a/src/evaluation_proof.rs b/src/evaluation_proof.rs index dc831d7..0e54c37 100644 --- a/src/evaluation_proof.rs +++ b/src/evaluation_proof.rs @@ -473,10 +473,12 @@ where // Create trace operators let pairing = TracePairing::new(Rc::clone(&ctx)); - // VMV check pairing: d2 == e(e1, h2) - let e1_trace = TraceG1::new(vmv_message.e1, Rc::clone(&ctx)); + // VMV check: d2 == e(e1, h2) + // This check binds the hints to the proof - using wrong hints will cause + // the pairing result to not match the proof's d2 value. let h2_trace = TraceG2::new(setup.h2, Rc::clone(&ctx)); - let pairing_check = pairing.pair(&e1_trace, &h2_trace); + let e1_vmv = TraceG1::new(vmv_message.e1, Rc::clone(&ctx)); + let pairing_check = pairing.pair(&e1_vmv, &h2_trace); if vmv_message.d2 != *pairing_check.inner() { return Err(DoryError::InvalidProof); @@ -610,56 +612,49 @@ where ctx.enter_final(); let gamma = transcript.challenge_scalar(b"gamma"); + + transcript.append_serde(b"final_e1", &proof.final_message.e1); + transcript.append_serde(b"final_e2", &proof.final_message.e2); + let d_challenge = transcript.challenge_scalar(b"d"); let gamma_inv = gamma.inv().expect("gamma must be invertible"); let d_inv = d_challenge.inv().expect("d must be invertible"); + let neg_gamma = -gamma; + let neg_gamma_inv = -gamma_inv; - // Final verification with tracing + // Optimized final verification using batched multi-pairing (3 ML + 1 FE) + // Note: VMV check was done early (for hint binding), so no d² terms here. + // + // RHS (GT terms): T = C + (s₁·s₂)·HT + χ₀ + d·D₂ + d⁻¹·D₁ let s_product = s1_acc * s2_acc; let ht_trace = TraceGT::new(setup.ht, Rc::clone(&ctx)); - let ht_scaled = ht_trace.scale(&s_product); - c = c + ht_scaled; - - // Traced pairings - let h1_trace = TraceG1::new(setup.h1, Rc::clone(&ctx)); - let pairing_h1_e2 = pairing.pair(&h1_trace, &e2_state); - let pairing_e1_h2 = pairing.pair(&e1, &h2_trace); - - c = c + pairing_h1_e2.scale(&gamma); - c = c + pairing_e1_h2.scale(&gamma_inv); - - // D1 update with traced operations - let scalar_for_g2_in_d1 = s1_acc * gamma; - let g2_0_trace = TraceG2::new(setup.g2_0, Rc::clone(&ctx)); - let g2_0_scaled = g2_0_trace.scale(&scalar_for_g2_in_d1); - - let pairing_h1_g2 = pairing.pair(&h1_trace, &g2_0_scaled); - d1 = d1 + pairing_h1_g2; + let mut rhs = c + ht_trace.scale(&s_product); + rhs = rhs + TraceGT::new(setup.chi[0], Rc::clone(&ctx)); + rhs = rhs + d2.scale(&d_challenge); + rhs = rhs + d1.scale(&d_inv); - // D2 update with traced operations - let scalar_for_g1_in_d2 = s2_acc * gamma_inv; + // Pair 1: (E₁_final + d·Γ₁₀, E₂_final + d⁻¹·Γ₂₀) let g1_0_trace = TraceG1::new(setup.g1_0, Rc::clone(&ctx)); - let g1_0_scaled = g1_0_trace.scale(&scalar_for_g1_in_d2); - - let pairing_g1_h2 = pairing.pair(&g1_0_scaled, &h2_trace); - d2 = d2 + pairing_g1_h2; - - // Final pairing check + let g2_0_trace = TraceG2::new(setup.g2_0, Rc::clone(&ctx)); let e1_final = TraceG1::new(proof.final_message.e1, Rc::clone(&ctx)); - let g1_0_d_scaled = g1_0_trace.scale(&d_challenge); - let e1_modified = e1_final + g1_0_d_scaled; - let e2_final = TraceG2::new(proof.final_message.e2, Rc::clone(&ctx)); - let g2_0_d_inv_scaled = g2_0_trace.scale(&d_inv); - let e2_modified = e2_final + g2_0_d_inv_scaled; + let p1_g1 = e1_final + g1_0_trace.scale(&d_challenge); + let p1_g2 = e2_final + g2_0_trace.scale(&d_inv); - let lhs = pairing.pair(&e1_modified, &e2_modified); + // Pair 2: (H₁, (-γ)·(E₂_acc + (d⁻¹·s₁)·Γ₂₀)) + let h1_trace = TraceG1::new(setup.h1, Rc::clone(&ctx)); + let d_inv_s1 = d_inv * s1_acc; + let g2_term = e2_state + g2_0_trace.scale(&d_inv_s1); + let p2_g2 = g2_term.scale(&neg_gamma); - let mut rhs = c; - rhs = rhs + TraceGT::new(setup.chi[0], Rc::clone(&ctx)); - rhs = rhs + d2.scale(&d_challenge); - rhs = rhs + d1.scale(&d_inv); + // Pair 3: ((-γ⁻¹)·(E₁_acc + (d·s₂)·Γ₁₀), H₂) + let d_s2 = d_challenge * s2_acc; + let g1_term = e1 + g1_0_trace.scale(&d_s2); + let p3_g1 = g1_term.scale(&neg_gamma_inv); + + // Multi-pairing check: 3 miller loops + 1 final exponentiation + let lhs = pairing.multi_pair(&[p1_g1, h1_trace, p3_g1], &[p1_g2, p2_g2, h2_trace]); if *lhs.inner() == *rhs.inner() { Ok(()) diff --git a/src/recursion/context.rs b/src/recursion/context.rs index 19f8696..5abad44 100644 --- a/src/recursion/context.rs +++ b/src/recursion/context.rs @@ -151,19 +151,19 @@ where /// Get a G1 hint for the given operation. #[inline] pub fn get_hint_g1(&self, id: OpId) -> Option { - self.hints.as_ref().and_then(|h| h.get_g1(id).copied()) + self.hints.as_ref().and_then(|h| h.get_g1(&id).copied()) } /// Get a G2 hint for the given operation. #[inline] pub fn get_hint_g2(&self, id: OpId) -> Option { - self.hints.as_ref().and_then(|h| h.get_g2(id).copied()) + self.hints.as_ref().and_then(|h| h.get_g2(&id).copied()) } /// Get a GT hint for the given operation. #[inline] pub fn get_hint_gt(&self, id: OpId) -> Option { - self.hints.as_ref().and_then(|h| h.get_gt(id).copied()) + self.hints.as_ref().and_then(|h| h.get_gt(&id).copied()) } /// Record a GT exponentiation witness. diff --git a/src/recursion/hint_map.rs b/src/recursion/hint_map.rs index ea7c183..cc5e1b5 100644 --- a/src/recursion/hint_map.rs +++ b/src/recursion/hint_map.rs @@ -1,10 +1,4 @@ -//! Lightweight hint storage for recursive verification. -//! -//! This module provides [`HintMap`], a simplified storage structure that holds -//! only operation results (not full witnesses with intermediate computation steps). -//! This results in ~30-50x smaller storage compared to full witness collections. - -use std::collections::HashMap; +use std::collections::BTreeMap; use std::io::{Read, Write}; use super::witness::{OpId, OpType}; @@ -51,41 +45,37 @@ impl HintResult { matches!(self, HintResult::GT(_)) } - /// Try to get as G1, returns None if wrong variant. + /// Extract a G1 result, returning None if this is not a G1 variant. #[inline] pub fn as_g1(&self) -> Option<&E::G1> { - match self { - HintResult::G1(g1) => Some(g1), - _ => None, + if let HintResult::G1(g1) = self { + Some(g1) + } else { + None } } - /// Try to get as G2, returns None if wrong variant. + /// Extract a G2 result, returning None if this is not a G2 variant. #[inline] pub fn as_g2(&self) -> Option<&E::G2> { - match self { - HintResult::G2(g2) => Some(g2), - _ => None, + if let HintResult::G2(g2) = self { + Some(g2) + } else { + None } } - /// Try to get as GT, returns None if wrong variant. + /// Extract a GT result, returning None if this is not a GT variant. #[inline] pub fn as_gt(&self) -> Option<&E::GT> { - match self { - HintResult::GT(gt) => Some(gt), - _ => None, + if let HintResult::GT(gt) = self { + Some(gt) + } else { + None } } } -impl Valid for HintResult { - fn check(&self) -> Result<(), SerializationError> { - // Curve points are validated during deserialization - Ok(()) - } -} - impl DorySerialize for HintResult { fn serialize_with_mode( &self, @@ -94,18 +84,19 @@ impl DorySerialize for HintResult { ) -> Result<(), SerializationError> { match self { HintResult::G1(g1) => { - TAG_G1.serialize_with_mode(&mut writer, compress)?; - g1.serialize_with_mode(writer, compress) + DorySerialize::serialize_with_mode(&TAG_G1, &mut writer, compress)?; + DorySerialize::serialize_with_mode(g1, &mut writer, compress)?; } HintResult::G2(g2) => { - TAG_G2.serialize_with_mode(&mut writer, compress)?; - g2.serialize_with_mode(writer, compress) + DorySerialize::serialize_with_mode(&TAG_G2, &mut writer, compress)?; + DorySerialize::serialize_with_mode(g2, &mut writer, compress)?; } HintResult::GT(gt) => { - TAG_GT.serialize_with_mode(&mut writer, compress)?; - gt.serialize_with_mode(writer, compress) + DorySerialize::serialize_with_mode(&TAG_GT, &mut writer, compress)?; + DorySerialize::serialize_with_mode(gt, &mut writer, compress)?; } } + Ok(()) } fn serialized_size(&self, compress: Compress) -> usize { @@ -123,25 +114,35 @@ impl DoryDeserialize for HintResult { compress: Compress, validate: Validate, ) -> Result { - let tag = u8::deserialize_with_mode(&mut reader, compress, validate)?; + let tag = ::deserialize_with_mode(&mut reader, compress, validate)?; match tag { - TAG_G1 => Ok(HintResult::G1(E::G1::deserialize_with_mode( - reader, compress, validate, - )?)), - TAG_G2 => Ok(HintResult::G2(E::G2::deserialize_with_mode( - reader, compress, validate, - )?)), - TAG_GT => Ok(HintResult::GT(E::GT::deserialize_with_mode( - reader, compress, validate, - )?)), - _ => Err(SerializationError::InvalidData(format!( - "Invalid HintResult tag: {tag}" - ))), + TAG_G1 => { + let g1 = E::G1::deserialize_with_mode(&mut reader, compress, validate)?; + Ok(HintResult::G1(g1)) + } + TAG_G2 => { + let g2 = E::G2::deserialize_with_mode(&mut reader, compress, validate)?; + Ok(HintResult::G2(g2)) + } + TAG_GT => { + let gt = E::GT::deserialize_with_mode(&mut reader, compress, validate)?; + Ok(HintResult::GT(gt)) + } + _ => Err(SerializationError::InvalidData( + "Invalid HintResult tag".to_string(), + )), } } } -/// Hint storage +impl Valid for HintResult { + fn check(&self) -> Result<(), SerializationError> { + // Curve elements are already validated upon creation + Ok(()) + } +} + +/// A lightweight hint storage for recursive verification. /// /// Unlike [`WitnessCollection`](crate::recursion::WitnessCollection) which stores /// full computation traces, this stores only the final results for each operation, @@ -151,7 +152,7 @@ pub struct HintMap { /// Number of reduce-and-fold rounds in the verification pub num_rounds: usize, /// All operation results indexed by OpId - results: HashMap>, + results: BTreeMap>, } impl HintMap { @@ -159,94 +160,98 @@ impl HintMap { pub fn new(num_rounds: usize) -> Self { Self { num_rounds, - results: HashMap::new(), + results: BTreeMap::new(), } } /// Get G1 result for an operation. - /// - /// Returns None if the operation is not found or is not a G1 result. - #[inline] - pub fn get_g1(&self, id: OpId) -> Option<&E::G1> { - self.results.get(&id).and_then(|r| r.as_g1()) + pub fn get_g1(&self, op_id: &OpId) -> Option<&E::G1> { + self.results.get(op_id)?.as_g1() } /// Get G2 result for an operation. - /// - /// Returns None if the operation is not found or is not a G2 result. - #[inline] - pub fn get_g2(&self, id: OpId) -> Option<&E::G2> { - self.results.get(&id).and_then(|r| r.as_g2()) + pub fn get_g2(&self, op_id: &OpId) -> Option<&E::G2> { + self.results.get(op_id)?.as_g2() } /// Get GT result for an operation. - /// - /// Returns None if the operation is not found or is not a GT result. - #[inline] - pub fn get_gt(&self, id: OpId) -> Option<&E::GT> { - self.results.get(&id).and_then(|r| r.as_gt()) + pub fn get_gt(&self, op_id: &OpId) -> Option<&E::GT> { + self.results.get(op_id)?.as_gt() } - /// Get raw result enum for an operation. - #[inline] - pub fn get(&self, id: OpId) -> Option<&HintResult> { - self.results.get(&id) + /// Get any result for an operation. + pub fn get(&self, op_id: &OpId) -> Option<&HintResult> { + self.results.get(op_id) } - /// Insert a G1 result. - #[inline] - pub fn insert_g1(&mut self, id: OpId, value: E::G1) { - self.results.insert(id, HintResult::G1(value)); + /// Insert a result for an operation. + pub fn insert(&mut self, op_id: OpId, result: HintResult) -> Option> { + self.results.insert(op_id, result) } - /// Insert a G2 result. - #[inline] - pub fn insert_g2(&mut self, id: OpId, value: E::G2) { - self.results.insert(id, HintResult::G2(value)); + /// Insert a G1 result for an operation. + pub fn insert_g1(&mut self, op_id: OpId, result: E::G1) -> Option> { + self.results.insert(op_id, HintResult::G1(result)) } - /// Insert a GT result. - #[inline] - pub fn insert_gt(&mut self, id: OpId, value: E::GT) { - self.results.insert(id, HintResult::GT(value)); + /// Insert a G2 result for an operation. + pub fn insert_g2(&mut self, op_id: OpId, result: E::G2) -> Option> { + self.results.insert(op_id, HintResult::G2(result)) } - /// Total number of hints stored. - #[inline] + /// Insert a GT result for an operation. + pub fn insert_gt(&mut self, op_id: OpId, result: E::GT) -> Option> { + self.results.insert(op_id, HintResult::GT(result)) + } + + /// Number of operations stored. pub fn len(&self) -> usize { self.results.len() } - /// Check if the hint map is empty. - #[inline] + /// Check if empty. pub fn is_empty(&self) -> bool { self.results.is_empty() } - /// Iterate over all (OpId, HintResult) pairs. + /// Iterator over all operations and results. pub fn iter(&self) -> impl Iterator)> { self.results.iter() } - /// Check if a hint exists for the given operation. - #[inline] - pub fn contains(&self, id: OpId) -> bool { - self.results.contains_key(&id) - } -} + /// Count operations by type. + pub fn count_by_type(&self) -> (usize, usize, usize) { + let mut g1_count = 0; + let mut g2_count = 0; + let mut gt_count = 0; + + for result in self.results.values() { + match result { + HintResult::G1(_) => g1_count += 1, + HintResult::G2(_) => g2_count += 1, + HintResult::GT(_) => gt_count += 1, + } + } -impl Default for HintMap { - fn default() -> Self { - Self::new(0) + (g1_count, g2_count, gt_count) } -} -impl Valid for HintMap { - fn check(&self) -> Result<(), SerializationError> { - for result in self.results.values() { - result.check()?; + /// Count operations by round and type. + pub fn stats(&self) -> Vec<(u16, OpType, usize)> { + use std::collections::HashMap; + + let mut stats: HashMap<(u16, OpType), usize> = HashMap::new(); + + for op_id in self.results.keys() { + *stats.entry((op_id.round, op_id.op_type)).or_insert(0) += 1; } - Ok(()) + + let mut result: Vec<_> = stats + .into_iter() + .map(|((round, op_type), count)| (round, op_type, count)) + .collect(); + result.sort(); + result } } @@ -256,27 +261,31 @@ impl DorySerialize for HintMap { mut writer: W, compress: Compress, ) -> Result<(), SerializationError> { - (self.num_rounds as u64).serialize_with_mode(&mut writer, compress)?; - (self.results.len() as u64).serialize_with_mode(&mut writer, compress)?; + DorySerialize::serialize_with_mode(&(self.num_rounds as u64), &mut writer, compress)?; + DorySerialize::serialize_with_mode(&(self.results.len() as u64), &mut writer, compress)?; for (id, result) in &self.results { // Serialize OpId as (round: u16, op_type: u8, index: u16) - id.round.serialize_with_mode(&mut writer, compress)?; - (id.op_type as u8).serialize_with_mode(&mut writer, compress)?; - id.index.serialize_with_mode(&mut writer, compress)?; - result.serialize_with_mode(&mut writer, compress)?; + DorySerialize::serialize_with_mode(&id.round, &mut writer, compress)?; + DorySerialize::serialize_with_mode(&(id.op_type as u8), &mut writer, compress)?; + DorySerialize::serialize_with_mode(&id.index, &mut writer, compress)?; + + // Serialize the result + DorySerialize::serialize_with_mode(result, &mut writer, compress)?; } + Ok(()) } fn serialized_size(&self, compress: Compress) -> usize { - let header = 8 + 8; // num_rounds + len - let entries: usize = self - .results - .values() - .map(|r| 2 + 1 + 2 + r.serialized_size(compress)) - .sum(); - header + entries + let mut size = 8 + 8; // num_rounds + len + + for result in self.results.values() { + size += 2 + 1 + 2; // OpId: round + op_type + index + size += result.serialized_size(compress); + } + + size } } @@ -286,14 +295,20 @@ impl DoryDeserialize for HintMap { compress: Compress, validate: Validate, ) -> Result { - let num_rounds = u64::deserialize_with_mode(&mut reader, compress, validate)? as usize; - let len = u64::deserialize_with_mode(&mut reader, compress, validate)? as usize; + let num_rounds = + ::deserialize_with_mode(&mut reader, compress, validate)? + as usize; + let len = ::deserialize_with_mode(&mut reader, compress, validate)? + as usize; - let mut results = HashMap::with_capacity(len); + let mut results = BTreeMap::new(); for _ in 0..len { - let round = u16::deserialize_with_mode(&mut reader, compress, validate)?; - let op_type_byte = u8::deserialize_with_mode(&mut reader, compress, validate)?; - let index = u16::deserialize_with_mode(&mut reader, compress, validate)?; + let round = + ::deserialize_with_mode(&mut reader, compress, validate)?; + let op_type_byte = + ::deserialize_with_mode(&mut reader, compress, validate)?; + let index = + ::deserialize_with_mode(&mut reader, compress, validate)?; let op_type = match op_type_byte { 0 => OpType::GtExp, @@ -305,15 +320,20 @@ impl DoryDeserialize for HintMap { 6 => OpType::MsmG1, 7 => OpType::MsmG2, _ => { - return Err(SerializationError::InvalidData(format!( - "Invalid OpType: {op_type_byte}" - ))) + return Err(SerializationError::InvalidData( + "Invalid OpType byte".to_string(), + )) } }; - let id = OpId::new(round, op_type, index); + let op_id = OpId { + round, + op_type, + index, + }; + let result = HintResult::deserialize_with_mode(&mut reader, compress, validate)?; - results.insert(id, result); + results.insert(op_id, result); } Ok(Self { @@ -322,3 +342,73 @@ impl DoryDeserialize for HintMap { }) } } + +// Implement ark-serialize traits by delegating to DorySerialize/DoryDeserialize +use ark_serialize::{ + CanonicalDeserialize, CanonicalSerialize, Compress as ArkCompress, Read as ArkRead, + SerializationError as ArkSerializationError, Valid as ArkValid, Validate as ArkValidate, + Write as ArkWrite, +}; + +// NOTE: These implementations preserve the original error information from Dory's +// serialization for better debugging. The error messages include the underlying +// cause to help diagnose serialization/deserialization failures. +impl CanonicalSerialize for HintMap { + fn serialize_with_mode( + &self, + writer: W, + compress: ArkCompress, + ) -> Result<(), ArkSerializationError> { + let compress = if matches!(compress, ArkCompress::Yes) { + Compress::Yes + } else { + Compress::No + }; + DorySerialize::serialize_with_mode(self, writer, compress).map_err(|e| { + ArkSerializationError::IoError(std::io::Error::new( + std::io::ErrorKind::Other, + format!("HintMap serialization failed: {:?}", e), + )) + }) + } + + fn serialized_size(&self, compress: ArkCompress) -> usize { + let compress = if matches!(compress, ArkCompress::Yes) { + Compress::Yes + } else { + Compress::No + }; + DorySerialize::serialized_size(self, compress) + } +} + +impl CanonicalDeserialize for HintMap { + fn deserialize_with_mode( + reader: R, + compress: ArkCompress, + validate: ArkValidate, + ) -> Result { + let compress = if matches!(compress, ArkCompress::Yes) { + Compress::Yes + } else { + Compress::No + }; + let validate = if matches!(validate, ArkValidate::Yes) { + Validate::Yes + } else { + Validate::No + }; + DoryDeserialize::deserialize_with_mode(reader, compress, validate).map_err(|e| { + ArkSerializationError::IoError(std::io::Error::new( + std::io::ErrorKind::Other, + format!("HintMap deserialization failed: {:?}", e), + )) + }) + } +} + +impl ArkValid for HintMap { + fn check(&self) -> Result<(), ArkSerializationError> { + Ok(()) + } +} diff --git a/src/recursion/witness.rs b/src/recursion/witness.rs index 9691a30..4ed3a36 100644 --- a/src/recursion/witness.rs +++ b/src/recursion/witness.rs @@ -1,7 +1,7 @@ //! Witness generation types and traits for recursive proof composition. /// Operation type identifier for witness indexing. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] #[repr(u8)] pub enum OpType { /// GT exponentiation: base^scalar in the target group @@ -26,7 +26,7 @@ pub enum OpType { /// /// Operations are indexed by (round, op_type, index) to enable deterministic /// mapping between witness generation and hint consumption. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct OpId { /// Protocol round number (0 for initial checks, 1..=num_rounds for reduce rounds) pub round: u16,