diff --git a/.codespellignore b/.codespellignore index bcd2c6b3..d77bb05e 100644 --- a/.codespellignore +++ b/.codespellignore @@ -1 +1,2 @@ inout +numer diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml deleted file mode 100644 index 8cb5fb6e..00000000 --- a/.github/workflows/build.yml +++ /dev/null @@ -1,23 +0,0 @@ -name: Build Workspace - -on: - push: - branches: ["main"] - pull_request: - branches: ["**"] - -concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }} - cancel-in-progress: true - -jobs: - lint: - name: Build - runs-on: - - runs-on=${{ github.run_id }}/runner=8cpu-linux-arm64 - steps: - - uses: actions/checkout@v4 - - uses: actions-rust-lang/setup-rust-toolchain@v1 - - name: Run build - run: | - cargo build --verbose diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index a9d2c622..84d33d16 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -20,9 +20,9 @@ jobs: - runs-on=${{ github.run_id }}/runner=test-gpu-nvidia steps: - uses: actions/checkout@v4 - - uses: actions-rust-lang/setup-rust-toolchain@v1 - with: - toolchain: nightly + - uses: dtolnay/rust-toolchain@nightly + env: + RUSTUP_PERMIT_COPY_RENAME: "1" - name: Build documentation run: cargo +nightly doc --workspace --no-deps diff --git a/.github/workflows/lints.yml b/.github/workflows/lints.yml index 3f566eba..65be9224 100644 --- a/.github/workflows/lints.yml +++ b/.github/workflows/lints.yml @@ -17,9 +17,11 @@ jobs: - runs-on=${{ github.run_id }}/runner=8cpu-linux-arm64 steps: - uses: actions/checkout@v5 - - uses: actions-rust-lang/setup-rust-toolchain@v1 + - uses: dtolnay/rust-toolchain@stable with: components: rustfmt, clippy + env: + RUSTUP_PERMIT_COPY_RENAME: "1" - uses: codespell-project/actions-codespell@v2 with: skip: Cargo.lock @@ -53,9 +55,11 @@ jobs: steps: - uses: runs-on/action@v2 - uses: actions/checkout@v5 - - uses: actions-rust-lang/setup-rust-toolchain@v1 + - uses: dtolnay/rust-toolchain@stable with: components: rustfmt, clippy + env: + RUSTUP_PERMIT_COPY_RENAME: "1" - uses: ./.github/actions/rust-cache-cuda - name: Verify CUDA setup run: | diff --git a/.github/workflows/rust-cuda-matcher.yml b/.github/workflows/rust-cuda-matcher.yml index 96d3e7bc..37f68cc5 100644 --- a/.github/workflows/rust-cuda-matcher.yml +++ b/.github/workflows/rust-cuda-matcher.yml @@ -20,10 +20,9 @@ jobs: steps: - uses: runs-on/action@v2 - uses: actions/checkout@v5 - - run: | # avoid cross-device link error - rustup component remove clippy || true - rm -rf ~/.rustup/toolchains/stable-x86_64-unknown-linux-gnu || true - uses: dtolnay/rust-toolchain@stable + env: + RUSTUP_PERMIT_COPY_RENAME: "1" - uses: Swatinem/rust-cache@v2 with: cache-on-failure: true diff --git a/.github/workflows/tests-backend-v2.yml b/.github/workflows/tests-backend-v2.yml new file mode 100644 index 00000000..ddac8a65 --- /dev/null +++ b/.github/workflows/tests-backend-v2.yml @@ -0,0 +1,37 @@ +name: Stark Backend V2 Tests + +on: + push: + branches: ["main"] + pull_request: + branches: ["**"] + paths: + - "crates/stark-backend-v2/**" + - ".github/workflows/tests-backend-v2.yml" + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }} + cancel-in-progress: true + +env: + CARGO_TERM_COLOR: always + +jobs: + tests: + name: stark-backend-v2 crate tests + runs-on: + - runs-on=${{ github.run_id }}/runner=16cpu-linux-arm64 + + steps: + - uses: runs-on/action@v2 + - uses: actions/checkout@v5 + - uses: dtolnay/rust-toolchain@stable + env: + RUSTUP_PERMIT_COPY_RENAME: "1" + - uses: Swatinem/rust-cache@v2 + - uses: taiki-e/install-action@nextest + + - name: Run tests + working-directory: crates/stark-backend-v2 + run: | + cargo nextest run diff --git a/.github/workflows/tests-cpu.yml b/.github/workflows/tests-cpu.yml index 6a21b3b7..fa864782 100644 --- a/.github/workflows/tests-cpu.yml +++ b/.github/workflows/tests-cpu.yml @@ -24,7 +24,10 @@ jobs: steps: - uses: actions/checkout@v5 - - uses: actions-rust-lang/setup-rust-toolchain@v1 + - uses: dtolnay/rust-toolchain@stable + env: + RUSTUP_PERMIT_COPY_RENAME: "1" + - uses: Swatinem/rust-cache@v2 - uses: taiki-e/install-action@nextest - name: Run tests diff --git a/.github/workflows/tests-cuda.yml b/.github/workflows/tests-cuda.yml index 7946ad43..e7ade9e7 100644 --- a/.github/workflows/tests-cuda.yml +++ b/.github/workflows/tests-cuda.yml @@ -25,7 +25,9 @@ jobs: steps: - uses: runs-on/action@v2 - uses: actions/checkout@v5 - - uses: actions-rust-lang/setup-rust-toolchain@v1 + - uses: dtolnay/rust-toolchain@stable + env: + RUSTUP_PERMIT_COPY_RENAME: "1" - uses: ./.github/actions/rust-cache-cuda - uses: taiki-e/install-action@nextest - name: Verify GPU setup diff --git a/Cargo.toml b/Cargo.toml index 1cd0e534..c29ab923 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace.package] -version = "1.2.3" +version = "2.0.0-alpha" edition = "2021" rust-version = "1.83" authors = ["OpenVM contributors"] @@ -8,9 +8,22 @@ repository = "https://github.com/openvm-org/" license = "MIT OR Apache-2.0" [workspace] -members = ["crates/stark-backend", "crates/stark-sdk", "crates/cuda-common", "crates/cuda-backend", "crates/cuda-builder"] +members = [ + "crates/stark-backend", + "crates/stark-sdk", + "crates/cuda-common", + "crates/cuda-backend", + "crates/cuda-builder", + "crates/stark-backend-v2", + "crates/stark-backend-v2/derive", +] # Only these build by default -default-members = ["crates/stark-backend", "crates/stark-sdk"] +default-members = [ + "crates/stark-backend", + "crates/stark-sdk", + "crates/stark-backend-v2", + "crates/stark-backend-v2/derive", +] resolver = "2" # Fastest runtime configuration @@ -55,10 +68,13 @@ openvm-stark-sdk = { path = "crates/stark-sdk", default-features = false } openvm-cuda-builder = { path = "crates/cuda-builder", default-features = false } openvm-cuda-common = { path = "crates/cuda-common", default-features = false } openvm-cuda-backend = { path = "crates/cuda-backend", default-features = false } +stark-backend-v2 = { path = "crates/stark-backend-v2", default-features = false } +codec-derive = { path = "crates/stark-backend-v2/derive", default-features = false } # Plonky3 p3-air = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb", default-features = false } p3-field = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb", default-features = false } +p3-interpolation = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb", default-features = false } p3-commit = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb", default-features = false } p3-matrix = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb", default-features = false } p3-baby-bear = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb", default-features = false } @@ -116,6 +132,7 @@ async-trait = "0.1.83" getset = "0.1.3" rand = { version = "0.8.5", default-features = false } hex = { version = "0.4.3", default-features = false } +hex-literal = "1.0.0" bitcode = "0.6.5" bincode = "2.0.1" bincode_derive = "2.0.1" diff --git a/crates/cuda-backend/cuda/supra/ntt.cu b/crates/cuda-backend/cuda/supra/ntt.cu index c9b892e6..d324605a 100644 --- a/crates/cuda-backend/cuda/supra/ntt.cu +++ b/crates/cuda-backend/cuda/supra/ntt.cu @@ -14,6 +14,8 @@ // Licensed under the Apache License, Version 2.0, see LICENSE for details. // SPDX-License-Identifier: Apache-2.0 +#include + #include "launcher.cuh" #include "ntt/ntt.cuh" @@ -22,7 +24,8 @@ __launch_bounds__(768, 1) __global__ void _CT_NTT(const unsigned int radix, const unsigned int lg_domain_size, const unsigned int stage, const unsigned int iterations, fr_t* d_inout, const unsigned int padded_poly_size, - bool is_intt, const fr_t d_domain_size_inverse) + const uint32_t poly_count, bool is_intt, + const fr_t d_domain_size_inverse) { #if (__CUDACC_VER_MAJOR__-0) >= 11 __builtin_assume(lg_domain_size <= MAX_LG_DOMAIN_SIZE); @@ -38,7 +41,10 @@ void _CT_NTT(const unsigned int radix, const unsigned int lg_domain_size, const fr_t* d_radixX_twiddles = d_radix6_twiddles + twiddles_offset; index_t tid = threadIdx.x + blockDim.x * (index_t)blockIdx.x; - d_inout += blockIdx.y * padded_poly_size; // [DIFF]: move in/out ptr to another row + const uint32_t poly_idx = blockIdx.y + blockIdx.z * gridDim.y; // [DIFF]: use gridDim.y to calculate poly_idx + if (poly_idx >= poly_count) + return; + d_inout += static_cast(poly_idx) * padded_poly_size; // [DIFF]: move in/out ptr to another row const index_t diff_mask = (1 << (iterations - 1)) - 1; const index_t inp_mask = ((index_t)1 << stage) - 1; @@ -211,6 +217,9 @@ extern "C" int _ct_mixed_radix_narrow( index_t block_size = 1 << (radix - 1); index_t num_blocks; + if (poly_count == 0) + return cudaSuccess; + block_size = (num_threads <= block_size) ? num_threads : block_size; num_blocks = (num_threads + block_size - 1) / block_size; @@ -219,18 +228,23 @@ extern "C" int _ct_mixed_radix_narrow( const int Z_COUNT = 256/8/sizeof(fr_t); size_t shared_sz = sizeof(fr_t) << (radix - 1); + // [DIFF]: calculate grid_y, grid_z from poly_count + const uint32_t MAX_Y = 65535; + uint32_t grid_y = poly_count < MAX_Y ? poly_count : MAX_Y; + uint32_t grid_z = (poly_count + grid_y - 1) / grid_y; + #define NTT_ARGUMENTS radix, lg_domain_size, stage, iterations, \ - d_inout, padded_poly_size, is_intt, domain_size_inverse[lg_domain_size] + d_inout, padded_poly_size, poly_count, is_intt, domain_size_inverse[lg_domain_size] // [DIFF]: N -> dim3(N, poly_count) in grid_size; stream -> cudaStreamPerThread if (num_blocks < Z_COUNT) - _CT_NTT<1><<>>(NTT_ARGUMENTS); + _CT_NTT<1><<>>(NTT_ARGUMENTS); else if (stage == 0 || lg_domain_size < 12) - _CT_NTT<<>>(NTT_ARGUMENTS); - else if (lg_domain_size < MAX_LG_DOMAIN_SIZE) - _CT_NTT<<>>(NTT_ARGUMENTS); + _CT_NTT<<>>(NTT_ARGUMENTS); + else if (lg_domain_size <= MAX_LG_DOMAIN_SIZE) + _CT_NTT<<>>(NTT_ARGUMENTS); else - assert(lg_domain_size < MAX_LG_DOMAIN_SIZE); + assert(lg_domain_size <= MAX_LG_DOMAIN_SIZE); #undef NTT_ARGUMENTS diff --git a/crates/cuda-backend/cuda/supra/ntt_bitrev.cu b/crates/cuda-backend/cuda/supra/ntt_bitrev.cu index 1e422ff3..7d4bbe70 100644 --- a/crates/cuda-backend/cuda/supra/ntt_bitrev.cu +++ b/crates/cuda-backend/cuda/supra/ntt_bitrev.cu @@ -2,28 +2,50 @@ * Source: https://github.com/supranational/sppark (tag=v0.1.12) * Status: MODIFIED from sppark/ntt/kernels.cu * Imported: 2025-08-13 by @gaxiom - * + * * LOCAL CHANGES (high level): * - 2025-08-13: Support multiple rows in bit_rev_permutation & bit_rev_permutation_z * - 2025-09-10: Add extern "C" launcher from sppark/ntt/ntt.cuh + * - 2025-12-24: Template field type to support fr_t and bb31_4_t */ +#include + #include "launcher.cuh" #include "ntt/ntt.cuh" +// [DIFF]: Add new type for bit reversal kernel +struct frac_fpext_t { + bb31_4_t num; + bb31_4_t denom; +}; + +/* + * Template type T requirements: + * - Default constructible: T() + * - Copy constructible: T t = other; + * - Copy assignable: t1 = t2; + * - Trivially copyable (for shared memory and global memory operations) + */ + // Permutes the data in an array such that data[i] = data[bit_reverse(i)] // and data[bit_reverse(i)] = data[i] +template __launch_bounds__(1024) __global__ -void bit_rev_permutation(fr_t* d_out, const fr_t *d_in, uint32_t lg_domain_size, uint32_t padded_poly_size) +void bit_rev_permutation(T* d_out, const T *d_in, uint32_t lg_domain_size, + uint32_t padded_poly_size, uint32_t poly_count) { - d_out += blockIdx.y * padded_poly_size; // [DIFF]: move out ptr to another row - d_in += blockIdx.y * padded_poly_size; // [DIFF]: move in ptr to another row + const uint32_t poly_idx = blockIdx.y + blockIdx.z * gridDim.y; // [DIFF]: use gridDim.y to calculate poly_idx + if (poly_idx >= poly_count) + return; + d_out += static_cast(poly_idx) * padded_poly_size; // [DIFF]: move out ptr to another row + d_in += static_cast(poly_idx) * padded_poly_size; // [DIFF]: move in ptr to another row if (gridDim.x == 1 && blockDim.x == (1 << lg_domain_size)) { uint32_t idx = threadIdx.x; uint32_t rev = bit_rev(idx, lg_domain_size); - fr_t t = d_in[idx]; + T t = d_in[idx]; if (d_out == d_in) __syncthreads(); d_out[rev] = t; @@ -33,9 +55,9 @@ void bit_rev_permutation(fr_t* d_out, const fr_t *d_in, uint32_t lg_domain_size, bool copy = d_out != d_in && idx == rev; if (idx < rev || copy) { - fr_t t0 = d_in[idx]; + T t0 = d_in[idx]; if (!copy) { - fr_t t1 = d_in[rev]; + T t1 = d_in[rev]; d_out[idx] = t1; } d_out[rev] = t0; @@ -43,16 +65,22 @@ void bit_rev_permutation(fr_t* d_out, const fr_t *d_in, uint32_t lg_domain_size, } } -template +template __launch_bounds__(192, 2) __global__ -void bit_rev_permutation_z(fr_t* out, const fr_t* in, uint32_t lg_domain_size, uint32_t padded_poly_size) +void bit_rev_permutation_z(T* out, const T* in, uint32_t lg_domain_size, + uint32_t padded_poly_size, uint32_t poly_count) { - out += blockIdx.y * padded_poly_size; // [DIFF]: move out ptr to another row - in += blockIdx.y * padded_poly_size; // [DIFF]: move in ptr to another row + const uint32_t poly_idx = blockIdx.y + blockIdx.z * gridDim.y; + if (poly_idx >= poly_count) + return; + out += static_cast(poly_idx) * padded_poly_size; // [DIFF]: move out ptr to another row + in += static_cast(poly_idx) * padded_poly_size; // [DIFF]: move in ptr to another row const uint32_t LG_Z_COUNT = 31 - __clz(Z_COUNT); // [DIFF]: use __clz to get lg2 - extern __shared__ fr_t xchg[][Z_COUNT][Z_COUNT]; + // Use byte array for extern shared memory to avoid symbol conflicts across template instantiations + extern __shared__ unsigned char xchg_raw[]; + T (*xchg)[Z_COUNT][Z_COUNT] = reinterpret_cast(xchg_raw); uint32_t gid = threadIdx.x / Z_COUNT; uint32_t idx = threadIdx.x % Z_COUNT; @@ -72,7 +100,7 @@ void bit_rev_permutation_z(fr_t* out, const fr_t* in, uint32_t lg_domain_size, u index_t base_idx = group_idx * Z_COUNT + idx; index_t base_rev = group_rev * Z_COUNT + idx; - fr_t regs[Z_COUNT]; + T regs[Z_COUNT]; #pragma unroll for (uint32_t i = 0; i < Z_COUNT; i++) { @@ -107,27 +135,34 @@ void bit_rev_permutation_z(fr_t* out, const fr_t* in, uint32_t lg_domain_size, u } -extern "C" int _bit_rev(fr_t* d_out, const fr_t* d_inp, +template +static int bit_rev_impl(T* d_out, const T* d_inp, uint32_t lg_domain_size, uint32_t padded_poly_size, uint32_t poly_count) { - assert(lg_domain_size <= MAX_LG_DOMAIN_SIZE); - size_t domain_size = (size_t)1 << lg_domain_size; // aim to read 4 cache lines of consecutive data per read - const uint32_t Z_COUNT = 256 / sizeof(fr_t); + const uint32_t Z_COUNT = 256 / sizeof(T); const uint32_t bsize = Z_COUNT > WARP_SIZE ? Z_COUNT : WARP_SIZE; + if (poly_count == 0) + return cudaSuccess; + + // [DIFF]: calculate grid_y, grid_z from poly_count + const uint32_t MAX_Y = 65535; + uint32_t grid_y = poly_count < MAX_Y ? poly_count : MAX_Y; + uint32_t grid_z = (poly_count + grid_y - 1) / grid_y; + // [DIFF]: N -> dim3(N, poly_count) in grid_size; stream -> cudaStreamPerThread if (domain_size <= 1024) - bit_rev_permutation<<>> - (d_out, d_inp, lg_domain_size, padded_poly_size); + bit_rev_permutation<<>> + (d_out, d_inp, lg_domain_size, padded_poly_size, poly_count); else if (domain_size < bsize * Z_COUNT) - bit_rev_permutation<<>> - (d_out, d_inp, lg_domain_size, padded_poly_size); + bit_rev_permutation<<>> + (d_out, d_inp, lg_domain_size, padded_poly_size, poly_count); else if (Z_COUNT > WARP_SIZE || lg_domain_size <= 32) - bit_rev_permutation_z<<>> - (d_out, d_inp, lg_domain_size, padded_poly_size); + bit_rev_permutation_z<<>> + (d_out, d_inp, lg_domain_size, padded_poly_size, poly_count); else { // Those GPUs that can reserve 96KB of shared memory can // schedule 2 blocks to each SM... @@ -136,10 +171,28 @@ extern "C" int _bit_rev(fr_t* d_out, const fr_t* d_inp, int sm_count; cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device); - bit_rev_permutation_z<<>> - (d_out, d_inp, lg_domain_size, padded_poly_size); + bit_rev_permutation_z<<>> + (d_out, d_inp, lg_domain_size, padded_poly_size, poly_count); } return CHECK_KERNEL(); +} + +extern "C" int _bit_rev(fr_t* d_out, const fr_t* d_inp, + uint32_t lg_domain_size, uint32_t padded_poly_size, uint32_t poly_count) +{ + return bit_rev_impl(d_out, d_inp, lg_domain_size, padded_poly_size, poly_count); +} + +extern "C" int _bit_rev_ext(bb31_4_t* d_out, const bb31_4_t* d_inp, + uint32_t lg_domain_size, uint32_t padded_poly_size, uint32_t poly_count) +{ + return bit_rev_impl(d_out, d_inp, lg_domain_size, padded_poly_size, poly_count); +} + +extern "C" int _bit_rev_frac_ext(frac_fpext_t* d_out, const frac_fpext_t* d_inp, + uint32_t lg_domain_size, uint32_t padded_poly_size, uint32_t poly_count) +{ + return bit_rev_impl(d_out, d_inp, lg_domain_size, padded_poly_size, poly_count); } \ No newline at end of file diff --git a/crates/cuda-backend/src/base.rs b/crates/cuda-backend/src/base.rs index 2505fafd..8105ba57 100644 --- a/crates/cuda-backend/src/base.rs +++ b/crates/cuda-backend/src/base.rs @@ -76,6 +76,11 @@ impl DeviceMatrix { pub fn strong_count(&self) -> usize { Arc::strong_count(&self.buffer) } + + pub fn as_view(&self) -> DeviceMatrixView<'_, T> { + // SAFETY: buffer is borrowed for lifetime 'a of the view + unsafe { DeviceMatrixView::from_raw_parts(self.buffer.as_ptr(), self.height, self.width) } + } } impl MatrixDimensions for DeviceMatrix { @@ -108,6 +113,48 @@ impl Debug for DeviceMatrix { } } +/// View of a device matrix. Dropping does not free memory. +#[derive(Clone, Copy)] +pub struct DeviceMatrixView<'a, T> { + ptr: *const T, + height: usize, + width: usize, + _ptr_lifetime: PhantomData<&'a T>, +} + +unsafe impl Send for DeviceMatrixView<'_, T> {} +unsafe impl Sync for DeviceMatrixView<'_, T> {} + +impl DeviceMatrixView<'_, T> { + /// # Safety + /// - The pointer must be valid for the lifetime of the view. + /// - The pointer must have memory allocated for the following `height * width` elements of `T`. + pub unsafe fn from_raw_parts(ptr: *const T, height: usize, width: usize) -> Self { + Self { + ptr, + height, + width, + _ptr_lifetime: PhantomData, + } + } + + pub fn as_ptr(&self) -> *const T { + self.ptr + } +} + +impl MatrixDimensions for DeviceMatrixView<'_, T> { + #[inline] + fn height(&self) -> usize { + self.height + } + + #[inline] + fn width(&self) -> usize { + self.width + } +} + /// The following trait and types are borrowed from [halo2](https:://github.com/zcash/halo2). /// The basis over which a polynomial is described. pub trait Basis: Copy + Debug + Send + Sync {} diff --git a/crates/cuda-backend/src/cuda/kernels.rs b/crates/cuda-backend/src/cuda/kernels.rs index 2f17b506..ebabf05f 100644 --- a/crates/cuda-backend/src/cuda/kernels.rs +++ b/crates/cuda-backend/src/cuda/kernels.rs @@ -148,6 +148,22 @@ pub mod lde { in_size, )) } + + pub unsafe fn raw_batch_expand_pad( + output: *mut T, + input: *const T, + poly_count: u32, + out_size: u32, + in_size: u32, + ) -> Result<(), CudaError> { + CudaError::from_result(_batch_expand_pad( + output as *mut std::ffi::c_void, + input as *const std::ffi::c_void, + poly_count, + out_size, + in_size, + )) + } } // relate to poseidon2.cu diff --git a/crates/cuda-backend/src/cuda/mod.rs b/crates/cuda-backend/src/cuda/mod.rs index ca53984e..824ffe5e 100644 --- a/crates/cuda-backend/src/cuda/mod.rs +++ b/crates/cuda-backend/src/cuda/mod.rs @@ -1,3 +1,3 @@ #![allow(clippy::missing_safety_doc)] -pub(crate) mod kernels; -pub(crate) mod ntt; +pub mod kernels; +pub mod ntt; diff --git a/crates/cuda-backend/src/cuda/ntt.rs b/crates/cuda-backend/src/cuda/ntt.rs index 4cffdeb3..b4008d8d 100644 --- a/crates/cuda-backend/src/cuda/ntt.rs +++ b/crates/cuda-backend/src/cuda/ntt.rs @@ -3,6 +3,8 @@ use openvm_cuda_common::{d_buffer::DeviceBuffer, error::CudaError}; +use crate::prelude::{EF, F}; + // relate to supra/ntt_params.cu extern "C" { fn _generate_all_twiddles(twiddles: *mut std::ffi::c_void, inverse: bool) -> i32; @@ -35,9 +37,25 @@ extern "C" { padded_poly_size: u32, poly_count: u32, ) -> i32; + + fn _bit_rev_ext( + d_out: *mut std::ffi::c_void, + d_inp: *const std::ffi::c_void, + lg_domain_size: u32, + padded_poly_size: u32, + poly_count: u32, + ) -> i32; + + fn _bit_rev_frac_ext( + d_out: *mut std::ffi::c_void, + d_inp: *const std::ffi::c_void, + lg_domain_size: u32, + padded_poly_size: u32, + poly_count: u32, + ) -> i32; } -pub unsafe fn bit_rev( +pub unsafe fn bit_rev( d_out: &DeviceBuffer, d_inp: &DeviceBuffer, lg_domain_size: u32, @@ -53,6 +71,38 @@ pub unsafe fn bit_rev( )) } +pub unsafe fn bit_rev_ext( + d_out: &DeviceBuffer, + d_inp: &DeviceBuffer, + lg_domain_size: u32, + padded_poly_size: u32, + poly_count: u32, +) -> Result<(), CudaError> { + CudaError::from_result(_bit_rev_ext( + d_out.as_mut_raw_ptr(), + d_inp.as_raw_ptr(), + lg_domain_size, + padded_poly_size, + poly_count, + )) +} + +pub unsafe fn bit_rev_frac_ext( + d_out: &DeviceBuffer<(EF, EF)>, + d_inp: &DeviceBuffer<(EF, EF)>, + lg_domain_size: u32, + padded_poly_size: u32, + poly_count: u32, +) -> Result<(), CudaError> { + CudaError::from_result(_bit_rev_frac_ext( + d_out.as_mut_raw_ptr(), + d_inp.as_raw_ptr(), + lg_domain_size, + padded_poly_size, + poly_count, + )) +} + // relate to supra/ntt.cu extern "C" { fn _ct_mixed_radix_narrow( @@ -67,7 +117,7 @@ extern "C" { ) -> i32; } -pub unsafe fn ct_mixed_radix_narrow( +pub unsafe fn ct_mixed_radix_narrow( d_inout: &DeviceBuffer, radix: u32, lg_domain_size: u32, diff --git a/crates/cuda-backend/src/engine.rs b/crates/cuda-backend/src/engine.rs index 042a7cf3..80690c14 100644 --- a/crates/cuda-backend/src/engine.rs +++ b/crates/cuda-backend/src/engine.rs @@ -44,7 +44,7 @@ impl StarkFriEngine for GpuBabyBearPoseidon2Engine { Self { device: GpuDevice::new( GpuConfig::new(fri_params, BabyBear::GENERATOR), - Some(FriLogUpPhaseGpu::new(log_up_params.clone())), + Some(FriLogUpPhaseGpu::new(log_up_params)), ), config: config_from_perm( &perm, diff --git a/crates/cuda-backend/src/fri_log_up.rs b/crates/cuda-backend/src/fri_log_up.rs index 4d0872ac..d3b3f918 100644 --- a/crates/cuda-backend/src/fri_log_up.rs +++ b/crates/cuda-backend/src/fri_log_up.rs @@ -64,7 +64,7 @@ impl FriLogUpPhaseGpu { return None; } - let logup_pow_witness = challenger.grind(self.log_up_params.log_up_pow_bits); + let logup_pow_witness = challenger.grind(self.log_up_params.pow_bits); let challenges: [EF; STARK_LU_NUM_CHALLENGES] = array::from_fn(|_| challenger.sample_ext_element::()); diff --git a/crates/cuda-backend/src/lde/mod.rs b/crates/cuda-backend/src/lde/mod.rs index 8ba44e6c..a548e2de 100644 --- a/crates/cuda-backend/src/lde/mod.rs +++ b/crates/cuda-backend/src/lde/mod.rs @@ -4,7 +4,7 @@ use crate::{base::DeviceMatrix, prelude::F}; mod ops; use ops::*; -mod ntt; +pub(super) mod ntt; /// The top-level LDE abstraction, composed of general matrix access (dimensions), /// trace access, and LDE behavior (which varies by mode). diff --git a/crates/cuda-backend/src/lde/ntt.rs b/crates/cuda-backend/src/lde/ntt.rs index 0b68b182..ec7e0a55 100644 --- a/crates/cuda-backend/src/lde/ntt.rs +++ b/crates/cuda-backend/src/lde/ntt.rs @@ -77,7 +77,14 @@ impl<'a> NttImpl<'a> { } } -pub(super) fn batch_ntt( +/// Performs column-wise batch NTT on `buffer`, where `buffer` is assumed to be column-major with +/// columns of height `2^(log_trace_height + log_blowup)`. The NTT are performed on the first +/// `2^log_trace_height` elements of each column. If `bit_reverse` is true, then the input columns +/// are assumed to be ordered in **natural** ordering, and a bit-reversal permutation is applied for +/// the internal algorithm of the NTT. If `bit_reverse` is false, then the input columns are assumed +/// to be in bit-reverse ordering. If `is_intt` is true, the inverse NTT is performed; otherwise, +/// the forward NTT is performed. +pub fn batch_ntt( buffer: &DeviceBuffer, log_trace_height: u32, log_blowup: u32, diff --git a/crates/cuda-backend/src/lib.rs b/crates/cuda-backend/src/lib.rs index 96f0d84d..5874458c 100644 --- a/crates/cuda-backend/src/lib.rs +++ b/crates/cuda-backend/src/lib.rs @@ -7,9 +7,12 @@ mod lde; mod merkle_tree; mod opener; mod quotient; -mod transpiler; +pub mod transpiler; pub mod types; +pub mod ntt { + pub use crate::lde::ntt::batch_ntt; +} pub mod prelude { pub use crate::types::prelude::*; } diff --git a/crates/cuda-backend/src/transpiler/mod.rs b/crates/cuda-backend/src/transpiler/mod.rs index 7f0078f7..8e5ac907 100644 --- a/crates/cuda-backend/src/transpiler/mod.rs +++ b/crates/cuda-backend/src/transpiler/mod.rs @@ -11,7 +11,7 @@ use p3_field::{Field, PrimeField32}; use rustc_hash::FxHashMap; use tracing::instrument; -pub(crate) mod codec; +pub mod codec; #[derive(Clone, Debug, Eq, PartialEq)] pub enum Source { @@ -136,15 +136,20 @@ impl SymbolicRulesOnGpu { // This should be a list of constraint indices, but we initialize it to be a list of DAG // node indices for now. We'll remap each entry later on. let used_nodes = if dag.constraints.constraint_idx.is_empty() { - // This branch is used only for encoding SymbolicInteractions for use during perm - // trace generation. The `message` is always a single expression for the denominator. - dag.interactions - .iter() - .flat_map(|i| { - assert_eq!(i.message.len(), 1); - [i.count, *i.message.first().unwrap()] - }) - .collect::>() + if is_permute { + // This branch is used only for encoding SymbolicInteractions for use during perm + // trace generation. The `message` is always a single expression for the + // denominator. + dag.interactions + .iter() + .flat_map(|i| { + assert_eq!(i.message.len(), 1); + [i.count, *i.message.first().unwrap()] + }) + .collect::>() + } else { + vec![] + } } else { dag.constraints.constraint_idx.clone() }; diff --git a/crates/cuda-backend/tests/ntt_roundtrip.rs b/crates/cuda-backend/tests/ntt_roundtrip.rs new file mode 100644 index 00000000..cac53ab3 --- /dev/null +++ b/crates/cuda-backend/tests/ntt_roundtrip.rs @@ -0,0 +1,29 @@ +use openvm_cuda_backend::{ntt::batch_ntt, prelude::F}; +use openvm_cuda_common::{ + copy::{MemCopyD2H, MemCopyH2D}, + d_buffer::DeviceBuffer, +}; +use p3_field::FieldAlgebra; + +#[test] +#[ignore] // Run explicitly: requires large GPU memory and significant runtime. +fn ntt_roundtrip_max_log_domain_size() { + const LOG_N: u32 = 27; + let n = 1usize << LOG_N; + let mut host = Vec::::with_capacity(n); + + for i in 0..n { + host.push(F::from_canonical_usize(i)); + } + + let mut device = DeviceBuffer::::with_capacity(n); + host.copy_to(&mut device).expect("host->device copy failed"); + + batch_ntt(&device, LOG_N, 0, 1, true, false); + batch_ntt(&device, LOG_N, 0, 1, true, true); + + let output = device.to_host().expect("device->host copy failed"); + for (i, got) in output.iter().enumerate() { + assert_eq!(*got, host[i], "mismatch at index {}", i); + } +} diff --git a/crates/cuda-builder/src/lib.rs b/crates/cuda-builder/src/lib.rs index beac9757..d12e7add 100644 --- a/crates/cuda-builder/src/lib.rs +++ b/crates/cuda-builder/src/lib.rs @@ -10,6 +10,7 @@ pub struct CudaBuilder { library_name: String, cuda_arch: Vec, cuda_opt_level: Option, + lineinfo: bool, custom_flags: Vec, link_libraries: Vec, link_search_paths: Vec, @@ -32,6 +33,7 @@ impl Default for CudaBuilder { library_name: String::new(), cuda_arch: Vec::new(), cuda_opt_level: None, + lineinfo: false, custom_flags: vec![ "--std=c++17".to_string(), "--expt-relaxed-constexpr".to_string(), @@ -132,6 +134,12 @@ impl CudaBuilder { self } + /// Enable line info for profiling (NCU). Keeps other optimizations intact. + pub fn lineinfo(mut self, enabled: bool) -> Self { + self.lineinfo = enabled; + self + } + /// Add custom compiler flag pub fn flag(mut self, flag: &str) -> Self { self.custom_flags.push(flag.to_string()); @@ -215,6 +223,11 @@ impl CudaBuilder { .flag(format!("--ptxas-options=-O{}", cuda_opt_level)); } + // Add line info for profiling (NCU) if enabled + if self.get_lineinfo() { + builder.flag("-lineinfo"); + } + // Add source files for file in &self.source_files { builder.file(file); @@ -265,6 +278,7 @@ impl CudaBuilder { println!("cargo:rerun-if-env-changed=CUDA_ARCH"); println!("cargo:rerun-if-env-changed=CUDA_OPT_LEVEL"); println!("cargo:rerun-if-env-changed=CUDA_DEBUG"); + println!("cargo:rerun-if-env-changed=CUDA_LINEINFO"); println!("cargo:rerun-if-env-changed=NVCC_THREADS"); // Watch specific paths @@ -304,6 +318,10 @@ impl CudaBuilder { env::var("CUDA_OPT_LEVEL").unwrap_or_else(|_| "3".to_string()) } + fn get_lineinfo(&self) -> bool { + self.lineinfo || env::var("CUDA_LINEINFO").map(|v| v == "1").unwrap_or(false) + } + fn handle_debug_shortcuts(&self, builder: &mut cc::Build) { if env::var("CUDA_DEBUG").map(|v| v == "1").unwrap_or(false) { env::set_var("CUDA_OPT_LEVEL", "0"); diff --git a/crates/cuda-common/Cargo.toml b/crates/cuda-common/Cargo.toml index 21360d78..273e8038 100644 --- a/crates/cuda-common/Cargo.toml +++ b/crates/cuda-common/Cargo.toml @@ -21,4 +21,4 @@ openvm-cuda-builder.workspace = true tokio = { workspace = true, features = ["full"] } [features] -touchemall = [] \ No newline at end of file +touchemall = [] diff --git a/crates/cuda-common/include/poseidon2.cuh b/crates/cuda-common/include/poseidon2.cuh index 6e7dd2d7..73828868 100644 --- a/crates/cuda-common/include/poseidon2.cuh +++ b/crates/cuda-common/include/poseidon2.cuh @@ -7,6 +7,7 @@ #pragma once #include "fp.h" +#include namespace poseidon2 { @@ -65,7 +66,6 @@ static __device__ __constant__ Fp internal_diag16[16] = { 15 // -1/2^27 }; - #define CELLS 16 #define CELLS_RATE 8 #define CELLS_OUT 8 @@ -87,9 +87,7 @@ static __device__ void do_full_sboxes(Fp *cells) { } } -static __device__ void do_partial_sboxes(Fp *cells) { - cells[0] = sbox_d7(cells[0]); -} +static __device__ void do_partial_sboxes(Fp *cells) { cells[0] = sbox_d7(cells[0]); } // Plonky3 version // Multiply a 4-element vector x by: @@ -104,8 +102,8 @@ static __device__ void multiply_by_4x4_circulant(Fp *x) { Fp t01123 = t0123 + x[1]; Fp t01233 = t0123 + x[3]; - x[3] = t01233 + Fp(2) * x[0]; - x[1] = t01123 + Fp(2) * x[2]; + x[3] = t01233 + x[0].doubled(); + x[1] = t01123 + x[2].doubled(); x[0] = t01123 + t01; x[2] = t01233 + t23; } @@ -113,10 +111,6 @@ static __device__ void multiply_by_4x4_circulant(Fp *x) { static __device__ void multiply_by_m_ext(Fp *old_cells) { // Optimized method for multiplication by M_EXT. // See appendix B of Poseidon2 paper for additional details. - Fp cells[CELLS]; - for (uint i = 0; i < CELLS; i++) { - cells[0] = 0; - } Fp tmp_sums[4]; for (uint i = 0; i < 4; i++) { tmp_sums[i] = 0; @@ -124,17 +118,19 @@ static __device__ void multiply_by_m_ext(Fp *old_cells) { for (uint i = 0; i < CELLS / 4; i++) { multiply_by_4x4_circulant(old_cells + i * 4); for (uint j = 0; j < 4; j++) { - Fp to_add = old_cells[i * 4 + j]; - tmp_sums[j] += to_add; - cells[i * 4 + j] += to_add; + tmp_sums[j] += old_cells[i * 4 + j]; } } for (uint i = 0; i < CELLS; i++) { - old_cells[i] = cells[i] + tmp_sums[i % 4]; + old_cells[i] += tmp_sums[i % 4]; } } -static __device__ void add_round_constants_full(const Fp *ROUND_CONSTANTS_PLONKY3, Fp *cells, uint round) { +static __device__ void add_round_constants_full( + const Fp *ROUND_CONSTANTS_PLONKY3, + Fp *cells, + uint round +) { for (uint i = 0; i < CELLS; i++) { cells[i] += ROUND_CONSTANTS_PLONKY3[round * CELLS + i]; } @@ -148,12 +144,29 @@ static __device__ void add_round_constants_partial( cells[0] += PARTIAL_ROUND_CONSTANTS_PLONKY3[round]; } -static __device__ __forceinline__ void internal_layer_mat_mul(Fp* cells, Fp sum) { - cells[1] += sum; -#pragma unroll - for (int i = 2; i < CELLS; i++) { - cells[i] = sum + cells[i] * internal_diag16[i]; +static __device__ __forceinline__ void internal_layer_mat_mul(Fp *cells) { + Fp part_sum = cells[1]; + for (uint i = 2; i < CELLS; i++) { + part_sum += cells[i]; } + // https://github.com/Plonky3/Plonky3/blob/ecc66909d17f37a4d626d4601dc1420742571630/baby-bear/src/poseidon2.rs#L219 + Fp sum = part_sum + cells[0]; + cells[0] = part_sum - cells[0]; // -2 + cells[1] += sum; // 1 + cells[2] = sum + cells[2].doubled(); // 2 + cells[3] = sum + cells[3].halve(); // 1/2 + cells[4] = sum + cells[4].doubled() + cells[4]; // 3 + cells[5] = sum + cells[5].doubled().doubled(); // 4 + cells[6] = sum - cells[6].halve(); // -1/2 + cells[7] = sum - (cells[7].doubled() + cells[7]); // -3 + cells[8] = sum - cells[8].doubled().doubled(); // -4 + cells[9] = sum + cells[9] * internal_diag16[9]; // 1/2^8 + cells[10] = sum + cells[10].halve().halve(); // 1/4 + cells[11] = sum + cells[11] * internal_diag16[11]; // 1/8 + cells[12] = sum + cells[12] * internal_diag16[12]; // 1/2^27 + cells[13] = sum + cells[13] * internal_diag16[13]; // -1/2^8 + cells[14] = sum + cells[14] * internal_diag16[14]; // -1/16 + cells[15] = sum + cells[15] * internal_diag16[15]; // -1/2^27 } static __device__ void full_round_half(const Fp *ROUND_CONSTANTS, Fp *cells, uint round) { @@ -165,13 +178,7 @@ static __device__ void full_round_half(const Fp *ROUND_CONSTANTS, Fp *cells, uin static __device__ void partial_round(const Fp *PARTIAL_ROUND_CONSTANTS, Fp *cells, uint round) { add_round_constants_partial(PARTIAL_ROUND_CONSTANTS, cells, round); do_partial_sboxes(cells); - Fp part_sum = Fp(0); - for (uint i = 1; i < CELLS; i++) { - part_sum += cells[i]; - } - Fp full_sum = part_sum + cells[0]; - cells[0] = part_sum - cells[0]; - internal_layer_mat_mul(cells, full_sum); + internal_layer_mat_mul(cells); } static __device__ void poseidon2_mix(Fp *cells) { diff --git a/crates/cuda-common/src/d_buffer.rs b/crates/cuda-common/src/d_buffer.rs index 6b7ccde4..4d823c46 100644 --- a/crates/cuda-common/src/d_buffer.rs +++ b/crates/cuda-common/src/d_buffer.rs @@ -9,9 +9,13 @@ use crate::{ #[link(name = "cudart")] extern "C" { - fn cudaMemsetAsync(dst: *mut c_void, value: i32, count: usize, stream: cudaStream_t) -> i32; + pub fn cudaMemsetAsync(dst: *mut c_void, value: i32, count: usize, stream: cudaStream_t) + -> i32; } +/// Struct that owns a buffer allocated on GPU device. The struct only holds the raw pointer and +/// length, but this struct has a `Drop` implementation which frees the associated device memory. +#[repr(C)] pub struct DeviceBuffer { ptr: *mut T, len: usize, @@ -42,6 +46,16 @@ impl DeviceBuffer { } } + /// # Safety + /// - The caller must ensure that the pointer `ptr` is valid for `len` elements of type `T` in + /// device memory. + /// - Dropping the constructed buffer will attempt to free the memory. As such, `ptr` must + /// either have been allocated by the internal memory manager (VPMM) or the caller must use + /// `ManuallyDrop` to prevent double-free. + pub unsafe fn from_raw_parts(ptr: *mut T, len: usize) -> Self { + DeviceBuffer { ptr, len } + } + /// Allocate device memory for `len` elements of type `T`. pub fn with_capacity(len: usize) -> Self { tracing::debug!( diff --git a/crates/cuda-common/src/memory_manager/mod.rs b/crates/cuda-common/src/memory_manager/mod.rs index a57075e8..488912b0 100644 --- a/crates/cuda-common/src/memory_manager/mod.rs +++ b/crates/cuda-common/src/memory_manager/mod.rs @@ -3,6 +3,7 @@ use std::{ ffi::c_void, ptr::NonNull, sync::{Mutex, OnceLock}, + time::{SystemTime, UNIX_EPOCH}, }; use bytesize::ByteSize; @@ -23,10 +24,18 @@ mod tests; extern "C" { fn cudaMallocAsync(dev_ptr: *mut *mut c_void, size: usize, stream: cudaStream_t) -> i32; fn cudaFreeAsync(dev_ptr: *mut c_void, stream: cudaStream_t) -> i32; + fn cudaMemGetInfo(free: *mut usize, total: *mut usize) -> i32; } static MEMORY_MANAGER: OnceLock> = OnceLock::new(); +pub fn device_memory_used() -> usize { + let mut free = 0usize; + let mut total = 0usize; + unsafe { cudaMemGetInfo(&mut free, &mut total) }; + total - free +} + #[ctor::ctor] fn init() { let _ = MEMORY_MANAGER.set(Mutex::new(MemoryManager::new())); @@ -156,11 +165,42 @@ impl MemTracker { .get() .and_then(|m| m.lock().ok()) .map(|m| m.current_size) - .unwrap_or(0); + .unwrap_or_default(); Self { current, label } } + pub fn start_and_reset_peak(label: &'static str) -> Self { + let mut mem = Self::start(label); + mem.reset_peak(); + mem + } + + pub fn emit_metrics(&self) { + self.emit_metrics_with_label(self.label); + } + + pub fn emit_metrics_with_label(&self, label: &'static str) { + let Some(manager) = MEMORY_MANAGER.get().and_then(|m| m.lock().ok()) else { + return; + }; + + let ts = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs_f64() + * 1000.0; + let current = manager.current_size; + // local_peak is local maximum memory size, as observed by the manager, since the last + // reset_peak call + let local_peak = manager.max_used_size; + let reserved = manager.pool.memory_usage(); + metrics::gauge!("gpu_mem.timestamp_ms", "module" => label).set(ts); + metrics::gauge!("gpu_mem.current_bytes", "module" => label).set(current as f64); + metrics::gauge!("gpu_mem.local_peak_bytes", "module" => label).set(local_peak as f64); + metrics::gauge!("gpu_mem.reserved_bytes", "module" => label).set(reserved as f64); + } + #[inline] pub fn tracing_info(&self, msg: impl Into>) { let Some(manager) = MEMORY_MANAGER.get().and_then(|m| m.lock().ok()) else { diff --git a/crates/stark-backend-v2/Cargo.toml b/crates/stark-backend-v2/Cargo.toml new file mode 100644 index 00000000..49f73963 --- /dev/null +++ b/crates/stark-backend-v2/Cargo.toml @@ -0,0 +1,49 @@ +[package] +name = "stark-backend-v2" +version.workspace = true +authors.workspace = true +edition.workspace = true +description = "Proof system v2" + +[dependencies] +openvm-stark-backend.workspace = true +openvm-stark-sdk.workspace = true +codec-derive.workspace = true + +p3-air.workspace = true +p3-baby-bear.workspace = true +p3-challenger.workspace = true +p3-dft.workspace = true +p3-field.workspace = true +p3-interpolation.workspace = true +p3-matrix.workspace = true +p3-maybe-rayon.workspace = true +p3-poseidon2.workspace = true +p3-symmetric.workspace = true +p3-util.workspace = true + +itertools.workspace = true +rayon.workspace = true +hex-literal.workspace = true +serde.workspace = true +derive-new.workspace = true +tracing.workspace = true +getset.workspace = true +cfg-if.workspace = true +thiserror.workspace = true +eyre.workspace = true +derivative.workspace = true +bitcode.workspace = true +metrics = { workspace = true, optional = true } + +[dev-dependencies] +rand = "0.9" # don't use workspace version as it's older +p3-keccak-air.workspace = true +test-case.workspace = true + +[features] +default = ["parallel", "metrics"] +parallel = ["p3-maybe-rayon/parallel", "openvm-stark-backend/parallel"] +jemalloc = ["openvm-stark-backend/jemalloc"] +metrics = ["dep:metrics", "openvm-stark-backend/metrics"] +test-utils = [] diff --git a/crates/stark-backend-v2/README.md b/crates/stark-backend-v2/README.md new file mode 100644 index 00000000..07a39792 --- /dev/null +++ b/crates/stark-backend-v2/README.md @@ -0,0 +1,7 @@ +Reference implementation only. + +This crate is intended to be either moved into `openvm-stark-backend` or a separate performance-oriented implementation will be directly added to `openvm-stark-backend`. + +# References + +- https://github.com/starkware-libs/stwo diff --git a/crates/stark-backend-v2/derive/Cargo.toml b/crates/stark-backend-v2/derive/Cargo.toml new file mode 100644 index 00000000..6f4e912c --- /dev/null +++ b/crates/stark-backend-v2/derive/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "codec-derive" +description = "Procedural macros for encoding and decoding." +version.workspace = true +authors.workspace = true +edition.workspace = true + +[lib] +proc-macro = true + +[dependencies] +quote = "1.0" +syn = { version = "2.0", features = ["parsing", "full"] } +proc-macro-crate = "1.2" +proc-macro2 = "1.0" diff --git a/crates/stark-backend-v2/derive/src/lib.rs b/crates/stark-backend-v2/derive/src/lib.rs new file mode 100644 index 00000000..f547b240 --- /dev/null +++ b/crates/stark-backend-v2/derive/src/lib.rs @@ -0,0 +1,152 @@ +use proc_macro::TokenStream; +use quote::{format_ident, quote}; +use syn::{parse_macro_input, Data, DeriveInput, Fields}; + +fn codec_crate_root() -> proc_macro2::TokenStream { + match proc_macro_crate::crate_name("stark-backend-v2") { + Ok(proc_macro_crate::FoundCrate::Itself) => quote!(crate), + Ok(proc_macro_crate::FoundCrate::Name(name)) => { + let ident = format_ident!("{}", name); + quote!(::#ident) + } + Err(_) => quote!(::stark_backend_v2), + } +} + +#[proc_macro_derive(Encode)] +pub fn encode_derive(input: TokenStream) -> TokenStream { + let ast = parse_macro_input!(input as DeriveInput); + let name = &ast.ident; + let (impl_generics, type_generics, where_clause) = ast.generics.split_for_impl(); + let codec_root = codec_crate_root(); + + let fields = match &ast.data { + Data::Struct(data_struct) => &data_struct.fields, + Data::Enum(_) => { + return syn::Error::new( + name.span(), + "Encode derive macro only supports structs, not enums", + ) + .to_compile_error() + .into(); + } + Data::Union(_) => { + return syn::Error::new( + name.span(), + "Encode derive macro only supports structs, not unions", + ) + .to_compile_error() + .into(); + } + }; + + let encode_fields = match fields { + Fields::Named(fields_named) => { + let field_encodes = fields_named.named.iter().map(|field| { + let field_name = &field.ident; + quote! { + self.#field_name.encode(writer)?; + } + }); + quote! { + #(#field_encodes)* + } + } + Fields::Unnamed(fields_unnamed) => { + let field_encodes = fields_unnamed.unnamed.iter().enumerate().map(|(idx, _)| { + let index = syn::Index::from(idx); + quote! { + self.#index.encode(writer)?; + } + }); + quote! { + #(#field_encodes)* + } + } + Fields::Unit => { + quote! {} + } + }; + + let expanded = quote! { + impl #impl_generics #codec_root::codec::Encode for #name #type_generics #where_clause { + fn encode(&self, writer: &mut W) -> std::io::Result<()> { + #encode_fields + Ok(()) + } + } + }; + + TokenStream::from(expanded) +} + +#[proc_macro_derive(Decode)] +pub fn decode_derive(input: TokenStream) -> TokenStream { + let ast = parse_macro_input!(input as DeriveInput); + let name = &ast.ident; + let (impl_generics, type_generics, where_clause) = ast.generics.split_for_impl(); + let codec_root = codec_crate_root(); + + let fields = match &ast.data { + Data::Struct(data_struct) => &data_struct.fields, + Data::Enum(_) => { + return syn::Error::new( + name.span(), + "Decode derive macro only supports structs, not enums", + ) + .to_compile_error() + .into(); + } + Data::Union(_) => { + return syn::Error::new( + name.span(), + "Decode derive macro only supports structs, not unions", + ) + .to_compile_error() + .into(); + } + }; + + let decode_fields = match fields { + Fields::Named(fields_named) => { + let field_decodes = fields_named.named.iter().map(|field| { + let field_name = &field.ident; + let field_ty = &field.ty; + quote! { + #field_name: <#field_ty as #codec_root::codec::Decode>::decode(reader)?, + } + }); + quote! { + { + #(#field_decodes)* + } + } + } + Fields::Unnamed(fields_unnamed) => { + let field_decodes = fields_unnamed.unnamed.iter().map(|field| { + let field_ty = &field.ty; + quote! { + <#field_ty as #codec_root::codec::Decode>::decode(reader)?, + } + }); + quote! { + ( + #(#field_decodes)* + ) + } + } + Fields::Unit => { + quote! {} + } + }; + + let expanded = quote! { + impl #impl_generics #codec_root::codec::Decode for #name #type_generics #where_clause { + fn decode(reader: &mut R) -> std::io::Result { + Ok(Self #decode_fields) + } + } + }; + + TokenStream::from(expanded) +} diff --git a/crates/stark-backend-v2/examples/keccakf.rs b/crates/stark-backend-v2/examples/keccakf.rs new file mode 100644 index 00000000..cd8930b8 --- /dev/null +++ b/crates/stark-backend-v2/examples/keccakf.rs @@ -0,0 +1,90 @@ +//! Prove keccakf-air over BabyBear using poseidon2 for FRI hash. + +use std::sync::Arc; + +use eyre::eyre; +use openvm_stark_backend::{ + p3_air::{Air, AirBuilder, BaseAir}, + p3_field::Field, + prover::types::AirProvingContext, + rap::{BaseAirWithPublicValues, PartitionedBaseAir}, +}; +use openvm_stark_sdk::{ + bench::run_with_metric_collection, + config::log_up_params::log_up_security_params_baby_bear_100_bits, +}; +use p3_baby_bear::BabyBear; +use p3_keccak_air::KeccakAir; +use rand::{rngs::StdRng, Rng, SeedableRng}; +use stark_backend_v2::{ + poseidon2::sponge::DuplexSponge, + prover::{AirProvingContextV2, DeviceDataTransporterV2, ProvingContextV2}, + verifier::verify, + BabyBearPoseidon2CpuEngineV2, StarkEngineV2, SystemParams, WhirConfig, WhirParams, +}; +use tracing::info_span; + +const NUM_PERMUTATIONS: usize = 1 << 10; + +// Newtype to implement extended traits +struct TestAir(KeccakAir); + +impl BaseAir for TestAir { + fn width(&self) -> usize { + BaseAir::::width(&self.0) + } +} +impl BaseAirWithPublicValues for TestAir {} +impl PartitionedBaseAir for TestAir {} + +impl Air for TestAir { + fn eval(&self, builder: &mut AB) { + self.0.eval(builder); + } +} + +fn main() -> eyre::Result<()> { + let l_skip = 4; + let n_stack = 17; + let k_whir = 4; + let whir_params = WhirParams { + k: k_whir, + log_final_poly_len: 2 * k_whir, + query_phase_pow_bits: 20, + }; + let log_blowup = 1; + let whir = WhirConfig::new(log_blowup, l_skip + n_stack, whir_params, 100); + let params = SystemParams { + l_skip, + n_stack, + log_blowup, + whir, + logup: log_up_security_params_baby_bear_100_bits(), + max_constraint_degree: 3, + }; + + run_with_metric_collection("OUTPUT_PATH", || -> eyre::Result<()> { + let mut rng = StdRng::seed_from_u64(42); + let air = TestAir(KeccakAir {}); + + let engine = BabyBearPoseidon2CpuEngineV2::::new(params.clone()); + let (pk, vk) = engine.keygen(&[Arc::new(air)]); + let air_idx = 0; + + let inputs = (0..NUM_PERMUTATIONS) + .map(|_| rng.random()) + .collect::>(); + let trace = info_span!("generate_trace") + .in_scope(|| p3_keccak_air::generate_trace_rows::(inputs, 0)); + + let air_ctx = AirProvingContextV2::from_v1( + ¶ms, + AirProvingContext::simple_no_pis(Arc::new(trace)), + ); + let d_pk = engine.device().transport_pk_to_device(&pk); + let proof = engine.prove(&d_pk, ProvingContextV2::new(vec![(air_idx, air_ctx)])); + + verify(&vk, &proof, &mut DuplexSponge::default()) + .map_err(|e| eyre!("Proof failed to verify: {e}")) + }) +} diff --git a/crates/stark-backend-v2/profile.sh b/crates/stark-backend-v2/profile.sh new file mode 100755 index 00000000..f2dd5c28 --- /dev/null +++ b/crates/stark-backend-v2/profile.sh @@ -0,0 +1,41 @@ +#!/bin/bash + +# This script profiles an example binary for STARK proving +eg_name=$1 + +git_root=$(git rev-parse --show-toplevel) +cd $git_root/crates/backend + +arch=$(uname -m) +case $arch in + arm64|aarch64) + export RUSTFLAGS="-Ctarget-cpu=native -g -C force-frame-pointers=yes" + ;; + x86_64|amd64) + export RUSTFLAGS="-Ctarget-cpu=native -C target-feature=+avx512f -g -C force-frame-pointers=yes" + ;; + *) + echo "Unsupported architecture: $arch" + exit 1 + ;; +esac + +export JEMALLOC_SYS_WITH_MALLOC_CONF="retain:true,background_thread:true,metadata_thp:always,dirty_decay_ms:-1,muzzy_decay_ms:-1,abort_conf:true" + +cargo build --profile=profiling --example $eg_name --no-default-features --features=jemalloc,parallel + +# Check if samply is installed +if ! command -v samply &> /dev/null; then + echo "samply not found. Installing..." + cargo install samply +else + echo "samply is already installed" +fi + + +if command -v perf &> /dev/null && [[ "$(uname -s)" == "Linux" ]]; then + perf record -F 100 --call-graph=fp -g -o perf.data -- $git_root/target/profiling/examples/$eg_name + samply import perf.data +else + samply record $git_root/target/profiling/examples/$eg_name +fi diff --git a/crates/stark-backend-v2/src/chip.rs b/crates/stark-backend-v2/src/chip.rs new file mode 100644 index 00000000..afd4505d --- /dev/null +++ b/crates/stark-backend-v2/src/chip.rs @@ -0,0 +1,47 @@ +// Copied from v1 +// TODO[jpw]: remove duplication +use std::any::Any; + +use crate::prover::{AirProvingContextV2, ProverBackendV2}; + +/// A chip is a [ProverBackend]-specific object that converts execution logs (also referred to as +/// records) into a trace matrix. +/// +/// A chip may be stateful and store state on either host or device, although it is preferred that +/// all state is received through records. +pub trait ChipV2 { + /// Generate all necessary context for proving a single AIR. + fn generate_proving_ctx(&self, records: R) -> AirProvingContextV2; +} + +/// Auto-implemented trait for downcasting of trait objects. +pub trait AnyChip: ChipV2 { + fn as_any(&self) -> &dyn Any; +} + +impl + 'static> AnyChip for C { + fn as_any(&self) -> &dyn Any { + self + } +} + +// impl> ChipV2 for RefCell { +// fn generate_proving_ctx(&self, records: R) -> AirProvingContextV2 { +// self.borrow().generate_proving_ctx(records) +// } +// } +// impl> ChipV2 for Rc { +// fn generate_proving_ctx(&self, records: R) -> AirProvingContextV2 { +// self.as_ref().generate_proving_ctx(records) +// } +// } +// impl> ChipV2 for Arc { +// fn generate_proving_ctx(&self, records: R) -> AirProvingContextV2 { +// self.as_ref().generate_proving_ctx(records) +// } +// } +// impl> ChipV2 for Mutex { +// fn generate_proving_ctx(&self, records: R) -> AirProvingContextV2 { +// self.lock().unwrap().generate_proving_ctx(records) +// } +// } diff --git a/crates/stark-backend-v2/src/codec.rs b/crates/stark-backend-v2/src/codec.rs new file mode 100644 index 00000000..d9525ca0 --- /dev/null +++ b/crates/stark-backend-v2/src/codec.rs @@ -0,0 +1,243 @@ +use std::{ + array::from_fn, + io::{self, Cursor, Read, Result, Write}, +}; + +pub use codec_derive::{Decode, Encode}; +use p3_field::{FieldAlgebra, FieldExtensionAlgebra, PrimeField32}; + +use crate::{D_EF, EF, F}; + +/// Hardware and language independent encoding. +/// Uses the Writer pattern for more efficient encoding without intermediate buffers. +// @dev Trait just for implementation sanity +pub trait Encode { + /// Writes the encoded representation of `self` to the given writer. + fn encode(&self, writer: &mut W) -> Result<()>; + + /// Convenience method to encode into a `Vec` + fn encode_to_vec(&self) -> Result> { + let mut buffer = Vec::new(); + self.encode(&mut buffer)?; + Ok(buffer) + } +} + +/// Hardware and language independent decoding. +/// Uses the Reader pattern for efficient decoding. +pub trait Decode: Sized { + /// Reads and decodes a value from the given reader. + fn decode(reader: &mut R) -> Result; + fn decode_from_bytes(bytes: &[u8]) -> Result { + let mut reader = Cursor::new(bytes); + Self::decode(&mut reader) + } +} + +// ==================== Encode implementations for basic types ==================== + +impl Encode for bool { + fn encode(&self, writer: &mut W) -> Result<()> { + writer.write_all(&[*self as u8])?; + Ok(()) + } +} + +impl Encode for u8 { + fn encode(&self, writer: &mut W) -> Result<()> { + writer.write_all(&[*self]) + } +} + +impl Encode for u32 { + fn encode(&self, writer: &mut W) -> Result<()> { + writer.write_all(&self.to_le_bytes()) + } +} + +impl Encode for usize { + fn encode(&self, writer: &mut W) -> Result<()> { + let x: u32 = (*self).try_into().map_err(io::Error::other)?; + x.encode(writer) + } +} + +impl Encode for F { + fn encode(&self, writer: &mut W) -> Result<()> { + writer.write_all(&self.as_canonical_u32().to_le_bytes()) + } +} + +impl Encode for EF { + fn encode(&self, writer: &mut W) -> Result<()> { + let base_slice: &[F] = self.as_base_slice(); + // Fixed length slice, so don't encode length + for val in base_slice { + val.encode(writer)?; + } + Ok(()) + } +} + +// ==================== Encode helpers ==================== + +/// Encodes length of slice and then each element +pub fn encode_slice(slice: &[T], writer: &mut W) -> Result<()> { + slice.len().encode(writer)?; + for elt in slice { + elt.encode(writer)?; + } + Ok(()) +} + +/// Encodes each element (no length) +pub fn encode_iter<'a, T: Encode + 'a, W: Write>( + iter: impl Iterator, + writer: &mut W, +) -> Result<()> { + for elt in iter { + elt.encode(writer)?; + } + Ok(()) +} + +impl Encode for [T; N] { + fn encode(&self, writer: &mut W) -> Result<()> { + for val in self { + val.encode(writer)?; + } + Ok(()) + } +} + +impl Encode for Vec { + fn encode(&self, writer: &mut W) -> Result<()> { + encode_slice(self, writer) + } +} + +impl Encode for (S, T) { + fn encode(&self, writer: &mut W) -> Result<()> { + self.0.encode(writer)?; + self.1.encode(writer) + } +} + +impl Encode for Option { + fn encode(&self, writer: &mut W) -> Result<()> { + self.is_some().encode(writer)?; + if let Some(val) = self { + val.encode(writer)?; + } + Ok(()) + } +} + +// ==================== Decode implementations for basic types ==================== + +impl Decode for bool { + fn decode(reader: &mut R) -> Result { + let mut bytes = [0u8; 1]; + reader.read_exact(&mut bytes)?; + Ok(bytes[0] != 0) + } +} + +impl Decode for u8 { + fn decode(reader: &mut R) -> Result { + let mut bytes = [0u8; 1]; + reader.read_exact(&mut bytes)?; + Ok(bytes[0]) + } +} + +impl Decode for u32 { + fn decode(reader: &mut R) -> Result { + let mut bytes = [0u8; 4]; + reader.read_exact(&mut bytes)?; + Ok(u32::from_le_bytes(bytes)) + } +} + +impl Decode for usize { + fn decode(reader: &mut R) -> Result { + let val = u32::decode(reader)?; + Ok(val as usize) + } +} + +impl Decode for F { + fn decode(reader: &mut R) -> Result { + let mut bytes = [0u8; 4]; + reader.read_exact(&mut bytes)?; + let value = u32::from_le_bytes(bytes); + if value < F::ORDER_U32 { + Ok(F::from_canonical_u32(value)) + } else { + Err(io::Error::other(format!( + "Attempted read of {} into F >= F::ORDER_U32 {}", + value, + F::ORDER_U32 + ))) + } + } +} + +impl Decode for EF { + fn decode(reader: &mut R) -> Result { + let mut base_slice = [F::ZERO; D_EF]; + for val in &mut base_slice { + *val = F::decode(reader)?; + } + Ok(EF::from_base_slice(&base_slice)) + } +} + +// ==================== Decode helpers ==================== + +/// Decodes into a vector given preset length +pub fn decode_into_vec(reader: &mut R, len: usize) -> Result> { + let mut vec = Vec::with_capacity(len); + for _ in 0..len { + vec.push(T::decode(reader)?); + } + Ok(vec) +} + +impl Decode for [T; N] { + fn decode(reader: &mut R) -> Result { + let mut result = from_fn(|_| T::default()); + for val in &mut result { + *val = T::decode(reader)?; + } + Ok(result) + } +} + +impl Decode for Vec { + fn decode(reader: &mut R) -> Result { + let len = usize::decode(reader)?; + let mut vec = Vec::with_capacity(len); + for _ in 0..len { + vec.push(T::decode(reader)?); + } + Ok(vec) + } +} + +impl Decode for (S, T) { + fn decode(reader: &mut R) -> Result { + Ok((S::decode(reader)?, T::decode(reader)?)) + } +} + +impl Decode for Option { + fn decode(reader: &mut R) -> Result { + let is_some = bool::decode(reader)?; + if is_some { + Ok(Some(T::decode(reader)?)) + } else { + Ok(None) + } + } +} diff --git a/crates/stark-backend-v2/src/config.rs b/crates/stark-backend-v2/src/config.rs new file mode 100644 index 00000000..5022b584 --- /dev/null +++ b/crates/stark-backend-v2/src/config.rs @@ -0,0 +1,162 @@ +use getset::Getters; +use openvm_stark_backend::interaction::LogUpSecurityParameters; +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Getters)] +pub struct SystemParams { + pub l_skip: usize, + pub n_stack: usize, + /// `-log_2` of the rate for the initial Reed-Solomon code. + pub log_blowup: usize, + #[getset(get = "pub")] + pub whir: WhirConfig, + pub logup: LogUpSecurityParameters, + /// Global max constraint degree enforced across all AIR and Interaction constraints + pub max_constraint_degree: usize, +} + +impl SystemParams { + pub fn logup_pow_bits(&self) -> usize { + self.logup.pow_bits + } + + pub fn k_whir(&self) -> usize { + self.whir.k + } + + #[inline] + pub fn log_stacked_height(&self) -> usize { + self.l_skip + self.n_stack + } + + #[inline] + pub fn log_final_poly_len(&self) -> usize { + self.whir.log_final_poly_len(self.log_stacked_height()) + } + + #[inline] + pub fn num_whir_rounds(&self) -> usize { + self.whir.num_whir_rounds() + } + + #[inline] + pub fn num_whir_sumcheck_rounds(&self) -> usize { + self.whir.num_sumcheck_rounds() + } +} + +/// Configurable parameters that are used to determine the [WhirConfig] for a target security level. +#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)] +pub struct WhirParams { + pub k: usize, + /// WHIR rounds will stop as soon as `log2` of the final polynomial length is `<= + /// log_final_poly_len`. + pub log_final_poly_len: usize, + pub query_phase_pow_bits: usize, +} + +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] +pub struct WhirConfig { + /// Constant folding factor. This means that `2^k` terms are folded per round. + pub k: usize, + pub rounds: Vec, + /// Number of bits of grinding for the query phase of each WHIR round. + /// The PoW bits can vary per round, but for simplicity we use the same number for all rounds. + pub query_phase_pow_bits: usize, + /// Number of bits of grinding before sampling folding randomness in each WHIR round. + /// The folding PoW bits can vary per round, but for simplicity (and efficiency of the + /// recursion circuit) we use the same number for all rounds. + pub folding_pow_bits: usize, +} + +#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)] +pub struct WhirRoundConfig { + pub num_queries: usize, +} + +/// Defines the soundness type for the proof system. +#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)] +pub enum SoundnessType { + /// Unique decoding guarantees a single valid witness. + UniqueDecoding, +} + +impl WhirConfig { + /// Sets parameters targeting 100-bits of provable security, with grinding, using the unique + /// decoding regime. + pub fn new( + log_blowup: usize, + log_stacked_height: usize, + whir_params: WhirParams, + security_bits: usize, + ) -> Self { + let query_phase_pow_bits = whir_params.query_phase_pow_bits; + let protocol_security_level = security_bits.saturating_sub(query_phase_pow_bits); + let k_whir = whir_params.k; + let num_rounds = log_stacked_height + .saturating_sub(whir_params.log_final_poly_len) + .div_ceil(k_whir); + let mut log_inv_rate = log_blowup; + + // A safe setting for BabyBear and ~200 columns + // TODO[jpw]: use rbr_soundness_queries_combination + const FOLDING_POW_BITS: usize = 10; + + let mut round_parameters = Vec::with_capacity(num_rounds); + for _round in 0..num_rounds { + // Queries are set w.r.t. to old rate, while the rest to the new rate + let next_rate = log_inv_rate + (k_whir - 1); + + let num_queries = Self::queries( + SoundnessType::UniqueDecoding, + protocol_security_level, + log_inv_rate, + ); + round_parameters.push(WhirRoundConfig { num_queries }); + + log_inv_rate = next_rate; + } + + Self { + k: k_whir, + rounds: round_parameters, + query_phase_pow_bits, + folding_pow_bits: FOLDING_POW_BITS, + } + } + + #[inline] + pub fn log_final_poly_len(&self, log_stacked_height: usize) -> usize { + log_stacked_height - self.num_whir_rounds() * self.k + } + + pub fn num_whir_rounds(&self) -> usize { + self.rounds.len() + } + + #[inline] + pub fn num_sumcheck_rounds(&self) -> usize { + self.num_whir_rounds() * self.k + } + + /// Pure function to calculate the number of queries necessary for a given WHIR round. + /// - `protocol_security_level` refers to the target bits of security without grinding. + /// - `log_inv_rate` is the log blowup for the WHIR round we want to calculate the number of + /// queries for. + // Source: https://github.com/WizardOfMenlo/whir/blob/cf1599b56ff50e09142ebe6d2e2fbd86875c9986/src/whir/parameters.rs#L457 + pub fn queries( + soundness_type: SoundnessType, + protocol_security_level: usize, + log_inv_rate: usize, + ) -> usize { + let num_queries_f = match soundness_type { + SoundnessType::UniqueDecoding => { + let rate = 1. / f64::from(1 << log_inv_rate); + let denom = (0.5 * (1. + rate)).log2(); + + -(protocol_security_level as f64) / denom + } + }; + num_queries_f.ceil() as usize + } +} diff --git a/crates/stark-backend-v2/src/debug.rs b/crates/stark-backend-v2/src/debug.rs new file mode 100644 index 00000000..112820ed --- /dev/null +++ b/crates/stark-backend-v2/src/debug.rs @@ -0,0 +1,135 @@ +// Merge into debug/mod.rs in v1: +// - debug_constraints +// - debug_constraints_and_interactions +use std::sync::Arc; + +use itertools::{izip, Itertools}; +use openvm_stark_backend::{ + air_builders::{ + debug::{check_constraints, check_logup, USE_DEBUG_BUILDER}, + symbolic::SymbolicConstraints, + }, + prover::types::AirProofRawInput, + AirRef, +}; + +use crate::{ + keygen::{types::StarkProvingKeyV2, MultiStarkKeygenBuilderV2}, + prover::{ + ColMajorMatrix, DeviceDataTransporterV2, ProverBackendV2, ProvingContextV2, + StridedColMajorMatrixView, + }, + SystemParams, SC, +}; + +// TODO[jpw]: move into StarkEngineV2::debug default implementation after `SC` is made generic. +/// `airs` should be the full list of all AIRs, not just used AIRs. +pub fn debug_impl, PD: DeviceDataTransporterV2>( + config: SystemParams, + device: &PD, + airs: &[AirRef], + ctx: &ProvingContextV2, +) { + let mut keygen_builder = MultiStarkKeygenBuilderV2::new(config); + for air in airs { + keygen_builder.add_air(air.clone()); + } + let pk = keygen_builder.generate_pk().unwrap(); + + let transpose = |mat: ColMajorMatrix| { + let row_major = StridedColMajorMatrixView::from(mat.as_view()).to_row_major_matrix(); + Arc::new(row_major) + }; + let (inputs, used_airs, used_pks): (Vec<_>, Vec<_>, Vec<_>) = ctx + .per_trace + .iter() + .map(|(air_id, air_ctx)| { + // Transfer from device **back** to host so the debugger can read the data. + let common_main = device.transport_matrix_from_device_to_host(&air_ctx.common_main); + let cached_mains = air_ctx + .cached_mains + .iter() + .map(|cd| transpose(device.transport_matrix_from_device_to_host(&cd.trace))) + .collect_vec(); + let common_main = Some(transpose(common_main)); + let public_values = air_ctx.public_values.clone(); + ( + AirProofRawInput { + cached_mains, + common_main, + public_values, + }, + airs[*air_id].clone(), + &pk.per_air[*air_id], + ) + }) + .multiunzip(); + + debug_constraints_and_interactions(&used_airs, &used_pks, &inputs); +} + +/// The debugging will check the main AIR constraints and then separately check LogUp constraints by +/// checking the actual multiset equalities. Currently it will not debug check any after challenge +/// phase constraints for implementation simplicity. +#[allow(dead_code)] +#[allow(clippy::too_many_arguments)] +pub fn debug_constraints_and_interactions( + airs: &[AirRef], + pk: &[&StarkProvingKeyV2], + inputs: &[AirProofRawInput], +) { + USE_DEBUG_BUILDER.with(|debug| { + if *debug.lock().unwrap() { + let (main_parts_per_air, pvs_per_air): (Vec<_>, Vec<_>) = inputs + .iter() + .map(|input| { + let mut main_parts = input + .cached_mains + .iter() + .map(|trace| trace.as_view()) + .collect_vec(); + if let Some(trace) = input.common_main.as_ref() { + main_parts.push(trace.as_view()); + } + (main_parts, input.public_values.clone()) + }) + .unzip(); + let preprocessed = izip!(airs, pk, &main_parts_per_air, &pvs_per_air) + .map(|(air, pk, main_parts, pvs)| { + let preprocessed_trace = pk + .preprocessed_data + .as_ref() + .map(|data| data.mat_view(0).to_row_major_matrix()); + tracing::debug!("Checking constraints for {}", air.name()); + check_constraints( + air.as_ref(), + &air.name(), + &preprocessed_trace.as_ref().map(|t| t.as_view()), + main_parts, + pvs, + ); + preprocessed_trace + }) + .collect_vec(); + + let (air_names, interactions): (Vec<_>, Vec<_>) = pk + .iter() + .map(|pk| { + let sym_constraints = SymbolicConstraints::from(&pk.vk.symbolic_constraints); + (pk.air_name.clone(), sym_constraints.interactions) + }) + .unzip(); + let preprocessed_views = preprocessed + .iter() + .map(|t| t.as_ref().map(|t| t.as_view())) + .collect_vec(); + check_logup( + &air_names, + &interactions, + &preprocessed_views, + &main_parts_per_air, + &pvs_per_air, + ); + } + }); +} diff --git a/crates/stark-backend-v2/src/dft/mod.rs b/crates/stark-backend-v2/src/dft/mod.rs new file mode 100644 index 00000000..db7868d5 --- /dev/null +++ b/crates/stark-backend-v2/src/dft/mod.rs @@ -0,0 +1,3 @@ +mod radix_2_bowers_serial; + +pub use radix_2_bowers_serial::*; diff --git a/crates/stark-backend-v2/src/dft/radix_2_bowers_serial.rs b/crates/stark-backend-v2/src/dft/radix_2_bowers_serial.rs new file mode 100644 index 00000000..c43d0890 --- /dev/null +++ b/crates/stark-backend-v2/src/dft/radix_2_bowers_serial.rs @@ -0,0 +1,182 @@ +use std::borrow::BorrowMut; + +// Originally copied from p3-dft [src/radix_2_bowers.rs] to turn off rayon +use p3_dft::{Butterfly, DifButterfly, DitButterfly, TwiddleFreeButterfly, TwoAdicSubgroupDft}; +use p3_field::{Field, PackedValue, Powers, TwoAdicField}; +use p3_matrix::{ + dense::{DenseMatrix, DenseStorage, RowMajorMatrix, RowMajorMatrixViewMut}, + Matrix, +}; +use p3_util::{log2_strict_usize, reverse_bits, reverse_bits_len, reverse_slice_index_bits}; +use tracing::instrument; + +/// The Bowers G FFT algorithm. +/// See: "Improved Twiddle Access for Fast Fourier Transforms" +#[derive(Default, Clone)] +pub struct Radix2BowersSerial; + +impl TwoAdicSubgroupDft for Radix2BowersSerial { + type Evaluations = RowMajorMatrix; + + fn dft(&self, vec: Vec) -> Vec { + self.dft_batch(RowMajorMatrix::new_col(vec)).values + } + + fn dft_batch(&self, mut mat: RowMajorMatrix) -> RowMajorMatrix { + reverse_matrix_index_bits(&mut mat); + bowers_g(&mut mat.as_view_mut()); + mat + } + + /// Compute the inverse DFT of `vec`. + fn idft(&self, vec: Vec) -> Vec { + self.idft_batch(RowMajorMatrix::new(vec, 1)).values + } + + /// Compute the inverse DFT of each column in `mat`. + fn idft_batch(&self, mut mat: RowMajorMatrix) -> RowMajorMatrix { + bowers_g_t(&mut mat.as_view_mut()); + divide_by_height(&mut mat); + reverse_matrix_index_bits(&mut mat); + mat + } + + fn lde_batch(&self, mut mat: RowMajorMatrix, added_bits: usize) -> RowMajorMatrix { + bowers_g_t(&mut mat.as_view_mut()); + divide_by_height(&mut mat); + mat = mat.bit_reversed_zero_pad(added_bits); + bowers_g(&mut mat.as_view_mut()); + mat + } + + #[instrument(skip_all, fields(dims = %mat.dimensions(), added_bits))] + fn coset_lde_batch( + &self, + mut mat: RowMajorMatrix, + added_bits: usize, + shift: F, + ) -> RowMajorMatrix { + let h = mat.height(); + let h_inv = F::from_canonical_usize(h).inverse(); + + bowers_g_t(&mut mat.as_view_mut()); + + // Rescale coefficients in two ways: + // - divide by height (since we're doing an inverse DFT) + // - multiply by powers of the coset shift (see default coset LDE impl for an explanation) + let weights = Powers { + base: shift, + current: h_inv, + } + .take(h); + for (row, weight) in weights.enumerate() { + // reverse_bits because mat is encoded in bit-reversed order + mat.scale_row(reverse_bits(row, h), weight); + } + + mat = mat.bit_reversed_zero_pad(added_bits); + + bowers_g(&mut mat.as_view_mut()); + + mat + } +} + +/// Executes the Bowers G network. This is like a DFT, except it assumes the input is in +/// bit-reversed order. +fn bowers_g(mat: &mut RowMajorMatrixViewMut) { + let h = mat.height(); + let log_h = log2_strict_usize(h); + + let root = F::two_adic_generator(log_h); + let mut twiddles: Vec<_> = root.powers().take(h / 2).map(DifButterfly).collect(); + reverse_slice_index_bits(&mut twiddles); + + let log_h = log2_strict_usize(mat.height()); + for log_half_block_size in 0..log_h { + butterfly_layer(mat, 1 << log_half_block_size, &twiddles) + } +} + +/// Executes the Bowers G^T network. This is like an inverse DFT, except we skip rescaling by +/// 1/height, and the output is bit-reversed. +fn bowers_g_t(mat: &mut RowMajorMatrixViewMut) { + let h = mat.height(); + let log_h = log2_strict_usize(h); + + let root_inv = F::two_adic_generator(log_h).inverse(); + let mut twiddles: Vec<_> = root_inv.powers().take(h / 2).map(DitButterfly).collect(); + reverse_slice_index_bits(&mut twiddles); + + let log_h = log2_strict_usize(mat.height()); + for log_half_block_size in (0..log_h).rev() { + butterfly_layer(mat, 1 << log_half_block_size, &twiddles) + } +} + +fn butterfly_layer>( + mat: &mut RowMajorMatrixViewMut, + half_block_size: usize, + twiddles: &[B], +) { + mat.row_chunks_exact_mut(2 * half_block_size) + .enumerate() + .for_each(|(block, mut chunks)| { + let (mut hi_chunks, mut lo_chunks) = chunks.split_rows_mut(half_block_size); + hi_chunks + .rows_mut() + .zip(lo_chunks.rows_mut()) + .for_each(|(hi_chunk, lo_chunk)| { + if block == 0 { + TwiddleFreeButterfly.apply_to_rows(hi_chunk, lo_chunk) + } else { + twiddles[block].apply_to_rows(hi_chunk, lo_chunk); + } + }); + }); +} + +pub fn divide_by_height + BorrowMut<[F]>>( + mat: &mut DenseMatrix, +) { + scale_slice_in_place( + F::from_canonical_usize(mat.height()).inverse(), + mat.values.borrow_mut(), + ); +} + +pub fn scale_slice_in_place(s: F, slice: &mut [F]) { + let (packed, sfx) = F::Packing::pack_slice_with_suffix_mut(slice); + let packed_s: F::Packing = s.into(); + packed.iter_mut().for_each(|x| *x *= packed_s); + sfx.iter_mut().for_each(|x| *x *= s); +} + +#[instrument(level = "debug", skip_all)] +pub fn reverse_matrix_index_bits<'a, F, S>(mat: &mut DenseMatrix) +where + F: Clone + Send + Sync + 'a, + S: DenseStorage + BorrowMut<[F]>, +{ + let w = mat.width(); + let h = mat.height(); + let log_h = log2_strict_usize(h); + let values = mat.values.borrow_mut().as_mut_ptr() as usize; + + (0..h).for_each(|i| { + let values = values as *mut F; + let j = reverse_bits_len(i, log_h); + if i < j { + unsafe { swap_rows_raw(values, w, i, j) }; + } + }); +} + +/// Assumes `i < j`. +/// +/// SAFETY: The caller must ensure `i < j < h`, where `h` is the height of the matrix. +pub(crate) unsafe fn swap_rows_raw(mat: *mut F, w: usize, i: usize, j: usize) { + let row_i = core::slice::from_raw_parts_mut(mat.add(i * w), w); + let row_j = core::slice::from_raw_parts_mut(mat.add(j * w), w); + row_i.swap_with_slice(row_j); +} diff --git a/crates/stark-backend-v2/src/engine.rs b/crates/stark-backend-v2/src/engine.rs new file mode 100644 index 00000000..50842a2c --- /dev/null +++ b/crates/stark-backend-v2/src/engine.rs @@ -0,0 +1,173 @@ +// Replace engine.rs in v1 +// TODO[jpw]: everything is currently assuming fixed types for: +// - F, EF, Digest, SystemParams +// We will make these generic in the future + +use std::marker::PhantomData; + +use openvm_stark_backend::{config::StarkGenericConfig, prover::Prover, AirRef}; +use openvm_stark_sdk::config::baby_bear_poseidon2::BabyBearPoseidon2Config; + +use crate::{ + debug::debug_impl, + keygen::{ + types::{MultiStarkProvingKeyV2, MultiStarkVerifyingKeyV2}, + MultiStarkKeygenBuilderV2, + }, + poseidon2::sponge::{DuplexSponge, FiatShamirTranscript}, + proof::*, + prover::{ + AirProvingContextV2, CoordinatorV2, CpuBackendV2, CpuDeviceV2, DeviceDataTransporterV2, + DeviceMultiStarkProvingKeyV2, MultiRapProver, OpeningProverV2, ProverBackendV2, + ProverDeviceV2, ProvingContextV2, + }, + verifier::{verify, VerifierError}, + SystemParams, +}; + +/// Data for verifying a Stark proof. +#[derive(Debug)] +pub struct VerificationDataV2 { + pub vk: MultiStarkVerifyingKeyV2, + pub proof: Proof, +} + +/// A helper trait to collect the different steps in multi-trace STARK +/// keygen and proving. Currently this trait is CPU specific. +pub trait StarkEngineV2 +where + >::Artifacts: + Into<>::OpeningPoints>, + >::PartialProof: + Into<(GkrProof, BatchConstraintProof)>, + >::OpeningProof: + Into<(StackingProof, WhirProof)>, +{ + type SC: StarkGenericConfig< + Pcs = ::Pcs, + Challenge = ::Challenge, + Challenger = ::Challenger, + >; + type PB: ProverBackendV2; + type PD: ProverDeviceV2 + DeviceDataTransporterV2; + type TS: FiatShamirTranscript + Default; + + fn config(&self) -> &SystemParams { + self.device().config() + } + + fn device(&self) -> &Self::PD; + + // TODO[jpw]: keygen builder + + fn prover_from_transcript( + &self, + transcript: Self::TS, + ) -> CoordinatorV2; + + fn prover(&self) -> CoordinatorV2 { + self.prover_from_transcript(Self::TS::default()) + } + + fn keygen( + &self, + airs: &[AirRef], + ) -> (MultiStarkProvingKeyV2, MultiStarkVerifyingKeyV2) { + let mut keygen_builder = MultiStarkKeygenBuilderV2::new(self.config().clone()); + for air in airs { + keygen_builder.add_air(air.clone()); + } + + let pk = keygen_builder.generate_pk().unwrap(); + let vk = pk.get_vk(); + (pk, vk) + } + + fn prove( + &self, + pk: &DeviceMultiStarkProvingKeyV2, + ctx: ProvingContextV2, + ) -> Proof { + let mut prover = self.prover(); + prover.prove(pk, ctx) + } + + fn verify(&self, vk: &MultiStarkVerifyingKeyV2, proof: &Proof) -> Result<(), VerifierError> { + let mut transcript = Self::TS::default(); + verify(vk, proof, &mut transcript) + } + + /// The indexing of AIR ID in `ctx` should be consistent with the order of `airs`. In + /// particular, `airs` should correspond to the global proving key with all AIRs, including ones + /// not present in the `ctx`. + fn debug(&self, airs: &[AirRef], ctx: &ProvingContextV2); + + /// Runs a single end-to-end test for a given set of chips and traces partitions. + /// This includes proving/verifying key generation, creating a proof, and verifying the proof. + fn run_test( + &self, + airs: Vec>, + ctxs: Vec>, + ) -> Result { + let (pk, vk) = self.keygen(&airs); + let device = self.prover().device; + let d_pk = device.transport_pk_to_device(&pk); + let ctx = ProvingContextV2::new(ctxs.into_iter().enumerate().collect()); + let proof = self.prove(&d_pk, ctx); + self.verify(&vk, &proof)?; + Ok(VerificationDataV2 { vk, proof }) + } +} + +pub struct BabyBearPoseidon2CpuEngineV2 { + device: CpuDeviceV2, + _transcript: PhantomData, +} + +impl BabyBearPoseidon2CpuEngineV2 { + pub fn new(params: SystemParams) -> Self { + Self { + device: CpuDeviceV2::new(params), + _transcript: PhantomData, + } + } +} + +impl StarkEngineV2 for BabyBearPoseidon2CpuEngineV2 +where + TS: FiatShamirTranscript + Default, +{ + type SC = BabyBearPoseidon2Config; + type PB = CpuBackendV2; + type PD = CpuDeviceV2; + type TS = TS; + + fn device(&self) -> &Self::PD { + &self.device + } + + fn prover_from_transcript( + &self, + transcript: TS, + ) -> CoordinatorV2 { + CoordinatorV2::new(CpuBackendV2, self.device.clone(), transcript) + } + + fn debug(&self, airs: &[AirRef], ctx: &ProvingContextV2) { + debug_impl::(self.config().clone(), self.device(), airs, ctx); + } +} + +// TODO[jpw]: move to stark-sdk +pub trait StarkWhirEngine: StarkEngineV2 { + fn new(params: SystemParams) -> Self; +} + +impl StarkWhirEngine for BabyBearPoseidon2CpuEngineV2 +where + TS: FiatShamirTranscript + Default, +{ + fn new(params: SystemParams) -> Self { + Self::new(params) + } +} diff --git a/crates/stark-backend-v2/src/keygen/mod.rs b/crates/stark-backend-v2/src/keygen/mod.rs new file mode 100644 index 00000000..dfa7dd18 --- /dev/null +++ b/crates/stark-backend-v2/src/keygen/mod.rs @@ -0,0 +1,377 @@ +use std::{cmp::max, collections::HashMap, sync::Arc}; + +use itertools::Itertools; +use openvm_stark_backend::{ + air_builders::symbolic::{ + get_symbolic_builder, + symbolic_variable::{Entry, SymbolicVariable}, + SymbolicConstraintsDag, SymbolicExpressionNode, SymbolicRapBuilder, + }, + keygen::types::{LinearConstraint, TraceWidth}, + prover::MatrixDimensions, + rap::AnyRap, + AirRef, +}; +use openvm_stark_sdk::config::baby_bear_poseidon2::BabyBearPoseidon2Config; +use p3_air::BaseAir; +use p3_field::{Field, FieldAlgebra}; +use p3_util::log2_strict_usize; +use tracing::instrument; + +use crate::{ + keygen::types::{ + KeygenError, MultiStarkProvingKeyV2, MultiStarkVerifyingKey0V2, StarkProvingKeyV2, + StarkVerifyingKeyV2, StarkVerifyingParamsV2, VerifierSinglePreprocessedData, + }, + poseidon2::sponge::poseidon2_hash_slice, + prover::{ + stacked_pcs::{stacked_commit, StackedPcsData}, + ColMajorMatrix, + }, + Digest, SystemParams, F, +}; + +pub mod types; + +type SC = BabyBearPoseidon2Config; + +struct AirKeygenBuilderV2 { + pub is_required: bool, + air: AirRef, + prep_keygen_data: PrepKeygenDataV2, +} + +/// Stateful builder to create multi-stark proving and verifying keys +/// for system of multiple RAPs with multiple multi-matrix commitments +pub struct MultiStarkKeygenBuilderV2 { + pub config: SystemParams, + /// Information for partitioned AIRs. + partitioned_airs: Vec, +} + +impl MultiStarkKeygenBuilderV2 { + pub fn new(config: SystemParams) -> Self { + Self { + config, + partitioned_airs: vec![], + } + } + + /// Default way to add a single Interactive AIR. + /// Returns `air_id` + pub fn add_air(&mut self, air: AirRef) -> usize { + self.add_air_impl(air, false) + } + + pub fn add_required_air(&mut self, air: AirRef) -> usize { + self.add_air_impl(air, true) + } + + #[instrument(level = "debug", skip_all, fields(name = air.name(), is_required = is_required))] + fn add_air_impl(&mut self, air: AirRef, is_required: bool) -> usize { + self.partitioned_airs + .push(AirKeygenBuilderV2::new(&self.config, air, is_required)); + self.partitioned_airs.len() - 1 + } + + /// Consume the builder and generate proving key. + /// The verifying key can be obtained from the proving key. + pub fn generate_pk(self) -> Result { + let max_constraint_degree = self.config.max_constraint_degree; + let pk_per_air: Vec<_> = self + .partitioned_airs + .into_iter() + .map(|keygen_builder| { + // Second pass: get final constraints, where RAP phase constraints may have changed + keygen_builder.generate_pk(max_constraint_degree) + }) + .collect::, KeygenError>>()?; + + let mut air_max_constraint_degree = 0; + for pk in pk_per_air.iter() { + let width = &pk.vk.params.width; + tracing::info!("{:<20} | Constraint Deg = {:<2} | Prep Cols = {:<2} | Main Cols = {:<8} | {:4} Constraints | {:3} Interactions", + pk.air_name, + pk.vk.max_constraint_degree, + width.preprocessed.unwrap_or(0), + format!("{:?}",width.main_widths()), + pk.vk.symbolic_constraints.constraints.constraint_idx.len(), + pk.vk.symbolic_constraints.interactions.len(), + ); + air_max_constraint_degree = max(air_max_constraint_degree, pk.vk.max_constraint_degree); + tracing::debug!( + "On Buses {:?}", + pk.vk + .symbolic_constraints + .interactions + .iter() + .map(|i| i.bus_index) + .collect_vec() + ); + #[cfg(feature = "metrics")] + { + let labels = [("air_name", pk.air_name.clone())]; + metrics::counter!("constraint_deg", &labels) + .absolute(pk.vk.max_constraint_degree as u64); + // column info will be logged by prover later + metrics::counter!("constraints", &labels) + .absolute(pk.vk.symbolic_constraints.constraints.constraint_idx.len() as u64); + metrics::counter!("interactions", &labels) + .absolute(pk.vk.symbolic_constraints.interactions.len() as u64); + } + } + if max_constraint_degree != air_max_constraint_degree as usize { + tracing::warn!( + "Actual max constraint degree across all AIRs ({air_max_constraint_degree}) does not match configured max constraint degree ({max_constraint_degree})", + ); + } + + let num_airs = pk_per_air.len(); + let base_order = F::order().to_u32_digits()[0]; + let mut count_weight_per_air_per_bus_index = HashMap::new(); + + let mut num_interactions_per_air: Vec = Vec::with_capacity(num_airs); + // We compute the a_i's for the constraints of the form a_0 n_0 + ... + a_{k-1} n_{k-1} < + // a_k, First the constraints that the total number of interactions on each bus is + // at most the base field order. + for (air_idx, pk) in pk_per_air.iter().enumerate() { + let constraints = &pk.vk.symbolic_constraints; + num_interactions_per_air.push(constraints.interactions.len().try_into().unwrap()); + for interaction in &constraints.interactions { + // Also make sure that this of interaction is valid given the security params. + // +1 because of the bus + let max_msg_len = self.config.logup.max_message_length(); + // plus one because of the bus + let total_message_length = interaction.message.len() + 1; + assert!( + total_message_length <= max_msg_len, + "interaction message with bus has length {}, which is more than max {max_msg_len}", + total_message_length, + ); + + let b = interaction.bus_index; + let constraint = count_weight_per_air_per_bus_index + .entry(b) + .or_insert_with(|| LinearConstraint { + coefficients: vec![0; num_airs], + threshold: base_order, + }); + constraint.coefficients[air_idx] += interaction.count_weight; + } + } + + // Sorting by bus index is not necessary, but makes debugging/testing easier. + let mut trace_height_constraints = count_weight_per_air_per_bus_index + .into_iter() + .sorted_by_key(|(bus_index, _)| *bus_index) + .map(|(_, constraint)| constraint) + .collect_vec(); + + let log_up_security_params = self.config.logup; + + // Add a constraint for the total number of interactions. + trace_height_constraints.push(LinearConstraint { + coefficients: num_interactions_per_air, + threshold: log_up_security_params.max_interaction_count, + }); + + let pre_vk: MultiStarkVerifyingKey0V2 = MultiStarkVerifyingKey0V2 { + params: self.config.clone(), + per_air: pk_per_air.iter().map(|pk| pk.vk.clone()).collect(), + trace_height_constraints: trace_height_constraints.clone(), + }; + // To protect against weak Fiat-Shamir, we hash the "pre"-verifying key and include it in + // the final verifying key. This just needs to commit to the verifying key and does + // not need to be verified by the verifier, so we just use bincode to serialize it. + let vk_bytes = bitcode::serialize(&pre_vk).unwrap(); + tracing::debug!("pre-vkey: {} bytes", vk_bytes.len()); + // Purely to get type compatibility and convenience, we hash using the native hash + let vk_pre_hash = + poseidon2_hash_slice(&vk_bytes.into_iter().map(F::from_canonical_u8).collect_vec()); + + Ok(MultiStarkProvingKeyV2 { + params: self.config, + per_air: pk_per_air, + trace_height_constraints, + max_constraint_degree, + vk_pre_hash, + }) + } +} + +impl AirKeygenBuilderV2 { + pub fn new(config: &SystemParams, air: AirRef, is_required: bool) -> Self { + let prep_keygen_data = PrepKeygenDataV2::new(config, air.as_ref()); + Self { + is_required, + air, + prep_keygen_data, + } + } + + /// `max_constraint_degree` is the global max constraint degree. If this AIR's constraint degree + /// exceeds it, an error will be returned. + pub fn generate_pk( + self, + max_constraint_degree: usize, + ) -> Result { + let air_name = self.air.name(); + + let symbolic_builder = self.get_symbolic_builder(); + let width = symbolic_builder.width(); + let num_public_values = symbolic_builder.num_public_values(); + + let symbolic_constraints = symbolic_builder.constraints(); + let constraint_degree = symbolic_constraints.max_constraint_degree(); + if constraint_degree > max_constraint_degree { + return Err(KeygenError::MaxConstraintDegreeExceeded { + name: air_name.clone(), + degree: constraint_degree, + max_degree: max_constraint_degree, + }); + } + + let Self { + prep_keygen_data: + PrepKeygenDataV2 { + verifier_data: preprocessed_vdata, + prover_data: prep_prover_data, + }, + .. + } = self; + + let dag = SymbolicConstraintsDag::from(symbolic_constraints); + let max_rotation = dag.constraints.max_rotation(); // TODO: exclude unused vars? + debug_assert!(max_rotation <= 1); + let vparams = StarkVerifyingParamsV2 { + width, + num_public_values, + need_rot: max_rotation == 1, + }; + // Deprecated in v2: + assert!(vparams.width.after_challenge.is_empty()); + + let unused_variables = find_unused_vars(&dag, &vparams.width); + let vk = StarkVerifyingKeyV2 { + preprocessed_data: preprocessed_vdata, + params: vparams, + symbolic_constraints: dag, + max_constraint_degree: constraint_degree + .try_into() + .expect("constraint degree should fit in u8"), + is_required: self.is_required, + unused_variables, + }; + Ok(StarkProvingKeyV2 { + air_name, + vk, + preprocessed_data: prep_prover_data, + }) + } + + pub fn get_symbolic_builder(&self) -> SymbolicRapBuilder { + let width = TraceWidth { + preprocessed: self.prep_keygen_data.width(), + cached_mains: self.air.cached_main_widths(), + common_main: self.air.common_main_width(), + after_challenge: vec![], + }; + get_symbolic_builder(self.air.as_ref(), &width, &[], &[]) + } +} + +pub(super) struct PrepKeygenDataV2 { + pub verifier_data: Option>, + pub prover_data: Option>>, +} + +impl PrepKeygenDataV2 { + fn new(params: &SystemParams, air: &dyn AnyRap) -> PrepKeygenDataV2 { + let preprocessed_trace = BaseAir::::preprocessed_trace(air); + let vpdata_opt = preprocessed_trace.map(|trace| { + let trace = ColMajorMatrix::from_row_major(&trace); + let (commit, data) = stacked_commit( + params.l_skip, + params.n_stack, + params.log_blowup, + params.k_whir(), + &[&trace], + ); + debug_assert_eq!(trace.width(), data.mat_view(0).width()); + let vdata = VerifierSinglePreprocessedData { + commit, + hypercube_dim: log2_strict_usize(trace.height()) as isize - params.l_skip as isize, + stacking_width: data.matrix.width(), + }; + let pdata = Arc::new(data); + (vdata, pdata) + }); + if let Some((vdata, pdata)) = vpdata_opt { + PrepKeygenDataV2 { + prover_data: Some(pdata), + verifier_data: Some(vdata), + } + } else { + PrepKeygenDataV2 { + prover_data: None, + verifier_data: None, + } + } + } + + fn width(&self) -> Option { + self.prover_data.as_ref().map(|d| d.mat_view(0).width()) + } +} + +pub(crate) fn find_unused_vars( + constraints: &SymbolicConstraintsDag, + width: &TraceWidth, +) -> Vec> { + let preprocessed_width = width.preprocessed.unwrap_or(0); + let mut preprocessed_present = vec![vec![false; 2]; preprocessed_width]; + + let mut main_present = vec![]; + for width in width.main_widths() { + main_present.push(vec![vec![false; 2]; width]); + } + + for node in &constraints.constraints.nodes { + let SymbolicExpressionNode::Variable(var) = node else { + continue; + }; + + match var.entry { + Entry::Preprocessed { offset } => { + preprocessed_present[var.index][offset] = true; + } + Entry::Main { part_index, offset } => { + main_present[part_index][var.index][offset] = true; + } + Entry::Public => {} + Entry::Challenge | Entry::Exposed | Entry::Permutation { .. } => unreachable!(), + } + } + + let mut missing = vec![]; + for (index, presents) in preprocessed_present.iter().enumerate() { + for (offset, present) in presents.iter().enumerate() { + if !present { + missing.push(SymbolicVariable::new(Entry::Preprocessed { offset }, index)); + } + } + } + for (part_index, present_per_part) in main_present.iter().enumerate() { + for (index, presents) in present_per_part.iter().enumerate() { + for (offset, present) in presents.iter().enumerate() { + if !present { + missing.push(SymbolicVariable::new( + Entry::Main { part_index, offset }, + index, + )); + } + } + } + } + missing +} diff --git a/crates/stark-backend-v2/src/keygen/types.rs b/crates/stark-backend-v2/src/keygen/types.rs new file mode 100644 index 00000000..86a5b8b9 --- /dev/null +++ b/crates/stark-backend-v2/src/keygen/types.rs @@ -0,0 +1,177 @@ +// NOTE[jpw]: copied from stark-backend but renamed for V2 and without + +// Keygen API for STARK backend +// Changes: +// - All AIRs can be optional +use std::sync::Arc; + +use openvm_stark_backend::{ + air_builders::symbolic::{symbolic_variable::SymbolicVariable, SymbolicConstraintsDag}, + keygen::types::{LinearConstraint, TraceWidth}, +}; +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +use crate::{prover::stacked_pcs::StackedPcsData, Digest, SystemParams, F}; + +#[derive(Error, Debug)] +pub enum KeygenError { + #[error("Max constraint degree exceeded for AIR {name}: {degree} > {max_degree}")] + MaxConstraintDegreeExceeded { + name: String, + degree: usize, + max_degree: usize, + }, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[repr(C)] +pub struct StarkVerifyingParamsV2 { + /// Trace sub-matrix widths + pub width: TraceWidth, + /// Number of public values for this STARK only + pub num_public_values: usize, + /// A flag indication whether we need the rotations + pub need_rot: bool, +} + +/// Verifier data for preprocessed trace for a single AIR. +/// +/// Currently assumes each AIR has it's own preprocessed commitment +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct VerifierSinglePreprocessedData { + /// Commitment to the preprocessed trace. + pub commit: Digest, + /// The hypercube dimension of the preprocessed data _before stacking_ (log_height - + /// vk.l_skip). + pub hypercube_dim: isize, + /// The width of the data after stacking. + pub stacking_width: usize, +} + +/// Verifying key for a single STARK (corresponding to single AIR matrix) +#[derive(Clone, Debug, Serialize, Deserialize)] +#[repr(C)] +pub struct StarkVerifyingKeyV2 { + /// Preprocessed trace data, if any + pub preprocessed_data: Option>, + /// Parameters of the STARK + pub params: StarkVerifyingParamsV2, + /// Symbolic constraints of the AIR in all challenge phases. This is + /// a serialization of the constraints in the AIR. + pub symbolic_constraints: SymbolicConstraintsDag, + /// The maximum degree of any polynomial (constraint or interaction) for this AIR. + pub max_constraint_degree: u8, + /// True means this AIR must have non-empty trace. + pub is_required: bool, + /// Symbolic variables referenced unreferenced by the AIR. + pub unused_variables: Vec>, +} + +/// Common verifying key for multiple AIRs. +/// +/// This struct contains the necessary data for the verifier to verify proofs generated for +/// multiple AIRs using a single verifying key. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct MultiStarkVerifyingKeyV2 { + /// All parts of the verifying key needed by the verifier, except + /// the `pre_hash` used to initialize the Fiat-Shamir transcript. + pub inner: MultiStarkVerifyingKey0V2, + /// The hash of all other parts of the verifying key. The Fiat-Shamir hasher will + /// initialize by observing this hash. + pub pre_hash: Digest, +} + +/// Everything in [MultiStarkVerifyingKey] except the `pre_hash` used to initialize the Fiat-Shamir +/// transcript. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct MultiStarkVerifyingKey0V2 { + pub params: SystemParams, + pub per_air: Vec>, + pub trace_height_constraints: Vec, +} + +/// Proving key for a single STARK (corresponding to single AIR matrix) +#[derive(Clone, Serialize, Deserialize)] +pub struct StarkProvingKeyV2 { + /// Type name of the AIR, for display purposes only + pub air_name: String, + /// Verifying key + pub vk: StarkVerifyingKeyV2, + /// Prover only data for preprocessed trace + pub preprocessed_data: Option>>, +} + +/// Common proving key for multiple AIRs. +/// +/// This struct contains the necessary data for the prover to generate proofs for multiple AIRs +/// using a single proving key. +#[derive(Clone, Serialize, Deserialize)] +pub struct MultiStarkProvingKeyV2 { + pub per_air: Vec, + pub trace_height_constraints: Vec, + /// Maximum degree of constraints across all AIRs + pub max_constraint_degree: usize, + pub params: SystemParams, + /// See [MultiStarkVerifyingKey] + pub vk_pre_hash: Digest, +} + +impl StarkVerifyingKeyV2 { + pub fn num_cached_mains(&self) -> usize { + self.params.width.cached_mains.len() + } + + pub fn num_parts(&self) -> usize { + 1 + self.num_cached_mains() + (self.preprocessed_data.is_some() as usize) + } + + pub fn has_interaction(&self) -> bool { + !self.symbolic_constraints.interactions.is_empty() + } + + pub fn num_interactions(&self) -> usize { + self.symbolic_constraints.interactions.len() + } + + /// Converts from a main part index (as used by the constraint DAG) to the + /// commitment part indexing scheme that includes preprocessed trace. + pub fn dag_main_part_index_to_commit_index(&self, index: usize) -> usize { + // In the dag, common main is the final part index. + if index == self.num_cached_mains() { + 0 + } else { + index + 1 + self.preprocessed_data.is_some() as usize + } + } +} + +impl MultiStarkProvingKeyV2 { + pub fn get_vk(&self) -> MultiStarkVerifyingKeyV2 { + MultiStarkVerifyingKeyV2 { + inner: self.get_vk0(), + pre_hash: self.vk_pre_hash, + } + } + + fn get_vk0(&self) -> MultiStarkVerifyingKey0V2 { + MultiStarkVerifyingKey0V2 { + params: self.params.clone(), + per_air: self.per_air.iter().map(|pk| pk.vk.clone()).collect(), + trace_height_constraints: self.trace_height_constraints.clone(), + } + } +} + +impl MultiStarkVerifyingKeyV2 { + /// Global maximum constraint degree across all AIRs and Interactions. + pub fn max_constraint_degree(&self) -> usize { + self.inner.max_constraint_degree() + } +} + +impl MultiStarkVerifyingKey0V2 { + pub fn max_constraint_degree(&self) -> usize { + self.params.max_constraint_degree + } +} diff --git a/crates/stark-backend-v2/src/lib.rs b/crates/stark-backend-v2/src/lib.rs new file mode 100644 index 00000000..8625bb30 --- /dev/null +++ b/crates/stark-backend-v2/src/lib.rs @@ -0,0 +1,62 @@ +// TODO[TEMP]: remove once we make traits generic in SC +pub use openvm_stark_sdk; +use openvm_stark_sdk::config::baby_bear_poseidon2::BabyBearPoseidon2Config; +use p3_baby_bear::BabyBear; +use p3_field::extension::BinomialExtensionField; +use p3_util::log2_ceil_u64; + +mod chip; +pub mod codec; +mod config; +pub mod debug; +pub mod dft; +mod engine; +pub mod keygen; +pub mod poly_common; +pub mod poseidon2; +pub mod proof; +pub mod prover; +pub mod utils; +pub mod v1_shims; +pub mod verifier; + +#[cfg(any(test, feature = "test-utils"))] +pub mod test_utils; +#[cfg(test)] +mod tests; + +pub use chip::*; +pub use config::*; +pub use engine::*; + +pub type F = BabyBear; +pub type EF = BinomialExtensionField; +pub const D_EF: usize = 4; + +pub const DIGEST_SIZE: usize = poseidon2::CHUNK; +pub type Digest = [F; DIGEST_SIZE]; + +// TODO: remove after making SC generic in v2 +pub type SC = BabyBearPoseidon2Config; + +/// Common utility function for computing `n_logup` parameter in terms of `total_interactions`, +/// which is the sum of interaction message counts across all traces, using the lifted trace +/// heights. +/// +/// This calculation must be consistent between the prover and verifier and is enforced by +/// the verifier. +// NOTE: we could use a more strict calculation of `n_logup = log2_ceil(total_interactions >> +// l_skip)` but the `leading_zeros` calculation below is easier to check in the recursion +// circuit. The formula below is equivalent to `log2_ceil(total_interactions + 1) - l_skip`. +pub fn calculate_n_logup(l_skip: usize, total_interactions: u64) -> usize { + if total_interactions != 0 { + let n_logup = (u64::BITS - total_interactions.leading_zeros()) as usize - l_skip; + debug_assert_eq!( + n_logup + l_skip, + log2_ceil_u64(total_interactions + 1) as usize + ); + n_logup + } else { + 0 + } +} diff --git a/crates/stark-backend-v2/src/poly_common.rs b/crates/stark-backend-v2/src/poly_common.rs new file mode 100644 index 00000000..faaa9167 --- /dev/null +++ b/crates/stark-backend-v2/src/poly_common.rs @@ -0,0 +1,851 @@ +use core::ops::{Add, Sub}; +use std::{iter::zip, ops::Mul}; + +use itertools::Itertools; +use p3_dft::TwoAdicSubgroupDft; +use p3_field::{ExtensionField, Field, FieldAlgebra, TwoAdicField}; +use p3_matrix::{dense::RowMajorMatrix, Matrix}; +use p3_util::{log2_ceil_usize, log2_strict_usize}; +use tracing::instrument; + +use crate::{ + dft::Radix2BowersSerial, prover::poly::evals_eq_hypercube_serial, + utils::batch_multiplicative_inverse_serial, +}; + +pub fn eval_eq_mle(x: &[F1], y: &[F2]) -> F3 +where + F1: Field, + F2: Field, + F3: Field, + F1: Mul, + F3: Sub, + F3: Sub, +{ + debug_assert_eq!(x.len(), y.len()); + zip(x, y).fold(F3::ONE, |acc, (&x_i, &y_i)| { + acc * (F3::ONE - y_i - x_i + (x_i * y_i).double()) + }) +} + +/// Let D be univariate skip domain, the subgroup of `F^*` of order `l_skip`. +/// +/// Computes the polynomial ```text +/// eq_D(X, Y) = \sum_{z_1 \in D} \prod_{z_2 \in D, z_2 != z_1} (X - z_1)(Y - z_2) / (z_1 - +/// z_2)^2 ``` +pub fn eval_eq_uni(l_skip: usize, x: F, y: F) -> F { + let mut res = F::ONE; + for (x_pow, y_pow) in zip(x.exp_powers_of_2(), y.exp_powers_of_2()).take(l_skip) { + res = (x_pow + y_pow) * res + (x_pow - F::ONE) * (y_pow - F::ONE); + } + res * F::ONE.halve().exp_u64(l_skip as u64) +} + +/// Let D be univariate skip domain, the subgroup of `F^*` of order `l_skip`. +/// +/// Computes the polynomial eq_D(X, 1); see `eval_eq_uni`. +pub fn eval_eq_uni_at_one(l_skip: usize, x: F) -> F { + let mut res = F::ONE; + for x_pow in x.exp_powers_of_2().take(l_skip) { + res *= x_pow + F::ONE; + } + res * F::ONE.halve().exp_u64(l_skip as u64) +} + +/// Returns `eq_D(x, Z)` as a polynomial in `Z` in coefficient form. +/// Derived from `eq_D(x, Z)` being the Lagrange basis at `x`, which is the character sum over the +/// roots of unity. +/// +/// If z in D, then `eq_D(x, z) = 1/N sum_{k=1}^N (x/z)^k = 1/N sum_{k=1}^N x^k +/// z^{N-k}`. +pub fn eq_uni_poly(l_skip: usize, x: EF) -> UnivariatePoly +where + F: Field, + EF: ExtensionField, +{ + let n_inv = F::ONE.halve().exp_u64(l_skip as u64); + let mut coeffs = x + .powers() + .skip(1) + .take(1 << l_skip) + .map(|x_pow| x_pow * n_inv) + .collect_vec(); + coeffs.reverse(); + coeffs[0] = n_inv.into(); + UnivariatePoly::new(coeffs) +} + +pub fn eval_in_uni(l_skip: usize, n: isize, z: F) -> F { + debug_assert!(n >= -(l_skip as isize)); + if n.is_negative() { + eval_eq_uni_at_one( + n.unsigned_abs(), + z.exp_power_of_2(l_skip.wrapping_add_signed(n)), + ) + } else { + F::ONE + } +} + +pub fn eval_eq_prism(l_skip: usize, x: &[F], y: &[F]) -> F { + eval_eq_uni(l_skip, x[0], y[0]) * eval_eq_mle(&x[1..], &y[1..]) +} + +/// Length of `xi_1` should be `l_skip`. +pub fn eval_eq_sharp_uni(omega_skip_pows: &[F], xi_1: &[EF], z: EF) -> EF +where + F: Field, + EF: ExtensionField, +{ + let l_skip = xi_1.len(); + debug_assert_eq!(omega_skip_pows.len(), 1 << l_skip); + + let mut res = EF::ZERO; + let eq_xi_evals = evals_eq_hypercube_serial(xi_1); + for (&omega_pow, eq_xi_eval) in omega_skip_pows.iter().zip(eq_xi_evals) { + res += eval_eq_uni(l_skip, z, omega_pow.into()) * eq_xi_eval; + } + #[cfg(debug_assertions)] + { + let coeffs = (0..(1 << l_skip)) + .map(|k| { + let mut c = EF::ONE; + #[allow(clippy::needless_range_loop)] + for i in 0..l_skip { + let idx = (k << i) % (1 << l_skip); + c *= EF::ONE - xi_1[i] + + xi_1[i] * omega_skip_pows[((1 << l_skip) - idx) % (1 << l_skip)]; + } + c + }) + .collect_vec(); + let mut rpow = EF::ONE; + let mut other = EF::ZERO; + for c in coeffs { + other += rpow * c; + rpow *= z; + } + other *= EF::TWO.inverse().exp_u64(l_skip as u64); + debug_assert_eq!(other, res); + } + res +} + +pub fn eq_sharp_uni_poly(xi_1: &[EF]) -> UnivariatePoly { + let evals = evals_eq_hypercube_serial(xi_1); + UnivariatePoly::from_evals_idft(&evals) +} + +/// `\kappa_\rot(x, y)` should equal `\delta_{x,rot(y)}` on hyperprism. +/// +/// `omega_pows` must have length `2^{l_skip}`. +pub fn eval_rot_kernel_prism(l_skip: usize, x: &[F], y: &[F]) -> F { + let omega = F::two_adic_generator(l_skip); + + let (eq_cube, rot_cube) = eval_eq_rot_cube(&x[1..], &y[1..]); + // If not at boundary of D, just rotate in D, don't change cube coordinates. Otherwise at + // boundary, rotate the cube + eval_eq_uni(l_skip, x[0], y[0] * omega) * eq_cube + + eval_eq_uni_at_one(l_skip, x[0]) + * eval_eq_uni_at_one(l_skip, y[0] * omega) + * (rot_cube - eq_cube) +} + +/// MLE of cyclic rotation kernel on hypercube +pub fn eval_eq_rot_cube(x: &[F], y: &[F]) -> (F, F) { + let n = x.len(); + debug_assert_eq!(n, y.len()); + // Recursive formula: rot(x, y) = x[0] * (1 - y[0]) * eq(x[1..], y[1..]) + (1 - x[0]) y[0] * + // rot(x[1..], y[1..]) + let mut rot = F::ONE; + let mut eq = F::ONE; + for i in (0..n).rev() { + rot = x[i] * (F::ONE - y[i]) * eq + (F::ONE - x[i]) * y[i] * rot; + eq *= x[i] * y[i] + (F::ONE - x[i]) * (F::ONE - y[i]); + } + (eq, rot) +} + +// Source: https://github.com/starkware-libs/stwo/blob/dev/crates/stwo/src/prover/lookups/utils.rs#L12 +/// Univariate polynomial in coefficient form. +#[derive(Clone, Debug)] +pub struct UnivariatePoly(pub(crate) Vec); + +impl UnivariatePoly { + pub fn new(coeffs: Vec) -> Self { + Self(coeffs) + } + + pub fn coeffs(&self) -> &[F] { + &self.0 + } + + pub fn coeffs_mut(&mut self) -> &mut Vec { + &mut self.0 + } + + pub fn into_coeffs(self) -> Vec { + self.0 + } +} + +impl UnivariatePoly { + pub fn eval_at_point>(&self, x: EF) -> EF { + horner_eval(&self.0, x) + } + + #[instrument(level = "debug", skip_all)] + pub fn lagrange_interpolate(points: &[BF], evals: &[F]) -> Self + where + F: ExtensionField, + { + assert_eq!(points.len(), evals.len()); + let len = points.len(); + + // Special case: empty or single evaluation + if len == 0 { + return Self(vec![]); + } + if len == 1 { + return Self(vec![evals[0]]); + } + + // Lagrange interpolation algorithm + // P(x) = sum_{i=0}^{len-1} evals[i] * L_i(x) + // where L_i(x) = prod_{j != i} (x - points[j]) / (points[i] - points[j]) + + // Step 1: Compute all denominators (points[i] - points[j]) for i != j + let mut denominators = Vec::with_capacity(len * (len - 1)); + for i in 0..len { + for j in 0..len { + if i != j { + denominators.push(points[i] - points[j]); + } + } + } + + // Step 2: Batch invert all denominators + let inv_denominators = batch_multiplicative_inverse_serial(&denominators); + + // Step 3: Build coefficient form by accumulating Lagrange basis polynomials + let mut coeffs = vec![F::ZERO; len]; + + // Reusable workspace for Lagrange polynomial computation + let mut lagrange_poly = Vec::with_capacity(len); + + #[allow(clippy::needless_range_loop)] + for i in 0..len { + // Skip if evaluation is zero (optimization) + if evals[i] == F::ZERO { + continue; + } + + // Build L_i(x) in coefficient form using polynomial multiplication + // L_i(x) = prod_{j != i} (x - points[j]) / (points[i] - points[j]) + + // Start with constant polynomial 1 + lagrange_poly.clear(); + lagrange_poly.push(F::ONE); + + // Get the precomputed inverse denominators for this i + let inv_denom_start = i * (len - 1); + let mut inv_idx = 0; + + // Multiply by (x - points[j]) / (points[i] - points[j]) for each j != i + #[allow(clippy::needless_range_loop)] + for j in 0..len { + if i != j { + let scale = inv_denominators[inv_denom_start + inv_idx]; + inv_idx += 1; + + // Multiply lagrange_poly by (x - points[j]) * scale in place + // This is equivalent to: lagrange_poly * (x - points[j]) * scale + // = lagrange_poly * x * scale - lagrange_poly * points[j] * scale + + lagrange_poly.push(F::ZERO); // Extend by one for the new highest degree term + for k in (1..lagrange_poly.len()).rev() { + let prev_coeff = lagrange_poly[k - 1] * scale; + lagrange_poly[k] += prev_coeff; + lagrange_poly[k - 1] = -prev_coeff * points[j]; + } + } + } + + // Add evals[i] * L_i(x) to the result + for (k, &coeff) in lagrange_poly.iter().enumerate() { + coeffs[k] += evals[i] * coeff; + } + } + + Self(coeffs) + } +} + +impl UnivariatePoly { + /// Computes P(1), P(omega), ..., P(omega^{n-1}). + fn chirp_z(poly: &[F], omega: F, n: usize) -> Vec { + if n == 0 { + return Vec::new(); + } + if poly.is_empty() { + return vec![F::ZERO; n]; + } + let s = poly.len() + n; + let omega_powers = (0..(s as u64)) + .map(|i| omega.exp_u64(i * (i.saturating_sub(1)) / 2)) + .collect_vec(); + let omega_powers_inv = batch_multiplicative_inverse_serial(&omega_powers); + let mut p = zip(poly, &omega_powers_inv) + .map(|(&c, &inv)| c * inv) + .collect_vec(); + let mut q = omega_powers.iter().rev().copied().collect_vec(); + + let dft_deg = (p.len() + q.len() - 1).next_power_of_two(); + p.resize(dft_deg, F::ZERO); + q.resize(dft_deg, F::ZERO); + let dft = Radix2BowersSerial; + p = dft.dft(p); + q = dft.dft(q); + for (x, y) in p.iter_mut().zip(q.iter()) { + *x *= *y; + } + p = dft.idft(p); + zip(p.into_iter().skip(n).take(n).rev(), omega_powers_inv) + .map(|(x, inv)| x * inv) + .collect() + } + + /// Given z and n, find product (1 - x)(1 - zx)...(1 - z^{n-1}x). + /// If n is odd, this can be trivially computed from n - 1. + /// Otherwise, F_n(x) = F_{n/2}(x) * F_{n/2}(x * z^{n/2}). + fn geometric_sequence_linear_product_helper( + dft: &Radix2BowersSerial, + z: F, + n: usize, + ) -> Vec { + if n == 1 { + vec![F::ONE, F::NEG_ONE] + } else if n % 2 == 1 { + let mut prev = Self::geometric_sequence_linear_product_helper(dft, z, n - 1); + let zp = z.exp_u64((n - 1) as u64); + prev.push(F::ZERO); + for i in (1..prev.len()).rev() { + let value = prev[i - 1] * zp; + prev[i] -= value; + } + prev + } else { + let mut prev = Self::geometric_sequence_linear_product_helper(dft, z, n / 2); + let zp = z.exp_u64((n / 2) as u64); + let mut another = prev + .iter() + .zip(zp.powers()) + .map(|(a, b)| *a * b) + .collect_vec(); + let len = prev.len().next_power_of_two() * 2; + prev.resize(len, F::ZERO); + another.resize(len, F::ZERO); + prev = dft.dft(prev); + another = dft.dft(another); + for (x, y) in prev.iter_mut().zip(another.into_iter()) { + *x *= y; + } + prev = dft.idft(prev); + prev.truncate(n + 1); + prev + } + } + + /// Constructs the polynomial in coefficient form from its evaluations on + /// `{omega^0,...,omega^d}` where `d` is the degree of the polynomial. Here `omega` is a + /// (fixed) generator of the two-adic subgroup of order `(d+1).next_power_of_two()`. + #[instrument(level = "debug", skip_all)] + pub fn from_evals(evals: &[F]) -> Self { + let n = evals.len(); + let log_n = log2_ceil_usize(n); + let omega = F::two_adic_generator(log_n); + let omega_pows = omega.powers().take((1 << log_n) + 1).collect_vec(); + if n == 0 { + return Self(Vec::new()); + } + if n == 1 { + return Self(vec![evals[0]]); + } + + // We know that, by Lagrange interpolation, + // P(x) = \sum_i evals[i] * \prod_{j\neq i} (x - omega^j) / (omega^i - omega^j). + // Let y[i] = evals[i] / (omega^{(n-1) * i} * prod_{j < n - 1 - i}(1 - omega^j) * prod_{j < + // i}(1 - omega^{-j})). Then P(x) = \sum_i y[i] * \prod_{j\neq i} (x - omega^j). + + let mut positive_denoms = vec![F::ONE; n]; + let mut negative_denoms = vec![F::ONE; n]; + for i in 0..(n - 1) { + positive_denoms[i + 1] = positive_denoms[i] / (F::ONE - omega_pows[i + 1]); + negative_denoms[i + 1] = + negative_denoms[i] / (F::ONE - omega_pows[(1 << log_n) - 1 - i]); + } + let omega_inv = omega_pows[(1 << log_n) - 1]; + let y = (0..n) + .map(|i| { + evals[i] + * omega_inv.exp_u64(((n - 1) * i) as u64) + * negative_denoms[i] + * positive_denoms[n - 1 - i] + }) + .collect_vec(); + + // If we reverse both P and replace all (x - a) with (1 - ax), we'll still have an equality. + // So from now we assume that P(x) = \sum_i y[i] * \prod_{j\neq i} (1 - omega^j * x). + // If we divide everything by Q(x) = \prod_i (1 - omega^i * x), then we'll have + // P(x) / Q(x) = \sum_i y[i] / (1 - omega^i * x). + + // We want to find the first n coefficients of the right-hand side. + // [x^k](\sum_i y[i] / (1 - x * omega^i)) = \sum_i y[i] / (omega^{ik}) = Y(omega^k). + + let mut rhs = Self::chirp_z(&y, omega, n); + + // Now we need the denominator in the left-hand side. + let dft = Radix2BowersSerial; + let mut denom = Self::geometric_sequence_linear_product_helper(&dft, omega, n); + + let len = (denom.len() + rhs.len() - 1).next_power_of_two(); + denom.resize(len, F::ZERO); + rhs.resize(len, F::ZERO); + denom = dft.dft(denom); + rhs = dft.dft(rhs); + let res = denom.into_iter().zip(rhs).map(|(a, b)| a * b).collect_vec(); + let mut res = dft.idft(res); + res.truncate(n); + // Remember that P(x) is reversed + res.reverse(); + Self(res) + } + + /// Constructs the polynomial in coefficient form from its evaluations on a smooth subgroup of + /// `F^*` by performing inverse DFT. + /// + /// Requires that `evals.len()` is a power of 2. + pub fn from_evals_idft(evals: &[F]) -> Self { + // NOTE[jpw]: Use Bowers instead of Dit to avoid RefCell + let dft = Radix2BowersSerial; + let coeffs = dft.idft(evals.to_vec()); + Self(coeffs) + } + + /// Interpolates from evaluations on cosets `init * g^i D` for `i = 0,..,width-1` where `D` is + /// a smooth subgroup of F. + pub fn from_geometric_cosets_evals_idft( + evals: RowMajorMatrix, + shift: BF, + init: BF, + ) -> Self + where + F: ExtensionField, + { + let height = evals.height(); + let width = evals.width(); + if height == 0 || width == 0 { + return Self(Vec::new()); + } + + let log_height = log2_strict_usize(height); + let dft = Radix2BowersSerial; + + // First interpolate within each coset (size `height`) to get the remainder + // modulo `X^height - shift^height`, then unshift coefficients by `(init * shift^i)^{-t}`. + let mut coeffs_mat = dft.idft_batch(evals); + let shift_inv = shift.inverse(); + let init_inv = init.inverse(); + // shift_invs[i] = (init * shift^i)^{-1} = init^{-1} * shift^{-i} + let shift_invs = (0..width) + .fold( + (Vec::with_capacity(width), init_inv), + |(mut acc, pow), _| { + acc.push(pow); + (acc, pow * shift_inv) + }, + ) + .0; + let mut shift_pows = vec![F::ONE; width]; + for row in coeffs_mat.rows_mut() { + for (col, value) in row.iter_mut().enumerate() { + *value *= shift_pows[col]; + shift_pows[col] *= shift_invs[col]; + } + } + + // Interpolate across cosets for each coefficient degree. + // Points are init^height, init^height * shift^height, ..., init^height * + // shift^{(width-1)*height} + let coset_base = shift.exp_power_of_2(log_height); + let init_base = init.exp_power_of_2(log_height); + let lagrange_basis = lagrange_basis_from_geometric_points(coset_base, width, init_base); + let mut coeffs = vec![F::ZERO; height * width]; + for (row_idx, row_vals) in coeffs_mat.row_slices().enumerate() { + let mut poly_coeffs = vec![F::ZERO; width]; + for (i, &value) in row_vals.iter().enumerate() { + if value == F::ZERO { + continue; + } + for (k, basis_coeff) in lagrange_basis[i].iter().enumerate() { + poly_coeffs[k] += value * *basis_coeff; + } + } + for (coset_idx, coeff) in poly_coeffs.into_iter().enumerate() { + coeffs[coset_idx * height + row_idx] = coeff; + } + } + Self(coeffs) + } +} + +/// Evaluates univariate polynomial using [Horner's method]. +/// +/// [Horner's method]: https://en.wikipedia.org/wiki/Horner%27s_method +pub fn horner_eval(coeffs: &[F1], x: F2) -> F3 +where + F1: Field, + F2: Field, + F3: Field + Add, + F3: Mul, +{ + coeffs.iter().rfold(F3::ZERO, |acc, coeff| acc * x + *coeff) +} + +/// Interpolates a linear polynomial through points (0, evals[0]), (1, evals[1]) +/// and evaluates it at x. +#[inline(always)] +pub fn interpolate_linear_at_01(evals: &[F; 2], x: F) -> F { + let p = evals[1] - evals[0]; + p * x + evals[0] +} + +/// Interpolates a quadratic polynomial through points (0, evals[0]), (1, evals[1]), +/// (2, evals[2]) and evaluates it at x. +#[inline(always)] +pub fn interpolate_quadratic_at_012(evals: &[F; 3], x: F) -> F { + let s1 = evals[1] - evals[0]; + let s2 = evals[2] - evals[1]; + let p = (s2 - s1).halve(); + let q = s1 - p; + (p * x + q) * x + evals[0] +} + +/// Interpolates a cubic polynomial through points (0, evals[0]), (1, evals[1]), +/// (2, evals[2]), (3, evals[3]) and evaluates it at x. +#[inline(always)] +pub fn interpolate_cubic_at_0123(evals: &[F; 4], x: F) -> F { + let inv6 = F::from_canonical_u64(6).inverse(); + + let s1 = evals[1] - evals[0]; + let s2 = evals[2] - evals[0]; + let s3 = evals[3] - evals[0]; + + let d3 = s3 - (s2 - s1) * F::from_canonical_u64(3); + + let p = d3 * inv6; + let q = (s2 - d3).halve() - s1; + let r = s1 - p - q; + + ((p * x + q) * x + r) * x + evals[0] +} + +pub struct ExpPowers2 { + current: Option, +} + +impl Iterator for ExpPowers2 { + type Item = T; + fn next(&mut self) -> Option { + if let Some(curr) = self.current.take() { + let next = curr.square(); + self.current = Some(next); + Some(curr) + } else { + None + } + } +} + +pub trait Squarable: FieldAlgebra + Clone { + #[inline] + fn exp_powers_of_2(&self) -> ExpPowers2 { + ExpPowers2 { + current: Some(self.clone()), + } + } +} + +impl Squarable for T {} + +/// Precompute Lagrange basis polynomials for interpolation at `init * base^i` for i=0..width-1. +fn lagrange_basis_from_geometric_points(base: F, width: usize, init: F) -> Vec> { + if width == 0 { + return Vec::new(); + } + if width == 1 { + return vec![vec![F::ONE]]; + } + + // Points are init, init * base, init * base^2, ..., init * base^{width-1} + let points = (0..width) + .fold((Vec::with_capacity(width), init), |(mut acc, pow), _| { + acc.push(pow); + (acc, pow * base) + }) + .0; + + // Build the monic polynomial P(x) = ∏(x - points[i]). + let mut root_poly = vec![F::ONE]; + for &x in &points { + root_poly.push(F::ZERO); + for k in (1..root_poly.len()).rev() { + let prev = root_poly[k - 1]; + root_poly[k] = prev - x * root_poly[k]; + } + root_poly[0] = -x * root_poly[0]; + } + + // Precompute products of (1 - base^k) for k=1..width-1. + let mut prefix = vec![F::ONE; width]; + for (i, base_pow) in base.powers().skip(1).take(width - 1).enumerate() { + prefix[i + 1] = prefix[i] * (F::ONE - base_pow); + } + + // Compute P(x)/(x - points[i]) for each i and scale by the inverse denominator. + let mut quotients = Vec::with_capacity(width); + for (i, &x) in points.iter().enumerate() { + let mut q = vec![F::ZERO; width]; + q[width - 1] = root_poly[width]; + for k in (0..width - 1).rev() { + q[k] = root_poly[k + 1] + x * q[k + 1]; + } + + // Denominator is prod_{j != i} (points[i] - points[j]) + // = prod_{j != i} (init * base^i - init * base^j) + // = init^{width-1} * base^{i*(width-1)} * prod_{j != i} (1 - base^{j-i}) + // = init^{width-1} * base^{i*(width-1)} * prod_{k=1}^{i} (1 - base^{-k}) * + // prod_{k=1}^{width-1-i} (1 - base^k) For k=1..i: (1 - base^{-k}) = -base^{-k} * (1 + // - base^k) So prod_{k=1}^{i} (1 - base^{-k}) = (-1)^i * base^{-i*(i+1)/2} * + // prefix[i] + let sign = if i % 2 == 0 { F::ONE } else { F::NEG_ONE }; + let exp = i * (width - 1) - (i * (i + 1) / 2); + let pow = base.exp_u64(exp as u64); + let init_pow = init.exp_u64((width - 1) as u64); + let denom = sign * init_pow * pow * prefix[i] * prefix[width - 1 - i]; + let inv_denom = denom.inverse(); + for coeff in q.iter_mut() { + *coeff *= inv_denom; + } + quotients.push(q); + } + quotients +} + +#[cfg(test)] +mod tests { + use std::iter::zip; + + use itertools::Itertools; + use p3_field::{FieldAlgebra, TwoAdicField}; + use p3_util::{log2_ceil_usize, log2_strict_usize}; + use rand::{rngs::StdRng, Rng, SeedableRng}; + + use super::*; + use crate::{EF, F}; + + #[test] + fn test_lagrange_interpolation_round_trip() { + let mut rng = StdRng::seed_from_u64(0); + // Test various polynomial degrees + for degree in 0..8usize { + let num_evals: usize = degree + 1; + + // Create random coefficients for a polynomial of given degree + let mut original_coeffs = vec![]; + for _ in 0..num_evals { + // Use deterministic values for reproducibility + original_coeffs.push(F::from_wrapped_u32(rng.random())); + } + + // Generate evaluation points (powers of omega) + let log_domain_size = log2_ceil_usize(num_evals); + let omega = F::two_adic_generator(log_domain_size); + + // Evaluate polynomial at these points + let mut evals = vec![]; + let points = omega.powers().take(num_evals).collect_vec(); + for &x in &points { + evals.push(horner_eval(&original_coeffs, x)); + } + + // Reconstruct polynomial from evaluations using Lagrange interpolation + let reconstructed_poly = UnivariatePoly::lagrange_interpolate(&points, &evals); + + // Verify coefficients match (up to the original degree) + for (i, (coeff, reconstructed_coeff)) in + zip(original_coeffs, reconstructed_poly.0).enumerate() + { + assert_eq!( + coeff, reconstructed_coeff, + "Coefficient mismatch at index {} for degree {} polynomial", + i, degree + ); + } + } + } + + #[test] + fn test_chirp_z() { + let mut rng = StdRng::seed_from_u64(0); + for degree in 0..8usize { + let num_evals: usize = degree + 1; + + // Create random coefficients for a polynomial of given degree + let mut original_coeffs = vec![]; + for _ in 0..num_evals { + // Use deterministic values for reproducibility + original_coeffs.push(F::from_wrapped_u32(rng.random())); + } + + let log_domain_size = log2_ceil_usize(num_evals); + let omega = F::two_adic_generator(log_domain_size); + + // Evaluate polynomial at these points + let mut evals = vec![]; + let points = omega.powers().take(num_evals).collect_vec(); + for &x in &points { + evals.push(horner_eval(&original_coeffs, x)); + } + + assert_eq!( + evals, + UnivariatePoly::chirp_z(&original_coeffs, omega, num_evals) + ); + + let reconstructed_poly = UnivariatePoly::from_evals(&evals); + + // Verify coefficients match (up to the original degree) + for (i, (coeff, reconstructed_coeff)) in + zip(original_coeffs, reconstructed_poly.0).enumerate() + { + assert_eq!( + coeff, reconstructed_coeff, + "Coefficient mismatch at index {} for degree {} polynomial", + i, degree + ); + } + } + } + + #[test] + fn test_interpolate_linear() { + let evals = [ + EF::from_canonical_u64(20), // s(0) + EF::from_canonical_u64(10), // s(1) + ]; + + // Test interpolation at known points + assert_eq!(interpolate_linear_at_01(&evals, EF::ZERO), evals[0]); + assert_eq!(interpolate_linear_at_01(&evals, EF::ONE), evals[1]); + } + + #[test] + fn test_interpolate_quadratic() { + let evals = [ + EF::from_canonical_u64(20), // s(0) + EF::from_canonical_u64(10), // s(1) + EF::from_canonical_u64(18), // s(2) + ]; + + // Test interpolation at known points + assert_eq!(interpolate_quadratic_at_012(&evals, EF::ZERO), evals[0]); + assert_eq!(interpolate_quadratic_at_012(&evals, EF::ONE), evals[1]); + assert_eq!( + interpolate_quadratic_at_012(&evals, EF::from_canonical_u64(2)), + evals[2] + ); + } + + #[test] + fn test_interpolate_cubic() { + let evals = [ + EF::from_canonical_u64(20), // s(0) + EF::from_canonical_u64(10), // s(1) + EF::from_canonical_u64(18), // s(2) + EF::from_canonical_u64(28), // s(3) + ]; + + // Test interpolation at known points + assert_eq!(interpolate_cubic_at_0123(&evals, EF::ZERO), evals[0]); + assert_eq!(interpolate_cubic_at_0123(&evals, EF::ONE), evals[1]); + assert_eq!( + interpolate_cubic_at_0123(&evals, EF::from_canonical_u64(2)), + evals[2] + ); + assert_eq!( + interpolate_cubic_at_0123(&evals, EF::from_canonical_u64(3)), + evals[3] + ); + } + + #[test] + fn test_exp_powers_of_2() { + let x = F::from_canonical_u32(3); + let s = x.exp_powers_of_2().take(3).collect_vec(); + assert_eq!(s, vec![x, x * x, x * x * x * x],); + } + + #[test] + fn test_eval_in_uni() { + let l = 3; + let n = -2; + let u_0 = F::from_canonical_u32(12345); + let ind = eval_in_uni(l, n, u_0); + let expected = (u_0.exp_power_of_2(l) - F::ONE) + * (u_0.exp_power_of_2(l.wrapping_add_signed(n)) - F::ONE).inverse() + * F::from_canonical_usize(1 << n.unsigned_abs()).inverse(); + assert_eq!(ind, expected); + } + + #[test] + fn test_from_geometric_cosets_evals_idft_round_trip() { + let mut rng = StdRng::seed_from_u64(0); + let height = 8usize; + let log_height = log2_strict_usize(height); + let omega = F::two_adic_generator(log_height); + + let configs = [ + (F::GENERATOR, F::ONE), + (F::two_adic_generator(log_height + 2), F::GENERATOR), + ]; + + for (shift, init) in configs { + for width in 2..=4usize { + let coeffs = (0..height * width) + .map(|_| F::from_wrapped_u32(rng.random())) + .collect_vec(); + let coeffs_ref = &coeffs; + + // Evaluations on cosets init * shift^i * D for i = 0, ..., width - 1 + let evals: Vec = (0..height) + .flat_map(|row| { + let omega_pow = omega.exp_u64(row as u64); + (0..width).map(move |col| { + let coset_shift = init * shift.exp_u64(col as u64); + horner_eval(coeffs_ref, coset_shift * omega_pow) + }) + }) + .collect_vec(); + + let evals_mat = RowMajorMatrix::new(evals, width); + let poly = UnivariatePoly::from_geometric_cosets_evals_idft(evals_mat, shift, init); + assert_eq!( + poly.into_coeffs(), + coeffs, + "width {width} round-trip failed" + ); + } + } + } +} diff --git a/crates/stark-backend-v2/src/poseidon2/instance_babybear.rs b/crates/stark-backend-v2/src/poseidon2/instance_babybear.rs new file mode 100644 index 00000000..909a86c0 --- /dev/null +++ b/crates/stark-backend-v2/src/poseidon2/instance_babybear.rs @@ -0,0 +1,382 @@ +use hex_literal::hex; + +pub const RC16: &[&[u32]] = &[ + &[ + u32::from_be_bytes(hex!("69cbb6af")), + u32::from_be_bytes(hex!("46ad93f9")), + u32::from_be_bytes(hex!("60a00f4e")), + u32::from_be_bytes(hex!("6b1297cd")), + u32::from_be_bytes(hex!("23189afe")), + u32::from_be_bytes(hex!("732e7bef")), + u32::from_be_bytes(hex!("72c246de")), + u32::from_be_bytes(hex!("2c941900")), + u32::from_be_bytes(hex!("0557eede")), + u32::from_be_bytes(hex!("1580496f")), + u32::from_be_bytes(hex!("3a3ea77b")), + u32::from_be_bytes(hex!("54f3f271")), + u32::from_be_bytes(hex!("0f49b029")), + u32::from_be_bytes(hex!("47872fe1")), + u32::from_be_bytes(hex!("221e2e36")), + u32::from_be_bytes(hex!("1ab7202e")), + ], + &[ + u32::from_be_bytes(hex!("487779a6")), + u32::from_be_bytes(hex!("3851c9d8")), + u32::from_be_bytes(hex!("38dc17c0")), + u32::from_be_bytes(hex!("209f8849")), + u32::from_be_bytes(hex!("268dcee8")), + u32::from_be_bytes(hex!("350c48da")), + u32::from_be_bytes(hex!("5b9ad32e")), + u32::from_be_bytes(hex!("0523272b")), + u32::from_be_bytes(hex!("3f89055b")), + u32::from_be_bytes(hex!("01e894b2")), + u32::from_be_bytes(hex!("13ddedde")), + u32::from_be_bytes(hex!("1b2ef334")), + u32::from_be_bytes(hex!("7507d8b4")), + u32::from_be_bytes(hex!("6ceeb94e")), + u32::from_be_bytes(hex!("52eb6ba2")), + u32::from_be_bytes(hex!("50642905")), + ], + &[ + u32::from_be_bytes(hex!("05453f3f")), + u32::from_be_bytes(hex!("06349efc")), + u32::from_be_bytes(hex!("6922787c")), + u32::from_be_bytes(hex!("04bfff9c")), + u32::from_be_bytes(hex!("768c714a")), + u32::from_be_bytes(hex!("3e9ff21a")), + u32::from_be_bytes(hex!("15737c9c")), + u32::from_be_bytes(hex!("2229c807")), + u32::from_be_bytes(hex!("0d47f88c")), + u32::from_be_bytes(hex!("097e0ecc")), + u32::from_be_bytes(hex!("27eadba0")), + u32::from_be_bytes(hex!("2d7d29e4")), + u32::from_be_bytes(hex!("3502aaa0")), + u32::from_be_bytes(hex!("0f475fd7")), + u32::from_be_bytes(hex!("29fbda49")), + u32::from_be_bytes(hex!("018afffd")), + ], + &[ + u32::from_be_bytes(hex!("0315b618")), + u32::from_be_bytes(hex!("6d4497d1")), + u32::from_be_bytes(hex!("1b171d9e")), + u32::from_be_bytes(hex!("52861abd")), + u32::from_be_bytes(hex!("2e5d0501")), + u32::from_be_bytes(hex!("3ec8646c")), + u32::from_be_bytes(hex!("6e5f250a")), + u32::from_be_bytes(hex!("148ae8e6")), + u32::from_be_bytes(hex!("17f5fa4a")), + u32::from_be_bytes(hex!("3e66d284")), + u32::from_be_bytes(hex!("0051aa3b")), + u32::from_be_bytes(hex!("483f7913")), + u32::from_be_bytes(hex!("2cfe5f15")), + u32::from_be_bytes(hex!("023427ca")), + u32::from_be_bytes(hex!("2cc78315")), + u32::from_be_bytes(hex!("1e36ea47")), + ], + &[ + u32::from_be_bytes(hex!("5a8053c0")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + ], + &[ + u32::from_be_bytes(hex!("693be639")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + ], + &[ + u32::from_be_bytes(hex!("3858867d")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + ], + &[ + u32::from_be_bytes(hex!("19334f6b")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + ], + &[ + u32::from_be_bytes(hex!("128f0fd8")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + ], + &[ + u32::from_be_bytes(hex!("4e2b1ccb")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + ], + &[ + u32::from_be_bytes(hex!("61210ce0")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + ], + &[ + u32::from_be_bytes(hex!("3c318939")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + ], + &[ + u32::from_be_bytes(hex!("0b5b2f22")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + ], + &[ + u32::from_be_bytes(hex!("2edb11d5")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + ], + &[ + u32::from_be_bytes(hex!("213effdf")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + ], + &[ + u32::from_be_bytes(hex!("0cac4606")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + ], + &[ + u32::from_be_bytes(hex!("241af16d")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + u32::from_be_bytes(hex!("00000000")), + ], + &[ + u32::from_be_bytes(hex!("7290a80d")), + u32::from_be_bytes(hex!("6f7e5329")), + u32::from_be_bytes(hex!("598ec8a8")), + u32::from_be_bytes(hex!("76a859a0")), + u32::from_be_bytes(hex!("6559e868")), + u32::from_be_bytes(hex!("657b83af")), + u32::from_be_bytes(hex!("13271d3f")), + u32::from_be_bytes(hex!("1f876063")), + u32::from_be_bytes(hex!("0aeeae37")), + u32::from_be_bytes(hex!("706e9ca6")), + u32::from_be_bytes(hex!("46400cee")), + u32::from_be_bytes(hex!("72a05c26")), + u32::from_be_bytes(hex!("2c589c9e")), + u32::from_be_bytes(hex!("20bd37a7")), + u32::from_be_bytes(hex!("6a2d3d10")), + u32::from_be_bytes(hex!("20523767")), + ], + &[ + u32::from_be_bytes(hex!("5b8fe9c4")), + u32::from_be_bytes(hex!("2aa501d6")), + u32::from_be_bytes(hex!("1e01ac3e")), + u32::from_be_bytes(hex!("1448bc54")), + u32::from_be_bytes(hex!("5ce5ad1c")), + u32::from_be_bytes(hex!("4918a14d")), + u32::from_be_bytes(hex!("2c46a83f")), + u32::from_be_bytes(hex!("4fcf6876")), + u32::from_be_bytes(hex!("61d8d5c8")), + u32::from_be_bytes(hex!("6ddf4ff9")), + u32::from_be_bytes(hex!("11fda4d3")), + u32::from_be_bytes(hex!("02933a8f")), + u32::from_be_bytes(hex!("170eaf81")), + u32::from_be_bytes(hex!("5a9c314f")), + u32::from_be_bytes(hex!("49a12590")), + u32::from_be_bytes(hex!("35ec52a1")), + ], + &[ + u32::from_be_bytes(hex!("58eb1611")), + u32::from_be_bytes(hex!("5e481e65")), + u32::from_be_bytes(hex!("367125c9")), + u32::from_be_bytes(hex!("0eba33ba")), + u32::from_be_bytes(hex!("1fc28ded")), + u32::from_be_bytes(hex!("066399ad")), + u32::from_be_bytes(hex!("0cbec0ea")), + u32::from_be_bytes(hex!("75fd1af0")), + u32::from_be_bytes(hex!("50f5bf4e")), + u32::from_be_bytes(hex!("643d5f41")), + u32::from_be_bytes(hex!("6f4fe718")), + u32::from_be_bytes(hex!("5b3cbbde")), + u32::from_be_bytes(hex!("1e3afb3e")), + u32::from_be_bytes(hex!("296fb027")), + u32::from_be_bytes(hex!("45e1547b")), + u32::from_be_bytes(hex!("4a8db2ab")), + ], + &[ + u32::from_be_bytes(hex!("59986d19")), + u32::from_be_bytes(hex!("30bcdfa3")), + u32::from_be_bytes(hex!("1db63932")), + u32::from_be_bytes(hex!("1d7c2824")), + u32::from_be_bytes(hex!("53b33681")), + u32::from_be_bytes(hex!("0673b747")), + u32::from_be_bytes(hex!("038a98a3")), + u32::from_be_bytes(hex!("2c5bce60")), + u32::from_be_bytes(hex!("351979cd")), + u32::from_be_bytes(hex!("5008fb73")), + u32::from_be_bytes(hex!("547bca78")), + u32::from_be_bytes(hex!("711af481")), + u32::from_be_bytes(hex!("3f93bf64")), + u32::from_be_bytes(hex!("644d987b")), + u32::from_be_bytes(hex!("3c8bcd87")), + u32::from_be_bytes(hex!("608758b8")), + ], +]; diff --git a/crates/stark-backend-v2/src/poseidon2/mod.rs b/crates/stark-backend-v2/src/poseidon2/mod.rs new file mode 100644 index 00000000..6a8e10c7 --- /dev/null +++ b/crates/stark-backend-v2/src/poseidon2/mod.rs @@ -0,0 +1,57 @@ +use std::sync::OnceLock; + +use p3_baby_bear::{BabyBear, Poseidon2BabyBear}; +use p3_field::FieldAlgebra; +use p3_poseidon2::ExternalLayerConstants; + +mod instance_babybear; +pub mod sponge; + +pub use instance_babybear::*; + +pub const WIDTH: usize = 16; +pub const CHUNK: usize = 8; + +// Fixed Poseidon2 configuration +pub fn poseidon2_perm() -> &'static Poseidon2BabyBear { + static PERM: OnceLock> = OnceLock::new(); + PERM.get_or_init(|| { + let (external_constants, internal_constants) = horizen_round_consts_16(); + Poseidon2BabyBear::new(external_constants, internal_constants) + }) +} + +pub fn horizen_round_consts_16() -> (ExternalLayerConstants, Vec) { + let p3_rc16: Vec> = RC16 + .iter() + .map(|round| { + round + .iter() + .map(|&u32_canonical| BabyBear::from_wrapped_u32(u32_canonical)) + .collect() + }) + .collect(); + + let rounds_f = 8; + let rounds_p = 13; + let rounds_f_beginning = rounds_f / 2; + let p_end = rounds_f_beginning + rounds_p; + let initial: Vec<[BabyBear; 16]> = p3_rc16[..rounds_f_beginning] + .iter() + .cloned() + .map(|round| round.try_into().unwrap()) + .collect(); + let terminal: Vec<[BabyBear; 16]> = p3_rc16[p_end..] + .iter() + .cloned() + .map(|round| round.try_into().unwrap()) + .collect(); + let internal_round_constants: Vec = p3_rc16[rounds_f_beginning..p_end] + .iter() + .map(|round| round[0]) + .collect(); + ( + ExternalLayerConstants::new(initial, terminal), + internal_round_constants, + ) +} diff --git a/crates/stark-backend-v2/src/poseidon2/sponge.rs b/crates/stark-backend-v2/src/poseidon2/sponge.rs new file mode 100644 index 00000000..e5ee82a9 --- /dev/null +++ b/crates/stark-backend-v2/src/poseidon2/sponge.rs @@ -0,0 +1,477 @@ +use core::{array::from_fn, ops::Deref}; + +use p3_baby_bear::Poseidon2BabyBear; +use p3_challenger::CanObserve; +use p3_field::{FieldAlgebra, FieldExtensionAlgebra, PrimeField32}; +use p3_maybe_rayon::prelude::*; +use p3_symmetric::Permutation; +use tracing::instrument; + +use super::{poseidon2_perm, CHUNK, WIDTH}; +use crate::{Digest, D_EF, EF, F}; + +pub trait FiatShamirTranscript: Clone + Send + Sync { + fn observe(&mut self, value: F); + fn sample(&mut self) -> F; + + fn observe_commit(&mut self, digest: [F; CHUNK]) { + for x in digest { + self.observe(x); + } + } + + fn observe_ext(&mut self, value: EF) { + // for i in 0..D + for &base_val in value.as_base_slice() { + self.observe(base_val); + } + } + + fn sample_ext(&mut self) -> EF { + let slice: [F; D_EF] = from_fn(|_| self.sample()); + EF::from_base_slice(&slice) + } + + fn sample_bits(&mut self, bits: usize) -> u32 { + assert!(bits < (u32::BITS as usize)); + assert!((1 << bits) < F::ORDER_U32); + let rand_f: F = self.sample(); + let rand_u32 = rand_f.as_canonical_u32(); + rand_u32 & ((1 << bits) - 1) + } + + #[must_use] + fn check_witness(&mut self, bits: usize, witness: F) -> bool { + self.observe(witness); + self.sample_bits(bits) == 0 + } + + #[instrument(name = "grind_pow", skip_all)] + fn grind(&mut self, bits: usize) -> F { + assert!(bits < (u32::BITS as usize)); + assert!((1u32 << bits) < F::ORDER_U32); + + let witness = (0..F::ORDER_U32) + .into_par_iter() + .map(F::from_canonical_u32) + .find_any(|witness| self.clone().check_witness(bits, *witness)) + .expect("failed to find PoW witness"); + assert!(self.check_witness(bits, witness)); + witness + } +} + +pub trait TranscriptHistory { + fn len(&self) -> usize; + fn into_log(self) -> TranscriptLog; + + fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +#[derive(Clone, Debug, Default)] +pub struct TranscriptLog { + /// Every sampled or observed value F + values: Vec, + /// True iff values[tidx] was a sampled value + is_sample: Vec, + /// Sponge state after every permutation; note that not all implementations of + /// TranscriptHistory will define this + perm_results: Vec<[F; WIDTH]>, +} + +impl TranscriptLog { + pub fn new(values: Vec, is_sample: Vec) -> Self { + debug_assert_eq!(values.len(), is_sample.len()); + Self { + values, + is_sample, + perm_results: vec![], + } + } + + pub fn values(&self) -> &[F] { + &self.values + } + + pub fn values_mut(&mut self) -> &mut [F] { + &mut self.values + } + + pub fn samples(&self) -> &[bool] { + &self.is_sample + } + + pub fn samples_mut(&mut self) -> &mut [bool] { + &mut self.is_sample + } + + pub fn push_observe(&mut self, value: F) { + self.values.push(value); + self.is_sample.push(false); + } + + pub fn push_sample(&mut self, value: F) { + self.values.push(value); + self.is_sample.push(true); + } + + pub fn push_perm_result(&mut self, state: [F; WIDTH]) { + self.perm_results.push(state); + } + + pub fn extend_observe(&mut self, values: &[F]) { + self.values.extend_from_slice(values); + self.is_sample + .extend(core::iter::repeat_n(false, values.len())); + } + + pub fn extend_sample(&mut self, values: &[F]) { + self.values.extend_from_slice(values); + self.is_sample + .extend(core::iter::repeat_n(true, values.len())); + } + + pub fn extend_with_flags(&mut self, values: &[F], sample_flags: &[bool]) { + debug_assert_eq!(values.len(), sample_flags.len()); + self.values.extend_from_slice(values); + self.is_sample.extend(sample_flags.iter().copied()); + } + + pub fn len(&self) -> usize { + self.values.len() + } + + pub fn is_empty(&self) -> bool { + self.values.is_empty() + } + + pub fn into_parts(self) -> (Vec, Vec) { + (self.values, self.is_sample) + } + + pub fn perm_results(&self) -> &Vec<[F; WIDTH]> { + &self.perm_results + } +} + +impl Deref for TranscriptLog { + type Target = [F]; + + fn deref(&self) -> &Self::Target { + &self.values + } +} + +/// Poseidon2-based duplex sponge in overwrite mode. +/// +/// "Duplex" refers to being able to alternately absorb (observe) and squeeze +/// (sample), rather than a single absorb phase followed by a single squeeze +/// phase. +/// +/// This variant operates in *overwrite mode*, meaning new inputs overwrite +/// state elements directly (instead of, e.g., being added in). +#[derive(Clone, Debug)] +pub struct DuplexSponge { + perm: Poseidon2BabyBear, + /// Poseidon2 state + state: [F; WIDTH], + /// Invariant to be preserved: 0 <= absorb_idx < CHUNK + absorb_idx: usize, + /// Invariant to be preserved: 0 <= sample_idx <= CHUNK + sample_idx: usize, + /// True iff last sample/observe triggered a permutation + last_op_perm: bool, +} + +impl Default for DuplexSponge { + fn default() -> Self { + Self { + perm: poseidon2_perm().clone(), + state: [F::ZERO; WIDTH], + absorb_idx: 0, + sample_idx: 0, + last_op_perm: false, + } + } +} + +impl FiatShamirTranscript for DuplexSponge { + fn observe(&mut self, value: F) { + self.state[self.absorb_idx] = value; + self.absorb_idx += 1; + self.last_op_perm = self.absorb_idx == CHUNK; + if self.last_op_perm { + self.perm.permute_mut(&mut self.state); + self.absorb_idx = 0; + self.sample_idx = CHUNK; + } + } + + fn sample(&mut self) -> F { + self.last_op_perm = self.absorb_idx != 0 || self.sample_idx == 0; + if self.last_op_perm { + self.perm.permute_mut(&mut self.state); + self.absorb_idx = 0; + self.sample_idx = CHUNK; + } + self.sample_idx -= 1; + self.state[self.sample_idx] + } +} + +impl CanObserve for DuplexSponge { + fn observe(&mut self, value: F) { + FiatShamirTranscript::observe(self, value); + } +} + +impl CanObserve for DuplexSponge { + fn observe(&mut self, digest: Digest) { + FiatShamirTranscript::observe_commit(self, digest); + } +} + +pub fn poseidon2_hash_slice(vals: &[F]) -> [F; CHUNK] { + let perm = poseidon2_perm(); + let mut state = [F::ZERO; WIDTH]; + let mut i = 0; + for &val in vals { + state[i] = val; + i += 1; + if i == CHUNK { + perm.permute_mut(&mut state); + i = 0; + } + } + if i != 0 { + perm.permute_mut(&mut state); + } + state[..CHUNK].try_into().unwrap() +} + +pub fn poseidon2_compress_with_capacity( + left: [F; CHUNK], + right: [F; CHUNK], +) -> ([F; CHUNK], [F; CHUNK]) { + let mut state = [F::ZERO; WIDTH]; + state[..CHUNK].copy_from_slice(&left); + state[CHUNK..].copy_from_slice(&right); + poseidon2_perm().permute_mut(&mut state); + ( + state[..CHUNK].try_into().unwrap(), + state[CHUNK..].try_into().unwrap(), + ) +} + +pub fn poseidon2_compress(left: [F; CHUNK], right: [F; CHUNK]) -> [F; CHUNK] { + poseidon2_compress_with_capacity(left, right).0 +} + +pub fn poseidon2_tree_compress(mut hashes: Vec) -> Digest { + debug_assert!(hashes.len().is_power_of_two()); + while hashes.len() > 1 { + let mut next = Vec::with_capacity(hashes.len() / 2); + for pair in hashes.chunks_exact(2) { + next.push(poseidon2_compress(pair[0], pair[1])); + } + hashes = next; + } + hashes.pop().unwrap() +} + +#[derive(Clone)] +pub struct DuplexSpongeRecorder { + pub inner: DuplexSponge, + pub log: TranscriptLog, +} + +impl Default for DuplexSpongeRecorder { + fn default() -> Self { + let mut log = TranscriptLog::default(); + log.push_perm_result([F::ZERO; WIDTH]); + Self { + inner: Default::default(), + log, + } + } +} + +impl FiatShamirTranscript for DuplexSpongeRecorder { + fn observe(&mut self, x: F) { + ::observe(&mut self.inner, x); + self.log.push_observe(x); + if self.inner.last_op_perm { + self.log.push_perm_result(self.inner.state); + } + } + + fn sample(&mut self) -> F { + let x = self.inner.sample(); + self.log.push_sample(x); + if self.inner.last_op_perm { + self.log.push_perm_result(self.inner.state); + } + x + } +} + +impl TranscriptHistory for DuplexSpongeRecorder { + fn len(&self) -> usize { + self.log.len() + } + + fn into_log(self) -> TranscriptLog { + self.log + } +} + +/// Read-only transcript that replays a recorded log. +#[derive(Clone, Debug)] +pub struct ReadOnlyTranscript<'a> { + log: &'a TranscriptLog, + position: usize, +} + +impl<'a> ReadOnlyTranscript<'a> { + pub fn new(log: &'a TranscriptLog, start_idx: usize) -> Self { + debug_assert!(start_idx <= log.len(), "start index out of bounds"); + Self { + log, + position: start_idx, + } + } +} + +impl FiatShamirTranscript for ReadOnlyTranscript<'_> { + #[inline] + fn observe(&mut self, value: F) { + debug_assert!( + !self.log.samples()[self.position], + "expected observe at {}", + self.position + ); + debug_assert_eq!( + self.log.values()[self.position], + value, + "value mismatch at {}", + self.position + ); + self.position += 1; + } + + #[inline] + fn sample(&mut self) -> F { + debug_assert!( + self.log.samples()[self.position], + "expected sample at {}", + self.position + ); + let value = self.log.values()[self.position]; + self.position += 1; + value + } +} + +impl TranscriptHistory for ReadOnlyTranscript<'_> { + fn len(&self) -> usize { + self.position + } + + fn into_log(self) -> TranscriptLog { + self.log.clone() + } +} + +#[cfg(test)] +mod test { + use openvm_stark_sdk::config::baby_bear_poseidon2::Challenger; + use p3_baby_bear::BabyBear; + use p3_challenger::{CanObserve, CanSample}; + use p3_field::FieldAlgebra; + + use crate::poseidon2::{ + poseidon2_perm, + sponge::{ + DuplexSponge, DuplexSpongeRecorder, FiatShamirTranscript, ReadOnlyTranscript, + TranscriptHistory, + }, + }; + + #[test] + fn test_sponge() { + let perm = poseidon2_perm(); + + let mut challenger = Challenger::new(perm.clone()); + let mut sponge = DuplexSponge::default(); + + for i in 0..5 { + for _ in 0..(i + 1) * i { + let a: BabyBear = challenger.sample(); + let b = sponge.sample(); + assert_eq!(a, b); + } + + for j in 0..i * i { + challenger.observe(BabyBear::from_canonical_usize(j)); + FiatShamirTranscript::observe(&mut sponge, BabyBear::from_canonical_usize(j)); + } + } + } + + #[test] + fn test_read_only_transcript() { + // Record a sequence of operations + let mut recorder = DuplexSpongeRecorder::default(); + recorder.observe(BabyBear::from_canonical_u32(42)); + recorder.observe(BabyBear::from_canonical_u32(100)); + let s1 = recorder.sample(); + recorder.observe(BabyBear::from_canonical_u32(200)); + let s2 = recorder.sample(); + let s3 = recorder.sample(); + + let log = recorder.into_log(); + + // Replay from start + let mut replay = ReadOnlyTranscript::new(&log, 0); + replay.observe(BabyBear::from_canonical_u32(42)); + replay.observe(BabyBear::from_canonical_u32(100)); + assert_eq!(replay.sample(), s1); + replay.observe(BabyBear::from_canonical_u32(200)); + assert_eq!(replay.sample(), s2); + assert_eq!(replay.sample(), s3); + assert_eq!(replay.len(), 6); + + // Replay from middle + let mut replay2 = ReadOnlyTranscript::new(&log, 2); + assert_eq!(replay2.sample(), s1); + replay2.observe(BabyBear::from_canonical_u32(200)); + assert_eq!(replay2.sample(), s2); + assert_eq!(replay2.len(), 5); + } + + #[test] + #[cfg(debug_assertions)] + #[should_panic(expected = "expected observe at 0")] + fn test_read_only_transcript_wrong_operation() { + let mut recorder = DuplexSpongeRecorder::default(); + let _ = recorder.sample(); + let log = recorder.into_log(); + + let mut replay = ReadOnlyTranscript::new(&log, 0); + replay.observe(BabyBear::from_canonical_u32(42)); // Should panic + } + + #[test] + #[cfg(debug_assertions)] + #[should_panic(expected = "value mismatch at 0")] + fn test_read_only_transcript_wrong_value() { + let mut recorder = DuplexSpongeRecorder::default(); + recorder.observe(BabyBear::from_canonical_u32(42)); + let log = recorder.into_log(); + + let mut replay = ReadOnlyTranscript::new(&log, 0); + replay.observe(BabyBear::from_canonical_u32(99)); // Should panic + } +} diff --git a/crates/stark-backend-v2/src/proof.rs b/crates/stark-backend-v2/src/proof.rs new file mode 100644 index 00000000..e9408af8 --- /dev/null +++ b/crates/stark-backend-v2/src/proof.rs @@ -0,0 +1,577 @@ +use std::io::{Error, Read, Result, Write}; + +use p3_field::FieldAlgebra; +use serde::{Deserialize, Serialize}; + +use crate::{ + codec::{decode_into_vec, encode_iter, Decode, Encode}, + Digest, EF, F, +}; + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct Proof { + /// The commitment to the data in common_main. + pub common_main_commit: Digest, + + /// For each AIR in vkey order, the corresponding trace shape, or None if + /// the trace is empty. In a valid proof, if `vk.per_air[i].is_required`, + /// then `trace_vdata[i]` must be `Some(_)`. + pub trace_vdata: Vec>, + + /// For each AIR in vkey order, the public values. Public values should be empty if the AIR has + /// an empty trace. + pub public_values: Vec>, + + pub gkr_proof: GkrProof, + pub batch_constraint_proof: BatchConstraintProof, + pub stacking_proof: StackingProof, + pub whir_proof: WhirProof, +} + +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize, Encode, Decode)] +pub struct TraceVData { + /// The base 2 logarithm of the trace height. This should be a nonnegative integer and is + /// allowed to be `< l_skip`. + /// + /// If the corresponding AIR has a preprocessed trace, this must match the + /// value in the vkey. + pub log_height: usize, + /// The cached commitments used. + /// + /// The length must match the value in the vkey. + pub cached_commitments: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct GkrProof { + // TODO[jpw]: I'm not sure this is concepturally the place to put it, but recursion gkr module + // samples alpha,beta + pub logup_pow_witness: F, + /// The denominator of the root layer. + /// + /// Note that the numerator claim is always zero, so we don't include it in + /// the proof. Despite that the numerator is zero, the representation of the + /// denominator is important for the verification procedure and thus must be + /// provided. + pub q0_claim: EF, + /// The claims for p_j(xi, 0), p_j(xi, 1), q_j(xi, 0), and q_j(xi, 0) for each layer j > 0. + pub claims_per_layer: Vec, + /// The sumcheck polynomials for each layer, for each sumcheck round, given by their + /// evaluations on {1, 2, 3}. + pub sumcheck_polys: Vec>, +} + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Encode, Decode)] +pub struct GkrLayerClaims { + pub p_xi_0: EF, + pub p_xi_1: EF, + pub q_xi_0: EF, + pub q_xi_1: EF, +} + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct BatchConstraintProof { + /// The terms \textnormal{sum}_{\hat{p}, T, I} as defined in Protocol 3.4.6, per present AIR + /// **in sorted AIR order**. + pub numerator_term_per_air: Vec, + /// The terms \textnormal{sum}_{\hat{q}, T, I} as defined in Protocol 3.4.6, per present AIR + /// **in sorted AIR order**. + pub denominator_term_per_air: Vec, + + /// Polynomial for initial round, given by `(vk.d + 1) * (2^{l_skip} - 1) + 1` coefficients. + pub univariate_round_coeffs: Vec, + /// For rounds `1, ..., n_max`; evaluations on `{1, ..., vk.d + 1}`. + pub sumcheck_round_polys: Vec>, + + /// Per AIR **in sorted AIR order**, per AIR part, per column index in that part, openings for + /// the prismalinear column polynomial and (optionally) its rotational convolution. All column + /// openings are stored in a flat way, so only column openings or them interleaved with + /// rotations. The trace parts are ordered: [CommonMain (part 0), Preprocessed (if any), + /// Cached(0), Cached(1), ...] + pub column_openings: Vec>>, +} + +pub fn column_openings_by_rot<'a>( + openings: &'a [EF], + need_rot: bool, +) -> Box + 'a> { + if need_rot { + Box::new(openings.chunks_exact(2).map(|chunk| (chunk[0], chunk[1]))) + } else { + Box::new(openings.iter().map(|&claim| (claim, EF::ZERO))) + } +} + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct StackingProof { + /// Polynomial for round 0, given by `2 * (2^{l_skip} - 1) + 1` coefficients. + pub univariate_round_coeffs: Vec, + /// Rounds 1, ..., n_stack; evaluations at {1, 2}. + pub sumcheck_round_polys: Vec<[EF; 2]>, + /// Per commit, per column. + pub stacking_openings: Vec>, +} + +pub type MerkleProof = Vec; + +/// WHIR polynomial opening proof for multiple polynomials of the same height, committed to in +/// multiple commitments. +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct WhirProof { + /// Per sumcheck round; evaluations on {1, 2}. This list is "flattened" with respect to the + /// WHIR rounds. + pub whir_sumcheck_polys: Vec<[EF; 2]>, + /// The codeword commits after each fold, except the final round. + pub codeword_commits: Vec, + /// The out-of-domain values "y0" per round, except the final round. + pub ood_values: Vec, + /// For each sumcheck round, the folding PoW witness. Length is `num_whir_sumcheck_rounds = + /// num_whir_rounds * k_whir`. + pub folding_pow_witnesses: Vec, + /// For each WHIR round, the query phase PoW witness. Length is `num_whir_rounds`. + pub query_phase_pow_witnesses: Vec, + /// For the initial round: per committed matrix, per in-domain query. + // num_commits x num_queries x (1 << k) x stacking_width[i] + pub initial_round_opened_rows: Vec>>>, + pub initial_round_merkle_proofs: Vec>, + /// Per non-initial round, per in-domain-query. + pub codeword_opened_values: Vec>>, + pub codeword_merkle_proofs: Vec>, + /// Coefficients of the polynomial after the final round. + pub final_poly: Vec, +} + +// ==================== Encode implementations ==================== + +/// Codec version should change only when proof system or proof format changes. +/// It does correspond to the main openvm version (which may change more frequently). +pub(crate) const CODEC_VERSION: u32 = 2; + +// TODO: custom encode/decode for Proof that takes in a vk +impl Encode for Proof { + fn encode(&self, writer: &mut W) -> Result<()> { + // We explicitly implement Encode for Proof to add CODEC_VERSION + CODEC_VERSION.encode(writer)?; + self.common_main_commit.encode(writer)?; + + // We encode trace_vdata by encoding the number of AIRs, encoding a bitmap of + // which AIRs are present, and then encoding each present TraceVData. + let num_airs: usize = self.trace_vdata.len(); + num_airs.encode(writer)?; + for chunk in self.trace_vdata.chunks(8) { + let mut ret = 0u8; + for (i, vdata) in chunk.iter().enumerate() { + ret |= (vdata.is_some() as u8) << (i as u8); + } + ret.encode(writer)?; + } + for vdata in self.trace_vdata.iter().flatten() { + vdata.encode(writer)?; + } + + self.public_values.encode(writer)?; + self.gkr_proof.encode(writer)?; + self.batch_constraint_proof.encode(writer)?; + self.stacking_proof.encode(writer)?; + self.whir_proof.encode(writer) + } +} + +impl Encode for GkrProof { + fn encode(&self, writer: &mut W) -> Result<()> { + self.logup_pow_witness.encode(writer)?; + self.q0_claim.encode(writer)?; + self.claims_per_layer.encode(writer)?; + // We should know the length of sumcheck_polys and each nested vector based + // on the length of claims_per_layer. + encode_iter(self.sumcheck_polys.iter().flatten(), writer)?; + Ok(()) + } +} + +impl Encode for BatchConstraintProof { + fn encode(&self, writer: &mut W) -> Result<()> { + // Length of numerator_term_per_air is number of present AIRs + self.numerator_term_per_air.encode(writer)?; + encode_iter(self.denominator_term_per_air.iter(), writer)?; + + self.univariate_round_coeffs.encode(writer)?; + + // Each nested vector should be the same length + let n_max = self.sumcheck_round_polys.len(); + n_max.encode(writer)?; + if n_max > 0 { + self.sumcheck_round_polys[0].len().encode(writer)?; + for round_polys in &self.sumcheck_round_polys { + encode_iter(round_polys.iter(), writer)?; + } + } + + // There is one outer vector per present AIR + for part_col_openings in &self.column_openings { + part_col_openings.encode(writer)?; + } + Ok(()) + } +} + +impl Encode for StackingProof { + fn encode(&self, writer: &mut W) -> Result<()> { + self.univariate_round_coeffs.encode(writer)?; + self.sumcheck_round_polys.encode(writer)?; + self.stacking_openings.encode(writer) + } +} + +impl Encode for WhirProof { + fn encode(&self, writer: &mut W) -> Result<()> { + self.whir_sumcheck_polys.encode(writer)?; + let num_whir_sumcheck_rounds = self.whir_sumcheck_polys.len(); + + // Each length can be derived from num_whir_rounds + self.codeword_commits.encode(writer)?; + encode_iter(self.ood_values.iter(), writer)?; + let num_whir_rounds = self.codeword_commits.len() + 1; + if num_whir_sumcheck_rounds % num_whir_rounds != 0 { + return Err(Error::new( + std::io::ErrorKind::InvalidData, + "num_whir_sumcheck_rounds must be a multiple of num_whir_rounds", + )); + } + assert_eq!(num_whir_rounds, self.query_phase_pow_witnesses.len()); + encode_iter(self.folding_pow_witnesses.iter(), writer)?; + encode_iter(self.query_phase_pow_witnesses.iter(), writer)?; + + let num_commits = self.initial_round_opened_rows.len(); + assert!(num_commits > 0); + num_commits.encode(writer)?; + let initial_num_whir_queries = self.initial_round_opened_rows[0].len(); + initial_num_whir_queries.encode(writer)?; + + if initial_num_whir_queries > 0 { + let merkle_depth = self.initial_round_merkle_proofs[0][0].len(); + merkle_depth.encode(writer)?; + + // We avoid per-row Vec length prefixes by encoding each commit's stacked width, + // which we can use to determine the shapes of the remaining WHIR proof fields. + let widths: Vec = self + .initial_round_opened_rows + .iter() + .map(|commit_rows| { + // If there are any queries/rows, infer width from the first row. + commit_rows + .first() + .and_then(|q| q.first()) + .map(|row| row.len()) + .unwrap_or(0) + }) + .collect(); + + // Encode widths (length is implicit via num_commits). + encode_iter(widths.iter(), writer)?; + + // Encode all opened row values (no per-row length prefixes). + for (commit_rows, &width) in self.initial_round_opened_rows.iter().zip(&widths) { + debug_assert_eq!(commit_rows.len(), initial_num_whir_queries); + for query_rows in commit_rows { + for row in query_rows { + debug_assert_eq!(row.len(), width); + encode_iter(row.iter(), writer)?; + } + } + } + + encode_iter( + self.initial_round_merkle_proofs.iter().flatten().flatten(), + writer, + )?; + } + + // Length of outer vector is num_whir_rounds + for non_init_round in &self.codeword_opened_values { + let num_queries = non_init_round.len(); + num_queries.encode(writer)?; + // Length of nested vector is num_whir_queries, then k_whir_exp. + encode_iter(non_init_round.iter().flatten(), writer)?; + } + + // Length of outer vector is num_whir_rounds, then num_whir_queries. Each + // inner vector length is one less than the one that precedes it. + let mut first_merkle_depth = 0; + if num_whir_rounds > 1 && initial_num_whir_queries > 0 { + first_merkle_depth = self.codeword_merkle_proofs[0][0].len(); + } + first_merkle_depth.encode(writer)?; + encode_iter( + self.codeword_merkle_proofs.iter().flatten().flatten(), + writer, + )?; + + self.final_poly.encode(writer) + } +} + +// ==================== Decode implementations ==================== + +impl Decode for Proof { + fn decode(reader: &mut R) -> Result { + // We explicitly implement Decode for Proof to check CODEC_VERSION + let codec_version = u32::decode(reader)?; + if codec_version != CODEC_VERSION { + return Err(Error::other(format!( + "CODEC_VERSION mismatch, expected: {}, actual: {}", + CODEC_VERSION, codec_version + ))); + } + let common_main_commit = Digest::decode(reader)?; + + let num_airs = usize::decode(reader)?; + let bitmap_len = num_airs.div_ceil(8); + let mut bitmap: Vec = Vec::with_capacity(bitmap_len); + for _ in 0..bitmap_len { + bitmap.push(u8::decode(reader)?); + } + let mut trace_vdata = Vec::with_capacity(num_airs); + for byte in bitmap { + for i in 0u8..8 { + if trace_vdata.len() >= num_airs { + break; + } + if byte & (1u8 << i) != 0 { + trace_vdata.push(Some(TraceVData::decode(reader)?)); + } else { + trace_vdata.push(None); + } + } + } + + Ok(Self { + common_main_commit, + trace_vdata, + public_values: Vec::>::decode(reader)?, + gkr_proof: GkrProof::decode(reader)?, + batch_constraint_proof: BatchConstraintProof::decode(reader)?, + stacking_proof: StackingProof::decode(reader)?, + whir_proof: WhirProof::decode(reader)?, + }) + } +} + +impl Decode for GkrProof { + fn decode(reader: &mut R) -> Result { + let logup_pow_witness = F::decode(reader)?; + let q0_claim = EF::decode(reader)?; + let claims_per_layer = Vec::::decode(reader)?; + + let num_sumcheck_polys = claims_per_layer.len().saturating_sub(1); + let mut sumcheck_polys = Vec::with_capacity(num_sumcheck_polys); + for round_idx_minus_one in 0..num_sumcheck_polys { + sumcheck_polys.push(decode_into_vec(reader, round_idx_minus_one + 1)?); + } + + Ok(Self { + logup_pow_witness, + q0_claim, + claims_per_layer, + sumcheck_polys, + }) + } +} + +impl Decode for BatchConstraintProof { + fn decode(reader: &mut R) -> Result { + let numerator_term_per_air = Vec::::decode(reader)?; + let num_present_airs = numerator_term_per_air.len(); + let denominator_term_per_air = decode_into_vec(reader, num_present_airs)?; + + let univariate_round_coeffs = Vec::::decode(reader)?; + + let n_max = usize::decode(reader)?; + let mut sumcheck_round_polys = Vec::with_capacity(n_max); + if n_max > 0 { + let max_degree_plus_one = usize::decode(reader)?; + for _ in 0..n_max { + sumcheck_round_polys.push(decode_into_vec(reader, max_degree_plus_one)?); + } + } + + let mut column_openings = Vec::with_capacity(num_present_airs); + for _ in 0..num_present_airs { + column_openings.push(Vec::>::decode(reader)?); + } + + Ok(Self { + numerator_term_per_air, + denominator_term_per_air, + univariate_round_coeffs, + sumcheck_round_polys, + column_openings, + }) + } +} + +impl Decode for StackingProof { + fn decode(reader: &mut R) -> Result { + Ok(Self { + univariate_round_coeffs: Vec::::decode(reader)?, + sumcheck_round_polys: Vec::<[EF; 2]>::decode(reader)?, + stacking_openings: Vec::>::decode(reader)?, + }) + } +} + +impl Decode for WhirProof { + fn decode(reader: &mut R) -> Result { + let whir_sumcheck_polys = Vec::<[EF; 2]>::decode(reader)?; + let num_whir_sumcheck_rounds = whir_sumcheck_polys.len(); + let codeword_commits = Vec::::decode(reader)?; + let num_whir_rounds = codeword_commits.len() + 1; + if num_whir_sumcheck_rounds % num_whir_rounds != 0 { + return Err(Error::new( + std::io::ErrorKind::InvalidData, + "num_whir_sumcheck_rounds must be a multiple of num_whir_rounds", + )); + } + let k_whir = num_whir_sumcheck_rounds / num_whir_rounds; + let ood_values = decode_into_vec(reader, num_whir_rounds - 1)?; + let folding_pow_witnesses = decode_into_vec(reader, num_whir_sumcheck_rounds)?; + let query_phase_pow_witnesses = decode_into_vec(reader, num_whir_rounds)?; + + let num_commits = usize::decode(reader)?; + assert!(num_commits > 0); + let initial_num_whir_queries = usize::decode(reader)?; + let k_whir_exp = 1 << k_whir; + let mut merkle_depth = 0; + if initial_num_whir_queries > 0 { + merkle_depth = usize::decode(reader)?; + } + + let mut widths = vec![0usize; num_commits]; + if initial_num_whir_queries > 0 { + for width in &mut widths { + *width = usize::decode(reader)?; + } + } + + let mut initial_round_opened_rows = Vec::with_capacity(num_commits); + for width in widths { + let mut opened_rows = Vec::with_capacity(initial_num_whir_queries); + for _ in 0..initial_num_whir_queries { + // Each query has k_whir_exp rows. Each row is a fixed-width list of F elements. + let mut rows = Vec::with_capacity(k_whir_exp); + for _ in 0..k_whir_exp { + rows.push(decode_into_vec(reader, width)?); + } + opened_rows.push(rows); + } + initial_round_opened_rows.push(opened_rows); + } + + let mut initial_round_merkle_proofs = Vec::with_capacity(num_commits); + for _ in 0..num_commits { + let mut merkle_proofs = Vec::with_capacity(initial_num_whir_queries); + for _ in 0..initial_num_whir_queries { + merkle_proofs.push(decode_into_vec(reader, merkle_depth)?); + } + initial_round_merkle_proofs.push(merkle_proofs); + } + + let mut codeword_opened_values = Vec::with_capacity(num_whir_rounds - 1); + for _ in 0..num_whir_rounds - 1 { + let num_queries = usize::decode(reader)?; + let mut opened_values = Vec::with_capacity(num_queries); + for _ in 0..num_queries { + opened_values.push(decode_into_vec(reader, k_whir_exp)?); + } + codeword_opened_values.push(opened_values); + } + + merkle_depth = usize::decode(reader)?; + let mut codeword_merkle_proofs = Vec::with_capacity(num_whir_rounds - 1); + for opened_values in codeword_opened_values.iter() { + let num_queries = opened_values.len(); + let mut merkle_proof: Vec<_> = Vec::with_capacity(num_queries); + for _ in 0..num_queries { + merkle_proof.push(decode_into_vec(reader, merkle_depth)?); + } + codeword_merkle_proofs.push(merkle_proof); + merkle_depth -= 1; + } + + let final_poly = Vec::::decode(reader)?; + + Ok(Self { + whir_sumcheck_polys, + codeword_commits, + ood_values, + folding_pow_witnesses, + query_phase_pow_witnesses, + initial_round_opened_rows, + initial_round_merkle_proofs, + codeword_opened_values, + codeword_merkle_proofs, + final_poly, + }) + } +} + +#[cfg(test)] +mod tests { + use itertools::Itertools; + + use super::*; + use crate::{ + poseidon2::sponge::DuplexSpongeRecorder, + test_utils::{ + test_system_params_small, CachedFixture11, FibFixture, InteractionsFixture11, + PreprocessedFibFixture, TestFixture, + }, + BabyBearPoseidon2CpuEngineV2, SystemParams, + }; + + fn test_proof_encode_decode(fx: Fx, params: SystemParams) -> Result<()> { + let engine = BabyBearPoseidon2CpuEngineV2::new(params); + let pk = fx.keygen(&engine).0; + let proof = fx.prove_from_transcript(&engine, &pk, &mut DuplexSpongeRecorder::default()); + + let mut proof_bytes = Vec::new(); + proof.encode(&mut proof_bytes).unwrap(); + + let decoded_proof = Proof::decode(&mut &proof_bytes[..]).unwrap(); + assert_eq!(proof, decoded_proof); + Ok(()) + } + + #[test] + fn test_fib_proof_encode_decode() -> Result<()> { + let log_trace_height = 5; + let fx = FibFixture::new(0, 1, 1 << log_trace_height); + let params = SystemParams::new_for_testing(log_trace_height); + test_proof_encode_decode(fx, params) + } + + #[test] + fn test_interactions_proof_encode_decode() -> Result<()> { + let fx = InteractionsFixture11; + let params = test_system_params_small(2, 5, 3); + test_proof_encode_decode(fx, params) + } + + #[test] + fn test_cached_proof_encode_decode() -> Result<()> { + let params = test_system_params_small(2, 5, 3); + let fx = CachedFixture11::new(params.clone()); + test_proof_encode_decode(fx, params) + } + + #[test] + fn test_preprocessed_proof_encode_decode() -> Result<()> { + let log_trace_height = 5; + let params = SystemParams::new_for_testing(log_trace_height); + let sels = (0..(1 << log_trace_height)) + .map(|i| i % 2 == 0) + .collect_vec(); + let fx = PreprocessedFibFixture::new(0, 1, sels); + test_proof_encode_decode(fx, params) + } +} diff --git a/crates/stark-backend-v2/src/prover/cpu_backend.rs b/crates/stark-backend-v2/src/prover/cpu_backend.rs new file mode 100644 index 00000000..562e89c9 --- /dev/null +++ b/crates/stark-backend-v2/src/prover/cpu_backend.rs @@ -0,0 +1,206 @@ +//! CPU [ProverBackend] trait implementation. + +use getset::Getters; +use itertools::Itertools; + +use crate::{ + keygen::types::MultiStarkProvingKeyV2, + poly_common::Squarable, + poseidon2::sponge::FiatShamirTranscript, + proof::{BatchConstraintProof, GkrProof, StackingProof, WhirProof}, + prover::{ + prove_zerocheck_and_logup, + stacked_pcs::{stacked_commit, StackedPcsData}, + stacked_reduction::{prove_stacked_opening_reduction, StackedReductionCpu}, + whir::WhirProver, + ColMajorMatrix, CommittedTraceDataV2, DeviceDataTransporterV2, + DeviceMultiStarkProvingKeyV2, DeviceStarkProvingKeyV2, MultiRapProver, OpeningProverV2, + ProverBackendV2, ProverDeviceV2, ProvingContextV2, TraceCommitterV2, + }, + Digest, SystemParams, D_EF, EF, F, +}; + +#[derive(Clone, Copy)] +pub struct CpuBackendV2; + +#[derive(Clone, Getters, derive_new::new)] +pub struct CpuDeviceV2 { + #[getset(get = "pub")] + config: SystemParams, +} + +impl ProverBackendV2 for CpuBackendV2 { + const CHALLENGE_EXT_DEGREE: u8 = D_EF as u8; + + type Val = F; + type Challenge = EF; + type Commitment = Digest; + type Matrix = ColMajorMatrix; + type OtherAirData = (); + type PcsData = StackedPcsData; +} + +impl ProverDeviceV2 for CpuDeviceV2 { + fn config(&self) -> &SystemParams { + &self.config + } +} + +impl TraceCommitterV2 for CpuDeviceV2 { + fn commit(&self, traces: &[&ColMajorMatrix]) -> (Digest, StackedPcsData) { + stacked_commit( + self.config.l_skip, + self.config.n_stack, + self.config.log_blowup, + self.config.k_whir(), + traces, + ) + } +} + +impl MultiRapProver for CpuDeviceV2 { + type PartialProof = (GkrProof, BatchConstraintProof); + /// The random opening point `r` where the batch constraint sumcheck reduces to evaluation + /// claims of trace matrices `T, T_{rot}` at `r_{n_T}`. + type Artifacts = Vec; + + fn prove_rap_constraints( + &self, + transcript: &mut TS, + mpk: &DeviceMultiStarkProvingKeyV2, + ctx: &ProvingContextV2, + _common_main_pcs_data: &StackedPcsData, + ) -> ((GkrProof, BatchConstraintProof), Vec) { + let (gkr_proof, batch_constraint_proof, r) = + prove_zerocheck_and_logup(transcript, mpk, ctx); + ((gkr_proof, batch_constraint_proof), r) + } +} + +impl OpeningProverV2 for CpuDeviceV2 { + type OpeningProof = (StackingProof, WhirProof); + /// The shared vector `r` where each trace matrix `T, T_{rot}` is opened at `r_{n_T}`. + type OpeningPoints = Vec; + + fn prove_openings( + &self, + transcript: &mut TS, + mpk: &DeviceMultiStarkProvingKeyV2, + ctx: ProvingContextV2, + common_main_pcs_data: StackedPcsData, + r: Vec, + ) -> (StackingProof, WhirProof) { + let params = &self.config; + + let need_rot_per_trace = ctx + .per_trace + .iter() + .map(|(air_idx, _)| mpk.per_air[*air_idx].vk.params.need_rot) + .collect_vec(); + + // Currently alternates between preprocessed and cached pcs data + let pre_cached_pcs_data_per_commit: Vec<_> = ctx + .per_trace + .iter() + .flat_map(|(air_idx, air_ctx)| { + mpk.per_air[*air_idx] + .preprocessed_data + .iter() + .chain(&air_ctx.cached_mains) + .map(|cd| cd.data.clone()) + }) + .collect(); + + let mut stacked_per_commit = vec![&common_main_pcs_data]; + for data in &pre_cached_pcs_data_per_commit { + stacked_per_commit.push(data); + } + let mut need_rot_per_commit = vec![need_rot_per_trace]; + for (air_idx, air_ctx) in &ctx.per_trace { + let need_rot = mpk.per_air[*air_idx].vk.params.need_rot; + if mpk.per_air[*air_idx].preprocessed_data.is_some() { + need_rot_per_commit.push(vec![need_rot]); + } + for _ in &air_ctx.cached_mains { + need_rot_per_commit.push(vec![need_rot]); + } + } + let (stacking_proof, u_prisma) = + prove_stacked_opening_reduction::<_, _, _, StackedReductionCpu>( + self, + transcript, + self.config.n_stack, + stacked_per_commit, + need_rot_per_commit, + &r, + ); + + let (&u0, u_rest) = u_prisma.split_first().unwrap(); + let u_cube = u0 + .exp_powers_of_2() + .take(params.l_skip) + .chain(u_rest.iter().copied()) + .collect_vec(); + + let whir_proof = self.prove_whir( + transcript, + common_main_pcs_data, + pre_cached_pcs_data_per_commit, + &u_cube, + ); + (stacking_proof, whir_proof) + } +} + +impl DeviceDataTransporterV2 for CpuDeviceV2 { + fn transport_pk_to_device( + &self, + mpk: &MultiStarkProvingKeyV2, + ) -> DeviceMultiStarkProvingKeyV2 { + let per_air = mpk + .per_air + .iter() + .map(|pk| { + let preprocessed_data = pk.preprocessed_data.as_ref().map(|d| { + let trace = d.mat_view(0).to_matrix(); + CommittedTraceDataV2 { + commitment: d.commit(), + trace, + data: d.clone(), + } + }); + DeviceStarkProvingKeyV2 { + air_name: pk.air_name.clone(), + vk: pk.vk.clone(), + preprocessed_data, + other_data: (), + } + }) + .collect(); + DeviceMultiStarkProvingKeyV2::new( + per_air, + mpk.trace_height_constraints.clone(), + mpk.max_constraint_degree, + mpk.params.clone(), + mpk.vk_pre_hash, + ) + } + + fn transport_matrix_to_device(&self, matrix: &ColMajorMatrix) -> ColMajorMatrix { + matrix.clone() + } + + fn transport_pcs_data_to_device( + &self, + pcs_data: &StackedPcsData, + ) -> StackedPcsData { + pcs_data.clone() + } + + fn transport_matrix_from_device_to_host( + &self, + matrix: &ColMajorMatrix, + ) -> ColMajorMatrix { + matrix.clone() + } +} diff --git a/crates/stark-backend-v2/src/prover/hal.rs b/crates/stark-backend-v2/src/prover/hal.rs new file mode 100644 index 00000000..36037737 --- /dev/null +++ b/crates/stark-backend-v2/src/prover/hal.rs @@ -0,0 +1,188 @@ +// TODO[jpw]: replace v1 hal.rs file +// Keep from v1: +// - MatrixDimensions +// +// Changed ProverBackendV2 to remove Challenger(=Transcript) and non-essential types. Only keep the +// types you really need for interfaces. Protocol specific types moved to ProverDeviceV2 (possibly +// could be renamed ProtocolProver) + +use std::sync::Arc; + +use openvm_stark_backend::prover::MatrixDimensions; +use serde::{de::DeserializeOwned, Serialize}; + +use crate::{ + keygen::types::MultiStarkProvingKeyV2, + prover::{ + stacked_pcs::StackedPcsData, AirProvingContextV2, ColMajorMatrix, CommittedTraceDataV2, + CpuBackendV2, DeviceMultiStarkProvingKeyV2, ProvingContextV2, + }, + SystemParams, +}; + +/// Associated types needed by the prover, in the form of buffers and views, +/// specific to a specific hardware backend. +/// +/// Memory allocation and copying is not handled by this trait. +pub trait ProverBackendV2 { + /// Extension field degree for the challenge field `Self::Challenge` over base field + /// `Self::Val`. + const CHALLENGE_EXT_DEGREE: u8; + // ==== Host Types ==== + /// Base field type, on host. + type Val: Copy + Send + Sync + Serialize + DeserializeOwned; + /// Challenge field (extension field of base field), on host. + type Challenge: Copy + Send + Sync + Serialize + DeserializeOwned; + /// Single commitment on host. + // Commitments are small in size and need to be transferred back to host to be included in + // proof. + type Commitment: Clone + Send + Sync + Serialize + DeserializeOwned; + + // ==== Device Types ==== + /// Single matrix buffer on device together with dimension metadata. Owning this means nothing + /// else has a shared reference to the buffer. + type Matrix: MatrixDimensions + Send + Sync; + /// Backend specific type for any pre-computed data associated with a single AIR. For example, + /// it may contain prover-specific precomputations based on the AIR constraints (but + /// independent from any trace data). + type OtherAirData: Send + Sync; + /// Owned buffer for the preimage of a PCS commitment on device, together with any metadata + /// necessary for computing opening proofs. + /// + /// For example, multiple buffers for LDE matrices, their trace domain sizes, and pointer to + /// mixed merkle tree. + type PcsData: Send + Sync; +} + +pub trait ProverDeviceV2: + TraceCommitterV2 + MultiRapProver + OpeningProverV2 +{ + fn config(&self) -> &SystemParams; +} + +/// Provides functionality for committing to a batch of trace matrices, possibly of different +/// heights. +pub trait TraceCommitterV2 { + fn commit(&self, traces: &[&PB::Matrix]) -> (PB::Commitment, PB::PcsData); +} + +/// This trait is responsible for all proving steps to prove a collection of trace matrices +/// satisfies all constraints of a Randomized AIR with Preprocessing. Such constraints include AIR +/// constraints as well as bus balancing constraints for interactions between AIRs. These +/// constraints may be grouped into challenge phases, where new randomness is sampled between phases +/// via Fiat-Shamir (which would involve committing to more data). +/// +/// This trait is _not_ responsible for committing to the trace matrices or for proving polynomial +/// openings with respect to the committed trace matrices. +pub trait MultiRapProver { + /// The partial proof is the proof that the trace matrices satisfy all constraints assuming that + /// certain polynomial opening claims are validated. In other words, it is a proof that reduces + /// the constraint satisfaction claim to certain polynomial opening claims. + type PartialProof: Clone + Send + Sync + Serialize + DeserializeOwned; + /// Other artifacts of the proof (e.g., sampled randomness) that may be passed to later stages + /// of the protocol. + type Artifacts; + + fn prove_rap_constraints( + &self, + transcript: &mut TS, + mpk: &DeviceMultiStarkProvingKeyV2, + ctx: &ProvingContextV2, + common_main_pcs_data: &PB::PcsData, + ) -> (Self::PartialProof, Self::Artifacts); +} + +/// This trait is responsible for proving the evaluation claims of a collection of polynomials at a +/// collection of points. The opening point may be the same across polynomials. The polynomials may +/// be defined over different domains and are hence of "mixed" nature. The polynomials are already +/// committed and provided in their committed form. +pub trait OpeningProverV2 { + /// PCS opening proof on host. This should not be a reference. + type OpeningProof: Clone + Send + Sync + Serialize + DeserializeOwned; + type OpeningPoints; + + /// Computes the opening proof. + /// The `common_main_pcs_data` is the `PcsData` for the collection of common main trace + /// matrices. It is owned by the function and may be mutated. + /// The `pre_cached_pcs_data_per_commit` is the `PcsData` for the preprocessed and cached trace + /// matrices. These are specified by their `PcsData` per commitment. + fn prove_openings( + &self, + transcript: &mut TS, + mpk: &DeviceMultiStarkProvingKeyV2, + ctx: ProvingContextV2, + common_main_pcs_data: PB::PcsData, + points: Self::OpeningPoints, + ) -> Self::OpeningProof; +} + +/// Trait to manage data transport of prover types from host to device. +pub trait DeviceDataTransporterV2 { + /// Transport the proving key to the device, filtering for only the provided `air_ids`. + fn transport_pk_to_device( + &self, + mpk: &MultiStarkProvingKeyV2, + ) -> DeviceMultiStarkProvingKeyV2; + + fn transport_matrix_to_device(&self, matrix: &ColMajorMatrix) -> PB::Matrix; + + /// The `commitment` and `prover_data` are assumed to have been previously computed from the + /// `trace`. + fn transport_pcs_data_to_device( + &self, + pcs_data: &StackedPcsData, + ) -> PB::PcsData; + + fn transport_committed_trace_data_to_device( + &self, + committed_trace: &CommittedTraceDataV2, + ) -> CommittedTraceDataV2 + where + PB: ProverBackendV2, + { + let trace = self.transport_matrix_to_device(&committed_trace.trace); + let data = self.transport_pcs_data_to_device(committed_trace.data.as_ref()); + + CommittedTraceDataV2 { + commitment: committed_trace.commitment, + trace, + data: Arc::new(data), + } + } + + fn transport_proving_ctx_to_device( + &self, + ctx: &ProvingContextV2, + ) -> ProvingContextV2 + where + PB: ProverBackendV2, + { + let per_trace = ctx + .per_trace + .iter() + .map(|(air_idx, air_ctx)| { + let common_main = self.transport_matrix_to_device(&air_ctx.common_main); + let cached_mains = air_ctx + .cached_mains + .iter() + .map(|cd| self.transport_committed_trace_data_to_device(cd)) + .collect(); + let air_ctx_gpu = AirProvingContextV2::new( + cached_mains, + common_main, + air_ctx.public_values.clone(), + ); + (*air_idx, air_ctx_gpu) + }) + .collect(); + ProvingContextV2::new(per_trace) + } + + // ================================================================================== + // Device-to-Host methods below should only be used for testing / debugging purposes. + // ================================================================================== + + /// Transport a device matrix to host. This should only be used for testing / debugging + /// purposes. + fn transport_matrix_from_device_to_host(&self, matrix: &PB::Matrix) -> ColMajorMatrix; +} diff --git a/crates/stark-backend-v2/src/prover/logup_zerocheck/cpu.rs b/crates/stark-backend-v2/src/prover/logup_zerocheck/cpu.rs new file mode 100644 index 00000000..0b6fe1e6 --- /dev/null +++ b/crates/stark-backend-v2/src/prover/logup_zerocheck/cpu.rs @@ -0,0 +1,647 @@ +use std::{ + cmp::max, + iter::{self, zip}, + mem::take, +}; + +use itertools::Itertools; +use openvm_stark_backend::{ + air_builders::symbolic::{ + symbolic_variable::Entry, SymbolicConstraints, SymbolicExpressionNode, + }, + parizip, + prover::MatrixDimensions, +}; +use p3_field::{Field, FieldAlgebra, TwoAdicField}; +use p3_maybe_rayon::prelude::*; +use p3_util::log2_strict_usize; + +use crate::{ + poly_common::{eval_eq_mle, eval_eq_sharp_uni, eval_eq_uni, UnivariatePoly}, + prover::{ + logup_zerocheck::EvalHelper, + poly::evals_eq_hypercubes, + stacked_pcs::StackedLayout, + sumcheck::{ + batch_fold_mle_evals, batch_fold_ple_evals, fold_ple_evals, sumcheck_round0_deg, + sumcheck_round_poly_evals, sumcheck_uni_round0_poly, + }, + ColMajorMatrix, CpuBackendV2, DeviceMultiStarkProvingKeyV2, ProvingContextV2, + }, + EF, F, +}; + +pub struct LogupZerocheckCpu<'a> { + pub alpha_logup: EF, + pub beta_pows: Vec, + + pub l_skip: usize, + pub n_logup: usize, + pub n_max: usize, + + pub omega_skip: F, + pub omega_skip_pows: Vec, + + pub interactions_layout: StackedLayout, + pub(crate) eval_helpers: Vec>, + /// Max constraint degree across constraints and interactions + pub constraint_degree: usize, + pub n_per_trace: Vec, + max_num_constraints: usize, + + // Available after GKR: + pub xi: Vec, + lambda_pows: Vec, + // T -> segment tree of eq(xi[j..1+n_T]) for j=1..=n_T in _reverse_ layout + eq_xi_per_trace: Vec>, + eq_3b_per_trace: Vec>, + sels_per_trace_base: Vec>, + // After univariate round 0: + pub mat_evals_per_trace: Vec>>, + pub sels_per_trace: Vec>, + // Stores \hat{f}(\vec r_n) * r_{n+1} .. r_{round-1} for polys f that are "done" in the batch + // sumcheck + pub(crate) zerocheck_tilde_evals: Vec, + pub(crate) logup_tilde_evals: Vec<[EF; 2]>, + + // In round `j`, contains `s_{j-1}(r_{j-1})` + pub(crate) prev_s_eval: EF, + pub(crate) eq_ns: Vec, + pub(crate) eq_sharp_ns: Vec, +} + +impl<'a> LogupZerocheckCpu<'a> { + pub fn new( + pk: &'a DeviceMultiStarkProvingKeyV2, + ctx: &ProvingContextV2, + n_logup: usize, + interactions_layout: StackedLayout, + alpha_logup: EF, + beta_logup: EF, + ) -> Self { + let l_skip = pk.params.l_skip; + let omega_skip = F::two_adic_generator(l_skip); + let omega_skip_pows = omega_skip.powers().take(1 << l_skip).collect_vec(); + let num_airs_present = ctx.per_trace.len(); + + let constraint_degree = pk.max_constraint_degree; + let max_interaction_length = ctx + .per_trace + .iter() + .flat_map(|(air_idx, _)| { + pk.per_air[*air_idx] + .vk + .symbolic_constraints + .interactions + .iter() + .map(|i| i.message.len()) + }) + .max() + .unwrap_or(0); + let beta_pows = beta_logup + .powers() + .take(max_interaction_length + 1) + .collect_vec(); + + let n_per_trace: Vec = ctx + .common_main_traces() + .map(|(_, t)| log2_strict_usize(t.height()) as isize - l_skip as isize) + .collect_vec(); + let n_max: usize = n_per_trace[0].max(0) as usize; + + let eval_helpers: Vec> = ctx + .per_trace + .iter() + .map(|(air_idx, air_ctx)| { + let pk = &pk.per_air[*air_idx]; + let constraints = &pk.vk.symbolic_constraints.constraints; + let public_values = air_ctx.public_values.clone(); + let preprocessed_trace = + pk.preprocessed_data.as_ref().map(|cd| cd.data.mat_view(0)); + let partitioned_main_trace = air_ctx + .cached_mains + .iter() + .map(|cd| cd.data.mat_view(0)) + .chain(iter::once(air_ctx.common_main.as_view().into())) + .collect_vec(); + let constraint_degree = pk.vk.max_constraint_degree; + // Scan constraints to see if we need `next` row and also check index bounds + // so we don't need to check them per row. + let mut rotation = 0; + for node in &constraints.nodes { + if let SymbolicExpressionNode::Variable(var) = node { + match var.entry { + Entry::Preprocessed { offset } => { + rotation = max(rotation, offset); + assert!(var.index < preprocessed_trace.as_ref().unwrap().width()); + } + Entry::Main { part_index, offset } => { + rotation = max(rotation, offset); + assert!( + var.index < partitioned_main_trace[part_index].width(), + "col_index={} >= main partition {} width={}", + var.index, + part_index, + partitioned_main_trace[part_index].width() + ); + } + Entry::Public => { + assert!(var.index < public_values.len()); + } + _ => unreachable!("after_challenge not supported"), + } + } + } + let needs_next = pk.vk.params.need_rot; + debug_assert_eq!(needs_next, rotation > 0); + let symbolic_constraints = SymbolicConstraints::from(&pk.vk.symbolic_constraints); + EvalHelper { + constraints_dag: &pk.vk.symbolic_constraints.constraints, + interactions: symbolic_constraints.interactions, + public_values, + preprocessed_trace, + needs_next, + constraint_degree, + } + }) + .collect(); // end of preparation / loading of constraints + let max_num_constraints = pk + .per_air + .iter() + .map(|pk| pk.vk.symbolic_constraints.constraints.constraint_idx.len()) + .max() + .unwrap_or(0); + + let zerocheck_tilde_evals = vec![EF::ZERO; num_airs_present]; + let logup_tilde_evals = vec![[EF::ZERO; 2]; num_airs_present]; + Self { + alpha_logup, + beta_pows, + l_skip, + n_logup, + n_max, + omega_skip, + omega_skip_pows, + interactions_layout, + constraint_degree, + max_num_constraints, + n_per_trace, + eval_helpers, + xi: vec![], + lambda_pows: vec![], + sels_per_trace_base: vec![], + eq_xi_per_trace: vec![], + eq_3b_per_trace: vec![], + mat_evals_per_trace: vec![], + sels_per_trace: vec![], + zerocheck_tilde_evals, + logup_tilde_evals, + prev_s_eval: EF::ZERO, + eq_ns: Vec::with_capacity(n_max + 1), + eq_sharp_ns: Vec::with_capacity(n_max + 1), + } + } + + /// Returns the `s_0` polynomials in coefficient form. There should be exactly `num_airs_present + /// \* 3` polynomials, in the order `(s_0)_{p,T}, (s_0)_{q,T}, (s_0)_{zerocheck,T}` per trace + /// `T`. This is computed _before_ sampling batching randomness `mu` because the result is + /// used to observe the sum claims `sum_{p,T}, sum_{q,T}`. The `s_0` polynomials could be + /// returned in either coefficient or evaluation form, but we return them all in coefficient + /// form for uniformity and debugging since this interpolation is inexpensive. + pub fn sumcheck_uni_round0_polys( + &mut self, + ctx: &ProvingContextV2, + lambda: EF, + ) -> Vec> { + let n_logup = self.n_logup; + let l_skip = self.l_skip; + let xi = &self.xi; + self.lambda_pows = lambda.powers().take(self.max_num_constraints).collect_vec(); + + // For each trace, for each interaction \hat\sigma, the eq(ξ_3,b_{T,\hat\sigma}) term. + // This is some weight per interaction that does not depend on the row. + self.eq_3b_per_trace = self + .eval_helpers + .par_iter() + .zip(&self.n_per_trace) + .enumerate() + .map(|(trace_idx, (helper, &n))| { + // Everything for logup is done with respect to lifted traces + // Note: `n_lift = \tilde{n}` from the paper + let n_lift = n.max(0) as usize; + if helper.interactions.is_empty() { + return vec![]; + } + let mut b_vec = vec![F::ZERO; n_logup - n_lift]; + (0..helper.interactions.len()) + .map(|i| { + // PERF[jpw]: interactions_layout.get is linear + let stacked_idx = + self.interactions_layout.get(trace_idx, i).unwrap().row_idx; + debug_assert!(stacked_idx.trailing_zeros() as usize >= n_lift + l_skip); + let mut b_int = stacked_idx >> (l_skip + n_lift); + for b in &mut b_vec { + *b = F::from_bool(b_int & 1 == 1); + b_int >>= 1; + } + eval_eq_mle(&xi[l_skip + n_lift..l_skip + n_logup], &b_vec) + }) + .collect_vec() + }) + .collect::>(); + + // PERF[jpw]: make Hashmap from unique n -> eq_n(xi, -) + // NOTE: this is evaluations of `x -> eq_{H_{\tilde n}}(x, \xi[l_skip..l_skip + \tilde n])` + // on hypercube `H_{\tilde n}`. We store the univariate component eq_D separately as + // an optimization. + self.eq_xi_per_trace = self + .n_per_trace + .par_iter() + .map(|&n| { + let n_lift = n.max(0) as usize; + // PERF[jpw]: might be able to share computations between eq_xi, eq_sharp + // computations the eq(xi, -) evaluations on hyperprism for + // zerocheck + evals_eq_hypercubes(n_lift, xi[l_skip..l_skip + n_lift].iter().rev()) + }) + .collect(); + + // For each trace, create selectors as a 3-column matrix of _the lifts of_ [is_first, + // is_transition, is_last] + // + // PERF[jpw]: I think it's better to not save these and just + // interpolate directly using the formulas for selectors + self.sels_per_trace_base = self + .n_per_trace + .iter() + .map(|&n| { + let log_height = l_skip.checked_add_signed(n).unwrap(); + let height = 1 << log_height; + let lifted_height = height.max(1 << l_skip); + let mut mat = F::zero_vec(3 * lifted_height); + mat[lifted_height..2 * lifted_height].fill(F::ONE); + for i in (0..lifted_height).step_by(height) { + mat[i] = F::ONE; // is_first + mat[lifted_height + i + height - 1] = F::ZERO; // is_transition + mat[2 * lifted_height + i + height - 1] = F::ONE; // is_last + } + ColMajorMatrix::new(mat, 3) + }) + .collect_vec(); + + let sp_0_zerochecks = self + .eval_helpers + .par_iter() + .enumerate() + .map(|(trace_idx, helper)| { + let trace_ctx = &ctx.per_trace[trace_idx].1; + let n_lift = log2_strict_usize(trace_ctx.height()).saturating_sub(l_skip); + let mats = &helper.view_mats(trace_ctx); + let eq_xi = &self.eq_xi_per_trace[trace_idx][(1 << n_lift) - 1..(2 << n_lift) - 1]; + let sels = self.sels_per_trace_base[trace_idx].as_view(); + let mut parts = vec![(sels.into(), false)]; + parts.extend_from_slice(mats); + // s'_0 has degree dependent on this AIR's constraint degree + // s'_0(Z) is a univariate polynomial which vanishes on D (zerocheck). Hence q(Z) = + // s'_0(Z) / Z_D(Z) = s'_0(Z) / (Z^{2^l_skip} - 1) is a polynomial of degree d * + // (2^l_skip - 1) - 2^l_skip = (d - 1) * 2^l_skip - d We can obtain + // q(Z) by interpolating evaluations on (d - 1) * 2^l_skip points. For computation + // efficiency, we choose these to be (d - 1) cosets of D. To avoid divide by zero, + // we avoid the coset equal to the subgroup D itself. + let constraint_deg = helper.constraint_degree as usize; + if constraint_deg == 0 { + return UnivariatePoly(vec![]); + } + let num_cosets = constraint_deg - 1; + let [q] = sumcheck_uni_round0_poly( + l_skip, + n_lift, + num_cosets, + &parts, + |z, x, row_parts| { + let eq = eq_xi[x]; + let constraint_eval = helper.acc_constraints(row_parts, &self.lambda_pows); + let zerofier = z.exp_power_of_2(l_skip) - F::ONE; + [eq * constraint_eval * zerofier.inverse()] + }, + ); + // sp_0 = (Z^{2^l_skip} - 1) * q + let sp_0_deg = sumcheck_round0_deg(l_skip, constraint_deg); + let coeffs = (0..=sp_0_deg) + .map(|i| { + let mut c = -*q.coeffs().get(i).unwrap_or(&EF::ZERO); + if i >= 1 << l_skip { + c += q.coeffs()[i - (1 << l_skip)]; + } + c + }) + .collect_vec(); + debug_assert_eq!( + coeffs.iter().step_by(1 << l_skip).copied().sum::(), + EF::ZERO, + "Zerocheck sum is not zero for air_id: {}", + ctx.per_trace[trace_idx].0 + ); + UnivariatePoly(coeffs) + }) + .collect::>(); + // Reminder: sum claims for zerocheck are zero, per AIR + + // We interpolate each logup round 0 sumcheck poly because we need to use it to compute + // sum_{\hat{p}, T, I}, sum_{\hat{q}, T, I} per trace. + let sp_0_logups = self + .eval_helpers + .par_iter() + .enumerate() + .flat_map(|(trace_idx, helper)| { + if helper.interactions.is_empty() { + return [(); 2].map(|_| UnivariatePoly::new(vec![])); + } + let trace_ctx = &ctx.per_trace[trace_idx].1; + let log_height = log2_strict_usize(trace_ctx.height()); + let n_lift = log_height.saturating_sub(l_skip); + let mats = &helper.view_mats(trace_ctx); + let eq_xi = &self.eq_xi_per_trace[trace_idx][(1 << n_lift) - 1..(2 << n_lift) - 1]; + let eq_3bs = &self.eq_3b_per_trace[trace_idx]; + let sels = self.sels_per_trace_base[trace_idx].as_view(); + let mut parts = vec![(sels.into(), false)]; + parts.extend_from_slice(mats); + let norm_factor_denom = 1 << l_skip.saturating_sub(log_height); + let norm_factor = F::from_canonical_usize(norm_factor_denom).inverse(); + + // degree is constraint_degree + 1 due to eq term + let [mut numer, denom] = sumcheck_uni_round0_poly( + l_skip, + n_lift, + helper.constraint_degree as usize, + &parts, + |_z, x, row_parts| { + let eq = eq_xi[x]; + let [numer, denom] = + helper.acc_interactions(row_parts, &self.beta_pows, eq_3bs); + [eq * numer, eq * denom] + }, + ); + for p in numer.coeffs_mut() { + *p *= norm_factor; + } + [numer, denom] + }) + .collect::>(); + + sp_0_logups.into_iter().chain(sp_0_zerochecks).collect() + } + + /// After univariate sumcheck round 0, fold prismalinear evaluations using randomness `r_0`. + /// Folding _could_ directly mutate inplace the trace matrices in `ctx` as they will not be + /// needed after this. + pub fn fold_ple_evals(&mut self, ctx: &ProvingContextV2, r_0: EF) { + let l_skip = self.l_skip; + // "Fold" all PLE evaluations by interpolating and evaluating at `r_0`. + // NOTE: after this folding, \hat{T} and \hat{T_{rot}} will be treated as completely + // distinct matrices. + self.mat_evals_per_trace = self + .eval_helpers + .par_iter() + .zip(ctx.per_trace.par_iter()) + .map(|(helper, (_, trace_ctx))| { + let mats = helper.view_mats(trace_ctx); + mats.into_par_iter() + .map(|(mat, is_rot)| fold_ple_evals(l_skip, mat, is_rot, r_0)) + .collect::>() + }) + .collect::>(); + self.sels_per_trace = + batch_fold_ple_evals(l_skip, take(&mut self.sels_per_trace_base), false, r_0); + let eq_r0 = eval_eq_uni(l_skip, self.xi[0], r_0); + let eq_sharp_r0 = eval_eq_sharp_uni(&self.omega_skip_pows, &self.xi[..l_skip], r_0); + self.eq_ns.push(eq_r0); + self.eq_sharp_ns.push(eq_sharp_r0); + self.eq_xi_per_trace.iter_mut().for_each(|eq| { + // trim the back (which corresponds to r_{j-1}) because we don't need it anymore + if eq.len() > 1 { + eq.truncate(eq.len() / 2); + } + }); + } + + /// Returns length `3 * num_airs_present` polynomials, each polynomial either evaluated at + /// `1,...,deg(s')` or at `1` if a linear term (terms in front-loaded sumcheck that have reached + /// exhaustion) + pub fn sumcheck_polys_eval(&mut self, round: usize, r_prev: EF) -> Vec> { + // sp = s' + let sp_deg = self.constraint_degree; + let sp_zerocheck_evals: Vec> = parizip!( + &self.eval_helpers, + &mut self.zerocheck_tilde_evals, + &self.n_per_trace, + &self.mat_evals_per_trace, + &self.sels_per_trace, + &self.eq_xi_per_trace + ) + .map(|(helper, tilde_eval, &n, mats, sels, eq_xi_tree)| { + let n_lift = n.max(0) as usize; + if round > n_lift { + if round == n_lift + 1 { + // Evaluate \hat{f}(\vec r_n) + let parts = iter::once(sels) + .chain(mats) + .map(|mat| mat.columns().map(|c| c[0]).collect_vec()) + .collect_vec(); + // eq(xi, \vect r_{round-1}) + let eq_r_acc = *self.eq_ns.last().unwrap(); + *tilde_eval = eq_r_acc * helper.acc_constraints(&parts, &self.lambda_pows); + } else { + *tilde_eval *= r_prev; + }; + vec![*tilde_eval] + } else { + let log_num_y = n_lift - round; + let num_y = 1 << log_num_y; + let eq_xi = &eq_xi_tree[num_y - 1..]; + let parts = iter::once(sels) + .chain(mats) + .map(|m| m.as_view()) + .collect_vec(); + let [s] = + sumcheck_round_poly_evals(log_num_y + 1, sp_deg, &parts, |_x, y, row_parts| { + let eq = eq_xi[y]; + let constraint_eval = helper.acc_constraints(row_parts, &self.lambda_pows); + [eq * constraint_eval] + }); + s + } + }) + .collect(); + + let sp_logup_evals: Vec> = parizip!( + &self.eval_helpers, + &mut self.logup_tilde_evals, + &self.n_per_trace, + &self.mat_evals_per_trace, + &self.sels_per_trace, + &self.eq_xi_per_trace, + &self.eq_3b_per_trace + ) + .flat_map(|(helper, tilde_eval, &n, mats, sels, eq_xi_tree, eq_3bs)| { + if helper.interactions.is_empty() { + return [vec![EF::ZERO; sp_deg], vec![EF::ZERO; sp_deg]]; + } + let n_lift = n.max(0) as usize; + let norm_factor_denom = 1 << (-n).max(0); + let norm_factor = F::from_canonical_usize(norm_factor_denom).inverse(); + if round > n_lift { + if round == n_lift + 1 { + // Evaluate \hat{f}(\vec r_n) + let parts = iter::once(sels) + .chain(mats) + .map(|mat| mat.columns().map(|c| c[0]).collect_vec()) + .collect_vec(); + let eq_sharp_r_acc = *self.eq_sharp_ns.last().unwrap(); + *tilde_eval = helper + .acc_interactions(&parts, &self.beta_pows, eq_3bs) + .map(|x| eq_sharp_r_acc * x); + tilde_eval[0] *= norm_factor; + } else { + for x in tilde_eval.iter_mut() { + *x *= r_prev; + } + }; + tilde_eval.map(|tilde_eval| vec![tilde_eval]) + } else { + let parts = iter::once(sels) + .chain(mats) + .map(|m| m.as_view()) + .collect_vec(); + let log_num_y = n_lift - round; + let num_y = 1 << log_num_y; + let eq_xi = &eq_xi_tree[num_y - 1..]; + let [mut numer, denom] = + sumcheck_round_poly_evals(log_num_y + 1, sp_deg, &parts, |_x, y, row_parts| { + let eq = eq_xi[y]; + helper + .acc_interactions(row_parts, &self.beta_pows, eq_3bs) + .map(|eval| eq * eval) + }); + for p in &mut numer { + *p *= norm_factor; + } + [numer, denom] + } + }) + .collect(); + + sp_logup_evals + .into_iter() + .chain(sp_zerocheck_evals) + .collect() + } + + pub fn fold_mle_evals(&mut self, round: usize, r_round: EF) { + self.mat_evals_per_trace = take(&mut self.mat_evals_per_trace) + .into_iter() + .map(|mats| batch_fold_mle_evals(mats, r_round)) + .collect_vec(); + self.sels_per_trace = batch_fold_mle_evals(take(&mut self.sels_per_trace), r_round); + self.eq_xi_per_trace.par_iter_mut().for_each(|eq| { + // trim the back (which corresponds to r_{j-1}) because we don't need it anymore + if eq.len() > 1 { + eq.truncate(eq.len() / 2); + } + }); + let xi = self.xi[self.l_skip + round - 1]; + let eq_r = eval_eq_mle(&[xi], &[r_round]); + self.eq_ns.push(self.eq_ns[round - 1] * eq_r); + self.eq_sharp_ns.push(self.eq_sharp_ns[round - 1] * eq_r); + + #[allow(unused_variables)] + #[cfg(debug_assertions)] + if tracing::enabled!(tracing::Level::DEBUG) && round == self.n_max { + use itertools::izip; + + for (trace_idx, (helper, &n, mats, sels, eq_xi)) in izip!( + &self.eval_helpers, + &self.n_per_trace, + &self.mat_evals_per_trace, + &self.sels_per_trace, + &self.eq_xi_per_trace + ) + .enumerate() + { + let parts = iter::once(sels) + .chain(mats) + .map(|mat| mat.columns().map(|c| c[0]).collect_vec()) + .collect_vec(); + let expr = helper.acc_constraints(&parts, &self.lambda_pows); + tracing::debug!(%trace_idx, %expr, "constraints_eval"); + } + + for (trace_idx, (helper, &n, mats, sels, eq_3bs)) in izip!( + &self.eval_helpers, + &self.n_per_trace, + &self.mat_evals_per_trace, + &self.sels_per_trace, + &self.eq_3b_per_trace + ) + .enumerate() + { + if helper.interactions.is_empty() { + continue; + } + let parts = iter::once(sels) + .chain(mats) + .map(|mat| mat.columns().map(|c| c[0]).collect_vec()) + .collect_vec(); + let [num, denom] = helper.acc_interactions(&parts, &self.beta_pows, eq_3bs); + + tracing::debug!(%trace_idx, %num, %denom, "interactions_eval"); + } + } + } + + pub fn into_column_openings(&mut self) -> Vec>> { + let num_airs_present = self.mat_evals_per_trace.len(); + let mut column_openings = Vec::with_capacity(num_airs_present); + // At the end, we've folded all MLEs so they only have one row equal to evaluation at `\vec + // r`. + for (helper, mut mat_evals) in self + .eval_helpers + .iter() + .zip(take(&mut self.mat_evals_per_trace)) + { + // For column openings, we pop common_main (and common_main_rot when present) and put it + // at the front. + let openings_of_air: Vec> = if helper.needs_next { + let common_main_rot = mat_evals.pop().unwrap(); + let common_main = mat_evals.pop().unwrap(); + iter::once(&[common_main, common_main_rot] as &[_]) + .chain(mat_evals.chunks_exact(2)) + .map(|pair| { + zip(pair[0].columns(), pair[1].columns()) + .flat_map(|(claim, claim_rot)| { + assert_eq!(claim.len(), 1); + assert_eq!(claim_rot.len(), 1); + [claim[0], claim_rot[0]] + }) + .collect_vec() + }) + .collect_vec() + } else { + let common_main = mat_evals.pop().unwrap(); + iter::once(common_main) + .chain(mat_evals.into_iter()) + .map(|mat| { + mat.columns() + .map(|claim| { + assert_eq!(claim.len(), 1); + claim[0] + }) + .collect_vec() + }) + .collect_vec() + }; + column_openings.push(openings_of_air); + } + column_openings + } +} diff --git a/crates/stark-backend-v2/src/prover/logup_zerocheck/evaluator.rs b/crates/stark-backend-v2/src/prover/logup_zerocheck/evaluator.rs new file mode 100644 index 00000000..e3a92c30 --- /dev/null +++ b/crates/stark-backend-v2/src/prover/logup_zerocheck/evaluator.rs @@ -0,0 +1,77 @@ +use openvm_stark_backend::air_builders::symbolic::{ + symbolic_expression::SymbolicEvaluator, + symbolic_variable::{Entry, SymbolicVariable}, +}; +use p3_field::{ExtensionField, Field}; + +pub(super) struct ViewPair { + pub(super) local: *const T, + pub(super) next: Option<*const T>, +} + +impl ViewPair { + pub fn new(local: &[T], next: Option<&[T]>) -> Self { + Self { + local: local.as_ptr(), + next: next.map(|nxt| nxt.as_ptr()), + } + } + + /// SAFETY: no matrix bounds checks are done. + pub unsafe fn get(&self, row_offset: usize, column_idx: usize) -> &T { + match row_offset { + 0 => &*self.local.add(column_idx), + 1 => &*self.next.unwrap_unchecked().add(column_idx), + _ => panic!("row offset {row_offset} not supported"), + } + } +} + +/// Struct containing partitioned view of one row together with optional "rotated" row. +/// Constraints are evaluated on this struct. +pub(super) struct ProverConstraintEvaluator<'a, F, EF> { + pub preprocessed: Option>, + pub partitioned_main: Vec>, + pub is_first_row: EF, + pub is_last_row: EF, + pub is_transition: EF, + pub public_values: &'a [F], +} + +impl> SymbolicEvaluator + for ProverConstraintEvaluator<'_, F, EF> +{ + fn eval_const(&self, c: F) -> EF { + c.into() + } + fn eval_is_first_row(&self) -> EF { + self.is_first_row + } + fn eval_is_last_row(&self) -> EF { + self.is_last_row + } + fn eval_is_transition(&self) -> EF { + self.is_transition + } + + /// SAFETY: we only use this trait implementation when we have already done + /// a previous scan to ensure all matrix bounds are satisfied, + /// so no bounds checks are done here. + fn eval_var(&self, symbolic_var: SymbolicVariable) -> EF { + let index = symbolic_var.index; + match symbolic_var.entry { + Entry::Preprocessed { offset } => unsafe { + *self + .preprocessed + .as_ref() + .unwrap_unchecked() + .get(offset, index) + }, + Entry::Main { part_index, offset } => unsafe { + *self.partitioned_main[part_index].get(offset, index) + }, + Entry::Public => unsafe { EF::from(*self.public_values.get_unchecked(index)) }, + _ => unreachable!("after_challenge not supported"), + } + } +} diff --git a/crates/stark-backend-v2/src/prover/logup_zerocheck/fractional_sumcheck_gkr.rs b/crates/stark-backend-v2/src/prover/logup_zerocheck/fractional_sumcheck_gkr.rs new file mode 100644 index 00000000..98fd2414 --- /dev/null +++ b/crates/stark-backend-v2/src/prover/logup_zerocheck/fractional_sumcheck_gkr.rs @@ -0,0 +1,209 @@ +use std::ops::Add; + +use p3_field::{Field, FieldAlgebra}; +use p3_util::log2_strict_usize; +use tracing::{debug, instrument}; + +use crate::{ + poseidon2::sponge::FiatShamirTranscript, + proof::GkrLayerClaims, + prover::{ + poly::evals_eq_hypercube, + sumcheck::{fold_mle_evals, sumcheck_round_poly_evals}, + ColMajorMatrix, + }, + EF, +}; + +/// Proof for fractional sumcheck protocol +pub struct FracSumcheckProof { + /// The fractional sum p_0 / q_0 + pub fractional_sum: (EF, EF), + /// The claims for p_j(0, rho), p_j(1, rho), q_j(0, rho), and q_j(1, rho) for each layer j > 0. + pub claims_per_layer: Vec, + /// Sumcheck polynomials for each layer, for each sumcheck round, given by their evaluations on + /// {1, 2, 3}. + pub sumcheck_polys: Vec>, +} + +#[derive(Clone, Copy, Debug, Default, derive_new::new)] +#[repr(C)] +pub struct Frac { + // PERF[jpw]: in the initial round, we can keep `p` in base field + pub p: EF, + pub q: EF, +} + +impl Add> for Frac { + type Output = Frac; + + fn add(self, other: Frac) -> Self::Output { + Frac { + p: self.p * other.q + self.q * other.p, + q: self.q * other.q, + } + } +} + +/// Runs the fractional sumcheck protocol using GKR layered circuit. +/// +/// # Arguments +/// * `transcript` - The Fiat-Shamir transcript +/// * `evals` - list of `(p, q)` pairs of fractions in projective coordinates representing +/// evaluations on the hypercube +/// * `assert_zero` - Whether to assert that the final sum is zero. If `true`, then the transcript +/// will not observe the numerator of the final sum. +/// +/// # Returns +/// The fractional sumcheck proof and the final random evaluation vector. +#[instrument(level = "info", skip_all)] +pub fn fractional_sumcheck( + transcript: &mut TS, + evals: &[Frac], + assert_zero: bool, +) -> (FracSumcheckProof, Vec) { + if evals.is_empty() { + return ( + FracSumcheckProof { + fractional_sum: (EF::ZERO, EF::ONE), + claims_per_layer: vec![], + sumcheck_polys: vec![], + }, + vec![], + ); + } + // total_rounds = l_skip + n_logup + let total_rounds = log2_strict_usize(evals.len()); + // sumcheck polys for layers j=2,...,total_rounds + let mut sumcheck_polys = Vec::with_capacity(total_rounds); + + // segment tree: layer i=0,...,total_rounds starts at 2^i (index 0 unused) + let mut tree_evals: Vec> = vec![Frac::default(); 2 << total_rounds]; + tree_evals[(1 << total_rounds)..].copy_from_slice(evals); + + for node_idx in (1..(1 << total_rounds)).rev() { + tree_evals[node_idx] = tree_evals[2 * node_idx] + tree_evals[2 * node_idx + 1]; + } + let frac_sum = tree_evals[1]; + if assert_zero { + assert_eq!(frac_sum.p, EF::ZERO); + } else { + transcript.observe_ext(frac_sum.p); + } + transcript.observe_ext(frac_sum.q); + + // Index i is for layer i+1 + let mut claims_per_layer: Vec = Vec::with_capacity(total_rounds); + + // Process each GKR round + // `j = round + 1` goes from `1, ..., total_rounds` + // + // Round `j = 1` is special since "sumcheck" is directly checked by verifier + claims_per_layer.push(GkrLayerClaims { + p_xi_0: tree_evals[2].p, + q_xi_0: tree_evals[2].q, + p_xi_1: tree_evals[3].p, + q_xi_1: tree_evals[3].q, + }); + transcript.observe_ext(claims_per_layer[0].p_xi_0); + transcript.observe_ext(claims_per_layer[0].q_xi_0); + transcript.observe_ext(claims_per_layer[0].p_xi_1); + transcript.observe_ext(claims_per_layer[0].q_xi_1); + let mu_1 = transcript.sample_ext(); + debug!(gkr_round = 0, mu = %mu_1); + // ξ^{(j-1)} + let mut xi_prev = vec![mu_1]; + + // GKR rounds + for round in 1..total_rounds { + // Number of hypercube points + let eval_size = 1 << round; + // We apply batch sumcheck to the polynomials + // \eq(ξ^{(j-1)}, Y) (\hat p_j(0, Y) \hat q_j(1, Y) + \hat p_j(1, Y) \hat q_j(0, Y)) + // \eq(ξ^{(j-1)}, Y) (\hat q_j(0, Y) \hat q_j(1, Y)) + // Note: these are polynomials of degree 3 in each Y_i coordinate. + + // Sample λ_j for batching + let lambda = transcript.sample_ext(); + debug!(gkr_round = round, %lambda); + + // Columns are p_j0, q_j0, p_j1, q_j1 + // PERF: use a view instead of re-allocating memory + let mut pq_j_evals = EF::zero_vec(4 * eval_size); + let segment = &tree_evals[2 * eval_size..4 * eval_size]; + for x in 0..eval_size { + pq_j_evals[x] = segment[2 * x].p; + pq_j_evals[eval_size + x] = segment[2 * x].q; + pq_j_evals[2 * eval_size + x] = segment[2 * x + 1].p; + pq_j_evals[3 * eval_size + x] = segment[2 * x + 1].q; + } + let mut pq_j_evals = ColMajorMatrix::new(pq_j_evals, 4); + let mut eq_xis = ColMajorMatrix::new(evals_eq_hypercube(&xi_prev), 1); + + // Batch sumcheck where the round polynomials are evaluated at {1, 2, 3} + let (round_polys_eval, rho) = { + let n = round; + let mut round_polys_eval = Vec::with_capacity(n); + let mut r_vec = Vec::with_capacity(n); + + // Sumcheck rounds: apply fraction addition in projective coordinates to MLEs + for sumcheck_round in 0..n { + // Evaluate the univariate polynomial at {1, 2, 3} + // :projective fraction addition is degree 2, and then another +1 for eq + let [s_evals] = sumcheck_round_poly_evals( + n - sumcheck_round, + 3, + &[eq_xis.as_view(), pq_j_evals.as_view()], + |_x, _y, row| { + let eq_xi = row[0][0]; + let &[p_j0, q_j0, p_j1, q_j1] = row[1].as_slice().try_into().unwrap(); + let p_prev = p_j0 * q_j1 + p_j1 * q_j0; + let q_prev = q_j0 * q_j1; + // batch using lambda + [eq_xi * (p_prev + lambda * q_prev)] + }, + ); + let s_evals: [EF; 3] = s_evals.try_into().unwrap(); + for &eval in &s_evals { + transcript.observe_ext(eval); + } + round_polys_eval.push(s_evals); + + let r_round = transcript.sample_ext(); + pq_j_evals = fold_mle_evals(pq_j_evals, r_round); + eq_xis = fold_mle_evals(eq_xis, r_round); + r_vec.push(r_round); + debug!(gkr_round = round, %sumcheck_round, %r_round); + } + (round_polys_eval, r_vec) + }; + claims_per_layer.push(GkrLayerClaims { + p_xi_0: pq_j_evals.column(0)[0], + q_xi_0: pq_j_evals.column(1)[0], + p_xi_1: pq_j_evals.column(2)[0], + q_xi_1: pq_j_evals.column(3)[0], + }); + transcript.observe_ext(claims_per_layer[round].p_xi_0); + transcript.observe_ext(claims_per_layer[round].q_xi_0); + transcript.observe_ext(claims_per_layer[round].p_xi_1); + transcript.observe_ext(claims_per_layer[round].q_xi_1); + + // Sample μ_j for reduction to single evaluation point + let mu = transcript.sample_ext(); + debug!(gkr_round = round, %mu); + + // Update ξ^{(j)} = (μ_j, ρ^{(j-1)}) + xi_prev = [vec![mu], rho].concat(); + + sumcheck_polys.push(round_polys_eval); + } + + ( + FracSumcheckProof { + fractional_sum: (frac_sum.p, frac_sum.q), + claims_per_layer, + sumcheck_polys, + }, + xi_prev, + ) +} diff --git a/crates/stark-backend-v2/src/prover/logup_zerocheck/mod.rs b/crates/stark-backend-v2/src/prover/logup_zerocheck/mod.rs new file mode 100644 index 00000000..80681417 --- /dev/null +++ b/crates/stark-backend-v2/src/prover/logup_zerocheck/mod.rs @@ -0,0 +1,427 @@ +//! Batch sumcheck for ZeroCheck constraints and sumcheck for LogUp input layer MLEs + +use std::{cmp::max, iter::zip}; + +use itertools::Itertools; +use openvm_stark_backend::prover::MatrixDimensions; +use p3_dft::TwoAdicSubgroupDft; +use p3_field::{Field, FieldAlgebra}; +use p3_matrix::dense::RowMajorMatrix; +use p3_maybe_rayon::prelude::*; +use p3_util::log2_strict_usize; +use tracing::{debug, info_span, instrument}; + +use crate::{ + calculate_n_logup, + dft::Radix2BowersSerial, + poly_common::{eq_sharp_uni_poly, eq_uni_poly, UnivariatePoly}, + poseidon2::sponge::FiatShamirTranscript, + proof::{column_openings_by_rot, BatchConstraintProof, GkrProof}, + prover::{ + fractional_sumcheck_gkr::{fractional_sumcheck, Frac}, + stacked_pcs::StackedLayout, + sumcheck::sumcheck_round0_deg, + CpuBackendV2, DeviceMultiStarkProvingKeyV2, MatrixView, ProvingContextV2, + }, + EF, F, +}; + +mod cpu; +mod evaluator; +pub mod fractional_sumcheck_gkr; +mod single; + +pub use cpu::LogupZerocheckCpu; +pub use single::*; + +#[instrument(level = "info", skip_all)] +pub fn prove_zerocheck_and_logup( + transcript: &mut TS, + mpk: &DeviceMultiStarkProvingKeyV2, + ctx: &ProvingContextV2, +) -> (GkrProof, BatchConstraintProof, Vec) +where + TS: FiatShamirTranscript, +{ + let l_skip = mpk.params.l_skip; + let constraint_degree = mpk.max_constraint_degree; + let num_traces = ctx.per_trace.len(); + + // Traces are sorted + let n_max = log2_strict_usize(ctx.per_trace[0].1.common_main.height()).saturating_sub(l_skip); + // Gather interactions metadata, including interactions stacked layout which depends on trace + // heights + let mut total_interactions = 0u64; + let interactions_meta: Vec<_> = ctx + .per_trace + .iter() + .map(|(air_idx, air_ctx)| { + let pk = &mpk.per_air[*air_idx]; + + let num_interactions = pk.vk.symbolic_constraints.interactions.len(); + let height = air_ctx.common_main.height(); + let log_height = log2_strict_usize(height); + let log_lifted_height = log_height.max(l_skip); + total_interactions += (num_interactions as u64) << log_lifted_height; + (num_interactions, log_lifted_height) + }) + .collect(); + // Implicitly, the width of this stacking should be 1 + let n_logup = calculate_n_logup(l_skip, total_interactions); + debug!(%n_logup); + // There's no stride threshold for `interactions_layout` because there's no univariate skip for + // GKR + let interactions_layout = StackedLayout::new(0, l_skip + n_logup, interactions_meta); + + // Grind to increase soundness of random sampling for LogUp + let logup_pow_witness = transcript.grind(mpk.params.logup.pow_bits); + let alpha_logup = transcript.sample_ext(); + let beta_logup = transcript.sample_ext(); + debug!(%alpha_logup, %beta_logup); + + let mut prover = LogupZerocheckCpu::new( + mpk, + ctx, + n_logup, + interactions_layout, + alpha_logup, + beta_logup, + ); + // GKR + // Compute logup input layer: these are the evaluations of \hat{p}, \hat{q} on the hypercube + // `H_{l_skip + n_logup}` + let has_interactions = !prover.interactions_layout.sorted_cols.is_empty(); + let gkr_input_evals = if !has_interactions { + vec![] + } else { + // Per trace, a row major matrix of interaction evaluations + // NOTE: these are the evaluations _without_ lifting + // PERF[jpw]: we should write directly to the stacked `evals` in memory below + let unstacked_interaction_evals = prover + .eval_helpers + .par_iter() + .enumerate() + .map(|(trace_idx, helper)| { + let trace_ctx = &ctx.per_trace[trace_idx].1; + let mats = helper.view_mats(trace_ctx); + let height = trace_ctx.common_main.height(); + (0..height) + .into_par_iter() + .map(|i| { + let mut row_parts = Vec::with_capacity(mats.len() + 1); + let is_first = F::from_bool(i == 0); + let is_transition = F::from_bool(i != height - 1); + let is_last = F::from_bool(i == height - 1); + let sels = vec![is_first, is_transition, is_last]; + row_parts.push(sels); + for (mat, is_rot) in &mats { + let offset = usize::from(*is_rot); + row_parts.push( + // SAFETY: %height ensures we never go out of bounds + (0..mat.width()) + .map(|j| unsafe { + *mat.get_unchecked((i + offset) % height, j) + }) + .collect_vec(), + ); + } + helper.eval_interactions(&row_parts, &prover.beta_pows) + }) + .collect::>() + }) + .collect::>(); + let mut evals = vec![Frac::default(); 1 << (l_skip + n_logup)]; + for (trace_idx, interaction_idx, s) in + prover.interactions_layout.sorted_cols.iter().copied() + { + let pq_evals = &unstacked_interaction_evals[trace_idx]; + let height = pq_evals.len(); + debug_assert_eq!(s.col_idx, 0); + // the interactions layout has internal striding threshold=0 + debug_assert_eq!(1 << s.log_height(), s.len(0)); + debug_assert_eq!(s.len(0) % height, 0); + let norm_factor_denom = s.len(0) / height; + let norm_factor = F::from_canonical_usize(norm_factor_denom).inverse(); + // We need to fill `evals` with the logup evaluations on the lifted trace, which is + // the same as cyclic repeating of the unlifted evaluations + evals[s.row_idx..s.row_idx + s.len(0)] + .chunks_exact_mut(height) + .for_each(|evals| { + evals + .par_iter_mut() + .zip(pq_evals) + .for_each(|(pq_eval, evals_at_z)| { + let (mut numer, denom) = evals_at_z[interaction_idx]; + numer *= norm_factor; + *pq_eval = Frac::new(numer.into(), denom); + }); + }); + } + // Prevent division by zero: + evals.par_iter_mut().for_each(|frac| frac.q += alpha_logup); + evals + }; + + let (frac_sum_proof, mut xi) = fractional_sumcheck(transcript, &gkr_input_evals, true); + + // Sample more for `\xi` in the edge case that some AIRs don't have interactions + let n_global = max(n_max, n_logup); + debug!(%n_global); + while xi.len() != l_skip + n_global { + xi.push(transcript.sample_ext()); + } + debug!(?xi); + prover.xi = xi; + // we now have full \xi vector + + // begin batch sumcheck + let mut sumcheck_round_polys = Vec::with_capacity(n_max); + let mut r = Vec::with_capacity(n_max + 1); + // batching randomness + let lambda = transcript.sample_ext(); + debug!(%lambda); + + let sp_0_polys = prover.sumcheck_uni_round0_polys(ctx, lambda); + let sp_0_deg = sumcheck_round0_deg(l_skip, constraint_degree); + let s_deg = constraint_degree + 1; + let s_0_deg = sumcheck_round0_deg(l_skip, s_deg); + let large_uni_domain = (s_0_deg + 1).next_power_of_two(); + let dft = Radix2BowersSerial; + let s_0_logup_polys = { + let eq_sharp_uni = eq_sharp_uni_poly(&prover.xi[..l_skip]); + let mut eq_coeffs = eq_sharp_uni.into_coeffs(); + eq_coeffs.resize(large_uni_domain, EF::ZERO); + let eq_evals = dft.dft(eq_coeffs); + + let width = 2 * num_traces; + let mut sp_coeffs_mat = EF::zero_vec(width * large_uni_domain); + for (i, coeffs) in sp_0_polys[..2 * num_traces].iter().enumerate() { + // NOTE: coeffs could have length longer than `sp_0_deg + 1` due to coset evaluation, + // but trailing coefficients should be zero. + for (j, &c_j) in coeffs.coeffs().iter().enumerate().take(sp_0_deg + 1) { + // SAFETY: + // - coeffs length is <= sp_0_deg + 1 <= s_0_deg < large_uni_domain + // - sp_coeffs_mat allocated for width + unsafe { + *sp_coeffs_mat.get_unchecked_mut(j * width + i) = c_j; + } + } + } + let mut s_evals = dft.dft_batch(RowMajorMatrix::new(sp_coeffs_mat, width)); + for (eq, row) in zip(eq_evals, s_evals.values.chunks_mut(width)) { + for x in row { + *x *= eq; + } + } + dft.idft_batch(s_evals) + }; + + let skip_domain_size = F::from_canonical_usize(1 << l_skip); + // logup sum claims (sum_{\hat p}, sum_{\hat q}) per present AIR + let (numerator_term_per_air, denominator_term_per_air): (Vec<_>, Vec<_>) = (0..num_traces) + .map(|trace_idx| { + let [sum_claim_p, sum_claim_q] = [0, 1].map(|is_denom| { + // Compute sum over D of s_0(Z) to get the sum claim + (0..=s_0_deg) + .step_by(1 << l_skip) + .map(|j| unsafe { + // SAFETY: matrix is 2 * num_trace x large_uni_domain, s_0_deg < + // large_uni_domain + *s_0_logup_polys + .values + .get_unchecked(j * 2 * num_traces + 2 * trace_idx + is_denom) + }) + .sum::() + * skip_domain_size + }); + transcript.observe_ext(sum_claim_p); + transcript.observe_ext(sum_claim_q); + + (sum_claim_p, sum_claim_q) + }) + .unzip(); + + let mu = transcript.sample_ext(); + debug!(%mu); + let mu_pows = mu.powers().take(3 * num_traces).collect_vec(); + + let s_0_zc_poly = { + let eq_uni = eq_uni_poly::(l_skip, prover.xi[0]); + let mut eq_coeffs = eq_uni.into_coeffs(); + eq_coeffs.resize(large_uni_domain, EF::ZERO); + let eq_evals = dft.dft(eq_coeffs); + + let mut sp_coeffs = EF::zero_vec(large_uni_domain); + let mus = &mu_pows[2 * num_traces..]; + let polys = &sp_0_polys[2 * num_traces..]; + for (j, batch_coeff) in sp_coeffs.iter_mut().enumerate().take(sp_0_deg + 1) { + for (&mu, poly) in zip(mus, polys) { + *batch_coeff += mu * *poly.coeffs().get(j).unwrap_or(&EF::ZERO); + } + } + let mut s_evals = dft.dft(sp_coeffs); + for (eq, x) in zip(eq_evals, &mut s_evals) { + *x *= eq; + } + dft.idft(s_evals) + }; + + // Algebraically batch + let s_0_poly = UnivariatePoly::new( + zip( + s_0_logup_polys.values.chunks_exact(2 * num_traces), + s_0_zc_poly, + ) + .take(s_0_deg + 1) + .map(|(logup_row, batched_zc)| { + let coeff = batched_zc + + zip(&mu_pows, logup_row) + .map(|(&mu_j, &x)| mu_j * x) + .sum::(); + transcript.observe_ext(coeff); + coeff + }) + .collect(), + ); + + let r_0 = transcript.sample_ext(); + r.push(r_0); + debug!(round = 0, r_round = %r_0); + prover.prev_s_eval = s_0_poly.eval_at_point(r_0); + debug!("s_0(r_0) = {}", prover.prev_s_eval); + + prover.fold_ple_evals(ctx, r_0); + + // Sumcheck rounds: + // - each round the prover needs to compute univariate polynomial `s_round`. This poly is linear + // since we are taking MLE of `evals`. + // - at end of each round, sample random `r_round` in `EF` + // + // `s_round` is degree `s_deg` so we evaluate it at `0, ..., =s_deg`. The prover skips + // evaluation at `0` because the verifier can infer it from the previous round's + // `s_{round-1}(r)` claim. The degree is constraint_degree + 1, where + 1 is from eq term + let _mle_rounds_span = + info_span!("prover.batch_constraints.mle_rounds", phase = "prover").entered(); + debug!(%s_deg); + for round in 1..=n_max { + let sp_round_evals = prover.sumcheck_polys_eval(round, r[round - 1]); + // From s'_T above, we can form s'_head(X) and s'_tail where s'_tail is constant + // The desired polynomial s(X) for this round `j` is + // s(X) = eq(\vec xi, \vec r_{j-1}) eq(xi_{}, X) s'_head(X) + s'_tail * X + // + // The head vs tail corresponds to the cutoff in front loaded batching where the coordinates + // have been exhausted. + // + // In fact, we further need to split s'_head into s'_{head,zc} and s'_{head,logup} due to + // different eq versus eq_sharp round 0 contributions. + let tail_start = prover + .n_per_trace + .iter() + .find_position(|&&n| round as isize > n) + .map(|(i, _)| i) + .unwrap_or(num_traces); + let mut sp_head_zc = vec![EF::ZERO; constraint_degree]; + let mut sp_head_logup = vec![EF::ZERO; constraint_degree]; + let mut sp_tail = EF::ZERO; + for trace_idx in 0..num_traces { + let zc_idx = 2 * num_traces + trace_idx; + let numer_idx = 2 * trace_idx; + let denom_idx = numer_idx + 1; + if trace_idx < tail_start { + for i in 0..constraint_degree { + sp_head_zc[i] += mu_pows[zc_idx] * sp_round_evals[zc_idx][i]; + sp_head_logup[i] += mu_pows[numer_idx] * sp_round_evals[numer_idx][i] + + mu_pows[denom_idx] * sp_round_evals[denom_idx][i]; + } + } else { + sp_tail += mu_pows[zc_idx] * sp_round_evals[zc_idx][0] + + mu_pows[numer_idx] * sp_round_evals[numer_idx][0] + + mu_pows[denom_idx] * sp_round_evals[denom_idx][0]; + } + } + // With eq(xi,r) contributions + let mut sp_head_evals = vec![EF::ZERO; s_deg]; + for i in 0..constraint_degree { + sp_head_evals[i + 1] = prover.eq_ns[round - 1] * sp_head_zc[i] + + prover.eq_sharp_ns[round - 1] * sp_head_logup[i]; + } + // We need to derive s'(0). + // We use that s_j(0) + s_j(1) = s_{j-1}(r_{j-1}) + let xi_cur = prover.xi[l_skip + round - 1]; + { + let eq_xi_0 = EF::ONE - xi_cur; + let eq_xi_1 = xi_cur; + sp_head_evals[0] = + (prover.prev_s_eval - eq_xi_1 * sp_head_evals[1] - sp_tail) * eq_xi_0.inverse(); + } + // s' has degree s_deg - 1 + let sp_head = UnivariatePoly::lagrange_interpolate( + &(0..s_deg).map(F::from_canonical_usize).collect_vec(), + &sp_head_evals, + ); + // eq(xi, X) = (2 * xi - 1) * X + (1 - xi) + // Compute s(X) = eq(xi, X) * s'_head(X) + s'_tail * X (s'_head now contains eq(..,r)) + // s(X) has degree s_deg + let batch_s = { + let mut coeffs = sp_head.into_coeffs(); + coeffs.push(EF::ZERO); + let b = EF::ONE - xi_cur; + let a = xi_cur - b; + for i in (0..s_deg).rev() { + coeffs[i + 1] = a * coeffs[i] + b * coeffs[i + 1]; + } + coeffs[0] *= b; + coeffs[1] += sp_tail; + UnivariatePoly::new(coeffs) + }; + let batch_s_evals = (1..=s_deg) + .map(|i| batch_s.eval_at_point(EF::from_canonical_usize(i))) + .collect_vec(); + for &eval in &batch_s_evals { + transcript.observe_ext(eval); + } + sumcheck_round_polys.push(batch_s_evals); + + let r_round = transcript.sample_ext(); + debug!(%round, %r_round); + r.push(r_round); + prover.prev_s_eval = batch_s.eval_at_point(r_round); + + prover.fold_mle_evals(round, r_round); + } + drop(_mle_rounds_span); + assert_eq!(r.len(), n_max + 1); + + let column_openings = prover.into_column_openings(); + + // Observe common main openings first, and then preprocessed/cached + for (helper, openings) in prover.eval_helpers.iter().zip(column_openings.iter()) { + for (claim, claim_rot) in column_openings_by_rot(&openings[0], helper.needs_next) { + transcript.observe_ext(claim); + transcript.observe_ext(claim_rot); + } + } + for (helper, openings) in prover.eval_helpers.iter().zip(column_openings.iter()) { + for part in openings.iter().skip(1) { + for (claim, claim_rot) in column_openings_by_rot(part, helper.needs_next) { + transcript.observe_ext(claim); + transcript.observe_ext(claim_rot); + } + } + } + + let batch_constraint_proof = BatchConstraintProof { + numerator_term_per_air, + denominator_term_per_air, + univariate_round_coeffs: s_0_poly.into_coeffs(), + sumcheck_round_polys, + column_openings, + }; + let gkr_proof = GkrProof { + logup_pow_witness, + q0_claim: frac_sum_proof.fractional_sum.1, + claims_per_layer: frac_sum_proof.claims_per_layer, + sumcheck_polys: frac_sum_proof.sumcheck_polys, + }; + (gkr_proof, batch_constraint_proof, r) +} diff --git a/crates/stark-backend-v2/src/prover/logup_zerocheck/single.rs b/crates/stark-backend-v2/src/prover/logup_zerocheck/single.rs new file mode 100644 index 00000000..2f9e413b --- /dev/null +++ b/crates/stark-backend-v2/src/prover/logup_zerocheck/single.rs @@ -0,0 +1,187 @@ +//! Single AIR constraint evaluation helpers + +use std::iter::zip; + +use itertools::Itertools; +use openvm_stark_backend::{ + air_builders::symbolic::{symbolic_expression::SymbolicEvaluator, SymbolicExpressionDag}, + interaction::SymbolicInteraction, +}; +use p3_field::{ExtensionField, TwoAdicField}; + +use crate::prover::{ + logup_zerocheck::evaluator::{ProverConstraintEvaluator, ViewPair}, + AirProvingContextV2, CpuBackendV2, StridedColMajorMatrixView, +}; + +/// For a single AIR +pub struct EvalHelper<'a, F> { + /// AIR constraints + pub constraints_dag: &'a SymbolicExpressionDag, + /// Interactions + pub interactions: Vec>, + pub public_values: Vec, + pub preprocessed_trace: Option>, + // TODO: skip rotation if vk dictates it is never used + pub needs_next: bool, + pub constraint_degree: u8, +} + +impl<'a> EvalHelper<'a, crate::F> { + /// Returns list of (ref to column-major matrix, is_rot) pairs in the order: + /// - (if has_preprocessed) (preprocessed, false), (preprocessed, true) + /// - (cached_0, false), (cached_0, true), ..., (cached_{m-1}, false), (cached_{m-1}, true) + /// - (common, false), (common, true) + /// + /// Note: currently every matrix returns both non-rotated and rotated versions. This will change + /// in the future for perf. + pub fn view_mats( + &self, + ctx: &'a AirProvingContextV2, + ) -> Vec<(StridedColMajorMatrixView<'a, crate::F>, bool)> { + let base_mats = usize::from(self.has_preprocessed()) + 1 + ctx.cached_mains.len(); + let mut mats = Vec::with_capacity(if self.needs_next { + 2 * base_mats + } else { + base_mats + }); + if let Some(mat) = self.preprocessed_trace { + mats.push((mat, false)); + if self.needs_next { + mats.push((mat, true)); + } + } + for cd in ctx.cached_mains.iter() { + let trace_view = cd.data.mat_view(0); + mats.push((trace_view, false)); + if self.needs_next { + mats.push((trace_view, true)); + } + } + mats.push((ctx.common_main.as_view().into(), false)); + if self.needs_next { + mats.push((ctx.common_main.as_view().into(), true)); + } + mats + } +} + +impl EvalHelper<'_, F> { + pub fn has_preprocessed(&self) -> bool { + self.preprocessed_trace.is_some() + } + + /// See [Self::evaluator]. + // Assumes that `z[0] != 1` or `omega_D^{-1}` to avoid handling division by zero. + pub fn acc_constraints, EF: ExtensionField>( + &self, + row_parts: &[Vec], + lambda_pows: &[EF], + ) -> EF { + let evaluator = self.evaluator(row_parts); + let nodes = evaluator.eval_nodes(&self.constraints_dag.nodes); + zip(lambda_pows, &self.constraints_dag.constraint_idx) + .fold(EF::ZERO, |acc, (&lambda_pow, &idx)| { + acc + lambda_pow * nodes[idx] + }) + } + + /// See [Self::evaluator]. + /// + /// Returns sum of ordered list of `interactions`, weighted by `eq(\xi_3, b_{T,\hat\sigma})` + /// terms as (numerator, denominator) pair. + /// + /// Note: the denominator does not include the `alpha` term. + pub fn acc_interactions( + &self, + row_parts: &[Vec], + beta_pows: &[EF], + eq_3bs: &[EF], + ) -> [EF; 2] + where + FF: ExtensionField, + EF: ExtensionField + ExtensionField, + { + // PERF[jpw]: no need to collect the vec, but I ran into a lifetime issue returning iterator + // in `eval_interactions` + let interaction_evals = self.eval_interactions(row_parts, beta_pows); + let mut numer = EF::ZERO; + let mut denom = EF::ZERO; // without alpha term + for (&eq_3b, eval) in zip(eq_3bs, interaction_evals) { + numer += eq_3b * eval.0; + denom += eq_3b * eval.1; + } + [numer, denom] + } + + pub fn eval_interactions( + &self, + row_parts: &[Vec], + beta_pows: &[EF], + ) -> Vec<(FF, EF)> + where + FF: ExtensionField, + EF: ExtensionField + ExtensionField, + { + let evaluator = self.evaluator(row_parts); + self.interactions + .iter() + .map(|interaction| { + let b = F::from_canonical_u32(interaction.bus_index as u32 + 1); + let msg_len = interaction.message.len(); + assert!(msg_len <= beta_pows.len()); + let denom = zip(&interaction.message, beta_pows).fold( + beta_pows[msg_len] * b, + |h_beta, (msg_j, &beta_j)| { + let msg_j_eval = evaluator.eval_expr(msg_j); + h_beta + beta_j * msg_j_eval + }, + ); + let numer = evaluator.eval_expr(&interaction.count); + (numer, denom) + }) + .collect() + } + + // `row_parts` should have separate Vec in following order: + // - selectors [is_first_row, is_transition, is_last_row] + // - (if has_preprocessed) preprocessed + // - (if has_preprocessed) preprocessed_rot + // - cached_0 + // - cached_0_rot + // - ... + // - common + // - common_rot + fn evaluator>( + &self, + row_parts: &[Vec], + ) -> ProverConstraintEvaluator<'_, F, FF> { + let sels = &row_parts[0]; + let mut view_pairs = if self.needs_next { + let mut chunks = row_parts[1..].chunks_exact(2); + let pairs = chunks + .by_ref() + .map(|pair| ViewPair::new(&pair[0], Some(&pair[1][..]))) + .collect_vec(); + debug_assert!(chunks.remainder().is_empty()); + pairs + } else { + row_parts[1..] + .iter() + .map(|part| ViewPair::new(part, None)) + .collect_vec() + }; + let mut preprocessed = None; + if self.has_preprocessed() { + preprocessed = Some(view_pairs.remove(0)); + } + ProverConstraintEvaluator { + preprocessed, + partitioned_main: view_pairs, + is_first_row: sels[0], + is_transition: sels[1], + is_last_row: sels[2], + public_values: &self.public_values, + } + } +} diff --git a/crates/stark-backend-v2/src/prover/matrix.rs b/crates/stark-backend-v2/src/prover/matrix.rs new file mode 100644 index 00000000..674f3afe --- /dev/null +++ b/crates/stark-backend-v2/src/prover/matrix.rs @@ -0,0 +1,255 @@ +use getset::CopyGetters; +use openvm_stark_backend::{ + p3_matrix::{dense::RowMajorMatrix, Matrix}, + prover::MatrixDimensions, +}; +use p3_field::Field; +use p3_maybe_rayon::prelude::*; +use serde::{Deserialize, Serialize}; + +/// Trait to consolidate different virtual matrices that may not be backed by an in-memory matrix +/// with a standard column-major or row-major layout. +/// +/// Note: for performance, users should still be aware of the underlying memory layout. This trait +/// is just an organizational convenience. +pub trait MatrixView: MatrixDimensions { + fn get(&self, row_idx: usize, col_idx: usize) -> Option<&F> { + if col_idx >= self.width() || row_idx >= self.height() { + None + } else { + // SAFETY: bounds checked above + Some(unsafe { self.get_unchecked(row_idx, col_idx) }) + } + } + /// Get a reference to an element without bounds checking. + /// + /// # Safety + /// + /// The caller must ensure that `col_idx` and `row_idx` are within the bounds of the matrix. + unsafe fn get_unchecked(&self, row_idx: usize, col_idx: usize) -> &F; +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ColMajorMatrix { + pub values: Vec, + width: usize, + height: usize, +} + +#[derive(Clone, Copy, Debug)] +pub struct ColMajorMatrixView<'a, F> { + pub values: &'a [F], + width: usize, + height: usize, +} + +/// Vertically strided column-major matrix view. +#[derive(Clone, Copy, Debug, CopyGetters)] +pub struct StridedColMajorMatrixView<'a, F> { + #[getset(get_copy = "pub")] + values: &'a [F], + width: usize, + height: usize, + /// Row stride + #[getset(get_copy = "pub")] + stride: usize, +} + +impl ColMajorMatrix { + pub fn new(values: Vec, width: usize) -> Self { + assert_eq!(values.len() % width, 0); + let height = values.len() / width; + assert!(height.is_power_of_two()); + Self { + values, + width, + height, + } + } + + pub(crate) fn dummy() -> Self { + Self { + values: vec![], + width: 0, + height: 0, + } + } + + pub fn column(&self, col_idx: usize) -> &[F] { + let start = col_idx * self.height; + let end = start + self.height; + &self.values[start..end] + } + + pub fn columns(&self) -> impl Iterator { + self.values.chunks_exact(self.height) + } + + pub fn as_view(&self) -> ColMajorMatrixView<'_, F> { + ColMajorMatrixView { + values: &self.values, + width: self.width, + height: self.height, + } + } + + pub fn from_row_major(mat: &RowMajorMatrix) -> Self + where + F: Field, + { + let mut values = F::zero_vec(mat.values.len()); + let width = mat.width; + let height = mat.height(); + values.par_iter_mut().enumerate().for_each(|(idx, value)| { + let r = idx % height; + let c = idx / height; + // SAFETY: index is in bounds for row-major matrix + *value = unsafe { *mat.values.get_unchecked(r * width + c) }; + }); + Self { + values, + width, + height, + } + } +} + +impl ColMajorMatrix { + pub fn par_columns(&self) -> impl ParallelIterator { + self.values.par_chunks_exact(self.height) + } +} + +impl MatrixDimensions for ColMajorMatrix { + fn width(&self) -> usize { + self.width + } + fn height(&self) -> usize { + self.height + } +} + +impl MatrixView for ColMajorMatrix { + unsafe fn get_unchecked(&self, row_idx: usize, col_idx: usize) -> &F { + debug_assert!(col_idx < self.width); + debug_assert!(row_idx < self.height); + self.values.get_unchecked(col_idx * self.height + row_idx) + } +} + +impl<'a, F> ColMajorMatrixView<'a, F> { + pub fn new(values: &'a [F], width: usize) -> Self { + assert_eq!(values.len() % width, 0); + let height = values.len() / width; + debug_assert!(height == 0 || height.is_power_of_two()); + Self { + values, + width, + height, + } + } + + pub fn column(&self, col_idx: usize) -> &[F] { + let start = col_idx * self.height; + let end = start + self.height; + &self.values[start..end] + } + + pub fn columns(&self) -> impl Iterator { + self.values.chunks_exact(self.height) + } +} + +impl MatrixDimensions for ColMajorMatrixView<'_, F> { + fn width(&self) -> usize { + self.width + } + fn height(&self) -> usize { + self.height + } +} + +impl MatrixView for ColMajorMatrixView<'_, F> { + unsafe fn get_unchecked(&self, row_idx: usize, col_idx: usize) -> &F { + debug_assert!(col_idx < self.width); + debug_assert!(row_idx < self.height); + self.values + .get_unchecked(col_maj_idx(row_idx, col_idx, self.height)) + } +} + +impl<'a, F> StridedColMajorMatrixView<'a, F> { + pub fn new(values: &'a [F], width: usize, stride: usize) -> Self { + assert_eq!(values.len() % (width * stride), 0); + let height = values.len() / (width * stride); + debug_assert!(height == 0 || height.is_power_of_two()); + Self { + values, + width, + height, + stride, + } + } + + pub fn to_matrix(&self) -> ColMajorMatrix + where + F: Field, + { + let values: Vec<_> = (0..self.width * self.height) + .into_par_iter() + .map(|i| { + let r = i % self.height; + let c = i / self.height; + unsafe { *self.get_unchecked(r, c) } + }) + .collect(); + ColMajorMatrix::new(values, self.width) + } + + pub fn to_row_major_matrix(&self) -> RowMajorMatrix + where + F: Field, + { + let values: Vec<_> = (0..self.width * self.height) + .into_par_iter() + .map(|i| { + let r = i / self.width; + let c = i % self.width; + unsafe { *self.get_unchecked(r, c) } + }) + .collect(); + RowMajorMatrix::new(values, self.width) + } +} + +impl MatrixDimensions for StridedColMajorMatrixView<'_, F> { + fn width(&self) -> usize { + self.width + } + fn height(&self) -> usize { + self.height + } +} + +impl MatrixView for StridedColMajorMatrixView<'_, F> { + unsafe fn get_unchecked(&self, row_idx: usize, col_idx: usize) -> &F { + debug_assert!(col_idx < self.width); + debug_assert!(row_idx < self.height); + self.values.get_unchecked(col_maj_idx( + row_idx * self.stride, + col_idx, + self.height * self.stride, + )) + } +} + +impl<'a, F> From> for StridedColMajorMatrixView<'a, F> { + fn from(mat: ColMajorMatrixView<'a, F>) -> Self { + Self::new(mat.values, mat.width, 1) + } +} + +#[inline(always)] +pub(crate) fn col_maj_idx(row_idx: usize, col_idx: usize, height: usize) -> usize { + col_idx * height + row_idx +} diff --git a/crates/stark-backend-v2/src/prover/metrics.rs b/crates/stark-backend-v2/src/prover/metrics.rs new file mode 100644 index 00000000..13f25d72 --- /dev/null +++ b/crates/stark-backend-v2/src/prover/metrics.rs @@ -0,0 +1,224 @@ +use std::fmt::Display; + +use itertools::zip_eq; +use openvm_stark_backend::keygen::types::TraceWidth; +use serde::{Deserialize, Serialize}; +use tracing::{debug, info}; + +use crate::{ + proof::TraceVData, + prover::{DeviceMultiStarkProvingKeyV2, ProverBackendV2}, +}; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct TraceMetrics { + pub per_air: Vec, + /// Total base field cells from all traces, excludes preprocessed. + pub total_cells: usize, + /// For each trace height constraint, the (weighted sum, threshold) + pub trace_height_inequalities: Vec<(usize, usize)>, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct SingleTraceMetrics { + pub air_name: String, + pub air_id: usize, + pub height: usize, + /// The after challenge width is adjusted to be in terms of **base field** elements. + pub width: TraceWidth, + pub cells: TraceCells, + // TODO[jpw]: update this calculation accordingly + /// Omitting preprocessed trace, the total base field cells from main and after challenge + /// traces. + pub total_cells: usize, +} + +/// Trace cells, counted in terms of number of **base field** elements. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct TraceCells { + pub preprocessed: Option, + pub cached_mains: Vec, + pub common_main: usize, + pub after_challenge: Vec, +} + +impl Display for TraceMetrics { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for (i, (weighted_sum, threshold)) in self.trace_height_inequalities.iter().enumerate() { + writeln!( + f, + "trace_height_constraint_{i} | weighted_sum = {:<10} | threshold = {:<10}", + format_number_with_underscores(*weighted_sum), + format_number_with_underscores(*threshold) + )?; + } + for trace_metrics in &self.per_air { + writeln!(f, "{}", trace_metrics)?; + } + Ok(()) + } +} + +impl Display for SingleTraceMetrics { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{:<20} | Rows = {:<10} | Cells = {:<11} | Prep Cols = {:<5} | Main Cols = {:<5} | Perm Cols = {:<5}", + self.air_name, format_number_with_underscores(self.height), format_number_with_underscores(self.total_cells), self.width.preprocessed.unwrap_or(0), + format!("{:?}", self.width.main_widths()), + format!("{:?}",self.width.after_challenge), + )?; + Ok(()) + } +} + +/// heights are the trace heights for each air +pub fn trace_metrics( + mpk: &DeviceMultiStarkProvingKeyV2, + trace_vdata: &[Option], +) -> TraceMetrics { + let heights = trace_vdata + .iter() + .map(|vdata| vdata.as_ref().map(|v| 1 << v.log_height).unwrap_or(0)) + .collect::>(); + let trace_height_inequalities = mpk + .trace_height_constraints + .iter() + .map(|trace_height_constraint| { + let weighted_sum = heights + .iter() + .enumerate() + .map(|(air_idx, h)| (trace_height_constraint.coefficients[air_idx] as usize) * h) + .sum::(); + (weighted_sum, trace_height_constraint.threshold as usize) + }) + .collect::>(); + let per_air: Vec<_> = zip_eq(&mpk.per_air, heights) + .enumerate() + .filter(|(_, (_, height))| *height > 0) + .map(|(air_idx, (pk, height))| { + let air_name = &pk.air_name; + let width = pk.vk.params.width.clone(); + let mut interaction_width = pk.vk.num_interactions(); + let ext_degree = PB::CHALLENGE_EXT_DEGREE as usize; + interaction_width *= ext_degree; + let cells = TraceCells { + preprocessed: width.preprocessed.map(|w| w * height), + cached_mains: width.cached_mains.iter().map(|w| w * height).collect(), + common_main: width.common_main * height, + after_challenge: vec![interaction_width * height], + }; + let total_cells = cells + .cached_mains + .iter() + .chain([&cells.common_main]) + .chain(cells.after_challenge.iter()) + .sum::(); + SingleTraceMetrics { + air_name: air_name.to_string(), + air_id: air_idx, + height, + width, + cells, + total_cells, + } + }) + .collect(); + let total_cells = per_air.iter().map(|m| m.total_cells).sum(); + let metrics = TraceMetrics { + per_air, + total_cells, + trace_height_inequalities, + }; + info!( + "total_trace_cells = {} (excluding preprocessed)", + format_number_with_underscores(metrics.total_cells) + ); + info!( + "preprocessed_trace_cells = {}", + format_number_with_underscores( + metrics + .per_air + .iter() + .map(|m| m.cells.preprocessed.unwrap_or(0)) + .sum::() + ) + ); + info!( + "main_trace_cells = {}", + format_number_with_underscores( + metrics + .per_air + .iter() + .map(|m| m.cells.cached_mains.iter().sum::() + m.cells.common_main) + .sum::() + ) + ); + info!( + "perm_trace_cells = {}", + format_number_with_underscores( + metrics + .per_air + .iter() + .map(|m| m.cells.after_challenge.iter().sum::()) + .sum::() + ) + ); + debug!("{}", metrics); + metrics +} + +pub fn format_number_with_underscores(n: usize) -> String { + let num_str = n.to_string(); + let mut result = String::new(); + + // Start adding characters from the end of num_str + for (i, c) in num_str.chars().rev().enumerate() { + if i > 0 && i % 3 == 0 { + result.push('_'); + } + result.push(c); + } + + // Reverse the result to get the correct order + result.chars().rev().collect() +} + +#[cfg(feature = "metrics")] +mod emit { + use metrics::counter; + + use super::{SingleTraceMetrics, TraceMetrics}; + + impl TraceMetrics { + pub fn emit(&self) { + for (i, (weighted_sum, threshold)) in self.trace_height_inequalities.iter().enumerate() + { + let labels = [("trace_height_constraint", i.to_string())]; + counter!("weighted_sum", &labels).absolute(*weighted_sum as u64); + counter!("threshold", &labels).absolute(*threshold as u64); + } + for trace_metrics in &self.per_air { + trace_metrics.emit(); + } + counter!("total_cells").absolute(self.total_cells as u64); + } + } + + impl SingleTraceMetrics { + pub fn emit(&self) { + let labels = [ + ("air_name", self.air_name.clone()), + ("air_id", self.air_id.to_string()), + ]; + counter!("rows", &labels).absolute(self.height as u64); + counter!("cells", &labels).absolute(self.total_cells as u64); + counter!("prep_cols", &labels).absolute(self.width.preprocessed.unwrap_or(0) as u64); + counter!("main_cols", &labels).absolute( + (self.width.cached_mains.iter().sum::() + self.width.common_main) as u64, + ); + counter!("perm_cols", &labels) + .absolute(self.width.after_challenge.iter().sum::() as u64); + } + } +} diff --git a/crates/stark-backend-v2/src/prover/mod.rs b/crates/stark-backend-v2/src/prover/mod.rs new file mode 100644 index 00000000..f675b287 --- /dev/null +++ b/crates/stark-backend-v2/src/prover/mod.rs @@ -0,0 +1,168 @@ +use itertools::{izip, Itertools}; +use openvm_stark_backend::prover::{MatrixDimensions, Prover}; +use p3_field::FieldAlgebra; +use p3_util::log2_strict_usize; +use tracing::{info, info_span, instrument}; + +#[cfg(feature = "metrics")] +use crate::prover::metrics::trace_metrics; +use crate::{ + poseidon2::sponge::FiatShamirTranscript, + proof::{BatchConstraintProof, GkrProof, Proof, StackingProof, TraceVData, WhirProof}, + Digest, EF, F, +}; + +mod cpu_backend; +mod hal; +mod logup_zerocheck; +mod matrix; +pub mod metrics; +pub mod poly; +pub mod stacked_pcs; +pub mod stacked_reduction; +pub mod sumcheck; +mod types; +pub mod whir; + +pub use cpu_backend::*; +pub use hal::*; +pub use logup_zerocheck::*; +pub use matrix::*; +pub use types::*; + +#[derive(derive_new::new)] +pub struct CoordinatorV2 { + pub backend: PB, + pub device: PD, + pub(crate) transcript: TS, +} + +impl Prover for CoordinatorV2 +where + // TODO[jpw]: make generic in F, EF, Commitment + PB: ProverBackendV2, + PD: ProverDeviceV2, + PD::Artifacts: Into, + PD::PartialProof: Into<(GkrProof, BatchConstraintProof)>, + PD::OpeningProof: Into<(StackingProof, WhirProof)>, + TS: FiatShamirTranscript, +{ + type Proof = Proof; + type ProvingKeyView<'a> + = &'a DeviceMultiStarkProvingKeyV2 + where + Self: 'a; + + type ProvingContext<'a> + = ProvingContextV2 + where + Self: 'a; + + /// Specialized prove for InteractiveAirs. + /// Handles trace generation of the permutation traces. + /// Assumes the main traces have been generated and committed already. + /// + /// The [DeviceMultiStarkProvingKey] should already be filtered to only include the relevant + /// AIR's proving keys. + #[instrument( + name = "stark_prove_excluding_trace", + level = "info", + skip_all, + fields(phase = "prover") + )] + fn prove<'a>( + &'a mut self, + mpk: &'a DeviceMultiStarkProvingKeyV2, + unsorted_ctx: ProvingContextV2, + ) -> Self::Proof { + assert_eq!(self.device.config(), &mpk.params); + let transcript = &mut self.transcript; + transcript.observe_commit(mpk.vk_pre_hash); + + let ctx = unsorted_ctx.into_sorted(); + // `ctx` should NOT be permuted anymore: the ordering by `trace_idx` is now fixed. + + let num_airs_present = ctx.per_trace.len(); + info!(num_airs_present); + + let _main_commit_span = info_span!("prover.main_trace_commit", phase = "prover").entered(); + let (common_main_commit, common_main_pcs_data) = { + let traces = ctx + .common_main_traces() + .map(|(_, trace)| trace) + .collect_vec(); + self.device.commit(&traces) + }; + + let mut trace_vdata: Vec> = vec![None; mpk.per_air.len()]; + let mut public_values: Vec> = vec![Vec::new(); mpk.per_air.len()]; + + // Hypercube dimension per trace (present AIR) + for (air_id, air_ctx) in &ctx.per_trace { + let trace_height = air_ctx.common_main.height(); + let log_height = log2_strict_usize(trace_height); + + trace_vdata[*air_id] = Some(TraceVData { + log_height, + cached_commitments: air_ctx + .cached_mains + .iter() + .map(|cd| cd.commitment) + .collect(), + }); + public_values[*air_id] = air_ctx.public_values.clone(); + } + #[cfg(feature = "metrics")] + trace_metrics(mpk, &trace_vdata).emit(); + + // Only observe commits for present AIRs. + // Commitments order: + // - 1 commitment of all common main traces + // - for each air: + // - preprocessed commit if present + // - for each cached main trace + // - 1 commitment + transcript.observe_commit(common_main_commit); + drop(_main_commit_span); + + for (trace_vdata, pvs, pk) in izip!(&trace_vdata, &public_values, &mpk.per_air) { + if !pk.vk.is_required { + transcript.observe(F::from_bool(trace_vdata.is_some())); + } + if let Some(trace_vdata) = trace_vdata { + if let Some(cd) = &pk.preprocessed_data { + transcript.observe_commit(cd.commitment); + } else { + transcript.observe(F::from_canonical_usize(trace_vdata.log_height)); + } + for commit in &trace_vdata.cached_commitments { + transcript.observe_commit(*commit); + } + } + for pv in pvs { + transcript.observe(*pv); + } + } + + let (constraints_proof, r) = + self.device + .prove_rap_constraints(transcript, mpk, &ctx, &common_main_pcs_data); + + let opening_proof = + self.device + .prove_openings(transcript, mpk, ctx, common_main_pcs_data, r.into()); + + let (gkr_proof, batch_constraint_proof) = constraints_proof.into(); + let (stacking_proof, whir_proof) = opening_proof.into(); + + Proof { + public_values, + trace_vdata, + common_main_commit, + gkr_proof, + batch_constraint_proof, + stacking_proof, + whir_proof, + } + } +} diff --git a/crates/stark-backend-v2/src/prover/poly.rs b/crates/stark-backend-v2/src/prover/poly.rs new file mode 100644 index 00000000..77443a3b --- /dev/null +++ b/crates/stark-backend-v2/src/prover/poly.rs @@ -0,0 +1,320 @@ +use core::ops::Mul; +use std::iter::zip; + +use getset::Getters; +use openvm_stark_backend::prover::MatrixDimensions; +use p3_dft::{Radix2Bowers, TwoAdicSubgroupDft}; +use p3_field::{ExtensionField, Field, TwoAdicField}; +use p3_maybe_rayon::prelude::*; +use p3_util::log2_strict_usize; + +use crate::{ + poly_common::eval_eq_uni, + prover::{ColMajorMatrix, ColMajorMatrixView}, +}; + +/// Multilinear extension polynomial, in coefficient form. +/// +/// Length of `coeffs` is `2^n` where `n` is hypercube dimension. +/// Indexing of `coeffs` is to use the little-endian encoding of integer index as the powers of +/// variables. +#[derive(Getters)] +pub struct Mle { + #[getset(get = "pub")] + coeffs: Vec, +} + +impl Mle { + pub fn from_coeffs(coeffs: Vec) -> Self { + Self { coeffs } + } + + /// Create MLE from evaluations on the hypercube. + /// + /// Takes evaluations of the polynomial at all points in {0,1}^n and converts + /// them to coefficient form. + /// + /// The input `evals` should have length 2^n where n is the number of variables. + /// The evaluation at index i corresponds to the point whose binary representation + /// is i (with bit 0 being the least significant). + pub fn from_evaluations(evals: &[F]) -> Self { + assert!(!evals.is_empty(), "Evaluations cannot be empty"); + let mut coeffs = evals.to_vec(); + Self::evals_to_coeffs_inplace(&mut coeffs); + Self { coeffs } + } + + pub fn into_coeffs(self) -> Vec { + self.coeffs + } + + /// Evaluate with `O(1)` extra memory via naive algorithm. + /// + /// Performs `N*log(N)/2` multiplications when `N = x.len()` is a power of two. + pub fn eval_at_point + Mul>( + &self, + x: &[F2], + ) -> EF { + debug_assert_eq!(log2_strict_usize(self.coeffs.len()), x.len()); + let mut res = EF::ZERO; + for (i, coeff) in self.coeffs.iter().enumerate() { + let mut term = EF::from(*coeff); + for (j, x_j) in x.iter().enumerate() { + if (i >> j) & 1 == 1 { + term = term * *x_j; + } + } + res += term; + } + res + } + + /// Evaluate with `O(1)` extra memory but consuming `self`. + /// + /// Performs `N - 1` multiplications for `N = x.len()`. + pub fn eval_at_point_inplace(self, x: &[F2]) -> F + where + F2: Field, + F: ExtensionField, + { + let mut buf = self.coeffs; + debug_assert_eq!(buf.len(), 1 << x.len()); + let mut len = 1usize << x.len(); + // Assumes caller ensured buf[..len] is initialized with the current coefficients. + for &xj in x.iter().rev() { + len >>= 1; + let (left, right) = buf.split_at_mut(len); + for (li, &ri) in zip(left.iter_mut(), right.iter()) { + *li += ri * xj; + } + } + buf[0] + } + + pub fn evals_to_coeffs_inplace(a: &mut [F]) { + let n = log2_strict_usize(a.len()); + // Go through coordinates X_1, ..., X_n and interpolate each one from s(0), s(1) -> s(0) + + // (s(1) - s(0)) X_i + for log_step in 0..n { + let step = 1usize << log_step; + let span = step << 1; + a.par_chunks_exact_mut(span).for_each(|chunk| { + let (first_half, second_half) = chunk.split_at_mut(step); + first_half + .par_iter() + .zip(second_half.par_iter_mut()) + .for_each(|(u, v)| { + *v -= *u; + }); + }); + } + } + + pub fn coeffs_to_evals_inplace(a: &mut [F]) { + let n = log2_strict_usize(a.len()); + + for b in 0..n { + let step = 1usize << b; + let span = step << 1; + for i in (0..a.len()).step_by(span) { + for j in 0..step { + let u = i + j; + let v = u + step; + a[v] += a[u]; + } + } + } + } +} + +/// Given vector `x` in `F^n`, populates `out` with `eq_n(x, y)` for `y` on hypercube `H_n`. +/// +/// The multilinear equality polynomial is defined as: +/// ```text +/// eq(x, z) = \prod_{i=0}^{n-1} (x_i z_i + (1 - x_i)(1 - z_i)). +/// ``` +// Reference: +pub fn evals_eq_hypercube(x: &[F]) -> Vec { + let n = x.len(); + let mut out = F::zero_vec(1 << n); + out[0] = F::ONE; + for (i, &x_i) in x.iter().enumerate() { + let (los, his) = out[..2 << i].split_at_mut(1 << i); + los.par_iter_mut() + .zip(his.par_iter_mut()) + .for_each(|(lo, hi)| { + *hi = *lo * x_i; + *lo *= F::ONE - x_i; + }) + } + out +} + +pub fn evals_eq_hypercube_serial(x: &[F]) -> Vec { + let n = x.len(); + let mut out = F::zero_vec(1 << n); + out[0] = F::ONE; + for (i, &x_i) in x.iter().enumerate() { + let (los, his) = out[..2 << i].split_at_mut(1 << i); + for (lo, hi) in los.iter_mut().zip(his.iter_mut()) { + *hi = *lo * x_i; + *lo *= F::ONE - x_i; + } + } + out +} + +/// Given vector `x` in `F^n`, returns a concatenation of `evals_eq_hypercube(x[..n])` for all valid +/// `n` in order. Also, the order of masks is of different endianness. +pub fn evals_eq_hypercubes<'a, F: Field>(n: usize, x: impl IntoIterator) -> Vec { + let mut out = F::zero_vec((2 << n) - 1); + out[0] = F::ONE; + for (i, &x_i) in x.into_iter().enumerate() { + for y in 0..(1 << i) { + out[(1 << (i + 1)) - 1 + (2 * y + 1)] = out[(1 << i) - 1 + y] * x_i; + out[(1 << (i + 1)) - 1 + (2 * y)] = out[(1 << i) - 1 + y] * (F::ONE - x_i); + } + } + out +} + +/// Given vector `(z,x)` in `F^{n+1}`, populates `out` with `eq_{l_skip,n}(x, y)` for `y` on +/// hyperprism `D_n`. +pub fn evals_eq_hyperprism>( + omega_pows: &[F], + z: EF, + x: &[EF], +) -> Vec { + // Size of D + let d_size = omega_pows.len(); + let l_skip = log2_strict_usize(d_size); + let n = x.len(); + let mut out = EF::zero_vec(d_size << n); + for (omega_pow, eq_uni) in zip(omega_pows, out.iter_mut()) { + *eq_uni = eval_eq_uni(l_skip, z, EF::from(*omega_pow)); + } + for (i, &x_i) in x.iter().enumerate() { + for y in (0..d_size << i).rev() { + let eq_prev = out[y]; + // Don't overwrite in y = 0 case + out[y | (d_size << i)] = eq_prev * x_i; + out[y] = eq_prev * (EF::ONE - x_i); + } + } + out +} + +/// Prismalinear extension polynomial, in coefficient form. +/// +/// Depends on implicit univariate skip parameter `l_skip`. +/// Length of `coeffs` is `2^{l_skip + n}` where `n` is hypercube dimension. +/// Indexing is to decompose `i = i_0 + 2^{l_skip} * (i_1 + 2 * i_2 .. + 2^{n-1} * i_n)` and let +/// `coeffs[i]` be the coefficient of `Z^{i_0} X_1^{i_1} .. X_n^{i_n}`. +#[derive(Getters)] +pub struct Ple { + #[getset(get = "pub")] + pub(crate) coeffs: Vec, +} + +impl Ple { + /// Create PLE from evaluations on the hypercube with univariate skip. + /// + /// Takes evaluations at 2^{l_skip + n} points and converts them to coefficient form + /// for a polynomial in n+1 variables: degree < 2^l_skip in the first variable, + /// degree < 2 (linear) in the other n variables. + /// + /// The input `evals` should have length 2^{l_skip + n}. + /// The evaluation at index i corresponds to: + /// - bits 0 to l_skip-1: univariate point index + /// - bits l_skip to l_skip+n-1: multilinear variable assignments + pub fn from_evaluations(l_skip: usize, evals: &[F]) -> Self { + let prism_dim = log2_strict_usize(evals.len()); + assert!( + prism_dim >= l_skip, + "Total variables must be at least l_skip" + ); + // Go through coordinates Z, X_1, ..., X_n and interpolate each one + // For first Z coordinate, we do parallel iDFT on each 2^l_skip sized chunk + let mut buf: Vec<_> = evals + .par_chunks_exact(1 << l_skip) + .flat_map(|chunk| { + let dft = Radix2Bowers; + dft.idft(chunk.to_vec()) + }) + .collect(); + + let n = prism_dim - l_skip; + // Go through coordinates X_1, ..., X_n and interpolate each one from s(0), s(1) -> s(0) + + // (s(1) - s(0)) X_i + for i in 0..n { + let step = 1usize << (l_skip + i); + let span = step << 1; + buf.par_chunks_exact_mut(span).for_each(|chunk| { + let (first_half, second_half) = chunk.split_at_mut(step); + first_half + .par_iter() + .zip(second_half.par_iter_mut()) + .for_each(|(u, v)| { + *v -= *u; + }); + }); + } + Self { coeffs: buf } + } + + pub fn eval_at_point>(&self, l_skip: usize, z: EF, x: &[EF]) -> EF { + let n = x.len(); + debug_assert_eq!(l_skip + n, log2_strict_usize(self.coeffs.len())); + let mut res = EF::ZERO; + let mut z_pow = EF::ONE; + for (i, coeff) in self.coeffs.iter().enumerate() { + if i.trailing_zeros() >= l_skip as u32 { + z_pow = EF::ONE; + } + let i_x = i >> l_skip; + let mut term = z_pow * *coeff; + for (j, x_j) in x.iter().enumerate() { + if (i_x >> j) & 1 == 1 { + term *= *x_j; + } + } + z_pow *= z; + res += term; + } + res + } + + pub fn into_coeffs(self) -> Vec { + self.coeffs + } +} + +pub struct MleMatrix { + pub columns: Vec>, +} + +impl MleMatrix { + pub fn from_evaluations(evals: &ColMajorMatrix) -> Self { + let width = evals.width(); + let columns = (0..width) + .into_par_iter() + .map(|j| Mle::from_evaluations(evals.column(j))) + .collect(); + Self { columns } + } +} + +pub struct PleMatrix { + pub columns: Vec>, +} + +impl PleMatrix { + pub fn from_evaluations(l_skip: usize, evals: &ColMajorMatrixView) -> Self { + let width = evals.width(); + let columns = (0..width) + .into_par_iter() + .map(|j| Ple::from_evaluations(l_skip, evals.column(j))) + .collect(); + Self { columns } + } +} diff --git a/crates/stark-backend-v2/src/prover/stacked_pcs.rs b/crates/stark-backend-v2/src/prover/stacked_pcs.rs new file mode 100644 index 00000000..1391ada6 --- /dev/null +++ b/crates/stark-backend-v2/src/prover/stacked_pcs.rs @@ -0,0 +1,563 @@ +use getset::{CopyGetters, Getters}; +use itertools::Itertools; +use openvm_stark_backend::prover::MatrixDimensions; +use p3_dft::{Radix2DitParallel, TwoAdicSubgroupDft}; +use p3_field::{Field, TwoAdicField}; +use p3_maybe_rayon::prelude::*; +use p3_util::log2_strict_usize; +use serde::{Deserialize, Serialize}; +use tracing::instrument; + +use crate::{ + prover::{col_maj_idx, poly::Ple, ColMajorMatrix, MatrixView, StridedColMajorMatrixView}, + Digest, F, +}; + +#[derive(Clone, Serialize, Deserialize, Debug, CopyGetters)] +pub struct StackedLayout { + /// The minimum log2 height of a stacked slice. When stacking columns with smaller height, the + /// column is expanded to `2^l_skip` by striding. + #[getset(get_copy = "pub")] + l_skip: usize, + /// Stacked height + #[getset(get_copy = "pub")] + height: usize, + /// Stacked width + #[getset(get_copy = "pub")] + width: usize, + /// The columns of the unstacked matrices in sorted order. Each entry `(matrix index, column + /// index, coordinate)` contains the pointer `(matrix index, column index)` to a column of the + /// unstacked collection of matrices as well as `coordinate` which is a pointer to where the + /// column starts in the stacked matrix. + pub sorted_cols: Vec<( + usize, /* unstacked matrix index */ + usize, /* unstacked column index */ + StackedSlice, + )>, + /// `mat_starts[mat_idx]` is the index in `sorted_cols` where the matrix with index `mat_idx` + /// starts. + pub mat_starts: Vec, +} + +/// Pointer to the location of a sub-column within the stacked matrix. +/// This struct contains length information, but information from [StackedLayout] (namely `l_skip`) +/// is needed to determine if this is a strided slice or not. +#[derive(Copy, Clone, Debug, Serialize, Deserialize, CopyGetters, derive_new::new)] +pub struct StackedSlice { + pub col_idx: usize, + pub row_idx: usize, + /// The true log height. If `>= l_skip`, no striding. Otherwise striding by `2^{l_skip - + /// log_height}`. + #[getset(get_copy = "pub")] + log_height: usize, +} + +impl StackedSlice { + #[inline(always)] + pub fn len(&self, l_skip: usize) -> usize { + Self::_len(self.log_height, l_skip) + } + + #[inline(always)] + pub fn stride(&self, l_skip: usize) -> usize { + 1 << l_skip.saturating_sub(self.log_height) + } + + #[inline(always)] + fn _len(log_height: usize, l_skip: usize) -> usize { + if l_skip <= log_height { + 1 << log_height + } else { + 1 << l_skip + } + } +} + +#[derive(Clone, Debug, Getters, CopyGetters, Serialize, Deserialize)] +pub struct MerkleTree { + /// The matrix that is used to form the leaves of the Merkle tree, which are + /// in turn hashed into the bottom digest layer. + /// + /// This is typically the codeword matrix in hash-based PCS. + #[getset(get = "pub")] + pub(crate) backing_matrix: ColMajorMatrix, + #[getset(get = "pub")] + pub(crate) digest_layers: Vec>, + #[getset(get_copy = "pub")] + pub(crate) rows_per_query: usize, +} + +#[derive(Clone, Serialize, Deserialize, derive_new::new)] +pub struct StackedPcsData { + /// Layout of the unstacked collection of matrices within the stacked matrix. + pub layout: StackedLayout, + /// The stacked matrix of evaluations with height `2^{l_skip + n_stack}`. + pub matrix: ColMajorMatrix, + /// Merkle tree of the Reed-Solomon codewords of the stacked matrix. + /// Depends on `k_whir` parameter. + pub tree: MerkleTree, +} + +impl StackedPcsData { + /// Returns the root of the Merkle tree. + pub fn commit(&self) -> Digest { + self.tree.root() + } + + pub fn mat_view(&self, unstacked_mat_idx: usize) -> StridedColMajorMatrixView<'_, F> { + self.layout.mat_view(unstacked_mat_idx, &self.matrix) + } +} + +#[instrument(level = "info", skip_all)] +pub fn stacked_commit( + l_skip: usize, + n_stack: usize, + log_blowup: usize, + k_whir: usize, + traces: &[&ColMajorMatrix], +) -> (Digest, StackedPcsData) { + let (q_trace, layout) = stacked_matrix(l_skip, n_stack, traces); + let rs_matrix = rs_code_matrix(l_skip, log_blowup, &q_trace); + let tree = MerkleTree::new(rs_matrix, 1 << k_whir); + let root = tree.root(); + let data = StackedPcsData::new(layout, q_trace, tree); + (root, data) +} + +impl StackedLayout { + /// Computes the layout of greedily stacking columns with dimension metadata given by `sorted` + /// into a stacked matrix. + /// - `l_skip` is a threshold log2 height: if a column has height less than `2^l_skip`, it is + /// stacked as a column of height `2^l_skip` with stride `2^{l_skip - log_height}`. + /// - `log_stacked_height` is the log2 height of the stacked matrix. + /// - `sorted` is Vec of `(width, log_height)` that must already be **sorted** in descending + /// order of `log_height`. + pub fn new( + l_skip: usize, + log_stacked_height: usize, + sorted: Vec<(usize /* width */, usize /* log_height */)>, + ) -> Self { + debug_assert!(l_skip <= log_stacked_height); + debug_assert!(sorted.is_sorted_by(|a, b| a.1 >= b.1)); + let mut sorted_cols = Vec::with_capacity(sorted.len()); + let mut mat_starts = Vec::new(); + let mut col_idx = 0; + let mut row_idx = 0; + for (mat_idx, (width, log_ht)) in sorted.into_iter().enumerate() { + mat_starts.push(sorted_cols.len()); + if width == 0 { + continue; + } + assert!( + log_ht <= log_stacked_height, + "log_height={log_ht} > log_stacked_height={log_stacked_height}" + ); + for j in 0..width { + let slice_len = StackedSlice::_len(log_ht, l_skip); + if row_idx + slice_len > (1 << log_stacked_height) { + assert_eq!(row_idx, 1 << log_stacked_height); + col_idx += 1; + row_idx = 0; + } + let slice = StackedSlice { + col_idx, + row_idx, + log_height: log_ht, + }; + sorted_cols.push((mat_idx, j, slice)); + row_idx += slice_len; + } + } + let stacked_width = col_idx + usize::from(row_idx != 0); + debug_assert_eq!( + stacked_width, + sorted_cols + .iter() + .map(|(_, _, slice)| slice.col_idx + 1) + .max() + .unwrap_or(0) + ); + Self { + l_skip, + height: 1 << log_stacked_height, + width: stacked_width, + sorted_cols, + mat_starts, + } + } + + /// Raw unsafe constructor + pub fn from_raw_parts( + l_skip: usize, + log_stacked_height: usize, + sorted_cols: Vec<(usize, usize, StackedSlice)>, + ) -> Self { + let height = 1 << log_stacked_height; + let width = sorted_cols + .iter() + .map(|(_, _, slice)| slice.col_idx + 1) + .max() + .unwrap_or(0); + let mut mat_starts = Vec::new(); + for (idx, (mat_idx, _, _)) in sorted_cols.iter().enumerate() { + if idx == 0 || *mat_idx + 1 != mat_starts.len() { + assert_eq!(*mat_idx, mat_starts.len()); + mat_starts.push(idx); + } + } + Self { + l_skip, + height, + width, + sorted_cols, + mat_starts, + } + } + + pub fn unstacked_slices_iter(&self) -> impl Iterator { + self.sorted_cols.iter().map(|(_, _, s)| s) + } + + /// `(mat_idx, col_idx)` should be indexing into the unstacked collection of matrices. + pub fn get(&self, mat_idx: usize, col_idx: usize) -> Option<&StackedSlice> { + let idx = self.mat_starts[mat_idx]; + if idx + col_idx >= self.sorted_cols.len() { + return None; + } + let (mat_idx1, col_idx1, s) = &self.sorted_cols[idx + col_idx]; + debug_assert_eq!(*mat_idx1, mat_idx); + debug_assert_eq!(*col_idx1, col_idx); + Some(s) + } + + pub fn width_of(&self, mat_idx: usize) -> usize { + let start_idx = self.mat_starts[mat_idx]; + debug_assert_eq!(self.sorted_cols[start_idx].0, mat_idx); + debug_assert_eq!(self.sorted_cols[start_idx].1, 0); + let next_idx = *self + .mat_starts + .get(mat_idx + 1) + .unwrap_or(&self.sorted_cols.len()); + debug_assert_ne!(next_idx, usize::MAX); + next_idx - start_idx + } + + /// Due to the definition of stacking, in a column major matrix the lifted columns of the + /// unstacked matrix will always be contiguous in memory within the stacked matrix, so we + /// can return the sub-view. + pub fn mat_view<'a, F>( + &self, + unstacked_mat_idx: usize, + stacked_matrix: &'a ColMajorMatrix, + ) -> StridedColMajorMatrixView<'a, F> { + let col_slices = self + .sorted_cols + .iter() + .filter(|(m, _, _)| *m == unstacked_mat_idx) + .collect_vec(); + let width = col_slices.len(); + let s = &col_slices[0].2; + let lifted_height = s.len(self.l_skip); + let stride = s.stride(self.l_skip); + let start = col_maj_idx(s.row_idx, s.col_idx, stacked_matrix.height()); + StridedColMajorMatrixView::new( + &stacked_matrix.values[start..start + lifted_height * width], + width, + stride, + ) + } +} + +/// The `traces` **must** already be in height-sorted order. +#[instrument(skip_all)] +pub fn stacked_matrix( + l_skip: usize, + n_stack: usize, + traces: &[&ColMajorMatrix], +) -> (ColMajorMatrix, StackedLayout) { + let sorted_meta = traces + .iter() + .map(|trace| { + // height cannot be zero: + let log_height = log2_strict_usize(trace.height()); + (trace.width(), log_height) + }) + .collect_vec(); + let mut layout = StackedLayout::new(l_skip, l_skip + n_stack, sorted_meta); + let total_cells: usize = traces + .iter() + .map(|t| t.height().max(1 << l_skip) * t.width()) + .sum(); + let height = 1usize << (l_skip + n_stack); + let width = total_cells.div_ceil(height); + + let mut q_mat = F::zero_vec(width.checked_mul(height).unwrap()); + for (mat_idx, j, s) in &mut layout.sorted_cols { + let start = s.col_idx * height + s.row_idx; + let t_col = traces[*mat_idx].column(*j); + debug_assert_eq!(t_col.len(), 1 << s.log_height); + if s.log_height >= l_skip { + q_mat[start..start + t_col.len()].copy_from_slice(t_col); + } else { + // t_col height is smaller than 2^l_skip, so we stride + let stride = s.stride(l_skip); + for (i, val) in t_col.iter().enumerate() { + q_mat[start + i * stride] = *val; + } + } + } + (ColMajorMatrix::new(q_mat, width), layout) +} + +/// Computes the Reed-Solomon codeword of each column vector of `eval_matrix` where the rate is +/// `2^{-log_blowup}`. The column vectors are treated as evaluations of a prismalinear extension on +/// a hyperprism. +#[instrument(skip_all)] +pub fn rs_code_matrix( + l_skip: usize, + log_blowup: usize, + eval_matrix: &ColMajorMatrix, +) -> ColMajorMatrix { + let height = eval_matrix.height(); + let codewords: Vec<_> = eval_matrix + .values + .par_chunks_exact(height) + .map(|column_evals| { + let ple = Ple::from_evaluations(l_skip, column_evals); + let mut coeffs = ple.coeffs; + // Compute RS codeword on a prismalinear polynomial in coefficient form: + // We use that the coefficients are in a basis that exactly corresponds to the standard + // monomial univariate basis. Hence RS codeword is just cosetDFT on the + // relevant smooth domain + let dft = Radix2DitParallel::default(); + coeffs.resize(height.checked_shl(log_blowup as u32).unwrap(), F::ZERO); + dft.dft(coeffs) + }) + .collect::>() + .concat(); + + ColMajorMatrix::new(codewords, eval_matrix.width()) +} + +impl MerkleTree { + pub fn query_stride(&self) -> usize { + self.digest_layers[0].len() + } + + pub fn proof_depth(&self) -> usize { + self.digest_layers.len() - 1 + } +} + +impl MerkleTree { + pub fn root(&self) -> Digest { + self.digest_layers.last().unwrap()[0].clone() + } + + pub fn query_merkle_proof(&self, query_idx: usize) -> Vec { + let stride = self.query_stride(); + assert!( + query_idx < stride, + "query_idx {query_idx} out of bounds for query_stride {stride}" + ); + + let mut idx = query_idx; + let mut proof = Vec::with_capacity(self.proof_depth()); + for layer in self.digest_layers.iter().take(self.proof_depth()) { + let sibling = layer[idx ^ 1].clone(); + proof.push(sibling); + idx >>= 1; + } + proof + } +} + +mod poseidon2_merkle_tree { + use p3_field::ExtensionField; + + use super::*; + use crate::{ + poseidon2::sponge::{poseidon2_compress, poseidon2_hash_slice}, + Digest, F, + }; + + impl MerkleTree + where + EF: ExtensionField, + { + #[instrument(name = "merkle_tree", skip_all)] + pub fn new(matrix: ColMajorMatrix, rows_per_query: usize) -> Self { + let height = matrix.height(); + assert!(height > 0); + assert!(rows_per_query.is_power_of_two()); + let num_leaves = height.next_power_of_two(); + assert!( + rows_per_query <= num_leaves, + "rows_per_query ({rows_per_query}) must not exceed the number of Merkle leaves ({num_leaves})" + ); + let row_hashes: Vec<_> = (0..num_leaves) + .into_par_iter() + .map(|r| { + let hash_input: Vec = Self::row_iter(&matrix, r) + .flat_map(|ef| ef.as_base_slice().to_vec()) + .collect(); + poseidon2_hash_slice(&hash_input) + }) + .collect(); + + let query_stride = num_leaves / rows_per_query; + let mut query_digest_layer = row_hashes; + // For the first log2(rows_per_query) layers, we hash in `query_stride` pairs and don't + // need to store the digest layers + for _ in 0..log2_strict_usize(rows_per_query) { + let prev_layer = query_digest_layer; + query_digest_layer = (0..prev_layer.len() / 2) + .into_par_iter() + .map(|i| { + let x = i / query_stride; + let y = i % query_stride; + let left = prev_layer[2 * x * query_stride + y]; + let right = prev_layer[(2 * x + 1) * query_stride + y]; + poseidon2_compress(left, right) + }) + .collect(); + } + + let mut digest_layers = vec![query_digest_layer]; + while digest_layers.last().unwrap().len() > 1 { + let prev_layer = digest_layers.last().unwrap(); + let layer: Vec<_> = prev_layer + .par_chunks_exact(2) + .map(|pair| poseidon2_compress(pair[0], pair[1])) + .collect(); + digest_layers.push(layer); + } + + Self { + backing_matrix: matrix, + digest_layers, + rows_per_query, + } + } + + /// # Safety + /// - Caller must ensure that `digest_layers` are correctly constructed Merkle hashes for + /// the Merkle tree. + pub unsafe fn from_raw_parts( + backing_matrix: ColMajorMatrix, + digest_layers: Vec>, + rows_per_query: usize, + ) -> Self { + Self { + backing_matrix, + digest_layers, + rows_per_query, + } + } + + /// Returns the ordered set of opened rows for the given query index. + /// The rows are { query_idx + t * query_stride() } for t in 0..rows_per_query. + pub fn get_opened_rows(&self, index: usize) -> Vec> { + let query_stride = self.query_stride(); + assert!( + index < query_stride, + "index {index} out of bounds for query_stride {query_stride}" + ); + + let rows_per_query = self.rows_per_query; + let width = self.backing_matrix.width(); + let mut preimage = Vec::with_capacity(rows_per_query); + for row_offset in 0..rows_per_query { + let row_idx = row_offset * query_stride + index; + let row = Self::row_iter(&self.backing_matrix, row_idx).collect_vec(); + debug_assert_eq!( + row.len(), + width, + "row width mismatch: expected {width}, got {}", + row.len() + ); + preimage.push(row); + } + preimage + } + + fn row_iter(matrix: &ColMajorMatrix, index: usize) -> impl Iterator + '_ { + (0..matrix.width()).map(move |c| matrix.get(index, c).copied().unwrap_or(EF::ZERO)) + } + } +} + +#[cfg(test)] +mod tests { + use itertools::Itertools; + use p3_field::FieldAlgebra; + + use super::*; + use crate::{prover::ColMajorMatrix, F}; + + #[test] + fn test_stacked_matrix_manual_0() { + let columns = [vec![1, 2, 3, 4], vec![5, 6], vec![7]] + .map(|v| v.into_iter().map(F::from_canonical_u32).collect_vec()); + let mats = columns + .into_iter() + .map(|c| ColMajorMatrix::new(c, 1)) + .collect_vec(); + let mat_refs = mats.iter().collect_vec(); + let (stacked_mat, layout) = stacked_matrix(0, 2, &mat_refs); + assert_eq!(stacked_mat.height(), 4); + assert_eq!(stacked_mat.width(), 2); + assert_eq!( + stacked_mat.values, + [1, 2, 3, 4, 5, 6, 7, 0].map(F::from_canonical_u32).to_vec() + ); + assert_eq!(layout.mat_starts, vec![0, 1, 2]); + } + + #[test] + fn test_stacked_matrix_manual_strided_0() { + let columns = [vec![1, 2, 3, 4], vec![5, 6], vec![7]] + .map(|v| v.into_iter().map(F::from_canonical_u32).collect_vec()); + let mats = columns + .into_iter() + .map(|c| ColMajorMatrix::new(c, 1)) + .collect_vec(); + let mat_refs = mats.iter().collect_vec(); + let (stacked_mat, _layout) = stacked_matrix(2, 0, &mat_refs); + assert_eq!(stacked_mat.height(), 4); + assert_eq!(stacked_mat.width(), 3); + assert_eq!( + stacked_mat.values, + [1, 2, 3, 4, 5, 0, 6, 0, 7, 0, 0, 0] + .map(F::from_canonical_u32) + .to_vec() + ); + } + + #[test] + fn test_stacked_matrix_manual_strided_1() { + let columns = [vec![1, 2, 3, 4], vec![5, 6], vec![7]] + .map(|v| v.into_iter().map(F::from_canonical_u32).collect_vec()); + let mats = columns + .into_iter() + .map(|c| ColMajorMatrix::new(c, 1)) + .collect_vec(); + let mat_refs = mats.iter().collect_vec(); + let (stacked_mat, _layout) = stacked_matrix(3, 0, &mat_refs); + assert_eq!(stacked_mat.height(), 8); + assert_eq!(stacked_mat.width(), 3); + assert_eq!( + stacked_mat.values, + [ + [1, 0, 2, 0, 3, 0, 4, 0], + [5, 0, 0, 0, 6, 0, 0, 0], + [7, 0, 0, 0, 0, 0, 0, 0] + ] + .into_iter() + .flatten() + .map(F::from_canonical_u32) + .collect_vec() + ); + } +} diff --git a/crates/stark-backend-v2/src/prover/stacked_reduction.rs b/crates/stark-backend-v2/src/prover/stacked_reduction.rs new file mode 100644 index 00000000..c78fda5b --- /dev/null +++ b/crates/stark-backend-v2/src/prover/stacked_reduction.rs @@ -0,0 +1,488 @@ +//! Stacked opening reduction + +use std::{array::from_fn, collections::HashMap, iter::zip, mem::take}; + +use itertools::Itertools; +use openvm_stark_backend::prover::MatrixDimensions; +use p3_field::{FieldAlgebra, TwoAdicField}; +use p3_maybe_rayon::prelude::*; +use tracing::{debug, instrument}; + +use crate::{ + poly_common::{eval_eq_mle, eval_eq_uni, eval_eq_uni_at_one, eval_in_uni, UnivariatePoly}, + poseidon2::sponge::FiatShamirTranscript, + proof::StackingProof, + prover::{ + poly::evals_eq_hypercube, + stacked_pcs::{StackedPcsData, StackedSlice}, + sumcheck::{ + batch_fold_mle_evals, fold_mle_evals, fold_ple_evals, sumcheck_round0_deg, + sumcheck_round_poly_evals, sumcheck_uni_round0_poly, + }, + ColMajorMatrix, ColMajorMatrixView, CpuBackendV2, CpuDeviceV2, MatrixView, ProverBackendV2, + }, + Digest, EF, F, +}; + +/// Helper trait for proving the reduction of column opening claims and column rotation opening +/// claims to opening claims of column polynomials of the stacked matrix. +/// +/// Returns the reduction proof and the random vector `u` of length `1 + n_stack`. +pub trait StackedReductionProver<'a, PB: ProverBackendV2, PD> { + /// We only provide a view to the stacked `PcsData` per commitment because the WHIR prover will + /// still use the PLE evaluations of the stacked matrices later. The order of + /// `stacked_per_commit` is `common_main, preprocessed for trace_idx=0 (if any), cached_0 for + /// trace_idx=0, ..., preprocessed for trace_idx=1 (if any), ...`. + /// + /// The `lambda` is the batching randomness for the batch sumcheck. + fn new( + device: &'a PD, + stacked_per_commit: Vec<&'a PB::PcsData>, + need_rot_per_commit: Vec>, + r: &[PB::Challenge], + lambda: PB::Challenge, + ) -> Self; + + /// Return the `s_0` batched polynomial from univariate round 0 of sumcheck. + fn batch_sumcheck_uni_round0_poly(&mut self) -> UnivariatePoly; + + fn fold_ple_evals(&mut self, u_0: PB::Challenge); + + fn batch_sumcheck_poly_eval( + &mut self, + round: usize, + u_prev: PB::Challenge, + ) -> [PB::Challenge; 2]; + + fn fold_mle_evals(&mut self, round: usize, u_round: PB::Challenge); + + fn into_stacked_openings(self) -> Vec>; +} + +/// Batch sumcheck to reduce trace openings, including rotations, to stacked matrix opening. +/// +/// The `stacked_matrix, stacked_layout` should be the result of stacking the `traces` with +/// parameters `l_skip` and `n_stack`. +#[instrument(level = "info", skip_all)] +pub fn prove_stacked_opening_reduction<'a, PB, PD, TS, SRP>( + device: &'a PD, + transcript: &mut TS, + n_stack: usize, + stacked_per_commit: Vec<&'a PB::PcsData>, + need_rot_per_commit: Vec>, + r: &[PB::Challenge], +) -> (StackingProof, Vec) +where + PB: ProverBackendV2, + TS: FiatShamirTranscript, + SRP: StackedReductionProver<'a, PB, PD>, +{ + // Batching randomness + let lambda = transcript.sample_ext(); + + let mut prover = SRP::new(device, stacked_per_commit, need_rot_per_commit, r, lambda); + let s_0 = prover.batch_sumcheck_uni_round0_poly(); + for &coeff in s_0.coeffs() { + transcript.observe_ext(coeff); + } + + let mut u_vec = Vec::with_capacity(n_stack + 1); + let u_0 = transcript.sample_ext(); + u_vec.push(u_0); + debug!(round = 0, u_round = %u_0); + + prover.fold_ple_evals(u_0); + // end round 0 + + let mut sumcheck_round_polys = Vec::with_capacity(n_stack); + + #[allow(clippy::needless_range_loop)] + for round in 1..=n_stack { + let batch_s_evals = prover.batch_sumcheck_poly_eval(round, u_vec[round - 1]); + + for &eval in &batch_s_evals { + transcript.observe_ext(eval); + } + sumcheck_round_polys.push(batch_s_evals); + + let u_round = transcript.sample_ext(); + u_vec.push(u_round); + debug!(%round, %u_round); + + prover.fold_mle_evals(round, u_round); + } + let stacking_openings = prover.into_stacked_openings(); + for claims_for_com in &stacking_openings { + for &claim in claims_for_com { + transcript.observe_ext(claim); + } + } + let proof = StackingProof { + univariate_round_coeffs: s_0.0, + sumcheck_round_polys, + stacking_openings, + }; + (proof, u_vec) +} + +pub struct StackedReductionCpu<'a> { + l_skip: usize, + omega_skip: F, + + r_0: EF, + lambda_pows: Vec, + eq_const: EF, + + stacked_per_commit: Vec<&'a StackedPcsData>, + trace_views: Vec, + ht_diff_idxs: Vec, + + eq_r_per_lht: HashMap>, + + // After round 0: + k_rot_r_per_lht: HashMap>, + q_evals: Vec>, + /// Stores eq(u[1+n_T..round-1], b_{T,j}[..round-n_T-1]) + eq_ub_per_trace: Vec, +} + +struct TraceViewMeta { + com_idx: usize, + slice: StackedSlice, + lambda_eq_idx: usize, + lambda_rot_idx: Option, +} + +impl<'a> StackedReductionProver<'a, CpuBackendV2, CpuDeviceV2> for StackedReductionCpu<'a> { + fn new( + device: &CpuDeviceV2, + stacked_per_commit: Vec<&'a StackedPcsData>, + need_rot_per_commit: Vec>, + r: &[EF], + lambda: EF, + ) -> Self { + let l_skip = device.config().l_skip; + let omega_skip = F::two_adic_generator(l_skip); + + let mut trace_views = Vec::new(); + let mut lambda_idx = 0usize; + for (com_idx, d) in stacked_per_commit.iter().enumerate() { + let need_rot_for_commit = &need_rot_per_commit[com_idx]; + debug_assert_eq!(need_rot_for_commit.len(), d.layout.mat_starts.len()); + for &(mat_idx, _col_idx, slice) in &d.layout.sorted_cols { + let lambda_eq_idx = lambda_idx; + lambda_idx += 1; + let lambda_rot_idx = if need_rot_for_commit[mat_idx] { + Some(lambda_idx) + } else { + None + }; + lambda_idx += 1; + trace_views.push(TraceViewMeta { + com_idx, + slice, + lambda_eq_idx, + lambda_rot_idx, + }); + } + } + let lambda_pows = lambda.powers().take(lambda_idx).collect_vec(); + + let mut ht_diff_idxs = Vec::new(); + let mut eq_r_per_lht: HashMap> = HashMap::new(); + let mut last_height = 0; + for (i, tv) in trace_views.iter().enumerate() { + let n_lift = tv.slice.log_height().saturating_sub(l_skip); + if i == 0 || tv.slice.log_height() != last_height { + ht_diff_idxs.push(i); + last_height = tv.slice.log_height(); + } + eq_r_per_lht + .entry(tv.slice.log_height()) + .or_insert_with(|| ColMajorMatrix::new(evals_eq_hypercube(&r[1..1 + n_lift]), 1)); + } + ht_diff_idxs.push(trace_views.len()); + + let eq_const = eval_eq_uni_at_one(l_skip, r[0] * omega_skip); + let eq_ub_per_trace = vec![EF::ONE; trace_views.len()]; + + Self { + l_skip, + omega_skip, + r_0: r[0], + lambda_pows, + eq_const, + stacked_per_commit, + trace_views, + ht_diff_idxs, + eq_r_per_lht, + q_evals: vec![], + k_rot_r_per_lht: HashMap::new(), + eq_ub_per_trace, + } + } + + fn batch_sumcheck_uni_round0_poly(&mut self) -> UnivariatePoly { + let l_skip = self.l_skip; + let omega_skip = self.omega_skip; + // +1 from eq term + let s_0_deg = sumcheck_round0_deg(l_skip, 2); + // We want to compute algebraic batching, via \lambda, + // for each (T, j) pair of (trace, column) of the univariate polynomials + // ```text + // Z -> sum_{x in H_{n_stack}} q(Z,x) in_{D,n_T}(Z) eq_{D_{n_T}}((Z,x[..\tilde n_T]), r[..1+\tilde n_T]) eq(x[\tilde n_T..], b_{T,j}) + // Z -> sum_{x in H_{n_stack}} q(Z,x) in_{D,n_T}(Z) \kappa_{\rot, D_{n_T}}((Z,x[..\tilde n_T]), r[..1+\tilde n_T]) eq(x[\tilde n_T..], b_{T,j}) + // ``` + // where `b_{T,j}` is length `n_stack - n_T` binary encoding of `StackedSlice.row_idx >> + // (l_skip + // + n_T)`. Note that since x is in the hypercube, by definition of `eq` the above + // simplifies to + // ```text + // Z -> sum_{x in H_{n_T}} q(Z,x,b_{T,j}) (in_{D,n_T}(Z) eq(Z,r_0)) eq(x[..\tilde n_T], r[1..1+\tilde n_T]) + // Z -> sum_{x in H_{n_T}} q(Z,x,b_{T,j}) in_{D,n_T}(Z) \kappa_rot((Z, x[..n_T]), (r_0, r[1..1+n_T])) + // ``` + // where we also simplified the other `eq` term. + // We further simplify the second using equation + // ```text + // \kappa_rot((Z, x[..n_T]), (r_0, r[1..1+n_T])) = + // eq_D(Z,omega_D r_0) eq(x[..n_T], r[1..1+n_T]) + eq_D(Z,1)eq_D(omega_D r_0,1) ( kappa_rot(x[..n_T], r[1..1+n_T]) - eq(x[..n_T], r[1..1+n_T]) ) + // ``` + // We compute the last polynomial in our usual way, by considering `q(\vec Z, b_{T,j})` as a + // prismalinear polynomial and using its evaluations on `D_{n_T}`. + let s_0_polys: Vec<_> = self + .ht_diff_idxs + .par_windows(2) + .flat_map(|window| { + let t_window = &self.trace_views[window[0]..window[1]]; + let log_height = t_window[0].slice.log_height(); + let n = log_height as isize - l_skip as isize; + let n_lift = n.max(0) as usize; + let eq_rs = self.eq_r_per_lht.get(&log_height).unwrap().column(0); + debug_assert_eq!(eq_rs.len(), 1 << n_lift); + // Prepare the q subslice eval views + let q_t_cols = t_window + .iter() + .map(|tv| { + debug_assert_eq!(tv.slice.log_height(), log_height); + let q = &self.stacked_per_commit[tv.com_idx].matrix; + let s = tv.slice; + let q_t_col = &q.column(s.col_idx)[s.row_idx..s.row_idx + s.len(l_skip)]; + // NOTE: even if s.stride(l_skip) != 1, we use the full non-strided column + // subslice. The sumcheck will not depend on the values outside of the + // stride because of the `in_{D, n_T}` indicator below. + (ColMajorMatrixView::new(q_t_col, 1).into(), false) + }) + .collect_vec(); + sumcheck_uni_round0_poly(l_skip, n_lift, 2, &q_t_cols, |z, x, evals| { + let eq_cube = eq_rs[x]; + let (l, omega, r_uni) = if n.is_negative() { + ( + l_skip.wrapping_add_signed(n), + omega_skip.exp_power_of_2(-n as usize), + self.r_0.exp_power_of_2(-n as usize), + ) + } else { + (l_skip, omega_skip, self.r_0) + }; + let ind = eval_in_uni(l_skip, n, z); + let eq_uni_r0 = eval_eq_uni(l, z.into(), r_uni); + let eq_uni_r0_rot = eval_eq_uni(l, z.into(), r_uni * omega); + // eq_uni_1, k_rot_cube are only used when n > 0 + let eq_uni_1 = eval_eq_uni_at_one(l_skip, z); + let k_rot_cube = eq_rs[rot_prev(x, n_lift)]; + + let eq = eq_uni_r0 * eq_cube; + let k_rot = + eq_uni_r0_rot * eq_cube + self.eq_const * eq_uni_1 * (k_rot_cube - eq_cube); + zip(t_window, evals).fold([EF::ZERO; 2], |mut acc, (tv, eval)| { + let q = eval[0]; + acc[0] += self.lambda_pows[tv.lambda_eq_idx] * eq * q * ind; + if let Some(rot_idx) = tv.lambda_rot_idx { + acc[1] += self.lambda_pows[rot_idx] * k_rot * q * ind; + } + acc + }) + }) + }) + .collect(); + let s_0_coeffs = (0..=s_0_deg) + .map(|i| s_0_polys.iter().map(|evals| evals.coeffs()[i]).sum::()) + .collect_vec(); + UnivariatePoly::new(s_0_coeffs) + } + + fn fold_ple_evals(&mut self, u_0: EF) { + let l_skip = self.l_skip; + let r_0 = self.r_0; + let omega_skip = self.omega_skip; + self.q_evals = self + .stacked_per_commit + .iter() + .map(|d| fold_ple_evals(l_skip, d.matrix.as_view().into(), false, u_0)) + .collect_vec(); + // fold PLEs into MLEs for \eq and \kappa_\rot, using u_0 + let eq_uni_u0r0 = eval_eq_uni(l_skip, u_0, r_0); + let eq_uni_u0r0_rot = eval_eq_uni(l_skip, u_0, r_0 * omega_skip); + let eq_uni_u01 = eval_eq_uni_at_one(l_skip, u_0); + // \kappa_\rot(x, r) = eq(rot^{-1}(x), r) + self.k_rot_r_per_lht = self + .eq_r_per_lht + .par_iter_mut() + .map(|(&log_height, mat)| { + let n = log_height as isize - l_skip as isize; + let n_lift = n.max(0) as usize; + debug_assert_eq!(mat.values.len(), 1 << n_lift); + let ind = eval_in_uni(l_skip, n, u_0); + let (eq_uni, eq_uni_rot) = if n.is_negative() { + let omega = omega_skip.exp_power_of_2(-n as usize); + let r = r_0.exp_power_of_2(-n as usize); + let l = l_skip.wrapping_add_signed(n); + (eval_eq_uni(l, u_0, r), eval_eq_uni(l, u_0, r * omega)) + } else { + (eq_uni_u0r0, eq_uni_u0r0_rot) + }; + // folded \kappa_\rot evals + let evals: Vec<_> = (0..1 << n_lift) + .into_par_iter() + .map(|x| { + let eq_cube = unsafe { *mat.get_unchecked(x, 0) }; + let k_rot_cube = unsafe { *mat.get_unchecked(rot_prev(x, n_lift), 0) }; + ind * (eq_uni_rot * eq_cube + + self.eq_const * eq_uni_u01 * (k_rot_cube - eq_cube)) + }) + .collect(); + // update \eq with the univariate factor: + mat.values.par_iter_mut().for_each(|v| { + *v *= ind * eq_uni; + }); + (log_height, ColMajorMatrix::new(evals, 1)) + }) + .collect(); + } + + fn batch_sumcheck_poly_eval(&mut self, round: usize, _u_prev: EF) -> [EF; 2] { + let l_skip = self.l_skip; + let s_deg = 2; + // We want to compute algebraic batching, via \lambda, + // for each (T, j) pair of (trace, column) of the univariate polynomials + // ``` + // X -> sum_{y in H_{n_stack-round}} q(u[..round],X,y) eq((u[..round],X,y[..n_T-round]), r[..1+n_T]) eq(y[n_T-round..], b_{T,j}) + // = sum_{y in H_{n_T-round}} q(u[..round],X,y,b_{T,j}) eq((u[..round],X,y), r[..1+n_T]) + // X -> sum_{y in H_{n_stack-round}} q(u[..round],X,y) \kappa_\rot((u[..round],X,y[..n_T]), r[..1+n_T]) eq(y[n_T..], b_{T,j}) + // ``` + // if `round <= n_T`. Otherwise we compute + // ``` + // X -> sum_{y in H_{n_stack-round}} q(u[..round],X,y) eq((u[..1+n_T], r[..1+n_T]) eq((u[1+n_T..round],X,y[round..]), b_{T,j}) + // = q(u[..round], X, b_{T,j}[round-n_T..]) eq((u[..1+n_T], r[..1+n_T]) eq((u[1+n_T..round],X), b_{T,j}[..round-n_T]) + // X -> sum_{y in H_{n_stack-round}} q(u[..round],X,y) \kappa_\rot(u[..1+n_T], r[..1+n_T]) eq((u[1+n_T..round],X,y[round..]), b_{T,j}) + // ``` + let s_evals: Vec<_> = self + .ht_diff_idxs + .par_windows(2) + .flat_map(|window| { + let t_views = &self.trace_views[window[0]..window[1]]; + let log_height = t_views[0].slice.log_height(); + let n_lift = log_height.saturating_sub(l_skip); // \tilde{n}_T + let hypercube_dim = n_lift.saturating_sub(round); + let eq_rs = self.eq_r_per_lht.get(&log_height).unwrap().column(0); + let k_rot_rs = self.k_rot_r_per_lht.get(&log_height).unwrap().column(0); + debug_assert_eq!(eq_rs.len(), 1 << n_lift.saturating_sub(round - 1)); + debug_assert_eq!(k_rot_rs.len(), 1 << n_lift.saturating_sub(round - 1)); + // Prepare the q subslice eval views + let t_cols = t_views + .iter() + .map(|tv| { + debug_assert_eq!(tv.slice.log_height(), log_height); + // q(u[..round], X, b_{T,j}[round-\tilde n_T..]) + // q_evals has been folded already + let q = &self.q_evals[tv.com_idx]; + let s = tv.slice; + let row_start = if round <= n_lift { + // round >= 1 so n_lift >= 1 + (s.row_idx >> log_height) << (hypercube_dim + 1) + } else { + (s.row_idx >> (l_skip + round)) << 1 + }; + let t_col = + &q.column(s.col_idx)[row_start..row_start + (2 << hypercube_dim)]; + ColMajorMatrixView::new(t_col, 1) + }) + .collect_vec(); + sumcheck_round_poly_evals(hypercube_dim + 1, s_deg, &t_cols, |x, y, evals| { + evals + .iter() + .enumerate() + .fold([EF::ZERO; 2], |mut acc, (i, eval)| { + let t_idx = window[0] + i; + let tv = &self.trace_views[t_idx]; + let q = eval[0]; + let mut eq_ub = self.eq_ub_per_trace[t_idx]; + let (eq, k_rot) = if round > n_lift { + // Extra contribution of eq(X, b_{T,j}[round-n_T-1]) + let b = (tv.slice.row_idx >> (l_skip + round - 1)) & 1; + eq_ub *= eval_eq_mle(&[x], &[F::from_bool(b == 1)]); + debug_assert_eq!(y, 0); + (eq_rs[0] * eq_ub, k_rot_rs[0] * eq_ub) + } else { + // linearly interpolate eq(-, r[..1+n_T]), \kappa_\rot(-, + // r[..1+n_T]) + let eq_r = eq_rs[y << 1] * (EF::ONE - x) + eq_rs[(y << 1) + 1] * x; + let k_rot_r = + k_rot_rs[y << 1] * (EF::ONE - x) + k_rot_rs[(y << 1) + 1] * x; + (eq_r * eq_ub, k_rot_r * eq_ub) + }; + acc[0] += self.lambda_pows[tv.lambda_eq_idx] * q * eq; + if let Some(rot_idx) = tv.lambda_rot_idx { + acc[1] += self.lambda_pows[rot_idx] * q * k_rot; + } + acc + }) + }) + }) + .collect(); + from_fn(|i| s_evals.iter().map(|evals| evals[i]).sum::()) + } + + fn fold_mle_evals(&mut self, round: usize, u_round: EF) { + let l_skip = self.l_skip; + self.q_evals = batch_fold_mle_evals(take(&mut self.q_evals), u_round); + self.eq_r_per_lht = take(&mut self.eq_r_per_lht) + .into_par_iter() + .map(|(lht, mat)| (lht, fold_mle_evals(mat, u_round))) + .collect(); + self.k_rot_r_per_lht = take(&mut self.k_rot_r_per_lht) + .into_par_iter() + .map(|(lht, mat)| (lht, fold_mle_evals(mat, u_round))) + .collect(); + for (tv, eq_ub) in zip(&self.trace_views, &mut self.eq_ub_per_trace) { + let s = tv.slice; + let n_lift = s.log_height().saturating_sub(l_skip); + if round > n_lift { + // Folding above did nothing, and we update the eq(u[1+n_T..=round], + // b_{T,j}[..=round-n_T-1]) value + let b = (s.row_idx >> (l_skip + round - 1)) & 1; + *eq_ub *= eval_eq_mle(&[u_round], &[F::from_bool(b == 1)]); + } + } + } + + fn into_stacked_openings(self) -> Vec> { + self.q_evals + .into_iter() + .map(|q| { + debug_assert_eq!(q.height(), 1); + q.values + }) + .collect() + } +} + +/// `x_int` is the integer representation of point on H_n. +fn rot_prev(x_int: usize, n: usize) -> usize { + debug_assert!(x_int < (1 << n)); + if x_int == 0 { + (1 << n) - 1 + } else { + x_int - 1 + } +} diff --git a/crates/stark-backend-v2/src/prover/sumcheck.rs b/crates/stark-backend-v2/src/prover/sumcheck.rs new file mode 100644 index 00000000..3c4500bc --- /dev/null +++ b/crates/stark-backend-v2/src/prover/sumcheck.rs @@ -0,0 +1,595 @@ +use std::array::from_fn; + +use cfg_if::cfg_if; +use itertools::Itertools; +use openvm_stark_backend::prover::MatrixDimensions; +use p3_dft::TwoAdicSubgroupDft; +use p3_field::{ + batch_multiplicative_inverse, ExtensionField, Field, FieldAlgebra, FieldExtensionAlgebra, + TwoAdicField, +}; +use p3_interpolation::interpolate_coset; +use p3_matrix::dense::RowMajorMatrix; +use p3_maybe_rayon::prelude::*; +use p3_util::log2_strict_usize; +use tracing::{debug, instrument, trace}; + +use crate::{ + dft::Radix2BowersSerial, + poly_common::UnivariatePoly, + poseidon2::sponge::FiatShamirTranscript, + prover::{ColMajorMatrix, ColMajorMatrixView, MatrixView, StridedColMajorMatrixView}, + EF, +}; + +/// The univariate skip round 0: we want to compute the univariate polynomial `s(Z) = sum_{x \in +/// H_n} \hat{f}(Z, x)`. For this function, assume that `\hat{f}(\vec z) = \hat\eps(\vec z) +/// W(\hat{T}_0(\vec z), .., \hat{T}_{m-1}(\vec z))` for a sequence of `\hat{T}_i` where each +/// `\hat{T}_i` consists of a collection of prismalinear polynomials in `n + 1` variables, with +/// degree `< 2^{l_skip}` in the first variable. +/// +/// The `mats` consists of the evaluations of `\hat{T}_i` on the hyperprism `D_n`, where evaluations +/// of each `\hat{T}_i` are in column-major order. +/// For round 0, we also provide a boolean `is_rotation` indicating whether the matrix should be +/// accessed at a cyclic offset of 1 (aka rotation). +/// +/// `eps` is a single column vector of evaluations on `D_n`, except valued in extension field. +/// +/// Let `W` be degree `d` in each variable. Then `s` is degree `<= d * (2^{l_skip} - 1)`, so it can +/// be interpolated using `d * (2^{l_skip} - 1) + 1` points. +/// +/// This function returns `s` in **coefficient** form. +/// +/// If `n > 0`, then all `mats` should have the same height equal to `2^{l_skip + n}`. +/// If `n = 0`, then all `mats` should have height `<= 2^{l_skip}` and they will be univariate +/// lifted to height `2^l_skip`. +#[instrument(level = "trace", skip_all)] +pub fn sumcheck_uni_round0_poly( + l_skip: usize, + n: usize, + d: usize, + mats: &[(StridedColMajorMatrixView, bool)], + w: FN, +) -> [UnivariatePoly; WD] +where + F: TwoAdicField, + EF: ExtensionField + TwoAdicField, + FN: Fn( + F, /* Z */ + usize, /* x_int */ + &[Vec], /* mats eval at (Z, bin(x_int)) */ + ) -> [EF; WD] + + Sync, +{ + if d == 0 { + return from_fn(|_| UnivariatePoly(vec![])); + } + #[cfg(debug_assertions)] + if n > 0 { + for (m, _) in mats.iter() { + assert_eq!(m.height(), 1 << (l_skip + n)); + } + } else { + for (m, _) in mats.iter() { + assert!( + m.height() <= 1 << l_skip, + "mat height {} > 2^{l_skip}", + m.height() + ); + } + } + let g = F::GENERATOR; + let omega_skip = F::two_adic_generator(l_skip); + // skip 1 to avoid divide by zero in zerocheck + let coset_shifts = g.powers().skip(1).take(d).collect_vec(); + + // Map-Reduce + // Map: for each x in H_n, compute + // ``` + // [W(\hat{T}_0(z, x), ..., \hat{T}_{m-1}(z, x)) for z in `g_i D` for `d` cosets of `D`. + // ``` + // We choose to iterate over x first to avoid multiple memory accesses to `\hat{T}`s + let evals = (0..1 << n).into_par_iter().map(|x| { + let dft = Radix2BowersSerial; + // For fixed `x`, `Z -> \hat{T}_i(Z, x)` is a polynomial of degree `<2^l_skip` and we + // have evaluations on the univariate skip domain `D = <ω_skip>` and we want to get + // evaluations on the larger domain `L`. + // + // For now, we apply iDFT on D and then DFT on L. + // PERF[jpw]: the most efficient algorithm would be to use Chirp-Z transform on L. + let mats_at_zs = mats + .iter() + .map(|(mat, is_rot)| { + let height = mat.height(); + let offset = usize::from(*is_rot); + (0..mat.width()) + .map(|col_idx| { + // SAFETY: col_idx < width + // Note that the % height is necessary even when `offset = 0` because we may + // have `height < 2^{l_skip + n}` in the case where we are taking the lifts + // of `mats` + let col_x = ((x << l_skip)..(x + 1) << l_skip) + .map(|i| unsafe { *mat.get_unchecked((i + offset) % height, col_idx) }) + .collect_vec(); + let coeffs = dft.idft(col_x); + coset_shifts + .iter() + .flat_map(|&shift| dft.coset_dft(coeffs.clone(), shift)) + .collect_vec() + }) + .collect_vec() + }) + .collect_vec(); + // Apply W(..) to `{\hat{T}_i(z, x)}` for each z in g_i D + omega_skip + .powers() + .take(1 << l_skip) + .enumerate() + .flat_map(|(z_idx, z)| { + coset_shifts + .iter() + .enumerate() + .map(|(coset_idx, &shift)| { + let z_int = (coset_idx << l_skip) + z_idx; + let row_z_x = mats_at_zs + .iter() + .map(|mat_at_zs| { + mat_at_zs + .iter() + .map(|col_at_zs| col_at_zs[z_int]) + .collect_vec() + }) + .collect_vec(); + w(shift * z, x, &row_z_x) + }) + .collect_vec() + }) + .collect_vec() + }); + // Reduce: sum over H_n + let hypercube_sum = |mut acc: Vec<[EF; WD]>, x| { + for (acc, x) in acc.iter_mut().zip(x) { + for (acc_i, x_i) in acc.iter_mut().zip(x) { + *acc_i += x_i; + } + } + acc + }; + cfg_if! { + if #[cfg(feature = "parallel")] { + let evals = evals.reduce( + || vec![[EF::ZERO; WD]; d << l_skip], + hypercube_sum + ); + } else { + let evals = evals.collect_vec(); + let evals = evals.into_iter().fold( + vec![[EF::ZERO; WD]; d << l_skip], + hypercube_sum + ); + } + } + from_fn(|i| { + let values = evals.iter().map(|x| x[i]).collect_vec(); + UnivariatePoly::from_geometric_cosets_evals_idft(RowMajorMatrix::new(values, d), g, g) + }) +} + +pub const fn sumcheck_round0_deg(l_skip: usize, d: usize) -> usize { + d * ((1 << l_skip) - 1) +} + +/// `mat` is a matrix of the evaluations on hyperprism D_n of a prismalinear extensions of the +/// columns. We "fold" it by evaluating the prismalinear polynomials at `r` in the univariate +/// variable `Z`. +/// +/// If `n < 0`, then we evaluate `mat` at `r^{-n}`, which is equivalent to folding the lift of +/// `mat`. +#[instrument(level = "trace", skip_all)] +pub fn fold_ple_evals( + l_skip: usize, + mat: StridedColMajorMatrixView, + is_rot: bool, + r: EF, +) -> ColMajorMatrix +where + F: TwoAdicField, + EF: ExtensionField + TwoAdicField, +{ + let height = mat.height(); + let lifted_height = height.max(1 << l_skip); + let width = mat.width(); + + let omega = F::two_adic_generator(l_skip); + let denoms = omega + .powers() + .take(1 << l_skip) + .map(|x_i| r - EF::from(x_i)) + .collect_vec(); + let inv_denoms = batch_multiplicative_inverse(&denoms); + + let offset = usize::from(is_rot); + let new_height = lifted_height >> l_skip; + let values = (0..width * new_height) + .into_par_iter() + .map(|idx| { + // `values` needs to be column-major + let x = idx % new_height; + let j = idx / new_height; + // SAFETY: j < width and we mod by height so row_idx < height + // Note that the `% height` is also necessary to handle lifting of `mats` + let uni_evals = (0..1 << l_skip) + .map(|z| unsafe { *mat.get_unchecked(((x << l_skip) + z + offset) % height, j) }) + .collect_vec(); + interpolate_coset( + &RowMajorMatrix::new_col(uni_evals), + F::ONE, + r, + Some(&inv_denoms), + )[0] + }) + .collect::>(); + ColMajorMatrix::new(values, width) +} + +pub fn batch_fold_ple_evals( + l_skip: usize, + mats: Vec>, + is_rot: bool, + r: EF, +) -> Vec> +where + F: TwoAdicField, + EF: ExtensionField + TwoAdicField, +{ + mats.into_par_iter() + .map(|mat| fold_ple_evals(l_skip, mat.as_view().into(), is_rot, r)) + .collect() +} + +/// For a sumcheck round, we want to compute the univariate polynomial `s(X) = sum_{y \in H_{n-1}} +/// \hat{f}(X, y)`. For this function, assume that `\hat{f}(\vec x) = W(\hat{T}_0(\vec x), .., +/// \hat{T}_{m-1}(\vec x))` for a sequence of `\hat{T}_i` where each `\hat{T}_i` consists of a +/// collection of MLE polynomials in `n` variables. +/// +/// The `mats` consists of the evaluations of `\hat{T}_i` on the hypercube `H_n`, where evaluations +/// of each `\hat{T}_i` are in column-major order. +/// +/// Let `W` be degree `d` in each variable. Then `s` is degree `d`, so it can be interpolated using +/// `d + 1` points. This function returns the evaluations of `s` at `{1, ..., d}`. The evaluation at +/// `0` is omitted because our use of sumcheck always leaves the verifier to infer the evaluation at +/// `0` from the previous round's claim. +/// +/// The generic `WF` is a closure `{\hat{T}_i(X, y)}_i -> W(\hat{T}_0(X, y), .., \hat{T}_{m-1}(X, +/// y))`. +/// +/// This function should **not** be used for the univariate skip round. +#[instrument(level = "trace", skip_all)] +pub fn sumcheck_round_poly_evals( + n: usize, + d: usize, + mats: &[ColMajorMatrixView], + w: FN, +) -> [Vec; WD] +where + F: Field, + FN: Fn( + F, /* X */ + usize, /* y_int */ + &[Vec], /* mats eval at (X, bin(y_int)) */ + ) -> [F; WD] + + Sync, +{ + debug_assert!(mats.iter().all(|mat| mat.height() == 1 << n)); + if n == 0 { + // Sum is trivial, s(X) is constant + let evals = mats.iter().map(|row| row.values.to_vec()).collect_vec(); + return w(F::ONE, 0, &evals).map(|x| vec![x; d]); + } + let hypercube_dim = n - 1; + // \hat{f}(x, \vec y) where \vec y is point on hypercube H_{n-1} + let f_hat = |x: usize, y: usize| { + let x = F::from_canonical_usize(x); + let row_x_y = mats + .iter() + .map(|mat| { + mat.columns() + .map(|col| { + let t_0 = col[y << 1]; + let t_1 = col[(y << 1) | 1]; + // Evaluate \hat{t}(x, \vec y) by linear interpolation since + // \hat{t} is MLE + t_0 + (t_1 - t_0) * x + }) + .collect_vec() + }) + .collect_vec(); + w(x, y, &row_x_y) + }; + trace!(sum_claim = ?{(0..1 << n) + .map(|x| f_hat(x & 1, x >> 1)) + .fold([F::ZERO; WD], |mut acc, x| { + for (acc_i, x_i) in acc.iter_mut().zip(x) { + *acc_i += x_i; + } + acc + }) + }, "sumcheck_round"); + // Map-Reduce + // Map: for each y in H_{n-1}, compute + // ``` + // [W(\hat{T}_0(x, y), ..., \hat{T}_{m-1}(x, y)) for x in {1,...,d}] + // ``` + // We choose to iterate over y first to avoid multiple memory accesses to `\hat{T}`s + let evals = (0..1 << hypercube_dim) + .into_par_iter() + .map(|y| (1..=d).map(|x| f_hat(x, y)).collect_vec()); + // Reduce: sum over H_{n-1} + let hypercube_sum = |mut acc: Vec<[F; WD]>, x| { + for (acc, x) in acc.iter_mut().zip(x) { + for (acc_i, x_i) in acc.iter_mut().zip(x) { + *acc_i += x_i; + } + } + acc + }; + cfg_if! { + if #[cfg(feature = "parallel")] { + let evals = evals.reduce( + || vec![[F::ZERO; WD]; d], + hypercube_sum + ); + } else { + let evals = evals.collect_vec(); + let evals = evals.into_iter().fold( + vec![[F::ZERO; WD]; d], + hypercube_sum + ); + } + } + from_fn(|i| evals.iter().map(|eval| eval[i]).collect_vec()) +} + +#[instrument(level = "trace", skip_all)] +pub fn fold_mle_evals(mat: ColMajorMatrix, r: EF) -> ColMajorMatrix { + let height = mat.height(); + if height <= 1 { + return mat; + } + let width = mat.width(); + let values = mat + .values + .par_chunks_exact(height) + .flat_map(|t| { + t.par_chunks_exact(2).map(|t_01| { + let t_0 = t_01[0]; + let t_1 = t_01[1]; + t_0 + (t_1 - t_0) * r + }) + }) + .collect::>(); + ColMajorMatrix::new(values, width) +} + +pub fn batch_fold_mle_evals( + mats: Vec>, + r: EF, +) -> Vec> { + mats.into_par_iter() + .map(|mat| fold_mle_evals(mat, r)) + .collect() +} + +/// `mat` is column major evaluations on H_n +pub fn fold_mle_evals_inplace(mat: &mut ColMajorMatrix, r: EF) { + let height = mat.height(); + if height <= 1 { + return; + } + mat.values.par_chunks_exact_mut(height).for_each(|t| { + for y in 0..height / 2 { + let t_0 = t[y << 1]; + let t_1 = t[(y << 1) + 1]; + t[y] = t_0 + (t_1 - t_0) * r; + } + }); +} + +pub struct SumcheckCubeProof { + /// Note: the sum claim is always observed as an element of the extension field. + pub sum_claim: EF, + /// For each `round`, we have univariate polynomial `s_round`. We store evaluations at `{1, + /// ..., deg(s_round)}` where evaluation at `0` is left for the verifier to infer from the + /// previous round claim. + pub round_polys_eval: Vec>, + /// Final evaluation claim of the polynomial at the random vector `r` + pub eval_claim: EF, +} + +pub struct SumcheckPrismProof { + pub sum_claim: EF, + /// The univariate polynomial `s_0` in coefficient form. + pub s_0: UnivariatePoly, + /// for each hypercube `round`, the evaluations of univariate polynomial `s_round` at `{1, ..., + /// deg(s_round)}`. See [SumcheckCubeProof] for details. + pub round_polys_eval: Vec>, + /// Final evaluation claim of the polynomial at the random vector `r` + pub eval_claim: EF, +} + +/// "Plain" sumcheck on a multilinear polynomial +/// +/// The slice `evals` contains the evaluations of a multilinear polynomial on boolean hypercube. +/// The length of `evals` should equal `2^n` where `n` is hypercube dimension. +/// +/// Returns the sumcheck proof containing all prover messages and the random evaluation point. +// +// NOTE[jpw]: we currently fix EF for the transcript, but the evaluations in F can be either base +// field or extension field +pub fn sumcheck_multilinear( + transcript: &mut TS, + evals: &[F], +) -> (SumcheckCubeProof, Vec) +where + EF: ExtensionField, +{ + let n = log2_strict_usize(evals.len()); + let mut round_polys_eval = Vec::with_capacity(n); + let mut r = Vec::with_capacity(n); + + // Working copy of evaluations that gets folded after each round + // PERF[jpw]: the first round should be treated specially in the case F is the base field + let mut current_evals = + ColMajorMatrix::new(evals.iter().map(|&x| EF::from_base(x)).collect(), 1); + let sum_claim: EF = evals.iter().fold(F::ZERO, |acc, &x| acc + x).into(); + transcript.observe_ext(sum_claim); + + // Sumcheck rounds: + // - each round the prover needs to compute univariate polynomial `s_round`. This poly is linear + // since we are taking MLE of `evals`. + // - at end of each round, sample random `r_round` in `EF` + for round in 0..n { + let [s] = + sumcheck_round_poly_evals(n - round, 1, &[current_evals.as_view()], |_x, _y, evals| { + [evals[0][0]] + }); + + println!("CPU s: {:?}", s); + assert_eq!(s.len(), 1); + transcript.observe_ext(s[0]); + round_polys_eval.push(s); + + let r_round = transcript.sample_ext(); + debug!(%round, %r_round); + r.push(r_round); + + current_evals = fold_mle_evals(current_evals, r_round); + } + + // After all rounds, current_evals should have exactly one element + assert_eq!(current_evals.values.len(), 1); + let eval_claim = current_evals.values[0]; + + // Add final evaluation to transcript + transcript.observe_ext(eval_claim); + + ( + SumcheckCubeProof { + sum_claim, + round_polys_eval, + eval_claim, + }, + r, + ) +} + +/// "Plain" sumcheck on a prismalinear polynomial with Gruen's univariate skip. +/// +/// The slice `evals` contains the evaluations of a prismalinear polynomial on the hyperprism. +/// The length of `evals` should equal `2^{l_skip + n}` where `l_skip` is the univariate skip +/// parameter and `n` is hypercube dimension. +/// Indexing is such that `evals[x * 2^{l_skip} + i]` is the evaluation of `f(omega_D^i, x)` where +/// `omega_D` is a fixed generator of the univariate skip domain `D` (which is a subgroup of +/// `F^*`). +/// +/// Returns the sumcheck proof containing all prover messages and the random evaluation point. +// +// NOTE[jpw]: +// - we currently fix EF for the transcript, but the evaluations in F can be either base +// field or extension field. +// - for simplicity, the transcript observes `sum_claim` and `s_0` as valued in `EF`. More +// fine-grained approaches may observe in `F`. +pub fn sumcheck_prismalinear( + transcript: &mut TS, + l_skip: usize, + evals: &[F], +) -> (SumcheckPrismProof, Vec) +where + F: TwoAdicField, + EF: ExtensionField, +{ + let prism_dim = log2_strict_usize(evals.len()); + assert!(prism_dim >= l_skip); + let n = prism_dim - l_skip; + + let mut round_polys_eval = Vec::with_capacity(n); + let mut r = Vec::with_capacity(n + 1); + + let sum_claim: EF = evals.iter().copied().sum::().into(); + transcript.observe_ext(sum_claim); + let current_evals = ColMajorMatrix::new(evals.to_vec(), 1); + let [s_0] = sumcheck_uni_round0_poly( + l_skip, + n, + 1, + &[(current_evals.as_view().into(), false)], + |_z, _x, evals| [evals[0][0]], + ); + let s_0_ext = UnivariatePoly::new( + s_0.0 + .into_iter() + .map(|x| { + let ext = EF::from(x); + transcript.observe_ext(ext); + ext + }) + .collect(), + ); + + let r_0 = transcript.sample_ext(); + debug!(round = 0, r_round = %r_0); + r.push(r_0); + + // After sampling r_0, we need to evaluate the prismalinear polynomial at (r_0, x) for each x in + // hypercube. For each x in the hypercube, we have evaluations f(z, x) for z in the + // univariate skip domain D. We interpolate these to get a univariate polynomial and evaluate + // at r_0. + let mut current_evals = fold_ple_evals(l_skip, current_evals.as_view().into(), false, r_0); + debug_assert_eq!(current_evals.height(), 1 << n); + + // Sumcheck rounds: + // - each round the prover needs to compute univariate polynomial `s_round`. This poly is linear + // since we are taking MLE of `evals`. + // - at end of each round, sample random `r_round` in `EF` + for round in 1..=n { + debug!( + cur_sum = %current_evals + .values + .iter() + .fold(EF::ZERO, |acc, x| acc + *x) + ); + let [s] = sumcheck_round_poly_evals( + n + 1 - round, + 1, + &[current_evals.as_view()], + |_x, _y, evals| [evals[0][0]], + ); + assert_eq!(s.len(), 1); + transcript.observe_ext(s[0]); + round_polys_eval.push(s); + + let r_round = transcript.sample_ext(); + debug!(%round, %r_round); + r.push(r_round); + + current_evals = fold_mle_evals(current_evals, r_round); + } + + assert_eq!(r.len(), n + 1); + // After all rounds, current_evals should have exactly one element + assert_eq!(current_evals.values.len(), 1); + let eval_claim = current_evals.values[0]; + + // Add final evaluation to transcript + transcript.observe_ext(eval_claim); + + ( + SumcheckPrismProof { + sum_claim, + s_0: s_0_ext, + round_polys_eval, + eval_claim, + }, + r, + ) +} diff --git a/crates/stark-backend-v2/src/prover/types.rs b/crates/stark-backend-v2/src/prover/types.rs new file mode 100644 index 00000000..0ec42086 --- /dev/null +++ b/crates/stark-backend-v2/src/prover/types.rs @@ -0,0 +1,161 @@ +use std::{cmp::Reverse, sync::Arc}; + +use derivative::Derivative; +use openvm_stark_backend::{keygen::types::LinearConstraint, prover::MatrixDimensions}; + +use crate::{ + keygen::types::{MultiStarkVerifyingKey0V2, MultiStarkVerifyingKeyV2, StarkVerifyingKeyV2}, + proof::TraceVData, + prover::ProverBackendV2, + Digest, SystemParams, +}; + +/// The committed trace data for a single trace matrix. This type is used to store prover data for +/// both preprocessed trace and cached trace. +#[derive(Derivative)] +#[derivative(Clone(bound = "PB::Matrix: Clone"))] +pub struct CommittedTraceDataV2 { + /// The polynomial commitment. + pub commitment: PB::Commitment, + /// The trace matrix, unstacked, in evaluation form. + pub trace: PB::Matrix, + /// The PCS data for a single committed trace matrix. + pub data: Arc, +} + +/// The proving key for a circuit consisting of multiple AIRs, after prover-specific data has been +/// transferred to device. The host data (e.g., vkey) is owned by this struct. +/// +/// Ordering is always by AIR ID and includes all AIRs, including ones that may have empty traces. +#[derive(derive_new::new)] +pub struct DeviceMultiStarkProvingKeyV2 { + pub per_air: Vec>, + pub trace_height_constraints: Vec, + /// Maximum degree of constraints across all AIRs + pub max_constraint_degree: usize, + pub params: SystemParams, + pub vk_pre_hash: PB::Commitment, +} + +/// The proving key after prover-specific data has been transferred to device. The host data (e.g., +/// vkey) is owned by this struct. +pub struct DeviceStarkProvingKeyV2 { + /// Type name of the AIR, for display purposes only + pub air_name: String, + pub vk: StarkVerifyingKeyV2, + /// Prover only data for preprocessed trace + pub preprocessed_data: Option>, + pub other_data: PB::OtherAirData, +} + +#[derive(derive_new::new)] +pub struct ProvingContextV2 { + /// For each AIR with non-empty trace, the pair of (AIR ID, [AirProvingContextV2]), where AIR + /// ID is with respect to the vkey ordering. + pub per_trace: Vec<(usize, AirProvingContextV2)>, +} + +#[derive(derive_new::new)] +pub struct AirProvingContextV2 { + /// Cached main trace matrices as `PcsData`. The original trace matrix should be extractable as + /// a view from the `PcsData`. The `PcsData` should also contain the commitment value. Cached + /// trace commitments have a single matrix per commitment. + /// + /// The `PcsData` is kept inside an `Arc` to emphasize that this data is cached and may be + /// shared between multiple proving contexts. In particular, it is not typically safe to mutate + /// the data during a proving job. + pub cached_mains: Vec>, + /// Common main trace matrix + pub common_main: PB::Matrix, + /// Public values + pub public_values: Vec, +} + +/// Proof on the host, with respect to the host types in the generic `PB`. +pub struct HostProof { + /// The commitment to the data in common_main. + pub common_main_commit: PB::Commitment, + + /// For each AIR in vkey order, the corresponding trace shape, or None if + /// the trace is empty. In a valid proof, if `vk.per_air[i].is_required`, + /// then `trace_vdata[i]` must be `Some(_)`. + pub trace_vdata: Vec>, + + /// For each AIR in vkey order, the public values. Public values should be empty if the AIR has + /// an empty trace. + pub public_values: Vec>, + + pub constraints_proof: ConstraintsProof, + /// Opening proof for multiple polynomials over mixed sized domains + pub opening_proof: OpeningProof, +} + +impl CommittedTraceDataV2 { + #[inline(always)] + pub fn height(&self) -> usize { + self.trace.height() + } +} + +impl DeviceMultiStarkProvingKeyV2 +where + PB: ProverBackendV2, +{ + pub fn get_vk(&self) -> MultiStarkVerifyingKeyV2 { + let per_air = self.per_air.iter().map(|pk| pk.vk.clone()).collect(); + let inner = MultiStarkVerifyingKey0V2 { + params: self.params.clone(), + per_air, + trace_height_constraints: self.trace_height_constraints.clone(), + }; + MultiStarkVerifyingKeyV2 { + inner, + pre_hash: self.vk_pre_hash, + } + } +} + +impl IntoIterator for ProvingContextV2 { + type Item = (usize, AirProvingContextV2); + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.per_trace.into_iter() + } +} + +impl ProvingContextV2 { + pub fn common_main_traces(&self) -> impl Iterator { + self.per_trace + .iter() + .map(|(air_idx, air_ctx)| (*air_idx, &air_ctx.common_main)) + } + + // Returns `self` with the trace data sorted to be descending in height for column stacking. For + // equal heights, traces are sorted in ascending order of AIR index. + pub fn into_sorted(mut self) -> Self { + self.sort_for_stacking(); + self + } + + // Stable sort the trace data to be descending in height: this is needed for stacking. For + // equal heights, sort in ascending order of AIR index. + pub fn sort_for_stacking(&mut self) { + self.per_trace + .sort_by_key(|(air_idx, air_ctx)| (Reverse(air_ctx.common_main.height()), *air_idx)); + } +} + +impl AirProvingContextV2 { + pub fn simple(common_main_trace: PB::Matrix, public_values: Vec) -> Self { + Self::new(vec![], common_main_trace, public_values) + } + pub fn simple_no_pis(common_main_trace: PB::Matrix) -> Self { + Self::simple(common_main_trace, vec![]) + } + + /// Return the height of the main trace. + pub fn height(&self) -> usize { + self.common_main.height() + } +} diff --git a/crates/stark-backend-v2/src/prover/whir.rs b/crates/stark-backend-v2/src/prover/whir.rs new file mode 100644 index 00000000..6c57844f --- /dev/null +++ b/crates/stark-backend-v2/src/prover/whir.rs @@ -0,0 +1,304 @@ +use std::{iter::once, sync::Arc}; + +use itertools::Itertools; +use openvm_stark_backend::prover::MatrixDimensions; +use p3_dft::{Radix2DitParallel, TwoAdicSubgroupDft}; +use p3_field::{ExtensionField, Field, FieldAlgebra, TwoAdicField}; +use p3_maybe_rayon::prelude::*; +use p3_util::log2_strict_usize; +use tracing::instrument; + +use crate::{ + poly_common::Squarable, + poseidon2::sponge::FiatShamirTranscript, + proof::{MerkleProof, WhirProof}, + prover::{ + poly::{evals_eq_hypercube, Mle, Ple}, + stacked_pcs::{MerkleTree, StackedPcsData}, + ColMajorMatrix, CpuBackendV2, CpuDeviceV2, ProverBackendV2, + }, + Digest, WhirConfig, EF, F, +}; + +pub trait WhirProver { + /// Prove the WHIR protocol for a collection of MLE polynomials \hat{q}_j, each in n variables, + /// at a single vector `u \in \Fext^n`. + /// + /// This means applying WHIR with weight polynomial `\hat{w}(Z, \vec X) = Z * eq(\vec X, u)`. + /// + /// The matrices in `common_main_pcs_data` and `pre_cached_pcs_data_per_commit` must all have + /// the same height. + fn prove_whir( + &self, + transcript: &mut TS, + common_main_pcs_data: PB::PcsData, + pre_cached_pcs_data_per_commit: Vec>, + u_cube: &[PB::Challenge], + ) -> WhirProof; +} + +impl WhirProver for CpuDeviceV2 { + #[instrument(level = "info", skip_all)] + fn prove_whir( + &self, + transcript: &mut TS, + common_main_pcs_data: StackedPcsData, + pre_cached_pcs_data_per_commit: Vec>>, + u_cube: &[EF], + ) -> WhirProof { + let params = self.config(); + let committed_mats = once(&common_main_pcs_data) + .chain(pre_cached_pcs_data_per_commit.iter().map(|d| d.as_ref())) + .map(|d| (&d.matrix, &d.tree)) + .collect_vec(); + prove_whir_opening( + transcript, + params.l_skip, + params.log_blowup, + ¶ms.whir, + &committed_mats, + u_cube, + ) + } +} + +#[allow(clippy::too_many_arguments)] +pub fn prove_whir_opening( + transcript: &mut TS, + l_skip: usize, + log_blowup: usize, + whir_params: &WhirConfig, + committed_mats: &[(&ColMajorMatrix, &MerkleTree)], + u: &[EF], +) -> WhirProof { + // Sample randomness for algebraic batching. + // We batch the codewords for \hat{q}_j together _before_ applying WHIR. + let mu = transcript.sample_ext(); + let total_width = committed_mats.iter().map(|(mat, _)| mat.width()).sum(); + let mu_powers = mu.powers().take(total_width).collect_vec(); + + let height = committed_mats[0].0.height(); + debug_assert!(committed_mats.iter().all(|(mat, _)| mat.height() == height)); + let mut m = log2_strict_usize(height); + + let k_whir = whir_params.k; + let num_whir_rounds = whir_params.num_whir_rounds(); + let num_sumcheck_rounds = whir_params.num_sumcheck_rounds(); + + let mles: Vec> = committed_mats + .par_iter() + .flat_map(|(mat, _)| { + mat.par_columns().map(|col| { + let mut x = Ple::from_evaluations(l_skip, col).coeffs; + Mle::coeffs_to_evals_inplace(&mut x); + x + }) + }) + .collect(); + + // The evaluations of `\hat{f}` in the current WHIR round on the hypercube `H_m`. + let mut f_evals: Vec<_> = (0..1 << m) + .into_par_iter() + .map(|i| { + mles.iter() + .zip(mu_powers.iter()) + .fold(EF::ZERO, |acc, (mle_j, mu_j)| acc + *mu_j * mle_j[i]) + }) + .collect(); + + // We assume `\hat{w}` in a WHIR round is always multilinear and maintain its + // evaluations on `H_m`. + let mut w_evals = evals_eq_hypercube(u); + + let mut whir_sumcheck_polys: Vec<[EF; 2]> = Vec::with_capacity(num_sumcheck_rounds); + let mut codeword_commits = vec![]; + let mut ood_values = vec![]; + // per commitment, per whir query, per column + let mut initial_round_opened_rows: Vec>>> = vec![vec![]; committed_mats.len()]; + let mut initial_round_merkle_proofs: Vec> = vec![vec![]; committed_mats.len()]; + let mut codeword_opened_values: Vec>> = Vec::with_capacity(num_whir_rounds - 1); + let mut codeword_merkle_proofs: Vec> = Vec::with_capacity(num_whir_rounds - 1); + let mut folding_pow_witnesses = Vec::with_capacity(num_sumcheck_rounds); + let mut query_phase_pow_witnesses = Vec::with_capacity(num_whir_rounds); + let mut rs_tree = None; + let mut log_rs_domain_size = m + log_blowup; + let mut final_poly = None; + for (whir_round, round_params) in whir_params.rounds.iter().enumerate() { + let is_last_round = whir_round == num_whir_rounds - 1; + // Run k_whir rounds of sumcheck on `sum_{x in H_m} \hat{w}(\hat{f}(x), x)` + for round in 0..k_whir { + debug_assert_eq!(f_evals.len(), 1 << (m - round)); + + // \hat{f} * eq has degree 2 + let s_deg = 2; + let s_evals = (1..=s_deg) + .map(|x| { + let x = F::from_canonical_usize(x); + let hypercube_dim = m - round - 1; + (0..(1usize << hypercube_dim)) + .map(|y| { + let f_0 = f_evals[y << 1]; + let f_1 = f_evals[(y << 1) + 1]; + let f_x = f_0 + (f_1 - f_0) * x; + let w_0 = w_evals[y << 1]; + let w_1 = w_evals[(y << 1) + 1]; + let w_x = w_0 + (w_1 - w_0) * x; + f_x * w_x + }) + .fold(EF::ZERO, |acc, x| acc + x) + }) + .collect_vec(); + + for &eval in &s_evals { + transcript.observe_ext(eval); + } + whir_sumcheck_polys.push(s_evals.try_into().unwrap()); + + folding_pow_witnesses.push(transcript.grind(whir_params.folding_pow_bits)); + // Folding randomness + let alpha = transcript.sample_ext(); + + // Fold the evaluations + let half = f_evals.len() / 2; + for y in 0..half { + let eval_0 = f_evals[y << 1]; + let eval_1 = f_evals[(y << 1) + 1]; + // Linear interpolation at r_round + f_evals[y] = eval_0 + alpha * (eval_1 - eval_0); + + let eval_0 = w_evals[y << 1]; + let eval_1 = w_evals[(y << 1) + 1]; + w_evals[y] = eval_0 + alpha * (eval_1 - eval_0); + } + f_evals.truncate(half); + w_evals.truncate(half); + } + // Define g^ = f^(alpha, \cdot) and send matrix commit of RS(g^) + // f_evals is the evaluations of f^(alpha, \cdot) on hypercube + let g_mle = Mle::from_evaluations(&f_evals); + let (g_tree, z_0) = if !is_last_round { + let dft = Radix2DitParallel::default(); + let mut g_coeffs = g_mle.coeffs().to_vec(); + debug_assert_eq!(g_coeffs.len(), 1 << (m - k_whir)); + g_coeffs.resize(1 << (log_rs_domain_size - 1), EF::ZERO); + // `g: \mathcal{L}^{(2)} \to \mathbb F` + let g_rs = dft.dft(g_coeffs); + let g_tree = MerkleTree::new(ColMajorMatrix::new(g_rs, 1), 1 << k_whir); + let g_commit = g_tree.root(); + transcript.observe_commit(g_commit); + codeword_commits.push(g_commit); + + let z_0 = transcript.sample_ext(); + let z_0_vec = z_0.exp_powers_of_2().take(m - k_whir).collect_vec(); + let g_opened_value = g_mle.eval_at_point(&z_0_vec); + transcript.observe_ext(g_opened_value); + ood_values.push(g_opened_value); + + (Some(g_tree), Some(z_0)) + } else { + let coeffs = g_mle.into_coeffs(); + for coeff in &coeffs { + transcript.observe_ext(*coeff); + } + final_poly = Some(coeffs); + (None, None) + }; + + // omega is generator of RS domain `\mathcal{L}^{(2^k)}` + let omega = F::two_adic_generator(log_rs_domain_size - k_whir); + let num_queries = round_params.num_queries; + let mut query_indices = Vec::with_capacity(num_queries); + query_phase_pow_witnesses.push(transcript.grind(whir_params.query_phase_pow_bits)); + // Sample query indices first + for _ in 0..num_queries { + // This is the index of the leaf in the Merkle tree + let index = transcript.sample_bits(log_rs_domain_size - k_whir); + query_indices.push(index as usize); + } + let mut zs = Vec::with_capacity(num_queries); + if !is_last_round { + codeword_opened_values.push(vec![]); + codeword_merkle_proofs.push(vec![]); + } + for (query_idx, index) in query_indices.into_iter().enumerate() { + let z_i = omega.exp_u64(index as u64); + // Get merkle proofs for in-domain samples necessary to evaluate Fold(f, \vec + // \alpha)(z_i) + zs.push(z_i); + + let depth = log_rs_domain_size.saturating_sub(k_whir); + // Row openings are different between first WHIR round (width > 1) and other rounds + // (width = 1): + // NOTE: merkle proof is deterministic from the index and merkle root, so the opened_row + // and merkle proof are both hinted and not observed by the transcript. + if whir_round == 0 { + #[allow(clippy::needless_range_loop)] + for com_idx in 0..committed_mats.len() { + debug_assert_eq!(initial_round_merkle_proofs[com_idx].len(), query_idx); + let tree = &committed_mats[com_idx].1; + assert_eq!(tree.backing_matrix.height(), 1 << log_rs_domain_size); + let opened_rows = tree.get_opened_rows(index); + initial_round_opened_rows[com_idx].push(opened_rows); + debug_assert_eq!(tree.proof_depth(), depth); + let proof = tree.query_merkle_proof(index); + debug_assert_eq!(proof.len(), depth); + initial_round_merkle_proofs[com_idx].push(proof); + } + } else { + let tree: &MerkleTree = rs_tree.as_ref().unwrap(); + assert_eq!(tree.backing_matrix.width(), 1); + let opened_rows = tree + .get_opened_rows(index) + .into_iter() + .flatten() + .collect_vec(); + codeword_opened_values[whir_round - 1].push(opened_rows); + debug_assert_eq!(tree.proof_depth(), depth); + let proof = tree.query_merkle_proof(index); + debug_assert_eq!(proof.len(), depth); + codeword_merkle_proofs[whir_round - 1].push(proof); + } + } + rs_tree = g_tree; + + // We still sample on the last round to match the verifier, who uses a + // final gamma to unify some logic. But we do not need to update + // `w_evals`. + let gamma = transcript.sample_ext(); + + if !is_last_round { + // Update \hat{w} + w_evals_accumulate::(&mut w_evals, z_0.unwrap(), gamma); + for (z_i, gamma_pow) in zs.into_iter().zip(gamma.powers().skip(2)) { + w_evals_accumulate::(&mut w_evals, z_i, gamma_pow); + } + } + + m -= k_whir; + log_rs_domain_size -= 1; + } + + WhirProof { + whir_sumcheck_polys, + codeword_commits, + ood_values, + folding_pow_witnesses, + query_phase_pow_witnesses, + initial_round_opened_rows, + initial_round_merkle_proofs, + codeword_opened_values, + codeword_merkle_proofs, + final_poly: final_poly.unwrap(), + } +} + +/// Given hypercube evaluations `w_evals` of `\hat{w}` on `H_t`, this updates the evaluations +/// in place to be the evaluations of `\hat{w}'(x) = \hat{w}(x) + γ * eq(x, pow(z))`. +fn w_evals_accumulate>(w_evals: &mut [EF], z: F, gamma: EF) { + let dim = log2_strict_usize(w_evals.len()); + let z_pows = z.exp_powers_of_2().take(dim).collect_vec(); + let evals = evals_eq_hypercube(&z_pows); + for (w, x) in w_evals.iter_mut().zip(evals.into_iter()) { + *w += gamma * x; + } +} diff --git a/crates/stark-backend-v2/src/test_utils.rs b/crates/stark-backend-v2/src/test_utils.rs new file mode 100644 index 00000000..2e94a1ae --- /dev/null +++ b/crates/stark-backend-v2/src/test_utils.rs @@ -0,0 +1,604 @@ +use std::sync::Arc; + +use itertools::Itertools; +use openvm_stark_backend::{ + interaction::BusIndex, + prover::{MatrixDimensions, Prover}, + AirRef, +}; +pub use openvm_stark_sdk::dummy_airs::fib_air::air::FibonacciAir; +use openvm_stark_sdk::{ + any_rap_arc_vec, + config::{ + baby_bear_poseidon2::BabyBearPoseidon2Config, + log_up_params::log_up_security_params_baby_bear_100_bits, setup_tracing, + }, + dummy_airs::{ + self, + fib_selector_air::air::FibonacciSelectorAir, + interaction::{ + dummy_interaction_air::DummyInteractionAir, + self_interaction_air::{SelfInteractionAir, SelfInteractionChip}, + }, + }, +}; +use p3_baby_bear::BabyBear; +use p3_field::{FieldAlgebra, PrimeField32}; +use p3_matrix::dense::RowMajorMatrix; + +use crate::{ + keygen::types::{MultiStarkProvingKeyV2, MultiStarkVerifyingKeyV2}, + poseidon2::sponge::{ + DuplexSponge, DuplexSpongeRecorder, FiatShamirTranscript, TranscriptHistory, TranscriptLog, + }, + proof::Proof, + prover::{ + stacked_pcs::stacked_commit, AirProvingContextV2, ColMajorMatrix, CommittedTraceDataV2, + CpuBackendV2, DeviceDataTransporterV2, DeviceMultiStarkProvingKeyV2, MultiRapProver, + ProvingContextV2, TraceCommitterV2, + }, + BabyBearPoseidon2CpuEngineV2, ChipV2, StarkEngineV2, SystemParams, WhirConfig, WhirParams, F, +}; + +#[allow(clippy::type_complexity)] +pub fn prove_up_to_batch_constraints( + engine: &E, + transcript: &mut E::TS, + pk: &DeviceMultiStarkProvingKeyV2, + ctx: ProvingContextV2, +) -> ( + >::PartialProof, + >::Artifacts, +) { + let (_, common_main_pcs_data) = engine.device().commit( + &ctx.common_main_traces() + .map(|(_, trace)| trace) + .collect_vec(), + ); + engine + .device() + .prove_rap_constraints(transcript, pk, &ctx, &common_main_pcs_data) +} + +fn get_fib_number(mut a: u32, mut b: u32, n: usize) -> u32 { + for _ in 0..n - 1 { + let c = (a + b) % BabyBear::ORDER_U32; + a = b; + b = c; + } + b +} + +fn get_conditional_fib_number(mut a: u32, mut b: u32, sels: &[bool]) -> u32 { + for &s in sels[0..sels.len() - 1].iter() { + if s { + let c = (a + b) % BabyBear::ORDER_U32; + a = b; + b = c; + } + } + b +} + +/// Trait for object responsible for generating the collection of AIRs and trace matrices for a +/// single test case. +pub trait TestFixture { + fn airs(&self) -> Vec>; + + fn generate_proving_ctx(&self) -> ProvingContextV2; + + fn keygen( + &self, + engine: &E, + ) -> (MultiStarkProvingKeyV2, MultiStarkVerifyingKeyV2) { + engine.keygen(&self.airs()) + } + + fn prove(&self, engine: &E, pk: &MultiStarkProvingKeyV2) -> Proof { + self.prove_from_transcript(engine, pk, &mut E::TS::default()) + } + + /// Prove using CPU tracegen and transport to device. + fn prove_from_transcript( + &self, + engine: &E, + pk: &MultiStarkProvingKeyV2, + transcript: &mut E::TS, + ) -> Proof { + let ctx = self.generate_proving_ctx(); + let device = engine.device(); + let d_pk = device.transport_pk_to_device(pk); + let d_ctx = device.transport_proving_ctx_to_device(&ctx); + let mut prover = engine.prover_from_transcript(transcript.clone()); + let proof = prover.prove(&d_pk, d_ctx); + *transcript = prover.transcript; + proof + } + + fn keygen_and_prove(&self, engine: &E) -> (MultiStarkVerifyingKeyV2, Proof) { + let (pk, vk) = self.keygen(engine); + let proof = self.prove(engine, &pk); + (vk, proof) + } +} + +pub struct FibFixture { + pub a: u32, + pub b: u32, + pub n: usize, + pub num_airs: usize, + pub empty_air_indices: Vec, +} + +impl FibFixture { + pub fn new(a: u32, b: u32, n: usize) -> Self { + FibFixture { + a, + b, + n, + num_airs: 1, + empty_air_indices: vec![], + } + } + + pub fn new_with_num_airs(a: u32, b: u32, n: usize, num_airs: usize) -> Self { + FibFixture { + a, + b, + n, + num_airs, + empty_air_indices: vec![], + } + } + + pub fn with_empty_air_indices(mut self, empty_air_indices: impl Into>) -> Self { + self.empty_air_indices = empty_air_indices.into(); + self + } +} + +impl TestFixture for FibFixture { + fn airs(&self) -> Vec> { + let air = Arc::new(FibonacciAir); + vec![air; self.num_airs] + } + + fn generate_proving_ctx(&self) -> ProvingContextV2 { + use dummy_airs::fib_air::trace::generate_trace_rows; + let f_n = get_fib_number(self.a, self.b, self.n); + let pis = [self.a, self.b, f_n].map(BabyBear::from_canonical_u32); + + ProvingContextV2::new( + (0..self.num_airs) + .filter(|i| !self.empty_air_indices.contains(i)) + .map(|i| { + ( + i, + AirProvingContextV2::simple( + ColMajorMatrix::from_row_major(&generate_trace_rows::( + self.a, self.b, self.n, + )), + pis.to_vec(), + ), + ) + }) + .collect_vec(), + ) + } +} + +/// Interactions fixture with 1 sender and 1 receiver +pub struct InteractionsFixture11; + +impl TestFixture for InteractionsFixture11 { + fn airs(&self) -> Vec> { + let sender_air = DummyInteractionAir::new(1, true, 0); + let receiver_air = DummyInteractionAir::new(1, false, 0); + any_rap_arc_vec!(sender_air, receiver_air) + } + + fn generate_proving_ctx(&self) -> ProvingContextV2 { + // Default traces + // Sender (2 columns: Mul, Val): + // 0 1 + // 7 4 + // 3 5 + // 546 889 + let sender_trace = RowMajorMatrix::new( + [0, 1, 3, 5, 7, 4, 546, 889] + .into_iter() + .map(BabyBear::from_canonical_usize) + .collect(), + 2, + ); + + // Receiver (2 columns: Mul, Val): + // 1 5 + // 3 4 + // 4 4 + // 2 5 + // 0 123 + // 545 889 + // 1 889 + // 0 456 + let receiver_trace = RowMajorMatrix::new( + [1, 5, 3, 4, 4, 4, 2, 5, 0, 123, 545, 889, 1, 889, 0, 456] + .into_iter() + .map(BabyBear::from_canonical_usize) + .collect(), + 2, + ); + + ProvingContextV2::new( + [sender_trace, receiver_trace] + .into_iter() + .enumerate() + .map(|(air_idx, trace)| { + ( + air_idx, + AirProvingContextV2::simple_no_pis(ColMajorMatrix::from_row_major(&trace)), + ) + }) + .collect(), + ) + } +} + +/// Dummy interaction AIRs with cached trace: 1 sender, 1 receiver +#[derive(derive_new::new)] +pub struct CachedFixture11 { + pub params: SystemParams, +} + +impl TestFixture for CachedFixture11 { + fn airs(&self) -> Vec> { + let sender_air = DummyInteractionAir::new(1, true, 0).partition(); + let receiver_air = DummyInteractionAir::new(1, false, 0).partition(); + any_rap_arc_vec!(sender_air, receiver_air) + } + + fn generate_proving_ctx(&self) -> ProvingContextV2 { + // Default traces + // Sender (2 columns: Mul, Val): + // 0 1 + // 3 5 + // 7 4 + // 546 889 + let sender_trace = ColMajorMatrix::new( + [0, 3, 7, 546] + .into_iter() + .map(BabyBear::from_canonical_usize) + .collect(), + 1, + ); + let sender_cached_trace = ColMajorMatrix::new( + [1, 5, 4, 889] + .into_iter() + .map(BabyBear::from_canonical_usize) + .collect(), + 1, + ); + + // Receiver (2 columns: Mul, Val): + // 1 5 + // 3 4 + // 4 4 + // 2 5 + // 0 123 + // 545 889 + // 1 889 + // 0 456 + let receiver_trace = ColMajorMatrix::new( + [1, 3, 4, 2, 0, 545, 1, 0] + .into_iter() + .map(BabyBear::from_canonical_usize) + .collect(), + 1, + ); + let receiver_cached_trace = ColMajorMatrix::new( + [5, 4, 4, 5, 123, 889, 889, 456] + .into_iter() + .map(BabyBear::from_canonical_usize) + .collect(), + 1, + ); + + let params = &self.params; + ProvingContextV2::new( + [ + (sender_trace, sender_cached_trace), + (receiver_trace, receiver_cached_trace), + ] + .map(|(common, cached)| { + let (commit, data) = stacked_commit( + params.l_skip, + params.n_stack, + params.log_blowup, + params.k_whir(), + &[&cached], + ); + assert_eq!(common.height(), cached.height()); + let cached_data = CommittedTraceDataV2 { + commitment: commit, + trace: cached, + data: Arc::new(data), + }; + AirProvingContextV2 { + cached_mains: vec![cached_data], + common_main: common, + public_values: vec![], + } + }) + .into_iter() + .enumerate() + .collect(), + ) + } +} + +#[derive(derive_new::new)] +pub struct PreprocessedFibFixture { + pub a: u32, + pub b: u32, + pub sels: Vec, +} + +impl TestFixture for PreprocessedFibFixture { + fn airs(&self) -> Vec> { + let air = Arc::new(FibonacciSelectorAir::new(self.sels.clone(), false)); + vec![air] + } + + fn generate_proving_ctx(&self) -> ProvingContextV2 { + use openvm_stark_sdk::dummy_airs::fib_selector_air::trace::generate_trace_rows; + let trace = generate_trace_rows(self.a, self.b, &self.sels); + let f_n = get_conditional_fib_number(self.a, self.b, &self.sels); + let pis = [self.a, self.b, f_n].map(BabyBear::from_canonical_u32); + + let single_ctx = + AirProvingContextV2::simple(ColMajorMatrix::from_row_major(&trace), pis.to_vec()); + ProvingContextV2::new(vec![(0, single_ctx)]) + } +} + +#[derive(derive_new::new)] +pub struct SelfInteractionFixture { + pub widths: Vec, + pub log_height: usize, + pub bus_index: BusIndex, +} + +impl TestFixture for SelfInteractionFixture { + fn airs(&self) -> Vec> { + self.widths + .iter() + .map(|&width| { + Arc::new(SelfInteractionAir { + width, + bus_index: self.bus_index, + }) as AirRef + }) + .collect_vec() + } + + fn generate_proving_ctx(&self) -> ProvingContextV2 { + let per_trace = self + .widths + .iter() + .map(|&width| { + let chip = SelfInteractionChip { + width, + log_height: self.log_height, + }; + chip.generate_proving_ctx(()) + }) + .enumerate() + .collect_vec(); + ProvingContextV2 { per_trace } + } +} + +pub struct MixtureFixture { + pub fxs: Vec, +} + +pub enum MixtureFixtureEnum { + FibFixture(FibFixture), + InteractionsFixture11(InteractionsFixture11), + CachedFixture11(CachedFixture11), + PreprocessedFibFixture(PreprocessedFibFixture), + SelfInteractionFixture(SelfInteractionFixture), +} + +impl MixtureFixtureEnum { + fn airs(&self) -> Vec> { + use crate::test_utils::MixtureFixtureEnum::*; + match self { + FibFixture(fx) => fx.airs(), + InteractionsFixture11(fx) => fx.airs(), + CachedFixture11(fx) => fx.airs(), + PreprocessedFibFixture(fx) => fx.airs(), + SelfInteractionFixture(fx) => fx.airs(), + } + } + + fn generate_air_proving_ctxs(&self) -> Vec> { + use crate::test_utils::MixtureFixtureEnum::*; + let ctx = match self { + FibFixture(fx) => fx.generate_proving_ctx(), + InteractionsFixture11(fx) => fx.generate_proving_ctx(), + CachedFixture11(fx) => fx.generate_proving_ctx(), + PreprocessedFibFixture(fx) => fx.generate_proving_ctx(), + SelfInteractionFixture(fx) => fx.generate_proving_ctx(), + }; + ctx.per_trace + .into_iter() + .map(|(_, air_ctx)| air_ctx) + .collect_vec() + } +} + +impl MixtureFixture { + pub fn new(fxs: Vec) -> Self { + Self { fxs } + } + + pub fn standard(log_height: usize, params: SystemParams) -> Self { + let height = 1usize << log_height; + let sels = (0..height).map(|i| i % 2 == 0).collect_vec(); + let widths = vec![4, 7, 8, 8, 10, 100]; + Self::new(vec![ + MixtureFixtureEnum::FibFixture(FibFixture::new(8, 8, height)), + MixtureFixtureEnum::InteractionsFixture11(InteractionsFixture11), + MixtureFixtureEnum::CachedFixture11(CachedFixture11::new(params)), + MixtureFixtureEnum::PreprocessedFibFixture(PreprocessedFibFixture::new(7, 3, sels)), + MixtureFixtureEnum::SelfInteractionFixture(SelfInteractionFixture::new( + widths, log_height, 5, + )), + ]) + } +} + +impl TestFixture for MixtureFixture { + fn airs(&self) -> Vec> { + self.fxs.iter().flat_map(|fx| fx.airs()).collect_vec() + } + + fn generate_proving_ctx(&self) -> ProvingContextV2 { + let per_trace = self + .fxs + .iter() + .flat_map(|fx| fx.generate_air_proving_ctxs()) + .enumerate() + .collect_vec(); + ProvingContextV2 { per_trace } + } +} + +impl SystemParams { + /// Parameters for testing traces of height up to `2^log_trace_height` with **toy security + /// parameters** for faster testing. + /// + /// **These parameters should not be used in production!** + pub fn new_for_testing(log_trace_height: usize) -> Self { + let l_skip = 4; + let k_whir = 4; + let mut params = test_system_params_small(l_skip, log_trace_height - l_skip, k_whir); + params.max_constraint_degree = 4; + params + } +} + +/// Trace heights cannot exceed 2^{l_skip + n_stack} when using these system params. +pub fn test_system_params_small(l_skip: usize, n_stack: usize, k_whir: usize) -> SystemParams { + let log_final_poly_len = (n_stack + l_skip) % k_whir; + test_system_params_small_with_poly_len(l_skip, n_stack, k_whir, log_final_poly_len, 3) +} + +pub fn test_system_params_small_with_poly_len( + l_skip: usize, + n_stack: usize, + k_whir: usize, + log_final_poly_len: usize, + max_constraint_degree: usize, +) -> SystemParams { + assert!(log_final_poly_len < l_skip + n_stack); + let log_blowup = 1; + // Use all different numbers + SystemParams { + l_skip, + n_stack, + log_blowup, + whir: test_whir_config_small(log_blowup, l_skip + n_stack, k_whir, log_final_poly_len), + logup: log_up_security_params_baby_bear_100_bits(), + max_constraint_degree, + } +} + +pub fn test_whir_config_small( + log_blowup: usize, + log_stacked_height: usize, + k_whir: usize, + log_final_poly_len: usize, +) -> WhirConfig { + let params = WhirParams { + k: k_whir, + log_final_poly_len, + query_phase_pow_bits: 1, + }; + let security_bits = 5; + WhirConfig::new(log_blowup, log_stacked_height, params, security_bits) +} + +pub fn default_test_params_small() -> SystemParams { + test_system_params_small(2, 8, 3) +} + +pub fn test_engine_small() -> BabyBearPoseidon2CpuEngineV2 { + setup_tracing(); + BabyBearPoseidon2CpuEngineV2::new(default_test_params_small()) +} + +#[derive(Clone)] +pub struct DuplexSpongeValidator { + pub inner: DuplexSpongeRecorder, + pub idx: usize, + log: TranscriptLog, +} + +impl DuplexSpongeValidator { + pub fn new(log: TranscriptLog) -> Self { + debug_assert_eq!(log.len(), log.samples().len()); + Self { + inner: Default::default(), + idx: 0, + log, + } + } +} + +impl FiatShamirTranscript for DuplexSpongeValidator { + fn observe(&mut self, x: F) { + debug_assert!(self.idx < self.log.len(), "transcript replay overflow"); + assert!(!self.log.samples()[self.idx]); + let exp_x = self.log[self.idx]; + assert_eq!(x, exp_x); + self.idx += 1; + self.inner.observe(x); + } + + fn sample(&mut self) -> F { + debug_assert!(self.idx < self.log.len(), "transcript replay overflow"); + assert!(self.log.samples()[self.idx]); + let x = self.inner.sample(); + let exp_x = self.log[self.idx]; + self.idx += 1; + assert_eq!(x, exp_x); + x + } +} + +impl TranscriptHistory for DuplexSpongeValidator { + fn len(&self) -> usize { + self.inner.len() + } + + fn into_log(self) -> TranscriptLog { + debug_assert_eq!(self.inner.len(), self.log.len()); + debug_assert_eq!( + self.inner.len(), + self.idx, + "transcript replay ended with {} of {} entries consumed", + self.idx, + self.inner.len() + ); + debug_assert_eq!( + self.log.len(), + self.idx, + "transcript replay ended with {} of {} entries consumed", + self.idx, + self.log.len() + ); + self.inner.into_log() + } +} diff --git a/crates/stark-backend-v2/src/tests.rs b/crates/stark-backend-v2/src/tests.rs new file mode 100644 index 00000000..5f2a1c15 --- /dev/null +++ b/crates/stark-backend-v2/src/tests.rs @@ -0,0 +1,469 @@ +use itertools::Itertools; +use openvm_stark_backend::prover::MatrixDimensions; +use openvm_stark_sdk::config::{ + log_up_params::log_up_security_params_baby_bear_100_bits, setup_tracing, + setup_tracing_with_log_level, +}; +use p3_field::{FieldAlgebra, PrimeField32, TwoAdicField}; +use p3_matrix::dense::RowMajorMatrix; +use p3_util::log2_strict_usize; +use rand::{rngs::StdRng, Rng, SeedableRng}; +use test_case::test_case; +use tracing::{debug, Level}; + +use crate::{ + poseidon2::sponge::{DuplexSponge, FiatShamirTranscript}, + prover::{ + stacked_pcs::stacked_commit, + stacked_reduction::{prove_stacked_opening_reduction, StackedReductionCpu}, + sumcheck::{sumcheck_multilinear, sumcheck_prismalinear}, + AirProvingContextV2, ColMajorMatrix, DeviceDataTransporterV2, MultiRapProver, + ProvingContextV2, + }, + test_utils::{ + prove_up_to_batch_constraints, test_engine_small, test_system_params_small, + CachedFixture11, FibFixture, InteractionsFixture11, MixtureFixture, PreprocessedFibFixture, + SelfInteractionFixture, TestFixture, + }, + verifier::{ + batch_constraints::{verify_zerocheck_and_logup, BatchConstraintError}, + fractional_sumcheck_gkr::verify_gkr, + proof_shape::{verify_proof_shape, ProofShapeError}, + stacked_reduction::{verify_stacked_reduction, StackedReductionError}, + sumcheck::{verify_sumcheck_multilinear, verify_sumcheck_prismalinear}, + }, + BabyBearPoseidon2CpuEngineV2, StarkEngineV2, SystemParams, WhirConfig, WhirRoundConfig, F, +}; + +#[test] +fn test_plain_multilinear_sumcheck() -> Result<(), String> { + let n = 15; + let mut rng = StdRng::from_seed([228; 32]); + + let num_pts = 1 << n; + assert!((F::ORDER_U32 - 1) % num_pts == 0); + + let evals = (0..num_pts) + .map(|_| F::from_canonical_u32(rng.random_range(0..F::ORDER_U32))) + .collect::>(); + let mut prover_sponge = DuplexSponge::default(); + let mut verifier_sponge = DuplexSponge::default(); + + let (proof, _) = sumcheck_multilinear(&mut prover_sponge, &evals); + verify_sumcheck_multilinear::(&mut verifier_sponge, &proof) +} + +#[test] +fn test_plain_prismalinear_sumcheck() -> Result<(), String> { + setup_tracing(); + let n = 5; + let l_skip = 10; + let mut rng = StdRng::from_seed([228; 32]); + + let dim = n + l_skip; + let num_pts = 1 << dim; + assert!((F::ORDER_U32 - 1) % num_pts == 0); + + let evals = (0..num_pts) + .map(|_| F::from_canonical_u32(rng.random_range(0..F::ORDER_U32))) + .collect::>(); + let mut prover_sponge = DuplexSponge::default(); + let mut verifier_sponge = DuplexSponge::default(); + + let (proof, _) = sumcheck_prismalinear(&mut prover_sponge, l_skip, &evals); + verify_sumcheck_prismalinear::(&mut verifier_sponge, l_skip, &proof) +} + +#[test] +fn test_proof_shape_verifier() -> Result<(), ProofShapeError> { + setup_tracing(); + let log_trace_degree = 3; + + // without interactions + let engine = test_engine_small(); + let (vk, proof) = FibFixture::new(0, 1, 1 << log_trace_degree).keygen_and_prove(&engine); + verify_proof_shape(&vk.inner, &proof)?; + + // with interactions + let (vk, proof) = InteractionsFixture11.keygen_and_prove(&engine); + verify_proof_shape(&vk.inner, &proof)?; + + // with cached trace + let params = engine.config().clone(); + let (vk, proof) = CachedFixture11::new(params).keygen_and_prove(&engine); + verify_proof_shape(&vk.inner, &proof)?; + + // with preprocessed trace + let height = 1 << log_trace_degree; + let sels = (0..height).map(|i| i % 2 == 0).collect_vec(); + let (vk, proof) = PreprocessedFibFixture::new(0, 1, sels).keygen_and_prove(&engine); + verify_proof_shape(&vk.inner, &proof)?; + + Ok(()) +} + +#[test] +fn test_proof_shape_verifier_rng_system_params() -> Result<(), ProofShapeError> { + setup_tracing(); + let mut rng = StdRng::from_seed([228; 32]); + for _ in 0..10 { + let l_skip = rng.random_range(1usize..=2); + let n_stack = rng.random_range(8usize..=9); + let k_whir = rng.random_range(1usize..=4); + let log_blowup = rng.random_range(1usize..=3); + let num_whir_rounds = rng.random_range(1..=2); + let mut rounds = Vec::with_capacity(num_whir_rounds); + for _ in 0..num_whir_rounds { + rounds.push(WhirRoundConfig { + num_queries: rng.random_range(1..=10), + }); + } + let whir = WhirConfig { + k: k_whir, + rounds, + query_phase_pow_bits: 1, + folding_pow_bits: 1, + }; + let params = SystemParams { + l_skip, + n_stack, + log_blowup, + whir, + logup: log_up_security_params_baby_bear_100_bits(), + max_constraint_degree: 3, + }; + let engine = BabyBearPoseidon2CpuEngineV2::::new(params); + let (vk, proof) = InteractionsFixture11.keygen_and_prove(&engine); + verify_proof_shape(&vk.inner, &proof)?; + } + Ok(()) +} + +#[test_case(4)] +#[test_case(2 ; "when log_height equals l_skip")] +#[test_case(1 ; "when log_height less than l_skip")] +#[test_case(0 ; "when log_height is zero")] +fn test_batch_sumcheck_zero_interactions( + log_trace_degree: usize, +) -> Result<(), BatchConstraintError> { + setup_tracing_with_log_level(Level::DEBUG); + + let engine = test_engine_small(); + let params = engine.config(); + let fib = FibFixture::new(0, 1, 1 << log_trace_degree); + let (pk, vk) = fib.keygen(&engine); + let pk = engine.device().transport_pk_to_device(&pk); + let ctx = fib.generate_proving_ctx(); + + let mut n_per_air: Vec = Vec::with_capacity(ctx.per_trace.len()); + for (_, trace) in ctx.common_main_traces() { + let trace_height = trace.height(); + let prism_dim = log2_strict_usize(trace_height); + let n = prism_dim as isize - params.l_skip as isize; + n_per_air.push(n); + } + + let mut prover_sponge = DuplexSponge::default(); + let mut verifier_sponge = DuplexSponge::default(); + + let omega_skip = F::two_adic_generator(params.l_skip); + let omega_skip_pows = omega_skip.powers().take(1 << params.l_skip).collect_vec(); + + let pvs = vec![ctx.per_trace[0].1.public_values.clone()]; + let ((gkr_proof, batch_proof), _) = + prove_up_to_batch_constraints(&engine, &mut prover_sponge, &pk, ctx); + + let r = verify_zerocheck_and_logup( + &mut verifier_sponge, + &vk.inner, + &pvs, + &gkr_proof, + &batch_proof, + &[0], + &n_per_air, + &omega_skip_pows, + )?; + assert_eq!(r.len(), log_trace_degree.saturating_sub(params.l_skip) + 1); + Ok(()) +} + +#[test_case(9)] +#[test_case(2 ; "when log_height equals l_skip")] +#[test_case(1 ; "when log_height less than l_skip")] +#[test_case(0 ; "when log_height is zero")] +fn test_stacked_opening_reduction(log_trace_degree: usize) -> Result<(), StackedReductionError> { + setup_tracing_with_log_level(Level::DEBUG); + + let engine = test_engine_small(); + let params = engine.config(); + + let engine = BabyBearPoseidon2CpuEngineV2::::new(params.clone()); + let fib = FibFixture::new(0, 1, 1 << log_trace_degree); + let (pk, _vk) = fib.keygen(&engine); + let pk = engine.device().transport_pk_to_device(&pk); + let mut ctx = fib.generate_proving_ctx(); + + ctx.sort_for_stacking(); + + let (_, common_main_pcs_data) = { + stacked_commit( + params.l_skip, + params.n_stack, + params.log_blowup, + params.k_whir(), + &ctx.common_main_traces() + .map(|(_, trace)| trace) + .collect_vec(), + ) + }; + + let omega_skip = F::two_adic_generator(params.l_skip); + let omega_skip_pows = omega_skip.powers().take(1 << params.l_skip).collect_vec(); + + let device = engine.device(); + // We need batch_proof to obtain the column openings + let ((_, batch_proof), r) = device.prove_rap_constraints( + &mut DuplexSponge::default(), + &pk, + &ctx, + &common_main_pcs_data, + ); + + let need_rot = pk.per_air[ctx.per_trace[0].0].vk.params.need_rot; + let need_rot_per_commit = vec![vec![need_rot]]; + let (stacking_proof, _) = prove_stacked_opening_reduction::<_, _, _, StackedReductionCpu>( + device, + &mut DuplexSponge::default(), + params.n_stack, + vec![&common_main_pcs_data], + need_rot_per_commit.clone(), + &r, + ); + + debug!(?batch_proof.column_openings); + + let u_prism = verify_stacked_reduction( + &mut DuplexSponge::default(), + &stacking_proof, + &[common_main_pcs_data.layout], + &need_rot_per_commit, + params.l_skip, + params.n_stack, + &batch_proof.column_openings, + &r, + &omega_skip_pows, + )?; + assert_eq!(u_prism.len(), params.n_stack + 1); + Ok(()) +} +#[test_case(3)] +#[test_case(2 ; "when fib log_height equals l_skip")] +#[test_case(1 ; "when fib log_height less than l_skip")] +#[test_case(0 ; "when fib log_height is zero")] +fn test_single_fib_and_dummy_trace_stark(log_trace_degree: usize) { + setup_tracing(); + + // 1. Create parameters + let engine = test_engine_small(); + + // 2. Create interactions fixture with larger trace - generate custom traces + let sender_height = 2 * (1 << 3); + let sender_trace = RowMajorMatrix::new( + [0, 1, 3, 5, 7, 4, 546, 889] + .into_iter() + .cycle() + .take(2 * sender_height) + .map(F::from_canonical_usize) + .collect(), + 2, + ); + let receiver_trace = RowMajorMatrix::new( + [1, 5, 3, 4, 4, 4, 2, 5, 0, 123, 545, 889, 1, 889, 0, 456] + .into_iter() + .cycle() + .take(4 * sender_height) + .map(F::from_canonical_usize) + .collect(), + 2, + ); + + // 3. Create fibonacci fixture with small trace + let height = 2 * (1 << log_trace_degree); + let fib = FibFixture::new(0, 1, height); + + // 4. Generate AIRs and proving keys + let fx_fixture = InteractionsFixture11; + let fx_airs = fx_fixture.airs(); + let fib_airs = fib.airs(); + let mut combined_airs = fx_airs; + combined_airs.extend(fib_airs); + let (combined_pk, _combined_vk) = engine.keygen(&combined_airs); + let combined_pk = engine.device().transport_pk_to_device(&combined_pk); + + // 5. Generate custom contexts for interactions with modified traces + let mut per_trace: Vec<_> = [sender_trace, receiver_trace] + .into_iter() + .enumerate() + .map(|(air_idx, trace)| { + ( + air_idx, + AirProvingContextV2::simple_no_pis(ColMajorMatrix::from_row_major(&trace)), + ) + }) + .collect(); + let fib_ctx = fib.generate_proving_ctx().per_trace.pop().unwrap().1; + + // 6. Update air_ids in fib context and combine contexts + per_trace.push((per_trace.len(), fib_ctx)); + let combined_ctx = ProvingContextV2::new(per_trace).into_sorted(); + + let proof = engine.prove(&combined_pk, combined_ctx); + engine.verify(&combined_pk.get_vk(), &proof).unwrap(); +} + +#[test] +fn test_interactions_single_sender_receiver_happy() { + setup_tracing(); + + let engine = test_engine_small(); + let fx = InteractionsFixture11; + let (vk, proof) = fx.keygen_and_prove(&engine); + engine.verify(&vk, &proof).unwrap(); +} + +#[test] +fn test_single_cached_trace_stark() { + setup_tracing(); + let engine = test_engine_small(); + let fx = CachedFixture11::new(engine.config().clone()); + let (vk, proof) = fx.keygen_and_prove(&engine); + engine.verify(&vk, &proof).unwrap(); +} + +#[test_case(10 ; "when log_height equals n_stack l_skip")] +#[test_case(3 ; "when log_height greater than l_skip")] +#[test_case(2 ; "when log_height equals l_skip")] +#[test_case(1 ; "when log_height less than l_skip")] +#[test_case(0 ; "when log_height is zero")] +fn test_single_preprocessed_trace_stark(log_trace_degree: usize) { + setup_tracing(); + let engine = test_engine_small(); + let height = 1 << log_trace_degree; + let sels = (0..height).map(|i| i % 2 == 0).collect_vec(); + let fx = PreprocessedFibFixture::new(0, 1, sels); + let (vk, proof) = fx.keygen_and_prove(&engine); + engine.verify(&vk, &proof).unwrap(); +} + +#[test_case(10 ; "when log_height equals n_stack l_skip")] +#[test_case(3 ; "when log_height greater than l_skip")] +#[test_case(2 ; "when log_height equals l_skip")] +#[test_case(1 ; "when log_height less than l_skip")] +#[test_case(0 ; "when log_height is zero")] +fn test_multi_interaction_traces_stark(log_trace_degree: usize) { + setup_tracing(); + let engine = test_engine_small(); + let fx = SelfInteractionFixture { + widths: vec![4, 7, 8, 8, 10, 100], + log_height: log_trace_degree, + bus_index: 4, + }; + let (vk, proof) = fx.keygen_and_prove(&engine); + engine.verify(&vk, &proof).unwrap(); +} + +#[test_case(10 ; "when log_height equals n_stack l_skip")] +#[test_case(3 ; "when log_height greater than l_skip")] +#[test_case(2 ; "when log_height equals l_skip")] +#[test_case(1 ; "when log_height less than l_skip")] +#[test_case(0 ; "when log_height is zero")] +fn test_mixture_traces_stark(log_trace_degree: usize) { + setup_tracing(); + let engine = test_engine_small(); + let fx = MixtureFixture::standard(log_trace_degree, engine.config().clone()); + let (vk, proof) = fx.keygen_and_prove(&engine); + engine.verify(&vk, &proof).unwrap(); +} + +#[test] +fn test_gkr_verify_zero_interactions() -> eyre::Result<()> { + setup_tracing_with_log_level(Level::DEBUG); + + let engine = test_engine_small(); + let params = engine.config(); + let fx = InteractionsFixture11; + let (pk, _vk) = fx.keygen(&engine); + let pk = engine.device().transport_pk_to_device(&pk); + let ctx = fx.generate_proving_ctx().into_sorted(); + let mut transcript = DuplexSponge::default(); + let ((gkr_proof, _), _) = prove_up_to_batch_constraints(&engine, &mut transcript, &pk, ctx); + + let mut transcript = DuplexSponge::default(); + assert!(transcript.check_witness(params.logup.pow_bits, gkr_proof.logup_pow_witness)); + let _alpha = transcript.sample_ext(); + let _beta = transcript.sample_ext(); + let total_rounds = gkr_proof.claims_per_layer.len(); + verify_gkr(&gkr_proof, &mut transcript, total_rounds)?; + + Ok(()) +} + +#[test] +fn test_batch_constraints_with_interactions() -> eyre::Result<()> { + setup_tracing_with_log_level(Level::DEBUG); + + let engine = test_engine_small(); + let fx = InteractionsFixture11; + let (pk, vk) = fx.keygen(&engine); + let pk = engine.device().transport_pk_to_device(&pk); + let ctx = fx.generate_proving_ctx().into_sorted(); + let l_skip = engine.device().config().l_skip; + let mut pvs = vec![vec![]; vk.inner.per_air.len()]; + let (trace_id_to_air_ids, ns): (Vec<_>, Vec<_>) = ctx + .per_trace + .iter() + .map(|(air_idx, air_ctx)| { + pvs[*air_idx] = air_ctx.public_values.clone(); + ( + *air_idx, + log2_strict_usize(air_ctx.common_main.height()) as isize - l_skip as isize, + ) + }) + .multiunzip(); + debug!(?trace_id_to_air_ids); + debug!(n_per_trace = ?ns); + let omega_pows = F::two_adic_generator(l_skip) + .powers() + .take(1 << l_skip) + .collect_vec(); + + let mut transcript = DuplexSponge::default(); + let ((gkr_proof, batch_proof), _) = + prove_up_to_batch_constraints(&engine, &mut transcript, &pk, ctx); + let mut transcript = DuplexSponge::default(); + verify_zerocheck_and_logup( + &mut transcript, + &vk.inner, + &pvs, + &gkr_proof, + &batch_proof, + &trace_id_to_air_ids, + &ns, + &omega_pows, + )?; + Ok(()) +} + +#[test] +fn test_matrix_stacking_overflow() { + setup_tracing(); + let params = test_system_params_small(3, 5, 3); + let engine = BabyBearPoseidon2CpuEngineV2::::new(params); + let fx = SelfInteractionFixture { + widths: vec![4, 7, 8, 8, 10], + log_height: 1, + bus_index: 4, + }; + let (vk, proof) = fx.keygen_and_prove(&engine); + engine.verify(&vk, &proof).unwrap(); +} diff --git a/crates/stark-backend-v2/src/utils/batch_inverse.rs b/crates/stark-backend-v2/src/utils/batch_inverse.rs new file mode 100644 index 00000000..d833d6b1 --- /dev/null +++ b/crates/stark-backend-v2/src/utils/batch_inverse.rs @@ -0,0 +1,77 @@ +//! Copied from p3-field [src/batch_inverse.rs] to remove use of rayon + +use p3_field::{Field, FieldAlgebra, FieldArray, PackedValue}; +use tracing::instrument; + +/// Batch multiplicative inverses with Montgomery's trick +/// This is Montgomery's trick. At a high level, we invert the product of the given field +/// elements, then derive the individual inverses from that via multiplication. +/// +/// The usual Montgomery trick involves calculating an array of cumulative products, +/// resulting in a long dependency chain. To increase instruction-level parallelism, we +/// compute WIDTH separate cumulative product arrays that only meet at the end. +/// +/// # Panics +/// This will panic if any of the inputs is zero. +#[instrument(level = "debug", skip_all)] +pub fn batch_multiplicative_inverse_serial(x: &[F]) -> Vec { + let n = x.len(); + let mut result = F::zero_vec(n); + + batch_multiplicative_inverse_helper(x, &mut result); + + result +} + +/// Like `batch_multiplicative_inverse`, but writes the result to the given output buffer. +fn batch_multiplicative_inverse_helper(x: &[F], result: &mut [F]) { + // Higher WIDTH increases instruction-level parallelism, but too high a value will cause us + // to run out of registers. + const WIDTH: usize = 4; + + let n = x.len(); + assert_eq!(result.len(), n); + if n % WIDTH != 0 { + // There isn't a very clean way to do this with FieldArray; for now just do it in serial. + // Another simple (though suboptimal) workaround would be to make two separate calls, one + // for the packed part and one for the remainder. + return batch_multiplicative_inverse_general(x, result, |x| x.inverse()); + } + + let x_packed = FieldArray::::pack_slice(x); + let result_packed = FieldArray::::pack_slice_mut(result); + + let inv = |x_packed: FieldArray| { + let mut result = FieldArray::::default(); + batch_multiplicative_inverse_general(&x_packed.0, &mut result.0, |x| x.inverse()); + result + }; + batch_multiplicative_inverse_general(x_packed, result_packed, inv); +} + +/// A simple single-threaded implementation of Montgomery's trick. Since not all `FieldAlgebra`s +/// support inversion, this takes a custom inversion function. +pub(crate) fn batch_multiplicative_inverse_general(x: &[F], result: &mut [F], inv: Inv) +where + F: FieldAlgebra + Copy, + Inv: Fn(F) -> F, +{ + let n = x.len(); + assert_eq!(result.len(), n); + if n == 0 { + return; + } + + result[0] = F::ONE; + for i in 1..n { + result[i] = result[i - 1] * x[i - 1]; + } + + let product = result[n - 1] * x[n - 1]; + let mut inv = inv(product); + + for i in (0..n).rev() { + result[i] *= inv; + inv *= x[i]; + } +} diff --git a/crates/stark-backend-v2/src/utils/mod.rs b/crates/stark-backend-v2/src/utils/mod.rs new file mode 100644 index 00000000..3e6a269c --- /dev/null +++ b/crates/stark-backend-v2/src/utils/mod.rs @@ -0,0 +1,3 @@ +mod batch_inverse; + +pub use batch_inverse::*; diff --git a/crates/stark-backend-v2/src/v1_shims.rs b/crates/stark-backend-v2/src/v1_shims.rs new file mode 100644 index 00000000..6764d9b7 --- /dev/null +++ b/crates/stark-backend-v2/src/v1_shims.rs @@ -0,0 +1,129 @@ +use std::sync::Arc; + +use openvm_stark_backend::{ + prover::{ + cpu::CpuBackend, + types::{AirProvingContext, ProvingContext}, + ProverBackend, + }, + Chip, +}; +use openvm_stark_sdk::config::baby_bear_poseidon2::BabyBearPoseidon2Config; +use p3_matrix::dense::RowMajorMatrix; + +use crate::{ + prover::{ + stacked_pcs::stacked_commit, AirProvingContextV2, ColMajorMatrix, CommittedTraceDataV2, + CpuBackendV2, ProverBackendV2, ProvingContextV2, + }, + ChipV2, SystemParams, F, +}; + +type SC = BabyBearPoseidon2Config; + +pub trait V1Compat: ProverBackendV2 + Sized { + type V1: ProverBackend::Val>; + + fn dummy_matrix() -> Self::Matrix; + fn convert_trace(matrix: ::Matrix) -> Self::Matrix; + + fn convert_committed_trace( + params: &SystemParams, + matrix: ::Matrix, + ) -> CommittedTraceDataV2; +} + +impl ProvingContextV2 { + pub fn from_v1(params: &SystemParams, ctx: ProvingContext) -> Self { + let per_trace = ctx + .per_air + .into_iter() + .map(|(air_idx, air_ctx)| (air_idx, AirProvingContextV2::from_v1(params, air_ctx))) + .filter(|(_, air_ctx)| air_ctx.height() > 0) + .collect(); + Self::new(per_trace) + } + + pub fn from_v1_no_cached(ctx: ProvingContext) -> Self { + let per_trace = ctx + .per_air + .into_iter() + .map(|(air_idx, air_ctx)| (air_idx, AirProvingContextV2::from_v1_no_cached(air_ctx))) + .collect(); + Self::new(per_trace) + } +} + +impl AirProvingContextV2 { + pub fn from_v1(params: &SystemParams, ctx: AirProvingContext) -> Self { + let common_main = ctx + .common_main + .map(::convert_trace) + .unwrap_or_else(|| ::dummy_matrix()); + let cached_mains = ctx + .cached_mains + .into_iter() + .map(|d| ::convert_committed_trace(params, d.trace)) + .collect(); + Self { + cached_mains, + common_main, + public_values: ctx.public_values, + } + } + + pub fn from_v1_no_cached(ctx: AirProvingContext) -> Self { + assert!(ctx.cached_mains.is_empty()); + let common_main = ctx + .common_main + .map(::convert_trace) + .unwrap_or_else(|| ::dummy_matrix()); + Self { + cached_mains: vec![], + common_main, + public_values: ctx.public_values, + } + } +} + +impl ChipV2 for C +where + C: Chip, +{ + fn generate_proving_ctx(&self, records: R) -> AirProvingContextV2 { + let v1_ctx = self.generate_proving_ctx(records); + AirProvingContextV2::from_v1_no_cached(v1_ctx) + } +} + +impl V1Compat for CpuBackendV2 { + type V1 = CpuBackend; + + fn dummy_matrix() -> Self::Matrix { + ColMajorMatrix::dummy() + } + + fn convert_trace(matrix: Arc>) -> Self::Matrix { + ColMajorMatrix::from_row_major(&matrix) + } + + fn convert_committed_trace( + params: &SystemParams, + matrix: Arc>, + ) -> CommittedTraceDataV2 { + let trace = ColMajorMatrix::from_row_major(&matrix); + let (commitment, data) = stacked_commit( + params.l_skip, + params.n_stack, + params.log_blowup, + params.k_whir(), + &[&trace], + ); + + CommittedTraceDataV2 { + commitment, + trace, + data: Arc::new(data), + } + } +} diff --git a/crates/stark-backend-v2/src/verifier/batch_constraints.rs b/crates/stark-backend-v2/src/verifier/batch_constraints.rs new file mode 100644 index 00000000..f61fb5c0 --- /dev/null +++ b/crates/stark-backend-v2/src/verifier/batch_constraints.rs @@ -0,0 +1,374 @@ +use std::{ + iter::{self, zip}, + slice, +}; + +use itertools::Itertools; +use openvm_stark_backend::air_builders::symbolic::{ + symbolic_expression::SymbolicEvaluator, SymbolicConstraints, +}; +use p3_field::{batch_multiplicative_inverse, Field, FieldAlgebra}; +use thiserror::Error; +use tracing::{debug, instrument}; + +use crate::{ + calculate_n_logup, + keygen::types::MultiStarkVerifyingKey0V2, + poly_common::{eval_eq_mle, eval_eq_sharp_uni, eval_eq_uni, UnivariatePoly}, + poseidon2::sponge::FiatShamirTranscript, + proof::{column_openings_by_rot, BatchConstraintProof, GkrProof}, + verifier::{ + evaluator::VerifierConstraintEvaluator, + fractional_sumcheck_gkr::{verify_gkr, GkrVerificationError}, + }, + EF, F, +}; + +#[derive(Error, Debug, PartialEq, Eq)] +pub enum BatchConstraintError { + #[error("Invalid logup_pow_witness")] + InvalidLogupPowWitness, + + #[error("GKR verification failed: {0}")] + GkrVerificationFailed(#[from] GkrVerificationError), + + #[error("GKR numerator evaluation claim {claim} does not match")] + GkrNumeratorMismatch { claim: EF }, + + #[error("GKR denominator evaluation claim {claim} does not match")] + GkrDenominatorMismatch { claim: EF }, + + #[error( + "`sum_claim` does not equal the sum of `s_0` at all the roots of unity: {sum_claim} != {sum_univ_domain_s_0}" + )] + SumClaimMismatch { + sum_claim: EF, + sum_univ_domain_s_0: EF, + }, + + #[error("Claims are inconsistent")] + InconsistentClaims, +} + +/// `public_values` should be in vkey (air_idx) order, including non-present AIRs. +#[allow(clippy::too_many_arguments)] +#[instrument(level = "debug", skip_all)] +pub fn verify_zerocheck_and_logup( + transcript: &mut TS, + mvk: &MultiStarkVerifyingKey0V2, + public_values: &[Vec], + gkr_proof: &GkrProof, + batch_proof: &BatchConstraintProof, + trace_id_to_air_id: &[usize], + n_per_trace: &[isize], + omega_skip_pows: &[F], +) -> Result, BatchConstraintError> { + let l_skip = mvk.params.l_skip; + // let num_airs_present = mvk.per_air.len(); + let BatchConstraintProof { + numerator_term_per_air, + denominator_term_per_air, + univariate_round_coeffs, + sumcheck_round_polys, + column_openings, + } = batch_proof; + + // 1. Check GKR witness + if !transcript.check_witness(mvk.params.logup.pow_bits, gkr_proof.logup_pow_witness) { + return Err(BatchConstraintError::InvalidLogupPowWitness); + } + + // 2. Sample alpha and beta, receive xi, sample lambda + let alpha_logup = transcript.sample_ext(); + let beta_logup = transcript.sample_ext(); + debug!(%alpha_logup, %beta_logup); + let total_interactions = zip(trace_id_to_air_id, n_per_trace) + .map(|(&air_idx, &n)| { + let n_lift = n.max(0) as usize; + let num_interactions = mvk.per_air[air_idx].symbolic_constraints.interactions.len(); + (num_interactions as u64) << (l_skip + n_lift) + }) + .sum::(); + let n_logup: usize = calculate_n_logup(l_skip, total_interactions); + debug!(%n_logup); + + let mut xi = Vec::new(); + let mut p_xi_claim = EF::ZERO; + let mut q_xi_claim = alpha_logup; + if total_interactions > 0 { + (p_xi_claim, q_xi_claim, xi) = verify_gkr(gkr_proof, transcript, l_skip + n_logup)?; + debug_assert_eq!(xi.len(), l_skip + n_logup); + } + + let n_max = n_per_trace.iter().copied().max().unwrap().max(0) as usize; + let n_global = n_max.max(n_logup); + while xi.len() != l_skip + n_global { + xi.push(transcript.sample_ext()); + } + debug!(%n_max); + debug!(?xi); + + let lambda = transcript.sample_ext(); + debug!(%lambda); + + // 3. Observe everything from numerator_per_air and denominator_per_air, compute its sum + for (&sum_claim_p, &sum_claim_q) in zip(numerator_term_per_air, denominator_term_per_air) { + p_xi_claim -= sum_claim_p; + q_xi_claim -= sum_claim_q; + transcript.observe_ext(sum_claim_p); + transcript.observe_ext(sum_claim_q); + } + if p_xi_claim != EF::ZERO { + return Err(BatchConstraintError::GkrNumeratorMismatch { claim: p_xi_claim }); + } + if q_xi_claim != alpha_logup { + return Err(BatchConstraintError::GkrDenominatorMismatch { claim: q_xi_claim }); + } + + // 4. Sample mu, compute the mu-hash of interleave of numerator_per_air and denominator_per_air + let mu = transcript.sample_ext(); + debug!(%mu); + + let mut sum_claim = EF::ZERO; + let mut cur_mu_pow = EF::ONE; + for (&sum_claim_p, &sum_claim_q) in zip(numerator_term_per_air, denominator_term_per_air) { + sum_claim += sum_claim_p * cur_mu_pow; + cur_mu_pow *= mu; + sum_claim += sum_claim_q * cur_mu_pow; + cur_mu_pow *= mu; + } + + // 5. Univariate sumcheck round + for &coeff in univariate_round_coeffs { + transcript.observe_ext(coeff); + } + + let s_deg = mvk.params.max_constraint_degree + 1; + let r_0 = transcript.sample_ext(); + debug!(round = 0, r_round = %r_0); + assert_eq!( + univariate_round_coeffs.len(), + (mvk.max_constraint_degree() + 1) * ((1 << l_skip) - 1) + 1 + ); + let s_0 = UnivariatePoly::new(univariate_round_coeffs.clone()); + let sum_univ_domain_s_0 = s_0 + .coeffs() + .iter() + .step_by(1 << l_skip) + .copied() + .sum::() + * EF::from_canonical_usize(1 << l_skip); + if sum_claim != sum_univ_domain_s_0 { + return Err(BatchConstraintError::SumClaimMismatch { + sum_claim, + sum_univ_domain_s_0, + }); + } + let mut cur_sum = s_0.eval_at_point(r_0); + let mut rs = vec![r_0]; + + // 6. Multilinear sumcheck rounds + #[allow(clippy::needless_range_loop)] + for round in 0..n_max { + debug!(sumcheck_round = round, sum_claim = %cur_sum, "batch_constraint_sumcheck"); + let batch_s_evals = &sumcheck_round_polys[round]; + for &eval in batch_s_evals.iter() { + transcript.observe_ext(eval); + } + let s_1 = batch_s_evals[0]; + let s_0 = cur_sum - s_1; + let batch_s_evals = iter::once(&s_0).chain(batch_s_evals).collect_vec(); + + let mut factorials = vec![F::ONE; s_deg + 1]; + for i in 1..=s_deg { + factorials[i] = factorials[i - 1] * F::from_canonical_usize(i); + } + let invfact = batch_multiplicative_inverse(&factorials); + + let r = transcript.sample_ext(); + let mut pref_product = vec![EF::ONE; s_deg + 1]; + let mut suf_product = vec![EF::ONE; s_deg + 1]; + for i in 0..s_deg { + pref_product[i + 1] = pref_product[i] * (r - EF::from_canonical_usize(i)); + suf_product[i + 1] = suf_product[i] * (EF::from_canonical_usize(s_deg - i) - r); + } + cur_sum = (0..=s_deg) + .map(|i| { + *batch_s_evals[i] + * pref_product[i] + * suf_product[s_deg - i] + * invfact[i] + * invfact[s_deg - i] + }) + .sum::(); + + debug!(round = round + 1, r_round = %r); + rs.push(r); + } + + // 7. Compute `eq_3b_per_trace` + let mut stacked_idx = 0usize; + let eq_3b_per_trace = n_per_trace + .iter() + .enumerate() + .map(|(trace_idx, &n)| { + let air_idx = trace_id_to_air_id[trace_idx]; + let interactions = &mvk.per_air[air_idx].symbolic_constraints.interactions; + if interactions.is_empty() { + return vec![]; + } + let n_lift = n.max(0) as usize; + let mut b_vec = vec![F::ZERO; n_logup - n_lift]; + (0..interactions.len()) + .map(|_| { + debug_assert!(stacked_idx < 1 << (l_skip + n_logup)); + debug_assert!(stacked_idx.trailing_zeros() as usize >= l_skip + n_lift); + let mut b_int = stacked_idx >> (l_skip + n_lift); + for b in &mut b_vec { + *b = F::from_bool(b_int & 1 == 1); + b_int >>= 1; + } + stacked_idx += 1 << (l_skip + n_lift); + eval_eq_mle(&xi[l_skip + n_lift..l_skip + n_logup], &b_vec) + }) + .collect_vec() + }) + .collect_vec(); + + // 8. Compute `eq_ns` and `eq_sharp_ns` + let mut eq_ns = vec![EF::ONE; n_max + 1]; + let mut eq_sharp_ns = vec![EF::ONE; n_max + 1]; + eq_ns[0] = eval_eq_uni(l_skip, xi[0], r_0); + eq_sharp_ns[0] = eval_eq_sharp_uni(omega_skip_pows, &xi[..l_skip], r_0); + debug_assert_eq!(rs.len(), n_max + 1); + for (i, r) in rs.iter().enumerate().skip(1) { + let eq_mle = eval_eq_mle(&[xi[l_skip + i - 1]], slice::from_ref(r)); + eq_ns[i] = eq_ns[i - 1] * eq_mle; + eq_sharp_ns[i] = eq_sharp_ns[i - 1] * eq_mle; + } + let mut r_rev_prod = rs[n_max]; + // Product with r_i's to account for \hat{f} vs \tilde{f} for different n's in front-loaded + // batch sumcheck. + for i in (0..n_max).rev() { + eq_ns[i] *= r_rev_prod; + eq_sharp_ns[i] *= r_rev_prod; + r_rev_prod *= rs[i]; + } + + // 9. Compute the interaction/constraint evals and their hash + let mut interactions_evals = Vec::new(); + let mut constraints_evals = Vec::new(); + let need_rot_per_trace = trace_id_to_air_id + .iter() + .map(|&air_idx| mvk.per_air[air_idx].params.need_rot) + .collect_vec(); + + // Observe common main openings first, and then preprocessed/cached + for (trace_idx, air_openings) in column_openings.iter().enumerate() { + let need_rot = need_rot_per_trace[trace_idx]; + for (claim, claim_rot) in column_openings_by_rot(&air_openings[0], need_rot) { + transcript.observe_ext(claim); + transcript.observe_ext(claim_rot); + } + } + + for (trace_idx, air_openings) in column_openings.iter().enumerate() { + let air_idx = trace_id_to_air_id[trace_idx]; + let vk = &mvk.per_air[air_idx]; + let n = n_per_trace[trace_idx]; + let n_lift = n.max(0) as usize; + let need_rot = need_rot_per_trace[trace_idx]; + + // claim lengths are checked in proof shape + for claims in air_openings.iter().skip(1) { + for (claim, claim_rot) in column_openings_by_rot(claims, need_rot) { + transcript.observe_ext(claim); + transcript.observe_ext(claim_rot); + } + } + + let has_preprocessed = vk.preprocessed_data.is_some(); + let common_main = column_openings_by_rot(&air_openings[0], need_rot).collect::>(); + let preprocessed = has_preprocessed + .then(|| column_openings_by_rot(&air_openings[1], need_rot).collect::>()); + let cached_idx = 1 + has_preprocessed as usize; + let mut partitioned_main: Vec<_> = air_openings[cached_idx..] + .iter() + .map(|opening| column_openings_by_rot(opening, need_rot).collect::>()) + .collect(); + partitioned_main.push(common_main); + let part_main_slices = partitioned_main + .iter() + .map(|x| x.as_slice()) + .collect::>(); + + // We are evaluating the lift, which is the same as evaluating the original with domain + // D^{(2^{n})} + let (l, rs_n, norm_factor) = if n.is_negative() { + ( + l_skip.wrapping_add_signed(n), + &[rs[0].exp_power_of_2(-n as usize)] as &[_], + F::from_canonical_usize(1 << n.unsigned_abs()).inverse(), + ) + } else { + (l_skip, &rs[..=(n as usize)], F::ONE) + }; + let evaluator = VerifierConstraintEvaluator::::new( + preprocessed.as_deref(), + &part_main_slices, + &public_values[air_idx], + rs_n, + l, + ); + + let constraints = &vk.symbolic_constraints.constraints; + let nodes = evaluator.eval_nodes(&constraints.nodes); + let expr = zip(lambda.powers(), &constraints.constraint_idx) + .map(|(lambda_pow, idx)| nodes[*idx] * lambda_pow) + .sum::(); + debug!(%trace_idx, %expr, %air_idx, "constraints_eval"); + let eq_xi_r = eq_ns[n_lift]; + debug!(%trace_idx, %eq_xi_r); + constraints_evals.push(eq_xi_r * expr); + + let symbolic_constraints = SymbolicConstraints::from(&vk.symbolic_constraints); + let interactions = &symbolic_constraints.interactions; + let cur_interactions_evals = interactions + .iter() + .map(|interaction| { + let num = evaluator.eval_expr(&interaction.count); + let denom = interaction + .message + .iter() + .map(|expr| evaluator.eval_expr(expr)) + .chain(std::iter::once(EF::from_canonical_u16( + interaction.bus_index + 1, + ))) + .zip(beta_logup.powers()) + .fold(EF::ZERO, |acc, (x, y)| acc + x * y); + (num, denom) + }) + .collect_vec(); + let eq_3bs = &eq_3b_per_trace[trace_idx]; + let mut num = EF::ZERO; + let mut denom = EF::ZERO; + for (&eq_3b, (n, d)) in eq_3bs.iter().zip_eq(cur_interactions_evals.iter()) { + num += eq_3b * *n; + denom += eq_3b * *d; + } + debug!(%trace_idx, %num, %denom, %air_idx, "interactions_eval"); + interactions_evals.push(num * norm_factor * eq_sharp_ns[n_lift]); + interactions_evals.push(denom * eq_sharp_ns[n_lift]); + } + let evaluated_claim = interactions_evals + .iter() + .chain(constraints_evals.iter()) + .zip(mu.powers()) + .map(|(x, y)| *x * y) + .sum::(); + if cur_sum != evaluated_claim { + return Err(BatchConstraintError::InconsistentClaims); + } + + Ok(rs) +} diff --git a/crates/stark-backend-v2/src/verifier/evaluator.rs b/crates/stark-backend-v2/src/verifier/evaluator.rs new file mode 100644 index 00000000..9a42895c --- /dev/null +++ b/crates/stark-backend-v2/src/verifier/evaluator.rs @@ -0,0 +1,107 @@ +use openvm_stark_backend::air_builders::symbolic::{ + symbolic_expression::SymbolicEvaluator, + symbolic_variable::{Entry, SymbolicVariable}, +}; +use p3_field::{ExtensionField, Field, FieldAlgebra, TwoAdicField}; + +type ViewPair<'a, T> = &'a [(T, T)]; + +/// Returns the sum `1 + m + ... + m^{2^l - 1}`. +/// Could be done with `if m == 1 { ... } else { num / denom }`, +/// but I don't like divisions of field extension elements. +fn progression_exp_2(m: EF, l: usize) -> EF +where + EF: FieldAlgebra + Copy, +{ + (0..l) + .fold((m, EF::ONE), |(pow, sum), _| { + (pow * pow, sum * (EF::ONE + pow)) + }) + .1 +} + +pub(super) struct VerifierConstraintEvaluator<'a, F, EF> { + pub preprocessed: Option>, + pub partitioned_main: &'a [ViewPair<'a, EF>], + pub is_first_row: EF, + pub is_last_row: EF, + pub public_values: &'a [F], +} + +impl<'a, F, EF> VerifierConstraintEvaluator<'a, F, EF> +where + F: Field + TwoAdicField, + EF: ExtensionField, +{ + pub(super) fn new( + preprocessed: Option>, + partitioned_main: &'a [ViewPair<'a, EF>], + public_values: &'a [F], + rs: &'a [EF], + l_skip: usize, + ) -> Self { + let omega = F::two_adic_generator(l_skip); + let inv = EF::from_base(F::from_canonical_usize(1 << l_skip).inverse()); + let is_first_row = inv + * progression_exp_2(rs[0], l_skip) + * rs[1..].iter().fold(EF::ONE, |acc, &x| acc * (EF::ONE - x)); + let is_last_row = inv + * progression_exp_2(rs[0] * omega, l_skip) + * rs[1..].iter().fold(EF::ONE, |acc, &x| acc * x); + Self { + preprocessed, + partitioned_main, + is_first_row, + is_last_row, + public_values, + } + } +} + +impl SymbolicEvaluator for VerifierConstraintEvaluator<'_, F, EF> +where + F: Field, + EF: ExtensionField, +{ + fn eval_const(&self, c: F) -> EF { + EF::from_base(c) + } + + fn eval_var(&self, symbolic_var: SymbolicVariable) -> EF { + let index = symbolic_var.index; + match symbolic_var.entry { + Entry::Preprocessed { offset } => match &self.preprocessed { + Some(vp) => { + if offset == 0 { + vp[index].0 + } else { + vp[index].1 + } + } + None => panic!(), + }, + Entry::Main { part_index, offset } => { + let vp = &self.partitioned_main[part_index]; + if offset == 0 { + vp[index].0 + } else { + vp[index].1 + } + } + Entry::Public => EF::from_base(self.public_values[index]), + _ => unimplemented!(), + } + } + + fn eval_is_first_row(&self) -> EF { + self.is_first_row + } + + fn eval_is_last_row(&self) -> EF { + self.is_last_row + } + + fn eval_is_transition(&self) -> EF { + EF::ONE - self.is_last_row + } +} diff --git a/crates/stark-backend-v2/src/verifier/fractional_sumcheck_gkr.rs b/crates/stark-backend-v2/src/verifier/fractional_sumcheck_gkr.rs new file mode 100644 index 00000000..90976ee2 --- /dev/null +++ b/crates/stark-backend-v2/src/verifier/fractional_sumcheck_gkr.rs @@ -0,0 +1,539 @@ +use p3_field::FieldAlgebra; +use thiserror::Error; +use tracing::debug; + +use crate::{ + poly_common::{eval_eq_mle, interpolate_cubic_at_0123, interpolate_linear_at_01}, + poseidon2::sponge::FiatShamirTranscript, + proof::{GkrLayerClaims, GkrProof}, + EF, +}; + +#[derive(Debug, Error, PartialEq, Eq)] +pub enum GkrVerificationError { + #[error("Zero-round proof: q0_claim should be 1, got {actual}")] + InvalidZeroRoundValue { actual: EF }, + #[error("Zero-check failed: numerator at root should be zero, got {actual}")] + ZeroCheckFailed { actual: EF }, + #[error("Denominator consistency check failed at root: expected {expected}, got {actual}")] + RootConsistencyCheckFailed { expected: EF, actual: EF }, + #[error("Layer consistency check failed at round {round}: expected {expected}, got {actual}")] + LayerConsistencyCheckFailed { + round: usize, + expected: EF, + actual: EF, + }, + // TODO(ayush): remove these errors and make them debug asserts once proof shape verifier is + // implemented + #[error( + "Zero-round proof should have empty layers and sumcheck, got {claims_len} and {sumcheck_len}" + )] + InvalidZeroRoundShape { + claims_len: usize, + sumcheck_len: usize, + }, + #[error("Expected {expected} layers, got {actual}")] + IncorrectLayerCount { expected: usize, actual: usize }, + #[error("Expected {expected} sumcheck polynomial entries, got {actual}")] + IncorrectSumcheckPolyCount { expected: usize, actual: usize }, + #[error("Round {round} expected {expected} sumcheck sub-rounds, got {actual}")] + IncorrectSubroundCount { + round: usize, + expected: usize, + actual: usize, + }, +} + +/// Verifies the GKR protocol for fractional sumcheck. +/// +/// Reduces the fractional sum ∑_{y ∈ H_{ℓ+n_logup}} p̂(y)/q̂(y) = 0 to evaluation claims +/// on the input layer polynomials p̂(ξ) and q̂(ξ) at a random point ξ. +/// +/// The argument `total_rounds` must equal `ℓ+n_logup`. +/// +/// # Returns +/// `(p̂(ξ), q̂(ξ), ξ)` where ξ ∈ F_ext^{ℓ+n_logup} is the random evaluation point. +pub fn verify_gkr( + proof: &GkrProof, + transcript: &mut TS, + total_rounds: usize, +) -> Result<(EF, EF, Vec), GkrVerificationError> { + if total_rounds == 0 { + // Check proof shape + if !proof.claims_per_layer.is_empty() || !proof.sumcheck_polys.is_empty() { + return Err(GkrVerificationError::InvalidZeroRoundShape { + claims_len: proof.claims_per_layer.len(), + sumcheck_len: proof.sumcheck_polys.len(), + }); + } + if proof.q0_claim != EF::ONE { + return Err(GkrVerificationError::InvalidZeroRoundValue { + actual: proof.q0_claim, + }); + } + return Ok((EF::ZERO, EF::ONE, vec![])); + } + + // Verify proof shape + if proof.claims_per_layer.len() != total_rounds { + return Err(GkrVerificationError::IncorrectLayerCount { + expected: total_rounds, + actual: proof.claims_per_layer.len(), + }); + } + + // Sumcheck polys: round j has j sub-rounds, so total = 0+1+2+...+(total_rounds-1) + let expected_sumcheck_entries = total_rounds.saturating_sub(1); + if proof.sumcheck_polys.len() != expected_sumcheck_entries { + return Err(GkrVerificationError::IncorrectSumcheckPolyCount { + expected: expected_sumcheck_entries, + actual: proof.sumcheck_polys.len(), + }); + } + + transcript.observe_ext(proof.q0_claim); + + // Handle round 0 (no sumcheck, direct tree evaluation) + let layer_claims = &proof.claims_per_layer[0]; + observe_layer_claims(transcript, layer_claims); + + // Compute recursive relations for layer 1 → 0 + let (p_cross_term, q_cross_term) = compute_recursive_relations(layer_claims); + + // Zero-check: p̂₀ must be zero + if p_cross_term != EF::ZERO { + return Err(GkrVerificationError::ZeroCheckFailed { + actual: p_cross_term, + }); + } + + // Verify q0 consistency + if q_cross_term != proof.q0_claim { + return Err(GkrVerificationError::RootConsistencyCheckFailed { + expected: proof.q0_claim, + actual: q_cross_term, + }); + } + + // Sample μ₁ and reduce to single evaluation + let mu = transcript.sample_ext(); + debug!(gkr_round = 0, %mu); + let (mut numer_claim, mut denom_claim) = reduce_to_single_evaluation(layer_claims, mu); + debug!(%numer_claim, %denom_claim); + let mut gkr_r = vec![mu]; + + // Handle rounds 1..total_rounds with sumcheck + for round in 1..total_rounds { + // Sample batching challenge λⱼ + let lambda = transcript.sample_ext(); + debug!(gkr_round = round, %lambda); + let claim = numer_claim + lambda * denom_claim; + + // Run sumcheck protocol for this round (round j has j sub-rounds) + let (new_claim, round_r, eq_at_r_prime) = + verify_gkr_sumcheck(proof, transcript, round, claim, &gkr_r)?; + debug_assert_eq!(eq_at_r_prime, eval_eq_mle(&gkr_r, &round_r)); + + // Observe layer evaluation claims + let layer_claims = &proof.claims_per_layer[round]; + observe_layer_claims(transcript, layer_claims); + + // Compute recursive relations + let (p_cross_term, q_cross_term) = compute_recursive_relations(layer_claims); + + // Verify consistency + let expected_claim = (p_cross_term + lambda * q_cross_term) * eq_at_r_prime; + if expected_claim != new_claim { + return Err(GkrVerificationError::LayerConsistencyCheckFailed { + round, + expected: expected_claim, + actual: new_claim, + }); + } + + // Sample μⱼ and reduce to single evaluation + let mu = transcript.sample_ext(); + debug!(gkr_round = round, %mu); + (numer_claim, denom_claim) = reduce_to_single_evaluation(layer_claims, mu); + // Update evaluation point: ξ^{(j)} = (μⱼ, ρ^{(j-1)}) + gkr_r = std::iter::once(mu).chain(round_r).collect(); + } + + Ok((numer_claim, denom_claim, gkr_r)) +} + +/// Verify sumcheck for a single GKR round. +/// +/// Reduces evaluation of (p̂ⱼ₋₁ + λⱼ·q̂ⱼ₋₁)(ξ^{(j-1)}) to evaluations at the next layer. +/// +/// # Returns +/// `(claim, ρ^{(j-1)}, eq(ξ^{(j-1)}, ρ^{(j-1)}))` where ρ^{(j-1)} is randomly sampled from the +/// sumcheck protocol. +fn verify_gkr_sumcheck( + proof: &GkrProof, + transcript: &mut TS, + round: usize, + mut claim: EF, + gkr_r: &[EF], +) -> Result<(EF, Vec, EF), GkrVerificationError> { + debug_assert!( + round > 0, + "verify_gkr_sumcheck should not be called for round 0" + ); + debug_assert_eq!( + gkr_r.len(), + round, + "gkr_r should have exactly round elements" + ); + + // For round j, there are j sumcheck sub-rounds + let expected_subrounds = round; + let polys = &proof.sumcheck_polys[round - 1]; + if polys.len() != expected_subrounds { + return Err(GkrVerificationError::IncorrectSubroundCount { + round, + expected: expected_subrounds, + actual: polys.len(), + }); + } + let mut gkr_r_prime = Vec::with_capacity(round); + let mut eq = EF::ONE; // eq(ξ^{(j-1)}, ρ^{(j-1)}) computed incrementally + + for (sumcheck_round, poly_evals) in polys.iter().enumerate() { + debug!(gkr_round = round, %sumcheck_round, sum_claim = %claim); + // Observe s(1), s(2), s(3) where s is the sumcheck polynomial + for &eval in poly_evals { + transcript.observe_ext(eval); + } + + let ri = transcript.sample_ext(); + gkr_r_prime.push(ri); + debug!(gkr_round = round, %sumcheck_round, r_round = %ri); + + let ev0 = claim - poly_evals[0]; // s(0) = claim - s(1) + let evals = [ev0, poly_evals[0], poly_evals[1], poly_evals[2]]; + claim = interpolate_cubic_at_0123(&evals, ri); + + // Update eq incrementally: eq *= ξᵢ·rᵢ + (1-ξᵢ)·(1-rᵢ) + let xi = gkr_r[sumcheck_round]; + eq *= xi * ri + (EF::ONE - xi) * (EF::ONE - ri); + } + + Ok((claim, gkr_r_prime, eq)) +} + +/// Observes layer claims in the transcript. +fn observe_layer_claims(transcript: &mut TS, claims: &GkrLayerClaims) { + transcript.observe_ext(claims.p_xi_0); + transcript.observe_ext(claims.q_xi_0); + transcript.observe_ext(claims.p_xi_1); + transcript.observe_ext(claims.q_xi_1); +} + +/// Computes recursive relations from layer claims. +fn compute_recursive_relations(claims: &GkrLayerClaims) -> (EF, EF) { + let p_cross_term = claims.p_xi_0 * claims.q_xi_1 + claims.p_xi_1 * claims.q_xi_0; + let q_cross_term = claims.q_xi_0 * claims.q_xi_1; + (p_cross_term, q_cross_term) +} + +/// Reduces claims to a single evaluation point using linear interpolation. +fn reduce_to_single_evaluation(claims: &GkrLayerClaims, mu: EF) -> (EF, EF) { + let numer = interpolate_linear_at_01(&[claims.p_xi_0, claims.p_xi_1], mu); + let denom = interpolate_linear_at_01(&[claims.q_xi_0, claims.q_xi_1], mu); + (numer, denom) +} + +#[cfg(test)] +mod tests { + use openvm_stark_sdk::config::setup_tracing; + + use super::*; + use crate::{ + poseidon2::sponge::DuplexSponge, + proof::{GkrLayerClaims, GkrProof}, + prover::fractional_sumcheck_gkr::{fractional_sumcheck, Frac}, + F, + }; + + #[test] + fn test_multiple_rounds_shape() { + setup_tracing(); + let proof = GkrProof { + logup_pow_witness: F::ZERO, + q0_claim: EF::ONE, + claims_per_layer: vec![], + sumcheck_polys: vec![], + }; + + let mut transcript = DuplexSponge::default(); + + let result = verify_gkr(&proof, &mut transcript, 2); + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + GkrVerificationError::IncorrectLayerCount { .. } + )); + + let layer_claims = GkrLayerClaims { + p_xi_0: EF::ZERO, + p_xi_1: EF::ZERO, + q_xi_0: EF::ONE, + q_xi_1: EF::ONE, + }; + + let proof2 = GkrProof { + logup_pow_witness: F::ZERO, + q0_claim: EF::ONE, + claims_per_layer: vec![layer_claims.clone(), layer_claims], + sumcheck_polys: vec![], + }; + + let mut transcript = DuplexSponge::default(); + let result = verify_gkr(&proof2, &mut transcript, 2); + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + GkrVerificationError::IncorrectSumcheckPolyCount { .. } + )); + } + + #[test] + fn test_gkr_base_layer_numerator_zero() { + setup_tracing(); + let layer1_claims = GkrLayerClaims { + p_xi_0: EF::from_canonical_u64(1), // Non-zero + p_xi_1: EF::from_canonical_u64(2), + q_xi_0: EF::from_canonical_u64(3), + q_xi_1: EF::from_canonical_u64(4), + }; + + // p0 = 1*4 + 2*3 = 10 (non-zero) + let proof = GkrProof { + logup_pow_witness: F::ZERO, + q0_claim: EF::from_canonical_u64(12), // q0 = 3*4 = 12 + claims_per_layer: vec![layer1_claims], + sumcheck_polys: vec![], + }; + + let mut transcript = DuplexSponge::default(); + let result = verify_gkr(&proof, &mut transcript, 1); + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + GkrVerificationError::ZeroCheckFailed { .. } + )); + } + + #[test] + fn test_gkr_1_round_integration() { + setup_tracing(); + let fractions = vec![ + Frac { + p: EF::ZERO, + q: EF::ONE, + }, + Frac { + p: EF::ZERO, + q: EF::ONE, + }, + ]; + + let mut prover_transcript = DuplexSponge::default(); + let (frac_proof, _xi) = fractional_sumcheck(&mut prover_transcript, &fractions, true); + + let gkr_proof = GkrProof { + logup_pow_witness: F::ZERO, + q0_claim: frac_proof.fractional_sum.1, + claims_per_layer: frac_proof.claims_per_layer, + sumcheck_polys: frac_proof.sumcheck_polys, + }; + + let mut verifier_transcript = DuplexSponge::default(); + let total_rounds = p3_util::log2_strict_usize(fractions.len()); + let result = verify_gkr(&gkr_proof, &mut verifier_transcript, total_rounds); + + assert!( + result.is_ok(), + "1-round verification failed: {:?}", + result.err() + ); + let (numer_claim, denom_claim, _) = result.unwrap(); + assert_eq!(numer_claim, EF::ZERO); + assert_ne!(denom_claim, EF::ZERO); + } + + #[test] + fn test_gkr_2_round_integration() { + setup_tracing(); + let fractions = vec![ + Frac { + p: EF::ZERO, + q: EF::ONE, + }, + Frac { + p: EF::ZERO, + q: EF::ONE, + }, + Frac { + p: EF::ZERO, + q: EF::ONE, + }, + Frac { + p: EF::ZERO, + q: EF::ONE, + }, + ]; + + let mut prover_transcript = DuplexSponge::default(); + let (frac_proof, _xi) = fractional_sumcheck(&mut prover_transcript, &fractions, true); + + let gkr_proof = GkrProof { + logup_pow_witness: F::ZERO, + q0_claim: frac_proof.fractional_sum.1, + claims_per_layer: frac_proof.claims_per_layer, + sumcheck_polys: frac_proof.sumcheck_polys, + }; + + let mut verifier_transcript = DuplexSponge::default(); + let total_rounds = p3_util::log2_strict_usize(fractions.len()); + let result = verify_gkr(&gkr_proof, &mut verifier_transcript, total_rounds); + + assert!( + result.is_ok(), + "2-round verification failed: {:?}", + result.err() + ); + let (numer_claim, denom_claim, _) = result.unwrap(); + assert_eq!(numer_claim, EF::ZERO); + assert_ne!(denom_claim, EF::ZERO); + } + + #[test] + fn test_gkr_3_round_integration() { + setup_tracing(); + let fractions = vec![ + Frac { + p: EF::ZERO, + q: EF::ONE, + }, + Frac { + p: EF::ZERO, + q: EF::ONE, + }, + Frac { + p: EF::ZERO, + q: EF::ONE, + }, + Frac { + p: EF::ZERO, + q: EF::ONE, + }, + Frac { + p: EF::ZERO, + q: EF::ONE, + }, + Frac { + p: EF::ZERO, + q: EF::ONE, + }, + Frac { + p: EF::ZERO, + q: EF::ONE, + }, + Frac { + p: EF::ZERO, + q: EF::ONE, + }, + ]; + + let mut prover_transcript = DuplexSponge::default(); + let (frac_proof, _xi) = fractional_sumcheck(&mut prover_transcript, &fractions, true); + + let gkr_proof = GkrProof { + logup_pow_witness: F::ZERO, + q0_claim: frac_proof.fractional_sum.1, + claims_per_layer: frac_proof.claims_per_layer, + sumcheck_polys: frac_proof.sumcheck_polys, + }; + + let mut verifier_transcript = DuplexSponge::default(); + let total_rounds = p3_util::log2_strict_usize(fractions.len()); + let result = verify_gkr(&gkr_proof, &mut verifier_transcript, total_rounds); + + assert!( + result.is_ok(), + "3-round verification failed: {:?}", + result.err() + ); + let (numer_claim, denom_claim, _) = result.unwrap(); + assert_eq!(numer_claim, EF::ZERO); + assert_ne!(denom_claim, EF::ZERO); + } + + #[test] + fn test_gkr_mixed_fractions() { + setup_tracing(); + let fractions = vec![ + Frac { + p: EF::from_canonical_u64(5), + q: EF::ONE, + }, + Frac { + p: -EF::from_canonical_u64(5), + q: EF::ONE, + }, + ]; + + let mut prover_transcript = DuplexSponge::default(); + let (frac_proof, _xi) = fractional_sumcheck(&mut prover_transcript, &fractions, true); + + let gkr_proof = GkrProof { + logup_pow_witness: F::ZERO, + q0_claim: frac_proof.fractional_sum.1, + claims_per_layer: frac_proof.claims_per_layer, + sumcheck_polys: frac_proof.sumcheck_polys, + }; + + let mut verifier_transcript = DuplexSponge::default(); + let total_rounds = p3_util::log2_strict_usize(fractions.len()); + let result = verify_gkr(&gkr_proof, &mut verifier_transcript, total_rounds); + + assert!( + result.is_ok(), + "Mixed fractions verification failed: {:?}", + result.err() + ); + let (_numer_claim, denom_claim, _) = result.unwrap(); + assert_ne!(denom_claim, EF::ZERO); + } + + #[test] + fn test_gkr_empty_case() { + setup_tracing(); + let fractions = vec![]; + + let mut prover_transcript = DuplexSponge::default(); + let (frac_proof, _xi) = fractional_sumcheck(&mut prover_transcript, &fractions, true); + + let gkr_proof = GkrProof { + logup_pow_witness: F::ZERO, + q0_claim: frac_proof.fractional_sum.1, + claims_per_layer: frac_proof.claims_per_layer, + sumcheck_polys: frac_proof.sumcheck_polys, + }; + + let mut verifier_transcript = DuplexSponge::default(); + let result = verify_gkr(&gkr_proof, &mut verifier_transcript, 0); + + assert!( + result.is_ok(), + "Empty case verification failed: {:?}", + result.err() + ); + let (numer_claim, denom_claim, gkr_r) = result.unwrap(); + assert_eq!(numer_claim, EF::ZERO); + assert_eq!(denom_claim, EF::ONE); + assert_eq!(gkr_r, vec![]); + } +} diff --git a/crates/stark-backend-v2/src/verifier/mod.rs b/crates/stark-backend-v2/src/verifier/mod.rs new file mode 100644 index 00000000..8dd044fd --- /dev/null +++ b/crates/stark-backend-v2/src/verifier/mod.rs @@ -0,0 +1,325 @@ +use core::cmp::Reverse; + +use itertools::{izip, Itertools}; +use p3_field::{FieldAlgebra, TwoAdicField}; +use thiserror::Error; + +use crate::{ + keygen::types::{MultiStarkVerifyingKey0V2, MultiStarkVerifyingKeyV2}, + poly_common::Squarable, + poseidon2::sponge::FiatShamirTranscript, + proof::Proof, + verifier::{ + batch_constraints::{verify_zerocheck_and_logup, BatchConstraintError}, + proof_shape::{verify_proof_shape, ProofShapeError}, + stacked_reduction::{verify_stacked_reduction, StackedReductionError}, + whir::{verify_whir, VerifyWhirError}, + }, + F, +}; + +#[derive(Error, Debug, PartialEq, Eq)] +pub enum VerifierError { + #[error("Trace heights are too large")] + TraceHeightsTooLarge, + + #[error("Proof shape verification failed: {0}")] + ProofShapeError(#[from] ProofShapeError), + + #[error("Batch constraint verification failed: {0}")] + BatchConstraintError(#[from] BatchConstraintError), + + #[error("Stacked reduction verification failed: {0}")] + StackedReductionError(#[from] StackedReductionError), + + #[error("Whir verification failed: {0}")] + WhirError(#[from] VerifyWhirError), +} + +pub mod batch_constraints; +pub mod evaluator; +pub mod fractional_sumcheck_gkr; +pub mod proof_shape; +pub mod stacked_reduction; +pub mod sumcheck; +pub mod whir; + +pub fn verify( + mvk: &MultiStarkVerifyingKeyV2, + proof: &Proof, + transcript: &mut TS, +) -> Result<(), VerifierError> { + let &Proof { + common_main_commit, + trace_vdata, + public_values, + gkr_proof, + batch_constraint_proof, + stacking_proof, + whir_proof, + } = &proof; + let &MultiStarkVerifyingKeyV2 { + inner: mvk, + pre_hash: mvk_pre_hash, + } = &mvk; + let &MultiStarkVerifyingKey0V2 { + params, + per_air, + trace_height_constraints, + } = &mvk; + let l_skip = params.l_skip; + + let num_airs = per_air.len(); + + let mut trace_id_to_air_id: Vec = (0..num_airs).collect(); + trace_id_to_air_id.sort_by_key(|&air_id| { + ( + trace_vdata[air_id].is_none(), + trace_vdata[air_id] + .as_ref() + .map(|vdata| Reverse(vdata.log_height)), + air_id, + ) + }); + let num_traces = trace_vdata.iter().flatten().collect_vec().len(); + trace_id_to_air_id.truncate(num_traces); + + for constraint in trace_height_constraints { + let sum = trace_id_to_air_id + .iter() + .map(|&air_id| { + let log_height = trace_vdata[air_id].as_ref().unwrap().log_height; + // Proof shape will check n <= n_stack is in bounds + (1 << log_height.max(l_skip)) as u64 * constraint.coefficients[air_id] as u64 + }) + .sum::(); + if sum >= constraint.threshold as u64 { + return Err(VerifierError::TraceHeightsTooLarge); + } + } + + let omega_skip = F::two_adic_generator(l_skip); + let omega_skip_pows = omega_skip.powers().take(1 << l_skip).collect_vec(); + + // Preamble + transcript.observe_commit(*mvk_pre_hash); + transcript.observe_commit(proof.common_main_commit); + + for (trace_vdata, avk, pvs) in izip!(&proof.trace_vdata, per_air, &proof.public_values) { + let is_air_present = trace_vdata.is_some(); + + if !avk.is_required { + transcript.observe(F::from_bool(is_air_present)); + } + if let Some(trace_vdata) = trace_vdata { + if let Some(pdata) = avk.preprocessed_data.as_ref() { + transcript.observe_commit(pdata.commit); + } else { + transcript.observe(F::from_canonical_usize(trace_vdata.log_height)); + } + debug_assert_eq!( + avk.params.width.cached_mains.len(), + trace_vdata.cached_commitments.len() + ); + for commit in &trace_vdata.cached_commitments { + transcript.observe_commit(*commit); + } + debug_assert_eq!(avk.params.num_public_values, pvs.len()); + } + for pv in pvs { + transcript.observe(*pv); + } + } + + let layouts = verify_proof_shape(mvk, proof)?; + + let n_per_trace: Vec = trace_id_to_air_id + .iter() + .map(|&air_id| trace_vdata[air_id].as_ref().unwrap().log_height as isize - l_skip as isize) + .collect(); + let r = verify_zerocheck_and_logup( + transcript, + mvk, + public_values, + gkr_proof, + batch_constraint_proof, + &trace_id_to_air_id, + &n_per_trace, + &omega_skip_pows, + )?; + + let need_rot_per_trace = trace_id_to_air_id + .iter() + .map(|&air_id| per_air[air_id].params.need_rot) + .collect_vec(); + let mut need_rot_per_commit = vec![need_rot_per_trace]; + for &air_id in &trace_id_to_air_id { + let need_rot = per_air[air_id].params.need_rot; + if per_air[air_id].preprocessed_data.is_some() { + need_rot_per_commit.push(vec![need_rot]); + } + let cached_len = trace_vdata[air_id] + .as_ref() + .unwrap() + .cached_commitments + .len(); + for _ in 0..cached_len { + need_rot_per_commit.push(vec![need_rot]); + } + } + + let u_prism = verify_stacked_reduction( + transcript, + stacking_proof, + &layouts, + &need_rot_per_commit, + l_skip, + params.n_stack, + &proof.batch_constraint_proof.column_openings, + &r, + &omega_skip_pows, + )?; + + let (&u0, u_rest) = u_prism.split_first().unwrap(); + let u_cube = u0 + .exp_powers_of_2() + .take(l_skip) + .chain(u_rest.iter().copied()) + .collect_vec(); + + let mut commits = vec![*common_main_commit]; + for &air_id in trace_id_to_air_id.iter() { + if let Some(preprocessed) = &per_air[air_id].preprocessed_data { + commits.push(preprocessed.commit); + } + commits.extend(&trace_vdata[air_id].as_ref().unwrap().cached_commitments); + } + + verify_whir( + transcript, + params, + whir_proof, + &stacking_proof.stacking_openings, + &commits, + &u_cube, + )?; + + Ok(()) +} + +#[cfg(test)] +mod tests { + use openvm_stark_sdk::config::{ + log_up_params::log_up_security_params_baby_bear_100_bits, setup_tracing_with_log_level, + }; + use test_case::test_case; + use tracing::Level; + + use crate::{ + poseidon2::sponge::{DuplexSpongeRecorder, TranscriptHistory}, + test_utils::{ + test_system_params_small, CachedFixture11, DuplexSpongeValidator, FibFixture, + InteractionsFixture11, PreprocessedFibFixture, TestFixture, + }, + verifier::{verify, VerifierError}, + BabyBearPoseidon2CpuEngineV2, SystemParams, WhirConfig, WhirParams, + }; + + #[test_case(2, 10)] + #[test_case(2, 1; "where log_trace_degree=1 less than l_skip=2")] + #[test_case(2, 0; "where log_trace_degree=0 less than l_skip=2")] + #[test_case(3, 2; "where log_trace_degree=2 less than l_skip=3")] + fn test_fib_air_roundtrip(l_skip: usize, log_trace_degree: usize) -> Result<(), VerifierError> { + setup_tracing_with_log_level(Level::DEBUG); + + let n_stack = 8; + let k_whir = 4; + let whir_params = WhirParams { + k: k_whir, + log_final_poly_len: k_whir, + query_phase_pow_bits: 1, + }; + let log_blowup = 1; + let whir = WhirConfig::new(log_blowup, l_skip + n_stack, whir_params, 80); + let params = SystemParams { + l_skip, + n_stack, + log_blowup, + whir, + logup: log_up_security_params_baby_bear_100_bits(), + max_constraint_degree: 3, + }; + let fib = FibFixture::new(0, 1, 1 << log_trace_degree); + + let engine = BabyBearPoseidon2CpuEngineV2::new(params); + let (pk, vk) = fib.keygen(&engine); + let mut recorder = DuplexSpongeRecorder::default(); + let proof = fib.prove_from_transcript(&engine, &pk, &mut recorder); + + let mut validator_sponge = DuplexSpongeValidator::new(recorder.into_log()); + verify(&vk, &proof, &mut validator_sponge) + } + + #[test_case(2, 8, 3)] + #[test_case(5, 5, 4)] + fn test_dummy_interactions_roundtrip( + l_skip: usize, + n_stack: usize, + k_whir: usize, + ) -> Result<(), VerifierError> { + let params = test_system_params_small(l_skip, n_stack, k_whir); + let engine = BabyBearPoseidon2CpuEngineV2::new(params); + let fx = InteractionsFixture11; + let (pk, vk) = fx.keygen(&engine); + + let mut recorder = DuplexSpongeRecorder::default(); + let proof = fx.prove_from_transcript(&engine, &pk, &mut recorder); + + let mut validator_sponge = DuplexSpongeValidator::new(recorder.into_log()); + verify(&vk, &proof, &mut validator_sponge) + } + + #[test_case(2, 8, 3)] + #[test_case(5, 5, 4)] + #[test_case(5, 8, 3)] + fn test_cached_trace_roundtrip( + l_skip: usize, + n_stack: usize, + k_whir: usize, + ) -> Result<(), VerifierError> { + setup_tracing_with_log_level(Level::DEBUG); + let params = test_system_params_small(l_skip, n_stack, k_whir); + let engine = BabyBearPoseidon2CpuEngineV2::new(params.clone()); + let fx = CachedFixture11::new(params); + let (pk, vk) = fx.keygen(&engine); + + let mut recorder = DuplexSpongeRecorder::default(); + let proof = fx.prove_from_transcript(&engine, &pk, &mut recorder); + + let mut validator_sponge = DuplexSpongeValidator::new(recorder.into_log()); + verify(&vk, &proof, &mut validator_sponge) + } + + #[test_case(2, 8, 3)] + #[test_case(5, 5, 4)] + fn test_preprocessed_trace_roundtrip( + l_skip: usize, + n_stack: usize, + k_whir: usize, + ) -> Result<(), VerifierError> { + use itertools::Itertools; + let params = test_system_params_small(l_skip, n_stack, k_whir); + let engine = BabyBearPoseidon2CpuEngineV2::new(params); + let log_trace_degree = 8; + let height = 1 << log_trace_degree; + let sels = (0..height).map(|i| i % 2 == 0).collect_vec(); + let fx = PreprocessedFibFixture::new(0, 1, sels); + let (pk, vk) = fx.keygen(&engine); + + let mut recorder = DuplexSpongeRecorder::default(); + let proof = fx.prove_from_transcript(&engine, &pk, &mut recorder); + + let mut validator_sponge = DuplexSpongeValidator::new(recorder.into_log()); + verify(&vk, &proof, &mut validator_sponge) + } +} diff --git a/crates/stark-backend-v2/src/verifier/proof_shape.rs b/crates/stark-backend-v2/src/verifier/proof_shape.rs new file mode 100644 index 00000000..a8e2a451 --- /dev/null +++ b/crates/stark-backend-v2/src/verifier/proof_shape.rs @@ -0,0 +1,783 @@ +use std::cmp::{max, Reverse}; + +use itertools::{izip, Itertools}; +use thiserror::Error; + +use crate::{ + calculate_n_logup, keygen::types::MultiStarkVerifyingKey0V2, proof::Proof, + prover::stacked_pcs::StackedLayout, +}; + +#[derive(Debug, Error, PartialEq, Eq)] +pub enum ProofShapeError { + #[error("Invalid VData: {0}")] + InvalidVData(ProofShapeVDataError), + #[error("Invalid GkrProof shape: {0}")] + InvalidGkrProofShape(GkrProofShapeError), + #[error("Invalid BatchConstraintProof shape: {0}")] + InvalidBatchConstraintProofShape(BatchProofShapeError), + #[error("Invalid StackingProof shape: {0}")] + InvalidStackingProofShape(StackingProofShapeError), + #[error("Invalid WhirProof shape: {0}")] + InvalidWhirProofShape(WhirProofShapeError), +} + +impl ProofShapeError { + fn invalid_vdata(err: ProofShapeVDataError) -> Result { + Err(Self::InvalidVData(err)) + } + + fn invalid_gkr(err: GkrProofShapeError) -> Result { + Err(Self::InvalidGkrProofShape(err)) + } + + fn invalid_batch_constraint(err: BatchProofShapeError) -> Result { + Err(Self::InvalidBatchConstraintProofShape(err)) + } + + fn invalid_stacking(err: StackingProofShapeError) -> Result { + Err(Self::InvalidStackingProofShape(err)) + } + + fn invalid_whir(err: WhirProofShapeError) -> Result { + Err(Self::InvalidWhirProofShape(err)) + } +} + +#[derive(Debug, Error, PartialEq, Eq)] +pub enum ProofShapeVDataError { + #[error("Proof trace_vdata length ({len}) does not match number of AIRs ({num_airs})")] + InvalidVDataLength { len: usize, num_airs: usize }, + #[error("Proof public_values length ({len}) does not match number of AIRs ({num_airs})")] + InvalidPublicValuesLength { len: usize, num_airs: usize }, + #[error("AIR {air_idx} is required, but trace_vdata[{air_idx}] is None")] + RequiredAirNoVData { air_idx: usize }, + #[error("AIR {air_idx} has no TraceVData, but a non-zero amount of public values")] + PublicValuesNoVData { air_idx: usize }, + #[error( + "TraceVata for AIR {air_idx} should have {expected} cached commitments, but has {actual}" + )] + InvalidCachedCommitments { + air_idx: usize, + expected: usize, + actual: usize, + }, + #[error("AIR {air_idx} should have log_height <= {}, but has {actual} (l_skip = {l_skip}, n_stack = {n_stack}", l_skip + n_stack)] + LogHeightOutOfBounds { + air_idx: usize, + l_skip: usize, + n_stack: usize, + actual: usize, + }, + #[error("AIR {air_idx} should have {expected} public values, but has {actual}")] + InvalidPublicValues { + air_idx: usize, + expected: usize, + actual: usize, + }, +} + +#[derive(Debug, Error, PartialEq, Eq)] +pub enum GkrProofShapeError { + #[error( + "claims_per_layer should have num_gkr_rounds = {expected} claims, but it has {actual}" + )] + InvalidClaimsPerLayer { expected: usize, actual: usize }, + #[error( + "sumcheck_polys should have num_gkr_rounds.saturating_sub(1) = {expected} polynomials, but it has {actual}" + )] + InvalidSumcheckPolys { expected: usize, actual: usize }, + #[error( + "Sumcheck polynomial for round {round} should have {expected} evaluations, but it has {actual}" + )] + InvalidSumcheckPolyEvals { + round: usize, + expected: usize, + actual: usize, + }, +} + +#[derive(Debug, Error, PartialEq, Eq)] +pub enum BatchProofShapeError { + #[error("numerator_term_per_air should have num_airs = {expected} terms, but it has {actual}")] + InvalidNumeratorTerms { expected: usize, actual: usize }, + #[error( + "denominator_term_per_air should have num_airs = {expected} terms, but it has {actual}" + )] + InvalidDenominatorTerms { expected: usize, actual: usize }, + #[error( + "univariate_round_coeffs should have (max_constraint_degree + 1) * (2^l_skip - 1) + 1 = {expected} coefficients, but it has {actual}" + )] + InvalidUnivariateRoundCoeffs { expected: usize, actual: usize }, + #[error( + "sumcheck_round_polys should have n_global = {expected} polynomials, but it has {actual}" + )] + InvalidSumcheckRoundPolys { expected: usize, actual: usize }, + #[error( + "column_openings should have num_airs = {expected} sets of openings, but it has {actual}" + )] + InvalidColumnOpeningsAirs { expected: usize, actual: usize }, + #[error( + "sumcheck_round_polys[{round}] should have degree = {expected} evaluations, but it has {actual}" + )] + InvalidSumcheckRoundPolyEvals { + round: usize, + expected: usize, + actual: usize, + }, + #[error( + "AIR {air_idx} has {expected} parts, but there are {actual} sets of per-part column openings" + )] + InvalidColumnOpeningsPerAir { + air_idx: usize, + expected: usize, + actual: usize, + }, + #[error( + "There should be {expected} column opening pairs for AIR {air_idx}'s main trace, but instead there are {actual}" + )] + InvalidColumnOpeningsPerAirMain { + air_idx: usize, + expected: usize, + actual: usize, + }, + #[error( + "There should be {expected} column opening pairs for AIR {air_idx}'s preprocessed trace, but instead there are {actual}" + )] + InvalidColumnOpeningsPerAirPreprocessed { + air_idx: usize, + expected: usize, + actual: usize, + }, + #[error( + "There should be {expected} column opening pairs for AIR {air_idx}'s cached trace {cached_idx}, but instead there are {actual}" + )] + InvalidColumnOpeningsPerAirCached { + air_idx: usize, + cached_idx: usize, + expected: usize, + actual: usize, + }, + #[error( + "Column opening for AIR {air_idx} (part {part_idx}) should have {expected} values, but has {actual}" + )] + InvalidColumnOpeningLen { + air_idx: usize, + part_idx: usize, + expected: usize, + actual: usize, + }, +} + +#[derive(Debug, Error, PartialEq, Eq)] +pub enum StackingProofShapeError { + #[error( + "univariate_round_coeffs should have 2 * ((1 << mvk.params.l_skip) - 1) + 1 = {expected} coefficients, but it has {actual}" + )] + InvalidUnivariateRoundCoeffs { expected: usize, actual: usize }, + #[error( + "sumcheck_round_polys should have n_stack = {expected} polynomials, but it has {actual}" + )] + InvalidSumcheckRoundPolys { expected: usize, actual: usize }, + #[error( + "There should be {expected} sets of per-commit stacking openings, but instead there are {actual}" + )] + InvalidStackOpenings { expected: usize, actual: usize }, + #[error( + "Stacked matrix {commit_idx} should have {expected} stacking openings, but instead there are {actual}" + )] + InvalidStackOpeningsPerMatrix { + commit_idx: usize, + expected: usize, + actual: usize, + }, +} + +#[derive(Debug, Error, PartialEq, Eq)] +pub enum WhirProofShapeError { + #[error( + "whir_sumcheck_polys should have num_whir_sumcheck_rounds = {expected} polynomials, but it has {actual}" + )] + InvalidSumcheckPolys { expected: usize, actual: usize }, + #[error("final_poly should have len = {expected}, but it has {actual}")] + InvalidFinalPolyLen { expected: usize, actual: usize }, + #[error( + "There should be num_whir_rounds = {expected} codeword commits, but there are {actual}" + )] + InvalidCodewordCommits { expected: usize, actual: usize }, + #[error( + "There should be num_whir_rounds = {expected} out-of-domain values, but there are {actual}" + )] + InvalidOodValues { expected: usize, actual: usize }, + #[error( + "There should be num_whir_sumcheck_rounds = {expected} folding PoW witnesses, but there are {actual}" + )] + InvalidFoldingPowWitnesses { expected: usize, actual: usize }, + #[error( + "There should be num_whir_rounds = {expected} query phase PoW witnesses, but there are {actual}" + )] + InvalidQueryPhasePowWitnesses { expected: usize, actual: usize }, + #[error( + "There should be num_commits = {expected} sets of initial round opened rows, but there are {actual}" + )] + InvalidInitialRoundOpenedRows { expected: usize, actual: usize }, + #[error( + "There should be num_commits = {expected} sets of initial round merkle proofs, but there are {actual}" + )] + InvalidInitialRoundMerkleProofs { expected: usize, actual: usize }, + #[error( + "There should be num_whir_rounds = {expected} sets of non-initial round opened rows, but there are {actual}" + )] + InvalidCodewordOpenedRows { expected: usize, actual: usize }, + #[error( + "There should be num_whir_rounds = {expected} sets of non-initial round merkle proofs, but there are {actual}" + )] + InvalidCodewordMerkleProofs { expected: usize, actual: usize }, + #[error( + "There should be num_whir_queries = {expected} initial round opened rows for commit {commit_idx}, but there are {actual}" + )] + InvalidInitialRoundOpenedRowsQueries { + commit_idx: usize, + expected: usize, + actual: usize, + }, + #[error( + "There should be num_whir_queries = {expected} initial round merkle proofs for commit {commit_idx}, but there are {actual}" + )] + InvalidInitialRoundMerkleProofsQueries { + commit_idx: usize, + expected: usize, + actual: usize, + }, + #[error( + "Initial round opened row {opened_idx} for commit {commit_idx} should have length {expected}, but it has length {actual}" + )] + InvalidInitialRoundOpenedRowK { + opened_idx: usize, + commit_idx: usize, + expected: usize, + actual: usize, + }, + #[error( + "Initial round opened row {row_idx} for commit {commit_idx} should have width {expected}, but it has width {actual}" + )] + InvalidInitialRoundOpenedRowWidth { + row_idx: usize, + commit_idx: usize, + expected: usize, + actual: usize, + }, + #[error( + "Initial round merkle proof {opened_idx} for commit {commit_idx} should have depth {expected}, but it has depth {actual}" + )] + InvalidInitialRoundMerkleProofDepth { + opened_idx: usize, + commit_idx: usize, + expected: usize, + actual: usize, + }, + #[error( + "There should be num_whir_queries = {expected} round {round} opened rows, but there are {actual}" + )] + InvalidCodewordOpenedRowsQueries { + round: usize, + expected: usize, + actual: usize, + }, + #[error( + "There should be num_whir_queries = {expected} round {round} merkle proofs, but there are {actual}" + )] + InvalidCodewordMerkleProofsQueries { + round: usize, + expected: usize, + actual: usize, + }, + #[error( + "Round {round} opened row {opened_idx} should have length {expected}, but it has length {actual}" + )] + InvalidCodewordOpenedValues { + round: usize, + opened_idx: usize, + expected: usize, + actual: usize, + }, + #[error( + "Round {round} merkle proof {opened_idx} should have depth {expected}, but it has depth {actual}" + )] + InvalidCodewordMerkleProofDepth { + round: usize, + opened_idx: usize, + expected: usize, + actual: usize, + }, +} + +pub fn verify_proof_shape( + mvk: &MultiStarkVerifyingKey0V2, + proof: &Proof, +) -> Result, ProofShapeError> { + // TRACE HEIGHTS AND PUBLIC VALUES + let num_airs = mvk.per_air.len(); + let l_skip = mvk.params.l_skip; + + if proof.trace_vdata.len() != num_airs { + return ProofShapeError::invalid_vdata(ProofShapeVDataError::InvalidVDataLength { + len: proof.trace_vdata.len(), + num_airs, + }); + } else if proof.public_values.len() != num_airs { + return ProofShapeError::invalid_vdata(ProofShapeVDataError::InvalidPublicValuesLength { + len: proof.public_values.len(), + num_airs, + }); + } + + for (air_idx, (vk, vdata, pvs)) in + izip!(&mvk.per_air, &proof.trace_vdata, &proof.public_values).enumerate() + { + if vdata.is_none() { + if vk.is_required { + return ProofShapeError::invalid_vdata(ProofShapeVDataError::RequiredAirNoVData { + air_idx, + }); + } else if !pvs.is_empty() { + return ProofShapeError::invalid_vdata(ProofShapeVDataError::PublicValuesNoVData { + air_idx, + }); + } + } else { + let vdata = vdata.as_ref().unwrap(); + if vdata.cached_commitments.len() != vk.num_cached_mains() { + return ProofShapeError::invalid_vdata( + ProofShapeVDataError::InvalidCachedCommitments { + air_idx, + expected: vk.num_cached_mains(), + actual: vdata.cached_commitments.len(), + }, + ); + } else if vdata.log_height > l_skip + mvk.params.n_stack { + return ProofShapeError::invalid_vdata( + ProofShapeVDataError::LogHeightOutOfBounds { + air_idx, + l_skip, + n_stack: mvk.params.n_stack, + actual: vdata.log_height, + }, + ); + } else if vk.params.num_public_values != pvs.len() { + return ProofShapeError::invalid_vdata(ProofShapeVDataError::InvalidPublicValues { + air_idx, + expected: vk.params.num_public_values, + actual: pvs.len(), + }); + } + } + } + + let per_trace = mvk + .per_air + .iter() + .zip(&proof.trace_vdata) + .enumerate() + .filter_map(|(air_idx, (vk, vdata))| vdata.as_ref().map(|vdata| (air_idx, vk, vdata))) + .sorted_by_key(|(_, _, vdata)| Reverse(vdata.log_height)) + .collect_vec(); + let num_airs_present = per_trace.len(); + + // GKR PROOF SHAPE + let total_interactions = per_trace.iter().fold(0u64, |acc, (_, vk, vdata)| { + acc + ((vk.num_interactions() as u64) << max(vdata.log_height, l_skip)) + }); + let n_logup = calculate_n_logup(l_skip, total_interactions); + let num_gkr_rounds = if total_interactions == 0 { + 0 + } else { + l_skip + n_logup + }; + + if proof.gkr_proof.claims_per_layer.len() != num_gkr_rounds { + return ProofShapeError::invalid_gkr(GkrProofShapeError::InvalidClaimsPerLayer { + expected: num_gkr_rounds, + actual: proof.gkr_proof.claims_per_layer.len(), + }); + } else if proof.gkr_proof.sumcheck_polys.len() != num_gkr_rounds.saturating_sub(1) { + return ProofShapeError::invalid_gkr(GkrProofShapeError::InvalidSumcheckPolys { + expected: num_gkr_rounds.saturating_sub(1), + actual: proof.gkr_proof.sumcheck_polys.len(), + }); + } + + for (i, poly) in proof.gkr_proof.sumcheck_polys.iter().enumerate() { + if poly.len() != i + 1 { + return ProofShapeError::invalid_gkr(GkrProofShapeError::InvalidSumcheckPolyEvals { + round: i + 1, + expected: i + 1, + actual: poly.len(), + }); + } + } + + // BATCH CONSTRAINTS PROOF SHAPE + let batch_proof = &proof.batch_constraint_proof; + + let n_max = per_trace[0].2.log_height.saturating_sub(l_skip); + + let s_0_deg = (mvk.max_constraint_degree() + 1) * ((1 << l_skip) - 1); + if batch_proof.numerator_term_per_air.len() != num_airs_present { + return ProofShapeError::invalid_batch_constraint( + BatchProofShapeError::InvalidNumeratorTerms { + expected: num_airs_present, + actual: batch_proof.numerator_term_per_air.len(), + }, + ); + } else if batch_proof.denominator_term_per_air.len() != num_airs_present { + return ProofShapeError::invalid_batch_constraint( + BatchProofShapeError::InvalidDenominatorTerms { + expected: num_airs_present, + actual: batch_proof.denominator_term_per_air.len(), + }, + ); + } else if batch_proof.univariate_round_coeffs.len() != s_0_deg + 1 { + return ProofShapeError::invalid_batch_constraint( + BatchProofShapeError::InvalidUnivariateRoundCoeffs { + expected: s_0_deg + 1, + actual: batch_proof.univariate_round_coeffs.len(), + }, + ); + } else if batch_proof.sumcheck_round_polys.len() != n_max { + return ProofShapeError::invalid_batch_constraint( + BatchProofShapeError::InvalidSumcheckRoundPolys { + expected: n_max, + actual: batch_proof.sumcheck_round_polys.len(), + }, + ); + } else if batch_proof.column_openings.len() != num_airs_present { + return ProofShapeError::invalid_batch_constraint( + BatchProofShapeError::InvalidColumnOpeningsAirs { + expected: num_airs_present, + actual: batch_proof.column_openings.len(), + }, + ); + } + + for (i, evals) in batch_proof.sumcheck_round_polys.iter().enumerate() { + if evals.len() != mvk.max_constraint_degree() + 1 { + return ProofShapeError::invalid_batch_constraint( + BatchProofShapeError::InvalidSumcheckRoundPolyEvals { + round: i, + expected: mvk.max_constraint_degree() + 1, + actual: evals.len(), + }, + ); + } + } + + for (part_openings, &(air_idx, vk, _)) in batch_proof.column_openings.iter().zip(&per_trace) { + let need_rot = mvk.per_air[air_idx].params.need_rot; + let openings_per_col = if need_rot { 2 } else { 1 }; + if part_openings.len() != vk.num_parts() { + return ProofShapeError::invalid_batch_constraint( + BatchProofShapeError::InvalidColumnOpeningsPerAir { + air_idx, + expected: vk.num_parts(), + actual: part_openings.len(), + }, + ); + } else if part_openings[0].len() != vk.params.width.common_main * openings_per_col { + return ProofShapeError::invalid_batch_constraint( + BatchProofShapeError::InvalidColumnOpeningsPerAirMain { + air_idx, + expected: vk.params.width.common_main, + actual: part_openings[0].len(), + }, + ); + } else if let Some(preprocessed_width) = &vk.params.width.preprocessed { + if part_openings[1].len() != *preprocessed_width * openings_per_col { + return ProofShapeError::invalid_batch_constraint( + BatchProofShapeError::InvalidColumnOpeningsPerAirPreprocessed { + air_idx, + expected: *preprocessed_width, + actual: part_openings[1].len(), + }, + ); + } + } + + let cached_openings = &part_openings[1 + (vk.preprocessed_data.is_some() as usize)..]; + for (cached_idx, (col_opening, &width)) in cached_openings + .iter() + .zip(&vk.params.width.cached_mains) + .enumerate() + { + if col_opening.len() != width * openings_per_col { + return ProofShapeError::invalid_batch_constraint( + BatchProofShapeError::InvalidColumnOpeningsPerAirCached { + air_idx, + cached_idx, + expected: width, + actual: col_opening.len(), + }, + ); + } + } + } + + // STACKING PROOF SHAPE + let stacking_proof = &proof.stacking_proof; + + let s_0_deg = 2 * ((1 << l_skip) - 1); + if stacking_proof.univariate_round_coeffs.len() != s_0_deg + 1 { + return ProofShapeError::invalid_stacking( + StackingProofShapeError::InvalidUnivariateRoundCoeffs { + expected: s_0_deg + 1, + actual: stacking_proof.univariate_round_coeffs.len(), + }, + ); + } else if stacking_proof.sumcheck_round_polys.len() != mvk.params.n_stack { + return ProofShapeError::invalid_stacking( + StackingProofShapeError::InvalidSumcheckRoundPolys { + expected: mvk.params.n_stack, + actual: stacking_proof.sumcheck_round_polys.len(), + }, + ); + } + + let common_main_layout = StackedLayout::new( + l_skip, + mvk.params.n_stack + l_skip, + per_trace + .iter() + .map(|(_, vk, vdata)| (vk.params.width.common_main, vdata.log_height)) + .collect_vec(), + ); + + let other_layouts = per_trace + .iter() + .flat_map(|(_, vk, vdata)| { + vk.params + .width + .preprocessed + .iter() + .chain(&vk.params.width.cached_mains) + .copied() + .map(|width| (width, vdata.log_height)) + .collect_vec() + }) + .map(|sorted| StackedLayout::new(l_skip, mvk.params.n_stack + l_skip, vec![sorted])) + .collect_vec(); + + let layouts = [common_main_layout] + .into_iter() + .chain(other_layouts) + .collect_vec(); + + if stacking_proof.stacking_openings.len() != layouts.len() { + return ProofShapeError::invalid_stacking(StackingProofShapeError::InvalidStackOpenings { + expected: layouts.len(), + actual: stacking_proof.stacking_openings.len(), + }); + } + + for (commit_idx, (openings, layout)) in stacking_proof + .stacking_openings + .iter() + .zip(&layouts) + .enumerate() + { + let stacked_matrix_width = layout.sorted_cols.last().unwrap().2.col_idx + 1; + if openings.len() != stacked_matrix_width { + return ProofShapeError::invalid_stacking( + StackingProofShapeError::InvalidStackOpeningsPerMatrix { + commit_idx, + expected: stacked_matrix_width, + actual: openings.len(), + }, + ); + } + } + + // WHIR PROOF SHAPE + let whir_proof = &proof.whir_proof; + + let log_stacked_height = mvk.params.log_stacked_height(); + let num_whir_rounds = mvk.params.num_whir_rounds(); + let num_whir_sumcheck_rounds = mvk.params.num_whir_sumcheck_rounds(); + let k_whir = mvk.params.k_whir(); + debug_assert_ne!(num_whir_rounds, 0); + + if whir_proof.whir_sumcheck_polys.len() != num_whir_sumcheck_rounds { + return ProofShapeError::invalid_whir(WhirProofShapeError::InvalidSumcheckPolys { + expected: num_whir_sumcheck_rounds, + actual: whir_proof.whir_sumcheck_polys.len(), + }); + } else if whir_proof.codeword_commits.len() != num_whir_rounds - 1 { + return ProofShapeError::invalid_whir(WhirProofShapeError::InvalidCodewordCommits { + expected: num_whir_rounds - 1, + actual: whir_proof.codeword_commits.len(), + }); + } else if whir_proof.ood_values.len() != num_whir_rounds - 1 { + return ProofShapeError::invalid_whir(WhirProofShapeError::InvalidOodValues { + expected: num_whir_rounds - 1, + actual: whir_proof.ood_values.len(), + }); + } else if whir_proof.folding_pow_witnesses.len() != num_whir_sumcheck_rounds { + return ProofShapeError::invalid_whir(WhirProofShapeError::InvalidFoldingPowWitnesses { + expected: num_whir_sumcheck_rounds, + actual: whir_proof.folding_pow_witnesses.len(), + }); + } else if whir_proof.query_phase_pow_witnesses.len() != num_whir_rounds { + return ProofShapeError::invalid_whir(WhirProofShapeError::InvalidQueryPhasePowWitnesses { + expected: num_whir_rounds, + actual: whir_proof.query_phase_pow_witnesses.len(), + }); + } else if whir_proof.initial_round_opened_rows.len() != layouts.len() { + return ProofShapeError::invalid_whir(WhirProofShapeError::InvalidInitialRoundOpenedRows { + expected: layouts.len(), + actual: whir_proof.initial_round_opened_rows.len(), + }); + } else if whir_proof.initial_round_merkle_proofs.len() != layouts.len() { + return ProofShapeError::invalid_whir( + WhirProofShapeError::InvalidInitialRoundMerkleProofs { + expected: layouts.len(), + actual: whir_proof.initial_round_merkle_proofs.len(), + }, + ); + } else if whir_proof.codeword_opened_values.len() != num_whir_rounds - 1 { + return ProofShapeError::invalid_whir(WhirProofShapeError::InvalidCodewordOpenedRows { + expected: num_whir_rounds - 1, + actual: whir_proof.codeword_opened_values.len(), + }); + } else if whir_proof.codeword_merkle_proofs.len() != num_whir_rounds - 1 { + return ProofShapeError::invalid_whir(WhirProofShapeError::InvalidCodewordMerkleProofs { + expected: num_whir_rounds - 1, + actual: whir_proof.codeword_merkle_proofs.len(), + }); + } else if whir_proof.final_poly.len() != 1 << mvk.params.log_final_poly_len() { + return ProofShapeError::invalid_whir(WhirProofShapeError::InvalidFinalPolyLen { + expected: 1 << mvk.params.log_final_poly_len(), + actual: whir_proof.final_poly.len(), + }); + } + + let initial_whir_round_num_queries = mvk.params.whir.rounds[0].num_queries; + for (commit_idx, (opened_rows, merkle_proofs)) in whir_proof + .initial_round_opened_rows + .iter() + .zip(&whir_proof.initial_round_merkle_proofs) + .enumerate() + { + if opened_rows.len() != initial_whir_round_num_queries { + return ProofShapeError::invalid_whir( + WhirProofShapeError::InvalidInitialRoundOpenedRowsQueries { + commit_idx, + expected: initial_whir_round_num_queries, + actual: opened_rows.len(), + }, + ); + } else if merkle_proofs.len() != initial_whir_round_num_queries { + return ProofShapeError::invalid_whir( + WhirProofShapeError::InvalidInitialRoundMerkleProofsQueries { + commit_idx, + expected: initial_whir_round_num_queries, + actual: merkle_proofs.len(), + }, + ); + } + let width = stacking_proof.stacking_openings[commit_idx].len(); + for (opened_idx, rows) in opened_rows.iter().enumerate() { + if rows.len() != 1 << k_whir { + return ProofShapeError::invalid_whir( + WhirProofShapeError::InvalidInitialRoundOpenedRowK { + opened_idx, + commit_idx, + expected: 1 << k_whir, + actual: rows.len(), + }, + ); + } + for (row_idx, row) in rows.iter().enumerate() { + if row.len() != width { + return ProofShapeError::invalid_whir( + WhirProofShapeError::InvalidInitialRoundOpenedRowWidth { + row_idx, + commit_idx, + expected: width, + actual: row.len(), + }, + ); + } + } + } + + let merkle_depth = (log_stacked_height + mvk.params.log_blowup).saturating_sub(k_whir); + for (opened_idx, proof) in merkle_proofs.iter().enumerate() { + if proof.len() != merkle_depth { + return ProofShapeError::invalid_whir( + WhirProofShapeError::InvalidInitialRoundMerkleProofDepth { + opened_idx, + commit_idx, + expected: merkle_depth, + actual: proof.len(), + }, + ); + } + } + } + + for (round_minus_one, (opened_values_per_query, merkle_proofs)) in whir_proof + .codeword_opened_values + .iter() + .zip(&whir_proof.codeword_merkle_proofs) + .take(num_whir_rounds - 1) + .enumerate() + { + let round = round_minus_one + 1; + let num_queries = mvk.params.whir.rounds[round].num_queries; + if opened_values_per_query.len() != num_queries { + return ProofShapeError::invalid_whir( + WhirProofShapeError::InvalidCodewordOpenedRowsQueries { + round, + expected: num_queries, + actual: opened_values_per_query.len(), + }, + ); + } else if merkle_proofs.len() != num_queries { + return ProofShapeError::invalid_whir( + WhirProofShapeError::InvalidCodewordMerkleProofsQueries { + round, + expected: num_queries, + actual: merkle_proofs.len(), + }, + ); + } + + for (opened_idx, opened_values) in opened_values_per_query.iter().enumerate() { + if opened_values.len() != 1 << mvk.params.k_whir() { + return ProofShapeError::invalid_whir( + WhirProofShapeError::InvalidCodewordOpenedValues { + round, + opened_idx, + expected: 1 << mvk.params.k_whir(), + actual: opened_values.len(), + }, + ); + } + } + + let merkle_depth = log_stacked_height + mvk.params.log_blowup - k_whir - round; + for (opened_idx, proof) in merkle_proofs.iter().enumerate() { + if proof.len() != merkle_depth { + return ProofShapeError::invalid_whir( + WhirProofShapeError::InvalidCodewordMerkleProofDepth { + round, + opened_idx, + expected: merkle_depth, + actual: proof.len(), + }, + ); + } + } + } + + Ok(layouts) +} diff --git a/crates/stark-backend-v2/src/verifier/stacked_reduction.rs b/crates/stark-backend-v2/src/verifier/stacked_reduction.rs new file mode 100644 index 00000000..d6e2e436 --- /dev/null +++ b/crates/stark-backend-v2/src/verifier/stacked_reduction.rs @@ -0,0 +1,451 @@ +use std::iter::zip; + +use itertools::Itertools; +use p3_field::FieldAlgebra; +use thiserror::Error; +use tracing::{debug, instrument}; + +use crate::{ + poly_common::{ + eval_eq_mle, eval_eq_prism, eval_in_uni, eval_rot_kernel_prism, horner_eval, + interpolate_quadratic_at_012, + }, + poseidon2::sponge::FiatShamirTranscript, + proof::{column_openings_by_rot, StackingProof}, + prover::stacked_pcs::StackedLayout, + EF, F, +}; + +#[derive(Error, Debug, PartialEq, Eq)] +pub enum StackedReductionError { + #[error("s_0 does not match s_0 polynomial evaluation sum: {s_0} != {s_0_sum_eval}")] + S0Mismatch { s_0: EF, s_0_sum_eval: EF }, + + #[error("s_n(u_n) does not match claimed q(u) sum: {claim} != {final_sum}")] + FinalSumMismatch { claim: EF, final_sum: EF }, +} + +/// `has_preprocessed` must be per present trace in sorted AIR order. +#[allow(clippy::too_many_arguments)] +#[instrument(level = "debug", skip_all)] +pub fn verify_stacked_reduction( + transcript: &mut TS, + proof: &StackingProof, + layouts: &[StackedLayout], + need_rot_per_commit: &[Vec], + l_skip: usize, + n_stack: usize, + column_openings: &Vec>>, + r: &[EF], + omega_shift_pows: &[F], +) -> Result, StackedReductionError> { + /* + * SETUP + * + * We start by setting up for the rounds below. Most importantly, we need to ensure that the + * order we process column_openings is the same as the stacked reduction prover. The prover + * orders the claims per commit -> per column (as in layouts), but column_openings is per AIR + * -> per part (common main, preprocessed, then cached) -> per column. Note that the verifier + * needs to compute and pass in has_preprocessed, which is expected to be sorted in the same + * way column_openings is (i.e. sorted by trace height). + */ + let omega_order = omega_shift_pows.len(); + let omega_order_f = F::from_canonical_usize(omega_order); + + debug_assert_eq!(layouts.len(), need_rot_per_commit.len()); + let mut lambda_idx = 0usize; + let lambda_indices_per_layout: Vec> = layouts + .iter() + .enumerate() + .map(|(commit_idx, layout)| { + let need_rot_for_commit = &need_rot_per_commit[commit_idx]; + debug_assert_eq!(need_rot_for_commit.len(), layout.mat_starts.len()); + layout + .sorted_cols + .iter() + .map(|&(mat_idx, _col_idx, _slice)| { + lambda_idx += 1; + (lambda_idx - 1, need_rot_for_commit[mat_idx]) + }) + .collect_vec() + }) + .collect_vec(); + let t_claims_len = lambda_idx; + let mut t_claims = Vec::with_capacity(t_claims_len); + + // common main columns (commit 0) + for (trace_idx, parts) in column_openings.iter().enumerate() { + let need_rot = need_rot_per_commit[0][trace_idx]; + t_claims.extend(column_openings_by_rot(&parts[0], need_rot)); + } + + // preprocessed and cached columns (commits 1..) + let mut commit_idx = 1usize; + for parts in column_openings { + for cols in parts.iter().skip(1) { + let need_rot = need_rot_per_commit[commit_idx][0]; + t_claims.extend(column_openings_by_rot(cols, need_rot)); + commit_idx += 1; + } + } + + assert_eq!(t_claims.len(), t_claims_len); + debug!(?t_claims); + + let lambda = transcript.sample_ext(); + let lambda_sqr_powers = (lambda * lambda).powers().take(t_claims_len).collect_vec(); + + /* + * INITIAL UNIVARIATE ROUND + * + * In this round we compute s_0 = sum_i (t_i * lambda^i) from the column opening claims t_i + * and compare it to the s_1 polynomial in proof. If the polynomial was correctly computed, + * then we should have s_0 == sum_{z in D} s_1(z). + * + * Note that we abuse the properties of multiplicative subgroup D to speed up the computation + * of sum_{z in D} s_1(z). Suppose s_1(x) = a_0 + a_1 * x + ... a_k * x^k. Because we have + * omega^{|D|} == 1, sum_{z in D} s_1(z) = |D| * (a_0 + a_{|D|} + ...). + */ + let s_0 = zip(&t_claims, &lambda_sqr_powers) + .map(|(&t_i, &lambda_i)| (t_i.0 + t_i.1 * lambda) * lambda_i) + .sum::(); + let s_0_sum_eval = proof + .univariate_round_coeffs + .iter() + .step_by(omega_order) + .copied() + .sum::() + * omega_order_f; + + if s_0 != s_0_sum_eval { + return Err(StackedReductionError::S0Mismatch { s_0, s_0_sum_eval }); + } + + for coeffs in &proof.univariate_round_coeffs { + transcript.observe_ext(*coeffs); + } + + let mut u = vec![EF::ZERO; n_stack + 1]; + u[0] = transcript.sample_ext(); + debug!(round = 0, u_round = %u[0]); + + let mut s_j_0 = s_0; + let mut claim = horner_eval(&proof.univariate_round_coeffs, u[0]); + + /* + * SUMCHECK ROUNDS 1 TO N + * + * We sample size n_stack vector u using the transcript, and run the verifier sumcheck for + * rounds 1 to n_stack. We start by evaluating the univariate round polynomial at u_0, which + * we store as s_0(u_0). We then evaluate s_j(0) = s_{j - 1}(u_{j - 1}) - s_j(1) for each j, + * which we then use with s_j(1) and s_j(2) to interpolate s_j(u_j). + */ + + u.iter_mut().enumerate().skip(1).for_each(|(j, u_j)| { + let s_j_1 = proof.sumcheck_round_polys[j - 1][0]; + let s_j_2 = proof.sumcheck_round_polys[j - 1][1]; + transcript.observe_ext(s_j_1); + transcript.observe_ext(s_j_2); + *u_j = transcript.sample_ext(); + s_j_0 = claim - s_j_1; + claim = interpolate_quadratic_at_012(&[s_j_0, s_j_1, s_j_2], *u_j); + debug!(round = %j, sum_claim = %claim); + }); + + /* + * FINAL VERIFICATION + * + * Finally, to verify that the claims about t_i(r) were properly reduced we assert that the + * final s_{n_stack}(u_{n_stack}) == sum_j (lambda^j * q_{j'}(u) * h(u, r, b_j)), where each + * j maps to some (non-unique) j' and h(u, r, b_j) is either (a) eq(u_{n_j}, r_{n_j}) * + * eq(u_{> n_j}, b_j) or (b) rot(u_{n_j}, r_{n_j}) * eq(u_{> n_j}, b_j). + * + * It is up to the verifier to compute each h(u, r, b_j). Let q_coeffs[j'] be the sum of all + * lambda^j * h(u, r, b_j) such that j maps to j' - given claims q_{j'}(u), we thus want to + * assert s_{n_stack}(u_{n_stack}) == sum_{j'} q_{j'}(u) * q_coeffs[j']. + */ + let mut q_coeffs = proof + .stacking_openings + .iter() + .map(|vec| vec![EF::ZERO; vec.len()]) + .collect_vec(); + + layouts + .iter() + .enumerate() + .zip(q_coeffs.iter_mut()) + .for_each(|((commit_idx, layout), coeffs)| { + let lambda_indices = &lambda_indices_per_layout[commit_idx]; + layout + .sorted_cols + .iter() + .enumerate() + .for_each(|(col_idx, &(_, _, s))| { + let (lambda_idx, need_rot) = lambda_indices[col_idx]; + let n = s.log_height() as isize - l_skip as isize; + let n_lift = n.max(0) as usize; + let b = (l_skip + n_lift..l_skip + n_stack) + .map(|j| F::from_bool((s.row_idx >> j) & 1 == 1)) + .collect_vec(); + let eq_mle = eval_eq_mle(&u[n_lift + 1..], &b); + let ind = eval_in_uni(l_skip, n, u[0]); + let (l, rs_n) = if n.is_negative() { + ( + l_skip.wrapping_add_signed(n), + &[r[0].exp_power_of_2(-n as usize)] as &[_], + ) + } else { + (l_skip, &r[..=n_lift]) + }; + let eq_prism = eval_eq_prism(l, &u[..=n_lift], rs_n); + let mut batched = lambda_sqr_powers[lambda_idx] * eq_prism; + if need_rot { + let rot_kernel_prism = eval_rot_kernel_prism(l, &u[..=n_lift], rs_n); + batched += lambda_sqr_powers[lambda_idx] * lambda * rot_kernel_prism; + } + coeffs[s.col_idx] += eq_mle * batched * ind; + }); + }); + + let final_sum = q_coeffs.iter().zip(proof.stacking_openings.iter()).fold( + EF::ZERO, + |acc, (q_coeff_vec, q_j_vec)| { + acc + q_coeff_vec + .iter() + .zip(q_j_vec.iter()) + .fold(EF::ZERO, |acc, (&q_coeff, &q_j)| { + transcript.observe_ext(q_j); + acc + (q_coeff * q_j) + }) + }, + ); + + if claim != final_sum { + return Err(StackedReductionError::FinalSumMismatch { claim, final_sum }); + } + + Ok(u) +} + +#[cfg(test)] +mod tests { + use itertools::Itertools; + use p3_dft::{Radix2Bowers, TwoAdicSubgroupDft}; + use p3_field::{FieldAlgebra, FieldExtensionAlgebra, PrimeField32, TwoAdicField}; + use p3_util::log2_ceil_usize; + use rand::{rngs::StdRng, Rng, SeedableRng}; + + use super::*; + use crate::{poseidon2::sponge::DuplexSponge, prover::stacked_pcs::StackedSlice}; + + const N_STACK: usize = 4; + const L_SKIP: usize = 2; + + struct StackedReductionTestCase { + pub transcript: DuplexSponge, + pub proof: StackingProof, + pub layouts: Vec, + pub need_rot_per_commit: Vec>, + pub column_openings: Vec>>, + pub r: Vec, + pub omega_pows: Vec, + } + + fn compute_t( + q: impl Fn(&[EF]) -> EF, + r: &[EF], + b: &[F], + u: &[EF], + x: EF, + round: usize, + l_skip: usize, + ) -> EF { + let n_t = N_STACK - b.len(); + let mut sum = EF::ZERO; + for i in 0..(1 << (N_STACK - round)) { + let hypercube = (0..(N_STACK - round)) + .map(|bit_idx| EF::from_canonical_usize((i >> bit_idx) & 1)) + .collect_vec(); + let z = u + .iter() + .take(round) + .chain([x].iter()) + .chain(hypercube.iter()) + .copied() + .collect_vec(); + let eq_or_rot_eval = if ROT { + eval_rot_kernel_prism(l_skip, &z[..=n_t], &r[..=n_t]) + } else { + eval_eq_prism(l_skip, &z[..=n_t], &r[..=n_t]) + }; + sum += q(&z) * eq_or_rot_eval * eval_eq_mle(&z[n_t + 1..], b); + } + sum + } + + fn generate_random_linear_q(rng: &mut StdRng) -> impl Fn(&[EF]) -> EF { + let coeffs = (0..=N_STACK) + .map(|_| EF::from_canonical_usize(rng.random_range(0usize..100))) + .collect_vec(); + move |vals: &[EF]| { + coeffs + .iter() + .zip(vals) + .fold(EF::ZERO, |acc, (&coeff, &val)| acc + coeff * val) + } + } + + fn generate_single_column_test_case() -> StackedReductionTestCase { + let mut rng = StdRng::from_seed([42; 32]); + let omega_pows = F::two_adic_generator(L_SKIP) + .powers() + .take(1 << L_SKIP) + .collect_vec(); + + let slice = StackedSlice::new(0, 0, L_SKIP + N_STACK); + let layout = StackedLayout::from_raw_parts(L_SKIP, L_SKIP + N_STACK, vec![(0, 0, slice)]); + + let q = generate_random_linear_q(&mut rng); + let r = (0..=N_STACK) + .map(|_| EF::from_canonical_u32(rng.random_range(0..F::ORDER_U32))) + .collect_vec(); + let n = slice.log_height() - L_SKIP; + let b = (L_SKIP + n..L_SKIP + N_STACK) + .map(|j| F::from_bool((slice.row_idx >> j) & 1 == 1)) + .collect_vec(); + let mut u = vec![]; + + let t = omega_pows.iter().fold(EF::ZERO, |acc, &omega| { + acc + compute_t::(&q, &r, &b, &u, EF::from_base(omega), 0, L_SKIP) + }); + let t_rot = omega_pows.iter().fold(EF::ZERO, |acc, &omega| { + acc + compute_t::(&q, &r, &b, &u, EF::from_base(omega), 0, L_SKIP) + }); + let column_openings = vec![vec![vec![t, t_rot]]]; + + let mut transcript = DuplexSponge::default(); + let lambda = transcript.sample_ext(); + + let s_0_deg = 2 * ((1 << L_SKIP) - 1); + let log_dft_size = log2_ceil_usize(s_0_deg + 1); + let two_adic_gen = F::two_adic_generator(log_dft_size); + + let univariate_round_evals = two_adic_gen + .powers() + .take(1 << log_dft_size) + .map(|z| { + compute_t::(&q, &r, &b, &u, z.into(), 0, L_SKIP) + + lambda * compute_t::(&q, &r, &b, &u, z.into(), 0, L_SKIP) + }) + .collect_vec(); + let univariate_round_coeffs = Radix2Bowers.coset_idft(univariate_round_evals, EF::ONE); + + for coeffs in &univariate_round_coeffs { + transcript.observe_ext(*coeffs); + } + + let mut sumcheck_round_polys = vec![]; + u.push(transcript.sample_ext()); + for round in 1..=N_STACK { + sumcheck_round_polys.push([ + compute_t::(&q, &r, &b, &u, EF::ONE, round, L_SKIP) + + lambda * compute_t::(&q, &r, &b, &u, EF::ONE, round, L_SKIP), + compute_t::(&q, &r, &b, &u, EF::TWO, round, L_SKIP) + + lambda * compute_t::(&q, &r, &b, &u, EF::TWO, round, L_SKIP), + ]); + transcript.observe_ext(sumcheck_round_polys[round - 1][0]); + transcript.observe_ext(sumcheck_round_polys[round - 1][1]); + u.push(transcript.sample_ext()); + } + + let q_at_u = q(&u); + transcript.observe_ext(q_at_u); + + let proof = StackingProof { + univariate_round_coeffs, + sumcheck_round_polys, + stacking_openings: vec![vec![q_at_u]], + }; + + StackedReductionTestCase { + transcript: DuplexSponge::default(), + proof, + layouts: vec![layout], + need_rot_per_commit: vec![vec![true]], + column_openings, + r, + omega_pows, + } + } + + #[test] + fn verify_single_column_test() { + let mut test_case = generate_single_column_test_case(); + verify_stacked_reduction( + &mut test_case.transcript, + &test_case.proof, + &test_case.layouts, + &test_case.need_rot_per_commit, + L_SKIP, + N_STACK, + &test_case.column_openings, + &test_case.r, + &test_case.omega_pows, + ) + .unwrap(); + } + + #[test] + fn single_column_univariate_round_negative_test() { + let mut test_case = generate_single_column_test_case(); + test_case.proof.univariate_round_coeffs[0] += EF::ONE; + verify_stacked_reduction( + &mut test_case.transcript, + &test_case.proof, + &test_case.layouts, + &test_case.need_rot_per_commit, + L_SKIP, + N_STACK, + &test_case.column_openings, + &test_case.r, + &test_case.omega_pows, + ) + .unwrap_err(); + } + + #[test] + fn single_column_sumcheck_rounds_negative_test() { + let mut test_case = generate_single_column_test_case(); + test_case.proof.sumcheck_round_polys[N_STACK - 1][0] += EF::ONE; + verify_stacked_reduction( + &mut test_case.transcript, + &test_case.proof, + &test_case.layouts, + &test_case.need_rot_per_commit, + L_SKIP, + N_STACK, + &test_case.column_openings, + &test_case.r, + &test_case.omega_pows, + ) + .unwrap_err(); + } + + #[test] + fn single_column_stacking_openings_negative_test() { + let mut test_case = generate_single_column_test_case(); + test_case.proof.stacking_openings[0][0] += EF::ONE; + verify_stacked_reduction( + &mut test_case.transcript, + &test_case.proof, + &test_case.layouts, + &test_case.need_rot_per_commit, + L_SKIP, + N_STACK, + &test_case.column_openings, + &test_case.r, + &test_case.omega_pows, + ) + .unwrap_err(); + } +} diff --git a/crates/stark-backend-v2/src/verifier/sumcheck.rs b/crates/stark-backend-v2/src/verifier/sumcheck.rs new file mode 100644 index 00000000..46d328b4 --- /dev/null +++ b/crates/stark-backend-v2/src/verifier/sumcheck.rs @@ -0,0 +1,110 @@ +use p3_field::{ExtensionField, Field, FieldAlgebra}; +use p3_util::log2_strict_usize; +use tracing::debug; + +use crate::{ + poseidon2::sponge::FiatShamirTranscript, + prover::sumcheck::{SumcheckCubeProof, SumcheckPrismProof}, + EF, +}; + +pub fn verify_sumcheck_multilinear( + transcript: &mut TS, + proof: &SumcheckCubeProof, +) -> Result<(), String> +where + EF: ExtensionField, +{ + let SumcheckCubeProof { + sum_claim, + round_polys_eval, + eval_claim, + } = proof; + let n = round_polys_eval.len(); + + let mut cur_sum = *sum_claim; + #[allow(clippy::needless_range_loop)] + for round in 0..n { + assert_eq!(round_polys_eval[round].len(), 1); + let s_1 = round_polys_eval[round][0]; + let s_0 = cur_sum - s_1; + + if round == 0 { + transcript.observe_ext(*sum_claim); + } + transcript.observe_ext(s_1); + + let r_round = transcript.sample_ext(); + debug!(%round, %r_round); + cur_sum = s_0 + (s_1 - s_0) * r_round; + } + if cur_sum != *eval_claim { + return Err("The provided evaluations are inconsistent".to_string()); + } + + transcript.observe_ext(*eval_claim); + + Ok(()) +} + +pub fn verify_sumcheck_prismalinear( + transcript: &mut TS, + l_skip: usize, + proof: &SumcheckPrismProof, +) -> Result<(), String> +where + EF: ExtensionField, +{ + let SumcheckPrismProof { + sum_claim, + s_0, + round_polys_eval, + eval_claim, + } = proof; + let n = round_polys_eval.len(); + + if log2_strict_usize(s_0.0.len()) != l_skip { + return Err(format!( + "Wrong proof shape: `s_0` must have length 2^l_skip = {}, but has {}", + (1 << l_skip), + s_0.0.len() + )); + } + + transcript.observe_ext(*sum_claim); + for x in s_0.0.iter() { + transcript.observe_ext(*x); + } + let r_0 = transcript.sample_ext(); + debug!(round = 0, r_round = %r_0); + + if *sum_claim != s_0.0[0] * EF::from_canonical_usize(s_0.0.len()) { + return Err(format!( + "`sum_claim` does not equal the sum of `s_0` at all the roots of unity: {} != {}", + *sum_claim, + s_0.0[0] * EF::from_canonical_usize(s_0.0.len()) + )); + } + + let mut cur_sum = s_0.eval_at_point(r_0); + #[allow(clippy::needless_range_loop)] + for round in 0..n { + debug!(%round, %cur_sum); + assert_eq!(round_polys_eval[round].len(), 1); + let s_1 = round_polys_eval[round][0]; + let s_0 = cur_sum - s_1; + + transcript.observe_ext(s_1); + let r_round = transcript.sample_ext(); + debug!(%round, %r_round); + cur_sum = s_0 + (s_1 - s_0) * r_round; + } + + if cur_sum != *eval_claim { + return Err("The provided evaluations are inconsistent".to_string()); + } + + transcript.observe_ext(*eval_claim); + + Ok(()) +} diff --git a/crates/stark-backend-v2/src/verifier/whir.rs b/crates/stark-backend-v2/src/verifier/whir.rs new file mode 100644 index 00000000..e2d4a81c --- /dev/null +++ b/crates/stark-backend-v2/src/verifier/whir.rs @@ -0,0 +1,725 @@ +use core::iter::zip; + +use itertools::{izip, Itertools}; +use p3_field::{Field, FieldAlgebra, FieldExtensionAlgebra, TwoAdicField}; +use thiserror::Error; +use tracing::instrument; + +use crate::{ + poly_common::{eval_eq_mle, horner_eval, interpolate_quadratic_at_012, Squarable}, + poseidon2::sponge::{ + poseidon2_compress, poseidon2_hash_slice, poseidon2_tree_compress, FiatShamirTranscript, + }, + proof::WhirProof, + prover::poly::Mle, + Digest, SystemParams, EF, F, +}; + +#[inline] +fn ensure(cond: bool, err: VerifyWhirError) -> Result<(), VerifyWhirError> { + if cond { + Ok(()) + } else { + Err(err) + } +} + +/// Verify a WHIR proof. +/// +/// Assumes that all inputs have already been checked to have the correct sizes. +#[instrument(level = "debug", skip_all)] +pub fn verify_whir( + transcript: &mut TS, + params: &SystemParams, + whir_proof: &WhirProof, + stacking_openings: &[Vec], + commitments: &[Digest], + u: &[EF], +) -> Result<(), VerifyWhirError> { + let widths = stacking_openings + .iter() + .map(|v| v.len()) + .collect::>(); + + let mu = transcript.sample_ext(); + + let WhirProof { + whir_sumcheck_polys, + codeword_commits, + ood_values, + initial_round_opened_rows, + initial_round_merkle_proofs, + codeword_opened_values, + codeword_merkle_proofs, + folding_pow_witnesses, + query_phase_pow_witnesses, + final_poly, + } = whir_proof; + + let m = params.l_skip + params.n_stack; + let k_whir = params.k_whir(); + debug_assert_eq!((m - params.log_final_poly_len()) % k_whir, 0); + let num_whir_rounds = params.num_whir_rounds(); + let mut log_rs_domain_size = m + params.log_blowup; + debug_assert!(params.num_whir_sumcheck_rounds() <= m); + debug_assert_eq!( + folding_pow_witnesses.len(), + params.num_whir_sumcheck_rounds() + ); + + let mut sumcheck_poly_iter = whir_sumcheck_polys.iter(); + let mut folding_pow_iter = folding_pow_witnesses.iter(); + let mu_pows: Vec<_> = mu.powers().take(widths.iter().sum::()).collect(); + let mut claim = stacking_openings + .iter() + .flatten() + .zip(mu_pows.iter()) + .fold(EF::ZERO, |acc, (&opening, &mu_pow)| acc + mu_pow * opening); + + let mut gammas = Vec::with_capacity(num_whir_rounds); + let mut zs = Vec::with_capacity(num_whir_rounds); + let mut z0s = Vec::with_capacity(num_whir_rounds); + let mut alphas = Vec::with_capacity(m); + + debug_assert_eq!(query_phase_pow_witnesses.len(), num_whir_rounds); + for (whir_round, (query_phase_pow_witness, round_params)) in + zip(query_phase_pow_witnesses, ¶ms.whir.rounds).enumerate() + { + // A WHIR round consists of the following steps: + // 1) Run k rounds of sumcheck to obtain polynomial f'. + // 2) On non-final rounds, observe commitment f' on shifted domain. + // 3) On non-final rounds, sample OOD point z0 and observe claim y0 =?= f'(z0). + // 4) Sample in-domain queries z_i and compute f'(z_i) from openings. On the first round, + // the codeword is not committed directly; instead it is derived from the stacking + // commitments. In all other rounds, the previous codeword is committed directly. + // 5) On non-final rounds, sample batching parameter gamma to define next codeword and + // derive new WHIR constraint target (`claim`). + + let is_initial_round = whir_round == 0; + let is_final_round = whir_round == num_whir_rounds - 1; + + let mut alphas_round = Vec::with_capacity(k_whir); + + for _ in 0..k_whir { + if let Some(evals) = sumcheck_poly_iter.next() { + let &[ev1, ev2] = evals; + + transcript.observe_ext(ev1); + transcript.observe_ext(ev2); + + let pow_witness = *folding_pow_iter.next().unwrap(); + if !transcript.check_witness(params.whir.folding_pow_bits, pow_witness) { + return Err(VerifyWhirError::FoldingPoWInvalid); + } + let alpha = transcript.sample_ext(); + alphas_round.push(alpha); + + let ev0 = claim - ev1; + claim = interpolate_quadratic_at_012(&[ev0, ev1, ev2], alpha); + } + } + + let y0 = if is_final_round { + // Observe the final polynomial before the queries on the final + // round. + for coeff in final_poly { + transcript.observe_ext(*coeff); + } + None + } else { + let commit = codeword_commits[whir_round]; + transcript.observe_commit(commit); + + let z0 = transcript.sample_ext(); + z0s.push(z0); + + let y0 = ood_values[whir_round]; + transcript.observe_ext(y0); + Some(y0) + }; + + if !transcript.check_witness(params.whir.query_phase_pow_bits, *query_phase_pow_witness) { + return Err(VerifyWhirError::QueryPhasePoWInvalid); + } + + let num_queries = round_params.num_queries; + let query_indices = + (0..num_queries).map(|_| transcript.sample_bits(log_rs_domain_size - k_whir)); + + let mut zs_round = Vec::with_capacity(num_queries); + let mut ys_round = Vec::with_capacity(num_queries); + + let omega = F::two_adic_generator(log_rs_domain_size); + for (query_idx, index) in query_indices.into_iter().enumerate() { + let zi_root = omega.exp_u64(index as u64); + let zi = zi_root.exp_power_of_2(k_whir); + + let yi = if is_initial_round { + let mut codeword_vals = vec![EF::ZERO; 1 << k_whir]; + let mut mu_pow_iter = mu_pows.iter(); + for (&commit, &width, opened_rows_per_query, merkle_proofs) in izip!( + commitments, + &widths, + initial_round_opened_rows, + initial_round_merkle_proofs + ) { + let opened_rows = &opened_rows_per_query[query_idx]; + let leaf_hashes = opened_rows + .iter() + .map(|opened_row| poseidon2_hash_slice(opened_row)) + .collect_vec(); + let query_digest = poseidon2_tree_compress(leaf_hashes); + let merkle_proof = &merkle_proofs[query_idx]; + merkle_verify(commit, index, query_digest, merkle_proof)?; + + for c in 0..width { + let mu_pow = mu_pow_iter.next().unwrap(); // ok; mu_pows has total_width length + for j in 0..(1 << k_whir) { + codeword_vals[j] += *mu_pow * opened_rows[j][c]; + } + } + } + binary_k_fold(codeword_vals, &alphas_round, zi_root) + } else { + let opened_values = codeword_opened_values[whir_round - 1][query_idx].clone(); + let merkle_proof = &codeword_merkle_proofs[whir_round - 1][query_idx]; + let leaf_hashes = opened_values + .iter() + .map(|opened_value| poseidon2_hash_slice(opened_value.as_base_slice())) + .collect_vec(); + let query_digest = poseidon2_tree_compress(leaf_hashes); + merkle_verify( + codeword_commits[whir_round - 1], + index, + query_digest, + merkle_proof, + )?; + binary_k_fold(opened_values, &alphas_round, zi_root) + }; + zs_round.push(zi); + ys_round.push(yi); + } + // We sample `gamma` even in the final round. There are no observations + // after this challenge and strictly serves to unify the verifier logic. + // Rather than checking that `final_poly(zi) = yi` for all `i` in the + // last round, we accumulate them into `claim`. The final WHIR check + // automatically performs this check for us (now with high probability). + let gamma = transcript.sample_ext(); + if let Some(y0) = y0 { + claim += y0 * gamma; + } + for (yi, gamma_pow) in ys_round.iter().zip(gamma.powers().skip(2)) { + claim += *yi * gamma_pow; + } + gammas.push(gamma); + zs.push(zs_round); + alphas.extend(alphas_round); + + log_rs_domain_size -= 1; + } + debug_assert!(sumcheck_poly_iter.next().is_none()); + + ensure( + final_poly.len() == 1 << params.log_final_poly_len(), + VerifyWhirError::FinalPolyDegree, + )?; + + debug_assert_eq!(alphas.len(), k_whir * num_whir_rounds); + debug_assert_eq!(z0s.len(), num_whir_rounds - 1); + debug_assert_eq!(zs.len(), num_whir_rounds); + debug_assert_eq!(gammas.len(), num_whir_rounds); + + // Here we perform the final WHIR check, which requires us to compute + // + // sum_{b in H_{m-t}} f(b) (eq(u, alpha || b) + + // sum_i sum_j gamma_{i,j} eq(pow(z_i) alpha[ki..] || b)), + // + // where || denotes concatenation. + // + // If we let u' = u[..t] and u'' = u[t..], then by factoring we can rewrite the term + // + // sum_{b in H_{m-t}} f(b) eq(u, alpha || b) = eq(u', alpha) * + // sum_{b in H_{m-t}} f(b) eq(u'', b) + // = eq(u', alpha) * f(u''). + // + // Similar algebra allows us to control the terms with eq(pow(z_i)). Note that here we actually + // end up with f(pow(z_i^{2^p})) for some power p, which is a univariate evaluation. + let t = k_whir * num_whir_rounds; + let f = Mle::from_coeffs(final_poly.clone()); + let mut acc = eval_eq_mle(&alphas[..t], &u[..t]) * f.eval_at_point_inplace(&u[t..]); + let mut j = k_whir; + for i in 0..num_whir_rounds { + let zis = &zs[i]; + let gamma = gammas[i]; + let alpha_slc = &alphas[j..t]; + let slc_len = (t - j) + 1; + + if i != num_whir_rounds - 1 { + let z0_pow = z0s[i].exp_powers_of_2().take(slc_len).collect_vec(); + let (z0_pow_max, z0_pow_left) = z0_pow.split_last().unwrap(); + acc += gamma + * eval_eq_mle(alpha_slc, z0_pow_left) + * horner_eval::(final_poly, *z0_pow_max); + } + + debug_assert_eq!(zis.len(), params.whir.rounds[i].num_queries); + for (zi, gamma_pow) in zip(zis, gamma.powers().skip(2)) { + let zi_pow = zi.exp_powers_of_2().take(slc_len).collect_vec(); + let (zi_pow_max, zi_pow_left) = zi_pow.split_last().unwrap(); + acc += gamma_pow + * eval_eq_mle(alpha_slc, zi_pow_left) + * horner_eval::(final_poly, *zi_pow_max); + } + j += k_whir; + } + ensure(acc == claim, VerifyWhirError::FinalPolyConstraint) +} + +#[derive(Debug, Error, PartialEq, Eq)] +pub enum VerifyWhirError { + #[error("final polynomial has wrong degree")] + FinalPolyDegree, + #[error("folding proof-of-work witness check failed")] + FoldingPoWInvalid, + #[error("query phase proof-of-work witness check failed")] + QueryPhasePoWInvalid, + #[error("final polynomial doesn't explain queries")] + FinalPolyQueryMismatch, + #[error("final poly is not in the final constrained RS code")] + FinalPolyConstraint, + #[error("merkle verification failed")] + MerkleVerify, +} + +/// Evaluates the k-fold binary fold of `f` at `x^{2^k}` given its evaluations +/// `values` on the coset `H = {x, ωx, …, ω^{2^k-1}x}` and fold points `alphas`. +/// +/// Let `g₀ = f`. For `i >= 1` define +/// +/// gᵢ(Y) = fold(g_{i-1}; α_{i-1})(Y), +/// +/// where +/// +/// fold(h; α)(X²) = h(X) + (α - X) * (h(X) - h(-X)) / (2X). +/// +/// If `values = [f(x), f(ωx), …, f(ω^{2^k-1}x)]`, then +/// `binary_k_fold(values, alphas, x)` returns `g_k(x^{2^k})`. +pub fn binary_k_fold(mut values: Vec, alphas: &[EF], x: F) -> EF { + let n = values.len(); + let k = alphas.len(); + debug_assert_eq!(n, 1 << k); + + let omega_k = F::two_adic_generator(k); + let omega_k_inv = omega_k.inverse(); + + let tw = omega_k.powers().take(1 << (k - 1)).collect_vec(); + let inv_tw = omega_k_inv.powers().take(1 << (k - 1)).collect_vec(); + + for (j, (&alpha, x_pow, x_inv_pow)) in izip!( + alphas.iter(), + x.exp_powers_of_2(), + x.inverse().exp_powers_of_2() + ) + .enumerate() + { + let m = n >> (j + 1); + let (lo, hi) = values.split_at_mut(m); + + for i in 0..m { + let t = tw[i << j] * x_pow; + let t_inv = inv_tw[i << j] * x_inv_pow; + lo[i] += (alpha - t) * (lo[i] - hi[i]) * t_inv.halve(); + } + } + values[0] +} + +pub fn merkle_verify( + root: Digest, + mut idx: u32, + leaf_hash: Digest, + merkle_proof: &[Digest], +) -> Result<(), VerifyWhirError> { + let mut cur = leaf_hash; + for &sibling in merkle_proof { + cur = if idx & 1 == 0 { + poseidon2_compress(cur, sibling) + } else { + poseidon2_compress(sibling, cur) + }; + idx >>= 1; + } + if root != cur { + Err(VerifyWhirError::MerkleVerify) + } else { + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use itertools::Itertools; + use openvm_stark_backend::prover::MatrixDimensions; + use openvm_stark_sdk::config::{ + log_up_params::log_up_security_params_baby_bear_100_bits, setup_tracing_with_log_level, + }; + use p3_field::{Field, FieldAlgebra, TwoAdicField}; + use rand::{rngs::StdRng, Rng, SeedableRng}; + use test_case::test_case; + use tracing::Level; + + use super::*; + use crate::{ + poly_common::Squarable, + poseidon2::sponge::{DuplexSponge, DuplexSpongeRecorder, TranscriptHistory}, + prover::{ + poly::Ple, stacked_pcs::stacked_commit, whir::prove_whir_opening, ColMajorMatrix, + CpuBackendV2, DeviceMultiStarkProvingKeyV2, ProvingContextV2, + }, + test_utils::{test_whir_config_small, DuplexSpongeValidator, FibFixture, TestFixture}, + verifier::whir::{binary_k_fold, verify_whir, VerifyWhirError}, + WhirConfig, WhirRoundConfig, EF, F, + }; + + fn generate_random_z(params: &SystemParams, rng: &mut StdRng) -> (Vec, Vec) { + let z_prism: Vec<_> = (0..params.n_stack + 1) + .map(|_| EF::from_wrapped_u64(rng.random())) + .collect(); + + let z_cube = { + let z_cube = z_prism[0] + .exp_powers_of_2() + .take(params.l_skip) + .chain(z_prism[1..].iter().copied()) + .collect_vec(); + debug_assert_eq!(z_cube.len(), params.n_stack + params.l_skip); + z_cube + }; + + (z_prism, z_cube) + } + + fn stacking_openings_for_matrix( + params: &SystemParams, + z_prism: &[EF], + matrix: &ColMajorMatrix, + ) -> Vec { + matrix + .columns() + .map(|col| { + Ple::from_evaluations(params.l_skip, col).eval_at_point( + params.l_skip, + z_prism[0], + &z_prism[1..], + ) + }) + .collect() + } + + fn run_whir_test( + params: SystemParams, + pk: DeviceMultiStarkProvingKeyV2, + ctx: &ProvingContextV2, + ) -> Result<(), VerifyWhirError> { + let (common_main_commit, common_main_pcs_data) = { + let traces = ctx + .common_main_traces() + .map(|(_, trace)| trace) + .collect_vec(); + stacked_commit( + params.l_skip, + params.n_stack, + params.log_blowup, + params.k_whir(), + &traces, + ) + }; + + let mut commits = vec![common_main_commit]; + let mut committed_mats = vec![(&common_main_pcs_data.matrix, &common_main_pcs_data.tree)]; + for (air_id, air_ctx) in &ctx.per_trace { + let pcs_datas = pk.per_air[*air_id] + .preprocessed_data + .iter() + .chain(&air_ctx.cached_mains); + for cd in pcs_datas { + let data = &cd.data; + committed_mats.push((&data.matrix, &data.tree)); + commits.push(data.commit()); + } + } + + let mut rng = StdRng::seed_from_u64(0); + + let (z_prism, z_cube) = generate_random_z(¶ms, &mut rng); + + let mut prover_sponge = DuplexSpongeRecorder::default(); + + let proof = prove_whir_opening( + &mut prover_sponge, + params.l_skip, + params.log_blowup, + params.whir(), + &committed_mats, + &z_cube, + ); + + let stacking_openings = committed_mats + .iter() + .map(|(matrix, _)| stacking_openings_for_matrix(¶ms, &z_prism, matrix)) + .collect_vec(); + + let mut verifier_sponge = DuplexSpongeValidator::new(prover_sponge.into_log()); + verify_whir( + &mut verifier_sponge, + ¶ms, + &proof, + &stacking_openings, + &commits, + &z_cube, + ) + } + + fn run_whir_fib_test(params: SystemParams) -> Result<(), VerifyWhirError> { + use crate::{ + poseidon2::sponge::DuplexSponge, prover::DeviceDataTransporterV2, + BabyBearPoseidon2CpuEngineV2, StarkEngineV2, + }; + let engine = BabyBearPoseidon2CpuEngineV2::::new(params.clone()); + let fib = FibFixture::new(0, 1, 1 << params.log_stacked_height()); + let (pk, _vk) = fib.keygen(&engine); + let pk = engine.device().transport_pk_to_device(&pk); + let ctx = fib.generate_proving_ctx(); + run_whir_test(params, pk, &ctx) + } + + #[test_case(0, 1, 1, 0)] + #[test_case(2, 1, 1, 2)] + #[test_case(2, 1, 2, 0)] + #[test_case(2, 1, 3, 1)] + #[test_case(2, 1, 4, 0)] + #[test_case(2, 2, 4, 0)] + fn test_whir_single_fib( + n_stack: usize, + log_blowup: usize, + k_whir: usize, + log_final_poly_len: usize, + ) -> Result<(), VerifyWhirError> { + setup_tracing_with_log_level(Level::DEBUG); + let l_skip = 2; + let whir = test_whir_config_small(log_blowup, l_skip + n_stack, k_whir, log_final_poly_len); + + let params = SystemParams { + l_skip, + n_stack, + log_blowup, + whir, + logup: log_up_security_params_baby_bear_100_bits(), + max_constraint_degree: 3, + }; + run_whir_fib_test(params) + } + + #[test] + fn test_fold_single() { + let mut rng = StdRng::seed_from_u64(0); + + let a0 = EF::from_wrapped_u32(rng.random()); + let a1 = EF::from_wrapped_u32(rng.random()); + let alpha = EF::from_wrapped_u32(rng.random()); + let x = F::from_wrapped_u32(rng.random()); + + let result = binary_k_fold(vec![a0, a1], &[alpha], x); + assert_eq!(result, a0 + (alpha - x) * (a0 - a1) * x.double().inverse()); + } + + #[test] + fn test_fold_double() { + let mut rng = StdRng::seed_from_u64(0); + + let a0 = EF::from_wrapped_u32(rng.random()); + let a1 = EF::from_wrapped_u32(rng.random()); + let a2 = EF::from_wrapped_u32(rng.random()); + let a3 = EF::from_wrapped_u32(rng.random()); + let alpha0 = EF::from_wrapped_u32(rng.random()); + let alpha1 = EF::from_wrapped_u32(rng.random()); + + let x = F::from_wrapped_u32(rng.random()); + + let result = binary_k_fold(vec![a0, a1, a2, a3], &[alpha0, alpha1], x); + let tw = F::two_adic_generator(2); + + let b0 = a0 + (alpha0 - x) * (a0 - a2) * x.double().inverse(); + let b1 = a1 + (alpha0 - (tw * x)) * (a1 - a3) * (tw * x).double().inverse(); + let x2 = x.square(); + let expected = b0 + (alpha1 - x2) * (b0 - b1) * x2.double().inverse(); + + assert_eq!(result, expected); + } + + fn whir_test_config(k_whir: usize) -> WhirConfig { + WhirConfig { + k: k_whir, + rounds: vec![ + WhirRoundConfig { num_queries: 6 }, + WhirRoundConfig { num_queries: 5 }, + ], + query_phase_pow_bits: 1, + folding_pow_bits: 1, + } + } + + #[test] + fn test_whir_multiple_commitments() -> Result<(), VerifyWhirError> { + setup_tracing_with_log_level(Level::DEBUG); + + let mut rng = StdRng::seed_from_u64(42); + + let params = SystemParams { + l_skip: 3, + n_stack: 3, + log_blowup: 1, + whir: whir_test_config(2), + logup: log_up_security_params_baby_bear_100_bits(), + max_constraint_degree: 3, + }; + + let n_rows = 1 << (params.n_stack + params.l_skip); + + let mut matrices = vec![]; + let mut commits = vec![]; + let mut trees = vec![]; + + let num_commitments = 5; + for _ in 0..num_commitments { + let n_cols = (rng.random::() % 10 + 3) as usize; + let data = (0..n_rows * n_cols) + .map(|_| F::from_wrapped_u64(rng.random())) + .collect_vec(); + let mat = ColMajorMatrix::new(data, n_cols); + + let (commit, pcs_data) = stacked_commit( + params.l_skip, + params.n_stack, + params.log_blowup, + params.k_whir(), + &[&mat], + ); + + matrices.push(mat); + commits.push(commit); + trees.push(pcs_data.tree); + } + + debug_assert_eq!(matrices[0].height(), 1 << (params.n_stack + params.l_skip)); + + let (z_prism, z_cube) = generate_random_z(¶ms, &mut rng); + + let mut prover_sponge = DuplexSpongeRecorder::default(); + + let committed_mats = matrices.iter().zip(trees.iter()).collect_vec(); + let proof = prove_whir_opening( + &mut prover_sponge, + params.l_skip, + params.log_blowup, + params.whir(), + &committed_mats, + &z_cube, + ); + + let stacking_openings: Vec> = matrices + .iter() + .map(|mat| stacking_openings_for_matrix(¶ms, &z_prism, mat)) + .collect(); + + let mut verifier_sponge = DuplexSpongeValidator::new(prover_sponge.into_log()); + verify_whir( + &mut verifier_sponge, + ¶ms, + &proof, + &stacking_openings, + &commits, + &z_cube, + ) + } + + #[test] + fn test_whir_multiple_commitments_negative() { + setup_tracing_with_log_level(Level::DEBUG); + + let mut rng = StdRng::seed_from_u64(42); + + let params = SystemParams { + l_skip: 3, + n_stack: 3, + log_blowup: 1, + whir: whir_test_config(2), + logup: log_up_security_params_baby_bear_100_bits(), + max_constraint_degree: 3, + }; + + let n_rows = 1 << (params.n_stack + params.l_skip); + + let mut matrices = vec![]; + let mut commits = vec![]; + let mut trees = vec![]; + + let num_commitments = 5; + for _ in 0..num_commitments { + let n_cols = (rng.random::() % 10 + 3) as usize; + let data = (0..n_rows * n_cols) + .map(|_| F::from_wrapped_u64(rng.random())) + .collect_vec(); + let mat = ColMajorMatrix::new(data, n_cols); + + let (commit, pcs_data) = stacked_commit( + params.l_skip, + params.n_stack, + params.log_blowup, + params.k_whir(), + &[&mat], + ); + + matrices.push(mat); + commits.push(commit); + trees.push(pcs_data.tree); + } + + debug_assert_eq!(matrices[0].height(), 1 << (params.n_stack + params.l_skip)); + + let (z_prism, z_cube) = generate_random_z(¶ms, &mut rng); + + let mut prover_sponge = DuplexSponge::default(); + let mut verifier_sponge = DuplexSponge::default(); + + let committed_mats = matrices.iter().zip(trees.iter()).collect_vec(); + let proof = prove_whir_opening( + &mut prover_sponge, + params.l_skip, + params.log_blowup, + params.whir(), + &committed_mats, + &z_cube, + ); + + let mut stacking_openings: Vec> = matrices + .iter() + .map(|mat| stacking_openings_for_matrix(¶ms, &z_prism, mat)) + .collect(); + + // change an opening to test soundness + stacking_openings[1][2] = EF::ONE; + + assert!(matches!( + verify_whir( + &mut verifier_sponge, + ¶ms, + &proof, + &stacking_openings, + &commits, + &z_cube, + ), + Err(VerifyWhirError::FinalPolyConstraint) + )); + } +} diff --git a/crates/stark-backend/src/air_builders/debug/mod.rs b/crates/stark-backend/src/air_builders/debug/mod.rs index 4e27525d..b4c4cf4a 100644 --- a/crates/stark-backend/src/air_builders/debug/mod.rs +++ b/crates/stark-backend/src/air_builders/debug/mod.rs @@ -10,17 +10,14 @@ use p3_matrix::{dense::RowMajorMatrixView, stack::VerticalPair}; use super::{symbolic::SymbolicConstraints, PartitionedAirBuilder, ViewPair}; use crate::{ config::{StarkGenericConfig, Val}, - interaction::{ - rap::InteractionPhaseAirBuilder, Interaction, InteractionBuilder, RapPhaseSeqKind, - SymbolicInteraction, - }, + interaction::{Interaction, InteractionBuilder, RapPhaseSeqKind}, keygen::types::StarkProvingKey, rap::{AnyRap, PermutationAirBuilderWithExposedValues}, }; mod check_constraints; -use check_constraints::*; +pub use check_constraints::*; use crate::interaction::BusIndex; @@ -279,20 +276,3 @@ where &[] } } - -// No-op -impl InteractionPhaseAirBuilder for DebugConstraintBuilder<'_, SC> { - fn finalize_interactions(&mut self) {} - - fn max_constraint_degree(&self) -> usize { - 0 - } - - fn rap_phase_seq_kind(&self) -> RapPhaseSeqKind { - self.rap_phase_seq_kind - } - - fn symbolic_interactions(&self) -> Vec>> { - vec![] - } -} diff --git a/crates/stark-backend/src/air_builders/symbolic/dag.rs b/crates/stark-backend/src/air_builders/symbolic/dag.rs index 4dafeb20..6eeb1540 100644 --- a/crates/stark-backend/src/air_builders/symbolic/dag.rs +++ b/crates/stark-backend/src/air_builders/symbolic/dag.rs @@ -15,7 +15,7 @@ use crate::{ /// A node in symbolic expression DAG. /// Basically replace `Arc`s in `SymbolicExpression` with node IDs. /// Intended to be serializable and deserializable. -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] +#[derive(Clone, Debug, Hash, Serialize, Deserialize, PartialEq, Eq)] #[serde(bound(serialize = "F: Serialize", deserialize = "F: Deserialize<'de>"))] #[repr(C)] pub enum SymbolicExpressionNode { @@ -95,25 +95,22 @@ pub(crate) fn build_symbolic_constraints_dag( constraints: &[SymbolicExpression], interactions: &[SymbolicInteraction], ) -> SymbolicConstraintsDag { - let mut expr_to_idx = FxHashMap::default(); - let mut nodes = Vec::new(); + let mut builder = SymbolicDagBuilder::new(); let mut constraint_idx: Vec = constraints .iter() - .map(|expr| topological_sort_symbolic_expr(expr, &mut expr_to_idx, &mut nodes)) + .map(|expr| builder.add_expr(expr)) .collect(); constraint_idx.sort(); + constraint_idx.dedup(); let interactions: Vec> = interactions .iter() .map(|interaction| { let fields: Vec = interaction .message .iter() - .map(|field_expr| { - topological_sort_symbolic_expr(field_expr, &mut expr_to_idx, &mut nodes) - }) + .map(|field_expr| builder.add_expr(field_expr)) .collect(); - let count = - topological_sort_symbolic_expr(&interaction.count, &mut expr_to_idx, &mut nodes); + let count = builder.add_expr(&interaction.count); Interaction { message: fields, count, @@ -122,11 +119,8 @@ pub(crate) fn build_symbolic_constraints_dag( } }) .collect(); - // Note[jpw]: there could be few nodes created after `constraint_idx` is built - // from `interactions` even though constraints already contain all interactions. - // This should be marginal and is not optimized for now. let constraints = SymbolicExpressionDag { - nodes, + nodes: builder.nodes, constraint_idx, }; SymbolicConstraintsDag { @@ -135,77 +129,207 @@ pub(crate) fn build_symbolic_constraints_dag( } } -/// `expr_to_idx` is a cache so that the `Arc<_>` references within symbolic expressions get -/// mapped to the same node ID if their underlying references are the same. -fn topological_sort_symbolic_expr<'a, F: Field>( - expr: &'a SymbolicExpression, - expr_to_idx: &mut FxHashMap<&'a SymbolicExpression, usize>, - nodes: &mut Vec>, -) -> usize { - if let Some(&idx) = expr_to_idx.get(expr) { - return idx; +/// Builder for constructing a symbolic expression DAG with structural deduplication +/// and algebraic simplifications. +/// +/// Two caches are used: +/// - `expr_to_idx`: Fast path for expressions with the same Arc pointer +/// - `node_to_idx`: Structural deduplication - catches identical nodes with different Arc pointers +/// +/// Algebraic simplifications performed: +/// - Constant folding: `a + b` → `c`, `a - b` → `c`, `a * b` → `c`, `-a` → `c` (for constants a,b) +/// - `x + 0` → `x`, `0 + x` → `x` +/// - `x - 0` → `x` +/// - `x * 1` → `x`, `1 * x` → `x` +/// - `x * 0` → `0`, `0 * x` → `0` +/// - `x + (-y)` → `x - y` +/// - `x - (-y)` → `x + y` +pub struct SymbolicDagBuilder { + /// Cache: Arc pointer -> node index (fast path for same Arc) + pub expr_to_idx: FxHashMap<*const SymbolicExpression, usize>, + /// Cache: node structure -> node index (structural deduplication) + pub node_to_idx: FxHashMap, usize>, + /// Nodes in topological order + pub nodes: Vec>, +} + +impl Default for SymbolicDagBuilder { + fn default() -> Self { + Self::new() } - let node = match expr { - SymbolicExpression::Variable(var) => SymbolicExpressionNode::Variable(*var), - SymbolicExpression::IsFirstRow => SymbolicExpressionNode::IsFirstRow, - SymbolicExpression::IsLastRow => SymbolicExpressionNode::IsLastRow, - SymbolicExpression::IsTransition => SymbolicExpressionNode::IsTransition, - SymbolicExpression::Constant(cons) => SymbolicExpressionNode::Constant(*cons), - SymbolicExpression::Add { - x, - y, - degree_multiple, - } => { - let left_idx = topological_sort_symbolic_expr(x.as_ref(), expr_to_idx, nodes); - let right_idx = topological_sort_symbolic_expr(y.as_ref(), expr_to_idx, nodes); - SymbolicExpressionNode::Add { - left_idx, - right_idx, - degree_multiple: *degree_multiple, - } +} + +impl SymbolicDagBuilder { + pub fn new() -> Self { + Self { + expr_to_idx: FxHashMap::default(), + node_to_idx: FxHashMap::default(), + nodes: Vec::new(), } - SymbolicExpression::Sub { - x, - y, - degree_multiple, - } => { - let left_idx = topological_sort_symbolic_expr(x.as_ref(), expr_to_idx, nodes); - let right_idx = topological_sort_symbolic_expr(y.as_ref(), expr_to_idx, nodes); - SymbolicExpressionNode::Sub { - left_idx, - right_idx, - degree_multiple: *degree_multiple, - } + } + + /// Add a symbolic expression to the DAG, returning its node index. + /// Performs structural deduplication and algebraic simplifications. + pub fn add_expr(&mut self, expr: &SymbolicExpression) -> usize { + // Fast path: check if we've seen this exact Arc pointer before + let ptr = expr as *const SymbolicExpression; + if let Some(&idx) = self.expr_to_idx.get(&ptr) { + return idx; } - SymbolicExpression::Neg { x, degree_multiple } => { - let idx = topological_sort_symbolic_expr(x.as_ref(), expr_to_idx, nodes); - SymbolicExpressionNode::Neg { - idx, - degree_multiple: *degree_multiple, + + let idx = match expr { + SymbolicExpression::Variable(var) => { + self.intern_node(SymbolicExpressionNode::Variable(*var)) } - } - SymbolicExpression::Mul { - x, - y, - degree_multiple, - } => { - // An important case to remember: square will have Arc::as_ptr(&x) == Arc::as_ptr(&y) - // The `expr_to_id` will ensure only one topological sort is done to prevent exponential - // behavior. - let left_idx = topological_sort_symbolic_expr(x.as_ref(), expr_to_idx, nodes); - let right_idx = topological_sort_symbolic_expr(y.as_ref(), expr_to_idx, nodes); - SymbolicExpressionNode::Mul { - left_idx, - right_idx, - degree_multiple: *degree_multiple, + SymbolicExpression::IsFirstRow => self.intern_node(SymbolicExpressionNode::IsFirstRow), + SymbolicExpression::IsLastRow => self.intern_node(SymbolicExpressionNode::IsLastRow), + SymbolicExpression::IsTransition => { + self.intern_node(SymbolicExpressionNode::IsTransition) + } + SymbolicExpression::Constant(cons) => { + self.intern_node(SymbolicExpressionNode::Constant(*cons)) + } + SymbolicExpression::Add { + x, + y, + degree_multiple, + } => { + let left_idx = self.add_expr(x.as_ref()); + let right_idx = self.add_expr(y.as_ref()); + + // Constant folding: const + const = const + if let (Some(a), Some(b)) = (self.get_const(left_idx), self.get_const(right_idx)) { + self.intern_node(SymbolicExpressionNode::Constant(a + b)) + } + // Simplify: 0 + x = x, x + 0 = x + else if self.is_const(left_idx, F::ZERO) { + right_idx + } else if self.is_const(right_idx, F::ZERO) { + left_idx + } + // Normalize: x + (-y) = x - y + else if let Some(neg_child_idx) = self.get_neg_child(right_idx) { + self.intern_node(SymbolicExpressionNode::Sub { + left_idx, + right_idx: neg_child_idx, + degree_multiple: *degree_multiple, + }) + } else { + self.intern_node(SymbolicExpressionNode::Add { + left_idx, + right_idx, + degree_multiple: *degree_multiple, + }) + } + } + SymbolicExpression::Sub { + x, + y, + degree_multiple, + } => { + let left_idx = self.add_expr(x.as_ref()); + let right_idx = self.add_expr(y.as_ref()); + + // Constant folding: const - const = const + if let (Some(a), Some(b)) = (self.get_const(left_idx), self.get_const(right_idx)) { + self.intern_node(SymbolicExpressionNode::Constant(a - b)) + } + // Simplify: x - 0 = x + else if self.is_const(right_idx, F::ZERO) { + left_idx + } + // Simplify: x - (-y) = x + y (double negation) + else if let Some(neg_child_idx) = self.get_neg_child(right_idx) { + self.intern_node(SymbolicExpressionNode::Add { + left_idx, + right_idx: neg_child_idx, + degree_multiple: *degree_multiple, + }) + } else { + self.intern_node(SymbolicExpressionNode::Sub { + left_idx, + right_idx, + degree_multiple: *degree_multiple, + }) + } + } + SymbolicExpression::Neg { x, degree_multiple } => { + let child_idx = self.add_expr(x.as_ref()); + + // Constant folding: -const = const + if let Some(c) = self.get_const(child_idx) { + self.intern_node(SymbolicExpressionNode::Constant(-c)) + } else { + self.intern_node(SymbolicExpressionNode::Neg { + idx: child_idx, + degree_multiple: *degree_multiple, + }) + } + } + SymbolicExpression::Mul { + x, + y, + degree_multiple, + } => { + // An important case to remember: square will have Arc::as_ptr(&x) == + // Arc::as_ptr(&y) The `expr_to_idx` will ensure only one recursive + // call is done to prevent exponential behavior. + let left_idx = self.add_expr(x.as_ref()); + let right_idx = self.add_expr(y.as_ref()); + + // Constant folding: const * const = const + if let (Some(a), Some(b)) = (self.get_const(left_idx), self.get_const(right_idx)) { + self.intern_node(SymbolicExpressionNode::Constant(a * b)) + } + // Simplify: 0 * x = 0, x * 1 = x (return left_idx) + // Simplify: x * 0 = 0, 1 * x = x (return right_idx) + else if self.is_const(left_idx, F::ZERO) || self.is_const(right_idx, F::ONE) { + left_idx + } else if self.is_const(right_idx, F::ZERO) || self.is_const(left_idx, F::ONE) { + right_idx + } else { + self.intern_node(SymbolicExpressionNode::Mul { + left_idx, + right_idx, + degree_multiple: *degree_multiple, + }) + } } + }; + + self.expr_to_idx.insert(ptr, idx); + idx + } + + /// Intern a node: return existing index if the node already exists, otherwise add it. + fn intern_node(&mut self, node: SymbolicExpressionNode) -> usize { + *self.node_to_idx.entry(node.clone()).or_insert_with(|| { + let idx = self.nodes.len(); + self.nodes.push(node); + idx + }) + } + + /// Check if a node at given index is a constant with specific value. + fn is_const(&self, idx: usize, val: F) -> bool { + matches!(&self.nodes[idx], SymbolicExpressionNode::Constant(c) if *c == val) + } + + /// Get constant value from a node, if it is a constant. + fn get_const(&self, idx: usize) -> Option { + match &self.nodes[idx] { + SymbolicExpressionNode::Constant(c) => Some(*c), + _ => None, } - }; + } - let idx = nodes.len(); - nodes.push(node); - expr_to_idx.insert(expr, idx); - idx + /// If a node is a Neg, return its child index. + fn get_neg_child(&self, idx: usize) -> Option { + match &self.nodes[idx] { + SymbolicExpressionNode::Neg { idx, .. } => Some(*idx), + _ => None, + } + } } impl SymbolicExpressionDag { @@ -324,8 +448,217 @@ mod tests { type F = BabyBear; + #[test] + fn test_duplicate_constraints_are_deduplicated() { + // Create a simple expression + let expr: SymbolicExpression = SymbolicExpression::Variable(SymbolicVariable::new( + Entry::Main { + part_index: 0, + offset: 0, + }, + 0, + )); + + // Simulate calling assert_zero twice on the same expression + let constraints = vec![expr.clone(), expr.clone()]; + let interactions = vec![]; + + let dag = build_symbolic_constraints_dag(&constraints, &interactions); + + // Nodes are deduplicated - there's only 1 node in the DAG + assert_eq!( + dag.constraints.nodes.len(), + 1, + "Nodes should be deduplicated" + ); + + // constraint_idx should also be deduplicated + assert_eq!( + dag.constraints.constraint_idx, + vec![0], + "constraint_idx should be deduplicated" + ); + + // Only 1 constraint + assert_eq!( + dag.constraints.num_constraints(), + 1, + "Duplicate constraints should be deduplicated" + ); + } + + #[test] + fn test_structural_deduplication() { + // Create two structurally identical expressions with different Arc pointers + // This simulates: builder.assert_zero(x - ONE); builder.assert_zero(x - ONE); + let var = SymbolicVariable::::new( + Entry::Main { + part_index: 0, + offset: 0, + }, + 0, + ); + let expr1 = SymbolicExpression::from(var) - SymbolicExpression::Constant(F::ONE); + let expr2 = SymbolicExpression::from(var) - SymbolicExpression::Constant(F::ONE); + + // These are different Arc allocations + assert!(!std::ptr::eq(&expr1, &expr2)); + + let constraints = vec![expr1, expr2]; + let dag = build_symbolic_constraints_dag(&constraints, &[]); + + // With structural deduplication, both expressions should map to the same node + // Nodes: Variable(0), Constant(1), Sub(0,1) + assert_eq!(dag.constraints.nodes.len(), 3); + assert_eq!(dag.constraints.constraint_idx, vec![2]); + } + + #[test] + fn test_algebraic_simplifications() { + let var = SymbolicVariable::::new( + Entry::Main { + part_index: 0, + offset: 0, + }, + 0, + ); + let x = SymbolicExpression::from(var); + let zero = SymbolicExpression::Constant(F::ZERO); + let one = SymbolicExpression::Constant(F::ONE); + + // Test x + 0 = x + let expr_add_zero = x.clone() + zero.clone(); + let dag = build_symbolic_constraints_dag(&[expr_add_zero], &[]); + // Should only have Variable node, no Add node + assert_eq!(dag.constraints.nodes.len(), 2); // Variable + Constant(0) interned but Add simplified away + assert!(matches!( + dag.constraints.nodes[dag.constraints.constraint_idx[0]], + SymbolicExpressionNode::Variable(_) + )); + + // Test 0 + x = x + let expr_zero_add = zero.clone() + x.clone(); + let dag = build_symbolic_constraints_dag(&[expr_zero_add], &[]); + assert!(matches!( + dag.constraints.nodes[dag.constraints.constraint_idx[0]], + SymbolicExpressionNode::Variable(_) + )); + + // Test x * 1 = x + let expr_mul_one = x.clone() * one.clone(); + let dag = build_symbolic_constraints_dag(&[expr_mul_one], &[]); + assert!(matches!( + dag.constraints.nodes[dag.constraints.constraint_idx[0]], + SymbolicExpressionNode::Variable(_) + )); + + // Test 1 * x = x + let expr_one_mul = one.clone() * x.clone(); + let dag = build_symbolic_constraints_dag(&[expr_one_mul], &[]); + assert!(matches!( + dag.constraints.nodes[dag.constraints.constraint_idx[0]], + SymbolicExpressionNode::Variable(_) + )); + + // Test x * 0 = 0 + let expr_mul_zero = x.clone() * zero.clone(); + let dag = build_symbolic_constraints_dag(&[expr_mul_zero], &[]); + assert!(matches!( + dag.constraints.nodes[dag.constraints.constraint_idx[0]], + SymbolicExpressionNode::Constant(c) if c == F::ZERO + )); + + // Test x - 0 = x + let expr_sub_zero = x.clone() - zero.clone(); + let dag = build_symbolic_constraints_dag(&[expr_sub_zero], &[]); + assert!(matches!( + dag.constraints.nodes[dag.constraints.constraint_idx[0]], + SymbolicExpressionNode::Variable(_) + )); + + // Test x + (-y) normalizes to x - y (same as Sub) + let y = SymbolicExpression::from(SymbolicVariable::::new( + Entry::Main { + part_index: 0, + offset: 0, + }, + 1, + )); + let expr_add_neg = x.clone() + (-y.clone()); + let expr_sub = x.clone() - y.clone(); + let dag1 = build_symbolic_constraints_dag(&[expr_add_neg], &[]); + let dag2 = build_symbolic_constraints_dag(&[expr_sub], &[]); + // Both should produce the same constraint node (Sub) + assert!(matches!( + dag1.constraints.nodes[dag1.constraints.constraint_idx[0]], + SymbolicExpressionNode::Sub { .. } + )); + assert_eq!( + dag1.constraints.nodes[dag1.constraints.constraint_idx[0]], + dag2.constraints.nodes[dag2.constraints.constraint_idx[0]] + ); + + // Test x - (-y) = x + y + let expr_sub_neg = x.clone() - (-y.clone()); + let dag = build_symbolic_constraints_dag(&[expr_sub_neg], &[]); + assert!(matches!( + dag.constraints.nodes[dag.constraints.constraint_idx[0]], + SymbolicExpressionNode::Add { .. } + )); + } + + #[test] + fn test_constant_folding() { + let two = SymbolicExpression::::Constant(F::TWO); + let three = SymbolicExpression::::Constant(F::from_canonical_u32(3)); + let five = F::from_canonical_u32(5); + let six = F::from_canonical_u32(6); + let neg_three = -F::from_canonical_u32(3); + + // Test 2 + 3 = 5 + let expr_add = two.clone() + three.clone(); + let dag = build_symbolic_constraints_dag(&[expr_add], &[]); + assert!(matches!( + dag.constraints.nodes[dag.constraints.constraint_idx[0]], + SymbolicExpressionNode::Constant(c) if c == five + )); + + // Test 3 - 2 = 1 + let expr_sub = three.clone() - two.clone(); + let dag = build_symbolic_constraints_dag(&[expr_sub], &[]); + assert!(matches!( + dag.constraints.nodes[dag.constraints.constraint_idx[0]], + SymbolicExpressionNode::Constant(c) if c == F::ONE + )); + + // Test 2 * 3 = 6 + let expr_mul = two.clone() * three.clone(); + let dag = build_symbolic_constraints_dag(&[expr_mul], &[]); + assert!(matches!( + dag.constraints.nodes[dag.constraints.constraint_idx[0]], + SymbolicExpressionNode::Constant(c) if c == six + )); + + // Test -3 = neg_three + let expr_neg = -three.clone(); + let dag = build_symbolic_constraints_dag(&[expr_neg], &[]); + assert!(matches!( + dag.constraints.nodes[dag.constraints.constraint_idx[0]], + SymbolicExpressionNode::Constant(c) if c == neg_three + )); + + // Test chained: (2 + 3) * 2 = 10 + let expr_chain = (two.clone() + three.clone()) * two.clone(); + let dag = build_symbolic_constraints_dag(&[expr_chain], &[]); + assert!(matches!( + dag.constraints.nodes[dag.constraints.constraint_idx[0]], + SymbolicExpressionNode::Constant(c) if c == F::from_canonical_u32(10) + )); + } + #[test] fn test_symbolic_constraints_dag() { + // expr = Constant(1) * Variable, which simplifies to just Variable let expr = SymbolicExpression::Constant(F::ONE) * SymbolicVariable::new( Entry::Main { @@ -365,19 +698,14 @@ mod tests { right_idx: 3, degree_multiple: 2 }, - // Currently topological sort does not detect all subgraph isomorphisms. For - // example each IsFirstRow and IsLastRow is a new reference so ptr::hash is - // distinct. - SymbolicExpressionNode::Mul { - left_idx: 0, - right_idx: 1, - degree_multiple: 2 - }, + // With structural deduplication, IsFirstRow * IsLastRow is now reused + // instead of duplicated. The second occurrence reuses node index 2. SymbolicExpressionNode::Add { left_idx: 4, - right_idx: 5, + right_idx: 2, degree_multiple: 2 }, + // expr = Constant(1) * Variable simplifies to just Variable (1 * x = x) SymbolicExpressionNode::Variable(SymbolicVariable::new( Entry::Main { part_index: 1, @@ -385,31 +713,29 @@ mod tests { }, 3 )), - SymbolicExpressionNode::Mul { - left_idx: 3, - right_idx: 7, - degree_multiple: 1 - }, + // First constraint: ... + expr (which is Variable at index 6) SymbolicExpressionNode::Add { - left_idx: 6, - right_idx: 8, + left_idx: 5, + right_idx: 6, degree_multiple: 2 }, + // Second constraint: expr * expr = Variable * Variable SymbolicExpressionNode::Mul { - left_idx: 8, - right_idx: 8, + left_idx: 6, + right_idx: 6, degree_multiple: 2 }, SymbolicExpressionNode::Constant(F::TWO), ], - constraint_idx: vec![9, 10], + constraint_idx: vec![7, 8], } ); assert_eq!( dag.interactions, vec![Interaction { bus_index: 0, - message: vec![8, 11], + // expr simplified to Variable at index 6, Constant(2) at index 9 + message: vec![6, 9], count: 3, count_weight: 1, }] diff --git a/crates/stark-backend/src/air_builders/symbolic/mod.rs b/crates/stark-backend/src/air_builders/symbolic/mod.rs index 8baf1b13..53172736 100644 --- a/crates/stark-backend/src/air_builders/symbolic/mod.rs +++ b/crates/stark-backend/src/air_builders/symbolic/mod.rs @@ -1,5 +1,7 @@ // Originally copied from uni-stark/src/symbolic_builder.rs to allow A: ?Sized +use std::iter; + use itertools::Itertools; use p3_air::{ AirBuilder, AirBuilderWithPublicValues, ExtensionBuilder, PairBuilder, PermutationAirBuilder, @@ -15,10 +17,7 @@ use self::{ }; use super::PartitionedAirBuilder; use crate::{ - interaction::{ - fri_log_up::find_interaction_chunks, rap::InteractionPhaseAirBuilder, Interaction, - InteractionBuilder, RapPhaseSeqKind, SymbolicInteraction, - }, + interaction::{Interaction, InteractionBuilder, SymbolicInteraction}, keygen::types::{StarkVerifyingParams, TraceWidth}, rap::{BaseAirWithPublicValues, PermutationAirBuilderWithExposedValues, Rap}, }; @@ -50,7 +49,16 @@ pub struct SymbolicConstraints { impl SymbolicConstraints { pub fn max_constraint_degree(&self) -> usize { - Iterator::max(self.constraints.iter().map(|c| c.degree_multiple())).unwrap_or(0) + iter::empty() + .chain(&self.constraints) + .chain( + self.interactions + .iter() + .flat_map(|i| iter::once(&i.count).chain(&i.message)), + ) + .map(|expr| expr.degree_multiple()) + .max() + .unwrap_or(0) } pub fn get_log_quotient_degree(&self) -> usize { @@ -97,8 +105,6 @@ pub fn get_symbolic_builder( width: &TraceWidth, num_challenges_to_sample: &[usize], num_exposed_values_after_challenge: &[usize], - rap_phase_seq_kind: RapPhaseSeqKind, - max_constraint_degree: usize, ) -> SymbolicRapBuilder where F: Field, @@ -109,8 +115,6 @@ where rap.num_public_values(), num_challenges_to_sample, num_exposed_values_after_challenge, - rap_phase_seq_kind, - max_constraint_degree, ); Rap::eval(rap, &mut builder); builder @@ -127,12 +131,7 @@ pub struct SymbolicRapBuilder { exposed_values_after_challenge: Vec>>, constraints: Vec>, interactions: Vec>, - max_constraint_degree: usize, - rap_phase_seq_kind: RapPhaseSeqKind, trace_width: TraceWidth, - - /// Caching for FRI logup to avoid recomputation during keygen - interaction_partitions: Option>>, } impl SymbolicRapBuilder { @@ -144,8 +143,6 @@ impl SymbolicRapBuilder { num_public_values: usize, num_challenges_to_sample: &[usize], num_exposed_values_after_challenge: &[usize], - rap_phase_seq_kind: RapPhaseSeqKind, - max_constraint_degree: usize, ) -> Self { let preprocessed_width = width.preprocessed.unwrap_or(0); let prep_values = [0, 1] @@ -186,10 +183,7 @@ impl SymbolicRapBuilder { exposed_values_after_challenge, constraints: vec![], interactions: vec![], - max_constraint_degree, - rap_phase_seq_kind, trace_width: width.clone(), - interaction_partitions: None, } } @@ -200,6 +194,8 @@ impl SymbolicRapBuilder { } } + #[deprecated] + #[allow(deprecated)] pub fn params(&self) -> StarkVerifyingParams { let width = self.width(); let num_exposed_values_after_challenge = self.num_exposed_values_after_challenge(); @@ -212,12 +208,17 @@ impl SymbolicRapBuilder { } } + pub fn num_public_values(&self) -> usize { + self.public_values.len() + } + pub fn width(&self) -> TraceWidth { let mut ret = self.trace_width.clone(); ret.after_challenge = self.after_challenge.iter().map(|m| m.width()).collect(); ret } + #[deprecated] pub fn num_exposed_values_after_challenge(&self) -> Vec { self.exposed_values_after_challenge .iter() @@ -225,6 +226,7 @@ impl SymbolicRapBuilder { .collect() } + #[deprecated] pub fn num_challenges_to_sample(&self) -> Vec { self.challenges.iter().map(|c| c.len()).collect() } @@ -392,49 +394,6 @@ impl InteractionBuilder for SymbolicRapBuilder { } } -impl InteractionPhaseAirBuilder for SymbolicRapBuilder { - fn finalize_interactions(&mut self) { - let num_interactions = self.num_interactions(); - if num_interactions != 0 { - assert!( - self.after_challenge.is_empty(), - "after_challenge width should be auto-populated by the InteractionBuilder" - ); - assert!(self.challenges.is_empty()); - assert!(self.exposed_values_after_challenge.is_empty()); - - if self.rap_phase_seq_kind == RapPhaseSeqKind::FriLogUp { - let interaction_partitions = - find_interaction_chunks(&self.interactions, self.max_constraint_degree) - .interaction_partitions(); - let num_chunks = interaction_partitions.len(); - self.interaction_partitions.replace(interaction_partitions); - let perm_width = num_chunks + 1; - self.after_challenge = Self::new_after_challenge(&[perm_width]); - } - - let phases_shapes = self.rap_phase_seq_kind.shape(); - let phase_shape = phases_shapes.first().unwrap(); - - self.challenges = Self::new_challenges(&[phase_shape.num_challenges]); - self.exposed_values_after_challenge = - Self::new_exposed_values_after_challenge(&[phase_shape.num_exposed_values]); - } - } - - fn max_constraint_degree(&self) -> usize { - self.max_constraint_degree - } - - fn rap_phase_seq_kind(&self) -> RapPhaseSeqKind { - self.rap_phase_seq_kind - } - - fn symbolic_interactions(&self) -> Vec> { - self.interactions.clone() - } -} - impl PartitionedAirBuilder for SymbolicRapBuilder { fn cached_mains(&self) -> &[Self::M] { &self.partitioned_main[..self.trace_width.cached_mains.len()] diff --git a/crates/stark-backend/src/air_builders/symbolic/symbolic_expression.rs b/crates/stark-backend/src/air_builders/symbolic/symbolic_expression.rs index b20c2a15..bcbf79bd 100644 --- a/crates/stark-backend/src/air_builders/symbolic/symbolic_expression.rs +++ b/crates/stark-backend/src/air_builders/symbolic/symbolic_expression.rs @@ -122,7 +122,7 @@ impl SymbolicExpression { SymbolicExpression::Variable(v) => v.degree_multiple(), SymbolicExpression::IsFirstRow => 1, SymbolicExpression::IsLastRow => 1, - SymbolicExpression::IsTransition => 0, + SymbolicExpression::IsTransition => 1, SymbolicExpression::Constant(_) => 0, SymbolicExpression::Add { degree_multiple, .. diff --git a/crates/stark-backend/src/chip.rs b/crates/stark-backend/src/chip.rs index 5b999427..62dfeaf9 100644 --- a/crates/stark-backend/src/chip.rs +++ b/crates/stark-backend/src/chip.rs @@ -5,7 +5,7 @@ use std::{ sync::{Arc, Mutex}, }; -use crate::prover::{hal::ProverBackend, types::AirProvingContext}; +use crate::prover::{types::AirProvingContext, ProverBackend}; /// A chip is a [ProverBackend]-specific object that converts execution logs (also referred to as /// records) into a trace matrix. diff --git a/crates/stark-backend/src/engine.rs b/crates/stark-backend/src/engine.rs index 4f0fe4a1..664af24e 100644 --- a/crates/stark-backend/src/engine.rs +++ b/crates/stark-backend/src/engine.rs @@ -10,9 +10,8 @@ use crate::{ proof::{OpeningProof, Proof}, prover::{ coordinator::Coordinator, - hal::{DeviceDataTransporter, ProverBackend, ProverDevice}, types::{AirProofRawInput, AirProvingContext, DeviceMultiStarkProvingKey, ProvingContext}, - Prover, + DeviceDataTransporter, Prover, ProverBackend, ProverDevice, }, verifier::{MultiTraceStarkVerifier, VerificationError}, AirRef, diff --git a/crates/stark-backend/src/interaction/fri_log_up.rs b/crates/stark-backend/src/interaction/fri_log_up.rs index 2df92949..427fe4eb 100644 --- a/crates/stark-backend/src/interaction/fri_log_up.rs +++ b/crates/stark-backend/src/interaction/fri_log_up.rs @@ -115,7 +115,7 @@ where } // Proof of work phase to boost logup security. - let logup_pow_witness = challenger.grind(self.log_up_params.log_up_pow_bits); + let logup_pow_witness = challenger.grind(self.log_up_params.pow_bits); let challenges: [Challenge; STARK_LU_NUM_CHALLENGES] = array::from_fn(|_| challenger.sample_ext_element::()); @@ -178,10 +178,7 @@ where } }; - if !challenger.check_witness( - self.log_up_params.log_up_pow_bits, - partial_proof.logup_pow_witness, - ) { + if !challenger.check_witness(self.log_up_params.pow_bits, partial_proof.logup_pow_witness) { return ( RapPhaseVerifierData::default(), Err(FriLogUpError::InvalidPowWitness), @@ -626,8 +623,7 @@ pub(crate) fn find_interaction_chunks( cur_chunk.push(interaction_idx); numerator_max_degree = count_degree; running_sum_field_degree = field_degree; - if max_constraint_degree > 0 - && max(count_degree, field_degree + 1) > max_constraint_degree + if max_constraint_degree > 0 && max(count_degree, field_degree) > max_constraint_degree { panic!("Interaction with field_degree={field_degree}, count_degree={count_degree} exceeds max_constraint_degree={max_constraint_degree}"); } diff --git a/crates/stark-backend/src/interaction/mod.rs b/crates/stark-backend/src/interaction/mod.rs index 4985ee8b..b4a2adb4 100644 --- a/crates/stark-backend/src/interaction/mod.rs +++ b/crates/stark-backend/src/interaction/mod.rs @@ -329,7 +329,7 @@ pub trait RapPhaseSeq { type PairTraceView<'a, F> = PairView>, F>; /// Parameters to ensure sufficient soundness of the LogUp part of the protocol. -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] #[repr(C)] pub struct LogUpSecurityParameters { /// A bound on the total number of interactions. @@ -339,7 +339,7 @@ pub struct LogUpSecurityParameters { /// keygen. pub log_max_message_length: u32, /// The number of proof-of-work bits for the LogUp proof-of-work phase. - pub log_up_pow_bits: usize, + pub pow_bits: usize, } impl LogUpSecurityParameters { @@ -350,7 +350,7 @@ impl LogUpSecurityParameters { log_order - log2_ceil_usize(2 * self.max_interaction_count as usize) as u32 // multiply by two to account for the poles as well - self.log_max_message_length - + u32::try_from(self.log_up_pow_bits).unwrap() + + u32::try_from(self.pow_bits).unwrap() } pub fn max_message_length(&self) -> usize { 2usize diff --git a/crates/stark-backend/src/interaction/rap.rs b/crates/stark-backend/src/interaction/rap.rs index d4e7d67a..d526e812 100644 --- a/crates/stark-backend/src/interaction/rap.rs +++ b/crates/stark-backend/src/interaction/rap.rs @@ -1,45 +1,18 @@ //! An AIR with specified interactions can be augmented into a RAP. //! This module auto-converts any [Air] implemented on an [InteractionBuilder] into a [Rap]. -use p3_air::{Air, AirBuilder}; +use p3_air::Air; -use super::{InteractionBuilder, RapPhaseSeqKind, SymbolicInteraction}; -use crate::{ - interaction::fri_log_up::eval_fri_log_up_phase, - rap::{PermutationAirBuilderWithExposedValues, Rap}, -}; - -/// Used internally to select RAP phase evaluation function. -pub(crate) trait InteractionPhaseAirBuilder: InteractionBuilder { - fn finalize_interactions(&mut self); - /// The symbolic interactions **must** correspond to the `InteractionBuilder::all_interactions` - /// function. - fn symbolic_interactions(&self) -> Vec::F>>; - /// The maximum constraint degree allowed in a RAP. - fn max_constraint_degree(&self) -> usize; - fn rap_phase_seq_kind(&self) -> RapPhaseSeqKind; -} +use super::InteractionBuilder; +use crate::rap::{PermutationAirBuilderWithExposedValues, Rap}; impl Rap for A where A: Air, - AB: InteractionBuilder + PermutationAirBuilderWithExposedValues + InteractionPhaseAirBuilder, + AB: InteractionBuilder + PermutationAirBuilderWithExposedValues, { fn eval(&self, builder: &mut AB) { // Constraints for the main trace: Air::eval(self, builder); - builder.finalize_interactions(); - if builder.num_interactions() != 0 { - match builder.rap_phase_seq_kind() { - RapPhaseSeqKind::FriLogUp => { - let symbolic_interactions = builder.symbolic_interactions(); - eval_fri_log_up_phase( - builder, - &symbolic_interactions, - builder.max_constraint_degree(), - ); - } - } - } } } diff --git a/crates/stark-backend/src/keygen/mod.rs b/crates/stark-backend/src/keygen/mod.rs index 081e1cd5..c3af1390 100644 --- a/crates/stark-backend/src/keygen/mod.rs +++ b/crates/stark-backend/src/keygen/mod.rs @@ -205,7 +205,7 @@ impl<'a, SC: StarkGenericConfig> MultiStarkKeygenBuilder<'a, SC> { let pre_vk: MultiStarkVerifyingKey0 = MultiStarkVerifyingKey0 { per_air: pk_per_air.iter().map(|pk| pk.vk.clone()).collect(), trace_height_constraints: trace_height_constraints.clone(), - log_up_pow_bits: log_up_security_params.log_up_pow_bits, + log_up_pow_bits: log_up_security_params.pow_bits, }; // To protect against weak Fiat-Shamir, we hash the "pre"-verifying key and include it in // the final verifying key. This just needs to commit to the verifying key and does @@ -228,7 +228,7 @@ impl<'a, SC: StarkGenericConfig> MultiStarkKeygenBuilder<'a, SC> { per_air: pk_per_air, trace_height_constraints, max_constraint_degree: self.max_constraint_degree, - log_up_pow_bits: log_up_security_params.log_up_pow_bits, + log_up_pow_bits: log_up_security_params.pow_bits, vk_pre_hash, } } @@ -258,6 +258,7 @@ impl AirKeygenBuilder { let air_name = self.air.name(); let symbolic_builder = self.get_symbolic_builder(Some(max_constraint_degree)); + #[allow(deprecated)] let params = symbolic_builder.params(); let symbolic_constraints = symbolic_builder.constraints(); let log_quotient_degree = symbolic_constraints.get_log_quotient_degree(); @@ -272,11 +273,14 @@ impl AirKeygenBuilder { .. } = self; + let max_constraint_degree: u8 = + u8::try_from(symbolic_constraints.max_constraint_degree()).unwrap(); let vk: StarkVerifyingKey, Com> = StarkVerifyingKey { preprocessed_data: prep_verifier_data, params, symbolic_constraints: symbolic_constraints.into(), quotient_degree, + max_constraint_degree, rap_phase_seq_kind: self.rap_phase_seq_kind, }; StarkProvingKey { @@ -289,7 +293,7 @@ impl AirKeygenBuilder { fn get_symbolic_builder( &self, - max_constraint_degree: Option, + _max_constraint_degree: Option, ) -> SymbolicRapBuilder> { let width = TraceWidth { preprocessed: self.prep_keygen_data.width(), @@ -297,14 +301,7 @@ impl AirKeygenBuilder { common_main: self.air.common_main_width(), after_challenge: vec![], }; - get_symbolic_builder( - self.air.as_ref(), - &width, - &[], - &[], - SC::RapPhaseSeq::ID, - max_constraint_degree.unwrap_or(0), - ) + get_symbolic_builder(self.air.as_ref(), &width, &[], &[]) } } diff --git a/crates/stark-backend/src/keygen/types.rs b/crates/stark-backend/src/keygen/types.rs index 0d509167..72533a43 100644 --- a/crates/stark-backend/src/keygen/types.rs +++ b/crates/stark-backend/src/keygen/types.rs @@ -85,6 +85,7 @@ pub struct StarkVerifyingKey { /// Determined from the max constraint degree of the AIR constraints. This is equivalently /// the number of chunks the quotient polynomial is split into. pub quotient_degree: u8, + pub max_constraint_degree: u8, pub rap_phase_seq_kind: RapPhaseSeqKind, } diff --git a/crates/stark-backend/src/prover/cpu/mod.rs b/crates/stark-backend/src/prover/cpu/mod.rs index d9534a29..8f85fab9 100644 --- a/crates/stark-backend/src/prover/cpu/mod.rs +++ b/crates/stark-backend/src/prover/cpu/mod.rs @@ -1,7 +1,7 @@ use std::{iter::zip, marker::PhantomData, mem::ManuallyDrop, ops::Deref, sync::Arc}; use derivative::Derivative; -use itertools::{izip, zip_eq, Itertools}; +use itertools::{izip, Itertools}; use opener::OpeningProver; use p3_challenger::FieldChallenger; use p3_commit::{Pcs, PolynomialSpace}; @@ -199,7 +199,7 @@ impl hal::RapPartialProver> for CpuDevice public_values: v.public_values, }) .collect_vec(); - let (rap_phase_seq_proof, rap_phase_seq_data) = self + let (rap_phase_seq_proof, _rap_phase_seq_data) = self .config() .rap_phase_seq() .partially_prove( @@ -210,11 +210,12 @@ impl hal::RapPartialProver> for CpuDevice ) .map_or((None, None), |(p, d)| (Some(p), Some(d))); - let mvk_view = mpk.vk_view(); + // let mvk_view = mpk.vk_view(); - let mut perm_matrix_idx = 0usize; - let rap_views_per_phase; - let perm_trace_per_air = if let Some(phase_data) = rap_phase_seq_data { + // let mut perm_matrix_idx = 0usize; + /*let rap_views_per_phase; + let perm_trace_per_air = + if let Some(phase_data) = rap_phase_seq_data { assert_eq!(mvk_view.num_phases(), 1); assert_eq!( mvk_view.num_challenges_in_phase(0), @@ -239,44 +240,47 @@ impl hal::RapPartialProver> for CpuDevice .collect_vec(); rap_views_per_phase = vec![perm_views]; // 1 challenge phase phase_data.after_challenge_trace_per_air - } else { - assert_eq!(mvk_view.num_phases(), 0); + } else + { + // assert_eq!(mvk_view.num_phases(), 0); rap_views_per_phase = vec![]; vec![None; num_airs] - }; + };*/ // Commit to permutation traces: this means only 1 challenge round right now // One shared commit for all permutation traces let committed_pcs_data_per_phase: Vec<(Com, PcsData)> = info_span!("perm_trace_commit") .in_scope(|| { - let (log_trace_heights, flattened_traces): (Vec<_>, Vec<_>) = - perm_trace_per_air - .into_iter() - .flatten() - .map(|perm_trace| { - // SAFETY: `Challenge` is assumed to be extension field of `F` - // with memory layout `[F; Challenge::D]` - let trace = unsafe { transmute_to_base(perm_trace) }; - let height = trace.height(); - let log_height: u8 = log2_strict_usize(height).try_into().unwrap(); - let domain = self.pcs().natural_domain_for_degree(height); - (log_height, (domain, trace)) - }) - .collect(); - // Only commit if there are permutation traces - if !flattened_traces.is_empty() { - let (commit, data) = self.pcs().commit(flattened_traces); - Some((commit, PcsData::new(Arc::new(data), log_trace_heights))) - } else { - None - } + // let (log_trace_heights, flattened_traces): (Vec<_>, Vec<_>) = + // perm_trace_per_air + // .into_iter() + // .flatten() + // .map(|perm_trace| { + // // SAFETY: `Challenge` is assumed to be extension field of `F` + // // with memory layout `[F; Challenge::D]` + // let trace = unsafe { transmute_to_base(perm_trace) }; + // let height = trace.height(); + // let log_height: u8 = + // log2_strict_usize(height).try_into().unwrap(); + // let domain = self.pcs().natural_domain_for_degree(height); + // (log_height, (domain, trace)) + // }) + // .collect(); + // // Only commit if there are permutation traces + // if !flattened_traces.is_empty() { + // let (commit, data) = self.pcs().commit(flattened_traces); + // Some((commit, PcsData::new(Arc::new(data), log_trace_heights))) + // } else { + // None + // } + None }) .into_iter() .collect(); let prover_view = ProverDataAfterRapPhases { committed_pcs_data_per_phase, - rap_views_per_phase, + rap_views_per_phase: vec![], }; (rap_phase_seq_proof, prover_view) } diff --git a/crates/stark-backend/src/prover/mod.rs b/crates/stark-backend/src/prover/mod.rs index ff47fc47..cdbf53dd 100644 --- a/crates/stark-backend/src/prover/mod.rs +++ b/crates/stark-backend/src/prover/mod.rs @@ -20,6 +20,8 @@ pub mod helper; // [jpw]: maybe this should be moved to sdk /// Metrics about trace and other statistics related to prover performance pub mod metrics; +pub use hal::*; + /// Trait for STARK/SNARK proving at the highest abstraction level. pub trait Prover { type ProvingKeyView<'a> diff --git a/crates/stark-backend/tests/integration_test.rs b/crates/stark-backend/tests/integration_test.rs index 7359c129..5d0cbe2f 100644 --- a/crates/stark-backend/tests/integration_test.rs +++ b/crates/stark-backend/tests/integration_test.rs @@ -19,7 +19,6 @@ use openvm_stark_sdk::{ use p3_baby_bear::BabyBear; mod cached_lookup; -mod fib_selector_air; mod fib_triples_air; pub mod interaction; mod partitioned_sum_air; @@ -73,7 +72,9 @@ fn test_single_fib_triples_stark() { #[test] fn test_single_fib_selector_stark() { - use fib_selector_air::{air::FibonacciSelectorAir, trace::generate_trace_rows}; + use openvm_stark_sdk::dummy_airs::fib_selector_air::{ + air::FibonacciSelectorAir, trace::generate_trace_rows, + }; let log_trace_degree = 3; @@ -98,8 +99,10 @@ fn test_single_fib_selector_stark() { #[test] fn test_double_fib_starks() { - use fib_selector_air::air::FibonacciSelectorAir; - use openvm_stark_sdk::dummy_airs::{fib_air, fib_air::air::FibonacciAir}; + use openvm_stark_sdk::dummy_airs::{ + fib_air, fib_air::air::FibonacciAir, fib_selector_air, + fib_selector_air::air::FibonacciSelectorAir, + }; let log_n1 = 3; let log_n2 = 5; diff --git a/crates/stark-backend/tests/interaction/mod.rs b/crates/stark-backend/tests/interaction/mod.rs index 64709060..cf8854e1 100644 --- a/crates/stark-backend/tests/interaction/mod.rs +++ b/crates/stark-backend/tests/interaction/mod.rs @@ -12,18 +12,17 @@ use openvm_stark_backend::{ use openvm_stark_sdk::{ any_rap_arc_vec, config::{self, baby_bear_poseidon2::BabyBearPoseidon2Engine}, - dummy_airs::interaction::dummy_interaction_air::DummyInteractionAir, + dummy_airs::{ + fib_selector_air::{air::FibonacciSelectorAir, trace::generate_trace_rows}, + interaction::dummy_interaction_air::DummyInteractionAir, + }, engine::StarkFriEngine, }; use p3_baby_bear::BabyBear; use p3_field::PrimeField32; use p3_matrix::dense::RowMajorMatrix; -use crate::{ - fib_selector_air::{air::FibonacciSelectorAir, trace::generate_trace_rows}, - get_conditional_fib_number, - utils::to_field_vec, -}; +use crate::{get_conditional_fib_number, utils::to_field_vec}; type Val = BabyBear; diff --git a/crates/stark-sdk/src/config/log_up_params.rs b/crates/stark-sdk/src/config/log_up_params.rs index 5147dc40..244a3a90 100644 --- a/crates/stark-sdk/src/config/log_up_params.rs +++ b/crates/stark-sdk/src/config/log_up_params.rs @@ -8,7 +8,7 @@ pub fn log_up_security_params_baby_bear_100_bits() -> LogUpSecurityParameters { let params = LogUpSecurityParameters { max_interaction_count: BabyBear::ORDER_U32, log_max_message_length: 7, - log_up_pow_bits: 16, + pow_bits: 16, }; assert!(params.conjectured_bits_of_security::>() >= 100); params diff --git a/crates/stark-backend/tests/fib_selector_air/air.rs b/crates/stark-sdk/src/dummy_airs/fib_selector_air/air.rs similarity index 93% rename from crates/stark-backend/tests/fib_selector_air/air.rs rename to crates/stark-sdk/src/dummy_airs/fib_selector_air/air.rs index a2fefa56..04929e83 100644 --- a/crates/stark-backend/tests/fib_selector_air/air.rs +++ b/crates/stark-sdk/src/dummy_airs/fib_selector_air/air.rs @@ -2,14 +2,14 @@ use std::borrow::Borrow; use openvm_stark_backend::{ interaction::{InteractionBuilder, LookupBus}, + p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, BaseAir, PairBuilder}, p3_field::{Field, FieldAlgebra}, + p3_matrix::{dense::RowMajorMatrix, Matrix}, rap::{BaseAirWithPublicValues, PartitionedBaseAir}, }; -use openvm_stark_sdk::dummy_airs::fib_air::columns::{FibonacciCols, NUM_FIBONACCI_COLS}; -use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, BaseAir, PairBuilder}; -use p3_matrix::{dense::RowMajorMatrix, Matrix}; use super::columns::FibonacciSelectorCols; +use crate::dummy_airs::fib_air::columns::{FibonacciCols, NUM_FIBONACCI_COLS}; pub struct FibonacciSelectorAir { sels: Vec, diff --git a/crates/stark-backend/tests/fib_selector_air/columns.rs b/crates/stark-sdk/src/dummy_airs/fib_selector_air/columns.rs similarity index 100% rename from crates/stark-backend/tests/fib_selector_air/columns.rs rename to crates/stark-sdk/src/dummy_airs/fib_selector_air/columns.rs diff --git a/crates/stark-backend/tests/fib_selector_air/mod.rs b/crates/stark-sdk/src/dummy_airs/fib_selector_air/mod.rs similarity index 100% rename from crates/stark-backend/tests/fib_selector_air/mod.rs rename to crates/stark-sdk/src/dummy_airs/fib_selector_air/mod.rs diff --git a/crates/stark-backend/tests/fib_selector_air/trace.rs b/crates/stark-sdk/src/dummy_airs/fib_selector_air/trace.rs similarity index 78% rename from crates/stark-backend/tests/fib_selector_air/trace.rs rename to crates/stark-sdk/src/dummy_airs/fib_selector_air/trace.rs index 18fbcfc6..0b5a2df5 100644 --- a/crates/stark-backend/tests/fib_selector_air/trace.rs +++ b/crates/stark-sdk/src/dummy_airs/fib_selector_air/trace.rs @@ -1,6 +1,6 @@ -use openvm_stark_backend::p3_field::PrimeField32; -use openvm_stark_sdk::dummy_airs::fib_air::columns::NUM_FIBONACCI_COLS; -use p3_matrix::dense::RowMajorMatrix; +use openvm_stark_backend::{p3_field::PrimeField32, p3_matrix::dense::RowMajorMatrix}; + +use crate::dummy_airs::fib_air::columns::NUM_FIBONACCI_COLS; /// sels contain boolean selectors to enable the fibonacci gate pub fn generate_trace_rows(a: u32, b: u32, sels: &[bool]) -> RowMajorMatrix { diff --git a/crates/stark-sdk/src/dummy_airs/interaction/dummy_interaction_air.rs b/crates/stark-sdk/src/dummy_airs/interaction/dummy_interaction_air.rs index fa3a03e7..ce480cb4 100644 --- a/crates/stark-sdk/src/dummy_airs/interaction/dummy_interaction_air.rs +++ b/crates/stark-sdk/src/dummy_airs/interaction/dummy_interaction_air.rs @@ -17,8 +17,8 @@ use openvm_stark_backend::{ p3_matrix::{dense::RowMajorMatrix, Matrix}, prover::{ cpu::{CpuBackend, CpuDevice}, - hal::TraceCommitter, types::{AirProvingContext, CommittedTraceData}, + TraceCommitter, }, rap::{BaseAirWithPublicValues, PartitionedBaseAir}, AirRef, Chip, ChipUsageGetter, diff --git a/crates/stark-sdk/src/dummy_airs/interaction/mod.rs b/crates/stark-sdk/src/dummy_airs/interaction/mod.rs index 12f0a1d0..f824e71d 100644 --- a/crates/stark-sdk/src/dummy_airs/interaction/mod.rs +++ b/crates/stark-sdk/src/dummy_airs/interaction/mod.rs @@ -1 +1,2 @@ pub mod dummy_interaction_air; +pub mod self_interaction_air; diff --git a/crates/stark-sdk/src/dummy_airs/interaction/self_interaction_air.rs b/crates/stark-sdk/src/dummy_airs/interaction/self_interaction_air.rs new file mode 100644 index 00000000..2f2515eb --- /dev/null +++ b/crates/stark-sdk/src/dummy_airs/interaction/self_interaction_air.rs @@ -0,0 +1,87 @@ +use std::sync::Arc; + +use itertools::{fold, Itertools}; +use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, + interaction::{BusIndex, InteractionBuilder}, + p3_air::{Air, AirBuilder, BaseAir}, + p3_field::{FieldAlgebra, PrimeField32}, + p3_matrix::{dense::RowMajorMatrix, Matrix}, + prover::{cpu::CpuBackend, types::AirProvingContext}, + rap::{BaseAirWithPublicValues, PartitionedBaseAir}, + Chip, +}; + +#[derive(Debug, Clone, Copy)] +pub struct SelfInteractionAir { + pub width: usize, + pub bus_index: BusIndex, +} + +impl BaseAir for SelfInteractionAir { + fn width(&self) -> usize { + self.width + } +} +impl BaseAirWithPublicValues for SelfInteractionAir {} +impl PartitionedBaseAir for SelfInteractionAir {} + +impl Air for SelfInteractionAir +where + AB::F: PrimeField32, +{ + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + + let (local, next) = (main.row_slice(0), main.row_slice(1)); + let mut local: Vec<::Expr> = + (*local).iter().map(|v| (*v).into()).collect_vec(); + let mut next: Vec<::Expr> = + (*next).iter().map(|v| (*v).into()).collect_vec(); + + let local_sum = fold(&local, AB::Expr::ZERO, |acc, val| acc + val.clone()); + let next_sum = fold(&local, AB::Expr::ZERO, |acc, val| acc + val.clone()); + + // Interaction where count is constant + builder.push_interaction(self.bus_index, local.clone(), AB::Expr::ONE, 1); + builder.push_interaction(self.bus_index, next.clone(), AB::Expr::NEG_ONE, 1); + + // Interaction where count is an expression + common with interaction below + builder.push_interaction(self.bus_index, local.clone(), local_sum.clone(), 1); + builder.push_interaction(self.bus_index, next.clone(), -next_sum.clone(), 1); + + // Interaction where count == fields[0] + builder.push_interaction(self.bus_index, local.clone(), local[0].clone(), 1); + builder.push_interaction(self.bus_index, next.clone(), -next[0].clone(), 1); + + local.reverse(); + next.reverse(); + + // Interaction where count_weight != 1 + builder.push_interaction(self.bus_index, local.clone(), AB::Expr::TWO, 2); + builder.push_interaction(self.bus_index, next.clone(), -AB::Expr::TWO, 2); + + // Interaction where count is an expression + common with interaction above + builder.push_interaction(self.bus_index, local, local_sum, 1); + builder.push_interaction(self.bus_index, next, -next_sum, 1); + } +} + +#[derive(Debug, Clone, Copy)] +pub struct SelfInteractionChip { + pub width: usize, + pub log_height: usize, +} + +impl Chip<(), CpuBackend> for SelfInteractionChip { + fn generate_proving_ctx(&self, _: ()) -> AirProvingContext> { + assert!(self.width > 0); + let mut trace = vec![Val::::ZERO; (1 << self.log_height) * self.width]; + for (row_idx, chunk) in trace.chunks_mut(self.width).enumerate() { + for (i, val) in chunk.iter_mut().enumerate() { + *val = Val::::from_canonical_usize((row_idx + i) % self.width); + } + } + AirProvingContext::simple_no_pis(Arc::new(RowMajorMatrix::new(trace, self.width))) + } +} diff --git a/crates/stark-sdk/src/dummy_airs/mod.rs b/crates/stark-sdk/src/dummy_airs/mod.rs index faa4d821..42067dab 100644 --- a/crates/stark-sdk/src/dummy_airs/mod.rs +++ b/crates/stark-sdk/src/dummy_airs/mod.rs @@ -1,3 +1,4 @@ -pub mod fib_air; /// Some dummy AIRs for testing. +pub mod fib_air; +pub mod fib_selector_air; pub mod interaction; diff --git a/crates/stark-sdk/src/engine.rs b/crates/stark-sdk/src/engine.rs index bfc5a3b4..fdb71114 100644 --- a/crates/stark-sdk/src/engine.rs +++ b/crates/stark-sdk/src/engine.rs @@ -5,7 +5,7 @@ use openvm_stark_backend::{ config::{StarkGenericConfig, Val}, engine::VerificationData, p3_matrix::dense::RowMajorMatrix, - prover::{hal::DeviceDataTransporter, types::AirProvingContext}, + prover::{types::AirProvingContext, DeviceDataTransporter}, verifier::VerificationError, AirRef, }; diff --git a/crates/stark-sdk/src/metrics_tracing.rs b/crates/stark-sdk/src/metrics_tracing.rs index 6bddd76e..388324dd 100644 --- a/crates/stark-sdk/src/metrics_tracing.rs +++ b/crates/stark-sdk/src/metrics_tracing.rs @@ -19,6 +19,7 @@ pub struct TimingMetricsLayer { struct SpanTiming { name: String, start_time: Instant, + labels: Vec<(String, String)>, } /// A visitor to extract the return value from span events @@ -39,11 +40,37 @@ impl Visit for ReturnValueVisitor { fn record_str(&mut self, _field: &Field, _value: &str) {} } +/// A visitor to extract all string fields from span attributes as metric labels +#[derive(Default)] +struct LabelVisitor { + labels: Vec<(String, String)>, +} + +impl Visit for LabelVisitor { + fn record_debug(&mut self, _field: &Field, _value: &dyn std::fmt::Debug) {} + fn record_i64(&mut self, _field: &Field, _value: i64) {} + fn record_u64(&mut self, _field: &Field, _value: u64) {} + fn record_bool(&mut self, _field: &Field, _value: bool) {} + fn record_str(&mut self, field: &Field, value: &str) { + self.labels + .push((field.name().to_string(), value.to_string())); + } +} + impl TimingMetricsLayer { /// Create a new TimingMetricsLayer pub fn new() -> Self { Self::default() } + + fn emit_metric(name: &str, duration_ms: f64, labels: &[(String, String)]) { + let metric_name = format!("{}_time_ms", name); + let labels: Vec = labels + .iter() + .map(|(k, v)| metrics::Label::new(k.clone(), v.clone())) + .collect(); + metrics::gauge!(metric_name, labels).set(duration_ms); + } } impl Layer for TimingMetricsLayer @@ -52,7 +79,7 @@ where { fn on_new_span( &self, - _attrs: &tracing::span::Attributes<'_>, + attrs: &tracing::span::Attributes<'_>, id: &Id, ctx: tracing_subscriber::layer::Context<'_, S>, ) { @@ -62,11 +89,16 @@ where // Only track spans at INFO level or higher to match metrics_span behavior if metadata.level() <= &tracing::Level::INFO { + // Extract all string fields from span attributes as labels + let mut label_visitor = LabelVisitor::default(); + attrs.record(&mut label_visitor); + self.span_timings.insert( id.clone(), SpanTiming { name: name.to_string(), start_time: Instant::now(), + labels: label_visitor.labels, }, ); } @@ -86,10 +118,7 @@ where // Emit metric for the span that's returning if let Some((_, timing)) = self.span_timings.remove(&span_id) { let duration_ms = timing.start_time.elapsed().as_millis() as f64; - - // Emit the metric gauge with the span name - // This matches the behavior of metrics_span - metrics::gauge!(format!("{}_time_ms", timing.name)).set(duration_ms); + Self::emit_metric(&timing.name, duration_ms, &timing.labels); } } } @@ -100,9 +129,7 @@ where // This handles spans that don't have instrumented return values if let Some((_, timing)) = self.span_timings.remove(&id) { let duration_ms = timing.start_time.elapsed().as_millis() as f64; - - // Emit the metric gauge with the span name - metrics::gauge!(format!("{}_time_ms", timing.name)).set(duration_ms); + Self::emit_metric(&timing.name, duration_ms, &timing.labels); } } }