diff --git a/air/src/air.rs b/air/src/air.rs index e9b140d98..18d31e3b9 100644 --- a/air/src/air.rs +++ b/air/src/air.rs @@ -1,3 +1,5 @@ +use alloc::vec; +use alloc::vec::Vec; use core::ops::{Add, Mul, Sub}; use p3_field::{Algebra, ExtensionField, Field, PrimeCharacteristicRing}; @@ -8,10 +10,22 @@ use p3_matrix::dense::RowMajorMatrix; pub trait BaseAir: Sync { /// The number of columns (a.k.a. registers) in this AIR. fn width(&self) -> usize; + /// Return an optional preprocessed trace matrix to be included in the prover's trace. fn preprocessed_trace(&self) -> Option> { None } + + /// Return the periodic table data. + /// + /// Periodic columns are columns whose values repeat with a period that divides the trace + /// length. Each inner `Vec` represents one periodic column. The length of the inner + /// vector is the period of that column (must be a power of 2 that divides the trace length). + /// + /// By default returns an empty table (no periodic columns). + fn periodic_table(&self) -> Vec> { + vec![] + } } /// An extension of `BaseAir` that includes support for public values. @@ -232,6 +246,20 @@ pub trait PermutationAirBuilder: ExtensionBuilder { fn permutation_randomness(&self) -> &[Self::RandomVar]; } +/// Trait for builders supporting periodic columns. +/// +/// Periodic columns are columns whose values repeat with a period dividing the trace length. +/// They are never committed to the proof - instead, both prover and verifier compute them +/// from the periodic table data provided by the AIR. +pub trait PeriodicAirBuilder: AirBuilder { + /// Variable type for periodic column values. + /// For the prover, this is base field; for the verifier, this is extension field. + type PeriodicVar: Into + Copy; + + /// Return the evaluations of periodic columns at the current row. + fn periodic_values(&self) -> &[Self::PeriodicVar]; +} + /// A wrapper around an [`AirBuilder`] that enforces constraints only when a specified condition is met. /// /// This struct allows selectively applying constraints to certain rows or under certain conditions in the AIR, @@ -287,6 +315,14 @@ impl PairBuilder for FilteredAirBuilder<'_, AB> { } } +impl AirBuilderWithPublicValues for FilteredAirBuilder<'_, AB> { + type PublicVar = AB::PublicVar; + + fn public_values(&self) -> &[Self::PublicVar] { + self.inner.public_values() + } +} + impl ExtensionBuilder for FilteredAirBuilder<'_, AB> { type EF = AB::EF; type ExprEF = AB::ExprEF; @@ -316,3 +352,11 @@ impl PermutationAirBuilder for FilteredAirBuilder<'_, self.inner.permutation_randomness() } } + +impl PeriodicAirBuilder for FilteredAirBuilder<'_, AB> { + type PeriodicVar = AB::PeriodicVar; + + fn periodic_values(&self) -> &[Self::PeriodicVar] { + self.inner.periodic_values() + } +} diff --git a/batch-stark/src/prover.rs b/batch-stark/src/prover.rs index 54090aa66..772ac26b7 100644 --- a/batch-stark/src/prover.rs +++ b/batch-stark/src/prover.rs @@ -646,6 +646,7 @@ where decomposed_alpha_powers: &decomposed_alpha_powers, accumulator, constraint_index: 0, + periodic_values: vec![], // batch-stark doesn't support periodic columns yet }; let packed_perm_challenges = permutation_challenges .iter() diff --git a/batch-stark/src/verifier.rs b/batch-stark/src/verifier.rs index 75395f6d7..039ea8d90 100644 --- a/batch-stark/src/verifier.rs +++ b/batch-stark/src/verifier.rs @@ -540,6 +540,7 @@ where is_transition: sels.is_transition, alpha: *alpha, accumulator: SC::Challenge::ZERO, + periodic_values: vec![], // batch-stark doesn't support periodic columns yet }; let mut folder = VerifierConstraintFolderWithLookups { inner: inner_folder, diff --git a/circle/Cargo.toml b/circle/Cargo.toml index 68a37b912..c3c39827b 100644 --- a/circle/Cargo.toml +++ b/circle/Cargo.toml @@ -10,6 +10,7 @@ keywords.workspace = true categories.workspace = true [dependencies] +p3-air.workspace = true p3-challenger.workspace = true p3-commit.workspace = true p3-dft.workspace = true diff --git a/circle/src/lib.rs b/circle/src/lib.rs index 9605646f2..3ae392563 100644 --- a/circle/src/lib.rs +++ b/circle/src/lib.rs @@ -11,6 +11,7 @@ mod domain; mod folding; mod ordering; mod pcs; +mod periodic; mod point; mod proof; mod prover; @@ -20,4 +21,5 @@ pub use cfft::*; pub use domain::*; pub use ordering::*; pub use pcs::*; +pub use periodic::*; pub use proof::*; diff --git a/circle/src/periodic.rs b/circle/src/periodic.rs new file mode 100644 index 000000000..65f4aa453 --- /dev/null +++ b/circle/src/periodic.rs @@ -0,0 +1,346 @@ +//! Periodic column support for Circle STARKs. +//! +//! This module provides `CirclePeriodicEvaluator` for evaluating periodic columns +//! in Circle STARK proofs. The implementation supports: +//! +//! - `eval_on_lde`: Evaluates periodic columns on the LDE domain using CFFT extrapolation. +//! All columns are padded to the maximum period, creating a rectangular matrix that +//! stores only `max_period × blowup` rows with modular indexing for O(1) lookup. +//! +//! - `eval_at_point`: Evaluates periodic columns at arbitrary points using polynomial +//! evaluation with repeated doubling projection. This is used by the verifier. +//! +//! ## Memory Efficiency +//! +//! Instead of materializing the full LDE-sized table, we store only `max_period × blowup` +//! rows. For a trace of size 2^20 with period-4 columns and blowup 4, this means storing +//! 16 rows instead of 4M rows per column. +//! +//! ## Complexity +//! +//! - `eval_on_lde`: O(max_period × blowup × log(max_period × blowup)) for CFFT extrapolation, +//! then O(1) per LDE point lookup using modular indexing. +//! +//! - `eval_at_point`: O(period) per column using polynomial evaluation. +//! +//! Note: The current `eval_at_point` implementation is not optimized for multiple columns +//! with the same period. The interpolation setup could be shared across columns with the +//! same period. + +use alloc::vec::Vec; + +use p3_commit::{PeriodicEvaluator, PeriodicLdeTable, PolynomialSpace}; +use p3_field::ExtensionField; +use p3_field::extension::ComplexExtendable; +use p3_matrix::Matrix; +use p3_matrix::dense::RowMajorMatrix; +use p3_util::log2_strict_usize; + +use crate::CircleEvaluations; +use crate::domain::CircleDomain; +use crate::point::Point; + +/// Evaluates periodic polynomials for Circle STARKs. +/// +/// For a periodic column with period `p` and trace length `n`, the periodic values +/// are interpolated on a Circle domain of size `p`. To evaluate at any point: +/// 1. Interpolate the periodic values on a small Circle domain of size `p` +/// 2. Project the query point to the periodic subdomain via repeated doubling +/// 3. Evaluate the polynomial at the projected point +#[derive(Clone, Copy, Debug, Default)] +pub struct CirclePeriodicEvaluator; + +impl CirclePeriodicEvaluator { + pub const fn new() -> Self { + Self + } +} + +/// Compute parameters for periodic polynomial evaluation. +/// Returns (log_period, log_repetitions). +fn periodic_params(period: usize, trace_len: usize) -> (usize, usize) { + debug_assert!( + period.is_power_of_two(), + "periodic column length must be a power of 2" + ); + + let log_period = log2_strict_usize(period); + let log_repetitions = log2_strict_usize(trace_len / period); + (log_period, log_repetitions) +} + +impl PeriodicEvaluator> for CirclePeriodicEvaluator { + fn eval_on_lde( + periodic_table: &[Vec], + trace_domain: &CircleDomain, + lde_domain: &CircleDomain, + ) -> PeriodicLdeTable { + if periodic_table.is_empty() { + return PeriodicLdeTable::empty(); + } + + let trace_len = trace_domain.size(); + let log_blowup = lde_domain.log_n - trace_domain.log_n; + let blowup = 1 << log_blowup; + + // Find the maximum period and validate all columns + let max_period = periodic_table + .iter() + .map(|col| { + let period = col.len(); + debug_assert!( + period.is_power_of_two(), + "periodic column length must be a power of 2" + ); + period + }) + .max() + .unwrap(); + + let log_max_period = log2_strict_usize(max_period); + let log_repetitions = log2_strict_usize(trace_len / max_period); + let extended_height = max_period * blowup; + let num_cols = periodic_table.len(); + + // Compute the shift for the periodic subdomain at max_period. + // This aligns the periodic domain with the LDE domain so modular indexing works. + let extended_shift = lde_domain.shift.repeated_double(log_repetitions); + let extended_periodic_domain = + CircleDomain::new(log_max_period + log_blowup, extended_shift); + + // Process each column: pad to max_period, then extrapolate + // Build the result in column-major order first, then transpose to row-major + let mut columns: Vec> = Vec::with_capacity(num_cols); + + for col in periodic_table { + let period = col.len(); + + // Pad column to max_period by repeating values + let padded: Vec = if period == max_period { + col.clone() + } else { + (0..max_period).map(|i| col[i % period]).collect() + }; + + // Interpolate on the max_period domain + let periodic_domain = CircleDomain::standard(log_max_period); + let evals = CircleEvaluations::from_natural_order( + periodic_domain, + RowMajorMatrix::new_col(padded), + ); + + // Extrapolate to extended_height using CFFT + let extended_evals = evals.extrapolate(extended_periodic_domain); + let extended_values = extended_evals.to_natural_order().to_row_major_matrix(); + columns.push(extended_values.values); + } + + // Convert from column-major to row-major storage + let mut row_major_values = Vec::with_capacity(extended_height * num_cols); + for row_idx in 0..extended_height { + for col in &columns { + row_major_values.push(col[row_idx]); + } + } + + PeriodicLdeTable::new(RowMajorMatrix::new(row_major_values, num_cols)) + } + + fn eval_at_point>( + periodic_table: &[Vec], + trace_domain: &CircleDomain, + point: EF, + ) -> Vec { + let trace_len = trace_domain.size(); + + periodic_table + .iter() + .map(|col| { + let (log_period, log_repetitions) = periodic_params(col.len(), trace_len); + let periodic_domain = CircleDomain::standard(log_period); + + let evals = CircleEvaluations::from_natural_order( + periodic_domain, + RowMajorMatrix::new_col(col.clone()), + ); + + // Project query point to periodic subdomain via repeated doubling + let query_point = Point::::from_projective_line(point); + let periodic_point = query_point.repeated_double(log_repetitions); + evals.evaluate_at_point(periodic_point)[0] + }) + .collect() + } +} + +#[cfg(test)] +mod tests { + use alloc::vec; + + use hashbrown::HashMap; + use p3_field::PrimeCharacteristicRing; + use p3_field::extension::BinomialExtensionField; + use p3_mersenne_31::Mersenne31; + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; + + use super::*; + + type F = Mersenne31; + type EF = BinomialExtensionField; + + #[test] + fn test_periodic_eval_consistency_random_points() { + // Test that eval_on_lde and eval_at_point define the same polynomial + // by checking consistency at random out-of-domain points + let log_n = 4; + let log_blowup = 1; + let trace_domain = CircleDomain::::standard(log_n); + let lde_domain = CircleDomain::::standard(log_n + log_blowup); + let lde_len = lde_domain.size(); + + // Periodic column: [10, 20, 30, 40] + let periodic_col = vec![ + F::from_u32(10), + F::from_u32(20), + F::from_u32(30), + F::from_u32(40), + ]; + let periodic_table = vec![periodic_col.clone()]; + + // Evaluate on LDE domain + let lde_table = + CirclePeriodicEvaluator::eval_on_lde(&periodic_table, &trace_domain, &lde_domain); + + assert_eq!(lde_table.width(), 1); + // Compact table has height = period * blowup = 4 * 2 = 8 + assert_eq!(lde_table.height(), 8); + + // Expand compact table to full LDE for interpolation test + let full_lde: Vec = (0..lde_len).map(|i| *lde_table.get(i, 0)).collect(); + + // Interpolate the LDE result to get a polynomial we can evaluate anywhere + let lde_evals = + CircleEvaluations::from_natural_order(lde_domain, RowMajorMatrix::new_col(full_lde)); + + // Test at random out-of-domain points + let mut rng = SmallRng::seed_from_u64(42); + for _ in 0..10 { + let random_point: EF = rng.random(); + + // Evaluate the LDE polynomial at the random point + let lde_at_point = + lde_evals.evaluate_at_point(Point::from_projective_line(random_point))[0]; + + // Evaluate using eval_at_point directly + let eval_at_point_result = CirclePeriodicEvaluator::eval_at_point( + &periodic_table, + &trace_domain, + random_point, + ); + + assert_eq!( + lde_at_point, eval_at_point_result[0], + "Mismatch at random point: LDE interpolation={:?}, eval_at_point={:?}", + lde_at_point, eval_at_point_result[0] + ); + } + } + + #[test] + fn test_periodic_eval_at_trace_domain_points() { + // Test that evaluating the periodic polynomial at trace domain points + // gives the expected periodic pattern + let log_n = 4; // 16 rows + let trace_domain = CircleDomain::::standard(log_n); + let trace_len = trace_domain.size(); + let period = 4; + + // Periodic column: [1, 2, 3, 4] + let periodic_col = vec![ + F::from_u32(1), + F::from_u32(2), + F::from_u32(3), + F::from_u32(4), + ]; + let periodic_table = vec![periodic_col.clone()]; + + // Evaluate on trace domain (same as LDE with blowup=1) + let lde_table = + CirclePeriodicEvaluator::eval_on_lde(&periodic_table, &trace_domain, &trace_domain); + + assert_eq!(lde_table.width(), 1); + // Compact table has height = period * blowup = 4 * 1 = 4 + assert_eq!(lde_table.height(), 4); + + // Expand compact table to full trace + let full_trace: Vec = (0..trace_len).map(|i| *lde_table.get(i, 0)).collect(); + + // The values should follow a periodic pattern with period 4 + // But the exact mapping depends on Circle domain structure. + // Verify that we get exactly 4 distinct values, each appearing 4 times. + let mut value_counts = HashMap::new(); + for &val in &full_trace { + *value_counts.entry(val).or_insert(0) += 1; + } + assert_eq!( + value_counts.len(), + period, + "Expected {} distinct values, got {}", + period, + value_counts.len() + ); + for (val, count) in &value_counts { + assert_eq!( + *count, 4, + "Value {:?} appears {} times, expected 4", + val, count + ); + } + } + + #[test] + fn test_cfft_extrapolation_matches_naive() { + // Verify that the CFFT-based eval_on_lde matches point-by-point evaluation + // using the naive repeated_double approach. + for (log_n, log_blowup, log_period) in [(4, 1, 2), (5, 2, 2), (6, 1, 3), (8, 2, 4)] { + let trace_domain = CircleDomain::::standard(log_n); + let lde_domain = CircleDomain::::standard(log_n + log_blowup); + let lde_len = lde_domain.size(); + let period = 1 << log_period; + let log_repetitions = log_n - log_period; + + // Create a periodic column with distinct values + let periodic_col: Vec = (0..period).map(|i| F::from_u32(i as u32 + 1)).collect(); + let periodic_table = vec![periodic_col.clone()]; + + // Evaluate using the optimized CFFT-based method + let cfft_table = + CirclePeriodicEvaluator::eval_on_lde(&periodic_table, &trace_domain, &lde_domain); + + // Expand compact table to full LDE + let cfft_result: Vec = (0..lde_len).map(|i| *cfft_table.get(i, 0)).collect(); + + // Evaluate using the naive point-by-point method + let periodic_domain = CircleDomain::standard(log_period); + let evals = CircleEvaluations::from_natural_order( + periodic_domain, + RowMajorMatrix::new_col(periodic_col.clone()), + ); + + let naive_result: Vec = (0..lde_len) + .map(|lde_idx| { + let lde_point = lde_domain.nth_point(lde_idx); + let periodic_point = lde_point.repeated_double(log_repetitions); + evals.evaluate_at_point(periodic_point)[0] + }) + .collect(); + + assert_eq!( + cfft_result, naive_result, + "CFFT-based and naive methods disagree for log_n={}, log_blowup={}, log_period={}", + log_n, log_blowup, log_period + ); + } + } +} diff --git a/circle/src/point.rs b/circle/src/point.rs index e1cc14a35..574f9abfe 100644 --- a/circle/src/point.rs +++ b/circle/src/point.rs @@ -59,6 +59,14 @@ impl Point { Self::new(self.x.square().double() - F::ONE, self.x.double() * self.y) } + /// Apply the doubling map `n` times: π^n(x,y) + pub fn repeated_double(mut self, n: usize) -> Self { + for _ in 0..n { + self = self.double(); + } + self + } + /// Evaluate the vanishing polynomial for the standard position coset of size 2^log_n /// at this point /// Circle STARKs, Section 3.3, Equation 8 (page 10 of the first revision PDF) diff --git a/commit/src/lib.rs b/commit/src/lib.rs index 2a2f6b97d..4835ef5ea 100644 --- a/commit/src/lib.rs +++ b/commit/src/lib.rs @@ -8,6 +8,7 @@ mod adapters; mod domain; mod mmcs; mod pcs; +mod periodic; #[cfg(any(test, feature = "test-utils"))] pub mod testing; @@ -16,3 +17,4 @@ pub use adapters::*; pub use domain::*; pub use mmcs::*; pub use pcs::*; +pub use periodic::*; diff --git a/commit/src/periodic.rs b/commit/src/periodic.rs new file mode 100644 index 000000000..529a6c547 --- /dev/null +++ b/commit/src/periodic.rs @@ -0,0 +1,225 @@ +//! Periodic column evaluation support. +//! +//! Periodic columns are columns whose values repeat with a period that divides the trace length. +//! This module provides the `PeriodicEvaluator` trait for evaluating periodic polynomials +//! in a domain-agnostic way (supporting both two-adic and circle STARKs). +//! +//! ## Power-of-Two Requirement +//! +//! **All period lengths must be powers of two.** This is because: +//! - The trace domain is a multiplicative/additive group of order `n` (a power of 2) +//! - The periodic subdomain must be a subgroup of order `p` +//! - For `p` to divide `n` as group orders, `p` must also be a power of 2 +//! +//! ## Mathematical Background +//! +//! A periodic column with period `p` and trace length `n` repeats every `p` rows: +//! `col[i] = col[i + p]` for all `i`. +//! +//! **The problem**: We have a polynomial `P` of degree `n-1` over the trace domain `H`, +//! but it only takes `p` distinct values. Can we work with a smaller polynomial instead? +//! +//! **Key observation**: We want `P(ω^i) = P(ω^{i+p})` for all `i`. So we need a map +//! `π: H → ?` that identifies points `p` apart: `π(ω^i) = π(ω^{i+p})`, i.e., `π` must +//! be constant on cosets of the subgroup `⟨ω^p⟩` of order `n/p`. +//! +//! **Finding π**: For cyclic groups, raising to the power `k` gives a homomorphism with +//! kernel of size `k`. Since we need `ker(π) = ⟨ω^p⟩` of order `n/p`, we set `π(x) = x^(n/p)`. +//! Indeed, `π(ω^{i+p}) = ω^{(i+p)·n/p} = ω^{i·n/p} · ω^n = π(ω^i)` since `ω^n = 1`. +//! +//! **Where π lands**: The image of `π` is `H_p = {1, ω^(n/p), ω^(2n/p), ...}`, a subgroup +//! of order `p`. Now we can factor `P = Q ∘ π` where `Q: H_p → F` is a degree `p-1` +//! polynomial interpolating the `p` periodic values. +//! +//! **Group-theoretic view**: `π: H → H_p` is a surjective homomorphism with kernel of +//! order `n/p`. By the first isomorphism theorem, `H/ker(π) ≅ H_p`. The periodic column +//! is constant on cosets of `ker(π)`, so it factors through `π`. +//! +//! **For Circle STARKs**: The same idea applies with `π(P) = (n/p)·P` (repeated doubling) +//! instead of exponentiation. +//! +//! **Evaluating at an out-of-domain point `ζ`**: +//! 1. Compute `π(ζ)` to get a point in `H_p` +//! 2. Evaluate `Q(π(ζ))` using Lagrange interpolation over `H_p` +//! +//! ## Memory-Efficient Storage +//! +//! Instead of materializing the full LDE-sized table (which would be wasteful for small periods), +//! we store only `max_period × blowup` rows in a [`PeriodicLdeTable`]. All periodic columns are +//! padded to the maximum period, creating a rectangular matrix that can be efficiently accessed +//! with modular indexing in the constraint evaluation hot loop. + +use alloc::vec::Vec; + +use p3_field::{ExtensionField, Field}; +use p3_matrix::dense::RowMajorMatrix; + +use crate::PolynomialSpace; + +/// Compact storage for periodic column values on the LDE domain. +/// +/// Instead of materializing the full LDE-sized table, stores only `extended_height` rows +/// (where `extended_height = max_period × blowup`) and uses modular indexing to access values. +/// +/// All periodic columns are padded to the maximum period before extrapolation, creating a +/// rectangular matrix for cache-friendly row-wise access. +/// +/// # Invariants +/// +/// - All periods must be powers of 2 (see module-level documentation) +/// - Height is always `max_period × blowup` (both powers of 2, so height is power of 2) +#[derive(Clone, Debug)] +pub struct PeriodicLdeTable { + /// Values in row-major form: height = extended_height, width = num_columns. + /// Empty if there are no periodic columns. + values: RowMajorMatrix, +} + +impl PeriodicLdeTable { + /// Create a new periodic LDE table from extrapolated values. + /// + /// The matrix should have height = `max_period × blowup` and width = `num_periodic_columns`. + pub const fn new(values: RowMajorMatrix) -> Self { + Self { values } + } + + /// Create an empty table (for AIRs without periodic columns). + pub fn empty() -> Self { + Self { + values: RowMajorMatrix::new(Vec::new(), 0), + } + } + + /// Returns true if there are no periodic columns. + pub const fn is_empty(&self) -> bool { + self.values.values.is_empty() + } + + /// Number of periodic columns. + pub const fn width(&self) -> usize { + self.values.width + } + + /// Height of the compact table (max_period × blowup). + pub const fn height(&self) -> usize { + if self.values.width == 0 { + 0 + } else { + self.values.values.len() / self.values.width + } + } + + /// Get all periodic column values for a given LDE index using modular indexing. + /// + /// Returns a slice of length `width()` containing the value of each periodic column. + #[inline] + pub fn get_row(&self, lde_idx: usize) -> &[F] { + let height = self.height(); + debug_assert!(height > 0, "cannot index into empty periodic table"); + let row_idx = lde_idx % height; + let start = row_idx * self.values.width; + let end = start + self.values.width; + &self.values.values[start..end] + } + + /// Get a specific periodic column value for a given LDE index. + #[inline] + pub fn get(&self, lde_idx: usize, col_idx: usize) -> &F { + let height = self.height(); + debug_assert!(height > 0, "cannot index into empty periodic table"); + let row_idx = lde_idx % height; + &self.values.values[row_idx * self.values.width + col_idx] + } +} + +/// Evaluates periodic polynomials for a given domain system. +/// +/// Periodic columns are defined by their values over one period. This trait +/// handles interpolation and evaluation, abstracting over the domain-specific +/// math (two-adic multiplicative groups vs circle groups). +/// +/// # Power-of-Two Requirement +/// +/// **All period lengths must be powers of two.** This ensures the periodic subdomain +/// is a valid subgroup of the trace domain. See module-level documentation for details. +/// +/// # Type Parameters +/// - `F`: The base field type +/// - `D`: The polynomial space / domain type +pub trait PeriodicEvaluator> { + /// Evaluate all periodic columns on the LDE domain, returning a compact table. + /// + /// This is used by the prover to compute periodic column values on the + /// low-degree extension domain for constraint evaluation. + /// + /// The returned table stores only `max_period × blowup` rows. All columns are + /// padded to the maximum period before extrapolation, creating a rectangular + /// matrix for efficient row-wise access with modular indexing. + /// + /// # Arguments + /// * `periodic_table` - Slice of periodic columns, each containing one period of values. + /// The length of each inner `Vec` is the period of that column (must be a power of 2). + /// * `trace_domain` - The original trace domain + /// * `lde_domain` - The low-degree extension domain + /// + /// # Returns + /// A [`PeriodicLdeTable`] with height = `max_period × blowup` and width = number of columns. + fn eval_on_lde( + periodic_table: &[Vec], + trace_domain: &D, + lde_domain: &D, + ) -> PeriodicLdeTable; + + /// Evaluate all periodic columns at a single point (for verification). + /// + /// This is used by the verifier to compute periodic column values at + /// query points during constraint verification. + /// + /// # Arguments + /// * `periodic_table` - Slice of periodic columns. Each column's length (period) + /// must be a power of 2. + /// * `trace_domain` - The original trace domain + /// * `point` - The query point (in extension field) + /// + /// # Returns + /// `Vec` containing the evaluation of each periodic column at `point` + fn eval_at_point>( + periodic_table: &[Vec], + trace_domain: &D, + point: EF, + ) -> Vec; +} + +/// Unit type implements `PeriodicEvaluator` as a no-op. +/// +/// This is used internally by `prove` and `verify` for AIRs without periodic columns. +/// Panics if any periodic columns are present. +impl> PeriodicEvaluator for () { + fn eval_on_lde( + periodic_table: &[Vec], + _trace_domain: &D, + _lde_domain: &D, + ) -> PeriodicLdeTable { + assert!( + periodic_table.is_empty(), + "AIR has periodic columns but no PeriodicEvaluator was specified. \ + Use prove_with_periodic or verify_with_periodic with TwoAdicPeriodicEvaluator \ + or CirclePeriodicEvaluator." + ); + PeriodicLdeTable::empty() + } + + fn eval_at_point>( + periodic_table: &[Vec], + _trace_domain: &D, + _point: EF, + ) -> Vec { + assert!( + periodic_table.is_empty(), + "AIR has periodic columns but no PeriodicEvaluator was specified. \ + Use prove_with_periodic or verify_with_periodic with TwoAdicPeriodicEvaluator \ + or CirclePeriodicEvaluator." + ); + Vec::new() + } +} diff --git a/fri/src/lib.rs b/fri/src/lib.rs index 598503cff..b16b42d16 100644 --- a/fri/src/lib.rs +++ b/fri/src/lib.rs @@ -6,6 +6,7 @@ extern crate alloc; mod config; mod hiding_pcs; +mod periodic; mod proof; pub mod prover; mod two_adic_pcs; @@ -13,5 +14,6 @@ pub mod verifier; pub use config::*; pub use hiding_pcs::*; +pub use periodic::*; pub use proof::*; pub use two_adic_pcs::*; diff --git a/fri/src/periodic.rs b/fri/src/periodic.rs new file mode 100644 index 000000000..7a2144525 --- /dev/null +++ b/fri/src/periodic.rs @@ -0,0 +1,246 @@ +//! Two-adic periodic column evaluator. +//! +//! This module provides `TwoAdicPeriodicEvaluator` for evaluating periodic columns +//! in two-adic STARK proofs. The implementation supports: +//! +//! - `eval_on_lde`: Evaluates periodic columns on the LDE domain using FFT extrapolation. +//! All columns are padded to the maximum period, creating a rectangular matrix that +//! stores only `max_period × blowup` rows with modular indexing for O(1) lookup. +//! +//! - `eval_at_point`: Evaluates periodic columns at arbitrary points using Lagrange +//! interpolation. This is used by the verifier. +//! +//! ## Memory Efficiency +//! +//! Instead of materializing the full LDE-sized table, we store only `max_period × blowup` +//! rows. For a trace of size 2^20 with period-4 columns and blowup 4, this means storing +//! 16 rows instead of 4M rows per column. +//! +//! ## Complexity +//! +//! - `eval_on_lde`: O(max_period × blowup × log(max_period × blowup)) for FFT extrapolation, +//! then O(1) per LDE point lookup using modular indexing. +//! +//! - `eval_at_point`: O(period) per column using direct Lagrange interpolation. +//! +//! Note: The current `eval_at_point` implementation is not optimized for multiple columns +//! with the same period. The Lagrange basis evaluations (or barycentric weights) at `point` +//! could be computed once per distinct period and reused across columns sharing that period. + +use alloc::vec::Vec; + +use p3_commit::{PeriodicEvaluator, PeriodicLdeTable}; +use p3_dft::TwoAdicSubgroupDft; +use p3_field::coset::TwoAdicMultiplicativeCoset; +use p3_field::{ExtensionField, TwoAdicField}; +use p3_matrix::dense::RowMajorMatrix; +use p3_util::log2_strict_usize; + +/// Evaluates periodic polynomials for two-adic multiplicative cosets. +/// +/// For a periodic column with period `p` and trace length `n`, the periodic values +/// repeat every `n/p` rows. To evaluate at an arbitrary point `ζ`: +/// 1. Project `ζ` to the periodic subdomain: `ζ_periodic = ζ^(n/p)` +/// 2. Evaluate the degree-(p-1) polynomial at `ζ_periodic` +#[derive(Clone, Default)] +pub struct TwoAdicPeriodicEvaluator { + _phantom: core::marker::PhantomData, +} + +impl TwoAdicPeriodicEvaluator { + pub fn new(_dft: Dft) -> Self { + Self { + _phantom: core::marker::PhantomData, + } + } +} + +impl PeriodicEvaluator> for TwoAdicPeriodicEvaluator +where + F: TwoAdicField, + Dft: TwoAdicSubgroupDft, +{ + fn eval_on_lde( + periodic_table: &[Vec], + trace_domain: &TwoAdicMultiplicativeCoset, + lde_domain: &TwoAdicMultiplicativeCoset, + ) -> PeriodicLdeTable { + if periodic_table.is_empty() { + return PeriodicLdeTable::empty(); + } + + let trace_len = trace_domain.size(); + let lde_len = lde_domain.size(); + let lde_shift = lde_domain.shift(); + let blowup = lde_len / trace_len; + let log_blowup = log2_strict_usize(blowup); + + // Find the maximum period and validate all columns + let max_period = periodic_table + .iter() + .map(|col| { + let period = col.len(); + debug_assert!( + period.is_power_of_two(), + "periodic column length must be a power of 2" + ); + period + }) + .max() + .unwrap(); + + let extended_height = max_period * blowup; + let num_cols = periodic_table.len(); + + // Compute the shift for the periodic subdomain at max_period. + // This aligns the periodic domain with the LDE domain so modular indexing works. + let periodic_shift = lde_shift.exp_u64((lde_len / extended_height) as u64); + + let dft = Dft::default(); + + // Process each column: pad to max_period, then extrapolate + // Build the result in column-major order first, then transpose to row-major + let mut columns: Vec> = Vec::with_capacity(num_cols); + + for col in periodic_table { + let period = col.len(); + + // Pad column to max_period by repeating values + let padded: Vec = if period == max_period { + col.clone() + } else { + (0..max_period).map(|i| col[i % period]).collect() + }; + + // Extrapolate to extended_height using DFT + let extended = dft.coset_lde(padded, log_blowup, periodic_shift); + columns.push(extended); + } + + // Convert from column-major to row-major storage + let mut row_major_values = Vec::with_capacity(extended_height * num_cols); + for row_idx in 0..extended_height { + for col in &columns { + row_major_values.push(col[row_idx]); + } + } + + PeriodicLdeTable::new(RowMajorMatrix::new(row_major_values, num_cols)) + } + + fn eval_at_point>( + periodic_table: &[Vec], + trace_domain: &TwoAdicMultiplicativeCoset, + point: EF, + ) -> Vec { + let trace_len = trace_domain.size(); + + periodic_table + .iter() + .map(|col| { + let period = col.len(); + debug_assert!( + period.is_power_of_two(), + "periodic column length must be a power of 2" + ); + + let exponent = (trace_len / period) as u64; + + // Project point to periodic subdomain: ζ^(n/p) + let periodic_point = point.exp_u64(exponent); + + // Evaluate the periodic polynomial at the projected point + eval_periodic_poly(col, periodic_point) + }) + .collect() + } +} + +/// Evaluate a periodic polynomial at a single point using Lagrange interpolation. +/// +/// The polynomial is the unique degree-(period-1) polynomial that interpolates +/// the given values over the subgroup of size `period`. +fn eval_periodic_poly>(values: &[F], point: EF) -> EF { + let period = values.len(); + if period == 0 { + return EF::ZERO; + } + if period == 1 { + return EF::from(values[0]); + } + + let log_period = log2_strict_usize(period); + let omega = F::two_adic_generator(log_period); + + // Compute Lagrange interpolation at `point` + // L_i(x) = ∏_{j≠i} (x - ω^j) / (ω^i - ω^j) + // + // For efficiency, use the formula: + // p(x) = (x^n - 1) / n * Σ_i (v_i / (x - ω^i)) + // where n = period, ω is the primitive n-th root of unity + + let n = EF::from(F::from_usize(period)); + let x_n_minus_1 = point.exp_u64(period as u64) - EF::ONE; + + // Handle case where point is on the subgroup (x^n - 1 = 0) + // In this case, return the value at that point directly + if x_n_minus_1.is_zero() { + // Find which root of unity `point` equals + let mut omega_i = F::ONE; + for &val in values.iter() { + if point == EF::from(omega_i) { + return EF::from(val); + } + omega_i *= omega; + } + // If we get here, something is wrong + return EF::ZERO; + } + + // Compute Σ_i (v_i * ω^i / (x - ω^i)) + // Derived from Lagrange interpolation: L_i(x) = (x^n - 1) / (n * ω^{i(n-1)} * (x - ω^i)) + // Since ω^{-i(n-1)} = ω^i (as ω^n = 1), we get factor ω^i + let mut sum = EF::ZERO; + let mut omega_i = F::ONE; + for &val in values.iter() { + let denom = point - EF::from(omega_i); + // denom should be non-zero since x^n - 1 ≠ 0 + sum += EF::from(val * omega_i) * denom.inverse(); + omega_i *= omega; + } + + x_n_minus_1 * sum / n +} + +#[cfg(test)] +mod tests { + use alloc::vec; + + use super::*; + use p3_baby_bear::BabyBear; + use p3_field::PrimeCharacteristicRing; + + type F = BabyBear; + + #[test] + fn test_eval_periodic_poly_constant() { + // Constant polynomial: all values are 5 + let values = vec![F::from_u64(5); 4]; + let point = F::from_u64(7); + let result = eval_periodic_poly(&values, point); + assert_eq!(result, F::from_u64(5)); + } + + #[test] + fn test_eval_periodic_poly_at_roots() { + // Values at roots of unity should interpolate correctly + let values: Vec = vec![1, 2, 3, 4].into_iter().map(F::from_u64).collect(); + let omega = F::two_adic_generator(2); // 4th root of unity + + for (i, &expected) in values.iter().enumerate() { + let point = omega.exp_u64(i as u64); + let result = eval_periodic_poly(&values, point); + assert_eq!(result, expected, "Failed at root index {}", i); + } + } +} diff --git a/lookup/src/folder.rs b/lookup/src/folder.rs index 6cde2bc6a..dad44869d 100644 --- a/lookup/src/folder.rs +++ b/lookup/src/folder.rs @@ -1,5 +1,6 @@ use p3_air::{ - AirBuilder, AirBuilderWithPublicValues, ExtensionBuilder, PairBuilder, PermutationAirBuilder, + AirBuilder, AirBuilderWithPublicValues, ExtensionBuilder, PairBuilder, PeriodicAirBuilder, + PermutationAirBuilder, }; use p3_matrix::dense::RowMajorMatrixView; use p3_matrix::stack::ViewPair; @@ -8,6 +9,7 @@ use p3_uni_stark::{ VerifierConstraintFolder, }; + pub struct ProverConstraintFolderWithLookups<'a, SC: StarkGenericConfig> { pub inner: ProverConstraintFolder<'a, SC>, pub permutation: RowMajorMatrixView<'a, PackedChallenge>, @@ -107,6 +109,14 @@ impl<'a, SC: StarkGenericConfig> PermutationAirBuilder } } +impl PeriodicAirBuilder for ProverConstraintFolderWithLookups<'_, SC> { + type PeriodicVar = PackedVal; + + fn periodic_values(&self) -> &[Self::PeriodicVar] { + self.inner.periodic_values() + } +} + pub struct VerifierConstraintFolderWithLookups<'a, SC: StarkGenericConfig> { pub inner: VerifierConstraintFolder<'a, SC>, pub permutation: ViewPair<'a, SC::Challenge>, @@ -205,3 +215,11 @@ impl<'a, SC: StarkGenericConfig> PermutationAirBuilder self.permutation_challenges } } + +impl PeriodicAirBuilder for VerifierConstraintFolderWithLookups<'_, SC> { + type PeriodicVar = SC::Challenge; + + fn periodic_values(&self) -> &[Self::PeriodicVar] { + self.inner.periodic_values() + } +} diff --git a/lookup/src/lookup_traits.rs b/lookup/src/lookup_traits.rs index 2776b09c8..f0f7df3e4 100644 --- a/lookup/src/lookup_traits.rs +++ b/lookup/src/lookup_traits.rs @@ -452,6 +452,10 @@ impl> BaseAir for AirNoLookup { fn preprocessed_trace(&self) -> Option> { self.air.preprocessed_trace() } + + fn periodic_table(&self) -> Vec> { + self.air.periodic_table() + } } impl> Air for AirNoLookup { diff --git a/uni-stark/src/check_constraints.rs b/uni-stark/src/check_constraints.rs index 7f8a5f490..2b514ec8d 100644 --- a/uni-stark/src/check_constraints.rs +++ b/uni-stark/src/check_constraints.rs @@ -1,4 +1,6 @@ -use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, PairBuilder}; +use alloc::vec::Vec; + +use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, PairBuilder, PeriodicAirBuilder}; use p3_field::Field; use p3_matrix::Matrix; use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixView}; @@ -23,6 +25,7 @@ where { let height = main.height(); let preprocessed = air.preprocessed_trace(); + let periodic_table = air.periodic_table(); (0..height).for_each(|row_index| { let row_index_next = (row_index + 1) % height; @@ -49,11 +52,21 @@ where None }; + // Compute periodic values for this row + let periodic_values: Vec = periodic_table + .iter() + .map(|col| { + let period = col.len(); + col[row_index % period] + }) + .collect(); + let mut builder = DebugConstraintBuilder { row_index, main, preprocessed: preprocessed_pair, public_values, + periodic_values, is_first_row: F::from_bool(row_index == 0), is_last_row: F::from_bool(row_index == height - 1), is_transition: F::from_bool(row_index != height - 1), @@ -77,6 +90,8 @@ pub struct DebugConstraintBuilder<'a, F: Field> { preprocessed: Option>, /// The public values provided for constraint validation (e.g. inputs or outputs). public_values: &'a [F], + /// The periodic column values at the current row. + periodic_values: Vec, /// A flag indicating whether this is the first row. is_first_row: F, /// A flag indicating whether this is the last row. @@ -151,6 +166,14 @@ impl<'a, F: Field> PairBuilder for DebugConstraintBuilder<'a, F> { } } +impl PeriodicAirBuilder for DebugConstraintBuilder<'_, F> { + type PeriodicVar = F; + + fn periodic_values(&self) -> &[Self::PeriodicVar] { + &self.periodic_values + } +} + #[cfg(test)] mod tests { use alloc::vec; diff --git a/uni-stark/src/folder.rs b/uni-stark/src/folder.rs index bac5eaa88..7e905102c 100644 --- a/uni-stark/src/folder.rs +++ b/uni-stark/src/folder.rs @@ -1,6 +1,6 @@ use alloc::vec::Vec; -use p3_air::{AirBuilder, AirBuilderWithPublicValues, PairBuilder}; +use p3_air::{AirBuilder, AirBuilderWithPublicValues, PairBuilder, PeriodicAirBuilder}; use p3_field::{BasedVectorSpace, PackedField}; use p3_matrix::dense::RowMajorMatrixView; use p3_matrix::stack::ViewPair; @@ -35,6 +35,8 @@ pub struct ProverConstraintFolder<'a, SC: StarkGenericConfig> { pub accumulator: PackedChallenge, /// Current constraint index being processed pub constraint_index: usize, + /// Evaluations of periodic columns at the current row (base field for prover) + pub periodic_values: Vec>, } /// Handles constraint verification for the verifier in a STARK system. @@ -59,6 +61,8 @@ pub struct VerifierConstraintFolder<'a, SC: StarkGenericConfig> { pub alpha: SC::Challenge, /// Running accumulator for all constraints pub accumulator: SC::Challenge, + /// Evaluations of periodic columns + pub periodic_values: Vec, } impl<'a, SC: StarkGenericConfig> AirBuilder for ProverConstraintFolder<'a, SC> { @@ -123,6 +127,14 @@ impl AirBuilderWithPublicValues for ProverConstraintFold } } +impl PeriodicAirBuilder for ProverConstraintFolder<'_, SC> { + type PeriodicVar = PackedVal; + + fn periodic_values(&self) -> &[Self::PeriodicVar] { + &self.periodic_values + } +} + impl<'a, SC: StarkGenericConfig> PairBuilder for ProverConstraintFolder<'a, SC> { #[inline] fn preprocessed(&self) -> Self::M { @@ -175,6 +187,14 @@ impl AirBuilderWithPublicValues for VerifierConstraintFo } } +impl PeriodicAirBuilder for VerifierConstraintFolder<'_, SC> { + type PeriodicVar = SC::Challenge; + + fn periodic_values(&self) -> &[Self::PeriodicVar] { + &self.periodic_values + } +} + impl<'a, SC: StarkGenericConfig> PairBuilder for VerifierConstraintFolder<'a, SC> { fn preprocessed(&self) -> Self::M { self.preprocessed diff --git a/uni-stark/src/prover.rs b/uni-stark/src/prover.rs index 5a4353175..b96b32108 100644 --- a/uni-stark/src/prover.rs +++ b/uni-stark/src/prover.rs @@ -5,6 +5,7 @@ use itertools::Itertools; use p3_air::Air; use p3_challenger::{CanObserve, FieldChallenger}; use p3_commit::{Pcs, PolynomialSpace}; +use p3_commit::{PeriodicEvaluator, PeriodicLdeTable}; use p3_field::{BasedVectorSpace, PackedFieldExtension, PackedValue, PrimeCharacteristicRing}; use p3_matrix::Matrix; use p3_matrix::dense::RowMajorMatrix; @@ -18,8 +19,12 @@ use crate::{ get_log_num_quotient_chunks, get_symbolic_constraints, }; +/// Prove a STARK for an AIR with preprocessed columns (but no periodic columns). +/// +/// Use this when your AIR has preprocessed columns but no periodic columns. +/// For AIRs with both, use [`prove_with_preprocessed_and_periodic`]. #[instrument(skip_all)] -#[allow(clippy::multiple_bound_locations, clippy::type_repetition_in_bounds)] // cfg not supported in where clauses? +#[allow(clippy::multiple_bound_locations, clippy::type_repetition_in_bounds)] pub fn prove_with_preprocessed< SC, #[cfg(debug_assertions)] A: for<'a> Air>>, @@ -34,6 +39,38 @@ pub fn prove_with_preprocessed< where SC: StarkGenericConfig, A: Air>> + for<'a> Air>, +{ + prove_with_preprocessed_and_periodic::( + config, + air, + trace, + public_values, + preprocessed, + ) +} + +/// Prove a STARK for an AIR with both preprocessed and periodic columns. +/// +/// This is the most general proving function. The type parameter `PE` specifies +/// how to evaluate periodic columns on the LDE domain. +#[instrument(skip_all)] +#[allow(clippy::multiple_bound_locations, clippy::type_repetition_in_bounds)] +pub fn prove_with_preprocessed_and_periodic< + SC, + #[cfg(debug_assertions)] A: for<'a> Air>>, + #[cfg(not(debug_assertions))] A, + PE, +>( + config: &SC, + air: &A, + trace: RowMajorMatrix>, + public_values: &[Val], + preprocessed: Option<&PreprocessedProverData>, +) -> Proof +where + SC: StarkGenericConfig, + A: Air>> + for<'a> Air>, + PE: PeriodicEvaluator, Domain>, { #[cfg(debug_assertions)] crate::check_constraints::check_constraints(air, &trace, public_values); @@ -215,7 +252,7 @@ where // `C(T_1(x), ..., T_w(x), T_1(hx), ..., T_w(hx), selectors(x)) / Z_H(x)` // at every point in the quotient domain. The degree of `Q(x)` is `<= deg(C(x)) - N = 2N - 2` in the case // where `deg(C) = 3`. (See the discussion above constraint_degree for more details.) - let quotient_values = quotient_values( + let quotient_values = quotient_values::( air, public_values, trace_domain, @@ -354,8 +391,13 @@ where } } +/// Prove a STARK for an AIR without preprocessed or periodic columns. +/// +/// This is the simplest entry point. For AIRs with preprocessed columns, +/// use [`prove_with_preprocessed`]. For AIRs with periodic columns, +/// use [`prove_with_periodic`]. #[instrument(skip_all)] -#[allow(clippy::multiple_bound_locations, clippy::type_repetition_in_bounds)] // cfg not supported in where clauses? +#[allow(clippy::multiple_bound_locations, clippy::type_repetition_in_bounds)] pub fn prove< SC, #[cfg(debug_assertions)] A: for<'a> Air>>, @@ -370,13 +412,38 @@ where SC: StarkGenericConfig, A: Air>> + for<'a> Air>, { - prove_with_preprocessed::(config, air, trace, public_values, None) + prove_with_preprocessed_and_periodic::(config, air, trace, public_values, None) +} + +/// Prove a STARK for an AIR with periodic columns (but no preprocessed columns). +/// +/// The type parameter `PE` specifies how to evaluate periodic columns on the LDE domain. +/// Use [`TwoAdicPeriodicEvaluator`](p3_fri::TwoAdicPeriodicEvaluator) for two-adic STARKs +/// or [`CirclePeriodicEvaluator`](p3_circle::CirclePeriodicEvaluator) for Circle STARKs. +#[instrument(skip_all)] +#[allow(clippy::multiple_bound_locations, clippy::type_repetition_in_bounds)] +pub fn prove_with_periodic< + SC, + #[cfg(debug_assertions)] A: for<'a> Air>>, + #[cfg(not(debug_assertions))] A, + PE, +>( + config: &SC, + air: &A, + trace: RowMajorMatrix>, + public_values: &[Val], +) -> Proof +where + SC: StarkGenericConfig, + A: Air>> + for<'a> Air>, + PE: PeriodicEvaluator, Domain>, +{ + prove_with_preprocessed_and_periodic::(config, air, trace, public_values, None) } #[instrument(skip_all)] -// TODO: Group some arguments to remove the `allow`? #[allow(clippy::too_many_arguments)] -pub fn quotient_values( +pub fn quotient_values( air: &A, public_values: &[Val], trace_domain: Domain, @@ -390,22 +457,32 @@ where SC: StarkGenericConfig, A: for<'a> Air>, Mat: Matrix> + Sync, + PE: PeriodicEvaluator, Domain>, { let quotient_size = quotient_domain.size(); let width = trace_on_quotient_domain.width(); let mut sels = debug_span!("Compute Selectors") .in_scope(|| trace_domain.selectors_on_coset(quotient_domain)); + // Compute periodic column values on quotient domain + let periodic_table = air.periodic_table(); + let periodic_on_quotient = debug_span!("Compute Periodic Columns") + .in_scope(|| PE::eval_on_lde(&periodic_table, &trace_domain, "ient_domain)); + let qdb = log2_strict_usize(quotient_domain.size()) - log2_strict_usize(trace_domain.size()); let next_step = 1 << qdb; // We take PackedVal::::WIDTH worth of values at a time from a quotient_size slice, so we need to // pad with default values in the case where quotient_size is smaller than PackedVal::::WIDTH. - for _ in quotient_size..PackedVal::::WIDTH { - sels.is_first_row.push(Val::::default()); - sels.is_last_row.push(Val::::default()); - sels.is_transition.push(Val::::default()); - sels.inv_vanishing.push(Val::::default()); + if quotient_size < PackedVal::::WIDTH { + sels.is_first_row + .resize(PackedVal::::WIDTH, Val::::default()); + sels.is_last_row + .resize(PackedVal::::WIDTH, Val::::default()); + sels.is_transition + .resize(PackedVal::::WIDTH, Val::::default()); + sels.inv_vanishing + .resize(PackedVal::::WIDTH, Val::::default()); } let mut alpha_powers = alpha.powers().collect_n(constraint_count); @@ -444,6 +521,10 @@ where ) }); + // Extract packed periodic values for this chunk using modular indexing + let periodic_values: Vec> = + extract_periodic_values(&periodic_on_quotient, i_start); + let accumulator = PackedChallenge::::ZERO; let mut folder = ProverConstraintFolder { main: main.as_view(), @@ -456,6 +537,7 @@ where decomposed_alpha_powers: &decomposed_alpha_powers, accumulator, constraint_index: 0, + periodic_values, }; air.eval(&mut folder); @@ -468,3 +550,30 @@ where }) .collect() } + +/// Extract packed periodic values from a compact periodic LDE table using modular indexing. +/// +/// The periodic table stores only `max_period × blowup` rows. We use modular indexing +/// to access the correct values for each position in the quotient domain. +#[inline] +fn extract_periodic_values(periodic_table: &PeriodicLdeTable, i_start: usize) -> Vec

+where + F: Clone + Send + Sync, + P: PackedValue, +{ + let num_periodic_cols = periodic_table.width(); + if num_periodic_cols == 0 { + return vec![]; + } + + // For each column, gather WIDTH values using modular indexing + let mut result = Vec::with_capacity(num_periodic_cols); + for col_idx in 0..num_periodic_cols { + // Gather WIDTH values from the compact table + let values: Vec = (0..P::WIDTH) + .map(|j| periodic_table.get(i_start + j, col_idx).clone()) + .collect(); + result.push(*P::from_slice(&values)); + } + result +} diff --git a/uni-stark/src/symbolic_builder.rs b/uni-stark/src/symbolic_builder.rs index 5762ad938..ce037599a 100644 --- a/uni-stark/src/symbolic_builder.rs +++ b/uni-stark/src/symbolic_builder.rs @@ -2,12 +2,13 @@ use alloc::vec; use alloc::vec::Vec; use p3_air::{ - Air, AirBuilder, AirBuilderWithPublicValues, ExtensionBuilder, PairBuilder, - PermutationAirBuilder, + Air, AirBuilder, AirBuilderWithPublicValues, BaseAir, ExtensionBuilder, PairBuilder, + PeriodicAirBuilder, PermutationAirBuilder, }; use p3_field::{ExtensionField, Field}; use p3_matrix::dense::RowMajorMatrix; use p3_util::log2_ceil_usize; + use tracing::instrument; use crate::Entry; @@ -23,7 +24,7 @@ pub fn get_log_num_quotient_chunks( ) -> usize where F: Field, - A: Air>, + A: BaseAir + Air>, { get_log_quotient_degree_extension(air, preprocessed_width, num_public_values, 0, 0, is_zk) } @@ -40,7 +41,7 @@ pub fn get_log_quotient_degree_extension( where F: Field, EF: ExtensionField, - A: Air>, + A: BaseAir + Air>, { assert!(is_zk <= 1, "is_zk must be either 0 or 1"); // We pad to at least degree 2, since a quotient argument doesn't make sense with smaller degrees. @@ -67,7 +68,7 @@ pub fn get_max_constraint_degree( ) -> usize where F: Field, - A: Air>, + A: BaseAir + Air>, { get_max_constraint_degree_extension(air, preprocessed_width, num_public_values, 0, 0) } @@ -87,7 +88,7 @@ pub fn get_max_constraint_degree_extension( where F: Field, EF: ExtensionField, - A: Air>, + A: BaseAir + Air>, { let (base_constraints, extension_constraints) = get_all_symbolic_constraints( air, @@ -123,10 +124,17 @@ pub fn get_symbolic_constraints( ) -> Vec> where F: Field, - A: Air>, + A: BaseAir + Air>, { - let mut builder = - SymbolicAirBuilder::new(preprocessed_width, air.width(), num_public_values, 0, 0); + let num_periodic = air.periodic_table().len(); + let mut builder = SymbolicAirBuilder::new_with_periodic( + preprocessed_width, + air.width(), + num_public_values, + 0, + 0, + num_periodic, + ); air.eval(&mut builder); builder.base_constraints() } @@ -146,14 +154,16 @@ pub fn get_symbolic_constraints_extension( where F: Field, EF: ExtensionField, - A: Air>, + A: BaseAir + Air>, { - let mut builder = SymbolicAirBuilder::new( + let num_periodic = air.periodic_table().len(); + let mut builder = SymbolicAirBuilder::new_with_periodic( preprocessed_width, air.width(), num_public_values, permutation_width, num_permutation_challenges, + num_periodic, ); air.eval(&mut builder); builder.extension_constraints() @@ -174,14 +184,16 @@ pub fn get_all_symbolic_constraints( where F: Field, EF: ExtensionField, - A: Air>, + A: BaseAir + Air>, { - let mut builder = SymbolicAirBuilder::new( + let num_periodic = air.periodic_table().len(); + let mut builder = SymbolicAirBuilder::new_with_periodic( preprocessed_width, air.width(), num_public_values, permutation_width, num_permutation_challenges, + num_periodic, ); air.eval(&mut builder); (builder.base_constraints(), builder.extension_constraints()) @@ -193,6 +205,7 @@ pub struct SymbolicAirBuilder = F> { preprocessed: RowMajorMatrix>, main: RowMajorMatrix>, public_values: Vec>, + periodic_values: Vec>, base_constraints: Vec>, permutation: RowMajorMatrix>, permutation_challenges: Vec>, @@ -206,6 +219,24 @@ impl> SymbolicAirBuilder { num_public_values: usize, permutation_width: usize, num_permutation_challenges: usize, + ) -> Self { + Self::new_with_periodic( + preprocessed_width, + width, + num_public_values, + permutation_width, + num_permutation_challenges, + 0, + ) + } + + pub fn new_with_periodic( + preprocessed_width: usize, + width: usize, + num_public_values: usize, + permutation_width: usize, + num_permutation_challenges: usize, + num_periodic_values: usize, ) -> Self { let prep_values = [0, 1] .into_iter() @@ -223,6 +254,9 @@ impl> SymbolicAirBuilder { let public_values = (0..num_public_values) .map(move |index| SymbolicVariable::new(Entry::Public, index)) .collect(); + let periodic_values = (0..num_periodic_values) + .map(move |index| SymbolicVariable::new(Entry::Periodic, index)) + .collect(); let perm_values = [0, 1] .into_iter() .flat_map(|offset| { @@ -238,6 +272,7 @@ impl> SymbolicAirBuilder { preprocessed: RowMajorMatrix::new(prep_values, preprocessed_width), main: RowMajorMatrix::new(main_values, width), public_values, + periodic_values, base_constraints: vec![], permutation, permutation_challenges, @@ -333,6 +368,17 @@ where } } +impl> PeriodicAirBuilder for SymbolicAirBuilder +where + SymbolicExpression: From>, +{ + type PeriodicVar = SymbolicVariable; + + fn periodic_values(&self) -> &[Self::PeriodicVar] { + &self.periodic_values + } +} + #[cfg(test)] mod tests { use p3_air::BaseAir; diff --git a/uni-stark/src/symbolic_expression.rs b/uni-stark/src/symbolic_expression.rs index 142471005..177e59182 100644 --- a/uni-stark/src/symbolic_expression.rs +++ b/uni-stark/src/symbolic_expression.rs @@ -156,13 +156,13 @@ impl SymbolicExpression { /// /// Degree 0 (constants): /// - `Constant` - /// - `IsTransition` /// - `Variable` with public values or challenges /// /// Degree 1 (linear in trace length): /// - `Variable` with trace columns (main, preprocessed, permutation) /// - `IsFirstRow` /// - `IsLastRow` + /// - `IsTransition` /// /// Composite expressions: /// - `Add`, `Sub`: max of operands @@ -171,8 +171,8 @@ impl SymbolicExpression { pub const fn degree_multiple(&self) -> usize { match self { Self::Variable(v) => v.degree_multiple(), - Self::IsFirstRow | Self::IsLastRow => 1, - Self::IsTransition | Self::Constant(_) => 0, + Self::IsFirstRow | Self::IsLastRow | Self::IsTransition => 1, + Self::Constant(_) => 0, Self::Add { degree_multiple, .. } @@ -417,8 +417,8 @@ mod tests { let is_transition = SymbolicExpression::::IsTransition; assert_eq!( is_transition.degree_multiple(), - 0, - "IsTransition should have degree 0" + 1, + "IsTransition should have degree 1" ); let add_expr = SymbolicExpression::::Add { diff --git a/uni-stark/src/symbolic_variable.rs b/uni-stark/src/symbolic_variable.rs index 37d176338..80f742a31 100644 --- a/uni-stark/src/symbolic_variable.rs +++ b/uni-stark/src/symbolic_variable.rs @@ -10,6 +10,7 @@ pub enum Entry { Preprocessed { offset: usize }, Main { offset: usize }, Permutation { offset: usize }, + Periodic, Public, Challenge, } @@ -33,7 +34,12 @@ impl SymbolicVariable { pub const fn degree_multiple(&self) -> usize { match self.entry { - Entry::Preprocessed { .. } | Entry::Main { .. } | Entry::Permutation { .. } => 1, + Entry::Preprocessed { .. } + | Entry::Main { .. } + | Entry::Permutation { .. } + // Degree 1 is an approximation; see Winterfell's TransitionConstraintDegree for + // a more precise model: https://github.com/facebook/winterfell/blob/main/air/src/air/transition/degree.rs + | Entry::Periodic => 1, Entry::Public | Entry::Challenge => 0, } } diff --git a/uni-stark/src/verifier.rs b/uni-stark/src/verifier.rs index 898f5bccb..d3e301290 100644 --- a/uni-stark/src/verifier.rs +++ b/uni-stark/src/verifier.rs @@ -7,6 +7,7 @@ use alloc::{format, vec}; use itertools::Itertools; use p3_air::Air; use p3_challenger::{CanObserve, FieldChallenger}; +use p3_commit::PeriodicEvaluator; use p3_commit::{Pcs, PolynomialSpace}; use p3_field::{BasedVectorSpace, Field, PrimeCharacteristicRing}; use p3_matrix::dense::RowMajorMatrixView; @@ -79,6 +80,7 @@ pub fn verify_constraints( preprocessed_local: Option<&[SC::Challenge]>, preprocessed_next: Option<&[SC::Challenge]>, public_values: &[Val], + periodic_values: Vec, trace_domain: Domain, zeta: SC::Challenge, alpha: SC::Challenge, @@ -113,6 +115,7 @@ where is_transition: sels.is_transition, alpha, accumulator: SC::Challenge::ZERO, + periodic_values, }; air.eval(&mut folder); let folded_constraints = folder.accumulator; @@ -190,6 +193,11 @@ where } } +/// Verify a STARK proof for an AIR without preprocessed or periodic columns. +/// +/// This is the simplest entry point. For AIRs with preprocessed columns, +/// use [`verify_with_preprocessed`]. For AIRs with periodic columns, +/// use [`verify_with_periodic`]. #[instrument(skip_all)] pub fn verify( config: &SC, @@ -201,9 +209,33 @@ where SC: StarkGenericConfig, A: Air>> + for<'a> Air>, { - verify_with_preprocessed(config, air, proof, public_values, None) + verify_with_preprocessed_and_periodic::(config, air, proof, public_values, None) +} + +/// Verify a STARK proof for an AIR with periodic columns (but no preprocessed columns). +/// +/// The type parameter `PE` specifies how to evaluate periodic columns at the query point. +/// Use [`TwoAdicPeriodicEvaluator`](p3_fri::TwoAdicPeriodicEvaluator) for two-adic STARKs +/// or [`CirclePeriodicEvaluator`](p3_circle::CirclePeriodicEvaluator) for Circle STARKs. +#[instrument(skip_all)] +pub fn verify_with_periodic( + config: &SC, + air: &A, + proof: &Proof, + public_values: &[Val], +) -> Result<(), VerificationError>> +where + SC: StarkGenericConfig, + A: Air>> + for<'a> Air>, + PE: PeriodicEvaluator, Domain>, +{ + verify_with_preprocessed_and_periodic::(config, air, proof, public_values, None) } +/// Verify a STARK proof for an AIR with preprocessed columns (but no periodic columns). +/// +/// Use this when your AIR has preprocessed columns but no periodic columns. +/// For AIRs with both, use [`verify_with_preprocessed_and_periodic`]. #[instrument(skip_all)] pub fn verify_with_preprocessed( config: &SC, @@ -215,6 +247,32 @@ pub fn verify_with_preprocessed( where SC: StarkGenericConfig, A: Air>> + for<'a> Air>, +{ + verify_with_preprocessed_and_periodic::( + config, + air, + proof, + public_values, + preprocessed_vk, + ) +} + +/// Verify a STARK proof for an AIR with both preprocessed and periodic columns. +/// +/// This is the most general verification function. The type parameter `PE` specifies +/// how to evaluate periodic columns at the query point. +#[instrument(skip_all)] +pub fn verify_with_preprocessed_and_periodic( + config: &SC, + air: &A, + proof: &Proof, + public_values: &[Val], + preprocessed_vk: Option<&PreprocessedVerifierKey>, +) -> Result<(), VerificationError>> +where + SC: StarkGenericConfig, + A: Air>> + for<'a> Air>, + PE: PeriodicEvaluator, Domain>, { let Proof { commitments, @@ -375,6 +433,10 @@ where zeta, ); + // Evaluate periodic columns at zeta + let periodic_table = air.periodic_table(); + let periodic_values = PE::eval_at_point(&periodic_table, &init_trace_domain, zeta); + verify_constraints::>( air, &opened_values.trace_local, @@ -382,6 +444,7 @@ where opened_values.preprocessed_local.as_deref(), opened_values.preprocessed_next.as_deref(), public_values, + periodic_values, init_trace_domain, zeta, alpha, diff --git a/uni-stark/tests/rescue_hash.rs b/uni-stark/tests/rescue_hash.rs new file mode 100644 index 000000000..770566d6a --- /dev/null +++ b/uni-stark/tests/rescue_hash.rs @@ -0,0 +1,478 @@ +//! Rescue-like hash AIR with periodic round constants. +//! +//! This test arithmetizes a hash chain of length NUM_HASHES to demonstrate +//! periodic column integration. The round constants are provided as periodic +//! columns with period NUM_ROUNDS, allowing them to repeat across multiple +//! hash invocations without duplicating data in the trace. + +use core::marker::PhantomData; +use core::ops::Mul; + +use p3_air::{Air, AirBuilder, BaseAir, PeriodicAirBuilder}; +use p3_baby_bear::BabyBear; +use p3_challenger::{HashChallenger, SerializingChallenger32}; +use p3_circle::{CirclePcs, CirclePeriodicEvaluator}; +use p3_commit::ExtensionMmcs; +use p3_dft::Radix2DitParallel; +use p3_field::exponentiation::{exp_1717986917, exp_1725656503}; +use p3_field::extension::BinomialExtensionField; +use p3_field::{InjectiveMonomial, PrimeCharacteristicRing, PrimeField64}; +use p3_fri::TwoAdicPeriodicEvaluator; +use p3_fri::{FriParameters, TwoAdicFriPcs}; +use p3_keccak::Keccak256Hash; +use p3_matrix::Matrix; +use p3_matrix::dense::RowMajorMatrix; +use p3_merkle_tree::MerkleTreeMmcs; +use p3_mersenne_31::Mersenne31; +use p3_symmetric::{CompressionFunctionFromHasher, SerializingHasher}; +use p3_uni_stark::{StarkConfig, prove_with_periodic, verify_with_periodic}; +use rand::rngs::SmallRng; +use rand::{Rng, SeedableRng}; + +/// Standard FRI parameters for both tests +const LOG_BLOWUP: usize = 3; +const NUM_QUERIES: usize = 28; + +/// Number of rounds per hash invocation. The periodic columns have this period. +const NUM_ROUNDS: usize = 8; + +/// Length of the hash chain being proved. +const NUM_HASHES: usize = 8; + +/// Trace size = NUM_ROUNDS × NUM_HASHES (one row per round). +const TRACE_SIZE: usize = NUM_ROUNDS * NUM_HASHES; + +/// Circulant matrix coefficients for width 24. +/// This is supposed to be MDS (or near-MDS for M31) but most likely it is not. +const MDS_COEFFS_24: [u64; 24] = [ + 7, 1, 3, 8, 4, 6, 2, 9, 5, 1, 3, 7, 8, 2, 4, 9, 5, 1, 6, 3, 7, 2, 8, 4, +]; + +/// Apply circulant MDS matrix to state. +fn apply_mds( + state: [F; WIDTH], + coeffs: &[u64; WIDTH], +) -> [F; WIDTH] { + core::array::from_fn(|i| { + (0..WIDTH) + .map(|j| state[(i + j) % WIDTH] * F::from_u64(coeffs[j])) + .sum() + }) +} + +/// Rescue-like AIR with periodic round constants. +/// +/// This AIR proves that a Rescue-like hash was computed correctly. +/// Each row represents the state after applying the round function. +/// Round constants (ark1, ark2) are provided as periodic columns with period = NUM_ROUNDS. +/// +/// The constraint structure avoids computing the inverse S-box: +/// MDS((MDS(h) + ark1)^α) + ark2 = h'^α +/// +/// Both sides have degree α, so the constraint degree equals the S-box degree. +/// +/// Generic over: +/// - F: the field +/// - WIDTH: state width +/// - ALPHA: S-box degree (x^ALPHA) +#[derive(Clone)] +struct RescueLikeAir { + /// First round constants (ark1): ark1_periodic[i][round] = constant for element i at round + ark1_periodic: Vec>, + /// Second round constants (ark2): ark2_periodic[i][round] = constant for element i at round + ark2_periodic: Vec>, + /// MDS coefficients + mds_coeffs: [u64; WIDTH], +} + +impl RescueLikeAir +where + F: PrimeField64, +{ + fn new(mds_coeffs: [u64; WIDTH]) -> Self { + // Generate deterministic round constants (two sets: ark1, ark2) + let mut rng = SmallRng::seed_from_u64(0x524553435545); // "RESCUE" in hex + + // Generate ark1 constants + let ark1: Vec = (0..NUM_ROUNDS * WIDTH) + .map(|_| F::from_u64(rng.random::() % (1 << 30))) + .collect(); + + // Generate ark2 constants + let ark2: Vec = (0..NUM_ROUNDS * WIDTH) + .map(|_| F::from_u64(rng.random::() % (1 << 30))) + .collect(); + + // Organize as periodic columns: ark_periodic[element][round] + let ark1_periodic: Vec> = (0..WIDTH) + .map(|elem_idx| { + (0..NUM_ROUNDS) + .map(|round| ark1[round * WIDTH + elem_idx]) + .collect() + }) + .collect(); + + let ark2_periodic: Vec> = (0..WIDTH) + .map(|elem_idx| { + (0..NUM_ROUNDS) + .map(|round| ark2[round * WIDTH + elem_idx]) + .collect() + }) + .collect(); + + Self { + ark1_periodic, + ark2_periodic, + mds_coeffs, + } + } + + /// Get ark1 constants for a specific round. + fn get_ark1(&self, round: usize) -> [F; WIDTH] { + core::array::from_fn(|i| self.ark1_periodic[i][round]) + } + + /// Get ark2 constants for a specific round. + fn get_ark2(&self, round: usize) -> [F; WIDTH] { + core::array::from_fn(|i| self.ark2_periodic[i][round]) + } + + /// Compute one round of Rescue-like hash. + /// + /// Full Rescue round: h' = MDS((MDS(h) + ark1)^α + ark2)^(1/α) + /// + /// We compute this as: + /// temp = MDS(h) + ark1 (linear) + /// temp = temp^α (forward S-box) + /// temp = MDS(temp) + ark2 (linear) + /// h' = temp^(1/α) (inverse S-box) + fn compute_round(&self, state: &mut [F; WIDTH], round: usize) + where + F: InjectiveMonomial, + { + let ark1 = self.get_ark1(round); + let ark2 = self.get_ark2(round); + + // MDS + *state = apply_mds(*state, &self.mds_coeffs); + + // Add ark1 + for i in 0..WIDTH { + state[i] += ark1[i]; + } + + // Forward S-box: x^α + for s in state.iter_mut() { + *s = s.injective_exp_n(); + } + + // MDS + *state = apply_mds(*state, &self.mds_coeffs); + + // Add ark2 + for i in 0..WIDTH { + state[i] += ark2[i]; + } + + // Inverse S-box: x^(1/α) + // Uses optimized addition chains for the inverse exponents: + // - Mersenne31 (α=5): x^1717986917 since 5 * 1717986917 ≡ 1 (mod p-1) + // - BabyBear (α=7): x^1725656503 since 7 * 1725656503 ≡ 1 (mod p-1) + for s in state.iter_mut() { + *s = match ALPHA { + 5 => exp_1717986917(*s), + 7 => exp_1725656503(*s), + _ => panic!("Unsupported ALPHA for inverse S-box: {}", ALPHA), + }; + } + } + + /// Compute the full hash. + fn hash(&self, input: [F; WIDTH]) -> [F; WIDTH] + where + F: InjectiveMonomial, + { + let mut state = input; + for round in 0..NUM_ROUNDS { + self.compute_round(&mut state, round); + } + state + } + + /// Generate trace for proving knowledge of preimage. + /// Each row i contains the state BEFORE round (i % NUM_ROUNDS). + /// The trace has TRACE_SIZE rows, with the hash computation repeating. + fn generate_trace(&self, preimage: [F; WIDTH]) -> RowMajorMatrix + where + F: InjectiveMonomial, + { + let mut values = Vec::with_capacity(TRACE_SIZE * WIDTH); + let mut state = preimage; + + for row in 0..TRACE_SIZE { + // Store state before this round + values.extend_from_slice(&state); + // Compute round (wrapping around when hash completes) + let round = row % NUM_ROUNDS; + self.compute_round(&mut state, round); + } + + RowMajorMatrix::new(values, WIDTH) + } +} + +impl BaseAir for RescueLikeAir +where + F: PrimeCharacteristicRing + Sync + Copy, +{ + fn width(&self) -> usize { + WIDTH + } + + fn periodic_table(&self) -> Vec> { + // Interleave ark1 and ark2: [ark1[0], ark1[1], ..., ark2[0], ark2[1], ...] + let mut table = self.ark1_periodic.clone(); + table.extend(self.ark2_periodic.clone()); + table + } +} + +/// Compute x^5 explicitly (for Mersenne31, α=5) +fn exp5>(x: E) -> E { + let x2 = x.clone() * x.clone(); + let x4 = x2.clone() * x2; + x4 * x +} + +/// Compute x^7 explicitly (for BabyBear/Goldilocks, α=7) +fn exp7>(x: E) -> E { + let x2 = x.clone() * x.clone(); + let x4 = x2.clone() * x2.clone(); + let x6 = x4.clone() * x2; + x6 * x +} + +impl Air for RescueLikeAir +where + AB: AirBuilder + PeriodicAirBuilder, + AB::F: PrimeCharacteristicRing + PrimeField64 + Copy, +{ + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let local = main.row_slice(0).expect("matrix should have a local row"); + let next = main.row_slice(1).expect("matrix should have a next row"); + + // Get state from local row (h) + let h: [AB::Expr; WIDTH] = core::array::from_fn(|i| local[i].clone().into()); + + // Get next state (h') + let h_next: [AB::Expr; WIDTH] = core::array::from_fn(|i| next[i].clone().into()); + + // Get periodic round constants (ark1 is first WIDTH columns, ark2 is next WIDTH) + let periodic = builder.periodic_values(); + let ark1: [AB::Expr; WIDTH] = core::array::from_fn(|i| periodic[i].clone().into()); + let ark2: [AB::Expr; WIDTH] = core::array::from_fn(|i| periodic[WIDTH + i].clone().into()); + + // Convert MDS coefficients to field elements + let mds_field: [AB::F; WIDTH] = + core::array::from_fn(|i| AB::F::from_u64(self.mds_coeffs[i])); + + // Helper function to apply MDS to an expression array + let apply_mds_expr = |input: &[AB::Expr; WIDTH]| -> [AB::Expr; WIDTH] { + core::array::from_fn(|i| { + (0..WIDTH) + .map(|j| { + let coeff: AB::Expr = mds_field[j].into(); + input[(i + j) % WIDTH].clone() * coeff + }) + .sum() + }) + }; + + // Helper function to apply S-box to an expression array + let apply_sbox = |input: &[AB::Expr; WIDTH]| -> [AB::Expr; WIDTH] { + core::array::from_fn(|i| match ALPHA { + 5 => exp5(input[i].clone()), + 7 => exp7(input[i].clone()), + _ => panic!("Unsupported ALPHA: {}", ALPHA), + }) + }; + + // Rescue-like constraint: MDS((MDS(h) + ark1)^α) + ark2 = h'^α + // + // Left side (forward path from h): + // step1 = MDS(h) + // step2 = step1 + ark1 + // step3 = step2^α (forward S-box) + // step4 = MDS(step3) + ark2 + // + // Right side (from h'): + // step5 = h'^α + // + // Constraint: step4 == step5 + + // Left side computation + let step1 = apply_mds_expr(&h); + let step2: [AB::Expr; WIDTH] = core::array::from_fn(|i| step1[i].clone() + ark1[i].clone()); + let step3 = apply_sbox(&step2); + let step4_mds = apply_mds_expr(&step3); + let step4: [AB::Expr; WIDTH] = + core::array::from_fn(|i| step4_mds[i].clone() + ark2[i].clone()); + + // Right side computation + let step5 = apply_sbox(&h_next); + + // Constraint: step4[i] == step5[i] (on transition rows) + for i in 0..WIDTH { + builder + .when_transition() + .assert_eq(step4[i].clone(), step5[i].clone()); + } + } +} + +/// Test proving knowledge of preimage using two-adic FRI with BabyBear. +/// State width 24, digest size 8, S-box x^7. +#[test] +fn test_rescue_preimage_two_adic_babybear() { + const WIDTH: usize = 24; + const ALPHA: u64 = 7; + + type Val = BabyBear; + type Challenge = BinomialExtensionField; + type ByteHash = Keccak256Hash; + type FieldHash = SerializingHasher; + type MyCompress = CompressionFunctionFromHasher; + type ValMmcs = MerkleTreeMmcs; + type ChallengeMmcs = ExtensionMmcs; + type Dft = Radix2DitParallel; + type Challenger = SerializingChallenger32>; + type Pcs = TwoAdicFriPcs; + type MyConfig = StarkConfig; + + // Create the Rescue-like AIR + let air = RescueLikeAir::::new(MDS_COEFFS_24); + + // Generate a random preimage (the secret) + let mut rng = SmallRng::seed_from_u64(42); + let preimage: [Val; WIDTH] = + core::array::from_fn(|_| Val::from_u64(rng.random::() % (1 << 30))); + + // Compute the hash (the public output) + let hash_output = air.hash(preimage); + println!("BabyBear Rescue Preimage (first 3): {:?}", &preimage[..3]); + println!( + "BabyBear Rescue Hash output (first 3): {:?}", + &hash_output[..3] + ); + + // Generate the trace + let trace = air.generate_trace(preimage); + println!( + "BabyBear Rescue Trace: {} rows x {} cols", + trace.height(), + trace.width() + ); + + // Set up PCS with Keccak256 + let byte_hash = ByteHash {}; + let field_hash = FieldHash::new(byte_hash); + let compress = MyCompress::new(byte_hash); + let val_mmcs = ValMmcs::new(field_hash, compress); + let challenge_mmcs = ChallengeMmcs::new(val_mmcs.clone()); + let dft = Dft::default(); + + let fri_params = FriParameters { + log_blowup: LOG_BLOWUP, + log_final_poly_len: 0, + num_queries: NUM_QUERIES, + commit_proof_of_work_bits: 0, + query_proof_of_work_bits: 0, + mmcs: challenge_mmcs, + }; + let pcs = Pcs::new(dft, val_mmcs, fri_params); + let challenger = Challenger::from_hasher(vec![], byte_hash); + let config = MyConfig::new(pcs, challenger); + + // Prove and verify + let proof = + prove_with_periodic::<_, _, TwoAdicPeriodicEvaluator>(&config, &air, trace, &[]); + + verify_with_periodic::<_, _, TwoAdicPeriodicEvaluator>(&config, &air, &proof, &[]) + .expect("verification failed"); + println!("BabyBear Rescue verification succeeded!"); +} + +/// Test proving knowledge of preimage using Circle STARKs with Mersenne31. +/// State width 24, digest size 8, S-box x^5. +#[test] +fn test_rescue_preimage_circle_m31() { + const WIDTH: usize = 24; + const ALPHA: u64 = 5; + + type Val = Mersenne31; + type Challenge = BinomialExtensionField; // M31 only supports degree-3 extension + type ByteHash = Keccak256Hash; + type FieldHash = SerializingHasher; + type MyCompress = CompressionFunctionFromHasher; + type ValMmcs = MerkleTreeMmcs; + type ChallengeMmcs = ExtensionMmcs; + type Challenger = SerializingChallenger32>; + type Pcs = CirclePcs; + type MyConfig = StarkConfig; + + // Create the Rescue-like AIR + let air = RescueLikeAir::::new(MDS_COEFFS_24); + + // Generate a random preimage + let mut rng = SmallRng::seed_from_u64(42); + let preimage: [Val; WIDTH] = + core::array::from_fn(|_| Val::from_u64(rng.random::() % (1 << 30))); + + // Compute the hash + let hash_output = air.hash(preimage); + println!("Circle M31 Rescue Preimage (first 3): {:?}", &preimage[..3]); + println!( + "Circle M31 Rescue Hash output (first 3): {:?}", + &hash_output[..3] + ); + + // Generate the trace + let trace = air.generate_trace(preimage); + println!( + "Circle M31 Rescue Trace: {} rows x {} cols", + trace.height(), + trace.width() + ); + + // Set up Circle PCS with Keccak256 + let byte_hash = ByteHash {}; + let field_hash = FieldHash::new(byte_hash); + let compress = MyCompress::new(byte_hash); + let val_mmcs = ValMmcs::new(field_hash, compress); + let challenge_mmcs = ChallengeMmcs::new(val_mmcs.clone()); + + let fri_params = FriParameters { + log_blowup: LOG_BLOWUP, + log_final_poly_len: 0, + num_queries: NUM_QUERIES, + commit_proof_of_work_bits: 0, + query_proof_of_work_bits: 0, + mmcs: challenge_mmcs, + }; + + let pcs = Pcs { + mmcs: val_mmcs, + fri_params, + _phantom: PhantomData, + }; + let challenger = Challenger::from_hasher(vec![], byte_hash); + let config = MyConfig::new(pcs, challenger); + + // Prove and verify + let proof = prove_with_periodic::<_, _, CirclePeriodicEvaluator>(&config, &air, trace, &[]); + + verify_with_periodic::<_, _, CirclePeriodicEvaluator>(&config, &air, &proof, &[]) + .expect("verification failed"); + println!("Circle M31 Rescue verification succeeded!"); +}