diff --git a/ceno_recursion/src/zkvm_verifier/binding.rs b/ceno_recursion/src/zkvm_verifier/binding.rs index a8b7f985f..17d8f0bcd 100644 --- a/ceno_recursion/src/zkvm_verifier/binding.rs +++ b/ceno_recursion/src/zkvm_verifier/binding.rs @@ -385,6 +385,10 @@ pub struct ZKVMChipProofInput { pub has_gkr_proof: usize, pub gkr_iop_proof: GKRProofInput, + // chip-level rotation proof + pub has_rotation_proof: usize, + pub rotation_proof: SumcheckLayerProofInput, + // ecc proof pub has_ecc_proof: usize, pub ecc_proof: EccQuarkProofInput, @@ -476,6 +480,24 @@ impl From<(usize, ZKVMChipProof, usize, usize)> for ZKVMChipProofInput { } else { GKRProofInput::default() }, + has_rotation_proof: if p.rotation_proof.is_some() { 1 } else { 0 }, + rotation_proof: if let Some(rotation) = p.rotation_proof { + SumcheckLayerProofInput { + proof: IOPProverMessageVec::from( + rotation + .proof + .proofs + .iter() + .map(|p| IOPProverMessage { + evaluations: p.evaluations.clone(), + }) + .collect::>(), + ), + evals: rotation.evals, + } + } else { + SumcheckLayerProofInput::default() + }, has_ecc_proof: if p.ecc_proof.is_some() { 1 } else { 0 }, ecc_proof: if p.ecc_proof.is_some() { p.ecc_proof.unwrap().into() @@ -507,6 +529,8 @@ pub struct ZKVMChipProofInputVariable { pub main_sumcheck_proofs: IOPProverMessageVecVariable, pub has_gkr_iop_proof: Usize, pub gkr_iop_proof: GKRProofVariable, + pub has_rotation_proof: Usize, + pub rotation_proof: SumcheckLayerProofVariable, pub tower_proof: TowerProofInputVariable, pub has_ecc_proof: Usize, pub ecc_proof: EccQuarkProofVariable, @@ -543,6 +567,8 @@ impl Hintable for ZKVMChipProofInput { builder.cycle_tracker_end("read main sumcheck proofs"); let has_gkr_iop_proof = Usize::Var(usize::read(builder)); let gkr_iop_proof = GKRProofInput::read(builder); + let has_rotation_proof = Usize::Var(usize::read(builder)); + let rotation_proof = SumcheckLayerProofInput::read(builder); let has_ecc_proof = Usize::Var(usize::read(builder)); let ecc_proof = EccQuarkProofInput::read(builder); @@ -565,6 +591,8 @@ impl Hintable for ZKVMChipProofInput { main_sumcheck_proofs, has_gkr_iop_proof, gkr_iop_proof, + has_rotation_proof, + rotation_proof, tower_proof, has_ecc_proof, ecc_proof, @@ -626,6 +654,10 @@ impl Hintable for ZKVMChipProofInput { stream.extend(self.main_sumcheck_proofs.write()); stream.extend(>::write(&self.has_gkr_proof)); stream.extend(self.gkr_iop_proof.write()); + stream.extend(>::write( + &self.has_rotation_proof, + )); + stream.extend(self.rotation_proof.write()); stream.extend(>::write(&self.has_ecc_proof)); stream.extend(self.ecc_proof.write()); @@ -680,32 +712,12 @@ impl Hintable for SumcheckLayerProofInput { } } pub struct LayerProofInput { - pub has_rotation: usize, - pub rotation: SumcheckLayerProofInput, pub main: SumcheckLayerProofInput, } impl From> for LayerProofInput { fn from(p: LayerProof) -> Self { Self { - has_rotation: if p.rotation.is_some() { 1 } else { 0 }, - rotation: if p.rotation.is_some() { - let r = p.rotation.unwrap(); - SumcheckLayerProofInput { - proof: IOPProverMessageVec::from( - r.proof - .proofs - .iter() - .map(|p| IOPProverMessage { - evaluations: p.evaluations.clone(), - }) - .collect::>(), - ), - evals: r.evals, - } - } else { - SumcheckLayerProofInput::default() - }, main: SumcheckLayerProofInput { proof: IOPProverMessageVec::from( p.main @@ -725,8 +737,6 @@ impl From> for LayerProofInput { #[derive(DslVariable, Clone)] pub struct LayerProofVariable { - pub has_rotation: Usize, - pub rotation: SumcheckLayerProofVariable, pub main: SumcheckLayerProofVariable, } impl VecAutoHintable for LayerProofInput {} @@ -734,20 +744,12 @@ impl Hintable for LayerProofInput { type HintVariable = LayerProofVariable; fn read(builder: &mut Builder) -> Self::HintVariable { - let has_rotation = Usize::Var(usize::read(builder)); - let rotation = SumcheckLayerProofInput::read(builder); let main = SumcheckLayerProofInput::read(builder); - Self::HintVariable { - has_rotation, - rotation, - main, - } + Self::HintVariable { main } } fn write(&self) -> Vec::N>> { let mut stream = Vec::new(); - stream.extend(>::write(&self.has_rotation)); - stream.extend(self.rotation.write()); stream.extend(self.main.write()); stream } diff --git a/ceno_recursion/src/zkvm_verifier/verifier.rs b/ceno_recursion/src/zkvm_verifier/verifier.rs index a745c916d..12fc3dd2a 100644 --- a/ceno_recursion/src/zkvm_verifier/verifier.rs +++ b/ceno_recursion/src/zkvm_verifier/verifier.rs @@ -32,11 +32,7 @@ use ff_ext::BabyBearExt4; use crate::transcript::{challenger_add_forked_index, clone_challenger_state}; use gkr_iop::{ evaluation::EvalExpression, - gkr::{ - GKRCircuit, - booleanhypercube::BooleanHypercube, - layer::{Layer, ROTATION_OPENING_COUNT}, - }, + gkr::{GKRCircuit, booleanhypercube::BooleanHypercube, layer::Layer}, selector::SelectorType, }; use itertools::{Itertools, izip}; @@ -583,15 +579,7 @@ pub fn verify_chip_proof( is_infinity: Usize::uninit(builder), }; - if composed_cs.has_ecc_ops() { - builder.assert_nonzero(&chip_proof.has_ecc_proof); - let ecc_proof = &chip_proof.ecc_proof; - builder.assert_usize_eq(ecc_proof.sum.is_infinity.clone(), Usize::from(0)); - verify_ecc_proof(builder, challenger, ecc_proof, unipoly_extrapolator); - builder.assign(&shard_ec_sum, ecc_proof.sum.clone()); - } else { - builder.assign(&shard_ec_sum.is_infinity, Usize::from(1)); - } + builder.assign(&shard_ec_sum.is_infinity, Usize::from(1)); let tower_proof = &chip_proof.tower_proof; let num_variables: Array> = builder.dyn_array(num_batched); @@ -635,18 +623,23 @@ pub fn verify_chip_proof( }); } + if composed_cs.has_ecc_ops() { + builder.assert_nonzero(&chip_proof.has_ecc_proof); + let ecc_proof = &chip_proof.ecc_proof; + builder.assert_usize_eq(ecc_proof.sum.is_infinity.clone(), Usize::from(0)); + verify_ecc_proof(builder, challenger, ecc_proof, unipoly_extrapolator); + builder.assign(&shard_ec_sum, ecc_proof.sum.clone()); + } + let num_rw_records: Usize = builder.eval(r_counts_per_instance + w_counts_per_instance); builder.assert_usize_eq(record_evals.len(), num_rw_records.clone()); builder.assert_usize_eq(logup_p_evals.len(), lk_counts_per_instance.clone()); builder.assert_usize_eq(logup_q_evals.len(), lk_counts_per_instance.clone()); // GKR circuit - let out_evals_len: Usize = if cs.lk_table_expressions.is_empty() { - builder.eval(record_evals.len() + logup_q_evals.len()) - } else { - builder.eval(record_evals.len() + logup_p_evals.len() + logup_q_evals.len()) - }; - let out_evals: Array> = builder.dyn_array(out_evals_len.clone()); + let gkr_circuit = gkr_circuit.clone().unwrap(); + let out_evals: Array> = + builder.dyn_array(Usize::from(gkr_circuit.n_evaluations)); builder .range(0, record_evals.len()) @@ -670,14 +663,15 @@ pub fn verify_chip_proof( builder.assign(&end, record_evals.len()); } - let q_slice = out_evals.slice(builder, end, out_evals_len); + let q_end: Usize = builder.eval(end.clone() + logup_q_evals.len()); + let q_slice = out_evals.slice(builder, end, q_end); builder .range(0, logup_q_evals.len()) .for_each(|idx_vec, builder| { let cpt = builder.get(&logup_q_evals, idx_vec[0]); builder.set(&q_slice, idx_vec[0], cpt); }); - let gkr_circuit = gkr_circuit.clone().unwrap(); + let circuit_pi_evals: Array> = builder.dyn_array(Usize::from(cs.instance.len())); for (i, instance) in cs.instance.iter().enumerate() { @@ -686,93 +680,321 @@ pub fn verify_chip_proof( builder.set(&circuit_pi_evals, i, eval); } + let first_layer = gkr_circuit.layers.first().expect("empty gkr circuit layer"); let zero_bit_decomps: Array> = builder.dyn_array(32); - let selector_ctxs: Vec> = if cs.ec_final_sum.is_empty() { - let non_shard_n1 = Usize::Var(builder.get(&chip_proof.num_instances, 1)); - builder.assert_usize_eq(non_shard_n1, Usize::from(0)); - let num_instances_bit_decomps: Array>> = builder.dyn_array(1); - builder.set( - &num_instances_bit_decomps, - 0, - chip_proof - .sum_num_instances_minus_one_bit_decomposition - .clone(), - ); - vec![ + let sum_num_instances_bit_decomps: Array>> = builder.dyn_array(1); + builder.set( + &sum_num_instances_bit_decomps, + 0, + chip_proof + .sum_num_instances_minus_one_bit_decomposition + .clone(), + ); + + let mut selector_ctxs = Vec::with_capacity(first_layer.out_sel_and_eval_exprs.len()); + for (selector, _) in &first_layer.out_sel_and_eval_exprs { + let ctx = if cs.ec_final_sum.is_empty() { + let non_shard_n1 = Usize::Var(builder.get(&chip_proof.num_instances, 1)); + builder.assert_usize_eq(non_shard_n1, Usize::from(0)); SelectorContextVariable { offset: Usize::from(0), - offset_bit_decomps: zero_bit_decomps, + offset_bit_decomps: zero_bit_decomps.clone(), num_instances: chip_proof.sum_num_instances.clone(), - num_instances_layered_ns: builder.dyn_array(0), /* Only used in QuarkBinaryTreeLessThan(Expression) */ - num_instances_bit_decomps, + num_instances_layered_ns: builder.dyn_array(0), + num_instances_bit_decomps: sum_num_instances_bit_decomps.clone(), offset_instance_sum_bit_decomps: chip_proof .sum_num_instances_minus_one_bit_decomposition .clone(), num_vars: num_var_with_rotation.clone(), - }; - gkr_circuit - .layers - .first() - .map(|layer| layer.out_sel_and_eval_exprs.len()) - .unwrap_or(0) - ] - } else { - let num_inst_0_bit_decomps: Array>> = builder.dyn_array(1); - let num_inst_1_bit_decomps: Array>> = builder.dyn_array(1); - let num_inst_sum_bit_decomps: Array>> = builder.dyn_array(1); - - builder.set( - &num_inst_0_bit_decomps, - 0, - chip_proof.n_inst_0_bit_decomps.clone(), - ); - builder.set( - &num_inst_1_bit_decomps, - 0, - chip_proof.n_inst_1_bit_decomps.clone(), - ); - builder.set( - &num_inst_sum_bit_decomps, - 0, - chip_proof - .sum_num_instances_minus_one_bit_decomposition - .clone(), - ); - - vec![ + } + } else if cs.r_selector.as_ref() == Some(selector) { + let num_inst_0_bit_decomps: Array>> = builder.dyn_array(1); + builder.set( + &num_inst_0_bit_decomps, + 0, + chip_proof.n_inst_0_bit_decomps.clone(), + ); SelectorContextVariable { offset: Usize::from(0), offset_bit_decomps: zero_bit_decomps.clone(), num_instances: Usize::Var(builder.get(&chip_proof.num_instances, 0)), - num_instances_layered_ns: builder.dyn_array(0), /* Only used in QuarkBinaryTreeLessThan(Expression) */ + num_instances_layered_ns: builder.dyn_array(0), num_instances_bit_decomps: num_inst_0_bit_decomps, offset_instance_sum_bit_decomps: chip_proof.n_inst_0_bit_decomps.clone(), num_vars: num_var_with_rotation.clone(), - }, + } + } else if cs.w_selector.as_ref() == Some(selector) { + let num_inst_1_bit_decomps: Array>> = builder.dyn_array(1); + builder.set( + &num_inst_1_bit_decomps, + 0, + chip_proof.n_inst_1_bit_decomps.clone(), + ); SelectorContextVariable { offset: Usize::Var(builder.get(&chip_proof.num_instances, 0)), offset_bit_decomps: chip_proof.n_inst_0_bit_decomps.clone(), num_instances: Usize::Var(builder.get(&chip_proof.num_instances, 1)), - num_instances_layered_ns: builder.dyn_array(0), /* Only used in QuarkBinaryTreeLessThan(Expression) */ + num_instances_layered_ns: builder.dyn_array(0), num_instances_bit_decomps: num_inst_1_bit_decomps, offset_instance_sum_bit_decomps: chip_proof .sum_num_instances_minus_one_bit_decomposition .clone(), num_vars: num_var_with_rotation.clone(), - }, + } + } else { SelectorContextVariable { offset: Usize::from(0), - offset_bit_decomps: zero_bit_decomps, + offset_bit_decomps: zero_bit_decomps.clone(), num_instances: chip_proof.sum_num_instances.clone(), - num_instances_layered_ns: builder.dyn_array(0), /* Only used in QuarkBinaryTreeLessThan(Expression) */ - num_instances_bit_decomps: num_inst_sum_bit_decomps, + num_instances_layered_ns: builder.dyn_array(0), + num_instances_bit_decomps: sum_num_instances_bit_decomps.clone(), offset_instance_sum_bit_decomps: chip_proof .sum_num_instances_minus_one_bit_decomposition .clone(), num_vars: num_var_with_rotation.clone(), - }, - ] - }; + } + }; + selector_ctxs.push(ctx); + } + + if !first_layer.rotation_exprs.1.is_empty() { + builder.assert_usize_eq(chip_proof.has_rotation_proof.clone(), Usize::from(1)); + + let first_claim = builder.get(&out_evals, 0); + let rt_tower = builder.eval(first_claim.point.fs.clone()); + let RotationClaim { + left_evals, + right_evals, + target_evals, + left_point, + right_point, + origin_point, + } = verify_rotation( + builder, + challenger, + num_var_with_rotation.clone(), + first_layer.rotation_exprs.1.len(), + first_layer + .rotation_sumcheck_expression + .as_ref() + .expect("missing rotation sumcheck expression"), + &chip_proof.rotation_proof, + first_layer.rotation_cyclic_subgroup_size, + first_layer.rotation_cyclic_group_log2, + rt_tower, + challenges, + unipoly_extrapolator, + ); + + let [left_group_idx, right_group_idx, point_group_idx] = first_layer + .rotation_selector_group_indices() + .expect("rotation selectors missing"); + + let left_point: Array> = builder.eval(left_point); + let right_point: Array> = builder.eval(right_point); + let origin_point: Array> = builder.eval(origin_point); + + for (idx, eval_expr) in first_layer.out_sel_and_eval_exprs[left_group_idx] + .1 + .iter() + .enumerate() + { + let EvalExpression::Single(out_idx) = eval_expr else { + panic!("rotation groups must use EvalExpression::Single"); + }; + let eval = builder.get(&left_evals, idx); + let claim: PointAndEvalVariable = builder.eval(PointAndEvalVariable { + point: PointVariable { + fs: left_point.clone(), + }, + eval, + }); + builder.set(&out_evals, *out_idx, claim); + } + + for (idx, eval_expr) in first_layer.out_sel_and_eval_exprs[right_group_idx] + .1 + .iter() + .enumerate() + { + let EvalExpression::Single(out_idx) = eval_expr else { + panic!("rotation groups must use EvalExpression::Single"); + }; + let eval = builder.get(&right_evals, idx); + let claim: PointAndEvalVariable = builder.eval(PointAndEvalVariable { + point: PointVariable { + fs: right_point.clone(), + }, + eval, + }); + builder.set(&out_evals, *out_idx, claim); + } + + for (idx, eval_expr) in first_layer.out_sel_and_eval_exprs[point_group_idx] + .1 + .iter() + .enumerate() + { + let EvalExpression::Single(out_idx) = eval_expr else { + panic!("rotation groups must use EvalExpression::Single"); + }; + let eval = builder.get(&target_evals, idx); + let claim: PointAndEvalVariable = builder.eval(PointAndEvalVariable { + point: PointVariable { + fs: origin_point.clone(), + }, + eval, + }); + builder.set(&out_evals, *out_idx, claim); + } + } + + if composed_cs.has_ecc_ops() { + let [ + x_group_idx, + y_group_idx, + slope_group_idx, + x3_group_idx, + y3_group_idx, + ] = first_layer + .ecc_bridge_group_indices() + .expect("ecc bridge selectors missing"); + + transcript_observe_label(builder, challenger, b"ecc_gkr_bridge_r"); + let sample_r: Ext = challenger.sample_ext(builder); + let one_minus_r: Ext = builder.eval(one - sample_r); + let ecc_proof = &chip_proof.ecc_proof; + + let xy_point_len: Usize = builder.eval(ecc_proof.rt.fs.len() + Usize::from(1)); + let xy_point: Array> = builder.dyn_array(xy_point_len); + builder.set(&xy_point, 0, sample_r); + builder + .range(0, ecc_proof.rt.fs.len()) + .for_each(|idx_vec, builder| { + let idx = idx_vec[0]; + let v = builder.get(&ecc_proof.rt.fs, idx); + let shifted_idx = Usize::Var(Var::uninit(builder)); + builder.assign(&shifted_idx, idx + Usize::from(1)); + builder.set(&xy_point, shifted_idx, v); + }); + + let s_point_len: Usize = builder.eval(ecc_proof.rt.fs.len() + Usize::from(1)); + let s_point: Array> = builder.dyn_array(s_point_len.clone()); + builder + .range(0, ecc_proof.rt.fs.len()) + .for_each(|idx_vec, builder| { + let idx = idx_vec[0]; + let v = builder.get(&ecc_proof.rt.fs, idx); + builder.set(&s_point, idx, v); + }); + builder.set(&s_point, ecc_proof.rt.fs.len(), sample_r); + + let x3y3_point: Array> = builder.dyn_array(s_point_len.clone()); + builder + .range(0, ecc_proof.rt.fs.len()) + .for_each(|idx_vec, builder| { + let idx = idx_vec[0]; + let v = builder.get(&ecc_proof.rt.fs, idx); + builder.set(&x3y3_point, idx, v); + }); + builder.set(&x3y3_point, ecc_proof.rt.fs.len(), one); + + let degree = SEPTIC_EXTENSION_DEGREE; + for (idx, eval_expr) in first_layer.out_sel_and_eval_exprs[x_group_idx] + .1 + .iter() + .enumerate() + { + let EvalExpression::Single(out_idx) = eval_expr else { + panic!("ecc bridge x group must use EvalExpression::Single"); + }; + let x0 = builder.get(&ecc_proof.evals, 3 + degree + idx); + let x1 = builder.get(&ecc_proof.evals, 3 + degree * 3 + idx); + let eval = builder.eval(x0 * one_minus_r + x1 * sample_r); + let claim: PointAndEvalVariable = builder.eval(PointAndEvalVariable { + point: PointVariable { + fs: xy_point.clone(), + }, + eval, + }); + builder.set(&out_evals, *out_idx, claim); + } + + for (idx, eval_expr) in first_layer.out_sel_and_eval_exprs[y_group_idx] + .1 + .iter() + .enumerate() + { + let EvalExpression::Single(out_idx) = eval_expr else { + panic!("ecc bridge y group must use EvalExpression::Single"); + }; + let y0 = builder.get(&ecc_proof.evals, 3 + degree * 2 + idx); + let y1 = builder.get(&ecc_proof.evals, 3 + degree * 4 + idx); + let eval = builder.eval(y0 * one_minus_r + y1 * sample_r); + let claim: PointAndEvalVariable = builder.eval(PointAndEvalVariable { + point: PointVariable { + fs: xy_point.clone(), + }, + eval, + }); + builder.set(&out_evals, *out_idx, claim); + } + + for (idx, eval_expr) in first_layer.out_sel_and_eval_exprs[slope_group_idx] + .1 + .iter() + .enumerate() + { + let EvalExpression::Single(out_idx) = eval_expr else { + panic!("ecc bridge slope group must use EvalExpression::Single"); + }; + let s1 = builder.get(&ecc_proof.evals, 3 + idx); + let eval = builder.eval(s1 * sample_r); + let claim: PointAndEvalVariable = builder.eval(PointAndEvalVariable { + point: PointVariable { + fs: s_point.clone(), + }, + eval, + }); + builder.set(&out_evals, *out_idx, claim); + } + + for (idx, eval_expr) in first_layer.out_sel_and_eval_exprs[x3_group_idx] + .1 + .iter() + .enumerate() + { + let EvalExpression::Single(out_idx) = eval_expr else { + panic!("ecc bridge x3 group must use EvalExpression::Single"); + }; + let eval = builder.get(&ecc_proof.evals, 3 + degree * 5 + idx); + let claim: PointAndEvalVariable = builder.eval(PointAndEvalVariable { + point: PointVariable { + fs: x3y3_point.clone(), + }, + eval, + }); + builder.set(&out_evals, *out_idx, claim); + } + + for (idx, eval_expr) in first_layer.out_sel_and_eval_exprs[y3_group_idx] + .1 + .iter() + .enumerate() + { + let EvalExpression::Single(out_idx) = eval_expr else { + panic!("ecc bridge y3 group must use EvalExpression::Single"); + }; + let eval = builder.get(&ecc_proof.evals, 3 + degree * 6 + idx); + let claim: PointAndEvalVariable = builder.eval(PointAndEvalVariable { + point: PointVariable { + fs: x3y3_point.clone(), + }, + eval, + }); + builder.set(&out_evals, *out_idx, claim); + } + } builder.cycle_tracker_start("Verify GKR Circuit"); let rt = verify_gkr_circuit( @@ -812,26 +1034,13 @@ pub fn verify_gkr_circuit( for (i, layer) in gkr_circuit.layers.iter().enumerate() { let layer_proof = builder.get(&gkr_proof.layer_proofs, i); - let eval_and_dedup_points: Array> = extract_claim_and_point( - builder, - layer, - claims, - challenges, - &layer_proof.has_rotation, + let eval_and_dedup_points: Array> = + extract_claim_and_point(builder, layer, claims, challenges); + builder.assert_usize_eq( + Usize::from(layer.out_sel_and_eval_exprs.len()), + eval_and_dedup_points.len(), ); - if layer.rotation_sumcheck_expression.is_some() { - builder.assert_usize_eq( - Usize::from(layer.out_sel_and_eval_exprs.len() + 3), - eval_and_dedup_points.len(), - ); - } else { - builder.assert_usize_eq( - Usize::from(layer.out_sel_and_eval_exprs.len()), - eval_and_dedup_points.len(), - ); - } - // ZeroCheckLayer verification (might include other layer types in the future) let LayerProofVariable { main: @@ -840,91 +1049,14 @@ pub fn verify_gkr_circuit( evals: main_evals, evals_len_div_3: _main_evals_len_div_3, }, - rotation: rotation_proof, - has_rotation, } = layer_proof; - let expected_main_evals_len: Usize = Usize::from( - layer.n_witin + layer.n_fixed + layer.n_instance + layer.n_structural_witin, - ); + let expected_main_evals_len: Usize = + Usize::from(layer.n_witin + layer.n_fixed + layer.n_structural_witin); builder.assert_usize_eq(expected_main_evals_len, main_evals.len()); - if layer.rotation_sumcheck_expression.is_some() { - builder.if_eq(has_rotation, Usize::from(1)).then(|builder| { - let first = builder.get(&eval_and_dedup_points, 0); - builder.assert_usize_eq(first.has_point, Usize::from(1)); // Rotation proof should have at least one point - let rt = builder.eval(first.point.fs.clone()); - - let RotationClaim { - left_evals, - right_evals, - target_evals, - left_point, - right_point, - origin_point, - } = verify_rotation( - builder, - challenger, - max_num_variables.clone(), - layer.rotation_exprs.1.len(), - layer.rotation_sumcheck_expression.as_ref().unwrap(), - &rotation_proof, - layer.rotation_cyclic_subgroup_size, - layer.rotation_cyclic_group_log2, - rt, - challenges, - unipoly_extrapolator, - ); - - // extend eval_and_dedup_points by - // [ - // (left_evals, left_point), - // (right_evals, right_point), - // (target_evals, origin_point), - // ] - let last_idx: Usize = - builder.eval(eval_and_dedup_points.len() - Usize::from(1)); - builder.set( - &eval_and_dedup_points, - last_idx.clone(), - ClaimAndPoint { - evals: target_evals, - has_point: Usize::from(1), - point: PointVariable { fs: origin_point }, - }, - ); - - builder.assign(&last_idx, last_idx.clone() - Usize::from(1)); - builder.set( - &eval_and_dedup_points, - last_idx.clone(), - ClaimAndPoint { - evals: right_evals, - has_point: Usize::from(1), - point: PointVariable { fs: right_point }, - }, - ); - - builder.assign(&last_idx, last_idx.clone() - Usize::from(1)); - builder.set( - &eval_and_dedup_points, - last_idx.clone(), - ClaimAndPoint { - evals: left_evals, - has_point: Usize::from(1), - point: PointVariable { fs: left_point }, - }, - ); - }); - } - - let rotation_exprs_len = layer.rotation_exprs.1.len(); transcript_observe_label(builder, challenger, b"combine subset evals"); - let alpha_pows = gen_alpha_pows( - builder, - challenger, - Usize::from(layer.exprs.len() + rotation_exprs_len * ROTATION_OPENING_COUNT), - ); + let alpha_pows = gen_alpha_pows(builder, challenger, Usize::from(layer.exprs.len())); let sigma: Ext = builder.constant(C::EF::ZERO); let alpha_idx: Usize = Usize::Var(Var::uninit(builder)); @@ -961,7 +1093,7 @@ pub fn verify_gkr_circuit( unipoly_extrapolator, ); - let structural_witin_offset = layer.n_witin + layer.n_fixed + layer.n_instance; + let structural_witin_offset = layer.n_witin + layer.n_fixed; // check selector evaluations layer @@ -1176,10 +1308,16 @@ pub fn verify_rotation( let SumcheckLayerProofVariable { proof, evals, - evals_len_div_3: rotation_expr_len, + evals_len_div_3: hinted_rotation_expr_len, } = rotation_proof; - let rotation_expr_len = Usize::Var(*rotation_expr_len); + let rotation_expr_len = Usize::from(num_rotations); + builder.assert_usize_eq( + Usize::Var(*hinted_rotation_expr_len), + Usize::from(num_rotations), + ); + let expected_rotation_eval_len = Usize::from(num_rotations * 3); + builder.assert_usize_eq(evals.len(), expected_rotation_eval_len); transcript_observe_label(builder, challenger, b"combine subset evals"); let rotation_alpha_pows = gen_alpha_pows(builder, challenger, Usize::from(num_rotations)); let rotation_challenges = concat(builder, challenges, &rotation_alpha_pows); @@ -1293,7 +1431,7 @@ pub fn rotation_selector_eval( rotation_cyclic_subgroup_size: usize, cyclic_group_log2_size: usize, ) -> Ext { - let bh = BooleanHypercube::new(5); + let bh = BooleanHypercube::new(cyclic_group_log2_size); let eval: Ext = builder.constant(C::EF::ZERO); let rotation_index = bh .into_iter() @@ -1491,33 +1629,57 @@ pub fn evaluate_selector( (*wit_id as usize, eval) } -// TODO: make this as a function of BooleanHypercube pub fn get_rotation_points( builder: &mut Builder, - _num_vars: usize, + num_vars: usize, point: &Array>, ) -> (Array>, Array>) { let left: Array> = builder.dyn_array(point.len()); let right: Array> = builder.dyn_array(point.len()); - // left = (0,s0,s1,s2,s3,...) - // right = (1,s0,1-s1,s2,s3,...) - builder.range(0, 4).for_each(|idx_vec, builder| { - let e = builder.get(point, idx_vec[0]); - let dest_idx: Var = builder.eval(idx_vec[0] + RVar::from(1)); - builder.set(&left, dest_idx, e); - builder.set(&right, dest_idx, e); - }); - let one: Ext = builder.constant(C::EF::ONE); + let zero: Ext = builder.constant(C::EF::ZERO); + builder.set(&left, 0, zero); builder.set(&right, 0, one); - let r1 = builder.get(&right, 2); - builder.set(&right, 2, one - r1); - builder.range(5, point.len()).for_each(|idx_vec, builder| { - let e = builder.get(point, idx_vec[0]); - builder.set(&left, idx_vec[0], e); - builder.set(&right, idx_vec[0], e); - }); + match num_vars { + 5 => { + // left: (0, r0, r1, r2, r3, r5, r6, ...) + // right: (1, r0, 1-r1, r2, r3, r5, r6, ...) + builder.range(0, 4).for_each(|idx_vec, builder| { + let e = builder.get(point, idx_vec[0]); + let dest_idx: Var = builder.eval(idx_vec[0] + RVar::from(1)); + builder.set(&left, dest_idx, e); + builder.set(&right, dest_idx, e); + }); + let r1 = builder.get(point, 1); + builder.set(&right, 2, one - r1); + + builder.range(5, point.len()).for_each(|idx_vec, builder| { + let e = builder.get(point, idx_vec[0]); + builder.set(&left, idx_vec[0], e); + builder.set(&right, idx_vec[0], e); + }); + } + 6 => { + // left: (0, r0, r1, r2, r3, r4, r6, r7, ...) + // right: (1, 1-r0, r1, r2, r3, r4, r6, r7, ...) + builder.range(0, 5).for_each(|idx_vec, builder| { + let e = builder.get(point, idx_vec[0]); + let dest_idx: Var = builder.eval(idx_vec[0] + RVar::from(1)); + builder.set(&left, dest_idx, e); + builder.set(&right, dest_idx, e); + }); + let r0 = builder.get(point, 0); + builder.set(&right, 1, one - r0); + + builder.range(6, point.len()).for_each(|idx_vec, builder| { + let e = builder.get(point, idx_vec[0]); + builder.set(&left, idx_vec[0], e); + builder.set(&right, idx_vec[0], e); + }); + } + unsupported => unimplemented!("rotation cyclic group not supported: {unsupported}"), + } (left, right) } @@ -1629,13 +1791,8 @@ pub fn extract_claim_and_point( layer: &Layer, claims: &Array>, challenges: &Array>, - has_rotation: &Usize, ) -> Array> { - let r_len: Usize = Usize::Var(Var::uninit(builder)); - builder.assign( - &r_len, - has_rotation.clone() * Usize::from(3) + Usize::from(layer.out_sel_and_eval_exprs.len()), - ); + let r_len = Usize::from(layer.out_sel_and_eval_exprs.len()); let r = builder.dyn_array(r_len); layer .out_sel_and_eval_exprs diff --git a/ceno_zkvm/src/precompiles/lookup_keccakf.rs b/ceno_zkvm/src/precompiles/lookup_keccakf.rs index 4fed156dc..4d5de9a4d 100644 --- a/ceno_zkvm/src/precompiles/lookup_keccakf.rs +++ b/ceno_zkvm/src/precompiles/lookup_keccakf.rs @@ -7,10 +7,11 @@ use gkr_iop::{ cpu::{CpuBackend, CpuProver}, default_out_eval_groups, error::{BackendError, CircuitBuilderError}, + evaluation::EvalExpression, gkr::{ GKRCircuit, GKRProof, GKRProverOutput, booleanhypercube::{BooleanHypercube, CYCLIC_POW2_5}, - layer::Layer, + layer::{Layer, cpu::prove_rotation, zerocheck_layer::verify_rotation}, mock::MockProver, }, selector::{SelectorContext, SelectorType}, @@ -1131,14 +1132,37 @@ pub fn run_lookup_keccakf exit_span!(span); let span = entered_span!("out_eval", profiling_2 = true); - let out_evals = { - let mut point = Vec::with_capacity(log2_num_instance_rounds); - point.extend( - prover_transcript - .sample_vec(log2_num_instance_rounds) - .to_vec(), - ); + let mut point = Vec::with_capacity(log2_num_instance_rounds); + point.extend( + prover_transcript + .sample_vec(log2_num_instance_rounds) + .to_vec(), + ); + + let first_layer = gkr_circuit.layers.first().expect("empty gkr circuit layer"); + assert!( + !first_layer.rotation_exprs.1.is_empty(), + "lookup_keccakf unittest expects rotation-enabled circuit" + ); + let rotation_terms = first_layer + .rotation_sumcheck_expression_monomial_terms + .as_ref() + .expect("missing rotation sumcheck terms") + .clone(); + let (rotation_proof, rotation_points) = prove_rotation::( + num_threads, + log2_num_instance_rounds, + first_layer.rotation_cyclic_subgroup_size, + first_layer.rotation_cyclic_group_log2, + &gkr_witness.layers[0], + &first_layer.rotation_exprs.1, + rotation_terms, + &point, + &challenges, + &mut prover_transcript, + ); + let mut out_evals = { if test_outputs { // Confront outputs with tiny_keccak::keccakf call let mut instance_outputs = vec![vec![]; num_instances]; @@ -1189,6 +1213,60 @@ pub fn run_lookup_keccakf }) .collect::>() }; + + let Some([left_group_idx, right_group_idx, point_group_idx]) = + first_layer.rotation_selector_group_indices() + else { + panic!("rotation selectors missing"); + }; + + let mut left_evals = Vec::new(); + let mut right_evals = Vec::new(); + let mut point_evals = Vec::new(); + for chunk in rotation_proof.evals.chunks_exact(3) { + left_evals.push(chunk[0]); + right_evals.push(chunk[1]); + point_evals.push(chunk[2]); + } + + let assign_group = |out_evals: &mut [PointAndEval], + eval_exprs: &[EvalExpression], + evals: &[E], + point: &[E]| { + assert_eq!( + eval_exprs.len(), + evals.len(), + "rotation eval length mismatch" + ); + for (eval_expr, eval) in eval_exprs.iter().zip_eq(evals.iter()) { + let EvalExpression::Single(index) = eval_expr else { + panic!("rotation groups must use EvalExpression::Single"); + }; + out_evals[*index] = PointAndEval { + point: point.to_vec(), + eval: *eval, + }; + } + }; + + assign_group( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[left_group_idx].1, + &left_evals, + &rotation_points.left, + ); + assign_group( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[right_group_idx].1, + &right_evals, + &rotation_points.right, + ); + assign_group( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[point_group_idx].1, + &point_evals, + &rotation_points.origin, + ); exit_span!(span); if cfg!(debug_assertions) { @@ -1199,7 +1277,15 @@ pub fn run_lookup_keccakf } let span = entered_span!("create_proof", profiling_2 = true); - let selector_ctxs = vec![SelectorContext::new(0, num_instances, log2_num_instance_rounds); 3]; + let first_layer_selector_groups = gkr_circuit + .layers + .first() + .map(|layer| layer.out_sel_and_eval_exprs.len()) + .unwrap_or(0); + let selector_ctxs = vec![ + SelectorContext::new(0, num_instances, log2_num_instance_rounds); + first_layer_selector_groups + ]; let GKRProverOutput { gkr_proof, .. } = gkr_circuit .prove::, CpuProver<_>>( num_threads, @@ -1230,6 +1316,22 @@ pub fn run_lookup_keccakf .to_vec(), ); + verify_rotation( + log2_num_instance_rounds, + first_layer.rotation_exprs.1.len(), + first_layer + .rotation_sumcheck_expression + .as_ref() + .expect("missing rotation sumcheck expression"), + rotation_proof.clone(), + first_layer.rotation_cyclic_subgroup_size, + first_layer.rotation_cyclic_group_log2, + &point, + &challenges, + &mut verifier_transcript, + ) + .expect("rotation verify failed"); + gkr_circuit .verify( log2_num_instance_rounds, diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index 0f6898f6c..0d6a32680 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -1,6 +1,6 @@ use crate::structs::EccQuarkProof; use ff_ext::ExtensionField; -use gkr_iop::gkr::GKRProof; +use gkr_iop::gkr::{GKRProof, layer::sumcheck_layer::SumcheckLayerProof}; use itertools::Itertools; use mpcs::PolynomialCommitmentScheme; use p3::field::FieldAlgebra; @@ -66,6 +66,8 @@ pub struct ZKVMChipProof { pub main_sumcheck_proofs: Option>>, pub gkr_iop_proof: Option>, + // Rotation is proved at chip scope and consumed before layer verification. + pub rotation_proof: Option>, pub tower_proof: TowerProofs, pub ecc_proof: Option>, diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index 929c7850d..b105efe75 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -1,6 +1,6 @@ use super::hal::{ DeviceTransporter, MainSumcheckEvals, MainSumcheckProver, OpeningProver, ProverDevice, - TowerProver, TraceCommitter, + RotationProver, RotationProverOutput, TowerProver, TraceCommitter, }; use crate::{ error::ZKVMError, @@ -8,7 +8,10 @@ use crate::{ constants::{NUM_FANIN, SEPTIC_EXTENSION_DEGREE}, hal::{DeviceProvingKey, EccQuarkProver, ProofInput, TowerProverSpec}, septic_curve::{SepticExtension, SepticPoint, SymbolicSepticExtension}, - utils::{infer_tower_logup_witness, infer_tower_product_witness}, + utils::{ + assign_group_evals, derive_ecc_bridge_claims, extract_ecc_quark_witness_inputs, + infer_tower_logup_witness, infer_tower_product_witness, split_rotation_evals, + }, }, structs::{ComposedConstrainSystem, EccQuarkProof, PointAndEval, TowerProofs}, }; @@ -311,19 +314,22 @@ impl> EccQuarkProver( &self, - num_instances: usize, - xs: Vec>>, - ys: Vec>>, - invs: Vec>>, + cs: &ComposedConstrainSystem, + input: &ProofInput<'a, CpuBackend>, transcript: &mut impl Transcript, - ) -> Result, ZKVMError> { - Ok(CpuEccProver::create_ecc_proof( - num_instances, - xs, - ys, - invs, + ) -> Result>, ZKVMError> { + let Some(ecc_inputs) = extract_ecc_quark_witness_inputs::>(cs, input) + else { + return Ok(None); + }; + + Ok(Some(CpuEccProver::create_ecc_proof( + input.num_instances(), + ecc_inputs.xs, + ecc_inputs.ys, + ecc_inputs.slopes, transcript, - )) + ))) } } @@ -803,6 +809,64 @@ impl> TowerProver> RotationProver> + for CpuProver> +{ + fn prove_rotation<'a>( + &self, + composed_cs: &ComposedConstrainSystem, + input: &ProofInput<'a, CpuBackend>, + rt_tower: &Point, + challenges: &[E; 2], + transcript: &mut impl Transcript, + ) -> Result>, ZKVMError> { + let Some(gkr_circuit) = composed_cs.gkr_circuit.as_ref() else { + return Ok(None); + }; + let Some(layer) = gkr_circuit.layers.first() else { + return Ok(None); + }; + if layer.rotation_exprs.1.is_empty() { + return Ok(None); + } + + let Some(rotation_sumcheck_expression) = + layer.rotation_sumcheck_expression_monomial_terms.as_ref() + else { + return Ok(None); + }; + + let log2_num_instances = input.log2_num_instances(); + let num_threads = optimal_sumcheck_threads(log2_num_instances); + let num_var_with_rotation = log2_num_instances + composed_cs.rotation_vars().unwrap_or(0); + let wit = LayerWitness( + chain!(&input.witness, &input.fixed, &input.structural_witness,) + .cloned() + .collect_vec(), + ); + + let (proof, points) = gkr_iop::gkr::layer::cpu::prove_rotation::( + num_threads, + num_var_with_rotation, + layer.rotation_cyclic_subgroup_size, + layer.rotation_cyclic_group_log2, + &wit, + &layer.rotation_exprs.1, + rotation_sumcheck_expression.clone(), + rt_tower, + challenges, + transcript, + ); + + Ok(Some(RotationProverOutput { + proof, + left_point: points.left, + right_point: points.right, + point: points.origin, + })) + } +} + impl> MainSumcheckProver> for CpuProver> { @@ -816,6 +880,8 @@ impl> MainSumcheckProver( &self, rt_tower: Vec, + rotation: Option>, + ecc_proof: Option<&EccQuarkProof>, input: &'b ProofInput<'a, CpuBackend>, composed_cs: &ComposedConstrainSystem, challenges: &[E; 2], @@ -842,40 +908,121 @@ impl> MainSumcheckProver> MainSumcheckProver( (r_out_evals, w_out_evals, lk_out_evals) } +/// Standalone function for prove_rotation that doesn't require &self. +/// This allows rotation proof generation from parallel task code paths. +pub fn prove_rotation_impl>( + composed_cs: &ComposedConstrainSystem, + input: &ProofInput<'_, GpuBackend>, + rt_tower: &Point, + challenges: &[E; 2], + transcript: &mut impl Transcript, +) -> Result>, ZKVMError> { + let Some(gkr_circuit) = composed_cs.gkr_circuit.as_ref() else { + return Ok(None); + }; + let Some(layer) = gkr_circuit.layers.first() else { + return Ok(None); + }; + if layer.rotation_exprs.1.is_empty() { + return Ok(None); + } + + let Some(rotation_sumcheck_expression) = + layer.rotation_sumcheck_expression_monomial_terms.as_ref() + else { + return Ok(None); + }; + + let log2_num_instances = input.log2_num_instances(); + let num_threads = optimal_sumcheck_threads(log2_num_instances); + let num_var_with_rotation = log2_num_instances + composed_cs.rotation_vars().unwrap_or(0); + let wit = LayerWitness( + chain!(&input.witness, &input.fixed, &input.structural_witness) + .cloned() + .collect_vec(), + ); + + let (proof, points) = gkr_iop::gkr::layer::gpu::prove_rotation_gpu::( + num_threads, + num_var_with_rotation, + layer.rotation_cyclic_subgroup_size, + layer.rotation_cyclic_group_log2, + &wit, + &layer.rotation_exprs.1, + rotation_sumcheck_expression.clone(), + rt_tower, + challenges, + transcript, + ); + + Ok(Some(RotationProverOutput { + proof, + left_point: points.left, + right_point: points.right, + point: points.origin, + })) +} + /// Standalone function for prove_main_constraints that doesn't require &self /// This allows it to be called from parallel threads without Send/Sync bounds on GpuProver #[allow(clippy::type_complexity)] @@ -230,6 +292,8 @@ pub fn prove_main_constraints_impl< PCS: PolynomialCommitmentScheme + 'static, >( rt_tower: Vec, + rotation: Option>, + ecc_proof: Option<&EccQuarkProof>, input: &ProofInput<'_, GpuBackend>, composed_cs: &ComposedConstrainSystem, challenges: &[E; 2], @@ -256,40 +320,121 @@ pub fn prove_main_constraints_impl< let Some(gkr_circuit) = gkr_circuit else { panic!("empty gkr circuit") }; - let selector_ctxs = if cs.ec_final_sum.is_empty() { - // it's not global chip - vec![ - SelectorContext { - offset: 0, - num_instances, - num_vars: num_var_with_rotation, - }; - gkr_circuit - .layers - .first() - .map(|layer| layer.out_sel_and_eval_exprs.len()) - .unwrap_or(0) - ] - } else { - // it's global chip - vec![ - SelectorContext { - offset: 0, - num_instances: input.num_instances[0], - num_vars: num_var_with_rotation, - }, - SelectorContext { - offset: input.num_instances[0], - num_instances: input.num_instances[1], - num_vars: num_var_with_rotation, - }, - SelectorContext { - offset: 0, - num_instances, - num_vars: num_var_with_rotation, - }, - ] - }; + let first_layer = gkr_circuit.layers.first().expect("empty gkr circuit layer"); + let selector_ctxs = first_layer + .out_sel_and_eval_exprs + .iter() + .map(|(selector, _)| { + if cs.ec_final_sum.is_empty() { + SelectorContext { + offset: 0, + num_instances, + num_vars: num_var_with_rotation, + } + } else if cs.r_selector.as_ref() == Some(selector) { + SelectorContext { + offset: 0, + num_instances: input.num_instances[0], + num_vars: num_var_with_rotation, + } + } else if cs.w_selector.as_ref() == Some(selector) { + SelectorContext { + offset: input.num_instances[0], + num_instances: input.num_instances[1], + num_vars: num_var_with_rotation, + } + } else { + SelectorContext { + offset: 0, + num_instances, + num_vars: num_var_with_rotation, + } + } + }) + .collect_vec(); + + let mut out_evals = + vec![PointAndEval::new(rt_tower.clone(), E::ZERO); gkr_circuit.n_evaluations]; + + if let Some(rotation) = rotation.as_ref() { + let Some([left_group_idx, right_group_idx, point_group_idx]) = + first_layer.rotation_selector_group_indices() + else { + panic!("rotation proof provided for non-rotation layer") + }; + + let (left_evals, right_evals, point_evals) = split_rotation_evals(&rotation.proof.evals); + + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[left_group_idx].1, + &left_evals, + &rotation.left_point, + ); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[right_group_idx].1, + &right_evals, + &rotation.right_point, + ); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[point_group_idx].1, + &point_evals, + &rotation.point, + ); + } + + if let Some(ecc_proof) = ecc_proof { + let Some( + [ + x_group_idx, + y_group_idx, + slope_group_idx, + x3_group_idx, + y3_group_idx, + ], + ) = first_layer.ecc_bridge_group_indices() + else { + panic!("ecc proof provided for non-ecc layer") + }; + + let sample_r = transcript.sample_and_append_vec(b"ecc_gkr_bridge_r", 1)[0]; + let claims = derive_ecc_bridge_claims(ecc_proof, sample_r, num_var_with_rotation) + .expect("invalid internal ecc bridge claims"); + + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[x_group_idx].1, + &claims.x_evals, + &claims.xy_point, + ); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[y_group_idx].1, + &claims.y_evals, + &claims.xy_point, + ); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[slope_group_idx].1, + &claims.s_evals, + &claims.s_point, + ); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[x3_group_idx].1, + &claims.x3_evals, + &claims.x3y3_point, + ); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[y3_group_idx].1, + &claims.y3_evals, + &claims.x3y3_point, + ); + } + let GKRProverOutput { gkr_proof, opening_evaluations, @@ -304,8 +449,7 @@ pub fn prove_main_constraints_impl< .collect_vec(), )], }, - // eval value doesn't matter as it won't be used by prover - &vec![PointAndEval::new(rt_tower, E::ZERO); gkr_circuit.final_out_evals.len()], + &out_evals, &input .pi .iter() @@ -347,12 +491,20 @@ pub fn prove_main_constraints_impl< level = "trace" )] pub fn prove_ec_sum_quark_impl<'a, E: ExtensionField, PCS: PolynomialCommitmentScheme>( - num_instances: usize, - xs: Vec>>, - ys: Vec>>, - invs: Vec>>, + composed_cs: &ComposedConstrainSystem, + input: &ProofInput<'a, GpuBackend>, transcript: &mut impl Transcript, -) -> Result, ZKVMError> { +) -> Result>, ZKVMError> { + let Some(ecc_inputs) = + extract_ecc_quark_witness_inputs::>(composed_cs, input) + else { + return Ok(None); + }; + let xs = ecc_inputs.xs; + let ys = ecc_inputs.ys; + let invs = ecc_inputs.slopes; + + let num_instances = input.num_instances(); let stream = gkr_iop::gpu::get_thread_stream(); assert_eq!(xs.len(), SEPTIC_EXTENSION_DEGREE); assert_eq!(ys.len(), SEPTIC_EXTENSION_DEGREE); @@ -558,13 +710,13 @@ pub fn prove_ec_sum_quark_impl<'a, E: ExtensionField, PCS: PolynomialCommitmentS assert_eq!(evals.len(), 3 + SEPTIC_EXTENSION_DEGREE * 7); let final_sum = SepticPoint::from_affine(final_sum_x.clone(), final_sum_y.clone()); - Ok(EccQuarkProof { + Ok(Some(EccQuarkProof { zerocheck_proof: proof_gpu_e, num_instances, evals, rt, sum: final_sum, - }) + })) } impl> TraceCommitter> @@ -1124,6 +1276,8 @@ impl> MainSumcheckProver( &self, rt_tower: Vec, + rotation: Option>, + ecc_proof: Option<&EccQuarkProof>, // _records: Vec>, // not used by GPU after delegation input: &'b ProofInput<'a, GpuBackend>, composed_cs: &ComposedConstrainSystem, @@ -1143,6 +1297,8 @@ impl> MainSumcheckProver( rt_tower, + rotation, + ecc_proof, input, composed_cs, challenges, @@ -1156,26 +1312,39 @@ impl> MainSumcheckProver> RotationProver> + for GpuProver> +{ + fn prove_rotation<'a>( + &self, + composed_cs: &ComposedConstrainSystem, + input: &ProofInput<'a, GpuBackend>, + rt_tower: &Point, + challenges: &[E; 2], + transcript: &mut impl Transcript, + ) -> Result>, ZKVMError> { + prove_rotation_impl::(composed_cs, input, rt_tower, challenges, transcript) + } +} + impl> EccQuarkProver> for GpuProver> { fn prove_ec_sum_quark<'a>( &self, - num_instances: usize, - xs: Vec>>, - ys: Vec>>, - invs: Vec>>, + composed_cs: &ComposedConstrainSystem, + input: &ProofInput<'a, GpuBackend>, transcript: &mut impl Transcript, - ) -> Result, ZKVMError> { - // n = num_vars of the ecc quark sumcheck (xs[0].num_vars - 1) - let n = xs[0].mle.num_vars() - 1; + ) -> Result>, ZKVMError> { let cuda_hal = get_cuda_hal().expect("Failed to get CUDA HAL"); let gpu_mem_tracker = init_gpu_mem_tracker(&cuda_hal, "prove_ec_sum_quark"); - let res = prove_ec_sum_quark_impl::(num_instances, xs, ys, invs, transcript); + let res = prove_ec_sum_quark_impl::(composed_cs, input, transcript); - let estimated_bytes = estimate_ecc_quark_bytes_from_num_vars(n); - check_gpu_mem_estimation(gpu_mem_tracker, estimated_bytes); + if let Ok(Some(proof)) = &res { + let estimated_bytes = estimate_ecc_quark_bytes_from_num_vars(proof.rt.len()); + check_gpu_mem_estimation(gpu_mem_tracker, estimated_bytes); + } res } diff --git a/ceno_zkvm/src/scheme/hal.rs b/ceno_zkvm/src/scheme/hal.rs index 873020f6c..65fe06f2d 100644 --- a/ceno_zkvm/src/scheme/hal.rs +++ b/ceno_zkvm/src/scheme/hal.rs @@ -6,7 +6,7 @@ use crate::{ use either::Either; use ff_ext::ExtensionField; use gkr_iop::{ - gkr::GKRProof, + gkr::{GKRProof, layer::sumcheck_layer::SumcheckLayerProof}, hal::{ProtocolWitnessGeneratorProver, ProverBackend}, }; use mpcs::{Point, PolynomialCommitmentScheme}; @@ -24,6 +24,7 @@ pub trait ProverDevice: + DeviceTransporter + ProtocolWitnessGeneratorProver + EccQuarkProver + + RotationProver + ChipInputPreparer // + FixedMLEPadder where @@ -107,12 +108,10 @@ pub trait TraceCommitter { pub trait EccQuarkProver { fn prove_ec_sum_quark<'a>( &self, - num_instances: usize, - xs: Vec>>, - ys: Vec>>, - invs: Vec>>, + cs: &ComposedConstrainSystem, + input: &ProofInput<'a, PB>, transcript: &mut impl Transcript, - ) -> Result, ZKVMError>; + ) -> Result>, ZKVMError>; } pub trait TowerProver { @@ -155,16 +154,40 @@ pub struct MainSumcheckEvals { pub fixed_in_evals: Vec, } +#[derive(Clone)] +pub struct RotationProverOutput { + pub proof: SumcheckLayerProof, + pub left_point: Point, + pub right_point: Point, + pub point: Point, +} + +pub trait RotationProver { + fn prove_rotation<'a>( + &self, + cs: &ComposedConstrainSystem, + input: &ProofInput<'a, PB>, + rt_tower: &Point, + challenges: &[PB::E; 2], + transcript: &mut impl Transcript, + ) -> Result>, ZKVMError> { + let _ = (cs, input, rt_tower, challenges, transcript); + Ok(None) + } +} + pub trait MainSumcheckProver { // this prover aims to achieve two goals: // 1. the validity of last layer in the tower tree is reduced to // the validity of read/write/logup records through sumchecks; // 2. multiple multiplication relations between witness multilinear polynomials // achieved via zerochecks. - #[allow(clippy::type_complexity)] + #[allow(clippy::type_complexity, clippy::too_many_arguments)] fn prove_main_constraints<'a, 'b>( &self, rt_tower: Vec, + rotation: Option>, + ecc_proof: Option<&EccQuarkProof>, input: &'b ProofInput<'a, PB>, cs: &ComposedConstrainSystem, challenges: &[PB::E; 2], diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 2e0d18dbf..cf04c0ed7 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -8,14 +8,13 @@ use std::{collections::BTreeMap, marker::PhantomData, sync::Arc}; #[cfg(feature = "gpu")] use crate::scheme::gpu::estimate_chip_proof_memory; use crate::scheme::{ - constants::SEPTIC_EXTENSION_DEGREE, hal::MainSumcheckEvals, scheduler::{ChipScheduler, ChipTask, ChipTaskResult}, }; use either::Either; use itertools::Itertools; use mpcs::{Point, PolynomialCommitmentScheme}; -use multilinear_extensions::{Expression, Instance}; +use multilinear_extensions::Instance; use p3::field::FieldAlgebra; use std::iter::Iterator; use sumcheck::{ @@ -429,39 +428,6 @@ impl< let log2_num_instances = input.log2_num_instances(); let num_var_with_rotation = log2_num_instances + cs.rotation_vars().unwrap_or(0); - // run ecc quark prover - let ecc_proof = if !cs.zkvm_v1_css.ec_final_sum.is_empty() { - let span = entered_span!("run_ecc_final_sum", profiling_2 = true); - let ec_point_exprs = &cs.zkvm_v1_css.ec_point_exprs; - assert_eq!(ec_point_exprs.len(), SEPTIC_EXTENSION_DEGREE * 2); - let mut xs_ys = ec_point_exprs - .iter() - .map(|expr| match expr { - Expression::WitIn(id) => input.witness[*id as usize].clone(), - _ => unreachable!("ec point's expression must be WitIn"), - }) - .collect_vec(); - let ys = xs_ys.split_off(SEPTIC_EXTENSION_DEGREE); - let xs = xs_ys; - let slopes = cs - .zkvm_v1_css - .ec_slope_exprs - .iter() - .map(|expr| match expr { - Expression::WitIn(id) => input.witness[*id as usize].clone(), - _ => unreachable!("slope's expression must be WitIn"), - }) - .collect_vec(); - let ecc_proof = Some(info_span!("[ceno] prove_ec_sum_quark").in_scope(|| { - self.device - .prove_ec_sum_quark(input.num_instances(), xs, ys, slopes, transcript) - })?); - exit_span!(span); - ecc_proof - } else { - None - }; - // build main witness let records = info_span!("[ceno] build_main_witness") .in_scope(|| build_main_witness::(cs, input, challenges)); @@ -481,13 +447,32 @@ impl< num_var_with_rotation, ); + let span = entered_span!("run_ecc_final_sum", profiling_2 = true); + let ecc_proof = info_span!("[ceno] prove_ec_sum_quark") + .in_scope(|| self.device.prove_ec_sum_quark(cs, input, transcript))?; + exit_span!(span); + + let span = entered_span!("prove_rotation", profiling_2 = true); + let rotation = info_span!("[ceno] prove_rotation").in_scope(|| { + self.device + .prove_rotation(cs, input, &rt_tower, challenges, transcript) + })?; + exit_span!(span); + // 1. prove the main constraints among witness polynomials // 2. prove the relation between last layer in the tower and read/write/logup records let span = entered_span!("prove_main_constraints", profiling_2 = true); let (input_opening_point, evals, main_sumcheck_proofs, gkr_iop_proof) = info_span!("[ceno] prove_main_constraints").in_scope(|| { - self.device - .prove_main_constraints(rt_tower, input, cs, challenges, transcript) + self.device.prove_main_constraints( + rt_tower, + rotation.clone(), + ecc_proof.as_ref(), + input, + cs, + challenges, + transcript, + ) })?; let MainSumcheckEvals { wits_in_evals, @@ -502,6 +487,7 @@ impl< lk_out_evals, main_sumcheck_proofs, gkr_iop_proof, + rotation_proof: rotation.map(|r| r.proof), tower_proof, ecc_proof, num_instances: input.num_instances, @@ -736,7 +722,7 @@ where { use crate::scheme::gpu::{ extract_witness_mles_for_trace, prove_ec_sum_quark_impl, prove_main_constraints_impl, - prove_tower_relation_impl, transport_structural_witness_to_gpu, + prove_rotation_impl, prove_tower_relation_impl, transport_structural_witness_to_gpu, }; use gkr_iop::gpu::{GpuBackend, get_cuda_hal}; @@ -773,38 +759,6 @@ where }); } - // run ecc quark prover using _impl function - let ecc_proof = if !cs.zkvm_v1_css.ec_final_sum.is_empty() { - let span = entered_span!("run_ecc_final_sum", profiling_2 = true); - let ec_point_exprs = &cs.zkvm_v1_css.ec_point_exprs; - assert_eq!(ec_point_exprs.len(), SEPTIC_EXTENSION_DEGREE * 2); - let mut xs_ys = ec_point_exprs - .iter() - .map(|expr| match expr { - Expression::WitIn(id) => input.witness[*id as usize].clone(), - _ => unreachable!("ec point's expression must be WitIn"), - }) - .collect_vec(); - let ys = xs_ys.split_off(SEPTIC_EXTENSION_DEGREE); - let xs = xs_ys; - let slopes = cs - .zkvm_v1_css - .ec_slope_exprs - .iter() - .map(|expr| match expr { - Expression::WitIn(id) => input.witness[*id as usize].clone(), - _ => unreachable!("slope's expression must be WitIn"), - }) - .collect_vec(); - let ecc_proof = Some(info_span!("[ceno] prove_ec_sum_quark").in_scope(|| { - prove_ec_sum_quark_impl::(input.num_instances(), xs, ys, slopes, transcript) - })?); - exit_span!(span); - ecc_proof - } else { - None - }; - // build main witness let records = info_span!("[ceno] build_main_witness").in_scope(|| { @@ -828,11 +782,30 @@ where assert_eq!(rt_tower.len(), num_var_with_rotation,); + let span = entered_span!("run_ecc_final_sum", profiling_2 = true); + let ecc_proof = info_span!("[ceno] prove_ec_sum_quark") + .in_scope(|| prove_ec_sum_quark_impl::(cs, &input, transcript))?; + exit_span!(span); + + let span = entered_span!("prove_rotation", profiling_2 = true); + let rotation = info_span!("[ceno] prove_rotation").in_scope(|| { + prove_rotation_impl::(cs, &input, &rt_tower, challenges, transcript) + })?; + exit_span!(span); + // prove main constraints using _impl function let span = entered_span!("prove_main_constraints", profiling_2 = true); let (input_opening_point, evals, main_sumcheck_proofs, gkr_iop_proof) = info_span!("[ceno] prove_main_constraints").in_scope(|| { - prove_main_constraints_impl::(rt_tower, &input, cs, challenges, transcript) + prove_main_constraints_impl::( + rt_tower, + rotation.clone(), + ecc_proof.as_ref(), + &input, + cs, + challenges, + transcript, + ) })?; let MainSumcheckEvals { wits_in_evals, @@ -847,6 +820,7 @@ where lk_out_evals, main_sumcheck_proofs, gkr_iop_proof, + rotation_proof: rotation.map(|r| r.proof), tower_proof, ecc_proof, num_instances: input.num_instances, diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index 8583e2710..ec7758ed1 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -1,21 +1,26 @@ use crate::{ + error::ZKVMError, scheme::{ - constants::MIN_PAR_SIZE, + constants::{MIN_PAR_SIZE, SEPTIC_EXTENSION_DEGREE}, hal::{ProofInput, ProverDevice}, }, - structs::ComposedConstrainSystem, + structs::{ComposedConstrainSystem, EccQuarkProof, PointAndEval}, }; use either::Either; use ff_ext::ExtensionField; use gkr_iop::{ evaluation::EvalExpression, - gkr::{GKRCircuit, GKRCircuitOutput, GKRCircuitWitness, layer::LayerWitness}, + gkr::{ + GKRCircuit, GKRCircuitOutput, GKRCircuitWitness, + layer::{LayerWitness, ROTATION_OPENING_COUNT}, + }, hal::{MultilinearPolynomial, ProtocolWitnessGeneratorProver, ProverBackend}, }; use itertools::Itertools; -use mpcs::PolynomialCommitmentScheme; +use mpcs::{Point, PolynomialCommitmentScheme}; pub use multilinear_extensions::wit_infer_by_expr; use multilinear_extensions::{ + Expression, mle::{ArcMultilinearExtension, FieldType, IntoMLE, MultilinearExtension}, util::ceil_log2, }; @@ -29,6 +34,187 @@ use rayon::{ use std::{iter, sync::Arc}; use witness::next_pow2_instance_padding; +pub(crate) struct EccBridgeClaims { + pub(crate) xy_point: Point, + pub(crate) s_point: Point, + pub(crate) x3y3_point: Point, + pub(crate) x_evals: Vec, + pub(crate) y_evals: Vec, + pub(crate) s_evals: Vec, + pub(crate) x3_evals: Vec, + pub(crate) y3_evals: Vec, +} + +pub(crate) struct EccQuarkWitnessInputs<'a, PB: ProverBackend> { + pub(crate) xs: Vec>>, + pub(crate) ys: Vec>>, + pub(crate) slopes: Vec>>, +} + +pub(crate) fn extract_ecc_quark_witness_inputs<'a, PB: ProverBackend>( + cs: &ComposedConstrainSystem, + input: &ProofInput<'a, PB>, +) -> Option> { + let cs = &cs.zkvm_v1_css; + if cs.ec_final_sum.is_empty() { + return None; + } + + let ec_point_exprs = &cs.ec_point_exprs; + assert_eq!(ec_point_exprs.len(), SEPTIC_EXTENSION_DEGREE * 2); + let mut xs_ys = ec_point_exprs + .iter() + .map(|expr| match expr { + Expression::WitIn(id) => input.witness[*id as usize].clone(), + _ => unreachable!("ec point's expression must be WitIn"), + }) + .collect_vec(); + let ys = xs_ys.split_off(SEPTIC_EXTENSION_DEGREE); + let xs = xs_ys; + + let slopes = cs + .ec_slope_exprs + .iter() + .map(|expr| match expr { + Expression::WitIn(id) => input.witness[*id as usize].clone(), + _ => unreachable!("slope's expression must be WitIn"), + }) + .collect_vec(); + + Some(EccQuarkWitnessInputs { xs, ys, slopes }) +} + +pub(crate) fn derive_ecc_bridge_claims( + ecc_proof: &EccQuarkProof, + sample_r: E, + num_var_with_rotation: usize, +) -> Result, ZKVMError> { + let degree = SEPTIC_EXTENSION_DEGREE; + if ecc_proof.evals.len() < 3 { + return Err(ZKVMError::InvalidProof( + "ecc proof evals shorter than selector prefix".into(), + )); + } + let evals = &ecc_proof.evals[3..]; + if evals.len() != degree * 7 { + return Err(ZKVMError::InvalidProof( + format!( + "invalid ecc proof eval length: expected {}, got {}", + degree * 7, + evals.len() + ) + .into(), + )); + } + + let s1 = &evals[0..degree]; + let x0 = &evals[degree..2 * degree]; + let y0 = &evals[2 * degree..3 * degree]; + let x1 = &evals[3 * degree..4 * degree]; + let y1 = &evals[4 * degree..5 * degree]; + let x3 = &evals[5 * degree..6 * degree]; + let y3 = &evals[6 * degree..7 * degree]; + + let one_minus_r = E::ONE - sample_r; + let x_evals = x0 + .iter() + .zip_eq(x1.iter()) + .map(|(a, b)| *a * one_minus_r + *b * sample_r) + .collect_vec(); + let y_evals = y0 + .iter() + .zip_eq(y1.iter()) + .map(|(a, b)| *a * one_minus_r + *b * sample_r) + .collect_vec(); + let s_evals = s1.iter().map(|v| *v * sample_r).collect_vec(); + let x3_evals = x3.to_vec(); + let y3_evals = y3.to_vec(); + + let mut xy_point = vec![sample_r]; + xy_point.extend(ecc_proof.rt.iter().copied()); + if xy_point.len() != num_var_with_rotation { + return Err(ZKVMError::InvalidProof( + format!( + "invalid ecc xy point length: expected {}, got {}", + num_var_with_rotation, + xy_point.len() + ) + .into(), + )); + } + + let mut s_point = ecc_proof.rt.clone(); + s_point.push(sample_r); + if s_point.len() != num_var_with_rotation { + return Err(ZKVMError::InvalidProof( + format!( + "invalid ecc slope point length: expected {}, got {}", + num_var_with_rotation, + s_point.len() + ) + .into(), + )); + } + + let mut x3y3_point = ecc_proof.rt.clone(); + x3y3_point.push(E::ONE); + if x3y3_point.len() != num_var_with_rotation { + return Err(ZKVMError::InvalidProof( + format!( + "invalid ecc x3/y3 point length: expected {}, got {}", + num_var_with_rotation, + x3y3_point.len() + ) + .into(), + )); + } + + Ok(EccBridgeClaims { + xy_point, + s_point, + x3y3_point, + x_evals, + y_evals, + s_evals, + x3_evals, + y3_evals, + }) +} + +pub(crate) fn split_rotation_evals(evals: &[E]) -> (Vec, Vec, Vec) { + assert_eq!( + evals.len() % ROTATION_OPENING_COUNT, + 0, + "rotation evals length must be a multiple of {}, got {}", + ROTATION_OPENING_COUNT, + evals.len() + ); + let mut left_evals = Vec::new(); + let mut right_evals = Vec::new(); + let mut point_evals = Vec::new(); + for chunk in evals.chunks_exact(ROTATION_OPENING_COUNT) { + left_evals.push(chunk[0]); + right_evals.push(chunk[1]); + point_evals.push(chunk[2]); + } + (left_evals, right_evals, point_evals) +} + +pub(crate) fn assign_group_evals( + out_evals: &mut [PointAndEval], + eval_exprs: &[EvalExpression], + evals: &[E], + point: &Point, +) { + assert_eq!(eval_exprs.len(), evals.len(), "group eval length mismatch"); + for (eval_expr, eval) in eval_exprs.iter().zip_eq(evals.iter()) { + let EvalExpression::Single(index) = eval_expr else { + panic!("group must use EvalExpression::Single"); + }; + out_evals[*index] = PointAndEval::new(point.clone(), *eval); + } +} + /// Wrapper that asserts a shared reference is safe to send across threads. /// /// # Safety @@ -407,7 +593,7 @@ pub fn gkr_witness< phase1_witness_group: &[Arc>], structural_witness: &[Arc>], fixed: &[Arc>], - pub_io_mles: &[Arc>], + _pub_io_mles: &[Arc>], pub_io_evals: &[Either], challenges: &[E], ) -> (GKRCircuitWitness<'b, PB>, GKRCircuitOutput<'b, PB>) { @@ -438,16 +624,6 @@ pub fn gkr_witness< witness_mle_flatten[*index] = Some(fixed_mle.clone()); }); - first_layer - .in_eval_expr - .iter() - .skip(first_layer.n_witin + first_layer.n_fixed) - .take(first_layer.n_instance) - .zip_eq(pub_io_mles.iter()) - .for_each(|(index, pubio_mle)| { - witness_mle_flatten[*index] = Some(pubio_mle.clone()); - }); - // XXX currently fixed poly not support in layers > 1 // TODO process fixed (and probably short) mle // @@ -500,10 +676,7 @@ pub fn gkr_witness< assert_eq!( current_layer_wits.len(), - layer.n_witin - + layer.n_fixed - + layer.n_instance - + if i == 0 { layer.n_structural_witin } else { 0 } + layer.n_witin + layer.n_fixed + if i == 0 { layer.n_structural_witin } else { 0 } ); // infer current layer output diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 9d595d0a5..60613f1ab 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -17,6 +17,7 @@ use crate::{ scheme::{ constants::{NUM_FANIN, SEPTIC_EXTENSION_DEGREE}, septic_curve::{SepticExtension, SepticPoint}, + utils::{assign_group_evals, derive_ecc_bridge_claims}, }, structs::{ ComposedConstrainSystem, EccQuarkProof, PointAndEval, TowerProofs, VerifyingKey, @@ -536,19 +537,7 @@ impl> ZKVMVerifier assert_eq!(num_vars, log2_num_instances); }); - // verify ecc proof if exists - let shard_ec_sum: Option> = if composed_cs.has_ecc_ops() { - tracing::debug!("verifying ecc proof..."); - assert!(proof.ecc_proof.is_some()); - let ecc_proof = proof.ecc_proof.as_ref().unwrap(); - assert!(!ecc_proof.sum.is_infinity); - - EccVerifier::verify_ecc_proof(ecc_proof, transcript)?; - tracing::debug!("ecc proof verified."); - Some(ecc_proof.sum.clone()) - } else { - None - }; + let mut shard_ec_sum: Option> = None; // verify and reduce product tower sumcheck let tower_proofs = &proof.tower_proof; @@ -564,7 +553,7 @@ impl> ZKVMVerifier transcript.append_field_element_ext(eval); } - let (_, record_evals, logup_p_evals, logup_q_evals) = TowerVerify::verify( + let (rt_tower, record_evals, logup_p_evals, logup_q_evals) = TowerVerify::verify( proof .r_out_evals .iter() @@ -593,6 +582,23 @@ impl> ZKVMVerifier })?; } + if composed_cs.has_ecc_ops() { + tracing::debug!("verifying ecc proof..."); + let ecc_proof = proof + .ecc_proof + .as_ref() + .ok_or_else(|| ZKVMError::InvalidProof("missing ecc proof".into()))?; + if ecc_proof.sum.is_infinity { + return Err(ZKVMError::InvalidProof( + "invalid ecc proof: infinity shard sum".into(), + )); + } + + EccVerifier::verify_ecc_proof(ecc_proof, transcript)?; + tracing::debug!("ecc proof verified."); + shard_ec_sum = Some(ecc_proof.sum.clone()); + } + debug_assert!( chain!(&record_evals, &logup_p_evals, &logup_q_evals) .map(|e| &e.point) @@ -605,7 +611,7 @@ impl> ZKVMVerifier debug_assert_eq!(logup_p_evals.len(), lk_counts_per_instance); debug_assert_eq!(logup_q_evals.len(), lk_counts_per_instance); - let evals = record_evals + let base_evals = record_evals .iter() // append p_evals if there got lk table expressions .chain(if cs.lk_table_expressions.is_empty() { @@ -618,43 +624,133 @@ impl> ZKVMVerifier .collect_vec(); let gkr_circuit = gkr_circuit.as_ref().unwrap(); - let selector_ctxs = if cs.ec_final_sum.is_empty() { - assert_eq!(proof.num_instances[1], 0); - // it's not shard chip - vec![ - SelectorContext::new(0, num_instances, num_var_with_rotation); - gkr_circuit - .layers - .first() - .map(|layer| layer.out_sel_and_eval_exprs.len()) - .unwrap_or(0) - ] - } else { - // it's shard chip - tracing::debug!( - "num_reads: {}, num_writes: {}, total: {}", - proof.num_instances[0], - proof.num_instances[1], - proof.num_instances[0] + proof.num_instances[1], + let first_layer = gkr_circuit.layers.first().expect("empty gkr circuit layer"); + let selector_ctxs = first_layer + .out_sel_and_eval_exprs + .iter() + .map(|(selector, _)| { + if cs.ec_final_sum.is_empty() { + SelectorContext::new(0, num_instances, num_var_with_rotation) + } else if cs.r_selector.as_ref() == Some(selector) { + SelectorContext::new(0, proof.num_instances[0], num_var_with_rotation) + } else if cs.w_selector.as_ref() == Some(selector) { + SelectorContext::new( + proof.num_instances[0], + proof.num_instances[1], + num_var_with_rotation, + ) + } else { + SelectorContext::new(0, num_instances, num_var_with_rotation) + } + }) + .collect_vec(); + + let mut out_evals = vec![PointAndEval::default(); gkr_circuit.n_evaluations]; + for (idx, point_and_eval) in base_evals.into_iter().enumerate() { + out_evals[idx] = point_and_eval; + } + + if !first_layer.rotation_exprs.1.is_empty() { + let rotation_proof = proof + .rotation_proof + .as_ref() + .ok_or_else(|| ZKVMError::InvalidProof("missing rotation proof".into()))? + .clone(); + + let rotation_claims = gkr_iop::gkr::layer::zerocheck_layer::verify_rotation( + num_var_with_rotation, + first_layer.rotation_exprs.1.len(), + first_layer + .rotation_sumcheck_expression + .as_ref() + .expect("missing rotation sumcheck expression"), + rotation_proof, + first_layer.rotation_cyclic_subgroup_size, + first_layer.rotation_cyclic_group_log2, + &rt_tower, + challenges, + transcript, + )?; + + let Some([left_group_idx, right_group_idx, point_group_idx]) = + first_layer.rotation_selector_group_indices() + else { + return Err(ZKVMError::InvalidProof( + "rotation claims expected but selectors are missing".into(), + )); + }; + + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[left_group_idx].1, + &rotation_claims.left_evals, + &rotation_claims.rotation_points.left, ); - vec![ - SelectorContext { - offset: 0, - num_instances: proof.num_instances[0], - num_vars: num_var_with_rotation, - }, - SelectorContext { - offset: proof.num_instances[0], - num_instances: proof.num_instances[1], - num_vars: num_var_with_rotation, - }, - SelectorContext { - offset: 0, - num_instances: proof.num_instances[0] + proof.num_instances[1], - num_vars: num_var_with_rotation, - }, - ] - }; + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[right_group_idx].1, + &rotation_claims.right_evals, + &rotation_claims.rotation_points.right, + ); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[point_group_idx].1, + &rotation_claims.target_evals, + &rotation_claims.rotation_points.origin, + ); + } + + if let Some(ecc_proof) = proof.ecc_proof.as_ref() { + let Some( + [ + x_group_idx, + y_group_idx, + slope_group_idx, + x3_group_idx, + y3_group_idx, + ], + ) = first_layer.ecc_bridge_group_indices() + else { + return Err(ZKVMError::InvalidProof( + "ecc bridge claims expected but selectors are missing".into(), + )); + }; + + let sample_r = transcript.sample_and_append_vec(b"ecc_gkr_bridge_r", 1)[0]; + let claims = derive_ecc_bridge_claims(ecc_proof, sample_r, num_var_with_rotation)?; + + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[x_group_idx].1, + &claims.x_evals, + &claims.xy_point, + ); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[y_group_idx].1, + &claims.y_evals, + &claims.xy_point, + ); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[slope_group_idx].1, + &claims.s_evals, + &claims.s_point, + ); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[x3_group_idx].1, + &claims.x3_evals, + &claims.x3y3_point, + ); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[y3_group_idx].1, + &claims.y3_evals, + &claims.x3y3_point, + ); + } + let pi = cs .instance .iter() @@ -664,7 +760,7 @@ impl> ZKVMVerifier let (_, rt) = gkr_circuit.verify( num_var_with_rotation, proof.gkr_iop_proof.clone().unwrap(), - &evals, + &out_evals, &pi, challenges, transcript, diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 82719a7bc..31540e3bc 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -44,7 +44,7 @@ pub struct EccQuarkProof { pub zerocheck_proof: IOPProof, /// Number of EC points being summed pub num_instances: usize, - pub evals: Vec, // x[rt,0], x[rt,1], y[rt,0], y[rt,1], x[0,rt], y[0,rt], s[0,rt] + pub evals: Vec, /* [sel_add, sel_bypass, sel_export] ++ [s[1,rt], x[rt,0], y[rt,0], x[rt,1], y[rt,1], x[1,rt], y[1,rt]] */ pub rt: Point, pub sum: SepticPoint, } diff --git a/ceno_zkvm/src/tables/shard_ram.rs b/ceno_zkvm/src/tables/shard_ram.rs index 6c5ef93b5..658ec85a1 100644 --- a/ceno_zkvm/src/tables/shard_ram.rs +++ b/ceno_zkvm/src/tables/shard_ram.rs @@ -418,6 +418,11 @@ impl TableCircuit for ShardRamCircuit { let selector_r = cb.create_placeholder_structural_witin(|| "selector_r"); let selector_w = cb.create_placeholder_structural_witin(|| "selector_w"); let selector_zero = cb.create_placeholder_structural_witin(|| "selector_zero"); + let selector_ecc_x = cb.create_placeholder_structural_witin(|| "selector_ecc_x"); + let selector_ecc_y = cb.create_placeholder_structural_witin(|| "selector_ecc_y"); + let selector_ecc_s = cb.create_placeholder_structural_witin(|| "selector_ecc_s"); + let selector_ecc_x3 = cb.create_placeholder_structural_witin(|| "selector_ecc_x3"); + let selector_ecc_y3 = cb.create_placeholder_structural_witin(|| "selector_ecc_y3"); let config = Self::construct_circuit(cb, param)?; @@ -439,6 +444,13 @@ impl TableCircuit for ShardRamCircuit { cb.cs.w_selector = Some(selector_w); cb.cs.zero_selector = Some(selector_zero.clone()); cb.cs.lk_selector = Some(selector_zero); + cb.cs.ec_bridge_selectors = Some([ + SelectorType::Whole(selector_ecc_x.expr()), + SelectorType::Whole(selector_ecc_y.expr()), + SelectorType::Whole(selector_ecc_s.expr()), + SelectorType::Whole(selector_ecc_x3.expr()), + SelectorType::Whole(selector_ecc_y3.expr()), + ]); // all shared the same selector let (out_evals, mut chip) = ( @@ -487,10 +499,20 @@ impl TableCircuit for ShardRamCircuit { // this is workaround, as call `construct_circuit` will not initialized selector // we can remove this one all opcode unittest migrate to call `build_gkr_iop_circuit` - assert_eq!(num_structural_witin, 3); + // ShardRam expects exactly these structural selectors: + // r, w, zero, ecc_x, ecc_y, ecc_s, ecc_x3, ecc_y3. + assert_eq!( + num_structural_witin, 8, + "ShardRam requires exactly 8 structural selectors (r,w,zero,ecc_x,ecc_y,ecc_s,ecc_x3,ecc_y3)" + ); let selector_r_witin = WitIn { id: 0 }; let selector_w_witin = WitIn { id: 1 }; let selector_zero_witin = WitIn { id: 2 }; + let selector_ecc_x_witin = WitIn { id: 3 }; + let selector_ecc_y_witin = WitIn { id: 4 }; + let selector_ecc_s_witin = WitIn { id: 5 }; + let selector_ecc_x3_witin = WitIn { id: 6 }; + let selector_ecc_y3_witin = WitIn { id: 7 }; let nthreads = max_usable_threads(); @@ -539,6 +561,17 @@ impl TableCircuit for ShardRamCircuit { ); RowMajorMatrix::new(value, num_structural_witin) }; + // ECC bridge selectors are `Whole`, so keep them active on all rows. + raw_structual_witin + .values + .par_chunks_mut(num_structural_witin) + .for_each(|row| { + set_val!(row, selector_ecc_x_witin, E::BaseField::ONE); + set_val!(row, selector_ecc_y_witin, E::BaseField::ONE); + set_val!(row, selector_ecc_s_witin, E::BaseField::ONE); + set_val!(row, selector_ecc_x3_witin, E::BaseField::ONE); + set_val!(row, selector_ecc_y3_witin, E::BaseField::ONE); + }); let raw_witin_iter = raw_witin.values[0..steps.len() * num_witin] .par_chunks_mut(num_instance_per_batch * num_witin); let raw_structual_witin_iter = raw_structual_witin.values diff --git a/gkr_iop/src/chip.rs b/gkr_iop/src/chip.rs index c1d423c26..39be9436b 100644 --- a/gkr_iop/src/chip.rs +++ b/gkr_iop/src/chip.rs @@ -1,4 +1,7 @@ -use crate::{circuit_builder::CircuitBuilder, gkr::layer::Layer}; +use crate::{ + circuit_builder::CircuitBuilder, + gkr::layer::{ECC_BRIDGE_OPENING_COUNT, Layer, ROTATION_OPENING_COUNT}, +}; use ff_ext::ExtensionField; use itertools::Itertools; use serde::{Deserialize, Serialize, de::DeserializeOwned}; @@ -30,24 +33,25 @@ pub struct Chip { impl Chip { pub fn new_from_cb(cb: &CircuitBuilder) -> Chip { + let rotation_eval_count = cb.cs.rotations.len() * ROTATION_OPENING_COUNT; + let ecc_eval_count = if cb.cs.ec_point_exprs.is_empty() { + 0 + } else { + cb.cs.ec_slope_exprs.len() * ECC_BRIDGE_OPENING_COUNT + }; + let num_non_zero_outputs = cb.cs.w_expressions.len() + + cb.cs.r_expressions.len() + + cb.cs.lk_expressions.len() + + cb.cs.w_table_expressions.len() + + cb.cs.r_table_expressions.len() + + cb.cs.lk_table_expressions.len() * 2 + + rotation_eval_count + + ecc_eval_count; Self { n_fixed: cb.cs.num_fixed, n_committed: cb.cs.num_witin as usize, - n_evaluations: cb.cs.w_expressions.len() - + cb.cs.r_expressions.len() - + cb.cs.lk_expressions.len() - + cb.cs.w_table_expressions.len() - + cb.cs.r_table_expressions.len() - + cb.cs.lk_table_expressions.len() * 2 - + cb.cs.num_fixed - + cb.cs.num_witin as usize, - final_out_evals: (0..cb.cs.w_expressions.len() - + cb.cs.r_expressions.len() - + cb.cs.lk_expressions.len() - + cb.cs.w_table_expressions.len() - + cb.cs.r_table_expressions.len() - + cb.cs.lk_table_expressions.len() * 2) - .collect_vec(), + n_evaluations: num_non_zero_outputs + cb.cs.num_fixed + cb.cs.num_witin as usize, + final_out_evals: (0..num_non_zero_outputs).collect_vec(), layers: vec![], } } diff --git a/gkr_iop/src/circuit_builder.rs b/gkr_iop/src/circuit_builder.rs index b66d05300..874afa208 100644 --- a/gkr_iop/src/circuit_builder.rs +++ b/gkr_iop/src/circuit_builder.rs @@ -8,8 +8,11 @@ use serde::de::DeserializeOwned; use std::{collections::HashMap, iter::once, marker::PhantomData}; use crate::{ - RAMType, error::CircuitBuilderError, gkr::layer::ROTATION_OPENING_COUNT, - selector::SelectorType, tables::LookupTable, + RAMType, + error::CircuitBuilderError, + gkr::layer::{ECC_BRIDGE_OPENING_COUNT, ROTATION_OPENING_COUNT}, + selector::SelectorType, + tables::LookupTable, }; use p3::field::FieldAlgebra; @@ -107,6 +110,7 @@ pub struct ConstraintSystem { pub ec_point_exprs: Vec>, pub ec_slope_exprs: Vec>, pub ec_final_sum: Vec>, + pub ec_bridge_selectors: Option<[SelectorType; ECC_BRIDGE_OPENING_COUNT]>, pub r_selector: Option>, pub r_expressions: Vec>, @@ -179,6 +183,7 @@ impl ConstraintSystem { ec_final_sum: vec![], ec_slope_exprs: vec![], ec_point_exprs: vec![], + ec_bridge_selectors: None, r_selector: None, r_expressions: vec![], r_expressions_namespace_map: vec![], diff --git a/gkr_iop/src/gkr/layer.rs b/gkr_iop/src/gkr/layer.rs index 175e7fb0e..ebf279b11 100644 --- a/gkr_iop/src/gkr/layer.rs +++ b/gkr_iop/src/gkr/layer.rs @@ -41,6 +41,7 @@ pub type RotateExprs = ( // rotation contribute // left + right + target, overall 3 pub const ROTATION_OPENING_COUNT: usize = 3; +pub const ECC_BRIDGE_OPENING_COUNT: usize = 5; #[derive(Clone, Debug, Serialize, Deserialize)] pub enum LayerType { @@ -71,7 +72,6 @@ pub struct Layer { pub n_witin: usize, pub n_structural_witin: usize, pub n_fixed: usize, - pub n_instance: usize, pub max_expr_degree: usize, /// keep all structural witin which could be evaluated succinctly without PCS pub structural_witins: Vec, @@ -100,6 +100,7 @@ pub struct Layer { // there got 3 different eq for (left, right, target) during rotation argument // refer https://hackmd.io/HAAj1JTQQiKfu0SIwOJDRw?view#Rotation pub rotation_exprs: RotateExprs, + pub ecc_bridge_group_indices: Option<[usize; ECC_BRIDGE_OPENING_COUNT]>, pub rotation_cyclic_group_log2: usize, pub rotation_cyclic_subgroup_size: usize, @@ -140,7 +141,6 @@ impl Layer { n_witin: usize, n_structural_witin: usize, n_fixed: usize, - n_instance: usize, // exprs concat zero/non-zero expression. exprs: Vec>, in_eval_expr: Vec, @@ -169,7 +169,6 @@ impl Layer { n_witin, n_structural_witin, n_fixed, - n_instance, max_expr_degree, structural_witins, exprs, @@ -177,6 +176,7 @@ impl Layer { in_eval_expr, out_sel_and_eval_exprs, rotation_exprs: (rotation_eq, rotation_exprs), + ecc_bridge_group_indices: None, rotation_cyclic_group_log2, rotation_cyclic_subgroup_size, expr_names, @@ -338,18 +338,54 @@ impl Layer { assert_eq!(lookup_evals.len(), lk_len); assert_eq!(zero_evals.len(), zero_len); + let rotation_expr_len = cb.cs.rotations.len() * ROTATION_OPENING_COUNT; + let ecc_bridge_expr_len = if cb.cs.ec_point_exprs.is_empty() { + 0 + } else { + cb.cs.ec_slope_exprs.len() * ECC_BRIDGE_OPENING_COUNT + }; + let mut next_non_zero_eval_idx = r_record_evals + .iter() + .chain(w_record_evals.iter()) + .chain(lookup_evals.iter()) + .copied() + .max() + .map_or(0, |max_idx| max_idx + 1); let non_zero_expr_len = cb.cs.w_expressions.len() + cb.cs.w_table_expressions.len() + cb.cs.r_expressions.len() + cb.cs.r_table_expressions.len() + cb.cs.lk_expressions.len() - + cb.cs.lk_table_expressions.len() * 2; + + cb.cs.lk_table_expressions.len() * 2 + + rotation_expr_len + + ecc_bridge_expr_len; let zero_expr_len = cb.cs.assert_zero_expressions.len() + cb.cs.assert_zero_sumcheck_expressions.len(); - let mut expr_evals = Vec::with_capacity(4); + let selector_group_capacity = [ + cb.cs.r_selector.as_ref(), + cb.cs.w_selector.as_ref(), + cb.cs.lk_selector.as_ref(), + cb.cs.zero_selector.as_ref(), + ] + .iter() + .filter(|selector| selector.is_some()) + .count() + + if cb.cs.rotations.is_empty() { + 0 + } else { + ROTATION_OPENING_COUNT + } + + if cb.cs.ec_point_exprs.is_empty() { + 0 + } else { + ECC_BRIDGE_OPENING_COUNT + }; + let mut expr_evals = Vec::with_capacity(selector_group_capacity); let mut expr_names = Vec::with_capacity(non_zero_expr_len + zero_expr_len); let mut expressions = Vec::with_capacity(non_zero_expr_len + zero_expr_len); + let mut ecc_bridge_group_indices: Option<[usize; ECC_BRIDGE_OPENING_COUNT]> = None; + let mut ecc_bridge_eval_bases: Option<[usize; ECC_BRIDGE_OPENING_COUNT]> = None; if let Some(r_selector) = cb.cs.r_selector.as_ref() { // process r_record @@ -474,6 +510,144 @@ impl Layer { } } + if !cb.cs.rotations.is_empty() { + let Some(RotationParams { + rotation_eqs: Some([rotation_left_eq, rotation_right_eq, rotation_eq]), + .. + }) = cb.cs.rotation_params.as_ref() + else { + panic!("rotation params not set"); + }; + + // Rotation claims occupy 3 * num_rotations dedicated out-eval entries: + // [left_0..left_n][right_0..right_n][target_0..target_n]. + let num_rotations = cb.cs.rotations.len(); + let rotation_left_eval_base = next_non_zero_eval_idx; + let rotation_right_eval_base = rotation_left_eval_base + num_rotations; + let rotation_eval_base = rotation_right_eval_base + num_rotations; + next_non_zero_eval_idx = rotation_eval_base + num_rotations; + + // Rotation selector groups must be fresh groups (no dedup with preceding + // r/w/lookup groups) so chip-level rotation claim assignment is unambiguous. + let left_group_idx = expr_evals.len(); + expr_evals.push((SelectorType::Whole(rotation_left_eq.clone()), vec![])); + let right_group_idx = expr_evals.len(); + expr_evals.push((SelectorType::Whole(rotation_right_eq.clone()), vec![])); + let target_group_idx = expr_evals.len(); + expr_evals.push((SelectorType::Whole(rotation_eq.clone()), vec![])); + + // Expression order must match flattened out-eval group order: + // [left...][right...][target...]. + for (idx, (rotate_expr, _)) in cb.cs.rotations.iter().enumerate() { + expressions.push(rotate_expr.clone()); + expr_evals[left_group_idx] + .1 + .push(EvalExpression::Single(rotation_left_eval_base + idx)); + expr_names.push(format!("rotation/left/{idx}")); + } + + for (idx, (rotate_expr, _)) in cb.cs.rotations.iter().enumerate() { + expressions.push(rotate_expr.clone()); + expr_evals[right_group_idx] + .1 + .push(EvalExpression::Single(rotation_right_eval_base + idx)); + expr_names.push(format!("rotation/right/{idx}")); + } + + for (idx, (_, target_expr)) in cb.cs.rotations.iter().enumerate() { + expressions.push(target_expr.clone()); + expr_evals[target_group_idx] + .1 + .push(EvalExpression::Single(rotation_eval_base + idx)); + expr_names.push(format!("rotation/point/{idx}")); + } + } + + if !cb.cs.ec_point_exprs.is_empty() { + let septic_degree = cb.cs.ec_slope_exprs.len(); + assert_eq!(cb.cs.ec_point_exprs.len(), septic_degree * 2); + + // ECC bridge selector groups must be explicitly supplied and independent. + // Do not fall back to (or reuse) preceding r/w/lk/zero selectors. + let [ecc_sel_x, ecc_sel_y, ecc_sel_s, ecc_sel_x3, ecc_sel_y3] = + cb.cs.ec_bridge_selectors.clone().expect( + "ecc bridge selectors must be provided when ec_point_exprs is non-empty", + ); + + let x_eval_base = next_non_zero_eval_idx; + let y_eval_base = x_eval_base + septic_degree; + let s_eval_base = y_eval_base + septic_degree; + let x3_eval_base = s_eval_base + septic_degree; + let y3_eval_base = x3_eval_base + septic_degree; + next_non_zero_eval_idx = y3_eval_base + septic_degree; + ecc_bridge_eval_bases = Some([ + x_eval_base, + y_eval_base, + s_eval_base, + x3_eval_base, + y3_eval_base, + ]); + + let x_group_idx = expr_evals.len(); + expr_evals.push((ecc_sel_x.clone(), vec![])); + let y_group_idx = expr_evals.len(); + expr_evals.push((ecc_sel_y.clone(), vec![])); + let s_group_idx = expr_evals.len(); + expr_evals.push((ecc_sel_s.clone(), vec![])); + let x3_group_idx = expr_evals.len(); + expr_evals.push((ecc_sel_x3, vec![])); + let y3_group_idx = expr_evals.len(); + expr_evals.push((ecc_sel_y3, vec![])); + ecc_bridge_group_indices = Some([ + x_group_idx, + y_group_idx, + s_group_idx, + x3_group_idx, + y3_group_idx, + ]); + + for (idx, x_expr) in cb.cs.ec_point_exprs[..septic_degree].iter().enumerate() { + expressions.push(x_expr.clone()); + expr_evals[x_group_idx] + .1 + .push(EvalExpression::Single(x_eval_base + idx)); + expr_names.push(format!("ecc_bridge/x/{idx}")); + } + + for (idx, y_expr) in cb.cs.ec_point_exprs[septic_degree..].iter().enumerate() { + expressions.push(y_expr.clone()); + expr_evals[y_group_idx] + .1 + .push(EvalExpression::Single(y_eval_base + idx)); + expr_names.push(format!("ecc_bridge/y/{idx}")); + } + + for (idx, slope_expr) in cb.cs.ec_slope_exprs.iter().enumerate() { + expressions.push(slope_expr.clone()); + expr_evals[s_group_idx] + .1 + .push(EvalExpression::Single(s_eval_base + idx)); + expr_names.push(format!("ecc_bridge/slope/{idx}")); + } + + // x3/y3 reuse x/y expressions but are opened at rt||1 instead of [r]||rt. + for (idx, x_expr) in cb.cs.ec_point_exprs[..septic_degree].iter().enumerate() { + expressions.push(x_expr.clone()); + expr_evals[x3_group_idx] + .1 + .push(EvalExpression::Single(x3_eval_base + idx)); + expr_names.push(format!("ecc_bridge/x3/{idx}")); + } + + for (idx, y_expr) in cb.cs.ec_point_exprs[septic_degree..].iter().enumerate() { + expressions.push(y_expr.clone()); + expr_evals[y3_group_idx] + .1 + .push(EvalExpression::Single(y3_eval_base + idx)); + expr_names.push(format!("ecc_bridge/y3/{idx}")); + } + } + if let Some(zero_selector) = cb.cs.zero_selector.as_ref() { // process zero_record let evals = Self::dedup_last_selector_evals(zero_selector, &mut expr_evals); @@ -503,24 +677,63 @@ impl Layer { .. } = &cb.cs; - let in_eval_expr = (non_zero_expr_len..) + // Drop selector groups that ended up without eval expressions. + expr_evals.retain(|(_, evals)| !evals.is_empty()); + + if let Some([x_base, y_base, s_base, x3_base, y3_base]) = ecc_bridge_eval_bases { + let find_group_by_base = |base: usize| { + expr_evals + .iter() + .enumerate() + .find_map(|(idx, (_, evals))| match evals.first() { + Some(EvalExpression::Single(pos)) if *pos == base => Some(idx), + _ => None, + }) + }; + let x_idx = find_group_by_base(x_base) + .expect("missing x ecc bridge selector group after retain"); + let y_idx = find_group_by_base(y_base) + .expect("missing y ecc bridge selector group after retain"); + let s_idx = find_group_by_base(s_base) + .expect("missing slope ecc bridge selector group after retain"); + let x3_idx = find_group_by_base(x3_base) + .expect("missing x3 ecc bridge selector group after retain"); + let y3_idx = find_group_by_base(y3_base) + .expect("missing y3 ecc bridge selector group after retain"); + ecc_bridge_group_indices = Some([x_idx, y_idx, s_idx, x3_idx, y3_idx]); + } + + let out_eval_count = expr_evals + .iter() + .map(|(_, evals)| evals.len()) + .sum::(); + assert_eq!( + expressions.len(), + out_eval_count, + "expression/out-eval ordering mismatch: exprs={}, out_evals={}", + expressions.len(), + out_eval_count, + ); + + let in_eval_expr = (next_non_zero_eval_idx..) .take(cb.cs.num_witin as usize + cb.cs.num_fixed) .collect_vec(); if rotations.is_empty() { - Layer::new( + let mut layer = Layer::new( layer_name, LayerType::Zerocheck, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, cb.cs.num_fixed, - 0, expressions, in_eval_expr, expr_evals, ((None, vec![]), 0, 0), expr_names, cb.cs.structural_witins.clone(), - ) + ); + layer.ecc_bridge_group_indices = ecc_bridge_group_indices; + layer } else { let Some(RotationParams { rotation_eqs, @@ -530,13 +743,12 @@ impl Layer { else { panic!("rotation params not set"); }; - Layer::new( + let mut layer = Layer::new( layer_name, LayerType::Zerocheck, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, cb.cs.num_fixed, - 0, expressions, in_eval_expr, expr_evals, @@ -547,7 +759,9 @@ impl Layer { ), expr_names, cb.cs.structural_witins.clone(), - ) + ); + layer.ecc_bridge_group_indices = ecc_bridge_group_indices; + layer } } @@ -571,6 +785,37 @@ impl Layer { &mut expr_evals.last_mut().unwrap().1 } + + pub fn rotation_selector_group_indices(&self) -> Option<[usize; ROTATION_OPENING_COUNT]> { + let [left_eq, right_eq, point_eq] = self.rotation_exprs.0.as_ref()?; + + let find_group = |selector_expr: &Expression| { + self.out_sel_and_eval_exprs + .iter() + .enumerate() + .find_map(|(idx, (sel_type, _))| { + let expr = match sel_type { + SelectorType::Whole(expr) + | SelectorType::Prefix(expr) + | SelectorType::OrderedSparse { + expression: expr, .. + } + | SelectorType::QuarkBinaryTreeLessThan(expr) => expr, + SelectorType::None => return None, + }; + (expr == selector_expr).then_some(idx) + }) + }; + + let left_idx = find_group(left_eq)?; + let right_idx = find_group(right_eq)?; + let point_idx = find_group(point_eq)?; + Some([left_idx, right_idx, point_idx]) + } + + pub fn ecc_bridge_group_indices(&self) -> Option<[usize; ECC_BRIDGE_OPENING_COUNT]> { + self.ecc_bridge_group_indices + } } impl<'a, PB: ProverBackend> LayerWitness<'a, PB> { diff --git a/gkr_iop/src/gkr/layer/cpu/mod.rs b/gkr_iop/src/gkr/layer/cpu/mod.rs index 00e472bf0..de12967a1 100644 --- a/gkr_iop/src/gkr/layer/cpu/mod.rs +++ b/gkr_iop/src/gkr/layer/cpu/mod.rs @@ -16,10 +16,7 @@ use ff_ext::ExtensionField; use itertools::{Itertools, chain}; use mpcs::PolynomialCommitmentScheme; use multilinear_extensions::{ - Expression, - mle::{MultilinearExtension, Point}, - monomial::Term, - virtual_poly::build_eq_x_r_vec, + Expression, mle::Point, monomial::Term, virtual_poly::build_eq_x_r_vec, virtual_polys::VirtualPolynomialsBuilder, }; use rayon::{ @@ -37,11 +34,11 @@ use transcript::Transcript; use crate::{ gkr::layer::{ - ROTATION_OPENING_COUNT, hal::LinearLayerProver, sumcheck_layer::{LayerProof, SumcheckLayerProof}, }, hal::ProverBackend, + selector::SelectorType, }; impl> LinearLayerProver> @@ -65,7 +62,6 @@ impl> LinearLayerProver> SumcheckLayerProver< proof, evals: prover_state.get_mle_flatten_final_evaluations(), }, - rotation: None, } } } @@ -134,115 +129,81 @@ impl> ZerocheckLayerProver selector_ctxs.len() ); - let (_, raw_rotation_exprs) = &layer.rotation_exprs; - let (rotation_proof, rotation_left, rotation_right, rotation_point) = - if let Some(rotation_sumcheck_expression) = - layer.rotation_sumcheck_expression_monomial_terms.as_ref() - { - // 1st sumcheck: process rotation_exprs - let rt = out_points.first().unwrap(); - let ( - proof, - RotationPoints { - left, - right, - origin, - }, - ) = prove_rotation( - num_threads, - max_num_variables, - layer.rotation_cyclic_subgroup_size, - layer.rotation_cyclic_group_log2, - &wit, - raw_rotation_exprs, - rotation_sumcheck_expression.clone(), - rt, - challenges, - transcript, - ); - (Some(proof), Some(left), Some(right), Some(origin)) - } else { - (None, None, None, None) - }; - - // 2th sumcheck: batch rotation with other constrains + // Main sumcheck: constraints are fully unified in out_sel_and_eval_exprs. let span = entered_span!("build_out_points_eq", profiling_4 = true); let main_sumcheck_challenges = chain!( challenges.iter().copied(), - get_challenge_pows( - layer.exprs.len() + raw_rotation_exprs.len() * ROTATION_OPENING_COUNT, - transcript, - ) + get_challenge_pows(layer.exprs.len(), transcript) ) .collect_vec(); - // zero check eq || rotation eq - let mut eqs = layer + // Build selector eq MLEs in parallel, then merge deterministically by structural wit id. + let selector_eq_pairs = layer .out_sel_and_eval_exprs .par_iter() .zip(out_points.par_iter()) .zip(selector_ctxs.par_iter()) .filter_map(|(((sel_type, _), point), selector_ctx)| { - sel_type.compute(point, selector_ctx) + let eq = sel_type.compute(point, selector_ctx)?; + let selector_expr = match sel_type { + SelectorType::Whole(expr) + | SelectorType::Prefix(expr) + | SelectorType::OrderedSparse { + expression: expr, .. + } + | SelectorType::QuarkBinaryTreeLessThan(expr) => expr, + SelectorType::None => return None, + }; + let Expression::StructuralWitIn(wit_id, _) = selector_expr else { + panic!("selector expression must be StructuralWitIn"); + }; + let wit_id = *wit_id as usize; + assert!( + wit_id < layer.n_structural_witin, + "selector wit id out of range" + ); + Some((wit_id, eq)) }) - // for rotation left point - .chain(rotation_left.par_iter().map(|rotation_left| { - MultilinearExtension::from_evaluations_ext_vec( - rotation_left.len(), - build_eq_x_r_vec(rotation_left), - ) - })) - // for rotation right point - .chain(rotation_right.par_iter().map(|rotation_right| { - MultilinearExtension::from_evaluations_ext_vec( - rotation_right.len(), - build_eq_x_r_vec(rotation_right), - ) - })) - // for rotation point - .chain(rotation_point.par_iter().map(|rotation_point| { - MultilinearExtension::from_evaluations_ext_vec( - rotation_point.len(), - build_eq_x_r_vec(rotation_point), - ) - })) .collect::>(); + + let mut selector_eq_by_wit_id = vec![None; layer.n_structural_witin]; + for (wit_id, eq) in selector_eq_pairs { + if selector_eq_by_wit_id[wit_id].is_none() { + selector_eq_by_wit_id[wit_id] = Some(eq); + } + } exit_span!(span); - // `wit` := witin ++ fixed ++ pubio - // we concat eq in between `wit` := witin ++ eqs ++ fixed - let all_witins = wit - .iter() - .take(layer.n_witin + layer.n_fixed + layer.n_instance) - .map(|mle| Either::Left(mle.as_ref())) - .chain( - // some non-selector structural witin - wit.iter() - .skip(layer.n_witin + layer.n_fixed + layer.n_instance) - .take( - layer.n_structural_witin - - layer.out_sel_and_eval_exprs.len() - - layer - .rotation_exprs - .0 - .as_ref() - .map(|_| ROTATION_OPENING_COUNT) - .unwrap_or(0), - ) - .map(|mle| Either::Left(mle.as_ref())), - ) - .chain(eqs.iter_mut().map(Either::Right)) - .collect_vec(); + // `wit` := witin ++ fixed ++ structural + // selector structural witins are replaced by computed eq MLEs in-place by witness id. + let base_wit_count = layer.n_witin + layer.n_fixed; + let mut all_witins = + Vec::with_capacity(layer.n_witin + layer.n_structural_witin + layer.n_fixed); + all_witins.extend( + wit.iter() + .take(base_wit_count) + .map(|mle| Either::Left(mle.as_ref())), + ); + for (selector_eq, mle) in selector_eq_by_wit_id.iter_mut().zip( + wit.iter() + .skip(base_wit_count) + .take(layer.n_structural_witin), + ) { + if let Some(eq) = selector_eq.as_mut() { + all_witins.push(Either::Right(eq)); + } else { + all_witins.push(Either::Left(mle.as_ref())); + } + } assert_eq!( all_witins.len(), - layer.n_witin + layer.n_structural_witin + layer.n_fixed + layer.n_instance, - "all_witins.len() {} != layer.n_witin {} + layer.n_structural_witin {} + layer.n_fixed {} + layer.n_instance {}", + layer.n_witin + layer.n_structural_witin + layer.n_fixed, + "all_witins.len() {} != layer.n_witin {} + layer.n_structural_witin {} + layer.n_fixed {}", all_witins.len(), layer.n_witin, layer.n_structural_witin, layer.n_fixed, - layer.n_instance, ); let builder = @@ -266,7 +227,6 @@ impl> ZerocheckLayerProver ( LayerProof { main: SumcheckLayerProof { proof, evals }, - rotation: rotation_proof, }, prover_state.collect_raw_challenges(), ) @@ -281,7 +241,7 @@ impl> ZerocheckLayerProver /// rotated_rotation_expr[i].0(rx) == (1 - rx_4) * rotation_expr[i].1(0, rx_0, rx_1, ..., rx_3, rx_5, ...) /// + rx_4 * rotation_expr[i].1(1, rx_0, 1 - rx_1, ..., rx_3, rx_5, ...) #[allow(clippy::too_many_arguments)] -pub(crate) fn prove_rotation>( +pub fn prove_rotation>( num_threads: usize, max_num_variables: usize, rotation_cyclic_subgroup_size: usize, diff --git a/gkr_iop/src/gkr/layer/gpu/mod.rs b/gkr_iop/src/gkr/layer/gpu/mod.rs index 3ec80e99b..14372729e 100644 --- a/gkr_iop/src/gkr/layer/gpu/mod.rs +++ b/gkr_iop/src/gkr/layer/gpu/mod.rs @@ -27,7 +27,6 @@ use transcript::{BasicTranscript, Transcript}; use crate::{ gkr::layer::{ - ROTATION_OPENING_COUNT, hal::LinearLayerProver, sumcheck_layer::{LayerProof, SumcheckLayerProof}, }, @@ -38,7 +37,7 @@ use ceno_gpu::common::sumcheck::CommonTermPlan; use crate::gpu::{MultilinearExtensionGpu, gpu_prover::*}; pub mod utils; -use crate::selector::SelectorContext; +use crate::selector::{SelectorContext, SelectorType}; use utils::*; impl> LinearLayerProver> @@ -88,7 +87,7 @@ impl> ZerocheckLayerProver ) { let stream = crate::gpu::get_thread_stream(); let span = entered_span!("ZerocheckLayerProver", profiling_2 = true); - let num_threads = 1; // VP builder for GPU: do not use _num_threads + let _num_threads = 1; // VP builder for GPU: do not use host thread parallelism assert_eq!(challenges.len(), 2); assert_eq!( @@ -99,108 +98,83 @@ impl> ZerocheckLayerProver out_points.len(), ); - let (_, raw_rotation_exprs) = &layer.rotation_exprs; - let (rotation_proof, rotation_left, rotation_right, rotation_point) = - if let Some(rotation_sumcheck_expression) = - layer.rotation_sumcheck_expression_monomial_terms.as_ref() - { - // 1st sumcheck: process rotation_exprs - let rt = out_points.first().unwrap(); - let ( - proof, - RotationPoints { - left, - right, - origin, - }, - ) = prove_rotation_gpu( - num_threads, - max_num_variables, - layer.rotation_cyclic_subgroup_size, - layer.rotation_cyclic_group_log2, - &wit, - raw_rotation_exprs, - rotation_sumcheck_expression.clone(), - rt, - challenges, - transcript, - ); - (Some(proof), Some(left), Some(right), Some(origin)) - } else { - (None, None, None, None) - }; - - // 2th sumcheck: batch rotation with other constrains + // Main sumcheck: constraints are fully unified in out_sel_and_eval_exprs. let main_sumcheck_challenges = chain!( challenges.iter().copied(), - get_challenge_pows( - layer.exprs.len() + raw_rotation_exprs.len() * ROTATION_OPENING_COUNT, - transcript, - ) + get_challenge_pows(layer.exprs.len(), transcript) ) .collect_vec(); let span_eq = entered_span!("build eqs", profiling_2 = true); let cuda_hal = get_cuda_hal().unwrap(); - let eqs_gpu = layer + let selector_eq_pairs = layer .out_sel_and_eval_exprs .iter() .zip(out_points.iter()) .zip(selector_ctxs.iter()) - .map(|(((sel_type, _), point), selector_ctx)| { - build_eq_x_r_with_sel_gpu(&cuda_hal, point, selector_ctx, sel_type) + .filter_map(|(((sel_type, _), point), selector_ctx)| { + let eq = build_eq_x_r_with_sel_gpu(&cuda_hal, point, selector_ctx, sel_type); + let selector_expr = match sel_type { + SelectorType::Whole(expr) + | SelectorType::Prefix(expr) + | SelectorType::OrderedSparse { + expression: expr, .. + } + | SelectorType::QuarkBinaryTreeLessThan(expr) => expr, + SelectorType::None => return None, + }; + let Expression::StructuralWitIn(wit_id, _) = selector_expr else { + panic!("selector expression must be StructuralWitIn"); + }; + let wit_id = *wit_id as usize; + assert!( + wit_id < layer.n_structural_witin, + "selector wit id out of range" + ); + Some((wit_id, eq)) }) - // for rotation left point - .chain( - rotation_left - .iter() - .map(|rotation_left| build_eq_x_r_gpu(&cuda_hal, rotation_left)), - ) - // for rotation right point - .chain( - rotation_right - .iter() - .map(|rotation_right| build_eq_x_r_gpu(&cuda_hal, rotation_right)), - ) - // for rotation point - .chain( - rotation_point - .iter() - .map(|rotation_point| build_eq_x_r_gpu(&cuda_hal, rotation_point)), - ) .collect::>(); - // `wit` := witin ++ fixed ++ pubio + + let mut selector_eq_by_wit_id: Vec>> = + vec![None; layer.n_structural_witin]; + for (wit_id, eq) in selector_eq_pairs { + if selector_eq_by_wit_id[wit_id].is_none() { + selector_eq_by_wit_id[wit_id] = Some(eq); + } + } + + // `wit` := witin ++ fixed ++ structural + // selector structural witins are replaced by computed eq MLEs in-place by witness id. + let base_wit_count = layer.n_witin + layer.n_fixed; let all_witins_gpu = wit .iter() - .take(layer.n_witin + layer.n_fixed + layer.n_instance) + .take(base_wit_count) .map(|mle| mle.as_ref()) .chain( - // some non-selector structural witin - wit.iter() - .skip(layer.n_witin + layer.n_fixed + layer.n_instance) - .take( - layer.n_structural_witin - - layer.out_sel_and_eval_exprs.len() - - layer - .rotation_exprs - .0 - .as_ref() - .map(|_| ROTATION_OPENING_COUNT) - .unwrap_or(0), + selector_eq_by_wit_id + .iter_mut() + .zip( + wit.iter() + .skip(base_wit_count) + .take(layer.n_structural_witin), ) - .map(|mle| mle.as_ref()), + .map(|(selector_eq, mle)| { + if let Some(eq) = selector_eq.as_mut() { + eq + } else { + mle.as_ref() + } + }), ) - .chain(eqs_gpu.iter()) .collect_vec(); assert_eq!( all_witins_gpu.len(), - layer.n_witin + layer.n_structural_witin + layer.n_fixed + layer.n_instance, - "all_witins.len() {} != layer.n_witin {} + layer.n_structural_witin {} + layer.n_fixed {} + layer.n_instance {}", + layer.n_witin + layer.n_structural_witin + layer.n_fixed, + "all_witins.len() {} != layer.n_witin {} + layer.n_structural_witin {} + layer.n_fixed {}", all_witins_gpu.len(), layer.n_witin, layer.n_structural_witin, layer.n_fixed, - layer.n_instance, ); exit_span!(span_eq); @@ -307,7 +281,6 @@ impl> ZerocheckLayerProver proof: proof_gpu_e, evals: evals_gpu_e, }, - rotation: rotation_proof, }, row_challenges_e, ) @@ -323,7 +296,7 @@ impl> ZerocheckLayerProver /// + rx_4 * rotation_expr[i].1(1, rx_0, 1 - rx_1, ..., rx_3, rx_5, ...) #[allow(clippy::too_many_arguments)] #[tracing::instrument(skip_all, name = "prove_rotation_gpu", level = "info")] -pub(crate) fn prove_rotation_gpu>( +pub fn prove_rotation_gpu>( _num_threads: usize, max_num_variables: usize, rotation_cyclic_subgroup_size: usize, diff --git a/gkr_iop/src/gkr/layer/sumcheck_layer.rs b/gkr_iop/src/gkr/layer/sumcheck_layer.rs index fb739935f..68ca7d311 100644 --- a/gkr_iop/src/gkr/layer/sumcheck_layer.rs +++ b/gkr_iop/src/gkr/layer/sumcheck_layer.rs @@ -21,7 +21,6 @@ use super::{Layer, LayerWitness, linear_layer::LayerClaims}; deserialize = "E::BaseField: DeserializeOwned" ))] pub struct LayerProof { - pub rotation: Option>, pub main: SumcheckLayerProof, } diff --git a/gkr_iop/src/gkr/layer/zerocheck_layer.rs b/gkr_iop/src/gkr/layer/zerocheck_layer.rs index cb0b91fa1..2ac2cb2b6 100644 --- a/gkr_iop/src/gkr/layer/zerocheck_layer.rs +++ b/gkr_iop/src/gkr/layer/zerocheck_layer.rs @@ -27,9 +27,7 @@ use crate::{ evaluation::EvalExpression, gkr::{ booleanhypercube::BooleanHypercube, - layer::{ - ROTATION_OPENING_COUNT, hal::ZerocheckLayerProver, sumcheck_layer::SumcheckLayerProof, - }, + layer::{hal::ZerocheckLayerProver, sumcheck_layer::SumcheckLayerProof}, }, hal::{ProverBackend, ProverDevice}, selector::{SelectorContext, SelectorType}, @@ -40,17 +38,17 @@ use crate::{ }, }; -pub(crate) struct RotationPoints { +pub struct RotationPoints { pub left: Point, pub right: Point, pub origin: Point, } -pub(crate) struct RotationClaims { - left_evals: Vec, - right_evals: Vec, - target_evals: Vec, - rotation_points: RotationPoints, +pub struct RotationClaims { + pub left_evals: Vec, + pub right_evals: Vec, + pub target_evals: Vec, + pub rotation_points: RotationPoints, } pub trait ZerocheckLayer { @@ -140,17 +138,17 @@ impl ZerocheckLayer for Layer { &expr, self.n_witin as WitnessId, self.n_fixed as WitnessId, - self.n_instance, + 0, ) }) .collect::>(); // build main sumcheck expression let alpha_pows_expr = (2..) - .take(self.exprs.len() + num_rotations * ROTATION_OPENING_COUNT) + .take(self.exprs.len()) .map(|id| Expression::Challenge(id as ChallengeId, 1, E::ONE, E::ZERO)) .collect_vec(); - let mut zero_expr = extend_exprs_with_rotation(self, &alpha_pows_expr) + let mut zero_expr = rlc_zero_expr(self, &alpha_pows_expr) .into_iter() .sum::>(); @@ -161,7 +159,7 @@ impl ZerocheckLayer for Layer { expr, self.n_witin as WitnessId, self.n_fixed as WitnessId, - self.n_instance, + 0, ) }); @@ -169,7 +167,7 @@ impl ZerocheckLayer for Layer { &mut zero_expr, self.n_witin as WitnessId, self.n_fixed as WitnessId, - self.n_instance, + 0, ); tracing::trace!("{} main sumcheck degree: {}", self.name, zero_expr.degree()); self.main_sumcheck_expression = Some(zero_expr); @@ -227,7 +225,7 @@ impl ZerocheckLayer for Layer { &self, max_num_variables: usize, proof: LayerProof, - mut eval_and_dedup_points: Vec<(Vec, Option>)>, + eval_and_dedup_points: Vec<(Vec, Option>)>, pub_io_evals: &[E], challenges: &[E], transcript: &mut impl Transcript, @@ -246,54 +244,17 @@ impl ZerocheckLayer for Layer { proof: IOPProof { proofs }, evals: main_evals, }, - rotation: rotation_proof, } = proof; assert_eq!( main_evals.len(), - self.n_witin + self.n_fixed + self.n_instance + self.n_structural_witin, + self.n_witin + self.n_fixed + self.n_structural_witin, "invalid main_evals length", ); - if let Some(rotation_proof) = rotation_proof { - // verify rotation proof - let rt = eval_and_dedup_points - .first() - .and_then(|(_, rt)| rt.as_ref()) - .expect("rotation proof should have at least one point"); - let RotationClaims { - left_evals, - right_evals, - target_evals, - rotation_points: - RotationPoints { - left: left_point, - right: right_point, - origin: origin_point, - }, - } = verify_rotation( - max_num_variables, - self.rotation_exprs.1.len(), - self.rotation_sumcheck_expression.as_ref().unwrap(), - rotation_proof, - self.rotation_cyclic_subgroup_size, - self.rotation_cyclic_group_log2, - rt, - challenges, - transcript, - )?; - eval_and_dedup_points.push((left_evals, Some(left_point))); - eval_and_dedup_points.push((right_evals, Some(right_point))); - eval_and_dedup_points.push((target_evals, Some(origin_point))); - } - - let rotation_exprs_len = self.rotation_exprs.1.len(); let main_sumcheck_challenges = chain!( challenges.iter().copied(), - get_challenge_pows( - self.exprs.len() + rotation_exprs_len * ROTATION_OPENING_COUNT, - transcript, - ) + get_challenge_pows(self.exprs.len(), transcript) ) .collect_vec(); @@ -320,7 +281,7 @@ impl ZerocheckLayer for Layer { ); let in_point = in_point.into_iter().map(|c| c.elements).collect_vec(); - let structural_witin_offset = self.n_witin + self.n_fixed + self.n_instance; + let structural_witin_offset = self.n_witin + self.n_fixed; // eval selector and set to respective witin izip!( &self.out_sel_and_eval_exprs, @@ -706,7 +667,7 @@ fn log_common_term_plan_stats( } #[allow(clippy::too_many_arguments)] -fn verify_rotation( +pub fn verify_rotation( max_num_variables: usize, num_rotations: usize, rotation_sumcheck_expression: &Expression, @@ -805,11 +766,11 @@ fn verify_rotation( }) } -pub fn extend_exprs_with_rotation( +pub fn rlc_zero_expr( layer: &Layer, alpha_pows: &[Expression], ) -> Vec> { - let offset_structural_witid = (layer.n_witin + layer.n_fixed + layer.n_instance) as WitnessId; + let offset_structural_witid = (layer.n_witin + layer.n_fixed) as WitnessId; let mut alpha_pows_iter = alpha_pows.iter(); let mut expr_iter = layer.exprs.iter(); let mut zero_check_exprs = Vec::with_capacity(layer.out_sel_and_eval_exprs.len()); @@ -840,72 +801,6 @@ pub fn extend_exprs_with_rotation( zero_check_exprs.push(expr); } - // prepare rotation expr - let (rotation_eq, rotation_exprs) = &layer.rotation_exprs; - if rotation_eq.is_none() { - return zero_check_exprs; - } - - let left_rotation_expr: Expression = izip!( - rotation_exprs.iter(), - alpha_pows_iter.by_ref().take(rotation_exprs.len()) - ) - .map(|((rotate_expr, _), alpha)| { - assert!(matches!(rotate_expr, Expression::WitIn(_))); - alpha * rotate_expr - }) - .sum(); - let right_rotation_expr: Expression = izip!( - rotation_exprs.iter(), - alpha_pows_iter.by_ref().take(rotation_exprs.len()) - ) - .map(|((rotate_expr, _), alpha)| { - assert!(matches!(rotate_expr, Expression::WitIn(_))); - alpha * rotate_expr - }) - .sum(); - let rotation_expr: Expression = izip!( - rotation_exprs.iter(), - alpha_pows_iter.by_ref().take(rotation_exprs.len()) - ) - .map(|((_, expr), alpha)| { - assert!(matches!(expr, Expression::WitIn(_))); - alpha * expr - }) - .sum(); - - // push rotation expr to zerocheck expr - if let Some( - [ - rotation_left_eq_expr, - rotation_right_eq_expr, - rotation_eq_expr, - ], - ) = rotation_eq.as_ref() - { - let (rotation_left_eq_expr, rotation_right_eq_expr, rotation_eq_expr) = match ( - rotation_left_eq_expr, - rotation_right_eq_expr, - rotation_eq_expr, - ) { - ( - Expression::StructuralWitIn(left_eq_id, ..), - Expression::StructuralWitIn(right_eq_id, ..), - Expression::StructuralWitIn(eq_id, ..), - ) => ( - Expression::WitIn(offset_structural_witid + *left_eq_id), - Expression::WitIn(offset_structural_witid + *right_eq_id), - Expression::WitIn(offset_structural_witid + *eq_id), - ), - invalid => panic!("invalid eq format {:?}", invalid), - }; - // add rotation left expr - zero_check_exprs.push(rotation_left_eq_expr * left_rotation_expr); - // add rotation right expr - zero_check_exprs.push(rotation_right_eq_expr * right_rotation_expr); - // add target expr - zero_check_exprs.push(rotation_eq_expr * rotation_expr); - } assert!(expr_iter.next().is_none() && alpha_pows_iter.next().is_none()); zero_check_exprs diff --git a/gkr_iop/src/gkr/layer_constraint_system.rs b/gkr_iop/src/gkr/layer_constraint_system.rs index c074feedc..a76f4d927 100644 --- a/gkr_iop/src/gkr/layer_constraint_system.rs +++ b/gkr_iop/src/gkr/layer_constraint_system.rs @@ -410,7 +410,6 @@ impl LayerConstraintSystem { self.num_witin, 0, self.num_fixed, - 0, expressions, in_eval_expr, expr_evals, @@ -433,7 +432,6 @@ impl LayerConstraintSystem { self.num_witin, 0, self.num_fixed, - 0, expressions, in_eval_expr, expr_evals, diff --git a/gkr_iop/src/gkr/mock.rs b/gkr_iop/src/gkr/mock.rs index 43c8239ff..07b5658fd 100644 --- a/gkr_iop/src/gkr/mock.rs +++ b/gkr_iop/src/gkr/mock.rs @@ -74,7 +74,7 @@ impl MockProver { &(sel.selector_expr() * expr), layer.n_witin as WitnessId, layer.n_fixed as WitnessId, - layer.n_instance, + 0, &[], &wits, &structural_wits, @@ -93,7 +93,6 @@ impl MockProver { out.mock_evaluate( layer.n_witin as WitnessId, layer.n_fixed as WitnessId, - layer.n_instance, &evaluations, &challenges, num_vars, @@ -148,7 +147,6 @@ impl EvalExpression { &self, n_witin: WitnessId, n_fixed: WitnessId, - n_instance: usize, evals: &[ArcMultilinearExtension<'a, E>], challenges: &[E], num_vars: usize, @@ -162,7 +160,7 @@ impl EvalExpression { &(Expression::WitIn(*i as WitnessId) * *c0.clone() + *c1.clone()), n_witin, n_fixed, - n_instance, + 0, &[], evals, &[], @@ -174,11 +172,7 @@ impl EvalExpression { assert_eq!(parts.len(), 1 << indices.len()); let parts = parts .iter() - .map(|part| { - part.mock_evaluate( - n_witin, n_fixed, n_instance, evals, challenges, num_vars, - ) - }) + .map(|part| part.mock_evaluate(n_witin, n_fixed, evals, challenges, num_vars)) .collect::, _>>()?; indices .iter()