From 80d5537249d53ee7e205271173a3d2e392c29d5d Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Tue, 7 Apr 2026 18:34:53 +0800 Subject: [PATCH 1/3] wip jagged_commit --- crates/mpcs/src/jagged.rs | 333 ++++++++++++++++++++++++++++++++++++++ crates/mpcs/src/lib.rs | 2 + 2 files changed, 335 insertions(+) create mode 100644 crates/mpcs/src/jagged.rs diff --git a/crates/mpcs/src/jagged.rs b/crates/mpcs/src/jagged.rs new file mode 100644 index 0000000..ddfd558 --- /dev/null +++ b/crates/mpcs/src/jagged.rs @@ -0,0 +1,333 @@ +//! Jagged PCS Commit Adaptor +//! +//! This module implements the commit protocol for the Jagged PCS as described in +//! the SP1 Jagged PCS paper () and Ceno issues #1272 / #1288. +//! +//! ## Overview +//! +//! The "Jagged PCS" reduces proof size by packing all trace polynomials from multiple +//! chips into a single "giga" multilinear polynomial `q'`: +//! +//! ```text +//! q' = bitrev(p_0) || bitrev(p_1) || ... || bitrev(p_N) +//! ``` +//! +//! where each `p_i` is a column polynomial extracted from the input trace matrices, and +//! `bitrev` is the suffix-to-prefix bit-reversal transformation. +//! +//! ## Suffix-to-Prefix Transformation +//! +//! The main sumcheck prover outputs evaluations `v_i = p_i(r[(n-s)..n])` — i.e., at the +//! **suffix** of the random challenge point. To make these evaluations compatible with the +//! jagged sumcheck (which operates on prefix-aligned polynomials), we apply a bit-reversal +//! permutation to each polynomial's evaluations: +//! +//! ```text +//! p_i'[j] = p_i[bitrev_s(j)] (for j in 0..2^s) +//! ``` +//! +//! After bit-reversal, `v_i = p_i(r[(n-s)..n]) = p_i'(reverse(r[(n-s)..n]))`. +//! +//! ## Cumulative Heights +//! +//! The cumulative height sequence `t` tracks the starting position of each polynomial in `q'`: +//! - `t[0] = 0` +//! - `t[i+1] = t[i] + h_i` where `h_i = 2^(num_vars of p_i)` is the number of evaluations +//! +//! Given a position `b` in `q'`, the verifier can locate the corresponding `(i, r)` pair via: +//! - `t[i-1] <= b < t[i]` +//! - `r = b - t[i-1]` +//! +//! The cumulative heights allow the verifier to succinctly evaluate the indicator function +//! `g(z_r, z_b, t[i-1], t[i])` needed for the jagged sumcheck. +//! +//! ## Commit Protocol +//! +//! 1. For each input matrix `M_k` (with `h_k` rows and `w_k` columns): +//! a. Extract each column as a polynomial with `h_k` evaluations. +//! b. Apply bit-reversal to the evaluations. +//! 2. Concatenate all bit-reversed polynomials: `cat = bitrev(p_0) || bitrev(p_1) || ...` +//! 3. Compute cumulative heights `t[i]`. +//! 4. Pad `cat` to the next power of two (required for MLE representation). +//! 5. Commit to the padded `cat` as a single-column matrix using the inner PCS. + +use std::iter::once; + +use crate::{Error, PolynomialCommitmentScheme}; +use ff_ext::ExtensionField; +use itertools::Itertools; +use p3::{ + matrix::{Matrix, bitrev::BitReversableMatrix}, + maybe_rayon::prelude::{ + IndexedParallelIterator, IntoParallelIterator, ParallelIterator, ParallelSliceMut, + }, +}; +use serde::{Deserialize, Serialize}; +use witness::{InstancePaddingStrategy, RowMajorMatrix}; + +/// Commitment to a jagged polynomial `q'`, together with all witness data needed +/// for opening proofs. +/// +/// Generic over the inner PCS `Pcs` so that any `PolynomialCommitmentScheme` can +/// serve as the underlying commitment engine. +pub struct JaggedCommitmentWithWitness> { + /// Commitment (with witness) to the "giga" polynomial `q'` via `Pcs`. + pub inner: Pcs::CommitmentWithWitness, + /// Cumulative height sequence `t`: + /// - `t[0] = 0` + /// - `t[i+1] = t[i] + poly_heights[i]` + /// - Length: `num_polys + 1` + pub cumulative_heights: Vec, + /// Number of evaluations `h_i = 2^(num_vars_i)` for each polynomial `p_i`. + /// Length: `num_polys`. + pub poly_heights: Vec, +} + +/// The pure commitment (without witness data) for a jagged polynomial `q'`. +/// This is what the verifier receives. +#[derive(Clone, Serialize, Deserialize)] +#[serde(bound(serialize = "", deserialize = ""))] +pub struct JaggedCommitment> { + /// Pure commitment to the underlying giga polynomial `q'`. + pub inner: Pcs::Commitment, + /// Cumulative height sequence `t` (verifier needs this to evaluate `f(b)`). + pub cumulative_heights: Vec, +} + +impl> JaggedCommitmentWithWitness { + /// Extract the pure commitment (without witness data). + pub fn to_commitment(&self) -> JaggedCommitment { + JaggedCommitment { + inner: Pcs::get_pure_commitment(&self.inner), + cumulative_heights: self.cumulative_heights.clone(), + } + } + + /// Total number of polynomials packed into `q'`. + pub fn num_polys(&self) -> usize { + self.poly_heights.len() + } + + /// Total number of evaluations in the *unpadded* concatenated polynomial + /// (= `t[num_polys]` = `cumulative_heights.last()`). + pub fn total_evaluations(&self) -> usize { + self.cumulative_heights.last().copied().unwrap_or(0) + } +} + +/// Commit to a sequence of row-major matrices using the Jagged PCS scheme. +/// +/// This function implements the commit phase described in Ceno issue #1288: +/// 1. For each matrix, bit-reverse its rows (suffix-to-prefix transformation). +/// 2. Transpose the bit-reversed matrix (row-major → column-major), so each +/// column polynomial occupies a contiguous region in memory. +/// 3. Concatenate all column polynomials: `q' = col_0 || col_1 || ...` +/// 4. Compute the cumulative height sequence `t`. +/// 5. Commit to `q'` as a single-column matrix using `Pcs::batch_commit`. +/// +/// # Arguments +/// * `pp` — Prover parameters for `Pcs`. +/// * `rmms` — Non-empty sequence of row-major matrices. This function uses each matrix's height exactly as given. +/// +/// # Errors +/// Returns `Error::InvalidPcsParam` if `rmms` is empty or all matrices are empty. +/// Any error from the inner `Pcs::batch_commit` is propagated as-is. +pub fn jagged_commit>( + pp: &Pcs::ProverParam, + rmms: Vec>, +) -> Result, Error> { + if rmms.is_empty() { + return Err(Error::InvalidPcsParam( + "jagged_commit: cannot commit to empty sequence of matrices".to_string(), + )); + } + + // --- Step 1: Compute cumulative heights --- + let mut poly_heights: Vec = Vec::new(); + for rmm in &rmms { + let num_rows = rmm.height(); + let num_cols = rmm.width(); + + if num_rows == 0 { + return Err(Error::InvalidPcsParam( + "jagged_commit: matrix has zero rows".to_string(), + )); + } + if num_cols == 0 { + return Err(Error::InvalidPcsParam( + "jagged_commit: matrix has zero columns".to_string(), + )); + } + for _ in 0..num_cols { + poly_heights.push(num_rows); + } + } + if poly_heights.is_empty() { + return Err(Error::InvalidPcsParam( + "jagged_commit: no polynomials found in input matrices".to_string(), + )); + } + // t[0] = 0, t[i+1] = t[i] + poly_heights[i] + let cumulative_heights = poly_heights + .iter() + .chain(once(&0)) + .scan(0usize, |acc, &h| { + let current = *acc; + *acc += h; + Some(current) + }) + .collect::>(); + + // --- Steps 2 & 3: Bit-reverse rows, transpose, and write to concatenated --- + let total_size = cumulative_heights.last().copied().unwrap(); + let mut concatenated: Vec = Vec::with_capacity(total_size); + // Safety: every element in `concatenated[0..total_size]` is fully written + // by the transpose loop below before it is read. + unsafe { concatenated.set_len(total_size) }; + + // `poly_idx` tracks which poly (column index in cumulative_heights) is the + // first polynomial of the current matrix. + let mut poly_idx = 0; + for rmm in &rmms { + // Step 2: Bit-reverse the rows (suffix-to-prefix transformation). + // br.values[i * n_cols + j] = original[bitrev(i)][j] + let br = rmm.as_view().bit_reverse_rows().to_row_major_matrix(); + + let n_cols = br.width(); + let n_rows = br.height(); + let n_cells = n_cols * n_rows; + + // The start position in `concatenated` for this matrix's block of polynomials. + let start = cumulative_heights[poly_idx]; + + // Step 3: Transpose — write each column j of `br` (= one polynomial) + // into its corresponding contiguous slice in `concatenated`. + (0..n_cols) + .into_par_iter() + .zip(concatenated[start..start + n_cells].par_chunks_mut(n_rows)) + .for_each(|(j, chunk)| { + br.values + .iter() + .skip(j) + .step_by(n_cols) + .zip_eq(chunk.iter_mut()) + .for_each(|(v, out)| *out = *v); + }); + + poly_idx += n_cols; + } + + // --- Step 4: Commit via the inner PCS --- + // q' is committed as a single-column matrix with height = total_size. + let giga_rmm = RowMajorMatrix::::new_by_values( + concatenated, + 1, // width = 1 (single polynomial q') + InstancePaddingStrategy::Default, + ); + + let inner = Pcs::batch_commit(pp, vec![giga_rmm])?; + + Ok(JaggedCommitmentWithWitness { + inner, + cumulative_heights, + poly_heights, + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + basefold::{Basefold, BasefoldRSParams}, + test_util::setup_pcs, + }; + use ff_ext::GoldilocksExt2; + use p3::{field::FieldAlgebra, goldilocks::Goldilocks}; + + type F = Goldilocks; + type E = GoldilocksExt2; + type Pcs = Basefold; + + fn make_rmm(num_rows: usize, num_cols: usize) -> RowMajorMatrix { + let values: Vec = (0..num_rows * num_cols) + .map(|i| F::from_canonical_u64(i as u64 + 1)) + .collect(); + RowMajorMatrix::::new_by_values(values, num_cols, InstancePaddingStrategy::Default) + } + + #[test] + fn test_cumulative_heights_single_matrix() { + // 4x2 matrix → 2 polynomials with 4 evaluations each → cumulative = [0, 4, 8] + let rmm = make_rmm(4, 2); + let num_rows = rmm.height(); + let num_cols = rmm.width(); + let mut poly_heights = Vec::new(); + for _ in 0..num_cols { + poly_heights.push(num_rows); + } + let mut ch = vec![0usize]; + for &h in &poly_heights { + ch.push(ch.last().unwrap() + h); + } + assert_eq!(poly_heights, vec![4, 4]); + assert_eq!(ch, vec![0, 4, 8]); + } + + #[test] + fn test_cumulative_heights_multiple_matrices() { + // 4x1 + 8x2 → heights [4, 8, 8] → cumulative [0, 4, 12, 20] + let m1 = make_rmm(4, 1); + let m2 = make_rmm(8, 2); + let mut poly_heights: Vec = Vec::new(); + for rmm in &[m1, m2] { + for _ in 0..rmm.width() { + poly_heights.push(rmm.height()); + } + } + let mut ch = vec![0usize]; + for &h in &poly_heights { + ch.push(ch.last().unwrap() + h); + } + assert_eq!(poly_heights, vec![4, 8, 8]); + assert_eq!(ch, vec![0, 4, 12, 20]); + } + + #[test] + fn test_jagged_commit_smoke() { + // Two matrices: 4x1 and 4x2 → 3 polynomials, total 12 evals, padded to 16 + let (pp, _vp) = setup_pcs::(4); + let m1 = make_rmm(4, 1); + let m2 = make_rmm(4, 2); + + let comm = jagged_commit::(&pp, vec![m1, m2]).expect("commit should succeed"); + + assert_eq!(comm.num_polys(), 3); + assert_eq!(comm.poly_heights, vec![4, 4, 4]); + assert_eq!(comm.cumulative_heights, vec![0, 4, 8, 12]); + assert_eq!(comm.total_evaluations(), 12); + + let pure = comm.to_commitment(); + assert_eq!(pure.cumulative_heights, vec![0, 4, 8, 12]); + } + + #[test] + fn test_jagged_commit_single_poly() { + // 8x1 matrix → 1 polynomial, 8 evals, no padding needed + let (pp, _vp) = setup_pcs::(3); + let m = make_rmm(8, 1); + + let comm = jagged_commit::(&pp, vec![m]).expect("commit should succeed"); + + assert_eq!(comm.num_polys(), 1); + assert_eq!(comm.poly_heights, vec![8]); + assert_eq!(comm.cumulative_heights, vec![0, 8]); + assert_eq!(comm.total_evaluations(), 8); + } + + #[test] + fn test_jagged_commit_empty_error() { + let (pp, _vp) = setup_pcs::(4); + let result = jagged_commit::(&pp, vec![]); + assert!(matches!(result, Err(Error::InvalidPcsParam(_)))); + } +} diff --git a/crates/mpcs/src/lib.rs b/crates/mpcs/src/lib.rs index a8defea..be6380d 100644 --- a/crates/mpcs/src/lib.rs +++ b/crates/mpcs/src/lib.rs @@ -267,6 +267,8 @@ pub use basefold::{ Basefold, BasefoldCommitment, BasefoldCommitmentWithWitness, BasefoldDefault, BasefoldParams, BasefoldRSParams, BasefoldSpec, EncodingScheme, RSCode, RSCodeDefaultSpec, }; +pub mod jagged; +pub use jagged::{JaggedCommitment, JaggedCommitmentWithWitness, jagged_commit}; #[cfg(feature = "whir")] extern crate whir as whir_external; #[cfg(feature = "whir")] From 91aa59f2c9408cb0080b954f214dc97e15b8e6cf Mon Sep 17 00:00:00 2001 From: xkx Date: Tue, 14 Apr 2026 21:35:35 +0800 Subject: [PATCH 2/3] feat: jagged sumcheck (#32) * claude code impl plan * implement the jagged_sumcheck using time-space tradeoff sumcheck prover algorithm * remove debugging codes * ref to the original paper * add jagged sumcheck bench * #32 parallelize (#34) * par wip * check f(z) * g(z) matches expected evaluation * fix clippy * fix clippy: add #[cfg(test)] to test-only method and fix unused import/variable warnings Agent-Logs-Url: https://github.com/scroll-tech/gkr-backend/sessions/4a61316e-cf28-47b8-a43e-fb6ab432701e Co-authored-by: hero78119 <3962077+hero78119@users.noreply.github.com> * refactor test * apply functional programming style * avoid unnecessary BaseField-to-ExtensionField conversion in q_evals access Use E * BaseField multiplication directly instead of converting q_evals elements to extension field with .into() first. Co-Authored-By: Claude Opus 4.6 * replace col_row binary search with incremental ColRowIter Add ColRowIter that does one binary search at construction and O(1) per step, replacing per-element binary searches in build_m_table, bind_and_materialize, compute_claimed_sum, and final_evaluations_slow. Co-Authored-By: Claude Opus 4.6 --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: hero78119 <3962077+hero78119@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 * extend jagged sumcheck benchmark to cover n=25..31 * switch jagged sumcheck benchmark to BabyBearExt4 Co-Authored-By: Claude Opus 4.6 * remove jagged sumcheck plan doc Co-Authored-By: Claude Opus 4.6 --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: hero78119 <3962077+hero78119@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 --- Cargo.lock | 13 + crates/mpcs/Cargo.toml | 5 + crates/mpcs/benches/jagged_sumcheck.rs | 65 +++ crates/mpcs/src/jagged.rs | 622 ++++++++++++++++++++++++- crates/mpcs/src/lib.rs | 5 +- 5 files changed, 707 insertions(+), 3 deletions(-) create mode 100644 crates/mpcs/benches/jagged_sumcheck.rs diff --git a/Cargo.lock b/Cargo.lock index 303a17c..554bd40 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -838,6 +838,7 @@ dependencies = [ "serde", "sumcheck", "tracing", + "tracing-forest", "tracing-subscriber", "transcript", "whir", @@ -1919,6 +1920,18 @@ dependencies = [ "valuable", ] +[[package]] +name = "tracing-forest" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee40835db14ddd1e3ba414292272eddde9dad04d3d4b65509656414d1c42592f" +dependencies = [ + "smallvec", + "thiserror", + "tracing", + "tracing-subscriber", +] + [[package]] name = "tracing-log" version = "0.2.0" diff --git a/crates/mpcs/Cargo.toml b/crates/mpcs/Cargo.toml index b54e03d..0962118 100644 --- a/crates/mpcs/Cargo.toml +++ b/crates/mpcs/Cargo.toml @@ -30,6 +30,7 @@ witness.workspace = true [dev-dependencies] criterion.workspace = true +tracing-forest = "0.1" [features] nightly-features = ["ff_ext/nightly-features"] @@ -54,3 +55,7 @@ name = "interpolate" harness = false name = "whir" required-features = ["whir"] + +[[bench]] +harness = false +name = "jagged_sumcheck" diff --git a/crates/mpcs/benches/jagged_sumcheck.rs b/crates/mpcs/benches/jagged_sumcheck.rs new file mode 100644 index 0000000..1390bfc --- /dev/null +++ b/crates/mpcs/benches/jagged_sumcheck.rs @@ -0,0 +1,65 @@ +use criterion::*; +use ff_ext::{BabyBearExt4, ExtensionField, FieldFrom, FromUniformBytes}; +use mpcs::{JaggedSumcheckInput, jagged_sumcheck_prove}; +use multilinear_extensions::virtual_poly::build_eq_x_r_vec; +use rand::thread_rng; +use transcript::BasicTranscript; + +type E = BabyBearExt4; +type F = ::BaseField; + +const NUM_SAMPLES: usize = 10; + +fn bench_jagged_sumcheck(c: &mut Criterion) { + let mut group = c.benchmark_group("jagged_sumcheck"); + group.sample_size(NUM_SAMPLES); + + // (num_giga_vars, num_polys, poly_height_log2) + let configs: Vec<(usize, usize, usize)> = (25..=31) + .map(|n| { + let s = 21usize; + let num_polys = 1usize << (n - s); + (n, num_polys, s) + }) + .collect(); + + for (num_giga_vars, num_polys, s) in configs { + let poly_height = 1usize << s; + let total_evals = num_polys * poly_height; + + let mut rng = thread_rng(); + + let q_evals: Vec = (0..total_evals) + .map(|i| F::from_v((i as u64 * 13 + 7) % (1 << 30))) + .collect(); + + let cumulative_heights: Vec = (0..=num_polys).map(|i| i * poly_height).collect(); + + let z_row: Vec = (0..s).map(|_| E::random(&mut rng)).collect(); + let z_col_vars = (num_polys as f64).log2().ceil() as usize; + let z_col: Vec = (0..z_col_vars).map(|_| E::random(&mut rng)).collect(); + + let input = JaggedSumcheckInput { + q_evals: &q_evals, + num_giga_vars, + cumulative_heights: &cumulative_heights, + eq_row: build_eq_x_r_vec(&z_row), + eq_col: build_eq_x_r_vec(&z_col), + }; + + group.bench_function( + BenchmarkId::new("prove", format!("n={}", num_giga_vars)), + |b| { + b.iter(|| { + let mut transcript = BasicTranscript::::new(b"jagged_bench"); + jagged_sumcheck_prove(black_box(&input), &mut transcript) + }) + }, + ); + } + + group.finish(); +} + +criterion_group!(benches, bench_jagged_sumcheck); +criterion_main!(benches); diff --git a/crates/mpcs/src/jagged.rs b/crates/mpcs/src/jagged.rs index ddfd558..0d044cd 100644 --- a/crates/mpcs/src/jagged.rs +++ b/crates/mpcs/src/jagged.rs @@ -56,13 +56,22 @@ use std::iter::once; use crate::{Error, PolynomialCommitmentScheme}; use ff_ext::ExtensionField; use itertools::Itertools; +use multilinear_extensions::{ + mle::MultilinearExtension, util::max_usable_threads, virtual_poly::build_eq_x_r_vec, +}; use p3::{ matrix::{Matrix, bitrev::BitReversableMatrix}, maybe_rayon::prelude::{ - IndexedParallelIterator, IntoParallelIterator, ParallelIterator, ParallelSliceMut, + IndexedParallelIterator, IntoParallelIterator, ParallelIterator, ParallelSlice, + ParallelSliceMut, }, }; use serde::{Deserialize, Serialize}; +use sumcheck::{ + macros::{entered_span, exit_span}, + structs::{IOPProof, IOPProverMessage, IOPProverState}, +}; +use transcript::Transcript; use witness::{InstancePaddingStrategy, RowMajorMatrix}; /// Commitment to a jagged polynomial `q'`, together with all witness data needed @@ -183,7 +192,10 @@ pub fn jagged_commit>( let mut concatenated: Vec = Vec::with_capacity(total_size); // Safety: every element in `concatenated[0..total_size]` is fully written // by the transpose loop below before it is read. - unsafe { concatenated.set_len(total_size) }; + #[allow(clippy::uninit_vec)] + unsafe { + concatenated.set_len(total_size) + }; // `poly_idx` tracks which poly (column index in cumulative_heights) is the // first polynomial of the current matrix. @@ -234,6 +246,427 @@ pub fn jagged_commit>( }) } +// --------------------------------------------------------------------------- +// Jagged Sumcheck Prover (M-table streaming algorithm) +// --------------------------------------------------------------------------- + +// Streaming sumcheck prover using the M-table algorithm from +// "Time-Space Trade-Offs for Sumcheck" (eprint 2025/1473), Section 4. + +/// Number of streaming rounds before switching to standard sumcheck. +/// Determined by the epoch schedule: 1 + 2 + 4 + 8 = 15. +#[allow(dead_code)] +const STREAMING_ROUNDS: usize = 15; +/// Epoch sizes used in the streaming phase: j' = 1, 2, 4, 8. +const EPOCH_SIZES: [usize; 4] = [1, 2, 4, 8]; + +/// All inputs needed for the jagged sumcheck. +pub struct JaggedSumcheckInput<'a, E: ExtensionField> { + /// Giga polynomial evaluations (concatenated, bit-reversed). + pub q_evals: &'a [E::BaseField], + /// n = log2(padded_total_size). + pub num_giga_vars: usize, + /// Cumulative height sequence t[j], length num_polys + 1. + pub cumulative_heights: &'a [usize], + /// Precomputed eq table for the row evaluation point: `build_eq_x_r_vec(z_row)`. + pub eq_row: Vec, + /// Precomputed eq table for the column challenge point: `build_eq_x_r_vec(z_col)`. + pub eq_col: Vec, +} + +/// Iterator that yields `(col, row)` pairs for consecutive giga indices. +/// Uses one binary search at construction, then O(1) per step. +struct ColRowIter<'a> { + cumulative_heights: &'a [usize], + col: usize, + row: usize, + num_polys: usize, +} + +impl<'a> Iterator for ColRowIter<'a> { + type Item = (usize, usize); + + fn next(&mut self) -> Option<(usize, usize)> { + if self.col >= self.num_polys { + return None; + } + let result = (self.col, self.row); + self.row += 1; + let poly_height = self.cumulative_heights[self.col + 1] - self.cumulative_heights[self.col]; + if self.row >= poly_height { + self.row = 0; + self.col += 1; + } + Some(result) + } +} + +impl<'a, E: ExtensionField> JaggedSumcheckInput<'a, E> { + fn total_evaluations(&self) -> usize { + *self.cumulative_heights.last().unwrap_or(&0) + } + + /// Return an iterator yielding `(col, row)` for consecutive giga indices + /// starting from `start`. One binary search at construction, O(1) per step. + fn col_row_iter(&self, start: usize) -> ColRowIter<'_> { + let num_polys = self.cumulative_heights.len() - 1; + if start >= self.total_evaluations() { + return ColRowIter { + cumulative_heights: self.cumulative_heights, + col: num_polys, + row: 0, + num_polys, + }; + } + let j = self.cumulative_heights.partition_point(|&t| t <= start) - 1; + ColRowIter { + cumulative_heights: self.cumulative_heights, + col: j, + row: start - self.cumulative_heights[j], + num_polys, + } + } + + /// Brute-force computation of sum_b q'(b) * f(b). + /// O(2^n) time — only for debugging and tests. + #[cfg(test)] + fn compute_claimed_sum(&self) -> E { + self.q_evals[..self.total_evaluations()] + .iter() + .zip(self.col_row_iter(0)) + .fold(E::ZERO, |sum, (&q, (col, row))| { + sum + (self.eq_row[row] * self.eq_col[col]) * q + }) + } + + /// Brute-force MLE evaluation of q'(rb) and f(rb) at the given point. + /// O(2^n) time and memory — only for debugging and tests. + /// Returns `(q_at_point, f_at_point)`. + #[cfg(test)] + fn final_evaluations_slow(&self, point: &[E]) -> (E, E) { + let n = self.num_giga_vars; + let total_evals = self.total_evaluations(); + + // Build q' MLE (padded with zeros) and evaluate at point. + let mut q_padded: Vec = self.q_evals.to_vec(); + q_padded.resize(1 << n, Default::default()); + let q_mle = MultilinearExtension::from_evaluations_vec(n, q_padded); + let q_at_point = q_mle.evaluate(point); + + // Build f MLE from eq tables and evaluate at point. + let mut f_evals = vec![E::ZERO; 1 << n]; + for (f_eval, (col, row)) in f_evals[..total_evals].iter_mut().zip(self.col_row_iter(0)) { + *f_eval = self.eq_row[row] * self.eq_col[col]; + } + let f_mle = MultilinearExtension::from_evaluations_ext_vec(n, f_evals); + let f_at_point = f_mle.evaluate(point); + + (q_at_point, f_at_point) + } +} + +/// Run the full jagged sumcheck: streaming phase (rounds 1..K) + standard phase (rounds K+1..n). +/// +/// Returns the proof and the full list of challenges (r_1, ..., r_n). +pub fn jagged_sumcheck_prove( + input: &JaggedSumcheckInput, + transcript: &mut impl Transcript, +) -> (IOPProof, Vec) { + let n = input.num_giga_vars; + let max_degree: usize = 2; + + let mut challenges: Vec = Vec::with_capacity(n); + let mut proof_messages: Vec> = Vec::with_capacity(n); + + // Write transcript header (must match verifier's expectations). + transcript.append_message(&n.to_le_bytes()); + transcript.append_message(&max_degree.to_le_bytes()); + + // --- Streaming phase: epochs j' = 1, 2, 4, 8 --- + for &epoch_size in &EPOCH_SIZES { + // Epoch j' handles rounds j'..2j'-1. Skip if all rounds are done. + if epoch_size > n { + break; + } + + // Build M-table for this epoch. + let span = entered_span!("build_m_table", epoch = epoch_size); + let m_table = build_m_table(input, &challenges, epoch_size); + exit_span!(span); + + // Extract rounds j = epoch_size .. min(2*epoch_size - 1, n) + let span = entered_span!("compute_rounds_from_m", epoch = epoch_size); + for j in epoch_size..(2 * epoch_size).min(n + 1) { + let d = j - epoch_size; // intra-epoch offset + let intra_challenges = challenges[epoch_size - 1..epoch_size - 1 + d].to_vec(); + + let [_h0, h1, h2] = compute_round_from_m(&m_table, epoch_size, &intra_challenges); + + // Append [h(1), h(2)] to transcript and sample challenge. + transcript.append_field_element_ext(&h1); + transcript.append_field_element_ext(&h2); + let challenge = transcript + .sample_and_append_challenge(b"Internal round") + .elements; + + proof_messages.push(IOPProverMessage { + evaluations: vec![h1, h2], + }); + challenges.push(challenge); + } + exit_span!(span); + } + + // --- Phase 2: Bind and materialize, then standard sumcheck --- + let k = challenges.len(); // actual number of streaming rounds completed + if k < n { + let span = entered_span!("bind_and_materialize"); + let (q_bound, f_bound) = bind_and_materialize(input, &challenges); + exit_span!(span); + + let remaining_vars = n - k; + let q_mle = MultilinearExtension::from_evaluations_ext_vec(remaining_vars, q_bound); + let f_mle = MultilinearExtension::from_evaluations_ext_vec(remaining_vars, f_bound); + + // Use VirtualPolynomial + round-by-round proving (no extra transcript header). + use multilinear_extensions::virtual_poly::VirtualPolynomial; + use std::sync::Arc; + let q_arc = Arc::new(q_mle); + let f_arc = Arc::new(f_mle); + let vp = VirtualPolynomial::new_from_product(vec![q_arc, f_arc], E::ONE); + + let span = entered_span!("standard_sumcheck", rounds = remaining_vars); + let mut prover_state = + IOPProverState::prover_init_with_extrapolation_aux(true, vp, None, None); + let mut challenge = None; + for _ in 0..remaining_vars { + let prover_msg = + IOPProverState::prove_round_and_update_state(&mut prover_state, &challenge); + prover_msg + .evaluations + .iter() + .for_each(|e| transcript.append_field_element_ext(e)); + challenge = Some(transcript.sample_and_append_challenge(b"Internal round")); + challenges.push(challenge.unwrap().elements); + proof_messages.push(prover_msg); + } + exit_span!(span); + } + + ( + IOPProof { + proofs: proof_messages, + }, + challenges, + ) +} + +/// Build M-table for epoch j'. +/// +/// M[beta1 * 2^{j'} + beta2] = sum_b Q_bound(beta1, b) * F_bound(beta2, b) +/// +/// where: +/// - Q_bound(beta, b) = sum_{a in {0,1}^{j'-1}} eq(R, a) * q'[a || beta || b] +/// - F_bound(beta, b) = sum_{a in {0,1}^{j'-1}} eq(R, a) * f[a || beta || b] +fn build_m_table( + input: &JaggedSumcheckInput, + challenges: &[E], // R_{j'} = (r_1, ..., r_{j'-1}) + epoch_size: usize, // j' +) -> Vec { + let n = input.num_giga_vars; + let bound_vars = epoch_size - 1; // j' - 1 + + let eq_r = if bound_vars > 0 { + build_eq_x_r_vec(&challenges[..bound_vars]) + } else { + vec![E::ONE] + }; + + let beta_count = 1usize << epoch_size; // 2^{j'} + let a_count = 1usize << bound_vars; // 2^{j'-1} + let chunk_size = a_count * beta_count; // 2^{2j'-1} + // When n < 2j'-1, all variables fit in a single chunk (no b dimension). + let n_chunks = 1usize << n.saturating_sub(2 * epoch_size - 1); // 2^{max(0, n - 2j' + 1)} + + let m_size = beta_count * beta_count; // 2^{2j'} + + // Step 1: Each thread processes a batch of b-chunks, producing a local M-table. + let span = entered_span!( + "streaming_pass", + n_chunks = n_chunks, + beta_count = beta_count + ); + let indices: Vec = (0..n_chunks).collect(); + let n_threads = max_usable_threads(); + let batch_size = (n_chunks / n_threads).max(1); + let partial_tables: Vec> = indices + .par_chunks(batch_size) + .map(|batch| { + let mut local_m = vec![E::ZERO; m_size]; + let mut q_bound = vec![E::ZERO; beta_count]; + let mut f_bound = vec![E::ZERO; beta_count]; + + for &b_idx in batch { + let chunk_start = b_idx * chunk_size; + q_bound + .iter_mut() + .zip(f_bound.iter_mut()) + .enumerate() + .for_each(|(beta, (q_b, f_b))| { + let base = chunk_start + beta * a_count; + let (q_acc, f_acc) = eq_r + .iter() + .zip(input.q_evals.get(base..).unwrap_or(&[])) + .zip(input.col_row_iter(base)) + .fold( + (E::ZERO, E::ZERO), + |(q_acc, f_acc), ((&eq_r_a, &q), (col, row))| { + ( + q_acc + eq_r_a * q, + f_acc + eq_r_a * (input.eq_row[row] * input.eq_col[col]), + ) + }, + ); + *q_b = q_acc; + *f_b = f_acc; + }); + + // Outer product accumulation into local M-table. + for b1 in 0..beta_count { + if q_bound[b1] == E::ZERO { + continue; + } + for b2 in 0..beta_count { + local_m[b1 * beta_count + b2] += q_bound[b1] * f_bound[b2]; + } + } + } + local_m + }) + .collect(); + exit_span!(span); + + // Step 2: Sum partial M-tables in parallel, each thread handles a slice of cells. + let span = entered_span!("reduce_partial_tables", n_partials = partial_tables.len()); + let n_partials = partial_tables.len(); + if n_partials == 0 { + exit_span!(span); + return vec![E::ZERO; m_size]; + } + let mut m_table = partial_tables[0].clone(); + let cell_batch = (m_size / n_threads).max(1); + m_table + .par_chunks_mut(cell_batch) + .enumerate() + .for_each(|(ci, cells)| { + let start = ci * cell_batch; + for partial in &partial_tables[1..] { + for (j, cell) in cells.iter_mut().enumerate() { + *cell += partial[start + j]; + } + } + }); + exit_span!(span); + m_table +} + +/// Extract round univariate h_j(x) from M-table. +/// +/// For round j in epoch j', d = j - j' intra-epoch challenges have been collected. +/// Returns [h(0), h(1), h(2)]. +fn compute_round_from_m( + m_table: &[E], + epoch_size: usize, // j' + intra_challenges: &[E], // r_{j'}, ..., r_{j-1} (d elements) +) -> [E; 3] { + let d = intra_challenges.len(); + let beta_count = 1usize << epoch_size; + let pad_bits = epoch_size - d - 1; // number of "future" bits to sum over + let pad_count = 1usize << pad_bits; + + let eq_intra = if d > 0 { + build_eq_x_r_vec(intra_challenges) + } else { + vec![E::ONE] + }; + + let a_count = 1usize << d; // 2^d + + let mut h = [E::ZERO; 3]; + + for a in 0..a_count { + for c in 0..a_count { + let eq_weight = eq_intra[a] * eq_intra[c]; + if eq_weight == E::ZERO { + continue; + } + + // Sum over all pad bit assignments (same pad for both beta1/beta2 + // since they correspond to the same physical "future" variables). + for p in 0..pad_count { + // beta = a_bits || x_bit || pad_bits (little-endian) + // beta_val = a + x_bit * 2^d + pad * 2^{d+1} + let base1 = a + (p << (d + 1)); + let base2 = c + (p << (d + 1)); + let b1_0 = base1; // x=0 + let b1_1 = base1 + (1 << d); // x=1 + let b2_0 = base2; + let b2_1 = base2 + (1 << d); + + let m00 = m_table[b1_0 * beta_count + b2_0]; + let m10 = m_table[b1_1 * beta_count + b2_0]; + let m01 = m_table[b1_0 * beta_count + b2_1]; + let m11 = m_table[b1_1 * beta_count + b2_1]; + + h[0] += eq_weight * m00; + h[1] += eq_weight * m11; + // h(2) via bilinear: (1-2)^2*M00 + 2(1-2)*M10 + (1-2)*2*M01 + 4*M11 + h[2] += eq_weight * (m00 - m10.double() - m01.double() + m11.double().double()); + } + } + } + + h +} + +/// Bind first K variables and materialize reduced q' and f as extension-field vectors. +/// +/// q_bound[idx] = sum_{a in {0,1}^K} eq(R, a) * q'[a + idx * 2^K] +/// f_bound[idx] = sum_{a in {0,1}^K} eq(R, a) * f[a + idx * 2^K] +fn bind_and_materialize( + input: &JaggedSumcheckInput, + challenges: &[E], // R_K = (r_1, ..., r_K) +) -> (Vec, Vec) { + let n = input.num_giga_vars; + let k = challenges.len(); + let remaining_size = 1usize << (n - k); + let a_count = 1usize << k; + + let eq_r = build_eq_x_r_vec(challenges); + + // Each output index is independent — parallelize over idx. + let results: Vec<(E, E)> = (0..remaining_size) + .into_par_iter() + .map(|idx| { + let base = idx * a_count; + eq_r.iter() + .zip(input.q_evals.get(base..).unwrap_or(&[])) + .zip(input.col_row_iter(base)) + .fold( + (E::ZERO, E::ZERO), + |(q_acc, f_acc), ((&eq_r_a, &q), (col, row))| { + ( + q_acc + eq_r_a * q, + f_acc + eq_r_a * (input.eq_row[row] * input.eq_col[col]), + ) + }, + ) + }) + .collect(); + + results.into_iter().unzip() +} + #[cfg(test)] mod tests { use super::*; @@ -330,4 +763,189 @@ mod tests { let result = jagged_commit::(&pp, vec![]); assert!(matches!(result, Err(Error::InvalidPcsParam(_)))); } + + // --- Sumcheck tests --- + + use multilinear_extensions::virtual_poly::{VPAuxInfo, build_eq_x_r_vec}; + use rand::thread_rng; + use std::marker::PhantomData; + use sumcheck::structs::IOPVerifierState; + use transcript::basic::BasicTranscript; + + #[test] + fn test_jagged_sumcheck_small() { + use ff_ext::FromUniformBytes; + + let mut rng = thread_rng(); + + // 3 polynomials of height 4 (s=2), total 12 evals, padded to 16 (n=4). + let num_polys = 3usize; + let poly_height = 4usize; + let s = 2; // log2(poly_height) + let total_evals = num_polys * poly_height; + let num_giga_vars = 4; // ceil(log2(12)) = 4, 2^4 = 16 + + let q_evals: Vec = (0..total_evals) + .map(|i| F::from_canonical_u64(i as u64 + 1)) + .collect(); + + let cumulative_heights: Vec = (0..=num_polys).map(|i| i * poly_height).collect(); + + let z_row: Vec = (0..s).map(|_| E::random(&mut rng)).collect(); + // z_col needs ceil(log2(num_polys)) = 2 variables + let z_col: Vec = (0..2).map(|_| E::random(&mut rng)).collect(); + + let input = JaggedSumcheckInput { + q_evals: &q_evals, + num_giga_vars, + cumulative_heights: &cumulative_heights, + eq_row: build_eq_x_r_vec(&z_row), + eq_col: build_eq_x_r_vec(&z_col), + }; + + let claimed_sum = input.compute_claimed_sum(); + + let mut transcript = BasicTranscript::::new(b"jagged_sumcheck_test"); + let (proof, challenges) = jagged_sumcheck_prove(&input, &mut transcript); + + assert_eq!(proof.proofs.len(), num_giga_vars); + assert_eq!(challenges.len(), num_giga_vars); + + // Verify using the standard sumcheck verifier. + let mut transcript_v = BasicTranscript::::new(b"jagged_sumcheck_test"); + let aux_info = VPAuxInfo { + max_degree: 2, + max_num_variables: num_giga_vars, + phantom: PhantomData::, + }; + let subclaim = + IOPVerifierState::::verify(claimed_sum, &proof, &aux_info, &mut transcript_v); + + // The subclaim point should match our challenges. + for (sc, ch) in subclaim.point.iter().zip(challenges.iter()) { + assert_eq!(sc.elements, *ch); + } + + // Verify the final evaluation: q'(point) * f(point) == expected_evaluation + let (q_at_point, f_at_point) = input.final_evaluations_slow(&challenges); + assert_eq!( + q_at_point * f_at_point, + subclaim.expected_evaluation, + "final evaluation mismatch" + ); + } + + #[test] + fn test_jagged_sumcheck_all_epochs() { + // n=16: exercises all 4 epochs (j'=1,2,4,8) + 1 round of standard sumcheck. + use ff_ext::FromUniformBytes; + + let mut rng = thread_rng(); + + let num_polys = 8usize; + let poly_height = 1 << 13; // 8192, s=13 + let s = 13; + let total_evals = num_polys * poly_height; // 65536 + let num_giga_vars = 16; // 2^16 = 65536 + + let q_evals: Vec = (0..total_evals) + .map(|i| F::from_canonical_u64((i as u64 * 7 + 3) % (1 << 20))) + .collect(); + + let cumulative_heights: Vec = (0..=num_polys).map(|i| i * poly_height).collect(); + + let z_row: Vec = (0..s).map(|_| E::random(&mut rng)).collect(); + let z_col: Vec = (0..3).map(|_| E::random(&mut rng)).collect(); // ceil(log2(8))=3 + + let input = JaggedSumcheckInput { + q_evals: &q_evals, + num_giga_vars, + cumulative_heights: &cumulative_heights, + eq_row: build_eq_x_r_vec(&z_row), + eq_col: build_eq_x_r_vec(&z_col), + }; + + let claimed_sum = input.compute_claimed_sum(); + + let mut transcript = BasicTranscript::::new(b"jagged_test_16"); + let (proof, challenges) = jagged_sumcheck_prove(&input, &mut transcript); + + assert_eq!(proof.proofs.len(), num_giga_vars); + assert_eq!(challenges.len(), num_giga_vars); + + let mut transcript_v = BasicTranscript::::new(b"jagged_test_16"); + let aux_info = VPAuxInfo { + max_degree: 2, + max_num_variables: num_giga_vars, + phantom: PhantomData::, + }; + let subclaim = + IOPVerifierState::::verify(claimed_sum, &proof, &aux_info, &mut transcript_v); + + for (sc, ch) in subclaim.point.iter().zip(challenges.iter()) { + assert_eq!(sc.elements, *ch); + } + } + + #[test] + fn test_jagged_sumcheck_n25() { + // n=25: 2^25 = 33M evaluations. Exercises all epochs + 10 rounds of standard sumcheck. + use ff_ext::FromUniformBytes; + + tracing_forest::init(); + + let mut rng = thread_rng(); + + let num_polys = 1 << 10; // 1024 polynomials + let poly_height = 1 << 15; // 32768 each, s=15 + let s = 15; + let total_evals = num_polys * poly_height; // 2^25 = 33554432 + let num_giga_vars = 25; + + let q_evals: Vec = (0..total_evals) + .map(|i| F::from_canonical_u64((i as u64 * 13 + 7) % (1 << 30))) + .collect(); + + let cumulative_heights: Vec = (0..=num_polys).map(|i| i * poly_height).collect(); + + let z_row: Vec = (0..s).map(|_| E::random(&mut rng)).collect(); + let z_col: Vec = (0..10).map(|_| E::random(&mut rng)).collect(); // ceil(log2(1024))=10 + + let input = JaggedSumcheckInput { + q_evals: &q_evals, + num_giga_vars, + cumulative_heights: &cumulative_heights, + eq_row: build_eq_x_r_vec(&z_row), + eq_col: build_eq_x_r_vec(&z_col), + }; + + let claimed_sum = input.compute_claimed_sum(); + + let mut transcript = BasicTranscript::::new(b"jagged_test_25"); + let (proof, challenges) = jagged_sumcheck_prove(&input, &mut transcript); + + assert_eq!(proof.proofs.len(), num_giga_vars); + assert_eq!(challenges.len(), num_giga_vars); + + let mut transcript_v = BasicTranscript::::new(b"jagged_test_25"); + let aux_info = VPAuxInfo { + max_degree: 2, + max_num_variables: num_giga_vars, + phantom: PhantomData::, + }; + let subclaim = + IOPVerifierState::::verify(claimed_sum, &proof, &aux_info, &mut transcript_v); + + for (sc, ch) in subclaim.point.iter().zip(challenges.iter()) { + assert_eq!(sc.elements, *ch); + } + + // Verify the final evaluation: q'(point) * f(point) == expected_evaluation + let (q_at_point, f_at_point) = input.final_evaluations_slow(&challenges); + assert_eq!( + q_at_point * f_at_point, + subclaim.expected_evaluation, + "final evaluation mismatch" + ); + } } diff --git a/crates/mpcs/src/lib.rs b/crates/mpcs/src/lib.rs index be6380d..4a4bb04 100644 --- a/crates/mpcs/src/lib.rs +++ b/crates/mpcs/src/lib.rs @@ -268,7 +268,10 @@ pub use basefold::{ BasefoldRSParams, BasefoldSpec, EncodingScheme, RSCode, RSCodeDefaultSpec, }; pub mod jagged; -pub use jagged::{JaggedCommitment, JaggedCommitmentWithWitness, jagged_commit}; +pub use jagged::{ + JaggedCommitment, JaggedCommitmentWithWitness, JaggedSumcheckInput, jagged_commit, + jagged_sumcheck_prove, +}; #[cfg(feature = "whir")] extern crate whir as whir_external; #[cfg(feature = "whir")] From adc2636905ae9ceb663dc6ab76b01c13a52d1e3f Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Tue, 14 Apr 2026 22:00:39 +0800 Subject: [PATCH 3/3] feat: make streaming epoch schedule configurable in jagged_sumcheck_prove (#38) * claude code impl plan * implement the jagged_sumcheck using time-space tradeoff sumcheck prover algorithm * remove debugging codes * ref to the original paper * add jagged sumcheck bench * #32 parallelize (#34) * par wip * check f(z) * g(z) matches expected evaluation * fix clippy * fix clippy: add #[cfg(test)] to test-only method and fix unused import/variable warnings Agent-Logs-Url: https://github.com/scroll-tech/gkr-backend/sessions/4a61316e-cf28-47b8-a43e-fb6ab432701e Co-authored-by: hero78119 <3962077+hero78119@users.noreply.github.com> * refactor test * apply functional programming style * avoid unnecessary BaseField-to-ExtensionField conversion in q_evals access Use E * BaseField multiplication directly instead of converting q_evals elements to extension field with .into() first. Co-Authored-By: Claude Opus 4.6 * replace col_row binary search with incremental ColRowIter Add ColRowIter that does one binary search at construction and O(1) per step, replacing per-element binary searches in build_m_table, bind_and_materialize, compute_claimed_sum, and final_evaluations_slow. Co-Authored-By: Claude Opus 4.6 --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: hero78119 <3962077+hero78119@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 * extend jagged sumcheck benchmark to cover n=25..31 * switch jagged sumcheck benchmark to BabyBearExt4 Co-Authored-By: Claude Opus 4.6 * remove jagged sumcheck plan doc Co-Authored-By: Claude Opus 4.6 * Initial plan * Make EPOCH_SIZES configurable in jagged_sumcheck_prove via optional parameter Agent-Logs-Url: https://github.com/scroll-tech/gkr-backend/sessions/b18805c7-6e15-44c0-ab03-a1905068d964 Co-authored-by: hero78119 <3962077+hero78119@users.noreply.github.com> * Fix doc spacing and add debug_assert for epoch_sizes validation Agent-Logs-Url: https://github.com/scroll-tech/gkr-backend/sessions/b18805c7-6e15-44c0-ab03-a1905068d964 Co-authored-by: hero78119 <3962077+hero78119@users.noreply.github.com> --------- Co-authored-by: kunxian xia Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: hero78119 <3962077+hero78119@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 --- crates/mpcs/benches/jagged_sumcheck.rs | 2 +- crates/mpcs/src/jagged.rs | 28 ++++++++++++++++---------- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/crates/mpcs/benches/jagged_sumcheck.rs b/crates/mpcs/benches/jagged_sumcheck.rs index 1390bfc..4dbaed0 100644 --- a/crates/mpcs/benches/jagged_sumcheck.rs +++ b/crates/mpcs/benches/jagged_sumcheck.rs @@ -52,7 +52,7 @@ fn bench_jagged_sumcheck(c: &mut Criterion) { |b| { b.iter(|| { let mut transcript = BasicTranscript::::new(b"jagged_bench"); - jagged_sumcheck_prove(black_box(&input), &mut transcript) + jagged_sumcheck_prove(black_box(&input), &mut transcript, None) }) }, ); diff --git a/crates/mpcs/src/jagged.rs b/crates/mpcs/src/jagged.rs index 0d044cd..97e3b30 100644 --- a/crates/mpcs/src/jagged.rs +++ b/crates/mpcs/src/jagged.rs @@ -253,12 +253,10 @@ pub fn jagged_commit>( // Streaming sumcheck prover using the M-table algorithm from // "Time-Space Trade-Offs for Sumcheck" (eprint 2025/1473), Section 4. -/// Number of streaming rounds before switching to standard sumcheck. -/// Determined by the epoch schedule: 1 + 2 + 4 + 8 = 15. -#[allow(dead_code)] -const STREAMING_ROUNDS: usize = 15; -/// Epoch sizes used in the streaming phase: j' = 1, 2, 4, 8. -const EPOCH_SIZES: [usize; 4] = [1, 2, 4, 8]; +/// Default log2 of the maximum epoch size for the streaming phase. +/// Epoch sizes are `[2^0, 2^1, ..., 2^LOG2_MAX_EPOCH]` = `[1, 2, 4, 8]`, +/// covering 1 + 2 + 4 + 8 = 15 streaming rounds before switching to standard sumcheck. +const LOG2_MAX_EPOCH: u32 = 3; /// All inputs needed for the jagged sumcheck. pub struct JaggedSumcheckInput<'a, E: ExtensionField> { @@ -367,14 +365,22 @@ impl<'a, E: ExtensionField> JaggedSumcheckInput<'a, E> { /// Run the full jagged sumcheck: streaming phase (rounds 1..K) + standard phase (rounds K+1..n). /// +/// `log2_max_epoch` controls the streaming phase epoch schedule. Epoch sizes are +/// `[1, 2, 4, ..., 2^k]` where `k = log2_max_epoch`. Pass `None` to use the default +/// `LOG2_MAX_EPOCH = 3`, giving epoch sizes `[1, 2, 4, 8]` and 15 streaming rounds. +/// /// Returns the proof and the full list of challenges (r_1, ..., r_n). pub fn jagged_sumcheck_prove( input: &JaggedSumcheckInput, transcript: &mut impl Transcript, + log2_max_epoch: Option, ) -> (IOPProof, Vec) { let n = input.num_giga_vars; let max_degree: usize = 2; + let k = log2_max_epoch.unwrap_or(LOG2_MAX_EPOCH); + let epoch_sizes: Vec = (0..=k).map(|i| 1usize << i).collect(); + let mut challenges: Vec = Vec::with_capacity(n); let mut proof_messages: Vec> = Vec::with_capacity(n); @@ -382,8 +388,8 @@ pub fn jagged_sumcheck_prove( transcript.append_message(&n.to_le_bytes()); transcript.append_message(&max_degree.to_le_bytes()); - // --- Streaming phase: epochs j' = 1, 2, 4, 8 --- - for &epoch_size in &EPOCH_SIZES { + // --- Streaming phase: epochs j' = 1, 2, 4, ..., 2^k --- + for &epoch_size in &epoch_sizes { // Epoch j' handles rounds j'..2j'-1. Skip if all rounds are done. if epoch_size > n { break; @@ -806,7 +812,7 @@ mod tests { let claimed_sum = input.compute_claimed_sum(); let mut transcript = BasicTranscript::::new(b"jagged_sumcheck_test"); - let (proof, challenges) = jagged_sumcheck_prove(&input, &mut transcript); + let (proof, challenges) = jagged_sumcheck_prove(&input, &mut transcript, None); assert_eq!(proof.proofs.len(), num_giga_vars); assert_eq!(challenges.len(), num_giga_vars); @@ -868,7 +874,7 @@ mod tests { let claimed_sum = input.compute_claimed_sum(); let mut transcript = BasicTranscript::::new(b"jagged_test_16"); - let (proof, challenges) = jagged_sumcheck_prove(&input, &mut transcript); + let (proof, challenges) = jagged_sumcheck_prove(&input, &mut transcript, None); assert_eq!(proof.proofs.len(), num_giga_vars); assert_eq!(challenges.len(), num_giga_vars); @@ -922,7 +928,7 @@ mod tests { let claimed_sum = input.compute_claimed_sum(); let mut transcript = BasicTranscript::::new(b"jagged_test_25"); - let (proof, challenges) = jagged_sumcheck_prove(&input, &mut transcript); + let (proof, challenges) = jagged_sumcheck_prove(&input, &mut transcript, None); assert_eq!(proof.proofs.len(), num_giga_vars); assert_eq!(challenges.len(), num_giga_vars);