diff --git a/ceno_zkvm/src/precompiles/bitwise_keccakf.rs b/ceno_zkvm/src/precompiles/bitwise_keccakf.rs index 9f8415da4..de35653d9 100644 --- a/ceno_zkvm/src/precompiles/bitwise_keccakf.rs +++ b/ceno_zkvm/src/precompiles/bitwise_keccakf.rs @@ -910,6 +910,7 @@ pub fn run_keccakf + 'stat &[], &[], &[], + None, ); exit_span!(span); diff --git a/ceno_zkvm/src/precompiles/fptower/fp.rs b/ceno_zkvm/src/precompiles/fptower/fp.rs index 63a89f61e..04267f53e 100644 --- a/ceno_zkvm/src/precompiles/fptower/fp.rs +++ b/ceno_zkvm/src/precompiles/fptower/fp.rs @@ -407,6 +407,7 @@ mod tests { &[], &[], &challenges, + None, ); let out_evals = { diff --git a/ceno_zkvm/src/precompiles/fptower/fp2_addsub.rs b/ceno_zkvm/src/precompiles/fptower/fp2_addsub.rs index 85c56e430..fb42f1f26 100644 --- a/ceno_zkvm/src/precompiles/fptower/fp2_addsub.rs +++ b/ceno_zkvm/src/precompiles/fptower/fp2_addsub.rs @@ -448,6 +448,7 @@ mod tests { &[], &[], &challenges, + None, ); let out_evals = { diff --git a/ceno_zkvm/src/precompiles/fptower/fp2_mul.rs b/ceno_zkvm/src/precompiles/fptower/fp2_mul.rs index 6d8288710..e84291a10 100644 --- a/ceno_zkvm/src/precompiles/fptower/fp2_mul.rs +++ b/ceno_zkvm/src/precompiles/fptower/fp2_mul.rs @@ -465,6 +465,7 @@ mod tests { &[], &[], &challenges, + None, ); let out_evals = { diff --git a/ceno_zkvm/src/precompiles/lookup_keccakf.rs b/ceno_zkvm/src/precompiles/lookup_keccakf.rs index 4d5de9a4d..9108a2bc5 100644 --- a/ceno_zkvm/src/precompiles/lookup_keccakf.rs +++ b/ceno_zkvm/src/precompiles/lookup_keccakf.rs @@ -1128,6 +1128,7 @@ pub fn run_lookup_keccakf &[], &[], &challenges, + None, ); exit_span!(span); diff --git a/ceno_zkvm/src/precompiles/sha256/extend.rs b/ceno_zkvm/src/precompiles/sha256/extend.rs index 21b442835..7e646b7c4 100644 --- a/ceno_zkvm/src/precompiles/sha256/extend.rs +++ b/ceno_zkvm/src/precompiles/sha256/extend.rs @@ -470,6 +470,7 @@ mod tests { &[], &[], &challenges, + None, ); let out_evals = { diff --git a/ceno_zkvm/src/precompiles/uint256.rs b/ceno_zkvm/src/precompiles/uint256.rs index 3664b9a59..5c17b35f0 100644 --- a/ceno_zkvm/src/precompiles/uint256.rs +++ b/ceno_zkvm/src/precompiles/uint256.rs @@ -825,6 +825,7 @@ pub fn run_uint256_mul + ' &[], &[], &challenges, + None, ); exit_span!(span); diff --git a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs index 1bd025e60..3cd4aa5ed 100644 --- a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs +++ b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs @@ -679,6 +679,7 @@ pub fn run_weierstrass_add< &[], &[], &challenges, + None, ); exit_span!(span); diff --git a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs index ca3ad11b7..0636c0a95 100644 --- a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs +++ b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs @@ -659,6 +659,7 @@ pub fn run_weierstrass_decompress< &[], &[], &challenges, + None, ); exit_span!(span); diff --git a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs index 1362e19ad..91620772d 100644 --- a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs +++ b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs @@ -681,6 +681,7 @@ pub fn run_weierstrass_double< &[], &[], &challenges, + None, ); exit_span!(span); diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index b105efe75..f9ffc8a3c 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -9,7 +9,8 @@ use crate::{ hal::{DeviceProvingKey, EccQuarkProver, ProofInput, TowerProverSpec}, septic_curve::{SepticExtension, SepticPoint, SymbolicSepticExtension}, utils::{ - assign_group_evals, derive_ecc_bridge_claims, extract_ecc_quark_witness_inputs, + GkrOutputStageMask, assign_group_evals, derive_ecc_bridge_claims, + extract_ecc_quark_witness_inputs, first_layer_output_group_stage_masks, infer_tower_logup_witness, infer_tower_product_witness, split_rotation_evals, }, }, @@ -909,11 +910,13 @@ impl> MainSumcheckProver> MainSumcheckProver> MainSumcheckProver( composed_cs: &ComposedConstrainSystem, num_var_with_rotation: usize, ) -> usize { - let cs = &composed_cs.zkvm_v1_css; - let num_reads = cs.r_expressions.len() + cs.r_table_expressions.len(); - let num_writes = cs.w_expressions.len() + cs.w_table_expressions.len(); - let num_lk_num = cs.lk_table_expressions.len(); - let num_lk_den = if !cs.lk_table_expressions.is_empty() { - cs.lk_table_expressions.len() - } else { - cs.lk_expressions.len() - }; - let num_records = num_reads + num_writes + num_lk_num + num_lk_den; - let elem_size = std::mem::size_of::(); let record_len = 1usize << num_var_with_rotation; - num_records * record_len * elem_size + tower_output_count(composed_cs) * record_len * elem_size } pub(crate) fn estimate_main_constraints_bytes< @@ -233,8 +223,7 @@ pub(crate) fn estimate_main_constraints_bytes< // (see ZerocheckLayer verifier: max_degree = self.max_expr_degree + 1) let main_sumcheck_degree = (layer.max_expr_degree + 1).max(1); - let total_mles = - layer.n_witin + layer.n_structural_witin + layer.n_fixed + layer.n_instance; + let total_mles = layer.n_witin + layer.n_structural_witin + layer.n_fixed; let main_mle_num_vars_list = vec![num_var_with_rotation; total_mles]; let main_est = estimate_sumcheck_memory( num_var_with_rotation, diff --git a/ceno_zkvm/src/scheme/gpu/mod.rs b/ceno_zkvm/src/scheme/gpu/mod.rs index a83e5f4ce..2a3b9f900 100644 --- a/ceno_zkvm/src/scheme/gpu/mod.rs +++ b/ceno_zkvm/src/scheme/gpu/mod.rs @@ -11,7 +11,8 @@ use crate::{ DeviceProvingKey, MainSumcheckEvals, ProofInput, RotationProverOutput, TowerProverSpec, }, utils::{ - assign_group_evals, derive_ecc_bridge_claims, extract_ecc_quark_witness_inputs, + GkrOutputStageMask, assign_group_evals, derive_ecc_bridge_claims, + extract_ecc_quark_witness_inputs, first_layer_output_group_stage_masks, split_rotation_evals, }, }, @@ -75,13 +76,7 @@ use util::{ pub struct GpuTowerProver; -use crate::{ - scheme::{ - constants::{NUM_FANIN, SEPTIC_EXTENSION_DEGREE}, - septic_curve::SepticPoint, - }, - structs::EccQuarkProof, -}; +use crate::scheme::{constants::NUM_FANIN, septic_curve::SepticPoint}; use gkr_iop::{ gpu::{ArcMultilinearExtensionGpu, BB31Base, MultilinearExtensionGpu}, selector::{SelectorContext, SelectorType}, @@ -321,11 +316,13 @@ pub fn prove_main_constraints_impl< panic!("empty gkr circuit") }; let first_layer = gkr_circuit.layers.first().expect("empty gkr circuit layer"); + let group_stage_masks = first_layer_output_group_stage_masks(composed_cs, gkr_circuit); let selector_ctxs = first_layer .out_sel_and_eval_exprs .iter() - .map(|(selector, _)| { - if cs.ec_final_sum.is_empty() { + .zip_eq(group_stage_masks.iter()) + .map(|((selector, _), stage_mask)| { + if !stage_mask.contains(GkrOutputStageMask::TOWER) || cs.ec_final_sum.is_empty() { SelectorContext { offset: 0, num_instances, @@ -362,6 +359,9 @@ pub fn prove_main_constraints_impl< else { panic!("rotation proof provided for non-rotation layer") }; + debug_assert!(group_stage_masks[left_group_idx].contains(GkrOutputStageMask::ROTATION)); + debug_assert!(group_stage_masks[right_group_idx].contains(GkrOutputStageMask::ROTATION)); + debug_assert!(group_stage_masks[point_group_idx].contains(GkrOutputStageMask::ROTATION)); let (left_evals, right_evals, point_evals) = split_rotation_evals(&rotation.proof.evals); @@ -398,6 +398,11 @@ pub fn prove_main_constraints_impl< else { panic!("ecc proof provided for non-ecc layer") }; + debug_assert!(group_stage_masks[x_group_idx].contains(GkrOutputStageMask::ECC)); + debug_assert!(group_stage_masks[y_group_idx].contains(GkrOutputStageMask::ECC)); + debug_assert!(group_stage_masks[slope_group_idx].contains(GkrOutputStageMask::ECC)); + debug_assert!(group_stage_masks[x3_group_idx].contains(GkrOutputStageMask::ECC)); + debug_assert!(group_stage_masks[y3_group_idx].contains(GkrOutputStageMask::ECC)); let sample_r = transcript.sample_and_append_vec(b"ecc_gkr_bridge_r", 1)[0]; let claims = derive_ecc_bridge_claims(ecc_proof, sample_r, num_var_with_rotation) diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index cf04c0ed7..ea2971500 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -429,8 +429,17 @@ impl< let num_var_with_rotation = log2_num_instances + cs.rotation_vars().unwrap_or(0); // build main witness - let records = info_span!("[ceno] build_main_witness") - .in_scope(|| build_main_witness::(cs, input, challenges)); + let records = info_span!("[ceno] build_main_witness").in_scope(|| { + // ECC and rotation have dedicated witness/eval flows. For tower proving we only + // materialize the tower-facing GKR outputs here to avoid keeping unrelated output + // MLEs resident in VRAM during tower prove. + build_main_witness::( + cs, + input, + challenges, + crate::scheme::utils::WitnessBuildStage::Tower, + ) + }); let span = entered_span!("prove_tower_relation", profiling_2 = true); // prove the product and logup sum relation between layers in tower @@ -762,12 +771,20 @@ where // build main witness let records = info_span!("[ceno] build_main_witness").in_scope(|| { + // ECC and rotation have dedicated witness/eval flows. For tower proving we only + // materialize the tower-facing GKR outputs here to avoid keeping unrelated output + // MLEs resident in VRAM during tower prove. build_main_witness::< E, PCS, GpuBackend, gkr_iop::gpu::GpuProver>, - >(cs, &input, challenges) + >( + cs, + &input, + challenges, + crate::scheme::utils::WitnessBuildStage::Tower, + ) }); let span = entered_span!("prove_tower_relation", profiling_2 = true); diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index ec7758ed1..ead260f7d 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -34,6 +34,133 @@ use rayon::{ use std::{iter, sync::Arc}; use witness::next_pow2_instance_padding; +/// Prover-only routing metadata for first-layer GKR output groups. +/// +/// This is group-level metadata describing which downstream proving submodule +/// consumes outputs from a selector group. A group may route to more than one +/// submodule, e.g. `TOWER | ZERO`, when the flat tower-output prefix cuts +/// through the middle of the group. +/// +/// This metadata is not part of the proof format and is not used by the +/// verifier. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)] +pub(crate) struct GkrOutputStageMask(u8); + +impl GkrOutputStageMask { + pub(crate) const TOWER: Self = Self(1 << 0); + pub(crate) const ECC: Self = Self(1 << 1); + pub(crate) const ROTATION: Self = Self(1 << 2); + pub(crate) const ZERO: Self = Self(1 << 3); + + pub(crate) const fn union(self, other: Self) -> Self { + Self(self.0 | other.0) + } + + pub(crate) const fn contains(self, other: Self) -> bool { + (self.0 & other.0) == other.0 + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub(crate) enum WitnessBuildStage { + Tower, +} + +pub(crate) fn tower_output_count( + composed_cs: &ComposedConstrainSystem, +) -> usize { + let cs = &composed_cs.zkvm_v1_css; + let num_reads = cs.r_expressions.len() + cs.r_table_expressions.len(); + let num_writes = cs.w_expressions.len() + cs.w_table_expressions.len(); + let num_lk_num = cs.lk_table_expressions.len(); + let num_lk_den = if !cs.lk_table_expressions.is_empty() { + cs.lk_table_expressions.len() + } else { + cs.lk_expressions.len() + }; + num_reads + num_writes + num_lk_num + num_lk_den +} + +fn build_output_materialization_mask( + composed_cs: &ComposedConstrainSystem, + circuit: &GKRCircuit, + stage: WitnessBuildStage, +) -> Vec { + let first_layer = circuit.layers.first().expect("empty gkr circuit layer"); + let group_stage_masks = first_layer_output_group_stage_masks(composed_cs, circuit); + let total_outputs = first_layer + .out_sel_and_eval_exprs + .iter() + .map(|(_, outputs)| outputs.len()) + .sum::(); + let mut mask = vec![false; total_outputs]; + match stage { + WitnessBuildStage::Tower => { + // Materialization is exact at flattened-entry granularity even though routing metadata + // is tracked at group granularity. This is what lets mixed `TOWER | ZERO` groups avoid + // allocating the non-tower suffix during tower prove. + let mut remaining = tower_output_count(composed_cs); + let mut offset = 0usize; + for ((_, outputs), stage_mask) in first_layer + .out_sel_and_eval_exprs + .iter() + .zip(group_stage_masks.iter()) + { + let len = outputs.len(); + if stage_mask.contains(GkrOutputStageMask::TOWER) && remaining > 0 { + let take_len = len.min(remaining); + mask[offset..offset + take_len].fill(true); + remaining -= take_len; + } + offset += len; + } + debug_assert_eq!(remaining, 0, "failed to cover all tower outputs"); + } + } + mask +} + +pub(crate) fn first_layer_output_group_stage_masks( + composed_cs: &ComposedConstrainSystem, + circuit: &GKRCircuit, +) -> Vec { + let first_layer = circuit.layers.first().expect("empty gkr circuit layer"); + let mut group_masks = vec![GkrOutputStageMask::ZERO; first_layer.out_sel_and_eval_exprs.len()]; + + if let Some(rotation_groups) = first_layer.rotation_selector_group_indices() { + for group_idx in rotation_groups { + group_masks[group_idx] = GkrOutputStageMask::ROTATION; + } + } + if let Some(ecc_groups) = first_layer.ecc_bridge_group_indices() { + for group_idx in ecc_groups { + group_masks[group_idx] = GkrOutputStageMask::ECC; + } + } + + let tower_outputs = tower_output_count(composed_cs); + let mut seen_tower_outputs = 0usize; + for (group_mask, (_, outputs)) in group_masks + .iter_mut() + .zip(first_layer.out_sel_and_eval_exprs.iter()) + { + if seen_tower_outputs >= tower_outputs { + break; + } + *group_mask = group_mask.union(GkrOutputStageMask::TOWER); + seen_tower_outputs += outputs.len(); + } + assert!( + seen_tower_outputs >= tower_outputs, + "failed to cover all tower outputs: layer={}, seen_tower_outputs={}, tower_outputs={}", + first_layer.name, + seen_tower_outputs, + tower_outputs, + ); + + group_masks +} + pub(crate) struct EccBridgeClaims { pub(crate) xy_point: Point, pub(crate) s_point: Point, @@ -511,6 +638,7 @@ pub fn build_main_witness< composed_cs: &ComposedConstrainSystem, input: &ProofInput<'a, PB>, challenges: &[E; 2], + stage: WitnessBuildStage, ) -> Vec>> { let ComposedConstrainSystem { zkvm_v1_css: cs, @@ -561,6 +689,7 @@ pub fn build_main_witness< #[cfg(feature = "gpu")] let gpu_mem_tracker = crate::scheme::gpu::init_gpu_mem_tracker(&cuda_hal, "build_main_witness"); + let output_mask = build_output_materialization_mask(composed_cs, gkr_circuit, stage); let (_, gkr_circuit_out) = gkr_witness::( gkr_circuit, &input.witness, @@ -569,6 +698,7 @@ pub fn build_main_witness< &[], &input.pi, challenges, + Some(output_mask.as_slice()), ); // GPU memory check: validate estimation against actual usage @@ -582,6 +712,7 @@ pub fn build_main_witness< gkr_circuit_out.0.0 } +#[allow(clippy::too_many_arguments)] pub fn gkr_witness< 'b, E: ExtensionField, @@ -596,6 +727,7 @@ pub fn gkr_witness< _pub_io_mles: &[Arc>], pub_io_evals: &[Either], challenges: &[E], + output_mask: Option<&[bool]>, ) -> (GKRCircuitWitness<'b, PB>, GKRCircuitOutput<'b, PB>) { // layer order from output to input let mut layer_wits = Vec::>::with_capacity(circuit.layers.len() + 1); @@ -680,12 +812,16 @@ pub fn gkr_witness< ); // infer current layer output + let layer_output_mask = (i + 1 == circuit.layers.len()) + .then_some(output_mask) + .flatten(); let current_layer_output: Vec>> = - >::layer_witness( + >::layer_witness_filtered( layer, ¤t_layer_wits, pub_io_evals, challenges, + layer_output_mask, ); layer_wits.push(LayerWitness::new(current_layer_wits, vec![])); diff --git a/gkr_iop/src/cpu/mod.rs b/gkr_iop/src/cpu/mod.rs index e64bfbe56..76f722f23 100644 --- a/gkr_iop/src/cpu/mod.rs +++ b/gkr_iop/src/cpu/mod.rs @@ -112,6 +112,16 @@ impl> layer_wits: &[Arc< as ProverBackend>::MultilinearPoly<'a>>], pub_io_evals: &[Either], challenges: &[E], + ) -> Vec as ProverBackend>::MultilinearPoly<'a>>> { + Self::layer_witness_filtered(layer, layer_wits, pub_io_evals, challenges, None) + } + + fn layer_witness_filtered<'a>( + layer: &Layer, + layer_wits: &[Arc< as ProverBackend>::MultilinearPoly<'a>>], + pub_io_evals: &[Either], + challenges: &[E], + output_mask: Option<&[bool]>, ) -> Vec as ProverBackend>::MultilinearPoly<'a>>> { let span = entered_span!("witness_infer", profiling_2 = true); let out_evals: Vec<_> = layer @@ -119,13 +129,25 @@ impl> .iter() .flat_map(|(sel_type, out_eval)| izip!(iter::repeat(sel_type), out_eval.iter())) .collect(); + if let Some(mask) = output_mask { + assert_eq!( + mask.len(), + out_evals.len(), + "output_mask len {} != out_evals len {} for layer {}", + mask.len(), + out_evals.len(), + layer.name + ); + } let res = layer .exprs_with_selector_out_eval_monomial_form .par_iter() .zip_eq(layer.expr_names.par_iter()) - .zip_eq(out_evals.par_iter()) - .map(|((expr, expr_name), (_, out_eval))| { + .zip_eq(out_evals.par_iter().enumerate()) + .map(|((expr, expr_name), (idx, (_, out_eval)))| { + let should_materialize = output_mask.is_none_or(|mask| mask[idx]); if cfg!(debug_assertions) + && should_materialize && let EvalExpression::Zero = out_eval { assert!( @@ -137,10 +159,14 @@ impl> ); }; match out_eval { - EvalExpression::Linear(_, _, _) | EvalExpression::Single(_) => { + EvalExpression::Linear(_, _, _) | EvalExpression::Single(_) + if should_materialize => + { wit_infer_by_monomial_expr(expr, layer_wits, pub_io_evals, challenges) } - EvalExpression::Zero => MultilinearExtension::default().into(), + EvalExpression::Linear(_, _, _) + | EvalExpression::Single(_) + | EvalExpression::Zero => MultilinearExtension::default().into(), EvalExpression::Partition(_, _) => unimplemented!(), } }) diff --git a/gkr_iop/src/gpu/mod.rs b/gkr_iop/src/gpu/mod.rs index 0a8f75c02..59de3b63b 100644 --- a/gkr_iop/src/gpu/mod.rs +++ b/gkr_iop/src/gpu/mod.rs @@ -405,6 +405,16 @@ impl> layer_wits: &[Arc< as ProverBackend>::MultilinearPoly<'a>>], pub_io_evals: &[Either], challenges: &[E], + ) -> Vec as ProverBackend>::MultilinearPoly<'a>>> { + Self::layer_witness_filtered(layer, layer_wits, pub_io_evals, challenges, None) + } + + fn layer_witness_filtered<'a>( + layer: &Layer, + layer_wits: &[Arc< as ProverBackend>::MultilinearPoly<'a>>], + pub_io_evals: &[Either], + challenges: &[E], + output_mask: Option<&[bool]>, ) -> Vec as ProverBackend>::MultilinearPoly<'a>>> { let stream = get_thread_stream(); let span = entered_span!("preprocess", profiling_2 = true); @@ -417,21 +427,33 @@ impl> .iter() .flat_map(|(sel_type, out_eval)| izip!(std::iter::repeat(sel_type), out_eval.iter())) .collect(); + if let Some(mask) = output_mask { + assert_eq!( + mask.len(), + out_evals.len(), + "output_mask len {} != out_evals len {} for layer {}", + mask.len(), + out_evals.len(), + layer.name + ); + } // pre-process and flatten indices into friendly GPU format - let (num_non_zero_expr, term_coefficients, mle_indices_per_term, mle_size_info) = layer + let (selected_indices, term_coefficients, mle_indices_per_term, mle_size_info) = layer .exprs_with_selector_out_eval_monomial_form .iter() - .zip_eq(out_evals.iter()) - .filter(|(_, (_, out_eval))| { + .zip_eq(out_evals.iter().enumerate()) + .filter(|(_, (idx, (_, out_eval)))| { + let should_materialize = output_mask.is_none_or(|mask| mask[*idx]); match out_eval { - // only take linear/single to process - EvalExpression::Linear(_, _, _) | EvalExpression::Single(_) => true, + EvalExpression::Linear(_, _, _) | EvalExpression::Single(_) => { + should_materialize + } EvalExpression::Partition(..) => unimplemented!("Partition"), EvalExpression::Zero => false, } }) - .map(|(expr, _)| { + .map(|(expr, (idx, _))| { let (coeffs, indices, size_info) = extract_mle_relationships_from_monomial_terms( expr, &layer_wits.iter().map(|mle| mle.as_ref()).collect_vec(), @@ -439,35 +461,40 @@ impl> challenges, ); let coeffs_gl64: Vec = unsafe { std::mem::transmute(coeffs) }; - (coeffs_gl64, indices, size_info) + (idx, coeffs_gl64, indices, size_info) }) .fold( - (0, Vec::new(), Vec::new(), Vec::new()), - |(mut num_non_zero_expr, mut coeff_acc, mut indices_acc, mut size_acc), - (coeffs, indices, size_info)| { - num_non_zero_expr += 1; + (Vec::new(), Vec::new(), Vec::new(), Vec::new()), + |(mut selected, mut coeff_acc, mut indices_acc, mut size_acc), + (idx, coeffs, indices, size_info)| { + selected.push(idx); coeff_acc.push(coeffs); indices_acc.push(indices); size_acc.push(size_info); - (num_non_zero_expr, coeff_acc, indices_acc, size_acc) + (selected, coeff_acc, indices_acc, size_acc) }, ); + let num_non_zero_expr = selected_indices.len(); - assert!( + let num_vars = if num_non_zero_expr == 0 { + 0 + } else { + assert!( + mle_size_info + .iter() + .flat_map(|mle_size_info| mle_size_info.iter().filter(|(a, b)| { + assert_eq!(a, b); + *a > 0 && *b > 0 + })) + .all_equal() + ); mle_size_info .iter() - .flat_map(|mle_size_info| mle_size_info.iter().filter(|(a, b)| { - assert_eq!(a, b); - *a > 0 && *b > 0 - })) - .all_equal() - ); - let num_vars = mle_size_info - .iter() - .flat_map(|mle_size_info| mle_size_info.iter().filter(|(a, b)| *a > 0 && *b > 0)) - .take(1) - .collect_vec()[0] - .0; + .flat_map(|mle_size_info| mle_size_info.iter().filter(|(a, b)| *a > 0 && *b > 0)) + .take(1) + .collect_vec()[0] + .0 + }; exit_span!(span); let span = entered_span!("witness_infer", profiling_2 = true); @@ -503,15 +530,20 @@ impl> // recover it back and interleaving with default gpu let mut next_iter = next_witness_buf.into_iter(); + let mut selected_iter = selected_indices.into_iter().peekable(); out_evals .into_iter() - .map(|(_, out_eval)| { - if matches!( + .enumerate() + .map(|(idx, (_, out_eval))| { + let should_materialize = matches!( out_eval, EvalExpression::Linear(..) | EvalExpression::Single(_) - ) { - // take next element from next_witness_buf + ) && selected_iter + .peek() + .is_some_and(|selected_idx| *selected_idx == idx); + if should_materialize { + selected_iter.next(); MultilinearExtensionGpu::from_ceno_gpu_ext(GpuPolynomialExt::new( next_iter .next() diff --git a/gkr_iop/src/hal.rs b/gkr_iop/src/hal.rs index efb72e811..13c1cf6ac 100644 --- a/gkr_iop/src/hal.rs +++ b/gkr_iop/src/hal.rs @@ -54,4 +54,15 @@ pub trait ProtocolWitnessGeneratorProver { pub_io_evals: &[Either<::BaseField, PB::E>], challenges: &[PB::E], ) -> Vec>>; + + fn layer_witness_filtered<'a>( + layer: &Layer, + layer_wits: &[Arc>], + pub_io_evals: &[Either<::BaseField, PB::E>], + challenges: &[PB::E], + output_mask: Option<&[bool]>, + ) -> Vec>> { + let _ = output_mask; + Self::layer_witness(layer, layer_wits, pub_io_evals, challenges) + } }