Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 0 additions & 29 deletions ceno_recursion/src/transcript/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,32 +54,3 @@ pub fn transcript_check_pow_witness<C: Config>(
builder.assert_eq::<Var<C::N>>(bit, Usize::from(0));
});
}

pub fn clone_challenger_state<C: Config>(
builder: &mut Builder<C>,
src: &DuplexChallengerVariable<C>,
) -> DuplexChallengerVariable<C> {
let dst = DuplexChallengerVariable::new(builder);
builder
.range(0, dst.sponge_state.len())
.for_each(|idx_vec, builder| {
let value = builder.get(&src.sponge_state, idx_vec[0]);
builder.set(&dst.sponge_state, idx_vec[0], value);
});

let input_offset = src.input_ptr - src.io_empty_ptr;
builder.assign(&dst.input_ptr, input_offset + dst.io_empty_ptr);

let output_offset = src.output_ptr - src.io_empty_ptr;
builder.assign(&dst.output_ptr, output_offset + dst.io_empty_ptr);
dst
}

pub fn challenger_add_forked_index<C: Config>(
builder: &mut Builder<C>,
challenger: &mut DuplexChallengerVariable<C>,
index: &Usize<C::N>,
) {
let felt = builder.unsafe_cast_var_to_felt(index.get_var());
challenger.observe(builder, felt);
}
34 changes: 14 additions & 20 deletions ceno_recursion/src/zkvm_verifier/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ use crate::{
use ceno_zkvm::structs::{ComposedConstrainSystem, VerifyingKey, ZKVMVerifyingKey};
use ff_ext::BabyBearExt4;

use crate::transcript::{challenger_add_forked_index, clone_challenger_state};
use gkr_iop::{
evaluation::EvalExpression,
gkr::{
Expand Down Expand Up @@ -153,22 +152,6 @@ pub fn verify_zkvm_proof<C: Config<F = F>>(
challenger.observe(builder, log2_max_codeword_size_felt);
}

iter_zip!(builder, zkvm_proof_input.chip_proofs).for_each(|ptr_vec, builder| {
let chip_proofs = builder.iter_ptr_get(&zkvm_proof_input.chip_proofs, ptr_vec[0]);
let chip_idx = builder.get(&chip_proofs, 0).idx_felt;
challenger.observe(builder, chip_idx);

iter_zip!(builder, chip_proofs).for_each(|ptr_vec, builder| {
let chip_proof = builder.iter_ptr_get(&chip_proofs, ptr_vec[0]);

iter_zip!(builder, chip_proof.num_instances).for_each(|ptr_vec, builder| {
let num_instance = builder.iter_ptr_get(&chip_proof.num_instances, ptr_vec[0]);
let num_instance = builder.unsafe_cast_var_to_felt(num_instance);
challenger.observe(builder, num_instance);
});
});
});

challenger_multi_observe(
builder,
&mut challenger,
Expand Down Expand Up @@ -252,9 +235,15 @@ pub fn verify_zkvm_proof<C: Config<F = F>>(

iter_zip!(builder, chip_proofs).for_each(|ptr_vec, builder| {
let chip_proof = builder.iter_ptr_get(&chip_proofs, ptr_vec[0]);
// fork transcript to support chip concurrently proved
let mut chip_challenger = clone_challenger_state(builder, &challenger);
challenger_add_forked_index(builder, &mut chip_challenger, &forked_sample_index);
// Fork chip transcript independently and bind challenges/metadata in verifier order.
let mut chip_challenger = DuplexChallengerVariable::new(builder);
transcript_observe_label(builder, &mut chip_challenger, b"fork");
let alpha_felts = builder.ext2felt(alpha);
chip_challenger.observe_slice(builder, alpha_felts);
let beta_felts = builder.ext2felt(beta);
chip_challenger.observe_slice(builder, beta_felts);
let fork_id_felt = builder.unsafe_cast_var_to_felt(forked_sample_index.get_var());
chip_challenger.observe(builder, fork_id_felt);
builder.assert_usize_eq(
chip_proof.rw_out_evals.length.clone(),
Usize::from(
Expand All @@ -266,6 +255,11 @@ pub fn verify_zkvm_proof<C: Config<F = F>>(
Usize::from(circuit_vk.get_cs().num_lks() * 4),
);
chip_challenger.observe(builder, chip_proof.idx_felt);
iter_zip!(builder, chip_proof.num_instances).for_each(|ptr_vec, builder| {
let num_instance = builder.iter_ptr_get(&chip_proof.num_instances, ptr_vec[0]);
let num_instance = builder.unsafe_cast_var_to_felt(num_instance);
chip_challenger.observe(builder, num_instance);
});

// getting the number of dummy padding item that we used in this opcode circuit
let num_lks: Var<C::N> =
Expand Down
60 changes: 23 additions & 37 deletions ceno_zkvm/src/scheme/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use sumcheck::{
structs::IOPProverMessage,
};
use tracing::info_span;
use transcript::{ForkableTranscript, Transcript};
use transcript::{BasicTranscript, ForkableTranscript, Transcript};

use super::{PublicValues, ZKVMChipProof, ZKVMProof, hal::ProverDevice};
#[cfg(feature = "gpu")]
Expand Down Expand Up @@ -171,43 +171,12 @@ impl<
}
exit_span!(span);

// only keep track of circuits that have non-zero instances
for (name, chip_inputs) in &witnesses.witnesses {
let pk = self.pk.circuit_pks.get(name).ok_or(ZKVMError::VKNotFound(
format!("proving key for circuit {} not found", name).into(),
))?;

// include omc init tables iff it's in first shard
if !shard_ctx.is_first_shard() && pk.get_cs().with_omc_init_only() {
continue;
}

// num_instance from witness might include rotation
let num_instances = chip_inputs
.iter()
.flat_map(|chip_input| chip_input.num_instances)
.collect_vec();

if num_instances.iter().sum::<usize>() == 0 {
continue;
}

let circuit_idx = self.pk.circuit_name_to_index.get(name).unwrap();
// write (circuit_idx, num_var) to transcript
transcript.append_field_element(&E::BaseField::from_canonical_usize(*circuit_idx));
for num_instance in num_instances {
transcript
.append_field_element(&E::BaseField::from_canonical_usize(num_instance));
}
}

// extract chip meta info before consuming witnesses
// (circuit_name, num_instances)
let name_and_instances = witnesses.get_witnesses_name_instance();

let commit_to_traces_span = entered_span!("batch commit to traces", profiling_1 = true);
let mut wits_rmms = BTreeMap::new();

// Extract chip metadata before consuming witnesses.
// We reuse this for both transcript appends and task construction.
let name_and_instances = witnesses.get_witnesses_name_instance();
let mut structural_rmms = Vec::with_capacity(name_and_instances.len());
// commit to opcode circuits first and then commit to table circuits, sorted by name
for (i, chip_input) in witnesses.into_iter_sorted().enumerate() {
Expand Down Expand Up @@ -261,7 +230,6 @@ impl<
transcript.read_challenge().elements,
];
tracing::debug!("global challenges in prover: {:?}", challenges);

let main_proofs_span = entered_span!("main_proofs", profiling_1 = true);

// Phase 1: Build all ChipTasks
Expand All @@ -283,8 +251,9 @@ impl<
// GPU concurrent: memory-aware backfilling with standalone impl.
// Sequential (GPU + CPU): unified path via self.create_chip_proof.
let execute_tasks_span = entered_span!("execute_chip_tasks", profiling_1 = true);
let fork_transcript = BasicTranscript::<E>::new(b"fork");
let (results, forked_samples) =
self.run_chip_proofs(tasks, &transcript, &witness_data)?;
self.run_chip_proofs(tasks, &fork_transcript, &witness_data)?;
exit_span!(execute_tasks_span);

// Phase 3: Collect results
Expand Down Expand Up @@ -352,10 +321,20 @@ impl<
let gpu_wd = SyncRef(gpu_witness_data);

return scheduler.execute(tasks, transcript, |task, transcript| {
// Bind global challenges and metadata in the same order as verifier.
transcript.append_field_element_ext(&task.challenges[0]);
transcript.append_field_element_ext(&task.challenges[1]);
transcript
.append_field_element(&E::BaseField::from_canonical_usize(task.task_id));
// Append circuit_idx to per-task forked transcript (matching verifier)
transcript.append_field_element(&E::BaseField::from_canonical_u64(
task.circuit_idx as u64,
));
for num_instance in task.input.num_instances {
transcript.append_field_element(&E::BaseField::from_canonical_usize(
num_instance,
));
}

// SAFETY: TypeId check above (before closure) guarantees PB = GpuBackend<E, PCS>.
let gpu_input: ProofInput<'static, gkr_iop::gpu::GpuBackend<E, PCS>> =
Expand Down Expand Up @@ -389,9 +368,16 @@ impl<
// Sequential path (GPU + CPU unified):
// Uses execute_sequentially directly to avoid Send+Sync requirement on the closure.
scheduler.execute_sequentially(tasks, transcript, |mut task, transcript| {
// Bind global challenges and metadata in the same order as verifier.
transcript.append_field_element_ext(&task.challenges[0]);
transcript.append_field_element_ext(&task.challenges[1]);
transcript.append_field_element(&E::BaseField::from_canonical_usize(task.task_id));
// Append circuit_idx to per-task forked transcript (matching verifier)
transcript
.append_field_element(&E::BaseField::from_canonical_u64(task.circuit_idx as u64));
for num_instance in task.input.num_instances {
transcript.append_field_element(&E::BaseField::from_canonical_usize(num_instance));
}

// Prepare: deferred extraction for GPU, no-op for CPU
self.device.prepare_chip_input(&mut task, witness_data);
Expand Down
25 changes: 5 additions & 20 deletions ceno_zkvm/src/scheme/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ use crate::{
use ff_ext::ExtensionField;
use gkr_iop::hal::ProverBackend;
use mpcs::Point;
use p3::field::FieldAlgebra;
use std::sync::OnceLock;
use transcript::Transcript;
static CHIP_PROVING_MODE: OnceLock<ChipProvingMode> = OnceLock::new();
Expand Down Expand Up @@ -152,8 +151,8 @@ impl ChipScheduler {

/// Execute tasks sequentially with automatic transcript forking and sampling.
///
/// Each task gets a transcript cloned from `parent_transcript` with `task_id`
/// appended (identical to `ForkableTranscript::fork` default impl).
/// Each task gets a transcript cloned from `parent_transcript`.
/// Task-specific transcript appends are performed by the task closure.
/// Returns `(results, forked_samples)` both sorted by task_id.
#[allow(clippy::type_complexity)]
pub(crate) fn execute_sequentially<'a, PB, T, F>(
Expand Down Expand Up @@ -186,12 +185,8 @@ impl ChipScheduler {

for task in tasks {
let task_id = task.task_id;
// Fork: clone parent + append task_id
// (identical to ForkableTranscript::fork default impl)
// Fork: clone parent transcript template.
let mut forked = parent_transcript.clone();
forked.append_field_element(&<PB::E as ExtensionField>::BaseField::from_canonical_u64(
task_id as u64,
));

let result = execute_task(task, &mut forked)?;
results.push(result);
Expand All @@ -213,8 +208,7 @@ impl ChipScheduler {
/// Tasks are sorted by memory requirement (descending) and scheduled to
/// maximize GPU utilization while respecting memory constraints.
///
/// Each worker thread clones the parent `transcript` and appends its task_id
/// (reproducing `ForkableTranscript::fork` locally). After proving, the worker
/// Each worker thread clones the parent `transcript`. After proving, the worker
/// samples one extension-field element from its local transcript and returns it.
/// This avoids sending non-`Send` transcript objects across threads.
///
Expand Down Expand Up @@ -250,9 +244,6 @@ impl ChipScheduler {
if tasks.len() == 1 {
let task = tasks.remove(0);
let mut fork = transcript.clone();
fork.append_field_element(&<PB::E as ExtensionField>::BaseField::from_canonical_u64(
task.task_id as u64,
));
let result = execute_task(task, &mut fork)?;
let sample = fork.sample_vec(1)[0];
return Ok((vec![result], vec![sample]));
Expand Down Expand Up @@ -347,14 +338,8 @@ impl ChipScheduler {
// waiting for a CompletionMessage that never arrives).
let outcome =
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
// Fork locally: clone parent transcript + append task_id
// (identical to ForkableTranscript::fork default impl)
// Fork locally: clone parent transcript template.
let mut local_transcript = tr.0.clone();
local_transcript.append_field_element(
&<PB::E as ExtensionField>::BaseField::from_canonical_u64(
task_id as u64,
),
);

let result = execute_fn(task, &mut local_transcript);

Expand Down
27 changes: 13 additions & 14 deletions ceno_zkvm/src/scheme/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ use sumcheck::{
structs::{IOPProof, IOPVerifierState},
util::get_challenge_pows,
};
use transcript::{ForkableTranscript, Transcript};
use transcript::{BasicTranscript, ForkableTranscript, Transcript};
use witness::next_pow2_instance_padding;

pub struct ZKVMVerifier<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> {
Expand Down Expand Up @@ -248,15 +248,6 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
PCS::write_commitment(fixed_commit, &mut transcript).map_err(ZKVMError::PCSError)?;
}

// write (circuit_idx, num_instance) to transcript
for (circuit_idx, proofs) in vm_proof.chip_proofs.iter() {
transcript.append_field_element(&E::BaseField::from_canonical_u32(*circuit_idx as u32));
// length of proof.num_instances will be constrained in verify_chip_proof
for num_instance in proofs.iter().flat_map(|proof| &proof.num_instances) {
transcript.append_field_element(&E::BaseField::from_canonical_usize(*num_instance));
}
}

// write witin commitment to transcript
PCS::write_commitment(&vm_proof.witin_commit, &mut transcript)
.map_err(ZKVMError::PCSError)?;
Expand Down Expand Up @@ -306,16 +297,17 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
}

// fork transcript to support chip concurrently proved
let mut forked_transcripts = transcript.fork(num_proofs);
for ((index, proof), transcript) in vm_proof
let mut forked_transcripts = vec![BasicTranscript::new(b"fork"); num_proofs];
for (index, ((circuit_index, proof), transcript)) in vm_proof
.chip_proofs
.iter()
.flat_map(|(index, proofs)| iter::repeat_n(index, proofs.len()).zip(proofs))
.zip_eq(forked_transcripts.iter_mut())
.enumerate()
{
let num_instance: usize = proof.num_instances.iter().sum();
assert!(num_instance > 0);
let circuit_name = &self.vk.circuit_index_to_name[index];
let circuit_name = &self.vk.circuit_index_to_name[circuit_index];
let circuit_vk = &self.vk.circuit_vks[circuit_name];

if proof.r_out_evals.len() != circuit_vk.get_cs().num_reads()
Expand Down Expand Up @@ -352,7 +344,14 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
})
.sum::<E>();

transcript.append_field_element(&E::BaseField::from_canonical_u64(*index as u64));
transcript.append_field_element_ext(&challenges[0]);
transcript.append_field_element_ext(&challenges[1]);
transcript.append_field_element(&E::BaseField::from_canonical_usize(index));
transcript
.append_field_element(&E::BaseField::from_canonical_u64(*circuit_index as u64));
for num_instance in &proof.num_instances {
transcript.append_field_element(&E::BaseField::from_canonical_usize(*num_instance));
}

// compute logup_sum padding
// getting the number of dummy padding item that we used in this opcode circuit
Expand Down
Loading