From 2b1eb382eb796b3518d2dc6c5077c212a8cba517 Mon Sep 17 00:00:00 2001 From: Alexander Golovanov Date: Tue, 27 Jan 2026 17:36:17 +0000 Subject: [PATCH 1/7] Implement the no-rotation thing --- crates/stark-backend-v2/src/keygen/mod.rs | 18 ++- crates/stark-backend-v2/src/keygen/types.rs | 2 + .../src/prover/cpu_backend.rs | 17 +++ .../src/prover/logup_zerocheck/cpu.rs | 62 +++++---- .../src/prover/logup_zerocheck/single.rs | 39 ++++-- .../src/prover/stacked_reduction.rs | 108 +++++++++------ crates/stark-backend-v2/src/tests.rs | 4 + .../src/verifier/batch_constraints.rs | 17 ++- crates/stark-backend-v2/src/verifier/mod.rs | 21 +++ .../src/verifier/stacked_reduction.rs | 123 +++++++++++++----- 10 files changed, 297 insertions(+), 114 deletions(-) diff --git a/crates/stark-backend-v2/src/keygen/mod.rs b/crates/stark-backend-v2/src/keygen/mod.rs index 9584710c..dfa7dd18 100644 --- a/crates/stark-backend-v2/src/keygen/mod.rs +++ b/crates/stark-backend-v2/src/keygen/mod.rs @@ -218,12 +218,8 @@ impl AirKeygenBuilderV2 { let air_name = self.air.name(); let symbolic_builder = self.get_symbolic_builder(); - let vparams = StarkVerifyingParamsV2 { - width: symbolic_builder.width(), - num_public_values: symbolic_builder.num_public_values(), - }; - // Deprecated in v2: - assert!(vparams.width.after_challenge.is_empty()); + let width = symbolic_builder.width(); + let num_public_values = symbolic_builder.num_public_values(); let symbolic_constraints = symbolic_builder.constraints(); let constraint_degree = symbolic_constraints.max_constraint_degree(); @@ -245,6 +241,16 @@ impl AirKeygenBuilderV2 { } = self; let dag = SymbolicConstraintsDag::from(symbolic_constraints); + let max_rotation = dag.constraints.max_rotation(); // TODO: exclude unused vars? + debug_assert!(max_rotation <= 1); + let vparams = StarkVerifyingParamsV2 { + width, + num_public_values, + need_rot: max_rotation == 1, + }; + // Deprecated in v2: + assert!(vparams.width.after_challenge.is_empty()); + let unused_variables = find_unused_vars(&dag, &vparams.width); let vk = StarkVerifyingKeyV2 { preprocessed_data: preprocessed_vdata, diff --git a/crates/stark-backend-v2/src/keygen/types.rs b/crates/stark-backend-v2/src/keygen/types.rs index f86e777b..86a5b8b9 100644 --- a/crates/stark-backend-v2/src/keygen/types.rs +++ b/crates/stark-backend-v2/src/keygen/types.rs @@ -31,6 +31,8 @@ pub struct StarkVerifyingParamsV2 { pub width: TraceWidth, /// Number of public values for this STARK only pub num_public_values: usize, + /// A flag indication whether we need the rotations + pub need_rot: bool, } /// Verifier data for preprocessed trace for a single AIR. diff --git a/crates/stark-backend-v2/src/prover/cpu_backend.rs b/crates/stark-backend-v2/src/prover/cpu_backend.rs index b88d5a10..562e89c9 100644 --- a/crates/stark-backend-v2/src/prover/cpu_backend.rs +++ b/crates/stark-backend-v2/src/prover/cpu_backend.rs @@ -92,6 +92,12 @@ impl OpeningProverV2 for CpuDeviceV2 ) -> (StackingProof, WhirProof) { let params = &self.config; + let need_rot_per_trace = ctx + .per_trace + .iter() + .map(|(air_idx, _)| mpk.per_air[*air_idx].vk.params.need_rot) + .collect_vec(); + // Currently alternates between preprocessed and cached pcs data let pre_cached_pcs_data_per_commit: Vec<_> = ctx .per_trace @@ -109,12 +115,23 @@ impl OpeningProverV2 for CpuDeviceV2 for data in &pre_cached_pcs_data_per_commit { stacked_per_commit.push(data); } + let mut need_rot_per_commit = vec![need_rot_per_trace]; + for (air_idx, air_ctx) in &ctx.per_trace { + let need_rot = mpk.per_air[*air_idx].vk.params.need_rot; + if mpk.per_air[*air_idx].preprocessed_data.is_some() { + need_rot_per_commit.push(vec![need_rot]); + } + for _ in &air_ctx.cached_mains { + need_rot_per_commit.push(vec![need_rot]); + } + } let (stacking_proof, u_prisma) = prove_stacked_opening_reduction::<_, _, _, StackedReductionCpu>( self, transcript, self.config.n_stack, stacked_per_commit, + need_rot_per_commit, &r, ); diff --git a/crates/stark-backend-v2/src/prover/logup_zerocheck/cpu.rs b/crates/stark-backend-v2/src/prover/logup_zerocheck/cpu.rs index 22e6b0cc..26a47a16 100644 --- a/crates/stark-backend-v2/src/prover/logup_zerocheck/cpu.rs +++ b/crates/stark-backend-v2/src/prover/logup_zerocheck/cpu.rs @@ -152,7 +152,8 @@ impl<'a> LogupZerocheckCpu<'a> { } } } - let needs_next = rotation > 0; + let needs_next = pk.vk.params.need_rot; + debug_assert_eq!(needs_next, rotation > 0); let symbolic_constraints = SymbolicConstraints::from(&pk.vk.symbolic_constraints); EvalHelper { constraints_dag: &pk.vk.symbolic_constraints.constraints, @@ -603,29 +604,42 @@ impl<'a> LogupZerocheckCpu<'a> { let mut column_openings = Vec::with_capacity(num_airs_present); // At the end, we've folded all MLEs so they only have one row equal to evaluation at `\vec // r`. - for mut mat_evals in take(&mut self.mat_evals_per_trace) { - // Order of mats is: - // - preprocessed (if has_preprocessed), - // - preprocessed_rot (if has_preprocessed), - // - cached(0), cached(0)_rot, ... - // - common_main - // - common_main_rot - // For column openings, we pop common_main, common_main_rot and put it at the front - assert_eq!(mat_evals.len() % 2, 0); // always include rot for now - let common_main_rot = mat_evals.pop().unwrap(); - let common_main = mat_evals.pop().unwrap(); - let openings_of_air = iter::once(&[common_main, common_main_rot] as &[_]) - .chain(mat_evals.chunks_exact(2)) - .map(|pair| { - zip(pair[0].columns(), pair[1].columns()) - .map(|(claim, claim_rot)| { - assert_eq!(claim.len(), 1); - assert_eq!(claim_rot.len(), 1); - (claim[0], claim_rot[0]) - }) - .collect_vec() - }) - .collect_vec(); + for (helper, mut mat_evals) in self + .eval_helpers + .iter() + .zip(take(&mut self.mat_evals_per_trace)) + { + // For column openings, we pop common_main (and common_main_rot when present) and put it + // at the front. + let openings_of_air = if helper.needs_next { + let common_main_rot = mat_evals.pop().unwrap(); + let common_main = mat_evals.pop().unwrap(); + iter::once(&[common_main, common_main_rot] as &[_]) + .chain(mat_evals.chunks_exact(2)) + .map(|pair| { + zip(pair[0].columns(), pair[1].columns()) + .map(|(claim, claim_rot)| { + assert_eq!(claim.len(), 1); + assert_eq!(claim_rot.len(), 1); + (claim[0], claim_rot[0]) + }) + .collect_vec() + }) + .collect_vec() + } else { + let common_main = mat_evals.pop().unwrap(); + iter::once(common_main) + .chain(mat_evals.into_iter()) + .map(|mat| { + mat.columns() + .map(|claim| { + assert_eq!(claim.len(), 1); + (claim[0], EF::ZERO) + }) + .collect_vec() + }) + .collect_vec() + }; column_openings.push(openings_of_air); } column_openings diff --git a/crates/stark-backend-v2/src/prover/logup_zerocheck/single.rs b/crates/stark-backend-v2/src/prover/logup_zerocheck/single.rs index c015507e..2f9e413b 100644 --- a/crates/stark-backend-v2/src/prover/logup_zerocheck/single.rs +++ b/crates/stark-backend-v2/src/prover/logup_zerocheck/single.rs @@ -39,20 +39,29 @@ impl<'a> EvalHelper<'a, crate::F> { &self, ctx: &'a AirProvingContextV2, ) -> Vec<(StridedColMajorMatrixView<'a, crate::F>, bool)> { - let mut mats = Vec::with_capacity( - 2 * (usize::from(self.has_preprocessed()) + 1 + ctx.cached_mains.len()), - ); + let base_mats = usize::from(self.has_preprocessed()) + 1 + ctx.cached_mains.len(); + let mut mats = Vec::with_capacity(if self.needs_next { + 2 * base_mats + } else { + base_mats + }); if let Some(mat) = self.preprocessed_trace { mats.push((mat, false)); - mats.push((mat, true)); + if self.needs_next { + mats.push((mat, true)); + } } for cd in ctx.cached_mains.iter() { let trace_view = cd.data.mat_view(0); mats.push((trace_view, false)); - mats.push((trace_view, true)); + if self.needs_next { + mats.push((trace_view, true)); + } } mats.push((ctx.common_main.as_view().into(), false)); - mats.push((ctx.common_main.as_view().into(), true)); + if self.needs_next { + mats.push((ctx.common_main.as_view().into(), true)); + } mats } } @@ -148,10 +157,20 @@ impl EvalHelper<'_, F> { row_parts: &[Vec], ) -> ProverConstraintEvaluator<'_, F, FF> { let sels = &row_parts[0]; - let mut view_pairs = row_parts[1..] - .chunks_exact(2) - .map(|pair| ViewPair::new(&pair[0], self.needs_next.then(|| &pair[1][..]))) - .collect_vec(); + let mut view_pairs = if self.needs_next { + let mut chunks = row_parts[1..].chunks_exact(2); + let pairs = chunks + .by_ref() + .map(|pair| ViewPair::new(&pair[0], Some(&pair[1][..]))) + .collect_vec(); + debug_assert!(chunks.remainder().is_empty()); + pairs + } else { + row_parts[1..] + .iter() + .map(|part| ViewPair::new(part, None)) + .collect_vec() + }; let mut preprocessed = None; if self.has_preprocessed() { preprocessed = Some(view_pairs.remove(0)); diff --git a/crates/stark-backend-v2/src/prover/stacked_reduction.rs b/crates/stark-backend-v2/src/prover/stacked_reduction.rs index c5d55fc3..da2e3bcf 100644 --- a/crates/stark-backend-v2/src/prover/stacked_reduction.rs +++ b/crates/stark-backend-v2/src/prover/stacked_reduction.rs @@ -38,6 +38,7 @@ pub trait StackedReductionProver<'a, PB: ProverBackendV2, PD> { fn new( device: &'a PD, stacked_per_commit: Vec<&'a PB::PcsData>, + need_rot_per_commit: Vec>, r: &[PB::Challenge], lambda: PB::Challenge, ) -> Self; @@ -68,6 +69,7 @@ pub fn prove_stacked_opening_reduction<'a, PB, PD, TS, SRP>( transcript: &mut TS, n_stack: usize, stacked_per_commit: Vec<&'a PB::PcsData>, + need_rot_per_commit: Vec>, r: &[PB::Challenge], ) -> (StackingProof, Vec) where @@ -78,7 +80,7 @@ where // Batching randomness let lambda = transcript.sample_ext(); - let mut prover = SRP::new(device, stacked_per_commit, r, lambda); + let mut prover = SRP::new(device, stacked_per_commit, need_rot_per_commit, r, lambda); let s_0 = prover.batch_sumcheck_uni_round0_poly(); for &coeff in s_0.coeffs() { transcript.observe_ext(coeff); @@ -132,7 +134,7 @@ pub struct StackedReductionCpu<'a> { eq_const: EF, stacked_per_commit: Vec<&'a StackedPcsData>, - trace_views: Vec<(usize, StackedSlice)>, + trace_views: Vec, ht_diff_idxs: Vec, eq_r_per_lht: HashMap>, @@ -144,41 +146,62 @@ pub struct StackedReductionCpu<'a> { eq_ub_per_trace: Vec, } +struct TraceViewMeta { + com_idx: usize, + slice: StackedSlice, + lambda_eq_idx: usize, + lambda_rot_idx: Option, +} + impl<'a> StackedReductionProver<'a, CpuBackendV2, CpuDeviceV2> for StackedReductionCpu<'a> { fn new( device: &CpuDeviceV2, stacked_per_commit: Vec<&'a StackedPcsData>, + need_rot_per_commit: Vec>, r: &[EF], lambda: EF, ) -> Self { let l_skip = device.config().l_skip; let omega_skip = F::two_adic_generator(l_skip); - let total_num_col_openings = stacked_per_commit - .iter() - .map(|d| d.layout.sorted_cols.len() * 2) - .sum(); - let lambda_pows = lambda.powers().take(total_num_col_openings).collect_vec(); - - // Flattened list of unstacked trace column slices for convenience - let trace_views = stacked_per_commit - .iter() - .enumerate() - .flat_map(|(com_idx, d)| d.layout.unstacked_slices_iter().map(move |s| (com_idx, *s))) - .collect_vec(); + let mut trace_views = Vec::new(); + let mut lambda_idx = 0usize; + for (com_idx, d) in stacked_per_commit.iter().enumerate() { + let need_rot_for_commit = &need_rot_per_commit[com_idx]; + debug_assert_eq!(need_rot_for_commit.len(), d.layout.mat_starts.len()); + for &(mat_idx, _col_idx, slice) in &d.layout.sorted_cols { + let need_rot = need_rot_for_commit[mat_idx]; + let lambda_eq_idx = lambda_idx; + lambda_idx += 1; + let lambda_rot_idx = if need_rot { + let idx = lambda_idx; + lambda_idx += 1; + Some(idx) + } else { + None + }; + trace_views.push(TraceViewMeta { + com_idx, + slice, + lambda_eq_idx, + lambda_rot_idx, + }); + } + } + let lambda_pows = lambda.powers().take(lambda_idx).collect_vec(); let mut ht_diff_idxs = Vec::new(); let mut eq_r_per_lht: HashMap> = HashMap::new(); let mut last_height = 0; - for (i, (_, s)) in trace_views.iter().enumerate() { - let n_lift = s.log_height().saturating_sub(l_skip); - if i == 0 || s.log_height() != last_height { + for (i, tv) in trace_views.iter().enumerate() { + let n_lift = tv.slice.log_height().saturating_sub(l_skip); + if i == 0 || tv.slice.log_height() != last_height { ht_diff_idxs.push(i); - last_height = s.log_height(); + last_height = tv.slice.log_height(); } - eq_r_per_lht - .entry(s.log_height()) - .or_insert_with(|| ColMajorMatrix::new(evals_eq_hypercube(&r[1..1 + n_lift]), 1)); + eq_r_per_lht.entry(tv.slice.log_height()).or_insert_with(|| { + ColMajorMatrix::new(evals_eq_hypercube(&r[1..1 + n_lift]), 1) + }); } ht_diff_idxs.push(trace_views.len()); @@ -233,7 +256,7 @@ impl<'a> StackedReductionProver<'a, CpuBackendV2, CpuDeviceV2> for StackedReduct .par_windows(2) .flat_map(|window| { let t_window = &self.trace_views[window[0]..window[1]]; - let log_height = t_window[0].1.log_height(); + let log_height = t_window[0].slice.log_height(); let n = log_height as isize - l_skip as isize; let n_lift = n.max(0) as usize; let eq_rs = self.eq_r_per_lht.get(&log_height).unwrap().column(0); @@ -241,9 +264,10 @@ impl<'a> StackedReductionProver<'a, CpuBackendV2, CpuDeviceV2> for StackedReduct // Prepare the q subslice eval views let q_t_cols = t_window .iter() - .map(|(com_idx, s)| { - debug_assert_eq!(s.log_height(), log_height); - let q = &self.stacked_per_commit[*com_idx].matrix; + .map(|tv| { + debug_assert_eq!(tv.slice.log_height(), log_height); + let q = &self.stacked_per_commit[tv.com_idx].matrix; + let s = tv.slice; let q_t_col = &q.column(s.col_idx)[s.row_idx..s.row_idx + s.len(l_skip)]; // NOTE: even if s.stride(l_skip) != 1, we use the full non-strided column // subslice. The sumcheck will not depend on the values outside of the @@ -272,14 +296,12 @@ impl<'a> StackedReductionProver<'a, CpuBackendV2, CpuDeviceV2> for StackedReduct let eq = eq_uni_r0 * eq_cube; let k_rot = eq_uni_r0_rot * eq_cube + self.eq_const * eq_uni_1 * (k_rot_cube - eq_cube); - zip( - self.lambda_pows[2 * window[0]..2 * window[1]].chunks_exact(2), - evals, - ) - .fold([EF::ZERO; 2], |mut acc, (lambdas, eval)| { + zip(t_window, evals).fold([EF::ZERO; 2], |mut acc, (tv, eval)| { let q = eval[0]; - acc[0] += lambdas[0] * eq * q * ind; - acc[1] += lambdas[1] * k_rot * q * ind; + acc[0] += self.lambda_pows[tv.lambda_eq_idx] * eq * q * ind; + if let Some(rot_idx) = tv.lambda_rot_idx { + acc[1] += self.lambda_pows[rot_idx] * k_rot * q * ind; + } acc }) }) @@ -361,7 +383,7 @@ impl<'a> StackedReductionProver<'a, CpuBackendV2, CpuDeviceV2> for StackedReduct .par_windows(2) .flat_map(|window| { let t_views = &self.trace_views[window[0]..window[1]]; - let log_height = t_views[0].1.log_height(); + let log_height = t_views[0].slice.log_height(); let n_lift = log_height.saturating_sub(l_skip); // \tilde{n}_T let hypercube_dim = n_lift.saturating_sub(round); let eq_rs = self.eq_r_per_lht.get(&log_height).unwrap().column(0); @@ -371,11 +393,12 @@ impl<'a> StackedReductionProver<'a, CpuBackendV2, CpuDeviceV2> for StackedReduct // Prepare the q subslice eval views let t_cols = t_views .iter() - .map(|(com_idx, s)| { - debug_assert_eq!(s.log_height(), log_height); + .map(|tv| { + debug_assert_eq!(tv.slice.log_height(), log_height); // q(u[..round], X, b_{T,j}[round-\tilde n_T..]) // q_evals has been folded already - let q = &self.q_evals[*com_idx]; + let q = &self.q_evals[tv.com_idx]; + let s = tv.slice; let row_start = if round <= n_lift { // round >= 1 so n_lift >= 1 (s.row_idx >> log_height) << (hypercube_dim + 1) @@ -393,12 +416,12 @@ impl<'a> StackedReductionProver<'a, CpuBackendV2, CpuDeviceV2> for StackedReduct .enumerate() .fold([EF::ZERO; 2], |mut acc, (i, eval)| { let t_idx = window[0] + i; + let tv = &self.trace_views[t_idx]; let q = eval[0]; let mut eq_ub = self.eq_ub_per_trace[t_idx]; let (eq, k_rot) = if round > n_lift { // Extra contribution of eq(X, b_{T,j}[round-n_T-1]) - let b = - (self.trace_views[t_idx].1.row_idx >> (l_skip + round - 1)) & 1; + let b = (tv.slice.row_idx >> (l_skip + round - 1)) & 1; eq_ub *= eval_eq_mle(&[x], &[F::from_bool(b == 1)]); debug_assert_eq!(y, 0); (eq_rs[0] * eq_ub, k_rot_rs[0] * eq_ub) @@ -410,8 +433,10 @@ impl<'a> StackedReductionProver<'a, CpuBackendV2, CpuDeviceV2> for StackedReduct k_rot_rs[y << 1] * (EF::ONE - x) + k_rot_rs[(y << 1) + 1] * x; (eq_r * eq_ub, k_rot_r * eq_ub) }; - acc[0] += self.lambda_pows[t_idx * 2] * q * eq; - acc[1] += self.lambda_pows[t_idx * 2 + 1] * q * k_rot; + acc[0] += self.lambda_pows[tv.lambda_eq_idx] * q * eq; + if let Some(rot_idx) = tv.lambda_rot_idx { + acc[1] += self.lambda_pows[rot_idx] * q * k_rot; + } acc }) }) @@ -431,7 +456,8 @@ impl<'a> StackedReductionProver<'a, CpuBackendV2, CpuDeviceV2> for StackedReduct .into_par_iter() .map(|(lht, mat)| (lht, fold_mle_evals(mat, u_round))) .collect(); - for ((_, s), eq_ub) in zip(&self.trace_views, &mut self.eq_ub_per_trace) { + for (tv, eq_ub) in zip(&self.trace_views, &mut self.eq_ub_per_trace) { + let s = tv.slice; let n_lift = s.log_height().saturating_sub(l_skip); if round > n_lift { // Folding above did nothing, and we update the eq(u[1+n_T..=round], diff --git a/crates/stark-backend-v2/src/tests.rs b/crates/stark-backend-v2/src/tests.rs index ce4ef987..5f2a1c15 100644 --- a/crates/stark-backend-v2/src/tests.rs +++ b/crates/stark-backend-v2/src/tests.rs @@ -229,11 +229,14 @@ fn test_stacked_opening_reduction(log_trace_degree: usize) -> Result<(), Stacked &common_main_pcs_data, ); + let need_rot = pk.per_air[ctx.per_trace[0].0].vk.params.need_rot; + let need_rot_per_commit = vec![vec![need_rot]]; let (stacking_proof, _) = prove_stacked_opening_reduction::<_, _, _, StackedReductionCpu>( device, &mut DuplexSponge::default(), params.n_stack, vec![&common_main_pcs_data], + need_rot_per_commit.clone(), &r, ); @@ -243,6 +246,7 @@ fn test_stacked_opening_reduction(log_trace_degree: usize) -> Result<(), Stacked &mut DuplexSponge::default(), &stacking_proof, &[common_main_pcs_data.layout], + &need_rot_per_commit, params.l_skip, params.n_stack, &batch_proof.column_openings, diff --git a/crates/stark-backend-v2/src/verifier/batch_constraints.rs b/crates/stark-backend-v2/src/verifier/batch_constraints.rs index 171db429..eb4d10cc 100644 --- a/crates/stark-backend-v2/src/verifier/batch_constraints.rs +++ b/crates/stark-backend-v2/src/verifier/batch_constraints.rs @@ -48,6 +48,9 @@ pub enum BatchConstraintError { #[error("Claims are inconsistent")] InconsistentClaims, + + #[error("rotation opening provided when rotations are not needed")] + RotationsNotNeeded, } /// `public_values` should be in vkey (air_idx) order, including non-present AIRs. @@ -258,10 +261,18 @@ pub fn verify_zerocheck_and_logup( // 9. Compute the interaction/constraint evals and their hash let mut interactions_evals = Vec::new(); let mut constraints_evals = Vec::new(); + let need_rot_per_trace = trace_id_to_air_id + .iter() + .map(|&air_idx| mvk.per_air[air_idx].params.need_rot) + .collect_vec(); // Observe common main openings first, and then preprocessed/cached - for air_openings in column_openings.iter() { + for (trace_idx, air_openings) in column_openings.iter().enumerate() { + let need_rot = need_rot_per_trace[trace_idx]; for &(claim, claim_rot) in &air_openings[0] { + if !need_rot && claim_rot != EF::ZERO { + return Err(BatchConstraintError::RotationsNotNeeded); + } transcript.observe_ext(claim); transcript.observe_ext(claim_rot); } @@ -272,10 +283,14 @@ pub fn verify_zerocheck_and_logup( let vk = &mvk.per_air[air_idx]; let n = n_per_trace[trace_idx]; let n_lift = n.max(0) as usize; + let need_rot = need_rot_per_trace[trace_idx]; // claim lengths are checked in proof shape for claims in air_openings.iter().skip(1) { for &(claim, claim_rot) in claims.iter() { + if !need_rot && claim_rot != EF::ZERO { + return Err(BatchConstraintError::RotationsNotNeeded); + } transcript.observe_ext(claim); transcript.observe_ext(claim_rot); } diff --git a/crates/stark-backend-v2/src/verifier/mod.rs b/crates/stark-backend-v2/src/verifier/mod.rs index cdeda1c2..8dd044fd 100644 --- a/crates/stark-backend-v2/src/verifier/mod.rs +++ b/crates/stark-backend-v2/src/verifier/mod.rs @@ -148,10 +148,31 @@ pub fn verify( &omega_skip_pows, )?; + let need_rot_per_trace = trace_id_to_air_id + .iter() + .map(|&air_id| per_air[air_id].params.need_rot) + .collect_vec(); + let mut need_rot_per_commit = vec![need_rot_per_trace]; + for &air_id in &trace_id_to_air_id { + let need_rot = per_air[air_id].params.need_rot; + if per_air[air_id].preprocessed_data.is_some() { + need_rot_per_commit.push(vec![need_rot]); + } + let cached_len = trace_vdata[air_id] + .as_ref() + .unwrap() + .cached_commitments + .len(); + for _ in 0..cached_len { + need_rot_per_commit.push(vec![need_rot]); + } + } + let u_prism = verify_stacked_reduction( transcript, stacking_proof, &layouts, + &need_rot_per_commit, l_skip, params.n_stack, &proof.batch_constraint_proof.column_openings, diff --git a/crates/stark-backend-v2/src/verifier/stacked_reduction.rs b/crates/stark-backend-v2/src/verifier/stacked_reduction.rs index a9a1789c..67df1c92 100644 --- a/crates/stark-backend-v2/src/verifier/stacked_reduction.rs +++ b/crates/stark-backend-v2/src/verifier/stacked_reduction.rs @@ -23,6 +23,9 @@ pub enum StackedReductionError { #[error("s_n(u_n) does not match claimed q(u) sum: {claim} != {final_sum}")] FinalSumMismatch { claim: EF, final_sum: EF }, + + #[error("rotation opening provided when rotations are not needed")] + RotationsNotNeeded, } /// `has_preprocessed` must be per present trace in sorted AIR order. @@ -32,6 +35,7 @@ pub fn verify_stacked_reduction( transcript: &mut TS, proof: &StackingProof, layouts: &[StackedLayout], + need_rot_per_commit: &[Vec], l_skip: usize, n_stack: usize, column_openings: &Vec>>, @@ -51,26 +55,68 @@ pub fn verify_stacked_reduction( let omega_order = omega_shift_pows.len(); let omega_order_f = F::from_canonical_usize(omega_order); - let t_claims_len = layouts + debug_assert_eq!(layouts.len(), need_rot_per_commit.len()); + let mut lambda_idx = 0usize; + let lambda_indices_per_layout: Vec)>> = layouts .iter() - .map(|l| l.sorted_cols.len() * 2) - .sum::(); + .enumerate() + .map(|(commit_idx, layout)| { + let need_rot_for_commit = &need_rot_per_commit[commit_idx]; + debug_assert_eq!(need_rot_for_commit.len(), layout.mat_starts.len()); + layout + .sorted_cols + .iter() + .map(|&(mat_idx, _col_idx, _slice)| { + let lambda_eq_idx = lambda_idx; + lambda_idx += 1; + let lambda_rot_idx = if need_rot_for_commit[mat_idx] { + let idx = lambda_idx; + lambda_idx += 1; + Some(idx) + } else { + None + }; + (lambda_eq_idx, lambda_rot_idx) + }) + .collect_vec() + }) + .collect_vec(); + let t_claims_len = lambda_idx; let mut t_claims = Vec::with_capacity(t_claims_len); - // common main columns - column_openings.iter().for_each(|parts| { - t_claims.extend(parts[0].iter().flat_map(|(t, t_rot)| [*t, *t_rot])); - }); + // common main columns (commit 0) + for (trace_idx, parts) in column_openings.iter().enumerate() { + let need_rot = need_rot_per_commit[0][trace_idx]; + for &(t, t_rot) in &parts[0] { + t_claims.push(t); + if need_rot { + t_claims.push(t_rot); + } else { + if t_rot != EF::ZERO { + return Err(StackedReductionError::RotationsNotNeeded); + } + } + } + } - // preprocessed and cached columns - column_openings.iter().for_each(|parts| { - t_claims.extend( - parts - .iter() - .skip(1) - .flat_map(|cols| cols.iter().flat_map(|(t, t_rot)| [*t, *t_rot])), - ); - }); + // preprocessed and cached columns (commits 1..) + let mut commit_idx = 1usize; + for parts in column_openings { + for cols in parts.iter().skip(1) { + let need_rot = need_rot_per_commit[commit_idx][0]; + for &(t, t_rot) in cols { + t_claims.push(t); + if need_rot { + t_claims.push(t_rot); + } else { + if t_rot != EF::ZERO { + return Err(StackedReductionError::RotationsNotNeeded); + } + } + } + commit_idx += 1; + } + } assert_eq!(t_claims.len(), t_claims_len); debug!(?t_claims); @@ -153,17 +199,23 @@ pub fn verify_stacked_reduction( .map(|vec| vec![EF::ZERO; vec.len()]) .collect_vec(); - let mut j = 0usize; layouts .iter() + .enumerate() .zip(q_coeffs.iter_mut()) - .for_each(|(layout, coeffs)| { - layout.sorted_cols.iter().for_each(|&(_, _, s)| { - let n = s.log_height() as isize - l_skip as isize; - let n_lift = n.max(0) as usize; - let b = (l_skip + n_lift..l_skip + n_stack) - .map(|j| F::from_bool((s.row_idx >> j) & 1 == 1)) - .collect_vec(); + .for_each(|((commit_idx, layout), coeffs)| { + let lambda_indices = &lambda_indices_per_layout[commit_idx]; + layout + .sorted_cols + .iter() + .enumerate() + .for_each(|(col_idx, &(_, _, s))| { + let (lambda_eq_idx, lambda_rot_idx) = lambda_indices[col_idx]; + let n = s.log_height() as isize - l_skip as isize; + let n_lift = n.max(0) as usize; + let b = (l_skip + n_lift..l_skip + n_stack) + .map(|j| F::from_bool((s.row_idx >> j) & 1 == 1)) + .collect_vec(); let eq_mle = eval_eq_mle(&u[n_lift + 1..], &b); let ind = eval_in_uni(l_skip, n, u[0]); let (l, rs_n) = if n.is_negative() { @@ -173,14 +225,15 @@ pub fn verify_stacked_reduction( ) } else { (l_skip, &r[..=n_lift]) - }; - let eq_prism = eval_eq_prism(l, &u[..=n_lift], rs_n); - let rot_kernel_prism = eval_rot_kernel_prism(l, &u[..=n_lift], rs_n); - coeffs[s.col_idx] += eq_mle - * (lambda_powers[j] * eq_prism + lambda_powers[j + 1] * rot_kernel_prism) - * ind; - j += 2; - }); + }; + let eq_prism = eval_eq_prism(l, &u[..=n_lift], rs_n); + let rot_kernel_prism = eval_rot_kernel_prism(l, &u[..=n_lift], rs_n); + let mut batched = lambda_powers[lambda_eq_idx] * eq_prism; + if let Some(rot_idx) = lambda_rot_idx { + batched += lambda_powers[rot_idx] * rot_kernel_prism; + } + coeffs[s.col_idx] += eq_mle * batched * ind; + }); }); let final_sum = q_coeffs.iter().zip(proof.stacking_openings.iter()).fold( @@ -221,6 +274,7 @@ mod tests { pub transcript: DuplexSponge, pub proof: StackingProof, pub layouts: Vec, + pub need_rot_per_commit: Vec>, pub column_openings: Vec>>, pub r: Vec, pub omega_pows: Vec, @@ -346,6 +400,7 @@ mod tests { transcript: DuplexSponge::default(), proof, layouts: vec![layout], + need_rot_per_commit: vec![vec![true]], column_openings, r, omega_pows, @@ -359,6 +414,7 @@ mod tests { &mut test_case.transcript, &test_case.proof, &test_case.layouts, + &test_case.need_rot_per_commit, L_SKIP, N_STACK, &test_case.column_openings, @@ -376,6 +432,7 @@ mod tests { &mut test_case.transcript, &test_case.proof, &test_case.layouts, + &test_case.need_rot_per_commit, L_SKIP, N_STACK, &test_case.column_openings, @@ -393,6 +450,7 @@ mod tests { &mut test_case.transcript, &test_case.proof, &test_case.layouts, + &test_case.need_rot_per_commit, L_SKIP, N_STACK, &test_case.column_openings, @@ -410,6 +468,7 @@ mod tests { &mut test_case.transcript, &test_case.proof, &test_case.layouts, + &test_case.need_rot_per_commit, L_SKIP, N_STACK, &test_case.column_openings, From 2478f2c34f7698aac1d9f8cc2aefbc50352d468f Mon Sep 17 00:00:00 2001 From: Alexander Golovanov Date: Tue, 27 Jan 2026 17:54:37 +0000 Subject: [PATCH 2/7] `cargo fmt` --- .../src/prover/stacked_reduction.rs | 6 +++--- .../src/verifier/stacked_reduction.rs | 18 +++++++++--------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/crates/stark-backend-v2/src/prover/stacked_reduction.rs b/crates/stark-backend-v2/src/prover/stacked_reduction.rs index da2e3bcf..246bec72 100644 --- a/crates/stark-backend-v2/src/prover/stacked_reduction.rs +++ b/crates/stark-backend-v2/src/prover/stacked_reduction.rs @@ -199,9 +199,9 @@ impl<'a> StackedReductionProver<'a, CpuBackendV2, CpuDeviceV2> for StackedReduct ht_diff_idxs.push(i); last_height = tv.slice.log_height(); } - eq_r_per_lht.entry(tv.slice.log_height()).or_insert_with(|| { - ColMajorMatrix::new(evals_eq_hypercube(&r[1..1 + n_lift]), 1) - }); + eq_r_per_lht + .entry(tv.slice.log_height()) + .or_insert_with(|| ColMajorMatrix::new(evals_eq_hypercube(&r[1..1 + n_lift]), 1)); } ht_diff_idxs.push(trace_views.len()); diff --git a/crates/stark-backend-v2/src/verifier/stacked_reduction.rs b/crates/stark-backend-v2/src/verifier/stacked_reduction.rs index 67df1c92..61e39de1 100644 --- a/crates/stark-backend-v2/src/verifier/stacked_reduction.rs +++ b/crates/stark-backend-v2/src/verifier/stacked_reduction.rs @@ -216,15 +216,15 @@ pub fn verify_stacked_reduction( let b = (l_skip + n_lift..l_skip + n_stack) .map(|j| F::from_bool((s.row_idx >> j) & 1 == 1)) .collect_vec(); - let eq_mle = eval_eq_mle(&u[n_lift + 1..], &b); - let ind = eval_in_uni(l_skip, n, u[0]); - let (l, rs_n) = if n.is_negative() { - ( - l_skip.wrapping_add_signed(n), - &[r[0].exp_power_of_2(-n as usize)] as &[_], - ) - } else { - (l_skip, &r[..=n_lift]) + let eq_mle = eval_eq_mle(&u[n_lift + 1..], &b); + let ind = eval_in_uni(l_skip, n, u[0]); + let (l, rs_n) = if n.is_negative() { + ( + l_skip.wrapping_add_signed(n), + &[r[0].exp_power_of_2(-n as usize)] as &[_], + ) + } else { + (l_skip, &r[..=n_lift]) }; let eq_prism = eval_eq_prism(l, &u[..=n_lift], rs_n); let rot_kernel_prism = eval_rot_kernel_prism(l, &u[..=n_lift], rs_n); From e069b4561207526b26425c3de0fdb4d85d08f1c8 Mon Sep 17 00:00:00 2001 From: Alexander Golovanov Date: Tue, 27 Jan 2026 18:07:20 +0000 Subject: [PATCH 3/7] `cargo clippy --fix` --- .../src/verifier/stacked_reduction.rs | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/crates/stark-backend-v2/src/verifier/stacked_reduction.rs b/crates/stark-backend-v2/src/verifier/stacked_reduction.rs index 61e39de1..51e085a5 100644 --- a/crates/stark-backend-v2/src/verifier/stacked_reduction.rs +++ b/crates/stark-backend-v2/src/verifier/stacked_reduction.rs @@ -91,10 +91,8 @@ pub fn verify_stacked_reduction( t_claims.push(t); if need_rot { t_claims.push(t_rot); - } else { - if t_rot != EF::ZERO { - return Err(StackedReductionError::RotationsNotNeeded); - } + } else if t_rot != EF::ZERO { + return Err(StackedReductionError::RotationsNotNeeded); } } } @@ -108,10 +106,8 @@ pub fn verify_stacked_reduction( t_claims.push(t); if need_rot { t_claims.push(t_rot); - } else { - if t_rot != EF::ZERO { - return Err(StackedReductionError::RotationsNotNeeded); - } + } else if t_rot != EF::ZERO { + return Err(StackedReductionError::RotationsNotNeeded); } } commit_idx += 1; From c5a799e9e188406ed82a859a2ee03a0b72e70b83 Mon Sep 17 00:00:00 2001 From: Alexander Golovanov Date: Thu, 29 Jan 2026 21:17:00 +0000 Subject: [PATCH 4/7] Return always increasing lambda power by 2 --- .../src/prover/stacked_reduction.rs | 22 ++++++-------- .../src/verifier/stacked_reduction.rs | 30 +++++++------------ 2 files changed, 20 insertions(+), 32 deletions(-) diff --git a/crates/stark-backend-v2/src/prover/stacked_reduction.rs b/crates/stark-backend-v2/src/prover/stacked_reduction.rs index 246bec72..8ef95552 100644 --- a/crates/stark-backend-v2/src/prover/stacked_reduction.rs +++ b/crates/stark-backend-v2/src/prover/stacked_reduction.rs @@ -150,7 +150,8 @@ struct TraceViewMeta { com_idx: usize, slice: StackedSlice, lambda_eq_idx: usize, - lambda_rot_idx: Option, + lambda_rot_idx: usize, + need_rot: bool, } impl<'a> StackedReductionProver<'a, CpuBackendV2, CpuDeviceV2> for StackedReductionCpu<'a> { @@ -172,19 +173,14 @@ impl<'a> StackedReductionProver<'a, CpuBackendV2, CpuDeviceV2> for StackedReduct for &(mat_idx, _col_idx, slice) in &d.layout.sorted_cols { let need_rot = need_rot_for_commit[mat_idx]; let lambda_eq_idx = lambda_idx; - lambda_idx += 1; - let lambda_rot_idx = if need_rot { - let idx = lambda_idx; - lambda_idx += 1; - Some(idx) - } else { - None - }; + let lambda_rot_idx = lambda_idx + 1; + lambda_idx += 2; trace_views.push(TraceViewMeta { com_idx, slice, lambda_eq_idx, lambda_rot_idx, + need_rot, }); } } @@ -299,8 +295,8 @@ impl<'a> StackedReductionProver<'a, CpuBackendV2, CpuDeviceV2> for StackedReduct zip(t_window, evals).fold([EF::ZERO; 2], |mut acc, (tv, eval)| { let q = eval[0]; acc[0] += self.lambda_pows[tv.lambda_eq_idx] * eq * q * ind; - if let Some(rot_idx) = tv.lambda_rot_idx { - acc[1] += self.lambda_pows[rot_idx] * k_rot * q * ind; + if tv.need_rot { + acc[1] += self.lambda_pows[tv.lambda_rot_idx] * k_rot * q * ind; } acc }) @@ -434,8 +430,8 @@ impl<'a> StackedReductionProver<'a, CpuBackendV2, CpuDeviceV2> for StackedReduct (eq_r * eq_ub, k_rot_r * eq_ub) }; acc[0] += self.lambda_pows[tv.lambda_eq_idx] * q * eq; - if let Some(rot_idx) = tv.lambda_rot_idx { - acc[1] += self.lambda_pows[rot_idx] * q * k_rot; + if tv.need_rot { + acc[1] += self.lambda_pows[tv.lambda_rot_idx] * q * k_rot; } acc }) diff --git a/crates/stark-backend-v2/src/verifier/stacked_reduction.rs b/crates/stark-backend-v2/src/verifier/stacked_reduction.rs index 51e085a5..99ac4987 100644 --- a/crates/stark-backend-v2/src/verifier/stacked_reduction.rs +++ b/crates/stark-backend-v2/src/verifier/stacked_reduction.rs @@ -57,7 +57,7 @@ pub fn verify_stacked_reduction( debug_assert_eq!(layouts.len(), need_rot_per_commit.len()); let mut lambda_idx = 0usize; - let lambda_indices_per_layout: Vec)>> = layouts + let lambda_indices_per_layout: Vec> = layouts .iter() .enumerate() .map(|(commit_idx, layout)| { @@ -68,15 +68,9 @@ pub fn verify_stacked_reduction( .iter() .map(|&(mat_idx, _col_idx, _slice)| { let lambda_eq_idx = lambda_idx; - lambda_idx += 1; - let lambda_rot_idx = if need_rot_for_commit[mat_idx] { - let idx = lambda_idx; - lambda_idx += 1; - Some(idx) - } else { - None - }; - (lambda_eq_idx, lambda_rot_idx) + let lambda_rot_idx = lambda_idx + 1; + lambda_idx += 2; + (lambda_eq_idx, lambda_rot_idx, need_rot_for_commit[mat_idx]) }) .collect_vec() }) @@ -89,9 +83,8 @@ pub fn verify_stacked_reduction( let need_rot = need_rot_per_commit[0][trace_idx]; for &(t, t_rot) in &parts[0] { t_claims.push(t); - if need_rot { - t_claims.push(t_rot); - } else if t_rot != EF::ZERO { + t_claims.push(t_rot); + if !need_rot && t_rot != EF::ZERO { return Err(StackedReductionError::RotationsNotNeeded); } } @@ -104,9 +97,8 @@ pub fn verify_stacked_reduction( let need_rot = need_rot_per_commit[commit_idx][0]; for &(t, t_rot) in cols { t_claims.push(t); - if need_rot { - t_claims.push(t_rot); - } else if t_rot != EF::ZERO { + t_claims.push(t_rot); + if !need_rot && t_rot != EF::ZERO { return Err(StackedReductionError::RotationsNotNeeded); } } @@ -206,7 +198,7 @@ pub fn verify_stacked_reduction( .iter() .enumerate() .for_each(|(col_idx, &(_, _, s))| { - let (lambda_eq_idx, lambda_rot_idx) = lambda_indices[col_idx]; + let (lambda_eq_idx, lambda_rot_idx, need_rot) = lambda_indices[col_idx]; let n = s.log_height() as isize - l_skip as isize; let n_lift = n.max(0) as usize; let b = (l_skip + n_lift..l_skip + n_stack) @@ -225,8 +217,8 @@ pub fn verify_stacked_reduction( let eq_prism = eval_eq_prism(l, &u[..=n_lift], rs_n); let rot_kernel_prism = eval_rot_kernel_prism(l, &u[..=n_lift], rs_n); let mut batched = lambda_powers[lambda_eq_idx] * eq_prism; - if let Some(rot_idx) = lambda_rot_idx { - batched += lambda_powers[rot_idx] * rot_kernel_prism; + if need_rot { + batched += lambda_powers[lambda_rot_idx] * rot_kernel_prism; } coeffs[s.col_idx] += eq_mle * batched * ind; }); From be9d4b994c20106b3c9ee75ea5d798844e9fed90 Mon Sep 17 00:00:00 2001 From: Alexander Golovanov Date: Mon, 2 Feb 2026 23:37:05 +0100 Subject: [PATCH 5/7] Move rot computation under an if --- crates/stark-backend-v2/src/verifier/stacked_reduction.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/stark-backend-v2/src/verifier/stacked_reduction.rs b/crates/stark-backend-v2/src/verifier/stacked_reduction.rs index 99ac4987..7f782dd0 100644 --- a/crates/stark-backend-v2/src/verifier/stacked_reduction.rs +++ b/crates/stark-backend-v2/src/verifier/stacked_reduction.rs @@ -215,9 +215,9 @@ pub fn verify_stacked_reduction( (l_skip, &r[..=n_lift]) }; let eq_prism = eval_eq_prism(l, &u[..=n_lift], rs_n); - let rot_kernel_prism = eval_rot_kernel_prism(l, &u[..=n_lift], rs_n); let mut batched = lambda_powers[lambda_eq_idx] * eq_prism; if need_rot { + let rot_kernel_prism = eval_rot_kernel_prism(l, &u[..=n_lift], rs_n); batched += lambda_powers[lambda_rot_idx] * rot_kernel_prism; } coeffs[s.col_idx] += eq_mle * batched * ind; From a4aa1283943cad1a4da30a3614a3afc6f1cf5d77 Mon Sep 17 00:00:00 2001 From: Alexander Golovanov Date: Tue, 3 Feb 2026 15:46:06 +0100 Subject: [PATCH 6/7] Remove a redundant parameter --- crates/stark-backend-v2/src/prover/stacked_reduction.rs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/crates/stark-backend-v2/src/prover/stacked_reduction.rs b/crates/stark-backend-v2/src/prover/stacked_reduction.rs index 8ef95552..e70c40ca 100644 --- a/crates/stark-backend-v2/src/prover/stacked_reduction.rs +++ b/crates/stark-backend-v2/src/prover/stacked_reduction.rs @@ -150,7 +150,6 @@ struct TraceViewMeta { com_idx: usize, slice: StackedSlice, lambda_eq_idx: usize, - lambda_rot_idx: usize, need_rot: bool, } @@ -173,13 +172,11 @@ impl<'a> StackedReductionProver<'a, CpuBackendV2, CpuDeviceV2> for StackedReduct for &(mat_idx, _col_idx, slice) in &d.layout.sorted_cols { let need_rot = need_rot_for_commit[mat_idx]; let lambda_eq_idx = lambda_idx; - let lambda_rot_idx = lambda_idx + 1; lambda_idx += 2; trace_views.push(TraceViewMeta { com_idx, slice, lambda_eq_idx, - lambda_rot_idx, need_rot, }); } @@ -296,7 +293,7 @@ impl<'a> StackedReductionProver<'a, CpuBackendV2, CpuDeviceV2> for StackedReduct let q = eval[0]; acc[0] += self.lambda_pows[tv.lambda_eq_idx] * eq * q * ind; if tv.need_rot { - acc[1] += self.lambda_pows[tv.lambda_rot_idx] * k_rot * q * ind; + acc[1] += self.lambda_pows[tv.lambda_eq_idx + 1] * k_rot * q * ind; } acc }) @@ -431,7 +428,7 @@ impl<'a> StackedReductionProver<'a, CpuBackendV2, CpuDeviceV2> for StackedReduct }; acc[0] += self.lambda_pows[tv.lambda_eq_idx] * q * eq; if tv.need_rot { - acc[1] += self.lambda_pows[tv.lambda_rot_idx] * q * k_rot; + acc[1] += self.lambda_pows[tv.lambda_eq_idx + 1] * q * k_rot; } acc }) From 9bfa64c413480635eec2822a662ac6eb9ba36190 Mon Sep 17 00:00:00 2001 From: Alexander Golovanov Date: Wed, 4 Feb 2026 16:02:16 +0100 Subject: [PATCH 7/7] Do not pack zeroes when no rotations --- crates/stark-backend-v2/src/proof.rs | 26 +++++++--- .../src/prover/logup_zerocheck/cpu.rs | 10 ++-- .../src/prover/logup_zerocheck/mod.rs | 18 +++---- .../src/prover/stacked_reduction.rs | 21 +++++---- .../src/verifier/batch_constraints.rs | 30 +++++------- .../src/verifier/proof_shape.rs | 17 +++++-- .../src/verifier/stacked_reduction.rs | 47 ++++++------------- 7 files changed, 89 insertions(+), 80 deletions(-) diff --git a/crates/stark-backend-v2/src/proof.rs b/crates/stark-backend-v2/src/proof.rs index 53ebd1fa..18f005c2 100644 --- a/crates/stark-backend-v2/src/proof.rs +++ b/crates/stark-backend-v2/src/proof.rs @@ -7,6 +7,8 @@ use crate::{ Digest, EF, F, }; +use p3_field::FieldAlgebra; + #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct Proof { /// The commitment to the data in common_main. @@ -82,11 +84,23 @@ pub struct BatchConstraintProof { /// For rounds `1, ..., n_max`; evaluations on `{1, ..., vk.d + 1}`. pub sumcheck_round_polys: Vec>, - /// Per AIR **in sorted AIR order**, per AIR part, per column index in that part, opening of - /// the prismalinear column polynomial and its rotational convolution. - /// The trace parts are ordered: [CommonMain (part - /// 0), Preprocessed (if any), Cached(0), Cached(1), ...] - pub column_openings: Vec>>, + /// Per AIR **in sorted AIR order**, per AIR part, per column index in that part, openings for + /// the prismalinear column polynomial and (optionally) its rotational convolution. All column + /// openings are stored in a flat way, so only column openings or them interleaved with + /// rotations. The trace parts are ordered: [CommonMain (part 0), Preprocessed (if any), + /// Cached(0), Cached(1), ...] + pub column_openings: Vec>>, +} + +pub fn column_openings_by_rot<'a>( + openings: &'a [EF], + need_rot: bool, +) -> Box + 'a> { + if need_rot { + Box::new(openings.chunks_exact(2).map(|chunk| (chunk[0], chunk[1]))) + } else { + Box::new(openings.iter().map(|&claim| (claim, EF::ZERO))) + } } #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] @@ -384,7 +398,7 @@ impl Decode for BatchConstraintProof { let mut column_openings = Vec::with_capacity(num_present_airs); for _ in 0..num_present_airs { - column_openings.push(Vec::>::decode(reader)?); + column_openings.push(Vec::>::decode(reader)?); } Ok(Self { diff --git a/crates/stark-backend-v2/src/prover/logup_zerocheck/cpu.rs b/crates/stark-backend-v2/src/prover/logup_zerocheck/cpu.rs index 26a47a16..0b6fe1e6 100644 --- a/crates/stark-backend-v2/src/prover/logup_zerocheck/cpu.rs +++ b/crates/stark-backend-v2/src/prover/logup_zerocheck/cpu.rs @@ -599,7 +599,7 @@ impl<'a> LogupZerocheckCpu<'a> { } } - pub fn into_column_openings(mut self) -> Vec>> { + pub fn into_column_openings(&mut self) -> Vec>> { let num_airs_present = self.mat_evals_per_trace.len(); let mut column_openings = Vec::with_capacity(num_airs_present); // At the end, we've folded all MLEs so they only have one row equal to evaluation at `\vec @@ -611,17 +611,17 @@ impl<'a> LogupZerocheckCpu<'a> { { // For column openings, we pop common_main (and common_main_rot when present) and put it // at the front. - let openings_of_air = if helper.needs_next { + let openings_of_air: Vec> = if helper.needs_next { let common_main_rot = mat_evals.pop().unwrap(); let common_main = mat_evals.pop().unwrap(); iter::once(&[common_main, common_main_rot] as &[_]) .chain(mat_evals.chunks_exact(2)) .map(|pair| { zip(pair[0].columns(), pair[1].columns()) - .map(|(claim, claim_rot)| { + .flat_map(|(claim, claim_rot)| { assert_eq!(claim.len(), 1); assert_eq!(claim_rot.len(), 1); - (claim[0], claim_rot[0]) + [claim[0], claim_rot[0]] }) .collect_vec() }) @@ -634,7 +634,7 @@ impl<'a> LogupZerocheckCpu<'a> { mat.columns() .map(|claim| { assert_eq!(claim.len(), 1); - (claim[0], EF::ZERO) + claim[0] }) .collect_vec() }) diff --git a/crates/stark-backend-v2/src/prover/logup_zerocheck/mod.rs b/crates/stark-backend-v2/src/prover/logup_zerocheck/mod.rs index b38ab095..80681417 100644 --- a/crates/stark-backend-v2/src/prover/logup_zerocheck/mod.rs +++ b/crates/stark-backend-v2/src/prover/logup_zerocheck/mod.rs @@ -16,7 +16,7 @@ use crate::{ dft::Radix2BowersSerial, poly_common::{eq_sharp_uni_poly, eq_uni_poly, UnivariatePoly}, poseidon2::sponge::FiatShamirTranscript, - proof::{BatchConstraintProof, GkrProof}, + proof::{column_openings_by_rot, BatchConstraintProof, GkrProof}, prover::{ fractional_sumcheck_gkr::{fractional_sumcheck, Frac}, stacked_pcs::StackedLayout, @@ -395,17 +395,17 @@ where let column_openings = prover.into_column_openings(); // Observe common main openings first, and then preprocessed/cached - for openings in &column_openings { - for (claim, claim_rot) in &openings[0] { - transcript.observe_ext(*claim); - transcript.observe_ext(*claim_rot); + for (helper, openings) in prover.eval_helpers.iter().zip(column_openings.iter()) { + for (claim, claim_rot) in column_openings_by_rot(&openings[0], helper.needs_next) { + transcript.observe_ext(claim); + transcript.observe_ext(claim_rot); } } - for openings in &column_openings { + for (helper, openings) in prover.eval_helpers.iter().zip(column_openings.iter()) { for part in openings.iter().skip(1) { - for (claim, claim_rot) in part { - transcript.observe_ext(*claim); - transcript.observe_ext(*claim_rot); + for (claim, claim_rot) in column_openings_by_rot(part, helper.needs_next) { + transcript.observe_ext(claim); + transcript.observe_ext(claim_rot); } } } diff --git a/crates/stark-backend-v2/src/prover/stacked_reduction.rs b/crates/stark-backend-v2/src/prover/stacked_reduction.rs index e70c40ca..c78fda5b 100644 --- a/crates/stark-backend-v2/src/prover/stacked_reduction.rs +++ b/crates/stark-backend-v2/src/prover/stacked_reduction.rs @@ -150,7 +150,7 @@ struct TraceViewMeta { com_idx: usize, slice: StackedSlice, lambda_eq_idx: usize, - need_rot: bool, + lambda_rot_idx: Option, } impl<'a> StackedReductionProver<'a, CpuBackendV2, CpuDeviceV2> for StackedReductionCpu<'a> { @@ -170,14 +170,19 @@ impl<'a> StackedReductionProver<'a, CpuBackendV2, CpuDeviceV2> for StackedReduct let need_rot_for_commit = &need_rot_per_commit[com_idx]; debug_assert_eq!(need_rot_for_commit.len(), d.layout.mat_starts.len()); for &(mat_idx, _col_idx, slice) in &d.layout.sorted_cols { - let need_rot = need_rot_for_commit[mat_idx]; let lambda_eq_idx = lambda_idx; - lambda_idx += 2; + lambda_idx += 1; + let lambda_rot_idx = if need_rot_for_commit[mat_idx] { + Some(lambda_idx) + } else { + None + }; + lambda_idx += 1; trace_views.push(TraceViewMeta { com_idx, slice, lambda_eq_idx, - need_rot, + lambda_rot_idx, }); } } @@ -292,8 +297,8 @@ impl<'a> StackedReductionProver<'a, CpuBackendV2, CpuDeviceV2> for StackedReduct zip(t_window, evals).fold([EF::ZERO; 2], |mut acc, (tv, eval)| { let q = eval[0]; acc[0] += self.lambda_pows[tv.lambda_eq_idx] * eq * q * ind; - if tv.need_rot { - acc[1] += self.lambda_pows[tv.lambda_eq_idx + 1] * k_rot * q * ind; + if let Some(rot_idx) = tv.lambda_rot_idx { + acc[1] += self.lambda_pows[rot_idx] * k_rot * q * ind; } acc }) @@ -427,8 +432,8 @@ impl<'a> StackedReductionProver<'a, CpuBackendV2, CpuDeviceV2> for StackedReduct (eq_r * eq_ub, k_rot_r * eq_ub) }; acc[0] += self.lambda_pows[tv.lambda_eq_idx] * q * eq; - if tv.need_rot { - acc[1] += self.lambda_pows[tv.lambda_eq_idx + 1] * q * k_rot; + if let Some(rot_idx) = tv.lambda_rot_idx { + acc[1] += self.lambda_pows[rot_idx] * q * k_rot; } acc }) diff --git a/crates/stark-backend-v2/src/verifier/batch_constraints.rs b/crates/stark-backend-v2/src/verifier/batch_constraints.rs index eb4d10cc..f61fb5c0 100644 --- a/crates/stark-backend-v2/src/verifier/batch_constraints.rs +++ b/crates/stark-backend-v2/src/verifier/batch_constraints.rs @@ -16,7 +16,7 @@ use crate::{ keygen::types::MultiStarkVerifyingKey0V2, poly_common::{eval_eq_mle, eval_eq_sharp_uni, eval_eq_uni, UnivariatePoly}, poseidon2::sponge::FiatShamirTranscript, - proof::{BatchConstraintProof, GkrProof}, + proof::{column_openings_by_rot, BatchConstraintProof, GkrProof}, verifier::{ evaluator::VerifierConstraintEvaluator, fractional_sumcheck_gkr::{verify_gkr, GkrVerificationError}, @@ -48,9 +48,6 @@ pub enum BatchConstraintError { #[error("Claims are inconsistent")] InconsistentClaims, - - #[error("rotation opening provided when rotations are not needed")] - RotationsNotNeeded, } /// `public_values` should be in vkey (air_idx) order, including non-present AIRs. @@ -269,10 +266,7 @@ pub fn verify_zerocheck_and_logup( // Observe common main openings first, and then preprocessed/cached for (trace_idx, air_openings) in column_openings.iter().enumerate() { let need_rot = need_rot_per_trace[trace_idx]; - for &(claim, claim_rot) in &air_openings[0] { - if !need_rot && claim_rot != EF::ZERO { - return Err(BatchConstraintError::RotationsNotNeeded); - } + for (claim, claim_rot) in column_openings_by_rot(&air_openings[0], need_rot) { transcript.observe_ext(claim); transcript.observe_ext(claim_rot); } @@ -287,24 +281,26 @@ pub fn verify_zerocheck_and_logup( // claim lengths are checked in proof shape for claims in air_openings.iter().skip(1) { - for &(claim, claim_rot) in claims.iter() { - if !need_rot && claim_rot != EF::ZERO { - return Err(BatchConstraintError::RotationsNotNeeded); - } + for (claim, claim_rot) in column_openings_by_rot(claims, need_rot) { transcript.observe_ext(claim); transcript.observe_ext(claim_rot); } } let has_preprocessed = vk.preprocessed_data.is_some(); - let common_main = air_openings[0].as_slice(); - let preprocessed = has_preprocessed.then(|| air_openings[1].as_slice()); + let common_main = column_openings_by_rot(&air_openings[0], need_rot).collect::>(); + let preprocessed = has_preprocessed + .then(|| column_openings_by_rot(&air_openings[1], need_rot).collect::>()); let cached_idx = 1 + has_preprocessed as usize; let mut partitioned_main: Vec<_> = air_openings[cached_idx..] .iter() - .map(|opening| opening.as_slice()) + .map(|opening| column_openings_by_rot(opening, need_rot).collect::>()) .collect(); partitioned_main.push(common_main); + let part_main_slices = partitioned_main + .iter() + .map(|x| x.as_slice()) + .collect::>(); // We are evaluating the lift, which is the same as evaluating the original with domain // D^{(2^{n})} @@ -318,8 +314,8 @@ pub fn verify_zerocheck_and_logup( (l_skip, &rs[..=(n as usize)], F::ONE) }; let evaluator = VerifierConstraintEvaluator::::new( - preprocessed, - &partitioned_main, + preprocessed.as_deref(), + &part_main_slices, &public_values[air_idx], rs_n, l, diff --git a/crates/stark-backend-v2/src/verifier/proof_shape.rs b/crates/stark-backend-v2/src/verifier/proof_shape.rs index ee9bc14b..a8e2a451 100644 --- a/crates/stark-backend-v2/src/verifier/proof_shape.rs +++ b/crates/stark-backend-v2/src/verifier/proof_shape.rs @@ -158,6 +158,15 @@ pub enum BatchProofShapeError { expected: usize, actual: usize, }, + #[error( + "Column opening for AIR {air_idx} (part {part_idx}) should have {expected} values, but has {actual}" + )] + InvalidColumnOpeningLen { + air_idx: usize, + part_idx: usize, + expected: usize, + actual: usize, + }, } #[derive(Debug, Error, PartialEq, Eq)] @@ -464,6 +473,8 @@ pub fn verify_proof_shape( } for (part_openings, &(air_idx, vk, _)) in batch_proof.column_openings.iter().zip(&per_trace) { + let need_rot = mvk.per_air[air_idx].params.need_rot; + let openings_per_col = if need_rot { 2 } else { 1 }; if part_openings.len() != vk.num_parts() { return ProofShapeError::invalid_batch_constraint( BatchProofShapeError::InvalidColumnOpeningsPerAir { @@ -472,7 +483,7 @@ pub fn verify_proof_shape( actual: part_openings.len(), }, ); - } else if part_openings[0].len() != vk.params.width.common_main { + } else if part_openings[0].len() != vk.params.width.common_main * openings_per_col { return ProofShapeError::invalid_batch_constraint( BatchProofShapeError::InvalidColumnOpeningsPerAirMain { air_idx, @@ -481,7 +492,7 @@ pub fn verify_proof_shape( }, ); } else if let Some(preprocessed_width) = &vk.params.width.preprocessed { - if part_openings[1].len() != *preprocessed_width { + if part_openings[1].len() != *preprocessed_width * openings_per_col { return ProofShapeError::invalid_batch_constraint( BatchProofShapeError::InvalidColumnOpeningsPerAirPreprocessed { air_idx, @@ -498,7 +509,7 @@ pub fn verify_proof_shape( .zip(&vk.params.width.cached_mains) .enumerate() { - if col_opening.len() != width { + if col_opening.len() != width * openings_per_col { return ProofShapeError::invalid_batch_constraint( BatchProofShapeError::InvalidColumnOpeningsPerAirCached { air_idx, diff --git a/crates/stark-backend-v2/src/verifier/stacked_reduction.rs b/crates/stark-backend-v2/src/verifier/stacked_reduction.rs index 7f782dd0..d6e2e436 100644 --- a/crates/stark-backend-v2/src/verifier/stacked_reduction.rs +++ b/crates/stark-backend-v2/src/verifier/stacked_reduction.rs @@ -11,7 +11,7 @@ use crate::{ interpolate_quadratic_at_012, }, poseidon2::sponge::FiatShamirTranscript, - proof::StackingProof, + proof::{column_openings_by_rot, StackingProof}, prover::stacked_pcs::StackedLayout, EF, F, }; @@ -23,9 +23,6 @@ pub enum StackedReductionError { #[error("s_n(u_n) does not match claimed q(u) sum: {claim} != {final_sum}")] FinalSumMismatch { claim: EF, final_sum: EF }, - - #[error("rotation opening provided when rotations are not needed")] - RotationsNotNeeded, } /// `has_preprocessed` must be per present trace in sorted AIR order. @@ -38,7 +35,7 @@ pub fn verify_stacked_reduction( need_rot_per_commit: &[Vec], l_skip: usize, n_stack: usize, - column_openings: &Vec>>, + column_openings: &Vec>>, r: &[EF], omega_shift_pows: &[F], ) -> Result, StackedReductionError> { @@ -57,7 +54,7 @@ pub fn verify_stacked_reduction( debug_assert_eq!(layouts.len(), need_rot_per_commit.len()); let mut lambda_idx = 0usize; - let lambda_indices_per_layout: Vec> = layouts + let lambda_indices_per_layout: Vec> = layouts .iter() .enumerate() .map(|(commit_idx, layout)| { @@ -67,10 +64,8 @@ pub fn verify_stacked_reduction( .sorted_cols .iter() .map(|&(mat_idx, _col_idx, _slice)| { - let lambda_eq_idx = lambda_idx; - let lambda_rot_idx = lambda_idx + 1; - lambda_idx += 2; - (lambda_eq_idx, lambda_rot_idx, need_rot_for_commit[mat_idx]) + lambda_idx += 1; + (lambda_idx - 1, need_rot_for_commit[mat_idx]) }) .collect_vec() }) @@ -81,13 +76,7 @@ pub fn verify_stacked_reduction( // common main columns (commit 0) for (trace_idx, parts) in column_openings.iter().enumerate() { let need_rot = need_rot_per_commit[0][trace_idx]; - for &(t, t_rot) in &parts[0] { - t_claims.push(t); - t_claims.push(t_rot); - if !need_rot && t_rot != EF::ZERO { - return Err(StackedReductionError::RotationsNotNeeded); - } - } + t_claims.extend(column_openings_by_rot(&parts[0], need_rot)); } // preprocessed and cached columns (commits 1..) @@ -95,13 +84,7 @@ pub fn verify_stacked_reduction( for parts in column_openings { for cols in parts.iter().skip(1) { let need_rot = need_rot_per_commit[commit_idx][0]; - for &(t, t_rot) in cols { - t_claims.push(t); - t_claims.push(t_rot); - if !need_rot && t_rot != EF::ZERO { - return Err(StackedReductionError::RotationsNotNeeded); - } - } + t_claims.extend(column_openings_by_rot(cols, need_rot)); commit_idx += 1; } } @@ -110,7 +93,7 @@ pub fn verify_stacked_reduction( debug!(?t_claims); let lambda = transcript.sample_ext(); - let lambda_powers = lambda.powers().take(t_claims_len).collect_vec(); + let lambda_sqr_powers = (lambda * lambda).powers().take(t_claims_len).collect_vec(); /* * INITIAL UNIVARIATE ROUND @@ -123,8 +106,8 @@ pub fn verify_stacked_reduction( * of sum_{z in D} s_1(z). Suppose s_1(x) = a_0 + a_1 * x + ... a_k * x^k. Because we have * omega^{|D|} == 1, sum_{z in D} s_1(z) = |D| * (a_0 + a_{|D|} + ...). */ - let s_0 = zip(&t_claims, &lambda_powers) - .map(|(&t_i, &lambda_i)| t_i * lambda_i) + let s_0 = zip(&t_claims, &lambda_sqr_powers) + .map(|(&t_i, &lambda_i)| (t_i.0 + t_i.1 * lambda) * lambda_i) .sum::(); let s_0_sum_eval = proof .univariate_round_coeffs @@ -198,7 +181,7 @@ pub fn verify_stacked_reduction( .iter() .enumerate() .for_each(|(col_idx, &(_, _, s))| { - let (lambda_eq_idx, lambda_rot_idx, need_rot) = lambda_indices[col_idx]; + let (lambda_idx, need_rot) = lambda_indices[col_idx]; let n = s.log_height() as isize - l_skip as isize; let n_lift = n.max(0) as usize; let b = (l_skip + n_lift..l_skip + n_stack) @@ -215,10 +198,10 @@ pub fn verify_stacked_reduction( (l_skip, &r[..=n_lift]) }; let eq_prism = eval_eq_prism(l, &u[..=n_lift], rs_n); - let mut batched = lambda_powers[lambda_eq_idx] * eq_prism; + let mut batched = lambda_sqr_powers[lambda_idx] * eq_prism; if need_rot { let rot_kernel_prism = eval_rot_kernel_prism(l, &u[..=n_lift], rs_n); - batched += lambda_powers[lambda_rot_idx] * rot_kernel_prism; + batched += lambda_sqr_powers[lambda_idx] * lambda * rot_kernel_prism; } coeffs[s.col_idx] += eq_mle * batched * ind; }); @@ -263,7 +246,7 @@ mod tests { pub proof: StackingProof, pub layouts: Vec, pub need_rot_per_commit: Vec>, - pub column_openings: Vec>>, + pub column_openings: Vec>>, pub r: Vec, pub omega_pows: Vec, } @@ -338,7 +321,7 @@ mod tests { let t_rot = omega_pows.iter().fold(EF::ZERO, |acc, &omega| { acc + compute_t::(&q, &r, &b, &u, EF::from_base(omega), 0, L_SKIP) }); - let column_openings = vec![vec![vec![(t, t_rot)]]]; + let column_openings = vec![vec![vec![t, t_rot]]]; let mut transcript = DuplexSponge::default(); let lambda = transcript.sample_ext();