diff --git a/ceno_recursion/src/transcript/mod.rs b/ceno_recursion/src/transcript/mod.rs index 9ea0e1ed4..b98ae4243 100644 --- a/ceno_recursion/src/transcript/mod.rs +++ b/ceno_recursion/src/transcript/mod.rs @@ -54,32 +54,3 @@ pub fn transcript_check_pow_witness( builder.assert_eq::>(bit, Usize::from(0)); }); } - -pub fn clone_challenger_state( - builder: &mut Builder, - src: &DuplexChallengerVariable, -) -> DuplexChallengerVariable { - let dst = DuplexChallengerVariable::new(builder); - builder - .range(0, dst.sponge_state.len()) - .for_each(|idx_vec, builder| { - let value = builder.get(&src.sponge_state, idx_vec[0]); - builder.set(&dst.sponge_state, idx_vec[0], value); - }); - - let input_offset = src.input_ptr - src.io_empty_ptr; - builder.assign(&dst.input_ptr, input_offset + dst.io_empty_ptr); - - let output_offset = src.output_ptr - src.io_empty_ptr; - builder.assign(&dst.output_ptr, output_offset + dst.io_empty_ptr); - dst -} - -pub fn challenger_add_forked_index( - builder: &mut Builder, - challenger: &mut DuplexChallengerVariable, - index: &Usize, -) { - let felt = builder.unsafe_cast_var_to_felt(index.get_var()); - challenger.observe(builder, felt); -} diff --git a/ceno_recursion/src/zkvm_verifier/verifier.rs b/ceno_recursion/src/zkvm_verifier/verifier.rs index b1b2e0c5d..7c45b9fd8 100644 --- a/ceno_recursion/src/zkvm_verifier/verifier.rs +++ b/ceno_recursion/src/zkvm_verifier/verifier.rs @@ -29,7 +29,6 @@ use crate::{ use ceno_zkvm::structs::{ComposedConstrainSystem, VerifyingKey, ZKVMVerifyingKey}; use ff_ext::BabyBearExt4; -use crate::transcript::{challenger_add_forked_index, clone_challenger_state}; use gkr_iop::{ evaluation::EvalExpression, gkr::{ @@ -153,22 +152,6 @@ pub fn verify_zkvm_proof>( challenger.observe(builder, log2_max_codeword_size_felt); } - iter_zip!(builder, zkvm_proof_input.chip_proofs).for_each(|ptr_vec, builder| { - let chip_proofs = builder.iter_ptr_get(&zkvm_proof_input.chip_proofs, ptr_vec[0]); - let chip_idx = builder.get(&chip_proofs, 0).idx_felt; - challenger.observe(builder, chip_idx); - - iter_zip!(builder, chip_proofs).for_each(|ptr_vec, builder| { - let chip_proof = builder.iter_ptr_get(&chip_proofs, ptr_vec[0]); - - iter_zip!(builder, chip_proof.num_instances).for_each(|ptr_vec, builder| { - let num_instance = builder.iter_ptr_get(&chip_proof.num_instances, ptr_vec[0]); - let num_instance = builder.unsafe_cast_var_to_felt(num_instance); - challenger.observe(builder, num_instance); - }); - }); - }); - challenger_multi_observe( builder, &mut challenger, @@ -252,9 +235,15 @@ pub fn verify_zkvm_proof>( iter_zip!(builder, chip_proofs).for_each(|ptr_vec, builder| { let chip_proof = builder.iter_ptr_get(&chip_proofs, ptr_vec[0]); - // fork transcript to support chip concurrently proved - let mut chip_challenger = clone_challenger_state(builder, &challenger); - challenger_add_forked_index(builder, &mut chip_challenger, &forked_sample_index); + // Fork chip transcript independently and bind challenges/metadata in verifier order. + let mut chip_challenger = DuplexChallengerVariable::new(builder); + transcript_observe_label(builder, &mut chip_challenger, b"fork"); + let alpha_felts = builder.ext2felt(alpha); + chip_challenger.observe_slice(builder, alpha_felts); + let beta_felts = builder.ext2felt(beta); + chip_challenger.observe_slice(builder, beta_felts); + let fork_id_felt = builder.unsafe_cast_var_to_felt(forked_sample_index.get_var()); + chip_challenger.observe(builder, fork_id_felt); builder.assert_usize_eq( chip_proof.rw_out_evals.length.clone(), Usize::from( @@ -266,6 +255,11 @@ pub fn verify_zkvm_proof>( Usize::from(circuit_vk.get_cs().num_lks() * 4), ); chip_challenger.observe(builder, chip_proof.idx_felt); + iter_zip!(builder, chip_proof.num_instances).for_each(|ptr_vec, builder| { + let num_instance = builder.iter_ptr_get(&chip_proof.num_instances, ptr_vec[0]); + let num_instance = builder.unsafe_cast_var_to_felt(num_instance); + chip_challenger.observe(builder, num_instance); + }); // getting the number of dummy padding item that we used in this opcode circuit let num_lks: Var = diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 2e0d18dbf..b7a67c9e4 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -23,7 +23,7 @@ use sumcheck::{ structs::IOPProverMessage, }; use tracing::info_span; -use transcript::{ForkableTranscript, Transcript}; +use transcript::{BasicTranscript, ForkableTranscript, Transcript}; use super::{PublicValues, ZKVMChipProof, ZKVMProof, hal::ProverDevice}; #[cfg(feature = "gpu")] @@ -171,43 +171,12 @@ impl< } exit_span!(span); - // only keep track of circuits that have non-zero instances - for (name, chip_inputs) in &witnesses.witnesses { - let pk = self.pk.circuit_pks.get(name).ok_or(ZKVMError::VKNotFound( - format!("proving key for circuit {} not found", name).into(), - ))?; - - // include omc init tables iff it's in first shard - if !shard_ctx.is_first_shard() && pk.get_cs().with_omc_init_only() { - continue; - } - - // num_instance from witness might include rotation - let num_instances = chip_inputs - .iter() - .flat_map(|chip_input| chip_input.num_instances) - .collect_vec(); - - if num_instances.iter().sum::() == 0 { - continue; - } - - let circuit_idx = self.pk.circuit_name_to_index.get(name).unwrap(); - // write (circuit_idx, num_var) to transcript - transcript.append_field_element(&E::BaseField::from_canonical_usize(*circuit_idx)); - for num_instance in num_instances { - transcript - .append_field_element(&E::BaseField::from_canonical_usize(num_instance)); - } - } - - // extract chip meta info before consuming witnesses - // (circuit_name, num_instances) - let name_and_instances = witnesses.get_witnesses_name_instance(); - let commit_to_traces_span = entered_span!("batch commit to traces", profiling_1 = true); let mut wits_rmms = BTreeMap::new(); + // Extract chip metadata before consuming witnesses. + // We reuse this for both transcript appends and task construction. + let name_and_instances = witnesses.get_witnesses_name_instance(); let mut structural_rmms = Vec::with_capacity(name_and_instances.len()); // commit to opcode circuits first and then commit to table circuits, sorted by name for (i, chip_input) in witnesses.into_iter_sorted().enumerate() { @@ -261,7 +230,6 @@ impl< transcript.read_challenge().elements, ]; tracing::debug!("global challenges in prover: {:?}", challenges); - let main_proofs_span = entered_span!("main_proofs", profiling_1 = true); // Phase 1: Build all ChipTasks @@ -283,8 +251,9 @@ impl< // GPU concurrent: memory-aware backfilling with standalone impl. // Sequential (GPU + CPU): unified path via self.create_chip_proof. let execute_tasks_span = entered_span!("execute_chip_tasks", profiling_1 = true); + let fork_transcript = BasicTranscript::::new(b"fork"); let (results, forked_samples) = - self.run_chip_proofs(tasks, &transcript, &witness_data)?; + self.run_chip_proofs(tasks, &fork_transcript, &witness_data)?; exit_span!(execute_tasks_span); // Phase 3: Collect results @@ -352,10 +321,20 @@ impl< let gpu_wd = SyncRef(gpu_witness_data); return scheduler.execute(tasks, transcript, |task, transcript| { + // Bind global challenges and metadata in the same order as verifier. + transcript.append_field_element_ext(&task.challenges[0]); + transcript.append_field_element_ext(&task.challenges[1]); + transcript + .append_field_element(&E::BaseField::from_canonical_usize(task.task_id)); // Append circuit_idx to per-task forked transcript (matching verifier) transcript.append_field_element(&E::BaseField::from_canonical_u64( task.circuit_idx as u64, )); + for num_instance in task.input.num_instances { + transcript.append_field_element(&E::BaseField::from_canonical_usize( + num_instance, + )); + } // SAFETY: TypeId check above (before closure) guarantees PB = GpuBackend. let gpu_input: ProofInput<'static, gkr_iop::gpu::GpuBackend> = @@ -389,9 +368,16 @@ impl< // Sequential path (GPU + CPU unified): // Uses execute_sequentially directly to avoid Send+Sync requirement on the closure. scheduler.execute_sequentially(tasks, transcript, |mut task, transcript| { + // Bind global challenges and metadata in the same order as verifier. + transcript.append_field_element_ext(&task.challenges[0]); + transcript.append_field_element_ext(&task.challenges[1]); + transcript.append_field_element(&E::BaseField::from_canonical_usize(task.task_id)); // Append circuit_idx to per-task forked transcript (matching verifier) transcript .append_field_element(&E::BaseField::from_canonical_u64(task.circuit_idx as u64)); + for num_instance in task.input.num_instances { + transcript.append_field_element(&E::BaseField::from_canonical_usize(num_instance)); + } // Prepare: deferred extraction for GPU, no-op for CPU self.device.prepare_chip_input(&mut task, witness_data); diff --git a/ceno_zkvm/src/scheme/scheduler.rs b/ceno_zkvm/src/scheme/scheduler.rs index c1bd260f4..9ab22ebbd 100644 --- a/ceno_zkvm/src/scheme/scheduler.rs +++ b/ceno_zkvm/src/scheme/scheduler.rs @@ -21,7 +21,6 @@ use crate::{ use ff_ext::ExtensionField; use gkr_iop::hal::ProverBackend; use mpcs::Point; -use p3::field::FieldAlgebra; use std::sync::OnceLock; use transcript::Transcript; static CHIP_PROVING_MODE: OnceLock = OnceLock::new(); @@ -152,8 +151,8 @@ impl ChipScheduler { /// Execute tasks sequentially with automatic transcript forking and sampling. /// - /// Each task gets a transcript cloned from `parent_transcript` with `task_id` - /// appended (identical to `ForkableTranscript::fork` default impl). + /// Each task gets a transcript cloned from `parent_transcript`. + /// Task-specific transcript appends are performed by the task closure. /// Returns `(results, forked_samples)` both sorted by task_id. #[allow(clippy::type_complexity)] pub(crate) fn execute_sequentially<'a, PB, T, F>( @@ -186,12 +185,8 @@ impl ChipScheduler { for task in tasks { let task_id = task.task_id; - // Fork: clone parent + append task_id - // (identical to ForkableTranscript::fork default impl) + // Fork: clone parent transcript template. let mut forked = parent_transcript.clone(); - forked.append_field_element(&::BaseField::from_canonical_u64( - task_id as u64, - )); let result = execute_task(task, &mut forked)?; results.push(result); @@ -213,8 +208,7 @@ impl ChipScheduler { /// Tasks are sorted by memory requirement (descending) and scheduled to /// maximize GPU utilization while respecting memory constraints. /// - /// Each worker thread clones the parent `transcript` and appends its task_id - /// (reproducing `ForkableTranscript::fork` locally). After proving, the worker + /// Each worker thread clones the parent `transcript`. After proving, the worker /// samples one extension-field element from its local transcript and returns it. /// This avoids sending non-`Send` transcript objects across threads. /// @@ -250,9 +244,6 @@ impl ChipScheduler { if tasks.len() == 1 { let task = tasks.remove(0); let mut fork = transcript.clone(); - fork.append_field_element(&::BaseField::from_canonical_u64( - task.task_id as u64, - )); let result = execute_task(task, &mut fork)?; let sample = fork.sample_vec(1)[0]; return Ok((vec![result], vec![sample])); @@ -347,14 +338,8 @@ impl ChipScheduler { // waiting for a CompletionMessage that never arrives). let outcome = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { - // Fork locally: clone parent transcript + append task_id - // (identical to ForkableTranscript::fork default impl) + // Fork locally: clone parent transcript template. let mut local_transcript = tr.0.clone(); - local_transcript.append_field_element( - &::BaseField::from_canonical_u64( - task_id as u64, - ), - ); let result = execute_fn(task, &mut local_transcript); diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 9d595d0a5..d8bab6661 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -42,7 +42,7 @@ use sumcheck::{ structs::{IOPProof, IOPVerifierState}, util::get_challenge_pows, }; -use transcript::{ForkableTranscript, Transcript}; +use transcript::{BasicTranscript, ForkableTranscript, Transcript}; use witness::next_pow2_instance_padding; pub struct ZKVMVerifier> { @@ -248,15 +248,6 @@ impl> ZKVMVerifier PCS::write_commitment(fixed_commit, &mut transcript).map_err(ZKVMError::PCSError)?; } - // write (circuit_idx, num_instance) to transcript - for (circuit_idx, proofs) in vm_proof.chip_proofs.iter() { - transcript.append_field_element(&E::BaseField::from_canonical_u32(*circuit_idx as u32)); - // length of proof.num_instances will be constrained in verify_chip_proof - for num_instance in proofs.iter().flat_map(|proof| &proof.num_instances) { - transcript.append_field_element(&E::BaseField::from_canonical_usize(*num_instance)); - } - } - // write witin commitment to transcript PCS::write_commitment(&vm_proof.witin_commit, &mut transcript) .map_err(ZKVMError::PCSError)?; @@ -306,16 +297,17 @@ impl> ZKVMVerifier } // fork transcript to support chip concurrently proved - let mut forked_transcripts = transcript.fork(num_proofs); - for ((index, proof), transcript) in vm_proof + let mut forked_transcripts = vec![BasicTranscript::new(b"fork"); num_proofs]; + for (index, ((circuit_index, proof), transcript)) in vm_proof .chip_proofs .iter() .flat_map(|(index, proofs)| iter::repeat_n(index, proofs.len()).zip(proofs)) .zip_eq(forked_transcripts.iter_mut()) + .enumerate() { let num_instance: usize = proof.num_instances.iter().sum(); assert!(num_instance > 0); - let circuit_name = &self.vk.circuit_index_to_name[index]; + let circuit_name = &self.vk.circuit_index_to_name[circuit_index]; let circuit_vk = &self.vk.circuit_vks[circuit_name]; if proof.r_out_evals.len() != circuit_vk.get_cs().num_reads() @@ -352,7 +344,14 @@ impl> ZKVMVerifier }) .sum::(); - transcript.append_field_element(&E::BaseField::from_canonical_u64(*index as u64)); + transcript.append_field_element_ext(&challenges[0]); + transcript.append_field_element_ext(&challenges[1]); + transcript.append_field_element(&E::BaseField::from_canonical_usize(index)); + transcript + .append_field_element(&E::BaseField::from_canonical_u64(*circuit_index as u64)); + for num_instance in &proof.num_instances { + transcript.append_field_element(&E::BaseField::from_canonical_usize(*num_instance)); + } // compute logup_sum padding // getting the number of dummy padding item that we used in this opcode circuit