diff --git a/ceno_emul/src/platform.rs b/ceno_emul/src/platform.rs index 16e7eddb0..2cb2c3d3a 100644 --- a/ceno_emul/src/platform.rs +++ b/ceno_emul/src/platform.rs @@ -1,15 +1,15 @@ +use crate::addr::{Addr, RegIdx}; use core::fmt::{self, Formatter}; use once_cell::sync::Lazy; +use serde::{Deserialize, Serialize}; use std::{collections::BTreeSet, fmt::Display, ops::Range, sync::Arc}; -use crate::addr::{Addr, RegIdx}; - /// The Platform struct holds the parameters of the VM. /// It defines: /// - the layout of virtual memory, /// - special addresses, such as the initial PC, /// - codes of environment calls. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct Platform { pub rom: Range, pub prog_data: Arc>, @@ -55,51 +55,43 @@ impl Display for Platform { } } -/// alined with [`memory.x`] -// ┌───────────────────────────── 0x4000_0000 (end of _sheap, or heap) +/// aligned with [`memory.x`] +// ┌───────────────────────────── 0x4000_0000 (stack top) // │ -// │ HEAP (128 MB, grows upward) -// │ 0x3800_0000 .. 0x4000_0000 +// │ STACK (256 MB window, grows downward) +// │ 0x3000_0000 .. 0x4000_0000 // │ -// ├───────────────────────────── 0x3800_0000 (_sheap, align 0x800_0000) -// │ RAM (128 MB) -// │ 0x3000_0000 .. 0x3800_0000 -// ├───────────────────────────── 0x3000_0000 (RAM base / hints end) +// ├───────────────────────────── 0x3000_0000 (stack base / hints end) // │ // │ HINTS (128 MB) // │ 0x2800_0000 .. 0x3000_0000 // │ -// │───────────────────────────── 0x2800_0000 (hint base / gap end) +// │───────────────────────────── 0x2800_0000 (hint start / gap end) // │ // │ [Reserved gap: 128 MB for debug I/O] // │ 0x2000_0000 .. 0x2800_0000 -// │───────────────────────────── 0x2000_0000 (gap / stack end) +// │───────────────────────────── 0x2000_0000 (gap / heap end) // │ -// │ STACK (≈128 MB, grows downward) +// │ HEAP (128 MB, grows upward) // │ 0x1800_0000 .. 0x2000_0000 // │ -// ├───────────────────────────── 0x1800_0000 (stack base) +// ├───────────────────────────── 0x1800_0000 (heap midpoint) // │ -// │ STACK (stack-only memory window, includes reserved low-address area -// │ previously used for PUBLIC I/O) +// │ RAM / DATA / BSS / HEAP // │ 0x1000_0000 .. 0x2000_0000 // │ -// ├───────────────────────────── 0x1000_0000 (stack base / rom end) +// ├───────────────────────────── 0x1000_0000 (ram base / rom end) // │ // │ ROM / TEXT / RODATA (128 MB) // │ 0x0800_0000 .. 0x1000_0000 // │ -// └───────────────────────────── 0x8000_0000 (rom base) +// └───────────────────────────── 0x0800_0000 (rom base) pub static CENO_PLATFORM: Lazy = Lazy::new(|| Platform { rom: 0x0800_0000..0x1000_0000, // 128 MB - stack: 0x1000_0000..0x2000_4000, // stack grows downward, 0x4000 reserved for debug io. - // we make hints start from 0x2800_0000 thus reserve a 128MB gap for debug io - // at the end of stack + stack: 0x3000_0000..0x4000_4000, // stack grows downward, 0x4000 reserved for debug io. hints: 0x2800_0000..0x3000_0000, // 128 MB - // heap grows upward, reserved 128 MB for it - // the beginning of heap address got bss/sbss data - // and the real heap start from 0x3800_0000 - heap: 0x3000_0000..0x4000_0000, + // heap grows upward in the low RAM window; .data/.bss live at its beginning. + heap: 0x1000_0000..0x2000_0000, unsafe_ecall_nop: false, prog_data: Arc::new(BTreeSet::new()), is_debug: false, diff --git a/ceno_emul/src/tracer.rs b/ceno_emul/src/tracer.rs index 45821ae64..9b68dff95 100644 --- a/ceno_emul/src/tracer.rs +++ b/ceno_emul/src/tracer.rs @@ -194,7 +194,7 @@ impl LatestAccesses { Self { store: DenseAddrSpace::new( WordAddr::from(0u32), - ByteAddr::from(platform.heap.end).waddr(), + ByteAddr::from(platform.stack.end).waddr(), ), len: 0, #[cfg(any(test, debug_assertions))] diff --git a/ceno_emul/src/vm_state.rs b/ceno_emul/src/vm_state.rs index d65844682..6fd02ce4e 100644 --- a/ceno_emul/src/vm_state.rs +++ b/ceno_emul/src/vm_state.rs @@ -66,7 +66,14 @@ impl VMState { program: program.clone(), memory: DenseAddrSpace::new( ByteAddr::from(platform.rom.start).waddr(), - ByteAddr::from(platform.heap.end).waddr(), + ByteAddr::from( + platform + .stack + .end + .max(platform.heap.end) + .max(platform.hints.end), + ) + .waddr(), ), registers: [0; VM_REG_COUNT], halt_state: None, diff --git a/ceno_recursion/src/aggregation/mod.rs b/ceno_recursion/src/aggregation/mod.rs index 8d3706f8d..68ed3fd2a 100644 --- a/ceno_recursion/src/aggregation/mod.rs +++ b/ceno_recursion/src/aggregation/mod.rs @@ -2,8 +2,12 @@ use crate::zkvm_verifier::{ binding::{E, F, ZKVMProofInput, ZKVMProofInputVariable}, verifier::verify_zkvm_proof, }; +use ceno_emul::WORD_SIZE; use ceno_zkvm::{ - instructions::riscv::constants::{END_PC_IDX, EXIT_CODE_IDX, INIT_PC_IDX}, + instructions::riscv::constants::{ + END_PC_IDX, EXIT_CODE_IDX, HEAP_LENGTH_IDX, HEAP_START_ADDR_IDX, HINT_LENGTH_IDX, + HINT_START_ADDR_IDX, INIT_PC_IDX, + }, scheme::ZKVMProof, structs::ZKVMVerifyingKey, }; @@ -56,7 +60,7 @@ use openvm_stark_sdk::{ openvm_stark_backend::keygen::types::MultiStarkVerifyingKey, p3_bn254_fr::Bn254Fr, }; -use p3::field::FieldAlgebra; +use p3::field::{FieldAlgebra, PrimeField32}; use serde::{Deserialize, Serialize}; use std::{borrow::Borrow, sync::Arc, time::Instant}; pub type RecPcs = Basefold; @@ -387,6 +391,49 @@ pub struct CenoLeafVmVerifierConfig { } impl CenoLeafVmVerifierConfig { + fn assert_felt_lt>( + builder: &mut Builder, + lhs: Felt, + rhs: Felt, + max_bits: u32, + ) { + Self::check_felt_lt(builder, lhs, rhs, max_bits, true) + } + + fn assert_felt_ge>( + builder: &mut Builder, + lhs: Felt, + rhs: Felt, + max_bits: u32, + ) { + Self::check_felt_lt(builder, lhs, rhs, max_bits, false) + } + + fn assert_felt_range>( + builder: &mut Builder, + value: Felt, + start: Felt, + end: Felt, + max_bits: u32, + ) { + Self::assert_felt_ge(builder, value, start, max_bits); + Self::assert_felt_lt(builder, value, end, max_bits); + } + + fn check_felt_lt>( + builder: &mut Builder, + lhs: Felt, + rhs: Felt, + max_bits: u32, + is_lt: bool, + ) { + let range: Felt<_> = builder.constant(C::F::from_canonical_u64(1u64 << max_bits)); + let zero = builder.constant(F::ZERO); + let diff = builder.eval(lhs - rhs + if is_lt { range } else { zero }); + let diff = builder.cast_felt_to_var(diff); + builder.range_check_var(diff, max_bits as usize); + } + pub fn build_program(&self) -> Program { let mut builder = Builder::::default(); @@ -436,6 +483,97 @@ impl CenoLeafVmVerifierConfig { builder.assign(&stark_pvs.public_values_commit[i], F::ZERO); } + assert!( + 2 * self.vk.mem_state_verifier.heap.end < F::ORDER_U32, + "2 * {:x} >= {}", + self.vk.mem_state_verifier.heap.end, + F::ORDER_U32 + ); + assert!( + 2 * self.vk.mem_state_verifier.hints.end < F::ORDER_U32, + "2 * {:x} >= {}", + self.vk.mem_state_verifier.hints.end, + F::ORDER_U32 + ); + fn bits_needed(x: u32) -> u32 { + if x == 0 { 1 } else { 32 - x.leading_zeros() } + } + let heap_max_bits = bits_needed( + self.vk.mem_state_verifier.heap.end - self.vk.mem_state_verifier.heap.start, + ); + let hint_max_bits = bits_needed( + self.vk.mem_state_verifier.hints.end - self.vk.mem_state_verifier.hints.start, + ); + let heap_min_start_addr = { + let v = builder.eval(Usize::from(self.vk.mem_state_verifier.heap.start as usize)); + builder.unsafe_cast_var_to_felt(v) + }; + let heap_max_end_addr = { + let v = builder.eval(Usize::from(self.vk.mem_state_verifier.heap.end as usize)); + builder.unsafe_cast_var_to_felt(v) + }; + let heap_max_addr_diff = { + let v = builder.eval(Usize::from( + (self.vk.mem_state_verifier.heap.end - self.vk.mem_state_verifier.heap.start) + as usize, + )); + builder.unsafe_cast_var_to_felt(v) + }; + let hint_min_start_addr = { + let v = builder.eval(Usize::from(self.vk.mem_state_verifier.hints.start as usize)); + builder.unsafe_cast_var_to_felt(v) + }; + let hint_max_end_addr = { + let v = builder.eval(Usize::from(self.vk.mem_state_verifier.hints.end as usize)); + builder.unsafe_cast_var_to_felt(v) + }; + let hint_max_addr_diff = { + let v = builder.eval(Usize::from( + (self.vk.mem_state_verifier.hints.end - self.vk.mem_state_verifier.hints.start) + as usize, + )); + builder.unsafe_cast_var_to_felt(v) + }; + + let heap_start_addr = builder.get(pv, HEAP_START_ADDR_IDX); + let heap_length_words = builder.get(pv, HEAP_LENGTH_IDX); + let word_size = builder.constant::>(F::from_canonical_usize(WORD_SIZE)); + let heap_length = builder.eval(heap_length_words * word_size); + let heap_end_addr = builder.eval(heap_start_addr + heap_length); + let hint_start_addr = builder.get(pv, HINT_START_ADDR_IDX); + let hint_length_words = builder.get(pv, HINT_LENGTH_IDX); + let hint_length = builder.eval(hint_length_words * word_size); + let hint_end_addr = builder.eval(hint_start_addr + hint_length); + + Self::assert_felt_range( + &mut builder, + heap_start_addr, + heap_min_start_addr, + heap_max_end_addr, + heap_max_bits, + ); + Self::assert_felt_lt( + &mut builder, + heap_end_addr, + heap_max_end_addr, + heap_max_bits, + ); + Self::assert_felt_lt(&mut builder, heap_length, heap_max_addr_diff, heap_max_bits); + Self::assert_felt_range( + &mut builder, + hint_start_addr, + hint_min_start_addr, + hint_max_end_addr, + hint_max_bits, + ); + Self::assert_felt_lt( + &mut builder, + hint_end_addr, + hint_max_end_addr, + hint_max_bits, + ); + Self::assert_felt_lt(&mut builder, hint_length, hint_max_addr_diff, hint_max_bits); + // TODO: assign shard_ec_sum to stark_pvs.shard_ec_sum // builder diff --git a/ceno_rt/ceno_link.x b/ceno_rt/ceno_link.x index 0aa7dd928..dba22c194 100644 --- a/ceno_rt/ceno_link.x +++ b/ceno_rt/ceno_link.x @@ -21,11 +21,10 @@ SECTIONS *(.rodata .rodata.*); } > ROM - .stack (NOLOAD) : ALIGN(4) { *(.stack .stack.*) - } > STACK_PUBIO + } > STACK /* Define a section for runtime-populated EEPROM-like HINTS data */ .hints (NOLOAD) : ALIGN(4) @@ -55,4 +54,5 @@ SECTIONS . = ALIGN(0x8000000); _sheap = .; } > RAM + } diff --git a/ceno_rt/memory.x b/ceno_rt/memory.x index d42358be7..28f538edc 100644 --- a/ceno_rt/memory.x +++ b/ceno_rt/memory.x @@ -1,15 +1,15 @@ MEMORY { ROM (rx) : ORIGIN = 0x08000000, LENGTH = 128M - STACK_PUBIO (rw) : ORIGIN = 0x10000000, LENGTH = 256M /* Stack region */ + RAM (rw) : ORIGIN = 0x10000000, LENGTH = 256M /* heap/data/bss */ HINTS (r) : ORIGIN = 0x20000000, LENGTH = 256M /* will shift hint to 0x28000000 with 128M to reserve gap*/ - RAM (rw) : ORIGIN = 0x30000000, LENGTH = 256M /* heap/data/bss */ + STACK (rw) : ORIGIN = 0x30000000, LENGTH = 256M /* stack-only region */ } REGION_ALIAS("REGION_TEXT", ROM); REGION_ALIAS("REGION_RODATA", ROM); -REGION_ALIAS("REGION_STACK", STACK_PUBIO); +REGION_ALIAS("REGION_STACK", STACK); REGION_ALIAS("REGION_HINTS", HINTS); REGION_ALIAS("REGION_DATA", RAM); diff --git a/ceno_rt/src/params.rs b/ceno_rt/src/params.rs index 36a4fef05..ed7b38fee 100644 --- a/ceno_rt/src/params.rs +++ b/ceno_rt/src/params.rs @@ -1,4 +1,4 @@ pub const WORD_SIZE: usize = 4; /// address defined in `memory.x` under RAM section. -pub const INFO_OUT_ADDR: u32 = 0x2000_0000; +pub const INFO_OUT_ADDR: u32 = 0x4000_0000; diff --git a/ceno_zkvm/benches/fibonacci.rs b/ceno_zkvm/benches/fibonacci.rs index 78afe8434..e86f341e2 100644 --- a/ceno_zkvm/benches/fibonacci.rs +++ b/ceno_zkvm/benches/fibonacci.rs @@ -13,7 +13,10 @@ use criterion::*; use ff_ext::BabyBearExt4; use gkr_iop::cpu::default_backend_config; -use ceno_zkvm::{e2e::MultiProver, scheme::verifier::ZKVMVerifier}; +use ceno_zkvm::{ + e2e::MultiProver, + scheme::verifier::{RV32imMemStateConfig, ZKVMVerifier}, +}; use mpcs::BasefoldDefault; use transcript::BasicTranscript; @@ -69,7 +72,7 @@ fn fibonacci_prove(c: &mut Criterion) { println!("e2e proof {}", proof); let transcript = BasicTranscript::new(b"riscv"); - let verifier = ZKVMVerifier::::new(vk); + let verifier = ZKVMVerifier::::new(vk); assert!( verifier .verify_full_trace_proofs_halt(vec![proof], vec![transcript], false) diff --git a/ceno_zkvm/benches/keccak.rs b/ceno_zkvm/benches/keccak.rs index 07ca96f37..33bf1b3c2 100644 --- a/ceno_zkvm/benches/keccak.rs +++ b/ceno_zkvm/benches/keccak.rs @@ -8,7 +8,10 @@ use ceno_zkvm::{ scheme::{create_backend, create_prover}, }; mod alloc; -use ceno_zkvm::{e2e::MultiProver, scheme::verifier::ZKVMVerifier}; +use ceno_zkvm::{ + e2e::MultiProver, + scheme::verifier::{RV32imMemStateConfig, ZKVMVerifier}, +}; use criterion::*; use ff_ext::BabyBearExt4; use gkr_iop::cpu::default_backend_config; @@ -66,7 +69,7 @@ fn keccak_prove(c: &mut Criterion) { println!("e2e proof {}", proof); let transcript = BasicTranscript::new(b"riscv"); - let verifier = ZKVMVerifier::::new(vk); + let verifier = ZKVMVerifier::::new(vk); assert!( verifier .verify_full_trace_proofs_halt(vec![proof], vec![transcript], true) diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index dfa52b333..6b1a18e6d 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -1600,8 +1600,8 @@ impl E2EProgramCtx { self.zkvm_fixed_traces.clone(), ) .expect("keygen failed"); - let vk = pk.get_vk_slow(); pk.set_program_ctx(self); + let vk = pk.get_vk_slow(); (pk, vk) } @@ -1623,8 +1623,8 @@ impl E2EProgramCtx { self.zkvm_fixed_traces.clone(), ) .expect("keygen failed"); - let vk = pk.get_vk_slow(); pk.set_program_ctx(self); + let vk = pk.get_vk_slow(); (pk, vk) } @@ -2073,15 +2073,19 @@ pub fn run_e2e_single_shard_debug_verify, max_steps: usize, ) { - let expect_halt = zkvm_proof.has_halt(&verifier.vk) || exit_code.is_some(); + let expect_halt = zkvm_proof.has_halt(&verifier.vk); let verified = verifier .verify_single_shard_segment_halt(zkvm_proof, Transcript::new(b"riscv"), expect_halt) .expect("verify proof return with error"); assert!(verified); - match exit_code { - Some(0) => tracing::info!("exit code 0. Success."), - Some(code) => tracing::error!("exit code {}. Failure.", code), - None => tracing::error!("Unfinished execution. max_steps={:?}.", max_steps), + if expect_halt { + match exit_code { + Some(0) => tracing::info!("exit code 0. Success."), + Some(code) => tracing::error!("exit code {}. Failure.", code), + None => tracing::error!("Unfinished execution. max_steps={:?}.", max_steps), + } + } else { + tracing::info!("single shard segment verified without full-trace continuation checks"); } } diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 3b10a1145..a1f172e20 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -12,7 +12,8 @@ use super::{PublicValues, ZKVMChipProof, ZKVMProof}; use crate::{ error::ZKVMError, instructions::riscv::constants::{ - END_PC_IDX, HEAP_LENGTH_IDX, HEAP_START_ADDR_IDX, INIT_CYCLE_IDX, INIT_PC_IDX, + END_PC_IDX, HEAP_LENGTH_IDX, HEAP_START_ADDR_IDX, HINT_LENGTH_IDX, HINT_START_ADDR_IDX, + INIT_CYCLE_IDX, INIT_PC_IDX, }, scheme::{ constants::{NUM_FANIN, SEPTIC_EXTENSION_DEGREE}, @@ -46,11 +47,87 @@ use sumcheck::{ use transcript::{ForkableTranscript, Transcript}; use witness::next_pow2_instance_padding; -pub struct ZKVMVerifier> { - pub vk: ZKVMVerifyingKey, +pub use crate::structs::RV32imMemStateConfig; + +pub struct ZKVMVerifier< + E: ExtensionField, + PCS: PolynomialCommitmentScheme, + M = RV32imMemStateConfig, +> where + M: Clone + Default + serde::Serialize + serde::de::DeserializeOwned, +{ + pub vk: ZKVMVerifyingKey, } -impl> ZKVMVerifier { +impl, M> ZKVMVerifier +where + M: Clone + Default + serde::Serialize + serde::de::DeserializeOwned, +{ + pub fn new(vk: ZKVMVerifyingKey) -> Self { + ZKVMVerifier { vk } + } + + pub fn into_inner(self) -> ZKVMVerifyingKey { + self.vk + } +} + +impl> + ZKVMVerifier +{ + fn validate_mem_state( + mem_state: &RV32imMemStateConfig, + prev_heap_addr_end: Option, + prev_hint_addr_end: Option, + vm_proof: &ZKVMProof, + ) -> Result<(u32, u32), ZKVMError> { + let heap_addr_start = vm_proof + .public_values + .query_by_index::(HEAP_START_ADDR_IDX) + .to_canonical_u64() as u32; + let heap_len = vm_proof + .public_values + .query_by_index::(HEAP_LENGTH_IDX) + .to_canonical_u64() as u32; + let next_heap_addr_end = heap_addr_start + heap_len * WORD_SIZE as u32; + if !mem_state.heap.contains(&heap_addr_start) + || !mem_state.heap.contains(&next_heap_addr_end) + { + return Err(ZKVMError::VerifyError( + "heap continuation out of range".into(), + )); + } + if let Some(prev_heap_addr_end) = prev_heap_addr_end + && heap_addr_start != prev_heap_addr_end + { + return Err(ZKVMError::VerifyError("heap continuation mismatch".into())); + } + + let hint_addr_start = vm_proof + .public_values + .query_by_index::(HINT_START_ADDR_IDX) + .to_canonical_u64() as u32; + let hint_len = vm_proof + .public_values + .query_by_index::(HINT_LENGTH_IDX) + .to_canonical_u64() as u32; + let next_hint_addr_end = hint_addr_start + hint_len * WORD_SIZE as u32; + if !mem_state.hints.contains(&hint_addr_start) + || !mem_state.hints.contains(&next_hint_addr_end) + { + return Err(ZKVMError::VerifyError( + "hint continuation out of range".into(), + )); + } + if let Some(prev_hint_addr_end) = prev_hint_addr_end + && hint_addr_start != prev_hint_addr_end + { + return Err(ZKVMError::VerifyError("hint continuation mismatch".into())); + } + + Ok((next_heap_addr_end, next_hint_addr_end)) + } + #[allow(clippy::type_complexity)] fn split_input_opening_evals( circuit_vk: &VerifyingKey, @@ -84,14 +161,6 @@ impl> ZKVMVerifier Ok((wits_in_evals, fixed_in_evals)) } - pub fn new(vk: ZKVMVerifyingKey) -> Self { - ZKVMVerifier { vk } - } - - pub fn into_inner(self) -> ZKVMVerifyingKey { - self.vk - } - /// Verify a full zkVM trace from program entry to halt. /// /// This is the production verifier API. It treats a single proof as a @@ -156,13 +225,13 @@ impl> ZKVMVerifier ) -> Result { assert!(!vm_proofs.is_empty()); let num_proofs = vm_proofs.len(); - let (_end_pc, _end_heap_addr, shard_ec_sum) = vm_proofs + let (_end_pc, _end_heap_addr, _end_hint_addr, shard_ec_sum) = vm_proofs .into_iter() .zip_eq(transcripts) // optionally halt on last chunk .zip_eq(iter::repeat_n(false, num_proofs - 1).chain(iter::once(expect_halt))) .enumerate() - .try_fold((None, None, SepticPoint::::default()), |(prev_pc, prev_heap_addr_end, mut shard_ec_sum), (shard_id, ((vm_proof, transcript), expect_halt))| { + .try_fold((None, None, None, SepticPoint::::default()), |(prev_pc, prev_heap_addr_end, prev_hint_addr_end, mut shard_ec_sum), (shard_id, ((vm_proof, transcript), expect_halt))| { // require ecall/halt proof to exist, depend on whether we expect a halt. let has_halt = vm_proof.has_halt(&self.vk); if has_halt != expect_halt { @@ -194,29 +263,23 @@ impl> ZKVMVerifier } let end_pc = vm_proof.public_values.query_by_index::(END_PC_IDX); - // check memory continuation consistency - let heap_addr_start_u32 = vm_proof - .public_values - .query_by_index::(HEAP_START_ADDR_IDX) - .to_canonical_u64() as u32; - let heap_len = vm_proof - .public_values - .query_by_index::(HEAP_LENGTH_IDX) - .to_canonical_u64() as u32; - if let Some(prev_heap_addr_end) = prev_heap_addr_end { - assert_eq!(heap_addr_start_u32, prev_heap_addr_end); - // TODO check heap addr in prime field within range - } else { - // TODO first chunk, check initial heap addr - }; - // TODO check heap_len == heap chip num_instances - let next_heap_addr_end: u32 = heap_addr_start_u32 + heap_len * WORD_SIZE as u32; + let (next_heap_addr_end, next_hint_addr_end) = Self::validate_mem_state( + &self.vk.mem_state_verifier, + prev_heap_addr_end, + prev_hint_addr_end, + &vm_proof, + )?; // add to shard ec sum let shard_ec = self.verify_proof_validity(shard_id, vm_proof, transcript)?; shard_ec_sum = shard_ec_sum + shard_ec; - Ok((Some(end_pc), Some(next_heap_addr_end), shard_ec_sum)) + Ok(( + Some(end_pc), + Some(next_heap_addr_end), + Some(next_hint_addr_end), + shard_ec_sum, + )) })?; // TODO check _end_heap_addr within heap range from vk // check shard ec_sum is_infinity @@ -336,6 +399,11 @@ impl> ZKVMVerifier .into(), )); } + if circuit_vk.get_cs().with_omc_init_dyn() && proofs.len() > 1 { + return Err(ZKVMError::InvalidProof( + format!("{shard_id}th shard got > 1 dynamic table init").into(), + )); + } num_proofs += proofs.len(); } @@ -352,6 +420,29 @@ impl> ZKVMVerifier let circuit_name = &self.vk.circuit_index_to_name[index]; let circuit_vk = &self.vk.circuit_vks[circuit_name]; + if circuit_name == "HeapTable" { + let heap_len = vm_proof + .public_values + .query_by_index::(HEAP_LENGTH_IDX) + .to_canonical_u64() as usize; + if num_instance != heap_len { + return Err(ZKVMError::InvalidProof( + format!("heap shard length mismatch: proof {num_instance} != public value {heap_len}").into(), + )); + } + } + if circuit_name == "HintsTable" { + let hint_len = vm_proof + .public_values + .query_by_index::(HINT_LENGTH_IDX) + .to_canonical_u64() as usize; + if num_instance != hint_len { + return Err(ZKVMError::InvalidProof( + format!("hint shard length mismatch: proof {num_instance} != public value {hint_len}").into(), + )); + } + } + if proof.r_out_evals.len() != circuit_vk.get_cs().num_reads() || proof.w_out_evals.len() != circuit_vk.get_cs().num_writes() { diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 31540e3bc..3c856b0e7 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -11,7 +11,10 @@ use crate::{ }; use ceno_emul::{Addr, CENO_PLATFORM, Platform, RegIdx, StepIndex, StepRecord, WordAddr}; use ff_ext::{ExtensionField, PoseidonField}; -use gkr_iop::{gkr::GKRCircuit, tables::LookupTable, utils::lk_multiplicity::Multiplicity}; +use gkr_iop::{ + circuit_builder::ShardOMCInitType, gkr::GKRCircuit, tables::LookupTable, + utils::lk_multiplicity::Multiplicity, +}; use itertools::Itertools; use mpcs::{Point, PolynomialCommitmentScheme}; use multilinear_extensions::Instance; @@ -30,6 +33,33 @@ use sumcheck::structs::{IOPProof, IOPProverMessage}; use tracing::Level; use witness::RowMajorMatrix; +#[derive(Clone, Default, Serialize, Deserialize)] +pub struct RV32imMemStateConfig { + pub heap: Range, + pub hints: Range, +} + +impl RV32imMemStateConfig { + pub fn from_platform(platform: &Platform) -> Self { + Self { + heap: platform.heap.start..platform.heap.end, + hints: platform.hints.start..platform.hints.end, + } + } +} + +impl From for RV32imMemStateConfig { + fn from(platform: Platform) -> Self { + Self::from_platform(&platform) + } +} + +impl From<&Platform> for RV32imMemStateConfig { + fn from(platform: &Platform) -> Self { + Self::from_platform(platform) + } +} + /// Proof that the sum of N (not necessarily a power of two) EC points /// is equal to `sum` in one layer instead of multiple layers in a /// GKR layered circuit approach that we used for offline memory checking. @@ -187,7 +217,11 @@ impl ComposedConstrainSystem { } pub fn with_omc_init_only(&self) -> bool { - self.zkvm_v1_css.with_omc_init_only + matches!(self.zkvm_v1_css.omc_init_type, ShardOMCInitType::InitOnce) + } + + pub fn with_omc_init_dyn(&self) -> bool { + matches!(self.zkvm_v1_css.omc_init_type, ShardOMCInitType::InitDyn) } } @@ -778,6 +812,13 @@ impl> ZKVMProvingKey> ZKVMProvingKey { pub fn get_vk_slow(&self) -> ZKVMVerifyingKey { + self.get_vk_slow_with_mem_state::() + } + + pub fn get_vk_slow_with_mem_state(&self) -> ZKVMVerifyingKey + where + M: Clone + Default + From + Serialize + DeserializeOwned, + { ZKVMVerifyingKey { vp: self.vp.clone(), entry_pc: self.entry_pc, @@ -794,6 +835,11 @@ impl> ZKVMProvingKey> ZKVMProvingKey> -where +pub struct ZKVMVerifyingKey< + E: ExtensionField, + PCS: PolynomialCommitmentScheme, + M = RV32imMemStateConfig, +> where PCS::VerifierParam: Sized, + M: Clone + Default + Serialize + DeserializeOwned, { pub vp: PCS::VerifierParam, // entry program counter @@ -821,4 +871,5 @@ where // circuit index -> circuit name // mainly used for debugging pub circuit_index_to_name: BTreeMap, + pub mem_state_verifier: M, } diff --git a/ceno_zkvm/src/tables/mod.rs b/ceno_zkvm/src/tables/mod.rs index f736947e1..79af6b918 100644 --- a/ceno_zkvm/src/tables/mod.rs +++ b/ceno_zkvm/src/tables/mod.rs @@ -48,6 +48,8 @@ pub trait TableCircuit { let r_table_len = cb.cs.r_table_expressions.len(); let w_table_len = cb.cs.w_table_expressions.len(); let lk_table_len = cb.cs.lk_table_expressions.len() * 2; + let zero_len = + cb.cs.assert_zero_expressions.len() + cb.cs.assert_zero_sumcheck_expressions.len(); let selector = cb.create_placeholder_structural_witin(|| "selector"); let selector_type = SelectorType::Prefix(selector.expr()); @@ -62,7 +64,7 @@ pub trait TableCircuit { // lk_record (r_table_len + w_table_len..r_table_len + w_table_len + lk_table_len).collect_vec(), // zero_record - vec![], + (0..zero_len).collect_vec(), ], Chip::new_from_cb(cb), ); @@ -77,6 +79,9 @@ pub trait TableCircuit { if lk_table_len > 0 { cb.cs.lk_selector = Some(selector_type.clone()); } + if zero_len > 0 { + cb.cs.zero_selector = Some(selector_type.clone()); + } let layer = Layer::from_circuit_builder(cb, Self::name(), out_evals); chip.add_layer(layer); diff --git a/ceno_zkvm/src/tables/ram.rs b/ceno_zkvm/src/tables/ram.rs index 505178944..7597abbdc 100644 --- a/ceno_zkvm/src/tables/ram.rs +++ b/ceno_zkvm/src/tables/ram.rs @@ -1,11 +1,11 @@ use ceno_emul::{Addr, VM_REG_COUNT, WORD_SIZE}; use ff_ext::ExtensionField; use gkr_iop::error::CircuitBuilderError; -use multilinear_extensions::{Expression, StructuralWitIn, StructuralWitInType, ToExpr}; +use multilinear_extensions::{Expression, Instance, StructuralWitIn, StructuralWitInType, ToExpr}; use ram_circuit::{DynVolatileRamCircuit, NonVolatileRamCircuit}; use crate::{ - instructions::riscv::constants::UINT_LIMBS, + instructions::riscv::constants::{HEAP_LENGTH_IDX, HINT_LENGTH_IDX, UINT_LIMBS}, structs::{ProgramParams, RAMType}, }; @@ -90,6 +90,10 @@ impl DynVolatileRamTable for HeapTable { ); addr } + + fn dynamic_length_instance() -> Option { + Some(Instance(HEAP_LENGTH_IDX)) + } } pub type HeapInitCircuit = @@ -196,6 +200,10 @@ impl DynVolatileRamTable for HintsTable { fn name() -> &'static str { "HintsTable" } + + fn dynamic_length_instance() -> Option { + Some(Instance(HINT_LENGTH_IDX)) + } } pub type HintsInitCircuit = DynVolatileRamCircuit>; diff --git a/ceno_zkvm/src/tables/ram/ram_circuit.rs b/ceno_zkvm/src/tables/ram/ram_circuit.rs index 9b7d54aa4..0d40bda52 100644 --- a/ceno_zkvm/src/tables/ram/ram_circuit.rs +++ b/ceno_zkvm/src/tables/ram/ram_circuit.rs @@ -16,7 +16,7 @@ use gkr_iop::{ selector::SelectorType, }; use itertools::Itertools; -use multilinear_extensions::{Expression, StructuralWitIn, StructuralWitInType, ToExpr}; +use multilinear_extensions::{Expression, Instance, StructuralWitIn, StructuralWitInType, ToExpr}; use std::{collections::HashMap, marker::PhantomData, ops::Range}; use witness::{InstancePaddingStrategy, RowMajorMatrix}; @@ -180,6 +180,10 @@ pub trait DynVolatileRamTable { fn dynamic_addr(_params: &ProgramParams, _entry_index: usize, _pv: &PublicValues) -> Addr { unimplemented!() } + + fn dynamic_length_instance() -> Option { + None + } } pub trait DynVolatileRamTableConfigTrait: Sized + Send + Sync { diff --git a/ceno_zkvm/src/tables/ram/ram_impl.rs b/ceno_zkvm/src/tables/ram/ram_impl.rs index 78049ecb3..009fefefc 100644 --- a/ceno_zkvm/src/tables/ram/ram_impl.rs +++ b/ceno_zkvm/src/tables/ram/ram_impl.rs @@ -384,7 +384,9 @@ impl DynVolatileRamTableConfig cb: &mut CircuitBuilder, params: &ProgramParams, ) -> Result { - if !DVRAM::DYNAMIC_OFFSET { + if DVRAM::dynamic_length_instance().is_some() { + cb.set_omc_init_dyn(); + } else if !DVRAM::DYNAMIC_OFFSET { cb.set_omc_init_only(); } @@ -436,7 +438,7 @@ impl DynVolatileRamTableConfig return Ok([RowMajorMatrix::empty(), RowMajorMatrix::empty()]); } assert_eq!(num_structural_witin, 2); - if DVRAM::DYNAMIC_OFFSET { + if DVRAM::dynamic_length_instance().is_some() || DVRAM::DYNAMIC_OFFSET { Self::assign_instances_dynamic(config, num_witin, num_structural_witin, data) } else { Self::assign_instances(config, num_witin, num_structural_witin, data) diff --git a/examples/examples/ceno_rt_mem.rs b/examples/examples/ceno_rt_mem.rs index 837dfc25e..d7735a18a 100644 --- a/examples/examples/ceno_rt_mem.rs +++ b/examples/examples/ceno_rt_mem.rs @@ -2,7 +2,7 @@ use core::ptr::{read_volatile, write_volatile}; extern crate ceno_rt; -const OUTPUT_ADDRESS: u32 = 0x3800_0000; +const OUTPUT_ADDRESS: u32 = 0x1800_0000; #[inline(never)] fn main() { diff --git a/gkr_iop/src/circuit_builder.rs b/gkr_iop/src/circuit_builder.rs index 874afa208..f211c8dbe 100644 --- a/gkr_iop/src/circuit_builder.rs +++ b/gkr_iop/src/circuit_builder.rs @@ -89,6 +89,15 @@ pub struct SetTableExpression { pub table_spec: SetTableSpec, } +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +pub enum ShardOMCInitType { + None, + // only init once in first shard + InitOnce, + // init in multi-shards with continuation address range + InitDyn, +} + #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] #[serde(bound = "E: ExtensionField + DeserializeOwned")] pub struct ConstraintSystem { @@ -131,9 +140,7 @@ pub struct ConstraintSystem { pub r_table_expressions_namespace_map: Vec, pub w_table_expressions: Vec>, pub w_table_expressions_namespace_map: Vec, - // specify whether constrains system cover only init_w - // as it imply w/r set and final_w might happen ACROSS shards - pub with_omc_init_only: bool, + pub omc_init_type: ShardOMCInitType, pub lk_selector: Option>, /// lookup expression @@ -196,7 +203,7 @@ impl ConstraintSystem { r_table_expressions_namespace_map: vec![], w_table_expressions: vec![], w_table_expressions_namespace_map: vec![], - with_omc_init_only: false, + omc_init_type: ShardOMCInitType::None, lk_selector: None, lk_expressions: vec![], lk_table_expressions: vec![], @@ -505,11 +512,7 @@ impl ConstraintSystem { name_fn: N, assert_zero_expr: Expression, ) -> Result<(), CircuitBuilderError> { - assert!( - assert_zero_expr.degree() > 0, - "constant expression assert to zero ?" - ); - if assert_zero_expr.degree() == 1 { + if assert_zero_expr.degree() <= 1 { self.assert_zero_expressions.push(assert_zero_expr); let path = self.ns.compute_path(name_fn().into()); self.assert_zero_expressions_namespace_map.push(path); @@ -542,7 +545,11 @@ impl ConstraintSystem { } pub fn set_omc_init_only(&mut self) { - self.with_omc_init_only = true; + self.omc_init_type = ShardOMCInitType::InitOnce; + } + + pub fn set_omc_init_dyn(&mut self) { + self.omc_init_type = ShardOMCInitType::InitDyn; } } @@ -1348,6 +1355,10 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { pub fn set_omc_init_only(&mut self) { self.cs.set_omc_init_only(); } + + pub fn set_omc_init_dyn(&mut self) { + self.cs.set_omc_init_dyn(); + } } /// take items from an iterator until the accumulated "weight" (measured by `f`)