Skip to content
Merged
12 changes: 12 additions & 0 deletions miden-air/src/air.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
use crate::{MidenAirBuilder, RowMajorMatrix};

pub enum BusType {
/// A multiset bus
Multiset,
/// A logup bus
Logup,
}

/// Super trait for all AIR definitions in the Miden VM ecosystem.
///
/// This trait contains all methods from `BaseAir`, `BaseAirWithPublicValues`,
Expand Down Expand Up @@ -59,6 +66,11 @@ pub trait MidenAir<F, EF>: Sync {
0
}

/// Types of buses
fn bus_types(&self) -> Vec<BusType> {
vec![]
}

/// Build an aux trace (EF-based) given the main trace and EF challenges.
/// Return None to indicate no aux or to fall back to legacy behavior.
/// The output is a matrix of EF elements, flattened to a matrix of F elements.
Expand Down
2 changes: 1 addition & 1 deletion miden-air/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ mod air;
mod builder;
mod filtered_builder;

pub use air::MidenAir;
pub use air::{BusType, MidenAir};
pub use builder::MidenAirBuilder;
pub use filtered_builder::FilteredMidenAirBuilder;
// Re-export for convenience
Expand Down
93 changes: 93 additions & 0 deletions miden-prover/src/air_wrapper_bus_boundary.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
use std::marker::PhantomData;

use miden_air::{BusType, MidenAir, MidenAirBuilder};
use p3_field::{PrimeCharacteristicRing, TwoAdicField};
use p3_matrix::Matrix;
use p3_matrix::dense::RowMajorMatrix;
use p3_maybe_rayon::prelude::*;

use crate::{StarkGenericConfig, Val};

pub struct AirWithBoundaryConstraints<'a, SC, A>
where
SC: StarkGenericConfig + std::marker::Sync,
A: MidenAir<Val<SC>, SC::Challenge>,
Val<SC>: TwoAdicField,
{
pub inner: &'a A,
pub phantom: PhantomData<SC>,
}

impl<'a, SC, A> MidenAir<Val<SC>, SC::Challenge> for AirWithBoundaryConstraints<'a, SC, A>
where
SC: StarkGenericConfig + std::marker::Sync,
A: MidenAir<Val<SC>, SC::Challenge>,
Val<SC>: TwoAdicField,
{
fn width(&self) -> usize {
self.inner.width()
}

fn preprocessed_trace(&self) -> Option<RowMajorMatrix<Val<SC>>> {
self.inner.preprocessed_trace()
}

fn num_public_values(&self) -> usize {
self.inner.num_public_values()
}

fn periodic_table(&self) -> Vec<Vec<Val<SC>>> {
self.inner.periodic_table()
}

fn num_randomness(&self) -> usize {
self.inner.num_randomness()
}

fn aux_width(&self) -> usize {
self.inner.aux_width()
}

/// Types of buses
fn bus_types(&self) -> Vec<BusType> {
self.inner.bus_types()
}

fn build_aux_trace(
&self,
_main: &RowMajorMatrix<Val<SC>>,
_challenges: &[SC::Challenge],
) -> Option<RowMajorMatrix<Val<SC>>> {
self.inner.build_aux_trace(_main, _challenges)
}

fn eval<AB: MidenAirBuilder<F = Val<SC>>>(&self, builder: &mut AB) {
// First, apply the inner AIR's constraints
self.inner.eval(builder);

if self.inner.num_randomness() > 0 {
// Then, apply any additional boundary constraints as needed
let aux = builder.permutation();
let aux_current = aux.row_slice(0).unwrap();
let aux_bus_boundary_values = builder.aux_bus_boundary_values().to_vec();

for (idx, bus_type) in self.inner.bus_types().into_iter().enumerate() {
match bus_type {
BusType::Multiset => {
builder
.when_first_row()
.assert_zero_ext(aux_current[idx].into() - AB::ExprEF::ONE);
}
BusType::Logup => {
builder
.when_first_row()
.assert_zero_ext(aux_current[idx].into());
}
}
builder
.when_last_row()
.assert_zero_ext(aux_current[idx].into() - aux_bus_boundary_values[idx].into());
}
}
}
}
14 changes: 13 additions & 1 deletion miden-prover/src/check_constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,15 @@ pub(crate) fn check_constraints<F, EF, A>(
)
};

let aux_bus_boundary_values;
if let Some(aux_matrix) = aux_trace.as_ref() {
let aux_bus_boundary_values_base =
unsafe { aux_matrix.row_slice_unchecked(height - 1) };
aux_bus_boundary_values = prover_row_to_ext::<F, EF>(&aux_bus_boundary_values_base);
} else {
aux_bus_boundary_values = vec![];
};

let preprocessed_pair = preprocessed.as_ref().map(|preprocessed_matrix| {
let preprocessed_local = preprocessed_matrix
.values
Expand Down Expand Up @@ -109,6 +118,7 @@ pub(crate) fn check_constraints<F, EF, A>(
main,
aux,
aux_randomness,
aux_bus_boundary_values: &aux_bus_boundary_values,
preprocessed: preprocessed_pair,
public_values,
periodic_values,
Expand All @@ -135,6 +145,8 @@ pub struct DebugConstraintBuilder<'a, F: Field, EF: ExtensionField<F>> {
aux: ViewPair<'a, EF>,
/// randomness that is used to compute aux trace
aux_randomness: &'a [EF],
/// Aux bus boundary values (against the last row)
aux_bus_boundary_values: &'a [EF],
/// A view of the preprocessed current and next row as a vertical pair (if present).
preprocessed: Option<ViewPair<'a, F>>,
/// The public values provided for constraint validation (e.g. inputs or outputs).
Expand Down Expand Up @@ -233,7 +245,7 @@ where
}

fn aux_bus_boundary_values(&self) -> &[Self::VarEF] {
unimplemented!()
self.aux_bus_boundary_values
}

fn periodic_evals(&self) -> &[Self::PeriodicVal] {
Expand Down
8 changes: 6 additions & 2 deletions miden-prover/src/folder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ pub struct ProverConstraintFolder<'a, SC: StarkGenericConfig> {
/// The randomness used to compute the aux trace; can be zero width.
/// Cached EF randomness packed from base randomness to avoid temporary leaks
pub packed_randomness: Vec<PackedChallenge<SC>>,
/// Aux trace bus boundary values packed from base field to extension field
pub aux_bus_boundary_values: &'a [PackedChallenge<SC>],
/// The preprocessed columns (if any)
pub preprocessed: Option<RowMajorMatrixView<'a, PackedVal<SC>>>,
/// Public inputs to the AIR
Expand Down Expand Up @@ -117,7 +119,7 @@ impl<'a, SC: StarkGenericConfig> MidenAirBuilder for ProverConstraintFolder<'a,
}

fn aux_bus_boundary_values(&self) -> &[Self::VarEF] {
unimplemented!()
self.aux_bus_boundary_values
}

fn periodic_evals(&self) -> &[Self::PeriodicVal] {
Expand All @@ -137,6 +139,8 @@ pub struct VerifierConstraintFolder<'a, SC: StarkGenericConfig> {
pub aux: ViewPair<'a, SC::Challenge>,
/// The randomness used to compute the aux tract; can be zero width.
pub randomness: &'a [SC::Challenge],
/// Aux trace bus boundary values; can be zero width.
pub aux_bus_boundary_values: &'a [SC::Challenge],
/// The preprocessed columns (if any)
pub preprocessed: Option<ViewPair<'a, SC::Challenge>>,
/// Public values that are inputs to the computation
Expand Down Expand Up @@ -227,7 +231,7 @@ impl<'a, SC: StarkGenericConfig> MidenAirBuilder for VerifierConstraintFolder<'a
}

fn aux_bus_boundary_values(&self) -> &[Self::VarEF] {
unimplemented!()
self.aux_bus_boundary_values
}

fn periodic_evals(&self) -> &[Self::PeriodicVal] {
Expand Down
2 changes: 2 additions & 0 deletions miden-prover/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mod air_wrapper_bus_boundary;
#[cfg(debug_assertions)]
mod check_constraints;
mod config;
Expand All @@ -12,6 +13,7 @@ mod symbolic_variable;
mod util;
mod verifier;

pub use air_wrapper_bus_boundary::*;
#[cfg(debug_assertions)]
pub use check_constraints::*;
pub use config::*;
Expand Down
54 changes: 46 additions & 8 deletions miden-prover/src/prover.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::marker::PhantomData;

use itertools::Itertools;
use miden_air::MidenAir;
use p3_challenger::{CanObserve, FieldChallenger};
Expand All @@ -12,9 +14,11 @@ use p3_util::log2_strict_usize;
use tracing::{debug_span, info_span, instrument};

use crate::periodic_tables::compute_periodic_on_quotient_eval_domain;
use crate::util::prover_row_to_ext;
use crate::{
Commitments, Domain, OpenedValues, PackedChallenge, PackedVal, Proof, ProverConstraintFolder,
StarkGenericConfig, Val, get_log_quotient_degree, get_symbolic_constraints,
AirWithBoundaryConstraints, Commitments, Domain, OpenedValues, PackedChallenge, PackedVal,
Proof, ProverConstraintFolder, StarkGenericConfig, Val, get_log_quotient_degree,
get_symbolic_constraints,
};

/// Commits the preprocessed trace if present.
Expand Down Expand Up @@ -44,10 +48,15 @@ pub fn prove<SC, A>(
public_values: &[Val<SC>],
) -> Proof<SC>
where
SC: StarkGenericConfig,
SC: StarkGenericConfig + Sync,
A: MidenAir<Val<SC>, SC::Challenge>,
Val<SC>: TwoAdicField,
{
let air = &AirWithBoundaryConstraints {
inner: air,
phantom: PhantomData::<SC>,
};

// Compute the height `N = 2^n` and `log_2(height)`, `n`, of the trace.
let degree = trace.height();
let log_degree = log2_strict_usize(degree);
Expand Down Expand Up @@ -104,7 +113,7 @@ where
// From the degree of the constraint polynomial, compute the number
// of quotient polynomials we will split Q(x) into. This is chosen to
// always be a power of 2.
let log_quotient_degree = get_log_quotient_degree::<Val<SC>, SC::Challenge, A>(
let log_quotient_degree = get_log_quotient_degree::<Val<SC>, SC::Challenge, _>(
air,
preprocessed_width,
public_values.len(),
Expand Down Expand Up @@ -170,7 +179,7 @@ where
// begin aux trace generation (optional)
let num_randomness = air.num_randomness();

let (aux_trace_commit_opt, _aux_trace_opt, aux_trace_data_opt, randomness) =
let (aux_trace_commit_opt, _aux_trace_opt, aux_trace_data_opt, randomness, aux_finals) =
if num_randomness > 0 {
let randomness: Vec<SC::Challenge> = (0..num_randomness)
.map(|_| challenger.sample_algebra_element())
Expand All @@ -184,19 +193,30 @@ where
let aux_trace = aux_trace_opt
.expect("aux_challenges > 0 but no aux trace was provided or generated");

let aux_finals_base = aux_trace
.last_row()
.expect("aux_challenges > 0 but aux trace was empty")
.into_iter()
.collect_vec();
let aux_finals = prover_row_to_ext(&aux_finals_base);

let (aux_trace_commit, aux_trace_data) = info_span!("commit to aux trace data")
.in_scope(|| pcs.commit([(ext_trace_domain, aux_trace.clone().flatten_to_base())]));

challenger.observe(aux_trace_commit.clone());
for aux_final in &aux_finals {
challenger.observe_algebra_element(*aux_final);
}

(
Some(aux_trace_commit),
Some(aux_trace),
Some(aux_trace_data),
randomness,
aux_finals,
)
} else {
(None, None, None, vec![])
(None, None, None, vec![], vec![])
};

#[cfg(debug_assertions)]
Expand Down Expand Up @@ -254,14 +274,15 @@ where
// `C(T_1(x), ..., T_w(x), T_1(hx), ..., T_w(hx), selectors(x)) / Z_H(x)`
// at every point in the quotient domain. The degree of `Q(x)` is `<= deg(C(x)) - N = 2N - 2` in the case
// where `deg(C) = 3`. (See the discussion above constraint_degree for more details.)
let quotient_values: Vec<SC::Challenge> = quotient_values::<SC, A, _>(
let quotient_values: Vec<SC::Challenge> = quotient_values::<SC, _, _>(
air,
public_values,
trace_domain,
quotient_domain,
&trace_on_quotient_domain,
aux_trace_on_quotient_domain.as_ref(),
&randomness,
&aux_finals,
preprocessed_on_quotient_domain.as_ref(),
alpha,
constraint_count,
Expand Down Expand Up @@ -396,6 +417,12 @@ where
} else {
(None, None)
};
let aux_finals = if aux_trace_data_opt.is_some() {
Some(aux_finals)
} else {
None
};

let opened_values = OpenedValues {
trace_local,
trace_next,
Expand All @@ -410,6 +437,7 @@ where
commitments,
opened_values,
opening_proof,
aux_finals,
degree_bits: log_ext_degree,
}
}
Expand All @@ -425,12 +453,13 @@ pub fn quotient_values<SC, A, Mat>(
trace_on_quotient_domain: &Mat,
aux_trace_on_quotient_domain: Option<&Mat>,
randomness: &[SC::Challenge],
aux_bus_boundary_values: &[SC::Challenge],
preprocessed_on_quotient_domain: Option<&Mat>,
alpha: SC::Challenge,
constraint_count: usize,
) -> Vec<SC::Challenge>
where
SC: StarkGenericConfig,
SC: StarkGenericConfig + Sync,
A: MidenAir<Val<SC>, SC::Challenge>,
Mat: Matrix<Val<SC>> + Sync,
Val<SC>: TwoAdicField,
Expand Down Expand Up @@ -547,6 +576,13 @@ where
let packed_randomness: Vec<PackedChallenge<SC>> =
randomness.iter().copied().map(Into::into).collect();

// Pack aux bus boundary values
let packed_aux_bus_boundary_values: Vec<PackedChallenge<SC>> = aux_bus_boundary_values
.iter()
.copied()
.map(Into::into)
.collect();

// Grab precomputed periodic evaluations for this packed chunk.
let periodic_values: Vec<PackedChallenge<SC>> = periodic_on_quotient
.as_ref()
Expand Down Expand Up @@ -575,7 +611,9 @@ where
accumulator,
constraint_index: 0,
packed_randomness,
aux_bus_boundary_values: &packed_aux_bus_boundary_values,
};

air.eval(&mut folder);

// quotient(x) = constraints(x) / Z_H(x)
Expand Down
Loading