diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 5f1e1087d..eb495a02c 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -63,10 +63,11 @@ jobs: RUSTFLAGS: "-C opt-level=3" run: cargo run --release --package ceno_zkvm --bin e2e -- --platform=ceno --hints=10 --public-io=4191 examples/target/riscv32im-ceno-zkvm-elf/release/examples/fibonacci - - name: Run fibonacci (release + goldilocks) - env: - RUSTFLAGS: "-C opt-level=3" - run: cargo run --release --package ceno_zkvm --no-default-features --features goldilocks --bin e2e -- --field=goldilocks --platform=ceno --hints=10 --public-io=4191 examples/target/riscv32im-ceno-zkvm-elf/release/examples/fibonacci + # note: the global chip does not support goldilocks field yet + # - name: Run fibonacci (release + goldilocks) + # env: + # RUSTFLAGS: "-C opt-level=3" + # run: cargo run --release --package ceno_zkvm --no-default-features --features goldilocks --bin e2e -- --field=goldilocks --platform=ceno --hints=10 --public-io=4191 examples/target/riscv32im-ceno-zkvm-elf/release/examples/fibonacci - name: Run Guest Heap Alloc (debug) env: @@ -80,10 +81,11 @@ jobs: RUSTFLAGS: "-C opt-level=3" run: cargo run --release --package ceno_zkvm --bin e2e -- --platform=ceno examples/target/riscv32im-ceno-zkvm-elf/release/examples/ceno_rt_alloc - - name: Run Guest Heap Alloc (release + goldilocks) - env: - RUSTFLAGS: "-C opt-level=3" - run: cargo run --release --package ceno_zkvm --no-default-features --features goldilocks --bin e2e -- --field=goldilocks --platform=ceno examples/target/riscv32im-ceno-zkvm-elf/release/examples/ceno_rt_alloc + # note: the global chip does not support goldilocks field yet + # - name: Run Guest Heap Alloc (release + goldilocks) + # env: + # RUSTFLAGS: "-C opt-level=3" + # run: cargo run --release --package ceno_zkvm --no-default-features --features goldilocks --bin e2e -- --field=goldilocks --platform=ceno examples/target/riscv32im-ceno-zkvm-elf/release/examples/ceno_rt_alloc - name: Run keccak_syscall (release) env: diff --git a/Cargo.lock b/Cargo.lock index 22db24a3e..64902df6f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1044,6 +1044,7 @@ dependencies = [ "multilinear_extensions", "ndarray", "num", + "num-bigint", "once_cell", "p3", "parse-size", @@ -1903,7 +1904,7 @@ dependencies = [ [[package]] name = "ff_ext" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.9#44e4aa4456b084481a9aef1b7ee5f829221d5a0d" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.12#5f6c787886163236c88c1f9c018aaeefb77e5801" dependencies = [ "once_cell", "p3", @@ -2715,7 +2716,7 @@ dependencies = [ [[package]] name = "mpcs" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.9#44e4aa4456b084481a9aef1b7ee5f829221d5a0d" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.12#5f6c787886163236c88c1f9c018aaeefb77e5801" dependencies = [ "bincode", "clap", @@ -2739,7 +2740,7 @@ dependencies = [ [[package]] name = "multilinear_extensions" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.9#44e4aa4456b084481a9aef1b7ee5f829221d5a0d" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.12#5f6c787886163236c88c1f9c018aaeefb77e5801" dependencies = [ "either", "ff_ext", @@ -3060,8 +3061,9 @@ dependencies = [ [[package]] name = "p3" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.9#44e4aa4456b084481a9aef1b7ee5f829221d5a0d" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.12#5f6c787886163236c88c1f9c018aaeefb77e5801" dependencies = [ + "p3-air", "p3-baby-bear", "p3-challenger", "p3-commit", @@ -3073,12 +3075,23 @@ dependencies = [ "p3-maybe-rayon", "p3-mds", "p3-merkle-tree", + "p3-monty-31", "p3-poseidon", "p3-poseidon2", + "p3-poseidon2-air", "p3-symmetric", "p3-util", ] +[[package]] +name = "p3-air" +version = "0.1.0" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" +dependencies = [ + "p3-field", + "p3-matrix", +] + [[package]] name = "p3-baby-bear" version = "0.1.0" @@ -3294,6 +3307,22 @@ dependencies = [ "rand 0.8.5", ] +[[package]] +name = "p3-poseidon2-air" +version = "0.1.0" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" +dependencies = [ + "p3-air", + "p3-field", + "p3-matrix", + "p3-maybe-rayon", + "p3-poseidon2", + "p3-util", + "rand 0.8.5", + "tikv-jemallocator", + "tracing", +] + [[package]] name = "p3-symmetric" version = "0.1.0" @@ -3469,7 +3498,7 @@ dependencies = [ [[package]] name = "poseidon" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.9#44e4aa4456b084481a9aef1b7ee5f829221d5a0d" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.12#5f6c787886163236c88c1f9c018aaeefb77e5801" dependencies = [ "ff_ext", "p3", @@ -4453,7 +4482,7 @@ dependencies = [ [[package]] name = "sp1-curves" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.9#44e4aa4456b084481a9aef1b7ee5f829221d5a0d" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.12#5f6c787886163236c88c1f9c018aaeefb77e5801" dependencies = [ "cfg-if", "dashu", @@ -4575,7 +4604,7 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.9#44e4aa4456b084481a9aef1b7ee5f829221d5a0d" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.12#5f6c787886163236c88c1f9c018aaeefb77e5801" dependencies = [ "either", "ff_ext", @@ -4593,7 +4622,7 @@ dependencies = [ [[package]] name = "sumcheck_macro" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.9#44e4aa4456b084481a9aef1b7ee5f829221d5a0d" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.12#5f6c787886163236c88c1f9c018aaeefb77e5801" dependencies = [ "itertools 0.13.0", "p3", @@ -4988,7 +5017,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.9#44e4aa4456b084481a9aef1b7ee5f829221d5a0d" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.12#5f6c787886163236c88c1f9c018aaeefb77e5801" dependencies = [ "ff_ext", "itertools 0.13.0", @@ -5260,7 +5289,7 @@ dependencies = [ [[package]] name = "whir" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.9#44e4aa4456b084481a9aef1b7ee5f829221d5a0d" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.12#5f6c787886163236c88c1f9c018aaeefb77e5801" dependencies = [ "bincode", "clap", @@ -5547,7 +5576,7 @@ dependencies = [ [[package]] name = "witness" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.9#44e4aa4456b084481a9aef1b7ee5f829221d5a0d" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.12#5f6c787886163236c88c1f9c018aaeefb77e5801" dependencies = [ "ff_ext", "multilinear_extensions", diff --git a/Cargo.toml b/Cargo.toml index 733390e07..5db870151 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,16 +23,16 @@ repository = "https://github.com/scroll-tech/ceno" version = "0.1.0" [workspace.dependencies] -ff_ext = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "ff_ext", rev = "v1.0.0-alpha.9" } -mpcs = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "mpcs", rev = "v1.0.0-alpha.9" } -multilinear_extensions = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "multilinear_extensions", rev = "v1.0.0-alpha.9" } -p3 = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "p3", rev = "v1.0.0-alpha.9" } -poseidon = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "poseidon", rev = "v1.0.0-alpha.9" } -sp1-curves = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sp1-curves", rev = "v1.0.0-alpha.9" } -sumcheck = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sumcheck", rev = "v1.0.0-alpha.9" } -transcript = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "transcript", rev = "v1.0.0-alpha.9" } -whir = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "whir", rev = "v1.0.0-alpha.9" } -witness = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "witness", rev = "v1.0.0-alpha.9" } +ff_ext = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "ff_ext", tag = "v1.0.0-alpha.12" } +mpcs = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "mpcs", tag = "v1.0.0-alpha.12" } +multilinear_extensions = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "multilinear_extensions", tag = "v1.0.0-alpha.12" } +p3 = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "p3", tag = "v1.0.0-alpha.12" } +poseidon = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "poseidon", tag = "v1.0.0-alpha.12" } +sp1-curves = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sp1-curves", tag = "v1.0.0-alpha.12" } +sumcheck = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sumcheck", tag = "v1.0.0-alpha.12" } +transcript = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "transcript", tag = "v1.0.0-alpha.12" } +whir = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "whir", tag = "v1.0.0-alpha.12" } +witness = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "witness", tag = "v1.0.0-alpha.12" } alloy-primitives = "1.3" anyhow = { version = "1.0", default-features = false } @@ -99,13 +99,14 @@ lto = "thin" # [patch."ssh://git@github.com/scroll-tech/ceno-gpu.git"] # ceno_gpu = { path = "../ceno-gpu/cuda_hal", package = "cuda_hal" } -#[patch."https://github.com/scroll-tech/gkr-backend"] -#ff_ext = { path = "../gkr-backend/crates/ff_ext", package = "ff_ext" } -#mpcs = { path = "../gkr-backend/crates/mpcs", package = "mpcs" } -#multilinear_extensions = { path = "../gkr-backend/crates/multilinear_extensions", package = "multilinear_extensions" } -#p3 = { path = "../gkr-backend/crates/p3", package = "p3" } -#poseidon = { path = "../gkr-backend/crates/poseidon", package = "poseidon" } -#sumcheck = { path = "../gkr-backend/crates/sumcheck", package = "sumcheck" } -#transcript = { path = "../gkr-backend/crates/transcript", package = "transcript" } -#whir = { path = "../gkr-backend/crates/whir", package = "whir" } -#witness = { path = "../gkr-backend/crates/witness", package = "witness" } +# [patch."https://github.com/scroll-tech/gkr-backend"] +# ff_ext = { path = "../gkr-backend/crates/ff_ext", package = "ff_ext" } +# mpcs = { path = "../gkr-backend/crates/mpcs", package = "mpcs" } +# multilinear_extensions = { path = "../gkr-backend/crates/multilinear_extensions", package = "multilinear_extensions" } +# p3 = { path = "../gkr-backend/crates/p3", package = "p3" } +# poseidon = { path = "../gkr-backend/crates/poseidon", package = "poseidon" } +# sp1-curves = { path = "../gkr-backend/crates/curves", package = "sp1-curves" } +# sumcheck = { path = "../gkr-backend/crates/sumcheck", package = "sumcheck" } +# transcript = { path = "../gkr-backend/crates/transcript", package = "transcript" } +# whir = { path = "../gkr-backend/crates/whir", package = "whir" } +# witness = { path = "../gkr-backend/crates/witness", package = "witness" } diff --git a/ceno_emul/src/syscalls/bn254/bn254_fptower.rs b/ceno_emul/src/syscalls/bn254/bn254_fptower.rs index 75c70a055..3fa98f368 100644 --- a/ceno_emul/src/syscalls/bn254/bn254_fptower.rs +++ b/ceno_emul/src/syscalls/bn254/bn254_fptower.rs @@ -12,6 +12,7 @@ use crate::{ use super::types::{BN254_FP_WORDS, BN254_FP2_WORDS}; pub struct Bn254FpAddSpec; + impl SyscallSpec for Bn254FpAddSpec { const NAME: &'static str = "BN254_FP_ADD"; diff --git a/ceno_emul/src/syscalls/secp256k1.rs b/ceno_emul/src/syscalls/secp256k1.rs index 2facffba4..fafabe78c 100644 --- a/ceno_emul/src/syscalls/secp256k1.rs +++ b/ceno_emul/src/syscalls/secp256k1.rs @@ -6,7 +6,9 @@ use std::iter; use super::{SyscallEffects, SyscallSpec, SyscallWitness}; pub struct Secp256k1AddSpec; + pub struct Secp256k1DoubleSpec; + pub struct Secp256k1DecompressSpec; impl SyscallSpec for Secp256k1AddSpec { diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index 3c1c99ed4..07d1394ac 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -48,6 +48,7 @@ derive = { path = "../derive" } generic-array.workspace = true generic_static = "0.2" num.workspace = true +num-bigint = "0.4.6" parse-size = "1.1" rand.workspace = true sp1-curves.workspace = true diff --git a/ceno_zkvm/benches/riscv_add.rs b/ceno_zkvm/benches/riscv_add.rs index 028748058..9d8cc22e8 100644 --- a/ceno_zkvm/benches/riscv_add.rs +++ b/ceno_zkvm/benches/riscv_add.rs @@ -111,7 +111,8 @@ fn bench_add(c: &mut Criterion) { witness: polys, structural_witness: vec![], public_input: vec![], - num_instances, + num_instances: vec![num_instances], + has_ecc_ops: false, }; let _ = prover .create_chip_proof( diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index dbd9961a9..9ea76595a 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -4,9 +4,10 @@ use gkr_iop::{error::CircuitBuilderError, tables::LookupTable}; use crate::{ circuit_builder::CircuitBuilder, instructions::riscv::constants::{ - END_CYCLE_IDX, END_PC_IDX, END_SHARD_ID_IDX, EXIT_CODE_IDX, INIT_CYCLE_IDX, INIT_PC_IDX, - PUBLIC_IO_IDX, UINT_LIMBS, + END_CYCLE_IDX, END_PC_IDX, END_SHARD_ID_IDX, EXIT_CODE_IDX, GLOBAL_RW_SUM_IDX, + INIT_CYCLE_IDX, INIT_PC_IDX, PUBLIC_IO_IDX, UINT_LIMBS, }, + scheme::constants::SEPTIC_EXTENSION_DEGREE, tables::InsnRecord, }; use multilinear_extensions::{Expression, Instance}; @@ -21,6 +22,7 @@ pub trait PublicIOQuery { fn query_init_cycle(&mut self) -> Result; fn query_end_pc(&mut self) -> Result; fn query_end_cycle(&mut self) -> Result; + fn query_global_rw_sum(&mut self) -> Result, CircuitBuilderError>; fn query_public_io(&mut self) -> Result<[Instance; UINT_LIMBS], CircuitBuilderError>; #[allow(dead_code)] fn query_shard_id(&mut self) -> Result; @@ -73,4 +75,23 @@ impl<'a, E: ExtensionField> PublicIOQuery for CircuitBuilder<'a, E> { .query_instance(|| "public_io_high", PUBLIC_IO_IDX + 1)?, ]) } + + fn query_global_rw_sum(&mut self) -> Result, CircuitBuilderError> { + let x = (0..SEPTIC_EXTENSION_DEGREE) + .map(|i| { + self.cs + .query_instance(|| format!("global_rw_sum_x_{}", i), GLOBAL_RW_SUM_IDX + i) + }) + .collect::, CircuitBuilderError>>()?; + let y = (0..SEPTIC_EXTENSION_DEGREE) + .map(|i| { + self.cs.query_instance( + || format!("global_rw_sum_y_{}", i), + GLOBAL_RW_SUM_IDX + SEPTIC_EXTENSION_DEGREE + i, + ) + }) + .collect::, CircuitBuilderError>>()?; + + Ok([x, y].concat()) + } } diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 712f3b7a1..62f3e425f 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -3,6 +3,7 @@ use crate::{ instructions::riscv::{DummyExtraConfig, MemPadder, MmuConfig, Rv32imConfig}, scheme::{ PublicValues, ZKVMProof, + constants::SEPTIC_EXTENSION_DEGREE, hal::ProverDevice, mock_prover::{LkMultiplicityKey, MockProver}, prover::ZKVMProver, @@ -105,10 +106,16 @@ pub struct RAMRecord { pub ram_type: RAMType, pub id: u64, pub addr: WordAddr, + // prev_cycle and cycle are global cycle pub prev_cycle: Cycle, pub cycle: Cycle, + // shard_cycle is cycle in current local shard, which already offset by start cycle + pub shard_cycle: Cycle, pub prev_value: Option, pub value: Word, + // for global reads, `shard_id` refers to the shard that previously produced this value. + // for global write, `shard_id` refers to current shard. + pub shard_id: usize, } #[derive(Clone, Debug)] @@ -154,6 +161,7 @@ pub struct ShardContext<'a> { write_thread_based_record_storage: Either>, &'a mut BTreeMap>, pub cur_shard_cycle_range: std::ops::Range, + pub expected_inst_per_shard: usize, } impl<'a> Default for ShardContext<'a> { @@ -176,6 +184,7 @@ impl<'a> Default for ShardContext<'a> { .collect::>(), ), cur_shard_cycle_range: Tracer::SUBCYCLES_PER_INSN as usize..usize::MAX, + expected_inst_per_shard: usize::MAX, } } } @@ -222,6 +231,7 @@ impl<'a> ShardContext<'a> { .collect::>(), ), cur_shard_cycle_range, + expected_inst_per_shard, } } @@ -243,6 +253,7 @@ impl<'a> ShardContext<'a> { read_thread_based_record_storage: Either::Right(read), write_thread_based_record_storage: Either::Right(write), cur_shard_cycle_range: self.cur_shard_cycle_range.clone(), + expected_inst_per_shard: self.expected_inst_per_shard, }) .collect_vec(), _ => panic!("invalid type"), @@ -278,15 +289,28 @@ impl<'a> ShardContext<'a> { self.cur_shard_cycle_range.contains(&(cycle as usize)) } + #[inline(always)] + pub fn extract_prev_shard_id(&self, cycle: Cycle) -> usize { + let subcycle_per_insn = Tracer::SUBCYCLES_PER_INSN; + let per_shard_cycles = + (self.expected_inst_per_shard as u64).saturating_mul(subcycle_per_insn); + ((cycle.saturating_sub(subcycle_per_insn)) / per_shard_cycles) as usize + } + #[inline(always)] pub fn aligned_prev_ts(&self, prev_cycle: Cycle) -> Cycle { - let mut ts = prev_cycle - self.current_shard_offset_cycle(); + let mut ts = prev_cycle.saturating_sub(self.current_shard_offset_cycle()); if ts < Tracer::SUBCYCLES_PER_INSN { ts = 0 } ts } + #[inline(always)] + pub fn aligned_current_ts(&self, cycle: Cycle) -> Cycle { + cycle.saturating_sub(self.current_shard_offset_cycle()) + } + pub fn current_shard_offset_cycle(&self) -> Cycle { // cycle of each local shard start from Tracer::SUBCYCLES_PER_INSN (self.cur_shard_cycle_range.start as Cycle) - Tracer::SUBCYCLES_PER_INSN @@ -310,6 +334,7 @@ impl<'a> ShardContext<'a> { && self.is_current_shard_cycle(cycle) && !self.is_first_shard() { + let prev_shard_id = self.extract_prev_shard_id(prev_cycle); let ram_record = self .read_thread_based_record_storage .as_mut() @@ -323,8 +348,10 @@ impl<'a> ShardContext<'a> { addr, prev_cycle, cycle, + shard_cycle: 0, prev_value, value, + shard_id: prev_shard_id, }, ); } @@ -347,6 +374,7 @@ impl<'a> ShardContext<'a> { && future_touch_cycle >= self.cur_shard_cycle_range.end as Cycle && self.is_current_shard_cycle(cycle) { + let shard_cycle = self.aligned_current_ts(cycle); let ram_record = self .write_thread_based_record_storage .as_mut() @@ -360,8 +388,10 @@ impl<'a> ShardContext<'a> { addr, prev_cycle, cycle, + shard_cycle, prev_value, value, + shard_id: self.shards.shard_id, }, ); } @@ -433,6 +463,7 @@ pub fn emulate_program<'a>( end_cycle, shards.shard_id as u32, io_init.iter().map(|rec| rec.value).collect_vec(), + vec![0; SEPTIC_EXTENSION_DEGREE * 2], // point_at_infinity ); // Find the final register values and cycles. diff --git a/ceno_zkvm/src/gadgets/mod.rs b/ceno_zkvm/src/gadgets/mod.rs index 5e429354f..a4d624568 100644 --- a/ceno_zkvm/src/gadgets/mod.rs +++ b/ceno_zkvm/src/gadgets/mod.rs @@ -1,6 +1,7 @@ mod div; mod field; mod is_lt; +mod poseidon2; mod signed; mod signed_ext; mod signed_limbs; @@ -13,6 +14,7 @@ pub use gkr_iop::gadgets::{ AssertLtConfig, InnerLtConfig, IsEqualConfig, IsLtConfig, IsZeroConfig, cal_lt_diff, }; pub use is_lt::{AssertSignedLtConfig, SignedLtConfig}; +pub use poseidon2::{Poseidon2BabyBearConfig, Poseidon2Config}; pub use signed::Signed; pub use signed_ext::SignedExtendConfig; pub use signed_limbs::{UIntLimbsLT, UIntLimbsLTConfig}; diff --git a/ceno_zkvm/src/gadgets/poseidon2.rs b/ceno_zkvm/src/gadgets/poseidon2.rs new file mode 100644 index 000000000..0eca74c50 --- /dev/null +++ b/ceno_zkvm/src/gadgets/poseidon2.rs @@ -0,0 +1,524 @@ +// Poseidon2 over BabyBear field + +use std::{ + borrow::{Borrow, BorrowMut}, + iter::from_fn, + mem::transmute, +}; + +use ff_ext::{BabyBearExt4, ExtensionField}; +use gkr_iop::error::CircuitBuilderError; +use itertools::Itertools; +use multilinear_extensions::{Expression, ToExpr, WitIn}; +use num_bigint::BigUint; +use p3::{ + babybear::BabyBearInternalLayerParameters, + field::{Field, FieldAlgebra, PrimeField}, + monty_31::InternalLayerBaseParameters, + poseidon2::{GenericPoseidon2LinearLayers, MDSMat4, mds_light_permutation}, + poseidon2_air::{FullRound, PartialRound, Poseidon2Cols, SBox, num_cols}, +}; + +use crate::circuit_builder::CircuitBuilder; + +// copied from poseidon2-air/src/constants.rs +// as the original one cannot be accessed here +#[derive(Debug, Clone)] +pub struct RoundConstants< + F: Field, + const WIDTH: usize, + const HALF_FULL_ROUNDS: usize, + const PARTIAL_ROUNDS: usize, +> { + pub beginning_full_round_constants: [[F; WIDTH]; HALF_FULL_ROUNDS], + pub partial_round_constants: [F; PARTIAL_ROUNDS], + pub ending_full_round_constants: [[F; WIDTH]; HALF_FULL_ROUNDS], +} + +impl + From> for RoundConstants +{ + fn from(value: Vec) -> Self { + let mut iter = value.into_iter(); + let mut beginning_full_round_constants = [[F::ZERO; WIDTH]; HALF_FULL_ROUNDS]; + + beginning_full_round_constants.iter_mut().for_each(|arr| { + arr.iter_mut() + .for_each(|c| *c = iter.next().expect("insufficient round constants")) + }); + + let mut partial_round_constants = [F::ZERO; PARTIAL_ROUNDS]; + + partial_round_constants + .iter_mut() + .for_each(|arr| *arr = iter.next().expect("insufficient round constants")); + + let mut ending_full_round_constants = [[F::ZERO; WIDTH]; HALF_FULL_ROUNDS]; + ending_full_round_constants.iter_mut().for_each(|arr| { + arr.iter_mut() + .for_each(|c| *c = iter.next().expect("insufficient round constants")) + }); + + assert!(iter.next().is_none(), "round constants are too many"); + + RoundConstants { + beginning_full_round_constants, + partial_round_constants, + ending_full_round_constants, + } + } +} + +pub type Poseidon2BabyBearConfig = Poseidon2Config; +pub struct Poseidon2Config< + E: ExtensionField, + const STATE_WIDTH: usize, + const SBOX_DEGREE: u64, + const SBOX_REGISTERS: usize, + const HALF_FULL_ROUNDS: usize, + const PARTIAL_ROUNDS: usize, +> { + cols: Vec, + constants: RoundConstants, +} + +#[derive(Debug, Clone)] +pub struct Poseidon2LinearLayers; + +impl GenericPoseidon2LinearLayers + for Poseidon2LinearLayers +{ + fn internal_linear_layer(state: &mut [F; WIDTH]) { + // this only works when F is BabyBear field for now + let babybear_prime = BigUint::from(0x7800_0001u32); + if F::order() == babybear_prime { + let diag_m1_matrix = &>::INTERNAL_DIAG_MONTY; + let diag_m1_matrix: &[F; WIDTH] = unsafe { transmute(diag_m1_matrix) }; + let sum = state.iter().cloned().sum::(); + for (input, diag_m1) in state.iter_mut().zip(diag_m1_matrix) { + *input = sum + F::from_f(*diag_m1) * *input; + } + } else { + panic!("Unsupported field"); + } + } + + fn external_linear_layer(state: &mut [F; WIDTH]) { + mds_light_permutation(state, &MDSMat4); + } +} + +impl< + E: ExtensionField, + const STATE_WIDTH: usize, + const SBOX_DEGREE: u64, + const SBOX_REGISTERS: usize, + const HALF_FULL_ROUNDS: usize, + const PARTIAL_ROUNDS: usize, +> Poseidon2Config +{ + // constraints taken from poseidon2_air/src/air.rs + fn eval_sbox( + sbox: &SBox, SBOX_DEGREE, SBOX_REGISTERS>, + x: &mut Expression, + cb: &mut CircuitBuilder, + ) -> Result<(), CircuitBuilderError> { + *x = match (SBOX_DEGREE, SBOX_REGISTERS) { + (3, 0) => x.cube(), + (5, 0) => x.exp_const_u64::<5>(), + (7, 0) => x.exp_const_u64::<7>(), + (5, 1) => { + let committed_x3: Expression = sbox.0[0].clone(); + let x2: Expression = x.square(); + cb.require_zero( + || "x3 = x.cube()", + committed_x3.clone() - x2.clone() * x.clone(), + )?; + committed_x3 * x2 + } + (7, 1) => { + let committed_x3: Expression = sbox.0[0].clone(); + // TODO: avoid x^3 as x may have ~STATE_WIDTH terms after the linear layer + // we can allocate one more column to store x^2 (which has ~STATE_WIDTH^2 terms) + // then x^3 = x * x^2 + // but this will increase the number of columns (by FULL_ROUNDS * STATE_WIDTH + PARTIAL_ROUNDS) + cb.require_zero(|| "x3 = x.cube()", committed_x3.clone() - x.cube())?; + committed_x3.square() * x.clone() + } + _ => panic!( + "Unexpected (SBOX_DEGREE, SBOX_REGISTERS) of ({}, {})", + SBOX_DEGREE, SBOX_REGISTERS + ), + }; + + Ok(()) + } + + fn eval_full_round( + state: &mut [Expression; STATE_WIDTH], + full_round: &FullRound, STATE_WIDTH, SBOX_DEGREE, SBOX_REGISTERS>, + round_constants: &[E::BaseField], + cb: &mut CircuitBuilder, + ) -> Result<(), CircuitBuilderError> { + for (i, (s, r)) in state.iter_mut().zip_eq(round_constants.iter()).enumerate() { + *s = s.clone() + r.expr(); + Self::eval_sbox(&full_round.sbox[i], s, cb)?; + } + Self::external_linear_layer(state); + for (state_i, post_i) in state.iter_mut().zip_eq(full_round.post.iter()) { + cb.require_zero(|| "post_i = state_i", state_i.clone() - post_i)?; + *state_i = post_i.clone(); + } + + Ok(()) + } + + fn eval_partial_round( + state: &mut [Expression; STATE_WIDTH], + partial_round: &PartialRound, STATE_WIDTH, SBOX_DEGREE, SBOX_REGISTERS>, + round_constant: &E::BaseField, + cb: &mut CircuitBuilder, + ) -> Result<(), CircuitBuilderError> { + state[0] = state[0].clone() + round_constant.expr(); + Self::eval_sbox(&partial_round.sbox, &mut state[0], cb)?; + + cb.require_zero( + || "state[0] = post_sbox", + state[0].clone() - partial_round.post_sbox.clone(), + )?; + state[0] = partial_round.post_sbox.clone(); + + Self::internal_linear_layer(state); + + Ok(()) + } + + fn external_linear_layer(state: &mut [Expression; STATE_WIDTH]) { + mds_light_permutation(state, &MDSMat4); + } + + fn internal_linear_layer(state: &mut [Expression; STATE_WIDTH]) { + let sum: Expression = state.iter().map(|s| s.get_monomial_form()).sum(); + // reduce to monomial form + let sum = sum.get_monomial_form(); + let babybear_prime = BigUint::from(0x7800_0001u32); + if E::BaseField::order() == babybear_prime { + // BabyBear + let diag_m1_matrix_bb = + &>:: + INTERNAL_DIAG_MONTY; + let diag_m1_matrix: &[E::BaseField; STATE_WIDTH] = + unsafe { transmute(diag_m1_matrix_bb) }; + for (input, diag_m1) in state.iter_mut().zip_eq(diag_m1_matrix) { + let updated = sum.clone() + Expression::from_f(*diag_m1) * input.clone(); + // reduce to monomial form + *input = updated.get_monomial_form(); + } + } else { + panic!("Unsupported field"); + } + } + + pub fn construct( + cb: &mut CircuitBuilder, + round_constants: RoundConstants< + E::BaseField, + STATE_WIDTH, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + >, + ) -> Self { + let num_cols = + num_cols::( + ); + let cols = from_fn(|| Some(cb.create_witin(|| "poseidon2 col"))) + .take(num_cols) + .collect::>(); + let mut col_exprs = cols + .iter() + .map(|c| c.expr()) + .collect::>>(); + + let poseidon2_cols: &mut Poseidon2Cols< + Expression, + STATE_WIDTH, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + > = col_exprs.as_mut_slice().borrow_mut(); + + // external linear layer + Self::external_linear_layer(&mut poseidon2_cols.inputs); + + // eval full round + for round in 0..HALF_FULL_ROUNDS { + Self::eval_full_round( + &mut poseidon2_cols.inputs, + &poseidon2_cols.beginning_full_rounds[round], + &round_constants.beginning_full_round_constants[round], + cb, + ) + .unwrap(); + } + + // eval partial round + for round in 0..PARTIAL_ROUNDS { + Self::eval_partial_round( + &mut poseidon2_cols.inputs, + &poseidon2_cols.partial_rounds[round], + &round_constants.partial_round_constants[round], + cb, + ) + .unwrap(); + } + + // TODO: after the last partial round, each state_i has ~STATE_WIDTH terms + // which will make the next full round to have many terms + + // eval full round + for round in 0..HALF_FULL_ROUNDS { + Self::eval_full_round( + &mut poseidon2_cols.inputs, + &poseidon2_cols.ending_full_rounds[round], + &round_constants.ending_full_round_constants[round], + cb, + ) + .unwrap(); + } + + Poseidon2Config { + cols, + constants: round_constants, + } + } + + pub fn inputs(&self) -> Vec> { + let col_exprs = self.cols.iter().map(|c| c.expr()).collect::>(); + + let poseidon2_cols: &Poseidon2Cols< + Expression, + STATE_WIDTH, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + > = col_exprs.as_slice().borrow(); + + poseidon2_cols.inputs.to_vec() + } + + pub fn output(&self) -> Vec> { + let col_exprs = self.cols.iter().map(|c| c.expr()).collect::>(); + + let poseidon2_cols: &Poseidon2Cols< + Expression, + STATE_WIDTH, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + > = col_exprs.as_slice().borrow(); + + poseidon2_cols + .ending_full_rounds + .last() + .map(|r| r.post.to_vec()) + .unwrap() + } + + pub fn assign_instance( + &self, + instance: &mut [E::BaseField], + state: [E::BaseField; STATE_WIDTH], + ) { + let poseidon2_cols: &mut Poseidon2Cols< + E::BaseField, + STATE_WIDTH, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + > = instance.borrow_mut(); + + generate_trace_rows_for_perm::< + E::BaseField, + Poseidon2LinearLayers, + STATE_WIDTH, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + >(poseidon2_cols, state, &self.constants); + } +} + +////////////////////////////////////////////////////////////////////////// +/// The following routines are taken from poseidon2-air/src/generation.rs +////////////////////////////////////////////////////////////////////////// +fn generate_trace_rows_for_perm< + F: PrimeField, + LinearLayers: GenericPoseidon2LinearLayers, + const WIDTH: usize, + const SBOX_DEGREE: u64, + const SBOX_REGISTERS: usize, + const HALF_FULL_ROUNDS: usize, + const PARTIAL_ROUNDS: usize, +>( + perm: &mut Poseidon2Cols< + F, + WIDTH, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + >, + mut state: [F; WIDTH], + constants: &RoundConstants, +) { + perm.export = F::ONE; + perm.inputs + .iter_mut() + .zip(state.iter()) + .for_each(|(input, &x)| { + *input = x; + }); + + LinearLayers::external_linear_layer(&mut state); + + for (full_round, constants) in perm + .beginning_full_rounds + .iter_mut() + .zip(&constants.beginning_full_round_constants) + { + generate_full_round::( + &mut state, full_round, constants, + ); + } + + for (partial_round, constant) in perm + .partial_rounds + .iter_mut() + .zip(&constants.partial_round_constants) + { + generate_partial_round::( + &mut state, + partial_round, + *constant, + ); + } + + for (full_round, constants) in perm + .ending_full_rounds + .iter_mut() + .zip(&constants.ending_full_round_constants) + { + generate_full_round::( + &mut state, full_round, constants, + ); + } +} + +#[inline] +fn generate_full_round< + F: PrimeField, + LinearLayers: GenericPoseidon2LinearLayers, + const WIDTH: usize, + const SBOX_DEGREE: u64, + const SBOX_REGISTERS: usize, +>( + state: &mut [F; WIDTH], + full_round: &mut FullRound, + round_constants: &[F; WIDTH], +) { + for (state_i, const_i) in state.iter_mut().zip(round_constants) { + *state_i += *const_i; + } + for (state_i, sbox_i) in state.iter_mut().zip(full_round.sbox.iter_mut()) { + generate_sbox(sbox_i, state_i); + } + LinearLayers::external_linear_layer(state); + full_round + .post + .iter_mut() + .zip(*state) + .for_each(|(post, x)| { + *post = x; + }); +} + +#[inline] +fn generate_partial_round< + F: PrimeField, + LinearLayers: GenericPoseidon2LinearLayers, + const WIDTH: usize, + const SBOX_DEGREE: u64, + const SBOX_REGISTERS: usize, +>( + state: &mut [F; WIDTH], + partial_round: &mut PartialRound, + round_constant: F, +) { + state[0] += round_constant; + generate_sbox(&mut partial_round.sbox, &mut state[0]); + partial_round.post_sbox = state[0]; + LinearLayers::internal_linear_layer(state); +} + +#[inline] +fn generate_sbox( + sbox: &mut SBox, + x: &mut F, +) { + *x = match (DEGREE, REGISTERS) { + (3, 0) => x.cube(), + (5, 0) => x.exp_const_u64::<5>(), + (7, 0) => x.exp_const_u64::<7>(), + (5, 1) => { + let x2 = x.square(); + let x3 = x2 * *x; + sbox.0[0] = x3; + x3 * x2 + } + (7, 1) => { + let x3 = x.cube(); + sbox.0[0] = x3; + x3 * x3 * *x + } + (11, 2) => { + let x2 = x.square(); + let x3 = x2 * *x; + let x9 = x3.cube(); + sbox.0[0] = x3; + sbox.0[1] = x9; + x9 * x2 + } + _ => panic!( + "Unexpected (DEGREE, REGISTERS) of ({}, {})", + DEGREE, REGISTERS + ), + } +} + +#[cfg(test)] +mod tests { + use crate::gadgets::poseidon2::Poseidon2BabyBearConfig; + use ff_ext::{BabyBearExt4, PoseidonField}; + use gkr_iop::circuit_builder::{CircuitBuilder, ConstraintSystem}; + use p3::babybear::BabyBear; + + type E = BabyBearExt4; + type F = BabyBear; + #[test] + fn test_poseidon2_gadget() { + let mut cs = ConstraintSystem::new(|| "poseidon2 gadget test"); + let mut cb = CircuitBuilder::::new(&mut cs); + + // let poseidon2_constants = horizen_round_consts(); + let rc = ::get_default_perm_rc().into(); + let _ = Poseidon2BabyBearConfig::construct(&mut cb, rc); + } +} diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index 13a3ed22b..12c137aa8 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -19,6 +19,7 @@ use rayon::{ }; use witness::{InstancePaddingStrategy, RowMajorMatrix, set_val}; +pub mod global; pub mod riscv; pub trait Instruction { @@ -56,7 +57,7 @@ pub trait Instruction { descending: false, }, ); - let selector_type = SelectorType::Prefix(E::BaseField::ZERO, selector.expr()); + let selector_type = SelectorType::Prefix(selector.expr()); // all shared the same selector let (out_evals, mut chip) = ( @@ -79,7 +80,7 @@ pub trait Instruction { cb.cs.lk_selector = Some(selector_type.clone()); cb.cs.zero_selector = Some(selector_type.clone()); - let layer = Layer::from_circuit_builder(cb, "Rounds".to_string(), 0, out_evals); + let layer = Layer::from_circuit_builder(cb, format!("{}_main", Self::name()), 0, out_evals); chip.add_layer(layer); Ok((config, chip.gkr_circuit())) diff --git a/ceno_zkvm/src/instructions/global.rs b/ceno_zkvm/src/instructions/global.rs new file mode 100644 index 000000000..c98cb3634 --- /dev/null +++ b/ceno_zkvm/src/instructions/global.rs @@ -0,0 +1,844 @@ +use std::{collections::HashMap, iter::repeat_n, marker::PhantomData}; + +use crate::{ + Value, + chip_handler::general::PublicIOQuery, + e2e::RAMRecord, + error::ZKVMError, + gadgets::Poseidon2Config, + instructions::riscv::constants::UINT_LIMBS, + scheme::septic_curve::{SepticExtension, SepticPoint}, + structs::{ProgramParams, RAMType}, + tables::{RMMCollections, TableCircuit}, + witness::LkMultiplicity, +}; +use ceno_emul::WordAddr; +use ff_ext::{ExtensionField, FieldInto, PoseidonField, SmallField}; +use gkr_iop::{ + chip::Chip, + circuit_builder::CircuitBuilder, + error::CircuitBuilderError, + gkr::{GKRCircuit, layer::Layer}, + selector::SelectorType, +}; +use itertools::{Itertools, chain}; +use multilinear_extensions::{ + Expression, StructuralWitInType::EqualDistanceSequence, ToExpr, WitIn, util::max_usable_threads, +}; +use p3::{ + field::{Field, FieldAlgebra}, + matrix::dense::RowMajorMatrix, + symmetric::Permutation, +}; +use rayon::{ + iter::{ + IndexedParallelIterator, IntoParallelIterator, IntoParallelRefMutIterator, ParallelExtend, + ParallelIterator, + }, + prelude::ParallelSliceMut, + slice::ParallelSlice, +}; +use std::ops::Deref; +use witness::{InstancePaddingStrategy, next_pow2_instance_padding, set_val}; + +use crate::{instructions::riscv::constants::UInt, scheme::constants::SEPTIC_EXTENSION_DEGREE}; + +/// A record for a read/write into the global set +#[derive(Debug, Clone)] +pub struct GlobalRecord { + pub addr: u32, + pub ram_type: RAMType, + pub value: u32, + pub shard: u64, + pub local_clk: u64, + pub global_clk: u64, + pub is_to_write_set: bool, +} + +impl From<(&WordAddr, &RAMRecord, bool)> for GlobalRecord { + fn from((vma, record, is_to_write_set): (&WordAddr, &RAMRecord, bool)) -> Self { + let addr = match record.ram_type { + RAMType::Register => record.id as u32, + RAMType::Memory => (*vma).into(), + _ => unreachable!(), + }; + let (shard, local_clk, global_clk, value) = if is_to_write_set { + // global write -> local read + ( + record.shard_id, + record.shard_cycle, + record.cycle, + // local read is for cancel final write value in `Write` set + record.value, + ) + } else { + // global read -> local write + debug_assert_eq!(record.shard_cycle, 0); + ( + record.shard_id, + 0, + record.prev_cycle, + // local write is for adapting write from previous shard + record.prev_value.unwrap_or(record.value), + ) + }; + + GlobalRecord { + addr, + ram_type: record.ram_type, + value, + shard: shard as u64, + local_clk, + global_clk, + is_to_write_set, + } + } +} +/// An EC point corresponding to a global read/write record +/// whose x-coordinate is derived from Poseidon2 hash of the record +#[derive(Clone, Debug)] +pub struct GlobalPoint { + pub nonce: u32, + pub point: SepticPoint, +} + +impl GlobalRecord { + pub fn to_ec_point>>( + &self, + hasher: &P, + ) -> GlobalPoint { + let mut nonce = 0; + let mut input = vec![ + E::BaseField::from_canonical_u32(self.addr), + E::BaseField::from_canonical_u32(self.ram_type as u32), + E::BaseField::from_canonical_u32(self.value & 0xFFFF), // lower 16 bits + E::BaseField::from_canonical_u32((self.value >> 16) & 0xFFFF), // higher 16 bits + E::BaseField::from_canonical_u64(self.shard), + E::BaseField::from_canonical_u64(self.global_clk), + E::BaseField::from_canonical_u32(nonce), + E::BaseField::ZERO, + E::BaseField::ZERO, + E::BaseField::ZERO, + E::BaseField::ZERO, + E::BaseField::ZERO, + E::BaseField::ZERO, + E::BaseField::ZERO, + E::BaseField::ZERO, + E::BaseField::ZERO, + ]; + + let prime = E::BaseField::order().to_u64_digits()[0]; + loop { + let x: SepticExtension = + hasher.permute(input.clone())[0..SEPTIC_EXTENSION_DEGREE].into(); + if let Some(p) = SepticPoint::from_x(x) { + let y6 = (p.y.0)[SEPTIC_EXTENSION_DEGREE - 1].to_canonical_u64(); + let is_y_in_2nd_half = y6 >= (prime / 2); + + // we negate y if needed + // to ensure read => y in [0, p/2) and write => y in [p/2, p) + let negate = match (self.is_to_write_set, is_y_in_2nd_half) { + (true, false) => true, // write, y in [0, p/2) + (false, true) => true, // read, y in [p/2, p) + _ => false, + }; + + let point = if negate { -p } else { p }; + + return GlobalPoint { nonce, point }; + } else { + // try again with different nonce + nonce += 1; + input[6] = E::BaseField::from_canonical_u32(nonce); + } + } + } +} +/// opcode circuit + mem init/final table + local finalize circuit + global chip +/// global chip is used to ensure the **local** reads and writes produced by +/// opcode circuits / memory init / memory finalize table / local finalize circuit +/// can balance out. +/// +/// 1. For a local memory read record whose previous write is not in the same shard, +/// the global chip will read it from the **global set** and insert a local write record. +/// 2. For a local memory write record which will **not** be read in the future, +/// the local finalize circuit will consume it by inserting a local read record. +/// 3. For a local memory write record which will be read in the future, +/// the global chip will insert a local read record and write it to the **global set**. +pub struct GlobalConfig { + addr: WitIn, + is_ram_register: WitIn, + value: UInt, + shard: WitIn, + global_clk: WitIn, + local_clk: WitIn, + nonce: WitIn, + // if it's a write to global set, then insert a local read record + // s.t. local offline memory checking can cancel out + // this serves as propagating local write to global. + is_global_write: WitIn, + x: Vec, + y: Vec, + slope: Vec, + perm_config: Poseidon2Config, +} + +impl GlobalConfig { + // TODO: make `WIDTH`, `HALF_FULL_ROUNDS`, `PARTIAL_ROUNDS` generic parameters + pub fn configure(cb: &mut CircuitBuilder) -> Result { + let x: Vec = (0..SEPTIC_EXTENSION_DEGREE) + .map(|i| cb.create_witin(|| format!("x{}", i))) + .collect(); + let y: Vec = (0..SEPTIC_EXTENSION_DEGREE) + .map(|i| cb.create_witin(|| format!("y{}", i))) + .collect(); + let slope: Vec = (0..SEPTIC_EXTENSION_DEGREE) + .map(|i| cb.create_witin(|| format!("slope{}", i))) + .collect(); + let addr = cb.create_witin(|| "addr"); + let is_ram_register = cb.create_witin(|| "is_ram_register"); + let value = UInt::new_unchecked(|| "value", cb)?; + let shard = cb.create_witin(|| "shard"); + let global_clk = cb.create_witin(|| "global_clk"); + let local_clk = cb.create_witin(|| "local_clk"); + let nonce = cb.create_witin(|| "nonce"); + let is_global_write = cb.create_witin(|| "is_global_write"); + + let is_ram_reg: Expression = is_ram_register.expr(); + let reg: Expression = RAMType::Register.into(); + let mem: Expression = RAMType::Memory.into(); + let ram_type: Expression = is_ram_reg.clone() * reg + (1 - is_ram_reg) * mem; + + let rc = ::get_default_perm_rc().into(); + let perm_config = Poseidon2Config::construct(cb, rc); + + let mut input = vec![]; + input.push(addr.expr()); + input.push(ram_type.clone()); + // memory expr has same number of limbs as register expr + input.extend(value.memory_expr()); + input.push(shard.expr()); + input.push(global_clk.expr()); + // add nonce to ensure poseidon2(input) always map to a valid ec point + input.push(nonce.expr()); + input.extend(repeat_n(E::BaseField::ZERO.expr(), 16 - input.len())); + + let mut record = vec![]; + record.push(ram_type.clone()); + record.push(addr.expr()); + record.extend(value.memory_expr()); + record.push(local_clk.expr()); + + // if is_global_write = 1, then it means we are propagating a local write to global + // so we need to insert a local read record to cancel out this local write + cb.assert_bit(|| "is_global_write must be boolean", is_global_write.expr())?; + // TODO: for all local reads, enforce they come to global writes + // TODO: for all local writes, enforce they come from global reads + + // global read => insert a local write with local_clk = 0 + cb.condition_require_zero( + || "is_global_read => local_clk = 0", + 1 - is_global_write.expr(), + local_clk.expr(), + )?; + // TODO: enforce shard = shard_id in the public values + cb.read_rlc_record( + || "r_record", + ram_type.clone(), + record.clone(), + cb.rlc_chip_record(record.clone()), + )?; + cb.write_rlc_record( + || "w_record", + ram_type, + record.clone(), + cb.rlc_chip_record(record), + )?; + + // enforces final_sum = \sum_i (x_i, y_i) using ecc quark protocol + let final_sum = cb.query_global_rw_sum()?; + cb.ec_sum( + x.iter().map(|xi| xi.expr()).collect::>(), + y.iter().map(|yi| yi.expr()).collect::>(), + slope.iter().map(|si| si.expr()).collect::>(), + final_sum.into_iter().map(|x| x.expr()).collect::>(), + ); + + // enforces x = poseidon2([addr, ram_type, value[0], value[1], shard, global_clk, nonce, 0, ..., 0]) + for (input_expr, hasher_input) in input.into_iter().zip_eq(perm_config.inputs().into_iter()) + { + cb.require_equal(|| "poseidon2 input", input_expr, hasher_input)?; + } + for (xi, hasher_output) in x.iter().zip(perm_config.output().into_iter()) { + cb.require_equal(|| "x = poseidon2's output", xi.expr(), hasher_output)?; + } + + // both (x, y) and (x, -y) are valid ec points + // if is_global_write = 1, then y should be in [0, p/2) + // if is_global_write = 0, then y should be in [p/2, p) + + // TODO: enforce 0 <= y < p/2 if is_global_write = 1 + // enforce p/2 <= y < p if is_global_write = 0 + + Ok(GlobalConfig { + x, + y, + slope, + addr, + is_ram_register, + value, + shard, + global_clk, + local_clk, + nonce, + is_global_write, + perm_config, + }) + } +} + +/// This chip is used to manage read/write into a global set +/// shared among multiple shards +#[derive(Default)] +pub struct GlobalChip { + _marker: PhantomData, +} + +#[derive(Clone, Debug)] +pub struct GlobalChipInput { + pub record: GlobalRecord, + pub ec_point: GlobalPoint, +} + +impl GlobalChip { + fn assign_instance( + config: &GlobalConfig, + instance: &mut [E::BaseField], + _lk_multiplicity: &mut LkMultiplicity, + input: &GlobalChipInput, + ) -> Result<(), crate::error::ZKVMError> { + // assign basic fields + let record = &input.record; + let is_ram_register = match record.ram_type { + RAMType::Register => 1, + RAMType::Memory => 0, + _ => unreachable!(), + }; + set_val!(instance, config.addr, record.addr as u64); + set_val!(instance, config.is_ram_register, is_ram_register as u64); + let value = Value::new_unchecked(record.value); + config.value.assign_limbs(instance, value.as_u16_limbs()); + set_val!(instance, config.shard, record.shard); + set_val!(instance, config.global_clk, record.global_clk); + set_val!(instance, config.local_clk, record.local_clk); + set_val!( + instance, + config.is_global_write, + record.is_to_write_set as u64 + ); + + // assign (x, y) and nonce + let GlobalPoint { nonce, point } = &input.ec_point; + set_val!(instance, config.nonce, *nonce as u64); + config + .x + .iter() + .chain(config.y.iter()) + .zip_eq((point.x.deref()).iter().chain((point.y.deref()).iter())) + .for_each(|(witin, fe)| { + instance[witin.id as usize] = *fe; + }); + + let ram_type = E::BaseField::from_canonical_u32(record.ram_type as u32); + let mut input = [E::BaseField::ZERO; 16]; + + let k = UINT_LIMBS; + input[0] = E::BaseField::from_canonical_u32(record.addr); + input[1] = ram_type; + input[2..(k + 2)] + .iter_mut() + .zip(value.as_u16_limbs().iter()) + .for_each(|(i, v)| *i = E::BaseField::from_canonical_u16(*v)); + input[2 + k] = E::BaseField::from_canonical_u64(record.shard); + input[2 + k + 1] = E::BaseField::from_canonical_u64(record.global_clk); + input[2 + k + 2] = E::BaseField::from_canonical_u32(*nonce); + + config + .perm_config + // TODO: remove hardcoded constant 28 + .assign_instance(&mut instance[28 + UINT_LIMBS..], input); + + Ok(()) + } +} + +impl TableCircuit for GlobalChip { + type TableConfig = GlobalConfig; + type FixedInput = (); + type WitnessInput = Vec>; + + fn name() -> String { + "Global".to_string() + } + + fn construct_circuit( + cb: &mut CircuitBuilder, + _param: &ProgramParams, + ) -> Result { + let config = GlobalConfig::configure(cb)?; + + Ok(config) + } + + fn build_gkr_iop_circuit( + cb: &mut CircuitBuilder, + param: &ProgramParams, + ) -> Result<(Self::TableConfig, Option>), crate::error::ZKVMError> { + // create three selectors: selector_r, selector_w, selector_zero + let selector_r = cb.create_structural_witin( + || "selector_r", + // this is just a placeholder, the actural type is SelectorType::Prefix() + EqualDistanceSequence { + max_len: 0, + offset: 0, + multi_factor: 0, + descending: false, + }, + ); + let selector_w = cb.create_structural_witin( + || "selector_w", + // this is just a placeholder, the actural type is SelectorType::Prefix() + EqualDistanceSequence { + max_len: 0, + offset: 0, + multi_factor: 0, + descending: false, + }, + ); + let selector_zero = cb.create_structural_witin( + || "selector_zero", + // this is just a placeholder, the actural type is SelectorType::Prefix() + EqualDistanceSequence { + max_len: 0, + offset: 0, + multi_factor: 0, + descending: false, + }, + ); + + let config = Self::construct_circuit(cb, param)?; + + let w_len = cb.cs.w_expressions.len(); + let r_len = cb.cs.r_expressions.len(); + let lk_len = cb.cs.lk_expressions.len(); + let zero_len = + cb.cs.assert_zero_expressions.len() + cb.cs.assert_zero_sumcheck_expressions.len(); + + let selector_r = SelectorType::Prefix(selector_r.expr()); + // note that the actual offset should be set by prover + // depending on the number of local read instances + let selector_w = SelectorType::Prefix(selector_w.expr()); + // TODO: when selector_r = 1 => selector_zero = 1 + // when selector_w = 1 => selector_zero = 1 + let selector_zero = SelectorType::Prefix(selector_zero.expr()); + + cb.cs.r_selector = Some(selector_r); + cb.cs.w_selector = Some(selector_w); + cb.cs.zero_selector = Some(selector_zero.clone()); + cb.cs.lk_selector = Some(selector_zero); + + // all shared the same selector + let (out_evals, mut chip) = ( + [ + // r_record + (0..r_len).collect_vec(), + // w_record + (r_len..r_len + w_len).collect_vec(), + // lk_record + (r_len + w_len..r_len + w_len + lk_len).collect_vec(), + // zero_record + (0..zero_len).collect_vec(), + ], + Chip::new_from_cb(cb, 0), + ); + + let layer = Layer::from_circuit_builder(cb, format!("{}_main", Self::name()), 0, out_evals); + chip.add_layer(layer); + + Ok((config, Some(chip.gkr_circuit()))) + } + + fn generate_fixed_traces( + _config: &Self::TableConfig, + _num_fixed: usize, + _input: &Self::FixedInput, + ) -> witness::RowMajorMatrix<::BaseField> { + unimplemented!() + } + + /// steps format: local reads ++ local writes + fn assign_instances<'a>( + config: &Self::TableConfig, + num_witin: usize, + num_structural_witin: usize, + _multiplicity: &[HashMap], + steps: &Self::WitnessInput, + ) -> Result, ZKVMError> { + if steps.is_empty() { + return Ok([ + witness::RowMajorMatrix::empty(), + witness::RowMajorMatrix::empty(), + ]); + } + // FIXME selector is the only structural witness + // this is workaround, as call `construct_circuit` will not initialized selector + // we can remove this one all opcode unittest migrate to call `build_gkr_iop_circuit` + + assert_eq!(num_structural_witin, 3); + let selector_r_witin = WitIn { id: 0 }; + let selector_w_witin = WitIn { id: 1 }; + let selector_zero_witin = WitIn { id: 2 }; + + let nthreads = max_usable_threads(); + + // local read iff it's global write + let num_local_reads = steps + .iter() + .take_while(|s| s.record.is_to_write_set) + .count(); + tracing::debug!( + "{} local reads / {} local writes in global chip", + num_local_reads, + steps.len() - num_local_reads + ); + + let num_instance_per_batch = if steps.len() > 256 { + steps.len().div_ceil(nthreads) + } else { + steps.len() + } + .max(1); + + let n = next_pow2_instance_padding(steps.len()); + // compute the input for the binary tree for ec point summation + + let lk_multiplicity = LkMultiplicity::default(); + // *2 because we need to store the internal nodes of binary tree for ec point summation + let num_rows_padded = 2 * n; + + let mut raw_witin = { + let matrix_size = num_rows_padded * num_witin; + let mut value = Vec::with_capacity(matrix_size); + value.par_extend( + (0..matrix_size) + .into_par_iter() + .map(|_| E::BaseField::default()), + ); + RowMajorMatrix::new(value, num_witin) + }; + let mut raw_structual_witin = { + let matrix_size = num_rows_padded * num_structural_witin; + let mut value = Vec::with_capacity(matrix_size); + value.par_extend( + (0..matrix_size) + .into_par_iter() + .map(|_| E::BaseField::default()), + ); + RowMajorMatrix::new(value, num_structural_witin) + }; + let raw_witin_iter = raw_witin.values[0..steps.len() * num_witin] + .par_chunks_mut(num_instance_per_batch * num_witin); + let raw_structual_witin_iter = raw_structual_witin.values + [0..steps.len() * num_structural_witin] + .par_chunks_mut(num_instance_per_batch * num_structural_witin); + + raw_witin_iter + .zip_eq(raw_structual_witin_iter) + .zip_eq(steps.par_chunks(num_instance_per_batch)) + .enumerate() + .flat_map(|(chunk_idx, ((instances, structural_instance), steps))| { + let mut lk_multiplicity = lk_multiplicity.clone(); + instances + .chunks_mut(num_witin) + .zip_eq(structural_instance.chunks_mut(num_structural_witin)) + .zip_eq(steps) + .enumerate() + .map(|(i, ((instance, structural_instance), step))| { + let row = chunk_idx * num_instance_per_batch + i; + let (sel_r, sel_w) = if row < num_local_reads { + (E::BaseField::ONE, E::BaseField::ZERO) + } else { + (E::BaseField::ZERO, E::BaseField::ONE) + }; + set_val!(structural_instance, selector_r_witin, sel_r); + set_val!(structural_instance, selector_w_witin, sel_w); + set_val!(structural_instance, selector_zero_witin, E::BaseField::ONE); + Self::assign_instance(config, instance, &mut lk_multiplicity, step) + }) + .collect::>() + }) + .collect::>()?; + + // allocate num_rows_padded size, fill points on first half + let mut cur_layer_points_buffer: Vec<_> = (0..num_rows_padded) + .into_par_iter() + .map(|i| { + steps + .get(i) + .map(|step| step.ec_point.point.clone()) + .unwrap_or_else(SepticPoint::default) + }) + .collect(); + // raw_witin offset start from n. + // left node is at b, right node is at b + 1 + // op(left node, right node) = offset + b / 2 + let mut offset = num_rows_padded / 2; + let mut current_layer_len = cur_layer_points_buffer.len() / 2; + + // slope[1,b] = (input[b,0].y - input[b,1].y) / (input[b,0].x - input[b,1].x) + loop { + if current_layer_len <= 1 { + break; + } + let (current_layer, next_layer) = + cur_layer_points_buffer.split_at_mut(current_layer_len); + current_layer + .par_chunks(2) + .zip_eq(next_layer[..current_layer_len / 2].par_iter_mut()) + .zip(raw_witin.values[offset * num_witin..].par_chunks_mut(num_witin)) + .for_each(|((pair, parent), instance)| { + let p1 = &pair[0]; + let p2 = &pair[1]; + let (slope, q) = if p2.is_infinity { + // input[1,b] = bypass_left(input[b,0], input[b,1]) + (SepticExtension::zero(), p1.clone()) + } else { + // input[1,b] = affine_add(input[b,0], input[b,1]) + let slope = (&p1.y - &p2.y) * (&p1.x - &p2.x).inverse().unwrap(); + let q = p1.clone() + p2.clone(); + (slope, q) + }; + config + .x + .iter() + .chain(config.y.iter()) + .chain(config.slope.iter()) + .zip_eq(chain!( + q.x.deref().iter(), + q.y.deref().iter(), + slope.deref().iter(), + )) + .for_each(|(witin, fe)| { + set_val!(instance, *witin, *fe); + }); + *parent = q.clone(); + }); + cur_layer_points_buffer = cur_layer_points_buffer.split_off(current_layer_len); + current_layer_len /= 2; + offset += current_layer_len; + } + + let raw_witin = witness::RowMajorMatrix::new_by_inner_matrix( + raw_witin, + InstancePaddingStrategy::Default, + ); + let raw_structual_witin = witness::RowMajorMatrix::new_by_inner_matrix( + raw_structual_witin, + InstancePaddingStrategy::Default, + ); + Ok([raw_witin, raw_structual_witin]) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use ff_ext::{BabyBearExt4, FromUniformBytes, PoseidonField}; + use itertools::Itertools; + use mpcs::{BasefoldDefault, PolynomialCommitmentScheme, SecurityLevel}; + use p3::babybear::BabyBear; + use rand::thread_rng; + use tracing_forest::{ForestLayer, util::LevelFilter}; + use tracing_subscriber::{EnvFilter, Registry, layer::SubscriberExt, util::SubscriberInitExt}; + use transcript::BasicTranscript; + + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::global::{GlobalChip, GlobalChipInput, GlobalRecord}, + scheme::{ + PublicValues, create_backend, create_prover, hal::ProofInput, prover::ZKVMProver, + septic_curve::SepticPoint, verifier::ZKVMVerifier, + }, + structs::{ComposedConstrainSystem, PointAndEval, ProgramParams, RAMType, ZKVMProvingKey}, + tables::TableCircuit, + }; + use multilinear_extensions::mle::IntoMLE; + use p3::field::PrimeField32; + + type E = BabyBearExt4; + type F = BabyBear; + type Perm = ::P; + type Pcs = BasefoldDefault; + + #[test] + fn test_global_chip() { + // default filter + let default_filter = EnvFilter::builder() + .with_default_directive(LevelFilter::DEBUG.into()) + .from_env_lossy(); + + Registry::default() + .with(ForestLayer::default()) + .with(default_filter) + .init(); + + // init global chip with horizen_rc_consts + let perm = ::get_default_perm(); + + let mut cs = ConstraintSystem::new(|| "global chip test"); + let mut cb = CircuitBuilder::new(&mut cs); + + let (config, gkr_circuit) = + GlobalChip::build_gkr_iop_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + // create a bunch of random memory read/write records + let n_global_reads = 1700; + let n_global_writes = 1420; + let global_reads = (0..n_global_reads) + .map(|i| { + let addr = i * 8; + let value = (i + 1) * 8; + + GlobalRecord { + addr: addr as u32, + ram_type: RAMType::Memory, + value: value as u32, + shard: 0, + local_clk: 0, + global_clk: i, + is_to_write_set: false, + } + }) + .collect::>(); + + let global_writes = (0..n_global_writes) + .map(|i| { + let addr = i * 8; + let value = (i + 1) * 8; + + GlobalRecord { + addr: addr as u32, + ram_type: RAMType::Memory, + value: value as u32, + shard: 1, + local_clk: i, + global_clk: i, + is_to_write_set: true, + } + }) + .collect::>(); + + let input = global_writes // local reads + .into_iter() + .chain(global_reads) // local writes + .map(|record| { + let ec_point = record.to_ec_point::(&perm); + GlobalChipInput { record, ec_point } + }) + .collect::>(); + + let global_ec_sum: SepticPoint = input + .iter() + .map(|record| record.ec_point.point.clone()) + .sum(); + + let public_value = PublicValues::new( + 0, + 0, + 0, + 0, + 0, + 0, + vec![0], // dummy + global_ec_sum + .x + .iter() + .chain(global_ec_sum.y.iter()) + .map(|fe| fe.as_canonical_u32()) + .collect_vec(), + ); + + // assign witness + let witness = GlobalChip::assign_instances( + &config, + cs.num_witin as usize, + cs.num_structural_witin as usize, + &[], + &input, + ) + .unwrap(); + + let composed_cs = ComposedConstrainSystem { + zkvm_v1_css: cs, + gkr_circuit, + }; + let pk = composed_cs.key_gen(); + + // create chip proof for global chip + let pcs_param = Pcs::setup(1 << 20, SecurityLevel::Conjecture100bits).unwrap(); + let (pp, vp) = Pcs::trim(pcs_param, 1 << 20).unwrap(); + let backend = create_backend::(20, SecurityLevel::Conjecture100bits); + let pd = create_prover(backend); + + let zkvm_pk = ZKVMProvingKey::new(pp, vp); + let zkvm_vk = zkvm_pk.get_vk_slow(); + let zkvm_prover = ZKVMProver::new(zkvm_pk, pd); + let mut transcript = BasicTranscript::new(b"global chip test"); + + let public_input_mles = public_value + .to_vec::() + .into_iter() + .map(|v| Arc::new(v.into_mle())) + .collect_vec(); + let proof_input = ProofInput { + witness: witness[0].to_mles().into_iter().map(Arc::new).collect(), + structural_witness: witness[1].to_mles().into_iter().map(Arc::new).collect(), + fixed: vec![], + public_input: public_input_mles.clone(), + num_instances: vec![n_global_writes as usize, n_global_reads as usize], + has_ecc_ops: true, + }; + let mut rng = thread_rng(); + let challenges = [E::random(&mut rng), E::random(&mut rng)]; + let (proof, _, point) = zkvm_prover + .create_chip_proof( + "global chip", + &pk, + proof_input, + &mut transcript, + &challenges, + ) + .unwrap(); + + let mut transcript = BasicTranscript::new(b"global chip test"); + let verifier = ZKVMVerifier::new(zkvm_vk); + let pi_evals = public_input_mles + .iter() + .map(|mle| mle.evaluate(&point[..mle.num_vars()])) + .collect_vec(); + let vrf_point = verifier + .verify_opcode_proof( + "global", + &pk.vk, + &proof, + &pi_evals, + &mut transcript, + 2, + &PointAndEval::default(), + &challenges, + ) + .expect("verify global chip proof"); + assert_eq!(vrf_point, point); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/constants.rs b/ceno_zkvm/src/instructions/riscv/constants.rs index 4e3786235..d98412b6f 100644 --- a/ceno_zkvm/src/instructions/riscv/constants.rs +++ b/ceno_zkvm/src/instructions/riscv/constants.rs @@ -11,6 +11,7 @@ pub const END_PC_IDX: usize = 4; pub const END_CYCLE_IDX: usize = 5; pub const END_SHARD_ID_IDX: usize = 6; pub const PUBLIC_IO_IDX: usize = 7; +pub const GLOBAL_RW_SUM_IDX: usize = PUBLIC_IO_IDX + 2; pub const LIMB_BITS: usize = 16; pub const LIMB_MASK: u32 = 0xFFFF; diff --git a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs index 82a8d0c91..900672a3d 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs @@ -1,12 +1,13 @@ use crate::{ e2e::ShardContext, error::ZKVMError, + instructions::global::GlobalChip, structs::{ProgramParams, ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses}, tables::{ DynVolatileRamTable, HeapInitCircuit, HeapTable, HintsCircuit, LocalFinalCircuit, - MemFinalRecord, MemInitRecord, NonVolatileTable, PubIOCircuit, PubIOTable, RBCircuit, - RegTable, RegTableInitCircuit, StackInitCircuit, StackTable, StaticMemInitCircuit, - StaticMemTable, TableCircuit, + MemFinalRecord, MemInitRecord, NonVolatileTable, PubIOCircuit, PubIOTable, RegTable, + RegTableInitCircuit, StackInitCircuit, StackTable, StaticMemInitCircuit, StaticMemTable, + TableCircuit, }, }; use ceno_emul::{Addr, Cycle, IterAddresses, WORD_SIZE, Word}; @@ -31,7 +32,7 @@ pub struct MmuConfig<'a, E: ExtensionField> { /// finalized circuit for all MMIO pub local_final_circuit: as TableCircuit>::TableConfig, /// ram bus to deal with cross shard read/write - pub ram_bus_circuit: as TableCircuit>::TableConfig, + pub ram_bus_circuit: as TableCircuit>::TableConfig, pub params: ProgramParams, } @@ -47,7 +48,7 @@ impl MmuConfig<'_, E> { let stack_init_config = cs.register_table_circuit::>(); let heap_init_config = cs.register_table_circuit::>(); let local_final_circuit = cs.register_table_circuit::>(); - let ram_bus_circuit = cs.register_table_circuit::>(); + let ram_bus_circuit = cs.register_table_circuit::>(); Self { reg_init_config, @@ -94,7 +95,7 @@ impl MmuConfig<'_, E> { fixed.register_table_circuit::>(cs, &self.stack_init_config, &()); fixed.register_table_circuit::>(cs, &self.heap_init_config, &()); fixed.register_table_circuit::>(cs, &self.local_final_circuit, &()); - fixed.register_table_circuit::>(cs, &self.ram_bus_circuit, &()); + // fixed.register_table_circuit::>(cs, &self.ram_bus_circuit, &()); } #[allow(clippy::too_many_arguments)] @@ -156,14 +157,13 @@ impl MmuConfig<'_, E> { .into_iter() .filter(|(_, record)| !record.is_empty()) .collect_vec(); - // take all mem result and + witness.assign_table_circuit::>( cs, &self.local_final_circuit, &(shard_ctx, all_records.as_slice()), )?; - - witness.assign_table_circuit::>(cs, &self.ram_bus_circuit, shard_ctx)?; + witness.assign_global_chip_circuit(cs, shard_ctx, &self.ram_bus_circuit)?; Ok(()) } diff --git a/ceno_zkvm/src/precompiles/bitwise_keccakf.rs b/ceno_zkvm/src/precompiles/bitwise_keccakf.rs index e25ee972d..51bf0092a 100644 --- a/ceno_zkvm/src/precompiles/bitwise_keccakf.rs +++ b/ceno_zkvm/src/precompiles/bitwise_keccakf.rs @@ -30,7 +30,7 @@ use gkr_iop::{ layer::Layer, layer_constraint_system::{LayerConstraintSystem, expansion_expr}, }, - selector::SelectorType, + selector::{SelectorContext, SelectorType}, utils::{indices_arr_with_offset, lk_multiplicity::LkMultiplicity, wits_fixed_and_eqs}, }; @@ -963,6 +963,14 @@ pub fn run_keccakf + 'stat }; let span = entered_span!("prove", profiling_1 = true); + let selector_ctxs = vec![ + SelectorContext::new(0, num_instances, log2_num_instances); + gkr_circuit + .layers + .first() + .map(|layer| layer.out_sel_and_eval_exprs.len()) + .unwrap() + ]; let GKRProverOutput { gkr_proof, .. } = gkr_circuit .prove::, CpuProver<_>>( num_threads, @@ -972,7 +980,7 @@ pub fn run_keccakf + 'stat &[], &[], &mut prover_transcript, - num_instances, + &selector_ctxs, ) .expect("Failed to prove phase"); exit_span!(span); @@ -993,7 +1001,7 @@ pub fn run_keccakf + 'stat &[], &[], &mut verifier_transcript, - num_instances, + &selector_ctxs, ) .expect("GKR verify failed"); diff --git a/ceno_zkvm/src/precompiles/lookup_keccakf.rs b/ceno_zkvm/src/precompiles/lookup_keccakf.rs index 5b2c1867f..bb105899d 100644 --- a/ceno_zkvm/src/precompiles/lookup_keccakf.rs +++ b/ceno_zkvm/src/precompiles/lookup_keccakf.rs @@ -14,7 +14,7 @@ use gkr_iop::{ layer::Layer, mock::MockProver, }, - selector::SelectorType, + selector::{SelectorContext, SelectorType}, utils::lk_multiplicity::LkMultiplicity, }; use itertools::{Itertools, iproduct, izip, zip_eq}; @@ -1228,6 +1228,7 @@ pub fn run_faster_keccakf } let span = entered_span!("create_proof", profiling_2 = true); + let selector_ctxs = vec![SelectorContext::new(0, num_instances, log2_num_instance_rounds); 3]; let GKRProverOutput { gkr_proof, .. } = gkr_circuit .prove::, CpuProver<_>>( num_threads, @@ -1237,7 +1238,7 @@ pub fn run_faster_keccakf &[], &challenges, &mut prover_transcript, - num_instances, + &selector_ctxs, ) .expect("Failed to prove phase"); exit_span!(span); @@ -1266,7 +1267,7 @@ pub fn run_faster_keccakf &[], &challenges, &mut verifier_transcript, - num_instances, + &selector_ctxs, ) .expect("GKR verify failed"); diff --git a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs index 18e1a205b..76df2b06a 100644 --- a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs +++ b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs @@ -36,7 +36,7 @@ use gkr_iop::{ cpu::{CpuBackend, CpuProver}, error::{BackendError, CircuitBuilderError}, gkr::{GKRCircuit, GKRProof, GKRProverOutput, layer::Layer, mock::MockProver}, - selector::SelectorType, + selector::{SelectorContext, SelectorType}, }; use itertools::{Itertools, izip}; use mpcs::PolynomialCommitmentScheme; @@ -141,11 +141,12 @@ impl WeierstrassAddAssignLayout { descending: false, }, ); + let sel = SelectorType::Prefix(eq.expr()); let selector_type_layout = SelectorTypeLayout { - sel_mem_read: SelectorType::Prefix(E::BaseField::ZERO, eq.expr()), - sel_mem_write: SelectorType::Prefix(E::BaseField::ZERO, eq.expr()), - sel_lookup: SelectorType::Prefix(E::BaseField::ZERO, eq.expr()), - sel_zero: SelectorType::Prefix(E::BaseField::ZERO, eq.expr()), + sel_mem_read: sel.clone(), + sel_mem_write: sel.clone(), + sel_lookup: sel.clone(), + sel_zero: sel.clone(), }; // Default expression, will be updated in build_layer_logic @@ -752,6 +753,7 @@ pub fn run_weierstrass_add< } let span = entered_span!("create_proof", profiling_2 = true); + let selector_ctxs = vec![SelectorContext::new(0, num_instances, log2_num_instance); 1]; let GKRProverOutput { gkr_proof, .. } = gkr_circuit .prove::, CpuProver<_>>( num_threads, @@ -761,7 +763,7 @@ pub fn run_weierstrass_add< &[], &challenges, &mut prover_transcript, - num_instances, + &selector_ctxs, ) .expect("Failed to prove phase"); exit_span!(span); @@ -786,7 +788,7 @@ pub fn run_weierstrass_add< &[], &challenges, &mut verifier_transcript, - num_instances, + &selector_ctxs, ) .expect("GKR verify failed"); diff --git a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs index de03a829e..9f37a26c7 100644 --- a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs +++ b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs @@ -36,7 +36,7 @@ use gkr_iop::{ cpu::{CpuBackend, CpuProver}, error::{BackendError, CircuitBuilderError}, gkr::{GKRCircuit, GKRProof, GKRProverOutput, layer::Layer, mock::MockProver}, - selector::SelectorType, + selector::{SelectorContext, SelectorType}, }; use itertools::{Itertools, izip}; use mpcs::PolynomialCommitmentScheme; @@ -159,11 +159,12 @@ impl descending: false, }, ); + let sel = SelectorType::Prefix(eq.expr()); let selector_type_layout = SelectorTypeLayout { - sel_mem_read: SelectorType::Prefix(E::BaseField::ZERO, eq.expr()), - sel_mem_write: SelectorType::Prefix(E::BaseField::ZERO, eq.expr()), - sel_lookup: SelectorType::Prefix(E::BaseField::ZERO, eq.expr()), - sel_zero: SelectorType::Prefix(E::BaseField::ZERO, eq.expr()), + sel_mem_read: sel.clone(), + sel_mem_write: sel.clone(), + sel_lookup: sel.clone(), + sel_zero: sel.clone(), }; let input32_exprs: GenericArray< @@ -732,6 +733,7 @@ pub fn run_weierstrass_decompress< } let span = entered_span!("create_proof", profiling_2 = true); + let selector_ctxs = vec![SelectorContext::new(0, num_instances, log2_num_instance); 1]; let GKRProverOutput { gkr_proof, .. } = gkr_circuit .prove::, CpuProver<_>>( num_threads, @@ -741,7 +743,7 @@ pub fn run_weierstrass_decompress< &[], &challenges, &mut prover_transcript, - num_instances, + &selector_ctxs, ) .expect("Failed to prove phase"); exit_span!(span); @@ -766,7 +768,7 @@ pub fn run_weierstrass_decompress< &[], &challenges, &mut verifier_transcript, - num_instances, + &selector_ctxs, ) .expect("GKR verify failed"); diff --git a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs index 1260fae33..7f9a02997 100644 --- a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs +++ b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs @@ -36,7 +36,7 @@ use gkr_iop::{ cpu::{CpuBackend, CpuProver}, error::{BackendError, CircuitBuilderError}, gkr::{GKRCircuit, GKRProof, GKRProverOutput, layer::Layer, mock::MockProver}, - selector::SelectorType, + selector::{SelectorContext, SelectorType}, }; use itertools::{Itertools, izip}; use mpcs::PolynomialCommitmentScheme; @@ -143,11 +143,12 @@ impl descending: false, }, ); + let sel = SelectorType::Prefix(eq.expr()); let selector_type_layout = SelectorTypeLayout { - sel_mem_read: SelectorType::Prefix(E::BaseField::ZERO, eq.expr()), - sel_mem_write: SelectorType::Prefix(E::BaseField::ZERO, eq.expr()), - sel_lookup: SelectorType::Prefix(E::BaseField::ZERO, eq.expr()), - sel_zero: SelectorType::Prefix(E::BaseField::ZERO, eq.expr()), + sel_mem_read: sel.clone(), + sel_mem_write: sel.clone(), + sel_lookup: sel.clone(), + sel_zero: sel.clone(), }; let input32_exprs: GenericArray< @@ -754,6 +755,7 @@ pub fn run_weierstrass_double< } let span = entered_span!("create_proof", profiling_2 = true); + let selector_ctxs = vec![SelectorContext::new(0, num_instances, log2_num_instance); 1]; let GKRProverOutput { gkr_proof, .. } = gkr_circuit .prove::, CpuProver<_>>( num_threads, @@ -763,7 +765,7 @@ pub fn run_weierstrass_double< &[], &challenges, &mut prover_transcript, - num_instances, + &selector_ctxs, ) .expect("Failed to prove phase"); exit_span!(span); @@ -788,7 +790,7 @@ pub fn run_weierstrass_double< &[], &challenges, &mut verifier_transcript, - num_instances, + &selector_ctxs, ) .expect("GKR verify failed"); diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index b36759d10..aa3928153 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -1,3 +1,4 @@ +use crate::structs::EccQuarkProof; use ff_ext::ExtensionField; use gkr_iop::gkr::GKRProof; use itertools::Itertools; @@ -29,6 +30,7 @@ pub mod cpu; pub mod gpu; pub mod hal; pub mod prover; +pub mod septic_curve; pub mod utils; pub mod verifier; @@ -58,8 +60,10 @@ pub struct ZKVMChipProof { pub gkr_iop_proof: Option>, pub tower_proof: TowerProofs, + pub ecc_proof: Option>, + + pub num_instances: Vec, - pub num_instances: usize, pub fixed_in_evals: Vec, pub wits_in_evals: Vec, } @@ -74,9 +78,11 @@ pub struct PublicValues { end_cycle: u64, shard_id: u32, public_io: Vec, + global_sum: Vec, } impl PublicValues { + #[allow(clippy::too_many_arguments)] pub fn new( exit_code: u32, init_pc: u32, @@ -85,6 +91,7 @@ impl PublicValues { end_cycle: u64, shard_id: u32, public_io: Vec, + global_sum: Vec, ) -> Self { Self { exit_code, @@ -94,6 +101,7 @@ impl PublicValues { end_cycle, shard_id, public_io, + global_sum, } } pub fn to_vec(&self) -> Vec> { @@ -124,6 +132,12 @@ impl PublicValues { }) .collect_vec(), ) + .chain( + self.global_sum + .iter() + .map(|value| vec![E::BaseField::from_canonical_u32(*value)]) + .collect_vec(), + ) .collect::>() } } @@ -197,7 +211,7 @@ impl> ZKVMProof { let halt_instance_count = self .chip_proofs .get(&halt_circuit_index) - .map_or(0, |proof| proof.num_instances); + .map_or(0, |proof| proof.num_instances.iter().sum()); if halt_instance_count > 0 { assert_eq!( halt_instance_count, 1, diff --git a/ceno_zkvm/src/scheme/constants.rs b/ceno_zkvm/src/scheme/constants.rs index 191fdf103..20687183e 100644 --- a/ceno_zkvm/src/scheme/constants.rs +++ b/ceno_zkvm/src/scheme/constants.rs @@ -6,3 +6,6 @@ pub const NUM_FANIN_LOGUP: usize = 2; pub const MAX_NUM_VARIABLES: usize = 24; pub const DYNAMIC_RANGE_MAX_BITS: usize = 18; + +pub const SEPTIC_EXTENSION_DEGREE: usize = 7; +pub const SEPTIC_JACOBIAN_NUM_MLES: usize = 3 * SEPTIC_EXTENSION_DEGREE; diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index b4972e3e2..cebe79899 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -5,14 +5,15 @@ use crate::{ circuit_builder::ConstraintSystem, error::ZKVMError, scheme::{ - constants::{NUM_FANIN, NUM_FANIN_LOGUP}, - hal::{DeviceProvingKey, MainSumcheckEvals, ProofInput, TowerProverSpec}, + constants::{NUM_FANIN, NUM_FANIN_LOGUP, SEPTIC_EXTENSION_DEGREE}, + hal::{DeviceProvingKey, EccQuarkProver, MainSumcheckEvals, ProofInput, TowerProverSpec}, + septic_curve::{SepticExtension, SepticPoint, SymbolicSepticExtension}, utils::{ infer_tower_logup_witness, infer_tower_product_witness, masked_mle_split_to_chunks, wit_infer_by_expr, }, }, - structs::{ComposedConstrainSystem, PointAndEval, TowerProofs}, + structs::{ComposedConstrainSystem, EccQuarkProof, PointAndEval, TowerProofs}, }; use either::Either; use ff_ext::ExtensionField; @@ -20,6 +21,7 @@ use gkr_iop::{ cpu::{CpuBackend, CpuProver}, gkr::{self, Evaluation, GKRProof, GKRProverOutput, layer::LayerWitness}, hal::ProverBackend, + selector::{SelectorContext, SelectorType}, }; use itertools::{Itertools, chain}; use mpcs::{Point, PolynomialCommitmentScheme}; @@ -30,7 +32,10 @@ use multilinear_extensions::{ virtual_poly::build_eq_x_r_vec, virtual_polys::VirtualPolynomialsBuilder, }; -use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator}; +use rayon::iter::{ + IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, + IntoParallelRefMutIterator, ParallelIterator, +}; use std::{collections::BTreeMap, sync::Arc}; use sumcheck::{ macros::{entered_span, exit_span}, @@ -47,6 +52,258 @@ pub type TowerRelationOutput = ( Vec>, Vec>, ); + +// accumulate N=2^n EC points into one EC point using affine coordinates +// in one layer which borrows ideas from the [Quark paper](https://eprint.iacr.org/2020/1275.pdf) +pub struct CpuEccProver; + +impl CpuEccProver { + pub fn create_ecc_proof<'a, E: ExtensionField>( + num_instances: usize, + xs: Vec>>, + ys: Vec>>, + invs: Vec>>, + transcript: &mut impl Transcript, + ) -> EccQuarkProof { + assert_eq!(xs.len(), SEPTIC_EXTENSION_DEGREE); + assert_eq!(ys.len(), SEPTIC_EXTENSION_DEGREE); + + let n = xs[0].num_vars() - 1; + tracing::debug!( + "Creating EC Summation Quark proof with {} points in {n} variables", + num_instances + ); + + let out_rt = transcript.sample_and_append_vec(b"ecc", n); + let num_threads = optimal_sumcheck_threads(out_rt.len()); + + // expression with add (3 zero constrains) and bypass (2 zero constrains) + let alpha_pows = transcript.sample_and_append_challenge_pows( + SEPTIC_EXTENSION_DEGREE * 3 + SEPTIC_EXTENSION_DEGREE * 2, + b"ecc_alpha", + ); + let mut alpha_pows_iter = alpha_pows.iter(); + + let mut expr_builder = VirtualPolynomialsBuilder::new(num_threads, out_rt.len()); + + let sel_add = SelectorType::QuarkBinaryTreeLessThan(0.into()); + let sel_add_ctx = SelectorContext { + offset: 0, + num_instances, + num_vars: n, + }; + let mut sel_add_mle: MultilinearExtension<'_, E> = + sel_add.compute(&out_rt, &sel_add_ctx).unwrap(); + // we construct sel_bypass witness here + // verifier can derive it via `sel_bypass = eq - sel_add - sel_last_onehot` + let mut sel_bypass_mle: Vec = build_eq_x_r_vec(&out_rt); + match sel_add_mle.evaluations() { + FieldType::Ext(sel_add_mle) => sel_add_mle + .par_iter() + .zip_eq(sel_bypass_mle.par_iter_mut()) + .for_each(|(sel_add, sel_bypass)| { + if *sel_add != E::ZERO { + *sel_bypass = E::ZERO; + } + }), + _ => unreachable!(), + } + *sel_bypass_mle.last_mut().unwrap() = E::ZERO; + let mut sel_bypass_mle = sel_bypass_mle.into_mle(); + let sel_add_expr = expr_builder.lift(sel_add_mle.to_either()); + let sel_bypass_expr = expr_builder.lift(sel_bypass_mle.to_either()); + + let mut exprs_add = vec![]; + let mut exprs_bypass = vec![]; + + let filter_bj = |v: &[Arc>], j: usize| { + v.iter() + .map(|v| { + v.get_base_field_vec() + .iter() + .enumerate() + .filter(|(i, _)| *i % 2 == j) + .map(|(_, v)| v) + .cloned() + .collect_vec() + .into_mle() + }) + .collect_vec() + }; + // build x[b,0], x[b,1], y[b,0], y[b,1] + let mut x0 = filter_bj(&xs, 0); + let mut y0 = filter_bj(&ys, 0); + let mut x1 = filter_bj(&xs, 1); + let mut y1 = filter_bj(&ys, 1); + // build x[1,b], y[1,b], s[1,b] + let mut x3 = xs.iter().map(|x| x.as_view_slice(2, 1)).collect_vec(); + let mut y3 = ys.iter().map(|x| x.as_view_slice(2, 1)).collect_vec(); + let mut s = invs.iter().map(|x| x.as_view_slice(2, 1)).collect_vec(); + + let s = SymbolicSepticExtension::new( + s.iter_mut() + .map(|s| expr_builder.lift(s.to_either())) + .collect(), + ); + let x0 = SymbolicSepticExtension::new( + x0.iter_mut() + .map(|x| expr_builder.lift(x.to_either())) + .collect(), + ); + let y0 = SymbolicSepticExtension::new( + y0.iter_mut() + .map(|y| expr_builder.lift(y.to_either())) + .collect(), + ); + let x1 = SymbolicSepticExtension::new( + x1.iter_mut() + .map(|x| expr_builder.lift(x.to_either())) + .collect(), + ); + let y1 = SymbolicSepticExtension::new( + y1.iter_mut() + .map(|y| expr_builder.lift(y.to_either())) + .collect(), + ); + let x3 = SymbolicSepticExtension::new( + x3.iter_mut() + .map(|x| expr_builder.lift(x.to_either())) + .collect(), + ); + let y3 = SymbolicSepticExtension::new( + y3.iter_mut() + .map(|y| expr_builder.lift(y.to_either())) + .collect(), + ); + // affine addition + // zerocheck: 0 = s[1,b] * (x[b,0] - x[b,1]) - (y[b,0] - y[b,1]) with b != (1,...,1) + exprs_add.extend( + (s.clone() * (&x0 - &x1) - (&y0 - &y1)) + .to_exprs() + .into_iter() + .zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE)) + .map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))), + ); + + // zerocheck: 0 = s[1,b]^2 - x[b,0] - x[b,1] - x[1,b] with b != (1,...,1) + exprs_add.extend( + ((&s * &s) - &x0 - &x1 - &x3) + .to_exprs() + .into_iter() + .zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE)) + .map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))), + ); + + // zerocheck: 0 = s[1,b] * (x[b,0] - x[1,b]) - (y[b,0] + y[1,b]) with b != (1,...,1) + exprs_add.extend( + (s.clone() * (&x0 - &x3) - (&y0 + &y3)) + .to_exprs() + .into_iter() + .zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE)) + .map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))), + ); + + let exprs_add = exprs_add.into_iter().sum::>() * sel_add_expr; + + // deal with bypass + // 0 = (x[1,b] - x[b,0]) + exprs_bypass.extend( + (&x3 - &x0) + .to_exprs() + .into_iter() + .zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE)) + .map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))), + ); + + // 0 = (y[1,b] - y[b,0]) + exprs_bypass.extend( + (&y3 - &y0) + .to_exprs() + .into_iter() + .zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE)) + .map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))), + ); + assert!(alpha_pows_iter.next().is_none()); + + let exprs_bypass = exprs_bypass.into_iter().sum::>() * sel_bypass_expr; + + let (zerocheck_proof, state) = IOPProverState::prove( + expr_builder.to_virtual_polys(&[exprs_add + exprs_bypass], &[]), + transcript, + ); + + let rt = state.collect_raw_challenges(); + let evals = state.get_mle_flatten_final_evaluations(); + + assert_eq!(zerocheck_proof.extract_sum(), E::ZERO); + // 7 for x[rt,0], x[rt,1], y[rt,0], y[rt,1], x[1,rt], y[1,rt], s[1,rt] + assert_eq!(evals.len(), 2 + SEPTIC_EXTENSION_DEGREE * 7); + + let last_evaluation_index = (1 << n) - 1; + let x3 = xs.iter().map(|x| x.as_view_slice(2, 1)).collect_vec(); + let y3 = ys.iter().map(|y| y.as_view_slice(2, 1)).collect_vec(); + let final_sum_x: SepticExtension = (x3.iter()) + .map(|x| x.get_base_field_vec()[last_evaluation_index - 1]) // x[1,...,1,0] + .collect_vec() + .into(); + let final_sum_y: SepticExtension = (y3.iter()) + .map(|y| y.get_base_field_vec()[last_evaluation_index - 1]) // x[1,...,1,0] + .collect_vec() + .into(); + let final_sum = SepticPoint::from_affine(final_sum_x, final_sum_y); + + #[cfg(feature = "sanity-check")] + { + let s = invs.iter().map(|x| x.as_view_slice(2, 1)).collect_vec(); + let x0 = filter_bj(&xs, 0); + let y0 = filter_bj(&ys, 0); + let x1 = filter_bj(&xs, 1); + let y1 = filter_bj(&ys, 1); + + let evals = &evals[2..]; + // check evaluations + for i in 0..SEPTIC_EXTENSION_DEGREE { + assert_eq!(s[i].evaluate(&rt), evals[i]); + assert_eq!(x0[i].evaluate(&rt), evals[SEPTIC_EXTENSION_DEGREE + i]); + assert_eq!(y0[i].evaluate(&rt), evals[SEPTIC_EXTENSION_DEGREE * 2 + i]); + assert_eq!(x1[i].evaluate(&rt), evals[SEPTIC_EXTENSION_DEGREE * 3 + i]); + assert_eq!(y1[i].evaluate(&rt), evals[SEPTIC_EXTENSION_DEGREE * 4 + i]); + assert_eq!(x3[i].evaluate(&rt), evals[SEPTIC_EXTENSION_DEGREE * 5 + i]); + assert_eq!(y3[i].evaluate(&rt), evals[SEPTIC_EXTENSION_DEGREE * 6 + i]); + } + } + + EccQuarkProof { + zerocheck_proof, + num_instances, + evals, + rt, + sum: final_sum, + } + } +} + +impl> EccQuarkProver> + for CpuProver> +{ + fn prove_ec_sum_quark<'a>( + &self, + num_instances: usize, + xs: Vec>>, + ys: Vec>>, + invs: Vec>>, + transcript: &mut impl Transcript, + ) -> Result, ZKVMError> { + Ok(CpuEccProver::create_ecc_proof( + num_instances, + xs, + ys, + invs, + transcript, + )) + } +} + pub struct CpuTowerProver; impl CpuTowerProver { @@ -59,7 +316,7 @@ impl CpuTowerProver { #[derive(Debug, Clone)] enum GroupedMLE<'a, E: ExtensionField> { Prod((usize, Vec>)), // usize is the index in prod_specs - Logup((usize, Vec>)), /* usize is the index in logup_specs */ + Logup((usize, Vec>)), // usize is the index in logup_specs } // XXX to sumcheck batched product argument with logup, we limit num_product_fanin to 2 @@ -311,8 +568,8 @@ impl> TowerProver> MainSumcheckProver> MainSumcheckProver> MainSumcheckProver> = { + // sample N = 2^n points + let mut points = (0..n_points) + .map(|_| SepticPoint::::random(&mut rng)) + .collect_vec(); + points.extend(repeat_n( + SepticPoint::point_at_infinity(), + (1 << log2_n) - points.len(), + )); + let mut s = Vec::with_capacity(1 << (log2_n + 1)); + s.extend(repeat_n(SepticExtension::zero(), 1 << log2_n)); + + for layer in (1..=log2_n).rev() { + let num_inputs = 1 << layer; + let inputs = &points[points.len() - num_inputs..]; + + s.extend(inputs.chunks_exact(2).map(|chunk| { + let p = &chunk[0]; + let q = &chunk[1]; + if q.is_infinity { + SepticExtension::zero() + } else { + (&p.y - &q.y) * (&p.x - &q.x).inverse().unwrap() + } + })); + + points.extend( + inputs + .chunks_exact(2) + .map(|chunk| { + let p = chunk[0].clone(); + let q = chunk[1].clone(); + p + q + }) + .collect_vec(), + ); + } + final_sum = points.last().cloned().unwrap(); + + // padding to 2*N + s.push(SepticExtension::zero()); + points.push(SepticPoint::point_at_infinity()); + + assert_eq!(s.len(), 1 << (log2_n + 1)); + assert_eq!(points.len(), 1 << (log2_n + 1)); + + // transform points to row major matrix + let trace = points + .iter() + .zip_eq(s.iter()) + .map(|(p, s)| { + p.x.iter() + .chain(p.y.iter()) + .chain(s.iter()) + .copied() + .collect_vec() + }) + .collect_vec(); + + // transpose row major matrix to column major matrix + transpose(trace) + .into_iter() + .map(|v| v.into_mle()) + .collect_vec() + }; + let (xs, rest) = ecc_spec.split_at(SEPTIC_EXTENSION_DEGREE); + let (ys, s) = rest.split_at(SEPTIC_EXTENSION_DEGREE); + + let mut transcript = BasicTranscript::new(b"test"); + let quark_proof = CpuEccProver::create_ecc_proof( + n_points, + xs.iter().cloned().map(Arc::new).collect_vec(), + ys.iter().cloned().map(Arc::new).collect_vec(), + s.iter().cloned().map(Arc::new).collect_vec(), + &mut transcript, + ); + + assert_eq!(quark_proof.sum, final_sum); + let mut transcript = BasicTranscript::new(b"test"); + assert!( + EccVerifier::verify_ecc_proof(&quark_proof, &mut transcript) + .inspect_err(|err| println!("err {:?}", err)) + .is_ok() + ); + } +} diff --git a/ceno_zkvm/src/scheme/gpu/mod.rs b/ceno_zkvm/src/scheme/gpu/mod.rs index 455b6786d..07e5adb4d 100644 --- a/ceno_zkvm/src/scheme/gpu/mod.rs +++ b/ceno_zkvm/src/scheme/gpu/mod.rs @@ -203,7 +203,7 @@ fn build_tower_witness_gpu<'buf, E: ExtensionField>( zkvm_v1_css: cs, .. } = composed_cs; let num_instances_with_rotation = - input.num_instances << composed_cs.rotation_vars().unwrap_or(0); + input.num_instances() << composed_cs.rotation_vars().unwrap_or(0); let chip_record_alpha = challenges[0]; // TODO: safety ? @@ -653,9 +653,7 @@ impl> MainSumcheckProver: + OpeningProver + DeviceTransporter + ProtocolWitnessGeneratorProver + + EccQuarkProver // + FixedMLEPadder where PB: ProverBackend, @@ -37,16 +38,30 @@ pub struct ProofInput<'a, PB: ProverBackend> { pub structural_witness: Vec>>, pub fixed: Vec>>, pub public_input: Vec>>, - pub num_instances: usize, + pub num_instances: Vec, + pub has_ecc_ops: bool, } impl<'a, PB: ProverBackend> ProofInput<'a, PB> { + pub fn num_instances(&self) -> usize { + self.num_instances.iter().sum() + } + #[inline] pub fn log2_num_instances(&self) -> usize { - ceil_log2(next_pow2_instance_padding(self.num_instances)) + let num_instance = self.num_instances(); + let log2 = ceil_log2(next_pow2_instance_padding(num_instance)); + if self.has_ecc_ops { + // the mles have one extra variable to store + // the internal partial sums for ecc additions + log2 + 1 + } else { + log2 + } } } +#[derive(Clone)] pub struct TowerProverSpec<'a, PB: ProverBackend> { pub witness: Vec>>, } @@ -65,6 +80,23 @@ pub trait TraceCommitter { ); } +/// Accumulate N (not necessarily power of 2) EC points into one EC point using affine coordinates +/// in one layer which borrows ideas from the [Quark paper](https://eprint.iacr.org/2020/1275.pdf) +/// Note that these points are defined over the septic extension field of BabyBear. +/// +/// The main constraint enforced in this quark layer is: +/// p[1,b] = affine_add(p[b,0], p[b,1]) for all b < N +pub trait EccQuarkProver { + fn prove_ec_sum_quark<'a>( + &self, + num_instances: usize, + xs: Vec>>, + ys: Vec>>, + invs: Vec>>, + transcript: &mut impl Transcript, + ) -> Result, ZKVMError>; +} + pub trait TowerProver { // infer read/write/logup records from the read/write/logup expressions and then // build multiple complete binary trees (tower tree) to accumulate these records diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index ace2215a1..187d2a708 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -9,12 +9,12 @@ use std::{ sync::Arc, }; -use crate::scheme::hal::MainSumcheckEvals; +use crate::scheme::{constants::SEPTIC_EXTENSION_DEGREE, hal::MainSumcheckEvals}; use gkr_iop::hal::MultilinearPolynomial; use itertools::Itertools; use mpcs::{Point, PolynomialCommitmentScheme}; use multilinear_extensions::{ - Instance, + Expression, Instance, mle::{IntoMLE, MultilinearExtension}, }; use p3::field::FieldAlgebra; @@ -118,17 +118,11 @@ impl< { // num_instance from witness might include rotation if let Some(num_instance) = witnesses - .get_opcode_witness(circuit_name) - .or_else(|| witnesses.get_table_witness(circuit_name)) - .map(|rmms| { - if rmms[0].num_instances() == 0 { - rmms[1].num_instances() - } else { - rmms[0].num_instances() - } - }) + .num_instances + .get(circuit_name) + .cloned() .and_then(|num_instance| { - if num_instance > 0 { + if num_instance.iter().sum::() > 0 { Some(num_instance) } else { None @@ -140,26 +134,28 @@ impl< .circuit_index_fixed_num_instances .get(&index) .copied() - .unwrap_or(0) + .map(|num_instance| vec![num_instance]) + .unwrap_or(vec![]) }) }) { - num_instances.push(( - index, - num_instance >> vk.get_cs().rotation_vars().unwrap_or(0), - )); + let num_instance_exclude_rotation = num_instance + .iter() + .map(|num_instance| num_instance >> vk.get_cs().rotation_vars().unwrap_or(0)) + .collect_vec(); + num_instances.push((index, num_instance_exclude_rotation.clone())); + circuit_name_num_instances_mapping + .insert(circuit_name, num_instance_exclude_rotation); num_instances_with_rotation.push((index, num_instance)); - circuit_name_num_instances_mapping.insert( - circuit_name, - num_instance >> vk.get_cs().rotation_vars().unwrap_or(0), - ); } } // write (circuit_idx, num_var) to transcript for (circuit_idx, num_instance) in &num_instances { transcript.append_message(&circuit_idx.to_le_bytes()); - transcript.append_message(&num_instance.to_le_bytes()); + for num_instance in num_instance { + transcript.append_message(&num_instance.to_le_bytes()); + } } let commit_to_traces_span = entered_span!("batch commit to traces", profiling_1 = true); @@ -216,10 +212,10 @@ impl< |(mut points, mut evaluations), (index, (circuit_name, pk))| { let num_instances = circuit_name_num_instances_mapping .get(&circuit_name) - .copied() - .unwrap_or(0); + .cloned() + .unwrap_or_default(); let cs = pk.get_cs(); - if num_instances == 0 { + if num_instances.is_empty() { // we need to drain respective fixed when num_instances is 0 if cs.num_fixed() > 0 { let _ = fixed_mles.drain(..cs.num_fixed()).collect_vec(); @@ -249,7 +245,8 @@ impl< fixed, structural_witness, public_input: public_input.clone(), - num_instances, + num_instances: num_instances.clone(), + has_ecc_ops: cs.has_ecc_ops(), }; if cs.is_opcode_circuit() { @@ -261,7 +258,7 @@ impl< &challenges, )?; tracing::trace!( - "generated proof for opcode {} with num_instances={}", + "generated proof for opcode {} with num_instances={:?}", circuit_name, num_instances ); @@ -341,7 +338,38 @@ impl< let log2_num_instances = input.log2_num_instances(); let num_var_with_rotation = log2_num_instances + cs.rotation_vars().unwrap_or(0); - // println!("create_chip_proof: {}", name); + // run ecc quark prover + let ecc_proof = if !cs.zkvm_v1_css.ec_final_sum.is_empty() { + let ec_point_exprs = &cs.zkvm_v1_css.ec_point_exprs; + assert_eq!(ec_point_exprs.len(), SEPTIC_EXTENSION_DEGREE * 2); + let mut xs_ys = ec_point_exprs + .iter() + .map(|expr| match expr { + Expression::WitIn(id) => input.witness[*id as usize].clone(), + _ => unreachable!("ec point's expression must be WitIn"), + }) + .collect_vec(); + let ys = xs_ys.split_off(SEPTIC_EXTENSION_DEGREE); + let xs = xs_ys; + let slopes = cs + .zkvm_v1_css + .ec_slope_exprs + .iter() + .map(|expr| match expr { + Expression::WitIn(id) => input.witness[*id as usize].clone(), + _ => unreachable!("slope's expression must be WitIn"), + }) + .collect_vec(); + Some(self.device.prove_ec_sum_quark( + input.num_instances(), + xs, + ys, + slopes, + transcript, + )?) + } else { + None + }; // build main witness let (records, is_padded) = @@ -360,6 +388,15 @@ impl< num_var_with_rotation, ); + // TODO: batch reduction into main sumcheck + // x[rt,0] = \sum_b eq([rt,0], b) * x[b] + // x[rt,1] = \sum_b eq([rt,1], b) * x[b] + // x[1,rt] = \sum_b eq([1,rt], b) * x[b] + // y[rt,0] = \sum_b eq([rt,0], b) * y[b] + // y[rt,1] = \sum_b eq([rt,1], b) * y[b] + // y[1,rt] = \sum_b eq([1,rt], b) * y[b] + // s[0,rt] = \sum_b eq([0,rt], b) * s[b] + // 1. prove the main constraints among witness polynomials // 2. prove the relation between last layer in the tower and read/write/logup records let span = entered_span!("prove_main_constraints", profiling_2 = true); @@ -394,6 +431,7 @@ impl< main_sumcheck_proofs, gkr_iop_proof, tower_proof, + ecc_proof, fixed_in_evals, wits_in_evals, num_instances: input.num_instances, diff --git a/ceno_zkvm/src/scheme/septic_curve.rs b/ceno_zkvm/src/scheme/septic_curve.rs new file mode 100644 index 000000000..f9b6b4f76 --- /dev/null +++ b/ceno_zkvm/src/scheme/septic_curve.rs @@ -0,0 +1,1174 @@ +use either::Either; +use ff_ext::{ExtensionField, FromUniformBytes}; +use multilinear_extensions::Expression; +// The extension field and curve definition are adapted from +// https://github.com/succinctlabs/sp1/blob/v5.2.1/crates/stark/src/septic_curve.rs +use p3::field::{Field, FieldAlgebra}; +use rand::RngCore; +use serde::{Deserialize, Serialize}; +use std::{ + iter::Sum, + ops::{Add, Deref, Mul, MulAssign, Neg, Sub}, +}; + +/// F[z] / (z^6 - z - 4) +/// +/// ```sage +/// # finite field F = GF(2^31 - 2^27 + 1) +/// p = 2^31 - 2^27 + 1 +/// F = GF(p) +/// +/// # polynomial ring over F +/// R. = PolynomialRing(F) +/// f = x^6 - x - 4 +/// +/// # check if f(x) is irreducible +/// print(f.is_irreducible()) +/// ``` +pub struct SexticExtension([F; 6]); + +/// F[z] / (z^7 - 2z - 5) +/// +/// ```sage +/// # finite field F = GF(2^31 - 2^27 + 1) +/// p = 2^31 - 2^27 + 1 +/// F = GF(p) +/// +/// # polynomial ring over F +/// R. = PolynomialRing(F) +/// f = x^7 - 2x - 5 +/// +/// # check if f(x) is irreducible +/// print(f.is_irreducible()) +/// ``` +#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize, Hash)] +pub struct SepticExtension(pub [F; 7]); + +impl From<&[F]> for SepticExtension { + fn from(slice: &[F]) -> Self { + assert!(slice.len() == 7); + let mut arr = [F::default(); 7]; + arr.copy_from_slice(&slice[0..7]); + Self(arr) + } +} + +impl From> for SepticExtension { + fn from(v: Vec) -> Self { + assert!(v.len() == 7); + let mut arr = [F::default(); 7]; + arr.copy_from_slice(&v[0..7]); + Self(arr) + } +} + +impl Deref for SepticExtension { + type Target = [F]; + + fn deref(&self) -> &[F] { + &self.0 + } +} + +impl SepticExtension { + pub fn is_zero(&self) -> bool { + self.0.iter().all(|c| *c == F::ZERO) + } + + pub fn zero() -> Self { + Self([F::ZERO; 7]) + } + + pub fn one() -> Self { + let mut arr = [F::ZERO; 7]; + arr[0] = F::ONE; + Self(arr) + } + + // returns z^{i*p} for i = 0..6 + // + // The sage script to compute z^{i*p} is as follows: + // ```sage + // p = 2^31 - 2^27 + 1 + // Fp = GF(p) + // R. = PolynomialRing(Fp) + // mod_poly = z^7 - 2*z - 5 + // Q = R.quotient(mod_poly) + // + // # compute z^(i*p) for i = 1..6 + // for k in range(1, 7): + // power = k * p + // z_power = Q(z)^power + // print(f"z^({k}*p) = {z_power}") + // ``` + fn z_pow_p(i: usize) -> Self { + match i { + 0 => [1, 0, 0, 0, 0, 0, 0].into(), + 1 => [ + 954599710, 1359279693, 566669999, 1982781815, 1735718361, 1174868538, 1120871770, + ] + .into(), + 2 => [ + 862825265, 597046311, 978840770, 1790138282, 1044777201, 835869808, 1342179023, + ] + .into(), + 3 => [ + 596273169, 658837454, 1515468261, 367059247, 781278880, 1544222616, 155490465, + ] + .into(), + 4 => [ + 557608863, 1173670028, 1749546888, 1086464137, 803900099, 1288818584, 1184677604, + ] + .into(), + 5 => [ + 763416381, 1252567168, 628856225, 1771903394, 650712211, 19417363, 57990258, + ] + .into(), + 6 => [ + 1734711039, 1749813853, 1227235221, 1707730636, 424560395, 1007029514, 498034669, + ] + .into(), + _ => unimplemented!("i should be in [0, 7]"), + } + } + + // returns z^{i*p^2} for i = 0..6 + // we can change the above sage script to compute z^{i*p^2} by replacing + // `power = k * p` with `power = k * p * p` + fn z_pow_p_square(i: usize) -> Self { + match i { + 0 => [1, 0, 0, 0, 0, 0, 0].into(), + 1 => [ + 1013489358, 1619071628, 304593143, 1949397349, 1564307636, 327761151, 415430835, + ] + .into(), + 2 => [ + 209824426, 1313900768, 38410482, 256593180, 1708830551, 1244995038, 1555324019, + ] + .into(), + 3 => [ + 1475628651, 777565847, 704492386, 1218528120, 1245363405, 475884575, 649166061, + ] + .into(), + 4 => [ + 550038364, 948935655, 68722023, 1251345762, 1692456177, 1177958698, 350232928, + ] + .into(), + 5 => [ + 882720258, 821925756, 199955840, 812002876, 1484951277, 1063138035, 491712810, + ] + .into(), + 6 => [ + 738287111, 1955364991, 552724293, 1175775744, 341623997, 1454022463, 408193320, + ] + .into(), + _ => unimplemented!("i should be in [0, 7]"), + } + } + + // returns self^p = (a0 + a1*z^p + ... + a6*z^(6p)) + pub fn frobenius(&self) -> Self { + Self::z_pow_p(0) * self.0[0] + + Self::z_pow_p(1) * self.0[1] + + Self::z_pow_p(2) * self.0[2] + + Self::z_pow_p(3) * self.0[3] + + Self::z_pow_p(4) * self.0[4] + + Self::z_pow_p(5) * self.0[5] + + Self::z_pow_p(6) * self.0[6] + } + + // returns self^(p^2) = (a0 + a1*z^(p^2) + ... + a6*z^(6*p^2)) + pub fn double_frobenius(&self) -> Self { + Self::z_pow_p_square(0) * self.0[0] + + Self::z_pow_p_square(1) * self.0[1] + + Self::z_pow_p_square(2) * self.0[2] + + Self::z_pow_p_square(3) * self.0[3] + + Self::z_pow_p_square(4) * self.0[4] + + Self::z_pow_p_square(5) * self.0[5] + + Self::z_pow_p_square(6) * self.0[6] + } + + // returns self^(p + p^2 + ... + p^6) + fn norm_sub(&self) -> Self { + let a = self.frobenius() * self.double_frobenius(); + let b = a.double_frobenius(); + let c = b.double_frobenius(); + + a * b * c + } + + // norm = self^(1 + p + ... + p^6) + // = self^((p^7-1)/(p-1)) + // it's a field element in F since norm^p = norm + fn norm(&self) -> F { + (self.norm_sub() * self).0[0] + } + + pub fn is_square(&self) -> bool { + // since a^((p^7 - 1)/2) = norm(a)^((p-1)/2) + // to test if self^((p^7 - 1) / 2) == 1? + // we can just test if norm(a)^((p-1)/2) == 1? + let exp_digits = ((F::order() - 1u32) / 2u32).to_u64_digits(); + debug_assert!(exp_digits.len() == 1); + let exp = exp_digits[0]; + + self.norm().exp_u64(exp) == F::ONE + } + + pub fn inverse(&self) -> Option { + match self.is_zero() { + true => None, + false => { + // since norm(a)^(-1) * a^(p + p^2 + ... + p^6) * a = 1 + // it's easy to see a^(-1) = norm(a)^(-1) * a^(p + p^2 + ... + p^6) + let x = self.norm_sub(); + let norm = (self * &x).0[0]; + // since self is not zero, norm is not zero + let norm_inv = norm.try_inverse().unwrap(); + + Some(x * norm_inv) + } + } + } + + pub fn square(&self) -> Self { + let mut result = [F::ZERO; 7]; + let two = F::from_canonical_u32(2); + let five = F::from_canonical_u32(5); + + // i < j + for i in 0..7 { + for j in (i + 1)..7 { + let term = two * self.0[i] * self.0[j]; + let mut index = i + j; + if index < 7 { + result[index] += term; + } else { + index -= 7; + // x^7 = 2x + 5 + result[index] += five * term; + result[index + 1] += two * term; + } + } + } + // i == j: i \in [0, 3] + result[0] += self.0[0] * self.0[0]; + result[2] += self.0[1] * self.0[1]; + result[4] += self.0[2] * self.0[2]; + result[6] += self.0[3] * self.0[3]; + // a4^2 * x^8 = a4^2 * (2x + 5)x = 5a4^2 * x + 2a4^2 * x^2 + let term = self.0[4] * self.0[4]; + result[1] += five * term; + result[2] += two * term; + // a5^2 * x^10 = a5^2 * (2x + 5)x^3 = 5a5^2 * x^3 + 2a5^2 * x^4 + let term = self.0[5] * self.0[5]; + result[3] += five * term; + result[4] += two * term; + // a6^2 * x^12 = a6^2 * (2x + 5)x^5 = 5a6^2 * x^5 + 2a6^2 * x^6 + let term = self.0[6] * self.0[6]; + result[5] += five * term; + result[6] += two * term; + + Self(result) + } + + pub fn pow(&self, exp: u64) -> Self { + let mut result = Self::one(); + let num_bits = 64 - exp.leading_zeros(); + for j in (0..num_bits).rev() { + result = result.square(); + if (exp >> j) & 1u64 == 1u64 { + result = result * self; + } + } + result + } + + pub fn sqrt(&self) -> Option { + // the algorithm is adapted from [Cipolla's algorithm](https://en.wikipedia.org/wiki/Cipolla%27s_algorithm + // the code is taken from https://github.com/succinctlabs/sp1/blob/dev/crates/stark/src/septic_extension.rs#L623 + let n = self.clone(); + + if n == Self::zero() || n == Self::one() { + return Some(n); + } + + // norm = n^(1 + p + ... + p^6) = n^(p^7-1)/(p-1) + let norm = n.norm(); + let exp = ((F::order() - 1u32) / 2u32).to_u64_digits()[0]; + // euler's criterion n^((p^7-1)/2) == 1 iff n is quadratic residue + if norm.exp_u64(exp) != F::ONE { + // it's not a square + return None; + }; + + // n_power = n^((p+1)/2) + let exp = ((F::order() + 1u32) / 2u32).to_u64_digits()[0]; + let n_power = self.pow(exp); + + // n^((p^2 + p)/2) + let mut n_frobenius = n_power.frobenius(); + let mut denominator = n_frobenius.clone(); + + // n^((p^4 + p^3)/2) + n_frobenius = n_frobenius.double_frobenius(); + denominator *= n_frobenius.clone(); + // n^((p^6 + p^5)/2) + n_frobenius = n_frobenius.double_frobenius(); + // d = n^((p^6 + p^5 + p^4 + p^3 + p^2 + p) / 2) + // d^2 * n = norm + denominator *= n_frobenius; + // d' = d*n + denominator *= n; + + let base = norm.inverse(); // norm^(-1) + let g = F::GENERATOR; + let mut a = F::ONE; + let mut non_residue = F::ONE - base; + let legendre_exp = (F::order() - 1u32) / 2u32; // (p-1)/2 + + // non_residue = a^2 - 1/norm + // find `a` such that non_residue is not a square in F + while non_residue.exp_u64(legendre_exp.to_u64_digits()[0]) == F::ONE { + a *= g; + non_residue = a.square() - base; + } + + // (p+1)/2 + let cipolla_exp = ((F::order() + 1u32) / 2u32).to_u64_digits()[0]; + // x = (a+i)^((p+1)/2) where a in Fp + // x^2 = (a+i) * (a+i)^p = (a+i)*(a-i) = a^2 - i^2 + // = a^2 - non_residue = 1/norm + // therefore, x is the square root of 1/norm + let mut x = QuadraticExtension::new(a, F::ONE, non_residue); + x = x.pow(cipolla_exp); + + // (x*d')^2 = x^2 * d^2 * n^2 = 1/norm * norm * n + Some(denominator * x.real) + } +} + +// a + bi where i^2 = non_residue +#[derive(Clone, Debug)] +pub struct QuadraticExtension { + pub real: F, + pub imag: F, + pub non_residue: F, +} + +impl QuadraticExtension { + pub fn new(real: F, imag: F, non_residue: F) -> Self { + Self { + real, + imag, + non_residue, + } + } + + pub fn square(&self) -> Self { + // (a + bi)^2 = (a^2 + b^2*i^2) + 2ab*i + let real = self.real * self.real + self.non_residue * self.imag * self.imag; + let mut imag = self.real * self.imag; + imag += imag; + + Self { + real, + imag, + non_residue: self.non_residue, + } + } + + pub fn mul(&self, other: &Self) -> Self { + // (a + bi)(c + di) = (ac + bd*i^2) + (ad + bc)i + let real = self.real * other.real + self.non_residue * self.imag * other.imag; + let imag = self.real * other.imag + self.imag * other.real; + + Self { + real, + imag, + non_residue: self.non_residue, + } + } + + pub fn pow(&self, exp: u64) -> Self { + let mut result = Self { + real: F::ONE, + imag: F::ZERO, + non_residue: self.non_residue, + }; + + let num_bits = 64 - exp.leading_zeros(); + for j in (0..num_bits).rev() { + result = result.square(); + if (exp >> j) & 1u64 == 1u64 { + result = result.mul(self); + } + } + + result + } +} + +impl SepticExtension { + pub fn random(mut rng: impl RngCore) -> Self { + let mut arr = [F::ZERO; 7]; + for item in arr.iter_mut() { + *item = F::random(&mut rng); + } + Self(arr) + } +} + +impl From<[u32; 7]> for SepticExtension { + fn from(arr: [u32; 7]) -> Self { + let mut result = [F::ZERO; 7]; + for i in 0..7 { + result[i] = F::from_canonical_u32(arr[i]); + } + Self(result) + } +} + +impl Add<&Self> for SepticExtension { + type Output = SepticExtension; + + fn add(self, other: &Self) -> Self { + let mut result = [F::ZERO; 7]; + for (i, res) in result.iter_mut().enumerate() { + *res = self.0[i] + other.0[i]; + } + Self(result) + } +} + +impl Add for &SepticExtension { + type Output = SepticExtension; + + fn add(self, other: Self) -> SepticExtension { + let mut result = [F::ZERO; 7]; + for (i, res) in result.iter_mut().enumerate() { + *res = self.0[i] + other.0[i]; + } + SepticExtension(result) + } +} + +impl Add for SepticExtension { + type Output = Self; + + fn add(self, other: Self) -> Self { + self.add(&other) + } +} + +impl Neg for SepticExtension { + type Output = Self; + + fn neg(self) -> Self { + let mut result = [F::ZERO; 7]; + for (res, src) in result.iter_mut().zip(self.0.iter()) { + *res = -(*src); + } + Self(result) + } +} + +impl Sub<&Self> for SepticExtension { + type Output = SepticExtension; + + fn sub(self, other: &Self) -> Self { + let mut result = [F::ZERO; 7]; + for (i, res) in result.iter_mut().enumerate() { + *res = self.0[i] - other.0[i]; + } + Self(result) + } +} + +impl Sub for &SepticExtension { + type Output = SepticExtension; + + fn sub(self, other: Self) -> SepticExtension { + let mut result = [F::ZERO; 7]; + for (i, res) in result.iter_mut().enumerate() { + *res = self.0[i] - other.0[i]; + } + SepticExtension(result) + } +} + +impl Sub for SepticExtension { + type Output = Self; + + fn sub(self, other: Self) -> Self { + self.sub(&other) + } +} + +impl Add for &SepticExtension { + type Output = SepticExtension; + + fn add(self, other: F) -> Self::Output { + let mut result = self.clone(); + result.0[0] += other; + + result + } +} + +impl Add for SepticExtension { + type Output = SepticExtension; + + fn add(self, other: F) -> Self::Output { + (&self).add(other) + } +} + +impl Mul for &SepticExtension { + type Output = SepticExtension; + + fn mul(self, other: F) -> Self::Output { + let mut result = [F::ZERO; 7]; + for (i, res) in result.iter_mut().enumerate() { + *res = self.0[i] * other; + } + SepticExtension(result) + } +} + +impl Mul for SepticExtension { + type Output = SepticExtension; + + fn mul(self, other: F) -> Self::Output { + (&self).mul(other) + } +} + +impl Mul for &SepticExtension { + type Output = SepticExtension; + + fn mul(self, other: Self) -> Self::Output { + let mut result = [F::ZERO; 7]; + let five = F::from_canonical_u32(5); + let two = F::from_canonical_u32(2); + for i in 0..7 { + for j in 0..7 { + let term = self.0[i] * other.0[j]; + let mut index = i + j; + if index < 7 { + result[index] += term; + } else { + index -= 7; + // x^7 = 2x + 5 + result[index] += five * term; + result[index + 1] += two * term; + } + } + } + SepticExtension(result) + } +} + +impl Mul for SepticExtension { + type Output = Self; + + fn mul(self, other: Self) -> Self { + (&self).mul(&other) + } +} + +impl Mul<&Self> for SepticExtension { + type Output = Self; + + fn mul(self, other: &Self) -> Self { + (&self).mul(other) + } +} + +impl MulAssign for SepticExtension { + fn mul_assign(&mut self, other: Self) { + *self = (&*self).mul(&other); + } +} + +#[derive(Clone, Debug)] +pub struct SymbolicSepticExtension(pub Vec>); + +impl SymbolicSepticExtension { + pub fn mul_scalar(&self, scalar: Either) -> Self { + let res = self + .0 + .iter() + .map(|a| a.clone() * Expression::Constant(scalar)) + .collect(); + + SymbolicSepticExtension(res) + } + + pub fn add_scalar(&self, scalar: Either) -> Self { + let res = self + .0 + .iter() + .map(|a| a.clone() + Expression::Constant(scalar)) + .collect(); + + SymbolicSepticExtension(res) + } +} + +impl Add for &SymbolicSepticExtension { + type Output = SymbolicSepticExtension; + + fn add(self, other: Self) -> Self::Output { + let res = self + .0 + .iter() + .zip(other.0.iter()) + .map(|(a, b)| a.clone() + b.clone()) + .collect(); + + SymbolicSepticExtension(res) + } +} + +impl Add<&Self> for SymbolicSepticExtension { + type Output = Self; + + fn add(self, other: &Self) -> Self { + (&self).add(other) + } +} + +impl Add for SymbolicSepticExtension { + type Output = Self; + + fn add(self, other: Self) -> Self { + (&self).add(&other) + } +} + +impl Sub for &SymbolicSepticExtension { + type Output = SymbolicSepticExtension; + + fn sub(self, other: Self) -> Self::Output { + let res = self + .0 + .iter() + .zip(other.0.iter()) + .map(|(a, b)| a.clone() - b.clone()) + .collect(); + + SymbolicSepticExtension(res) + } +} + +impl Sub<&Self> for SymbolicSepticExtension { + type Output = Self; + + fn sub(self, other: &Self) -> Self { + (&self).sub(other) + } +} + +impl Sub for SymbolicSepticExtension { + type Output = Self; + + fn sub(self, other: Self) -> Self { + (&self).sub(&other) + } +} + +impl Mul for &SymbolicSepticExtension { + type Output = SymbolicSepticExtension; + + fn mul(self, other: Self) -> Self::Output { + let mut result = vec![Expression::Constant(Either::Left(E::BaseField::ZERO)); 7]; + let five = Expression::Constant(Either::Left(E::BaseField::from_canonical_u32(5))); + let two = Expression::Constant(Either::Left(E::BaseField::from_canonical_u32(2))); + + for i in 0..7 { + for j in 0..7 { + let term = self.0[i].clone() * other.0[j].clone(); + let mut index = i + j; + if index < 7 { + result[index] += term; + } else { + index -= 7; + // x^7 = 2x + 5 + result[index] += five.clone() * term.clone(); + result[index + 1] += two.clone() * term.clone(); + } + } + } + SymbolicSepticExtension(result) + } +} + +impl Mul<&Self> for SymbolicSepticExtension { + type Output = Self; + + fn mul(self, other: &Self) -> Self { + (&self).mul(other) + } +} + +impl Mul for SymbolicSepticExtension { + type Output = Self; + + fn mul(self, other: Self) -> Self { + (&self).mul(&other) + } +} + +impl Mul<&Expression> for SymbolicSepticExtension { + type Output = SymbolicSepticExtension; + + fn mul(self, other: &Expression) -> Self::Output { + let res = self.0.iter().map(|a| a.clone() * other.clone()).collect(); + SymbolicSepticExtension(res) + } +} + +impl Mul> for SymbolicSepticExtension { + type Output = SymbolicSepticExtension; + + fn mul(self, other: Expression) -> Self::Output { + self.mul(&other) + } +} + +impl SymbolicSepticExtension { + pub fn new(exprs: Vec>) -> Self { + assert!( + exprs.len() == 7, + "exprs length must be 7, but got {}", + exprs.len() + ); + Self(exprs) + } + + pub fn to_exprs(&self) -> Vec> { + self.0.clone() + } +} + +/// A point on the short Weierstrass curve defined by +/// y^2 = x^3 + 2x + 26z^5 +/// over the extension field F[z] / (z^7 - 2z - 5). +/// +/// Note that +/// 1. The curve's cofactor is 1 +/// 2. The curve's order is a large prime number of 31x7 bits +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)] +pub struct SepticPoint { + pub x: SepticExtension, + pub y: SepticExtension, + pub is_infinity: bool, +} + +impl SepticPoint { + // if there exists y such that (x, y) is on the curve, return one of them + pub fn from_x(x: SepticExtension) -> Option { + let b: SepticExtension = [0, 0, 0, 0, 0, 26, 0].into(); + let a: F = F::from_canonical_u32(2); + + let y2 = x.square() * &x + (&x * a) + &b; + if y2.is_square() { + let y = y2.sqrt().unwrap(); + + Some(Self { + x, + y, + is_infinity: false, + }) + } else { + None + } + } + + pub fn from_affine(x: SepticExtension, y: SepticExtension) -> Self { + let is_infinity = x.is_zero() && y.is_zero(); + + Self { x, y, is_infinity } + } + pub fn double(&self) -> Self { + let a = F::from_canonical_u32(2); + let three = F::from_canonical_u32(3); + let two = F::from_canonical_u32(2); + + let x1 = &self.x; + let y1 = &self.y; + let x1_sqr = x1.square(); + + // x3 = (3*x1^2 + a)^2 / (2*y1)^2 - x1 - x1 + let slope = (x1_sqr * three + a) * (y1 * two).inverse().unwrap(); + let x3 = slope.square() - x1 - x1; + // y3 = slope * (x1 - x3) - y1 + let y3 = slope * (x1 - &x3) - y1; + + Self { + x: x3, + y: y3, + is_infinity: false, + } + } +} + +impl Default for SepticPoint { + fn default() -> Self { + Self { + x: SepticExtension::zero(), + y: SepticExtension::zero(), + is_infinity: true, + } + } +} + +impl Neg for SepticPoint { + type Output = SepticPoint; + + fn neg(self) -> Self::Output { + if self.is_infinity { + return self; + } + + Self { + x: self.x, + y: -self.y, + is_infinity: false, + } + } +} + +impl Add for SepticPoint { + type Output = Self; + + fn add(self, other: Self) -> Self { + if self.is_infinity { + return other; + } + + if other.is_infinity { + return self; + } + + if self.x == other.x { + if self.y == other.y { + return self.double(); + } else { + assert!((self.y + other.y).is_zero()); + + return Self { + x: SepticExtension::zero(), + y: SepticExtension::zero(), + is_infinity: true, + }; + } + } + + let slope = (other.y - &self.y) * (other.x.clone() - &self.x).inverse().unwrap(); + let x = slope.square() - (&self.x + &other.x); + let y = slope * (self.x - &x) - self.y; + + Self { + x, + y, + is_infinity: false, + } + } +} + +impl Sum for SepticPoint { + fn sum>(iter: I) -> Self { + iter.fold(Self::default(), |acc, p| acc + p) + } +} + +impl SepticPoint { + pub fn is_on_curve(&self) -> bool { + if self.is_infinity && self.x.is_zero() && self.y.is_zero() { + return true; + } + + let b: SepticExtension = [0, 0, 0, 0, 0, 26, 0].into(); + let a: F = F::from_canonical_u32(2); + + self.y.square() == self.x.square() * &self.x + (&self.x * a) + b + } + + pub fn point_at_infinity() -> Self { + Self::default() + } +} + +impl SepticPoint { + pub fn random(mut rng: impl RngCore) -> Self { + loop { + let x = SepticExtension::random(&mut rng); + if let Some(point) = Self::from_x(x) { + return point; + } + } + } +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct SepticJacobianPoint { + pub x: SepticExtension, + pub y: SepticExtension, + pub z: SepticExtension, +} + +impl From> for SepticJacobianPoint { + fn from(p: SepticPoint) -> Self { + if p.is_infinity { + Self::default() + } else { + Self { + x: p.x, + y: p.y, + z: SepticExtension::one(), + } + } + } +} + +impl Default for SepticJacobianPoint { + fn default() -> Self { + // return the point at infinity + Self { + x: SepticExtension::zero(), + y: SepticExtension::one(), + z: SepticExtension::zero(), + } + } +} + +impl SepticJacobianPoint { + pub fn point_at_infinity() -> Self { + Self::default() + } + + pub fn is_on_curve(&self) -> bool { + if self.z.is_zero() { + return self.x.is_zero() && !self.y.is_zero(); + } + + let b: SepticExtension = [0, 0, 0, 0, 0, 26, 0].into(); + let a: F = F::from_canonical_u32(2); + + let z2 = self.z.square(); + let z4 = z2.square(); + let z6 = &z4 * &z2; + + // y^2 = x^3 + 2x*z^4 + b*z^6 + self.y.square() == self.x.square() * &self.x + (&self.x * a * z4) + (b * &z6) + } + + pub fn into_affine(self) -> SepticPoint { + if self.z.is_zero() { + return SepticPoint::point_at_infinity(); + } + + let z_inv = self.z.inverse().unwrap(); + let z_inv2 = z_inv.square(); + let z_inv3 = &z_inv2 * &z_inv; + + let x = &self.x * &z_inv2; + let y = &self.y * &z_inv3; + + SepticPoint { + x, + y, + is_infinity: false, + } + } +} + +impl Add for &SepticJacobianPoint { + type Output = SepticJacobianPoint; + + fn add(self, rhs: Self) -> Self::Output { + // https://hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-3.html#addition-add-2007-bl + if self.z.is_zero() { + return rhs.clone(); + } + + if rhs.z.is_zero() { + return self.clone(); + } + + let z1z1 = self.z.square(); + let z2z2 = rhs.z.square(); + + let u1 = &self.x * &z2z2; + let u2 = &rhs.x * &z1z1; + + let s1 = &self.y * &z2z2 * &rhs.z; + let s2 = &rhs.y * &z1z1 * &self.z; + + if u1 == u2 { + if s1 == s2 { + return self.double(); + } else { + return SepticJacobianPoint::point_at_infinity(); + } + } + + let two = F::from_canonical_u32(2); + let h = u2 - &u1; + let i = (&h * two).square(); + let j = &h * &i; + let r = (s2 - &s1) * two; + let v = u1 * &i; + + let x3 = r.square() - &j - &v * two; + let y3 = r * (v - &x3) - s1 * &j * two; + let z3 = (&self.z + &rhs.z).square() - &z1z1 - &z2z2; + let z3 = z3 * h; + + Self::Output { + x: x3, + y: y3, + z: z3, + } + } +} + +impl Add for SepticJacobianPoint { + type Output = SepticJacobianPoint; + + fn add(self, rhs: Self) -> Self::Output { + (&self).add(&rhs) + } +} + +impl SepticJacobianPoint { + pub fn double(&self) -> Self { + // https://hyperelliptic.org/EFD/g1p/auto-shortw-jacobian.html#doubling-dbl-2007-bl + + // y = 0 means self.order = 2 + if self.y.is_zero() { + return SepticJacobianPoint::point_at_infinity(); + } + + let two = F::from_canonical_u32(2); + let three = F::from_canonical_u32(3); + let eight = F::from_canonical_u32(8); + let a = F::from_canonical_u32(2); // The curve coefficient a + + // xx = x1^2 + let xx = self.x.square(); + + // yy = y1^2 + let yy = self.y.square(); + + // yyyy = yy^2 + let yyyy = yy.square(); + + // zz = z1^2 + let zz = self.z.square(); + + // S = 2*((x1 + y1^2)^2 - x1^2 - y1^4) + let s = (&self.x + &yy).square() - &xx - &yyyy; + let s = s * two; + + // M = 3*x1^2 + a*z1^4 + let m = &xx * three + zz.square() * a; + + // T = M^2 - 2*S + let t = m.square() - &s * two; + + // Y3 = M*(S-T)-8*y^4 + let y3 = m * (&s - &t) - &yyyy * eight; + + // X3 = T + let x3 = t; + + // Z3 = (y1+z1)^2 - y1^2 - z1^2 + let z3 = (&self.y + &self.z).square() - &yy - &zz; + + Self { + x: x3, + y: y3, + z: z3, + } + } +} + +impl Sum for SepticJacobianPoint { + fn sum>(iter: I) -> Self { + iter.fold(Self::default(), |acc, p| acc + p) + } +} + +impl SepticJacobianPoint { + pub fn random(rng: impl RngCore) -> Self { + SepticPoint::random(rng).into() + } +} + +#[cfg(test)] +mod tests { + use super::SepticExtension; + use crate::scheme::septic_curve::{SepticJacobianPoint, SepticPoint}; + use p3::{babybear::BabyBear, field::Field}; + use rand::thread_rng; + + type F = BabyBear; + #[test] + fn test_septic_extension_arithmetic() { + let mut rng = thread_rng(); + // a = z, b = z^6 + z^5 + z^4 + let a: SepticExtension = SepticExtension::from([0, 1, 0, 0, 0, 0, 0]); + let b: SepticExtension = SepticExtension::from([0, 0, 0, 0, 1, 1, 1]); + + let c = SepticExtension::from([5, 2, 0, 0, 0, 1, 1]); + assert_eq!(a * b, c); + + // a^(p^2) = (a^p)^p + assert_eq!(c.double_frobenius(), c.frobenius().frobenius()); + + // norm_sub(a) * a must be in F + let norm = c.norm_sub() * &c; + assert!(norm.0[1..7].iter().all(|x| x.is_zero())); + + let d: SepticExtension = SepticExtension::random(&mut rng); + let e = d.square(); + assert!(e.is_square()); + + let f = e.sqrt().unwrap(); + let zero = SepticExtension::zero(); + assert!(f == d || f == zero - d); + } + + #[test] + fn test_septic_curve_arithmetic() { + let mut rng = thread_rng(); + let p1 = SepticPoint::::random(&mut rng); + let p2 = SepticPoint::::random(&mut rng); + + let j1 = SepticJacobianPoint::from(p1.clone()); + let j2 = SepticJacobianPoint::from(p2.clone()); + + let p3 = p1 + p2; + let j3 = &j1 + &j2; + + assert!(j1.is_on_curve()); + assert!(j2.is_on_curve()); + + assert!(j3.is_on_curve()); + assert!(p3.is_on_curve()); + + assert_eq!(p3, j3.clone().into_affine()); + + // 2*p3 - p3 = p3 + let p4 = p3.double(); + assert_eq!((-p3.clone() + p4.clone()), p3); + + // 2*j3 = 2*p3 + let j4 = j3.double(); + assert!(j4.is_on_curve()); + assert_eq!(j4.into_affine(), p4); + } +} diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 5cce8f4db..73355017c 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -55,6 +55,7 @@ use transcript::{BasicTranscript, Transcript}; struct TestConfig { pub(crate) reg_id: WitIn, } + struct TestCircuit { phantom: PhantomData, } @@ -197,7 +198,8 @@ fn test_rw_lk_expression_combination() { witness: wits_in, structural_witness: structural_in, public_input: vec![], - num_instances, + num_instances: vec![num_instances], + has_ecc_ops: false, }; let (proof, _, _) = prover .create_chip_proof( @@ -370,7 +372,7 @@ fn test_single_add_instance_e2e() { .assign_table_circuit::>(&zkvm_cs, &prog_config, &program) .unwrap(); - let pi = PublicValues::new(0, 0, 0, 0, 0, 0, vec![0]); + let pi = PublicValues::new(0, 0, 0, 0, 0, 0, vec![0], vec![0; 14]); let transcript = BasicTranscript::new(b"riscv"); let zkvm_proof = prover .create_proof(zkvm_witness, pi, transcript) diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index 637fa09b1..cfa88175f 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -156,6 +156,11 @@ macro_rules! tower_mle_4 { }}; } +pub fn log2_strict_usize(n: usize) -> usize { + assert!(n.is_power_of_two()); + n.trailing_zeros() as usize +} + /// infer logup witness from last layer /// return is the ([p1,p2], [q1,q2]) for each layer pub(crate) fn infer_tower_logup_witness<'a, E: ExtensionField>( @@ -254,45 +259,80 @@ pub(crate) fn infer_tower_logup_witness<'a, E: ExtensionField>( .collect_vec() } -/// infer tower witness from last layer -pub(crate) fn infer_tower_product_witness( +/// Infer tower witness from input layer (layer 0 is the output layer and layer n is the input layer). +/// The relation between layer i and layer i+1 is as follows: +/// prod[i][b] = ∏_s prod[i+1][s,b] +/// where 2^s is the fanin of the product gate `num_product_fanin`. +pub fn infer_tower_product_witness( num_vars: usize, last_layer: Vec>, num_product_fanin: usize, ) -> Vec>> { + // sanity check assert!(last_layer.len() == num_product_fanin); - assert_eq!(num_product_fanin % 2, 0); - let log2_num_product_fanin = ceil_log2(num_product_fanin); - let mut wit_layers = - (0..(num_vars / log2_num_product_fanin) - 1).fold(vec![last_layer], |mut acc, _| { - let next_layer = acc.last().unwrap(); - let cur_len = next_layer[0].evaluations().len() / num_product_fanin; - let cur_layer: Vec> = (0..num_product_fanin) - .map(|index| { - let mut evaluations = vec![E::ONE; cur_len]; - next_layer.chunks_exact(2).for_each(|f| { - match (f[0].evaluations(), f[1].evaluations()) { - (FieldType::Ext(f1), FieldType::Ext(f2)) => { - let start: usize = index * cur_len; - (start..(start + cur_len)) + assert!(num_product_fanin.is_power_of_two()); + + let log2_num_product_fanin = log2_strict_usize(num_product_fanin); + assert!(num_vars.is_multiple_of(log2_num_product_fanin)); + assert!( + last_layer + .iter() + .all(|p| p.num_vars() == num_vars - log2_num_product_fanin) + ); + + let num_layers = num_vars / log2_num_product_fanin; + + let mut wit_layers = Vec::with_capacity(num_layers); + wit_layers.push(last_layer); + + for _ in (0..num_layers - 1).rev() { + let input_layer = wit_layers.last().unwrap(); + let output_len = input_layer[0].evaluations().len() / num_product_fanin; + + let output_layer: Vec> = (0..num_product_fanin) + .map(|index| { + // avoid the overhead of vector initialization + let mut evaluations: Vec = Vec::with_capacity(output_len); + let remaining = evaluations.spare_capacity_mut(); + + input_layer.chunks_exact(2).enumerate().for_each(|(i, f)| { + match (f[0].evaluations(), f[1].evaluations()) { + (FieldType::Ext(f1), FieldType::Ext(f2)) => { + let start: usize = index * output_len; + + if i == 0 { + (start..(start + output_len)) + .into_par_iter() + .zip(remaining.par_iter_mut()) + .with_min_len(MIN_PAR_SIZE) + .for_each(|(index, evaluations)| { + evaluations.write(f1[index] * f2[index]); + }); + } else { + (start..(start + output_len)) .into_par_iter() - .zip(evaluations.par_iter_mut()) + .zip(remaining.par_iter_mut()) .with_min_len(MIN_PAR_SIZE) - .map(|(index, evaluations)| { - *evaluations *= f1[index] * f2[index] - }) - .collect() + .for_each(|(index, evaluations)| { + evaluations.write(f1[index] * f2[index]); + }); } - _ => unreachable!("must be extension field"), } - }); - evaluations.into_mle() - }) - .collect_vec(); - acc.push(cur_layer); - acc - }); + _ => unreachable!("must be extension field"), + } + }); + + unsafe { + evaluations.set_len(output_len); + } + evaluations.into_mle() + }) + .collect_vec(); + wit_layers.push(output_layer); + } + wit_layers.reverse(); + wit_layers } @@ -374,7 +414,7 @@ pub fn build_main_witness< } else { ( >::table_witness(device, input, cs, challenges), - input.num_instances > 1 && input.num_instances.is_power_of_two(), + input.num_instances() > 1 && input.num_instances().is_power_of_two(), ) } }; diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index b38c6e589..4ed5a89e9 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -5,21 +5,30 @@ use std::marker::PhantomData; #[cfg(debug_assertions)] use ff_ext::{Instrumented, PoseidonField}; -use super::{ZKVMChipProof, ZKVMProof}; use crate::{ error::ZKVMError, - scheme::constants::{NUM_FANIN, NUM_FANIN_LOGUP}, - structs::{ComposedConstrainSystem, PointAndEval, TowerProofs, VerifyingKey, ZKVMVerifyingKey}, + scheme::{ + constants::{NUM_FANIN, NUM_FANIN_LOGUP, SEPTIC_EXTENSION_DEGREE}, + septic_curve::SepticExtension, + }, + structs::{ + ComposedConstrainSystem, EccQuarkProof, PointAndEval, TowerProofs, VerifyingKey, + ZKVMVerifyingKey, + }, utils::{ eval_inner_repeated_incremental_vec, eval_outer_repeated_incremental_vec, eval_stacked_constant_vec, eval_stacked_wellform_address_vec, eval_wellform_address_vec, }, }; -use gkr_iop::gkr::GKRClaims; +use gkr_iop::{ + gkr::GKRClaims, + selector::{SelectorContext, SelectorType}, +}; use itertools::{Itertools, chain, interleave, izip}; use mpcs::{Point, PolynomialCommitmentScheme}; use multilinear_extensions::{ - Instance, StructuralWitIn, StructuralWitInType, + Expression, Instance, StructuralWitIn, StructuralWitInType, + StructuralWitInType::StackedConstantSequence, mle::IntoMLE, util::ceil_log2, utils::eval_by_expr_with_instance, @@ -33,6 +42,8 @@ use sumcheck::{ use transcript::{ForkableTranscript, Transcript}; use witness::next_pow2_instance_padding; +use super::{ZKVMChipProof, ZKVMProof}; + pub struct ZKVMVerifier> { pub vk: ZKVMVerifyingKey, } @@ -61,15 +72,15 @@ impl> ZKVMVerifier &self, vm_proof: ZKVMProof, transcript: impl ForkableTranscript, - expect_halt: bool, + _expect_halt: bool, ) -> Result { // require ecall/halt proof to exist, depending whether we expect a halt. - let has_halt = vm_proof.has_halt(&self.vk); - if has_halt != expect_halt { - return Err(ZKVMError::VerifyError( - format!("ecall/halt mismatch: expected {expect_halt} != {has_halt}",).into(), - )); - } + // let has_halt = vm_proof.has_halt(&self.vk); + // if has_halt != expect_halt { + // return Err(ZKVMError::VerifyError( + // format!("ecall/halt mismatch: expected {expect_halt} != {has_halt}",).into(), + // )); + // } self.verify_proof_validity(vm_proof, transcript) } @@ -131,7 +142,9 @@ impl> ZKVMVerifier // write (circuit_idx, num_instance) to transcript for (circuit_idx, proof) in &vm_proof.chip_proofs { transcript.append_message(&circuit_idx.to_le_bytes()); - transcript.append_message(&proof.num_instances.to_le_bytes()); + for num_instance in &proof.num_instances { + transcript.append_message(&num_instance.to_le_bytes()); + } } // write witin commitment to transcript @@ -158,7 +171,8 @@ impl> ZKVMVerifier let mut witin_openings = Vec::with_capacity(vm_proof.chip_proofs.len()); let mut fixed_openings = Vec::with_capacity(vm_proof.chip_proofs.len()); for (index, proof) in &vm_proof.chip_proofs { - assert!(proof.num_instances > 0); + let num_instance: usize = proof.num_instances.iter().sum(); + assert!(num_instance > 0); let circuit_name = &self.vk.circuit_index_to_name[index]; let circuit_vk = &self.vk.circuit_vks[circuit_name]; @@ -216,11 +230,10 @@ impl> ZKVMVerifier // getting the number of dummy padding item that we used in this opcode circuit let num_lks = circuit_vk.get_cs().num_lks(); // each padding instance contribute to (2^rotation_vars) dummy lookup padding - let num_padded_instance = (next_pow2_instance_padding(proof.num_instances) - - proof.num_instances) + let num_padded_instance = (next_pow2_instance_padding(num_instance) - num_instance) * (1 << circuit_vk.get_cs().rotation_vars().unwrap_or(0)); // each instance contribute to (2^rotation_vars - rotated) dummy lookup padding - let num_instance_non_selected = proof.num_instances + let num_instance_non_selected = num_instance * ((1 << circuit_vk.get_cs().rotation_vars().unwrap_or(0)) - (circuit_vk.get_cs().rotation_subgroup_size().unwrap_or(0) + 1)); dummy_table_item_multiplicity += @@ -346,7 +359,7 @@ impl> ZKVMVerifier zkvm_v1_css: cs, gkr_circuit, } = &composed_cs; - let num_instances = proof.num_instances; + let num_instances = proof.num_instances.iter().sum(); let (r_counts_per_instance, w_counts_per_instance, lk_counts_per_instance) = ( cs.r_expressions.len() + cs.r_table_expressions.len(), cs.w_expressions.len() + cs.w_table_expressions.len(), @@ -355,9 +368,47 @@ impl> ZKVMVerifier let num_batched = r_counts_per_instance + w_counts_per_instance + lk_counts_per_instance; let next_pow2_instance = next_pow2_instance_padding(num_instances); - let log2_num_instances = ceil_log2(next_pow2_instance); + let mut log2_num_instances = ceil_log2(next_pow2_instance); + if composed_cs.has_ecc_ops() { + // for opcode circuit with ecc ops, the mles have one extra variable + // to store the internal partial sums for ecc additions + log2_num_instances += 1; + } let num_var_with_rotation = log2_num_instances + composed_cs.rotation_vars().unwrap_or(0); + // verify ecc proof if exists + if composed_cs.has_ecc_ops() { + tracing::debug!("verifying ecc proof..."); + assert!(proof.ecc_proof.is_some()); + let ecc_proof = proof.ecc_proof.as_ref().unwrap(); + + // TODO: enable this + // let xy = cs + // .ec_final_sum + // .iter() + // .map(|expr| { + // eval_by_expr_with_instance(&[], &[], &[], pi, challenges, &expr) + // .right() + // .and_then(|v| v.as_base()) + // .unwrap() + // }) + // .collect_vec(); + // let x: SepticExtension = xy[0..SEPTIC_EXTENSION_DEGREE].into(); + // let y: SepticExtension = xy[SEPTIC_EXTENSION_DEGREE..].into(); + + // assert_eq!( + // SepticPoint { + // x, + // y, + // is_infinity: false, + // }, + // ecc_proof.sum + // ); + // assert ec sum in public input matches that in ecc proof + EccVerifier::verify_ecc_proof(ecc_proof, transcript)?; + tracing::debug!("ecc proof verified."); + } + // verify and reduce product tower sumcheck let tower_proofs = &proof.tower_proof; @@ -402,6 +453,44 @@ impl> ZKVMVerifier debug_assert_eq!(logup_q_evals.len(), lk_counts_per_instance); let gkr_circuit = gkr_circuit.as_ref().unwrap(); + let selector_ctxs = if cs.ec_final_sum.is_empty() { + assert_eq!(proof.num_instances.len(), 1); + // it's not global chip + vec![ + SelectorContext::new(0, num_instances, num_var_with_rotation); + gkr_circuit + .layers + .first() + .map(|layer| layer.out_sel_and_eval_exprs.len()) + .unwrap_or(0) + ] + } else { + assert_eq!(proof.num_instances.len(), 2); + // it's global chip + tracing::debug!( + "num_reads: {}, num_writes: {}, total: {}", + proof.num_instances[0], + proof.num_instances[1], + proof.num_instances[0] + proof.num_instances[1], + ); + vec![ + SelectorContext { + offset: 0, + num_instances: proof.num_instances[0], + num_vars: num_var_with_rotation, + }, + SelectorContext { + offset: proof.num_instances[0], + num_instances: proof.num_instances[1], + num_vars: num_var_with_rotation, + }, + SelectorContext { + offset: 0, + num_instances: proof.num_instances[0] + proof.num_instances[1], + num_vars: num_var_with_rotation, + }, + ] + }; let GKRClaims(opening_evaluations) = gkr_circuit.verify( num_var_with_rotation, proof.gkr_iop_proof.clone().unwrap(), @@ -409,7 +498,7 @@ impl> ZKVMVerifier pi, challenges, transcript, - num_instances, + &selector_ctxs, )?; Ok(opening_evaluations[0].point.clone()) } @@ -439,7 +528,8 @@ impl> ZKVMVerifier .all(|(r, w)| r.table_spec.len == w.table_spec.len) ); } - let log2_num_instances = next_pow2_instance_padding(proof.num_instances).ilog2() as usize; + let num_instances = proof.num_instances.iter().sum(); + let log2_num_instances = next_pow2_instance_padding(num_instances).ilog2() as usize; // verify and reduce product tower sumcheck let tower_proofs = &proof.tower_proof; @@ -708,6 +798,8 @@ impl TowerVerify { ) }) .unzip::<_, _, Vec<_>, Vec<_>>(); + + // initial claim = \sum_j alpha^j * out_j[rt] let initial_claim = izip!(&prod_spec_point_n_eval, &alpha_pows) .map(|(point_n_eval, alpha)| point_n_eval.eval * *alpha) .sum::() @@ -720,7 +812,7 @@ impl TowerVerify { let max_num_variables = num_variables.iter().max().unwrap(); - let (next_rt, _) = (0..(max_num_variables-1)).try_fold( + let (next_rt, _) = (0..(max_num_variables - 1)).try_fold( ( PointAndEval { point: initial_rt, @@ -745,33 +837,40 @@ impl TowerVerify { // check expected_evaluation let rt: Point = sumcheck_claim.point.iter().map(|c| c.elements).collect(); + let eq = eq_eval(out_rt, &rt); let expected_evaluation: E = (0..num_prod_spec) .zip(alpha_pows.iter()) .zip(num_variables.iter()) .map(|((spec_index, alpha), max_round)| { - eq_eval(out_rt, &rt) - * *alpha - * if round < *max_round-1 {tower_proofs.prod_specs_eval[spec_index][round].iter().copied().product()} else { - E::ZERO - } + // prod'[b] = prod[0,b] * prod[1,b] + // prod'[out_rt] = \sum_b eq(out_rt,b) * prod'[b] = \sum_b eq(out_rt,b) * prod[0,b] * prod[1,b] + eq * *alpha + * if round < *max_round - 1 { tower_proofs.prod_specs_eval[spec_index][round].iter().copied().product() } else { + E::ZERO + } }) .sum::() + (0..num_logup_spec) - .zip_eq(alpha_pows[num_prod_spec..].chunks(2)) - .zip_eq(num_variables[num_prod_spec..].iter()) - .map(|((spec_index, alpha), max_round)| { - let (alpha_numerator, alpha_denominator) = (&alpha[0], &alpha[1]); - eq_eval(out_rt, &rt) * if round < *max_round-1 { - let evals = &tower_proofs.logup_specs_eval[spec_index][round]; - let (p1, p2, q1, q2) = - (evals[0], evals[1], evals[2], evals[3]); - *alpha_numerator * (p1 * q2 + p2 * q1) - + *alpha_denominator * (q1 * q2) - } else { - E::ZERO - } - }) - .sum::(); + .zip_eq(alpha_pows[num_prod_spec..].chunks(2)) + .zip_eq(num_variables[num_prod_spec..].iter()) + .map(|((spec_index, alpha), max_round)| { + // logup_q'[b] = logup_q[0,b] * logup_q[1,b] + // logup_p'[b] = logup_p[0,b] * logup_q[1,b] + logup_p[1,b] * logup_q[0,b] + // logup_p'[out_rt] = \sum_b eq(out_rt,b) * (logup_p[0,b] * logup_q[1,b] + logup_p[1,b] * logup_q[0,b]) + // logup_q'[out_rt] = \sum_b eq(out_rt,b) * logup_q[0,b] * logup_q[1,b] + let (alpha_numerator, alpha_denominator) = (&alpha[0], &alpha[1]); + eq * if round < *max_round - 1 { + let evals = &tower_proofs.logup_specs_eval[spec_index][round]; + let (p1, p2, q1, q2) = + (evals[0], evals[1], evals[2], evals[3]); + *alpha_numerator * (p1 * q2 + p2 * q1) + + *alpha_denominator * (q1 * q2) + } else { + E::ZERO + } + }) + .sum::(); + if expected_evaluation != sumcheck_claim.expected_evaluation { return Err(ZKVMError::VerifyError("mismatch tower evaluation".into())); } @@ -779,7 +878,7 @@ impl TowerVerify { // derive single eval // rt' = r_merge || rt // r_merge.len() == ceil_log2(num_product_fanin) - let r_merge =transcript.sample_and_append_vec(b"merge", log2_num_fanin); + let r_merge = transcript.sample_and_append_vec(b"merge", log2_num_fanin); let coeffs = build_eq_x_r_vec_sequential(&r_merge); assert_eq!(coeffs.len(), num_fanin); let rt_prime = [rt, r_merge].concat(); @@ -794,17 +893,18 @@ impl TowerVerify { .zip(next_alpha_pows.iter()) .zip(num_variables.iter()) .map(|((spec_index, alpha), max_round)| { - if round < max_round -1 { + // prod'[rt,r_merge] = \sum_b eq(r_merge, b) * prod'[b,rt] + if round < max_round - 1 { // merged evaluation let evals = izip!( tower_proofs.prod_specs_eval[spec_index][round].iter(), coeffs.iter() ) - .map(|(a, b)| *a * *b) - .sum::(); + .map(|(a, b)| *a * *b) + .sum::(); // this will keep update until round > evaluation prod_spec_point_n_eval[spec_index] = PointAndEval::new(rt_prime.clone(), evals); - if next_round < max_round -1 { + if next_round < max_round - 1 { *alpha * evals } else { E::ZERO @@ -818,28 +918,28 @@ impl TowerVerify { .zip_eq(next_alpha_pows[num_prod_spec..].chunks(2)) .zip_eq(num_variables[num_prod_spec..].iter()) .map(|((spec_index, alpha), max_round)| { - if round < max_round -1 { + if round < max_round - 1 { let (alpha_numerator, alpha_denominator) = (&alpha[0], &alpha[1]); // merged evaluation let p_evals = izip!( tower_proofs.logup_specs_eval[spec_index][round][0..2].iter(), coeffs.iter() ) - .map(|(a, b)| *a * *b) - .sum::(); + .map(|(a, b)| *a * *b) + .sum::(); let q_evals = izip!( tower_proofs.logup_specs_eval[spec_index][round][2..4].iter(), coeffs.iter() ) - .map(|(a, b)| *a * *b) - .sum::(); + .map(|(a, b)| *a * *b) + .sum::(); // this will keep update until round > evaluation logup_spec_p_point_n_eval[spec_index] = PointAndEval::new(rt_prime.clone(), p_evals); logup_spec_q_point_n_eval[spec_index] = PointAndEval::new(rt_prime.clone(), q_evals); - if next_round < max_round -1 { + if next_round < max_round - 1 { *alpha_numerator * p_evals + *alpha_denominator * q_evals } else { E::ZERO @@ -849,8 +949,10 @@ impl TowerVerify { } }) .sum::(); + // sum evaluation from different specs let next_eval = next_prod_spec_evals + next_logup_spec_evals; + Ok((PointAndEval { point: rt_prime, eval: next_eval, @@ -866,3 +968,134 @@ impl TowerVerify { )) } } + +pub struct EccVerifier; + +impl EccVerifier { + pub fn verify_ecc_proof( + proof: &EccQuarkProof, + transcript: &mut impl Transcript, + ) -> Result<(), ZKVMError> { + let num_vars = next_pow2_instance_padding(proof.num_instances).ilog2() as usize; + let out_rt = transcript.sample_and_append_vec(b"ecc", num_vars); + let alpha_pows = transcript.sample_and_append_challenge_pows( + SEPTIC_EXTENSION_DEGREE * 3 + SEPTIC_EXTENSION_DEGREE * 2, + b"ecc_alpha", + ); + let mut alpha_pows_iter = alpha_pows.iter(); + + let sumcheck_claim = IOPVerifierState::verify( + E::ZERO, + &proof.zerocheck_proof, + &VPAuxInfo { + max_degree: 3, + max_num_variables: num_vars, + phantom: PhantomData, + }, + transcript, + ); + + let s0: SepticExtension = proof.evals[2..][0..][..SEPTIC_EXTENSION_DEGREE].into(); + let x0: SepticExtension = + proof.evals[2..][SEPTIC_EXTENSION_DEGREE..][..SEPTIC_EXTENSION_DEGREE].into(); + let y0: SepticExtension = + proof.evals[2..][2 * SEPTIC_EXTENSION_DEGREE..][..SEPTIC_EXTENSION_DEGREE].into(); + let x1: SepticExtension = + proof.evals[2..][3 * SEPTIC_EXTENSION_DEGREE..][..SEPTIC_EXTENSION_DEGREE].into(); + let y1: SepticExtension = + proof.evals[2..][4 * SEPTIC_EXTENSION_DEGREE..][..SEPTIC_EXTENSION_DEGREE].into(); + let x3: SepticExtension = + proof.evals[2..][5 * SEPTIC_EXTENSION_DEGREE..][..SEPTIC_EXTENSION_DEGREE].into(); + let y3: SepticExtension = + proof.evals[2..][6 * SEPTIC_EXTENSION_DEGREE..][..SEPTIC_EXTENSION_DEGREE].into(); + + let rt = sumcheck_claim + .point + .iter() + .map(|c| c.elements) + .collect_vec(); + + // zerocheck: 0 = s[0,b] * (x[b,0] - x[b,1]) - (y[b,0] - y[b,1]) + // zerocheck: 0 = s[0,b]^2 - x[b,0] - x[b,1] - x[1,b] + // zerocheck: 0 = s[0,b] * (x[b,0] - x[1,b]) - (y[b,0] + y[1,b]) + // zerocheck: 0 = (x[1,b] - x[b,0]) + // zerocheck: 0 = (y[1,b] - y[b,0]) + // + // note that they are not septic extension field elements, + // we just want to reuse the multiply/add/sub formulas + let v1: SepticExtension = s0.clone() * (&x0 - &x1) - (&y0 - &y1); + let v2: SepticExtension = s0.square() - &x0 - &x1 - &x3; + let v3: SepticExtension = s0 * (&x0 - &x3) - (&y0 + &y3); + + let v4: SepticExtension = &x3 - &x0; + let v5: SepticExtension = &y3 - &y0; + + let [v1, v2, v3, v4, v5] = [v1, v2, v3, v4, v5].map(|v| { + v.0.into_iter() + .zip(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE)) + .map(|(c, alpha)| c * *alpha) + .collect_vec() + }); + + let sel_add_expr = SelectorType::::QuarkBinaryTreeLessThan(Expression::StructuralWitIn( + 0, + // this value doesn't matter, as we only need structural id + StackedConstantSequence { max_value: 0 }, + )); + let mut sel_evals = vec![E::ZERO]; + sel_add_expr.evaluate( + &mut sel_evals, + &out_rt, + &rt, + &SelectorContext { + offset: 0, + num_instances: proof.num_instances, + num_vars, + }, + 0, + ); + let expected_sel_add = sel_evals[0]; + + if proof.evals[0] != expected_sel_add { + return Err(ZKVMError::VerifyError( + (format!( + "sel_add evaluation mismatch, expected {}, got {}", + expected_sel_add, proof.evals[0] + )) + .into(), + )); + } + + // derive `sel_bypass = eq - sel_add - sel_last_onehot` + let expected_sel_bypass = eq_eval(&out_rt, &rt) + - expected_sel_add + - (out_rt.iter().copied().product::() * rt.iter().copied().product::()); + + if proof.evals[1] != expected_sel_bypass { + return Err(ZKVMError::VerifyError( + (format!( + "sel_bypass evaluation mismatch, expected {}, got {}", + expected_sel_bypass, proof.evals[1] + )) + .into(), + )); + } + + let add_evaluations = vec![v1, v2, v3].into_iter().flatten().sum::(); + let bypass_evaluations = vec![v4, v5].into_iter().flatten().sum::(); + if sumcheck_claim.expected_evaluation + != add_evaluations * expected_sel_add + bypass_evaluations * expected_sel_bypass + { + return Err(ZKVMError::VerifyError( + (format!( + "ecc zerocheck failed: mismatched evaluation, expected {}, got {}", + sumcheck_claim.expected_evaluation, + add_evaluations * expected_sel_add + bypass_evaluations * expected_sel_bypass + )) + .into(), + )); + } + + Ok(()) + } +} diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 8c92036ae..79661d728 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -2,24 +2,46 @@ use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, e2e::ShardContext, error::ZKVMError, - instructions::Instruction, + instructions::{ + Instruction, + global::{GlobalChip, GlobalChipInput, GlobalPoint, GlobalRecord}, + }, + scheme::septic_curve::SepticPoint, state::StateCircuit, tables::{RMMCollections, TableCircuit}, }; use ceno_emul::{CENO_PLATFORM, Platform, StepRecord}; -use ff_ext::ExtensionField; +use ff_ext::{ExtensionField, PoseidonField}; use gkr_iop::{gkr::GKRCircuit, tables::LookupTable, utils::lk_multiplicity::Multiplicity}; use itertools::Itertools; use mpcs::{Point, PolynomialCommitmentScheme}; use multilinear_extensions::{Expression, Instance}; +use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use std::{ collections::{BTreeMap, HashMap}, sync::Arc, }; -use sumcheck::structs::IOPProverMessage; +use sumcheck::structs::{IOPProof, IOPProverMessage}; use witness::RowMajorMatrix; +/// proof that the sum of N=2^n EC points is equal to `sum` +/// in one layer instead of GKR layered circuit approach +/// note that this one layer IOP borrowed ideas from +/// [Quark paper](https://eprint.iacr.org/2020/1275.pdf) +#[derive(Clone, Serialize, Deserialize)] +#[serde(bound( + serialize = "E::BaseField: Serialize", + deserialize = "E::BaseField: DeserializeOwned" +))] +pub struct EccQuarkProof { + pub zerocheck_proof: IOPProof, + pub num_instances: usize, + pub evals: Vec, // x[rt,0], x[rt,1], y[rt,0], y[rt,1], x[0,rt], y[0,rt], s[0,rt] + pub rt: Point, + pub sum: SepticPoint, +} + #[derive(Clone, Serialize, Deserialize)] #[serde(bound( serialize = "E::BaseField: Serialize", @@ -130,12 +152,17 @@ impl ComposedConstrainSystem { self.zkvm_v1_css.w_expressions.len() + self.zkvm_v1_css.w_table_expressions.len() } + pub fn has_ecc_ops(&self) -> bool { + !self.zkvm_v1_css.ec_final_sum.is_empty() + } + pub fn instance_name_map(&self) -> &HashMap { &self.zkvm_v1_css.instance_name_map } pub fn is_opcode_circuit(&self) -> bool { - self.gkr_circuit.is_some() + // TODO: is global chip opcode circuit?? + self.gkr_circuit.is_some() || self.has_ecc_ops() } /// return number of lookup operation @@ -295,6 +322,8 @@ pub struct ZKVMWitnesses { witnesses_tables: BTreeMap>, lk_mlts: BTreeMap>, combined_lk_mlt: Option>>, + // in ram bus chip, num_instances length would be > 1 + pub num_instances: BTreeMap>, } impl ZKVMWitnesses { @@ -327,6 +356,11 @@ impl ZKVMWitnesses { cs.zkvm_v1_css.num_structural_witin as usize, records, )?; + assert!( + self.num_instances + .insert(OC::name(), vec![witness[0].num_instances()]) + .is_none() + ); assert!(self.witnesses_opcodes.insert(OC::name(), witness).is_none()); assert!(!self.witnesses_tables.contains_key(&OC::name())); assert!( @@ -380,12 +414,101 @@ impl ZKVMWitnesses { self.combined_lk_mlt.as_ref().unwrap(), input, )?; + let num_instances = std::cmp::max(witness[0].num_instances(), witness[1].num_instances()); + assert!( + self.num_instances + .insert(TC::name(), vec![num_instances]) + .is_none() + ); assert!(self.witnesses_tables.insert(TC::name(), witness).is_none()); assert!(!self.witnesses_opcodes.contains_key(&TC::name())); Ok(()) } + pub fn assign_global_chip_circuit( + &mut self, + cs: &ZKVMConstraintSystem, + shard_ctx: &ShardContext, + config: & as TableCircuit>::TableConfig, + ) -> Result<(), ZKVMError> { + let perm = ::get_default_perm(); + let global_input = shard_ctx + .write_records() + .par_iter() + .flat_map_iter(|records| { + // global write -> local reads + records.iter().map(|(vma, record)| { + let global_write: GlobalRecord = (vma, record, true).into(); + let ec_point: GlobalPoint = global_write.to_ec_point(&perm); + GlobalChipInput { + record: global_write, + ec_point, + } + }) + }) + .chain( + shard_ctx + .read_records() + .par_iter() + .flat_map_iter(|records| { + // global read -> local write + records.iter().map(|(vma, record)| { + let global_read: GlobalRecord = (vma, record, false).into(); + let ec_point: GlobalPoint = global_read.to_ec_point(&perm); + GlobalChipInput { + record: global_read, + ec_point, + } + }) + }), + ) + .collect::>(); + assert!(self.combined_lk_mlt.is_some()); + let cs = cs.get_cs(&GlobalChip::::name()).unwrap(); + let witness = GlobalChip::assign_instances( + config, + cs.zkvm_v1_css.num_witin as usize, + cs.zkvm_v1_css.num_structural_witin as usize, + self.combined_lk_mlt.as_ref().unwrap(), + &global_input, + )?; + // set num_read, num_write as separate instance + assert!( + self.num_instances + .insert( + GlobalChip::::name(), + vec![ + // global write -> local read + shard_ctx + .write_records() + .iter() + .map(|records| records.len()) + .sum(), + // global read -> local write + shard_ctx + .read_records() + .iter() + .map(|records| records.len()) + .sum(), + ] + ) + .is_none() + ); + assert!( + self.witnesses_tables + .insert(GlobalChip::::name(), witness) + .is_none() + ); + assert!( + !self + .witnesses_opcodes + .contains_key(&GlobalChip::::name()) + ); + + Ok(()) + } + /// Iterate opcode/table circuits, sorted by alphabetical order. pub fn into_iter_sorted( self, diff --git a/ceno_zkvm/src/tables/ram/ram_circuit.rs b/ceno_zkvm/src/tables/ram/ram_circuit.rs index 8fc43e348..344a8d891 100644 --- a/ceno_zkvm/src/tables/ram/ram_circuit.rs +++ b/ceno_zkvm/src/tables/ram/ram_circuit.rs @@ -20,7 +20,6 @@ use gkr_iop::{ }; use itertools::Itertools; use multilinear_extensions::{StructuralWitInType, ToExpr}; -use p3::field::FieldAlgebra; use witness::{InstancePaddingStrategy, RowMajorMatrix}; #[derive(Clone, Debug)] @@ -321,7 +320,7 @@ impl<'a, E: ExtensionField, const V_LIMBS: usize> TableCircuit descending: false, }, ); - let selector_type = SelectorType::Prefix(E::BaseField::ZERO, selector.expr()); + let selector_type = SelectorType::Prefix(selector.expr()); // all shared the same selector let (out_evals, mut chip) = ( diff --git a/gkr_iop/src/circuit_builder.rs b/gkr_iop/src/circuit_builder.rs index 395b9e6c9..70de7f171 100644 --- a/gkr_iop/src/circuit_builder.rs +++ b/gkr_iop/src/circuit_builder.rs @@ -103,6 +103,10 @@ pub struct ConstraintSystem { pub instance_name_map: HashMap, + pub ec_point_exprs: Vec>, + pub ec_slope_exprs: Vec>, + pub ec_final_sum: Vec>, + pub r_selector: Option>, pub r_expressions: Vec>, pub r_expressions_namespace_map: Vec, @@ -167,6 +171,9 @@ impl ConstraintSystem { fixed_namespace_map: vec![], ns: NameSpace::new(root_name_fn), instance_name_map: HashMap::new(), + ec_final_sum: vec![], + ec_slope_exprs: vec![], + ec_point_exprs: vec![], r_selector: None, r_expressions: vec![], r_expressions_namespace_map: vec![], @@ -412,12 +419,22 @@ impl ConstraintSystem { record: Vec>, ) -> Result<(), CircuitBuilderError> { let rlc_record = self.rlc_chip_record(record.clone()); + self.read_rlc_record(name_fn, (ram_type as u64).into(), record, rlc_record) + } + + pub fn read_rlc_record, N: FnOnce() -> NR>( + &mut self, + name_fn: N, + ram_type: Expression, + record: Vec>, + rlc_record: Expression, + ) -> Result<(), CircuitBuilderError> { self.r_expressions.push(rlc_record); let path = self.ns.compute_path(name_fn().into()); self.r_expressions_namespace_map.push(path); // Since r_expression is RLC(record) and when we're debugging // it's helpful to recover the value of record itself. - self.r_ram_types.push(((ram_type as u64).into(), record)); + self.r_ram_types.push((ram_type, record)); Ok(()) } @@ -428,13 +445,45 @@ impl ConstraintSystem { record: Vec>, ) -> Result<(), CircuitBuilderError> { let rlc_record = self.rlc_chip_record(record.clone()); + self.write_rlc_record(name_fn, (ram_type as u64).into(), record, rlc_record) + } + + pub fn write_rlc_record, N: FnOnce() -> NR>( + &mut self, + name_fn: N, + ram_type: Expression, + record: Vec>, + rlc_record: Expression, + ) -> Result<(), CircuitBuilderError> { self.w_expressions.push(rlc_record); let path = self.ns.compute_path(name_fn().into()); self.w_expressions_namespace_map.push(path); - self.w_ram_types.push(((ram_type as u64).into(), record)); + // Since w_expression is RLC(record) and when we're debugging + // it's helpful to recover the value of record itself. + self.w_ram_types.push((ram_type, record)); Ok(()) } + pub fn ec_sum( + &mut self, + xs: Vec>, + ys: Vec>, + slopes: Vec>, + final_sum: Vec>, + ) { + assert_eq!(xs.len(), 7); + assert_eq!(ys.len(), 7); + assert_eq!(slopes.len(), 7); + assert_eq!(final_sum.len(), 7 * 2); + + assert_eq!(self.ec_point_exprs.len(), 0); + self.ec_point_exprs.extend(xs); + self.ec_point_exprs.extend(ys); + + self.ec_slope_exprs = slopes; + self.ec_final_sum = final_sum; + } + pub fn require_zero, N: FnOnce() -> NR>( &mut self, name_fn: N, @@ -669,6 +718,21 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { self.cs.read_record(name_fn, ram_type, record) } + pub fn read_rlc_record( + &mut self, + name_fn: N, + ram_type: Expression, + record: Vec>, + rlc_record: Expression, + ) -> Result<(), CircuitBuilderError> + where + NR: Into, + N: FnOnce() -> NR, + { + self.cs + .read_rlc_record(name_fn, ram_type, record, rlc_record) + } + pub fn write_record( &mut self, name_fn: N, @@ -682,10 +746,35 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { self.cs.write_record(name_fn, ram_type, record) } + pub fn write_rlc_record( + &mut self, + name_fn: N, + ram_type: Expression, + record: Vec>, + rlc_record: Expression, + ) -> Result<(), CircuitBuilderError> + where + NR: Into, + N: FnOnce() -> NR, + { + self.cs + .write_rlc_record(name_fn, ram_type, record, rlc_record) + } + pub fn rlc_chip_record(&self, records: Vec>) -> Expression { self.cs.rlc_chip_record(records) } + pub fn ec_sum( + &mut self, + xs: Vec>, + ys: Vec>, + slope: Vec>, + final_sum: Vec>, + ) { + self.cs.ec_sum(xs, ys, slope, final_sum); + } + pub fn create_bit(&mut self, name_fn: N) -> Result where NR: Into, diff --git a/gkr_iop/src/gkr.rs b/gkr_iop/src/gkr.rs index 7d80229fd..b06e8fe71 100644 --- a/gkr_iop/src/gkr.rs +++ b/gkr_iop/src/gkr.rs @@ -11,6 +11,7 @@ use transcript::Transcript; use crate::{ error::BackendError, hal::{ProverBackend, ProverDevice}, + selector::SelectorContext, }; pub mod booleanhypercube; @@ -77,7 +78,7 @@ impl GKRCircuit { pub_io_evals: &[E], challenges: &[E], transcript: &mut impl Transcript, - num_instances: usize, + selector_ctxs: &[SelectorContext], ) -> Result>, BackendError> { let mut running_evals = out_evals.to_vec(); // running evals is a global referable within chip @@ -97,7 +98,7 @@ impl GKRCircuit { pub_io_evals, &mut challenges, transcript, - num_instances, + selector_ctxs, ); exit_span!(span); res @@ -122,7 +123,7 @@ impl GKRCircuit { pub_io_evals: &[E], challenges: &[E], transcript: &mut impl Transcript, - num_instances: usize, + selector_ctxs: &[SelectorContext], ) -> Result>, BackendError> where E: ExtensionField, @@ -141,7 +142,7 @@ impl GKRCircuit { pub_io_evals, &mut challenges, transcript, - num_instances, + selector_ctxs, )?; } diff --git a/gkr_iop/src/gkr/layer.rs b/gkr_iop/src/gkr/layer.rs index 6bd76af68..22312497d 100644 --- a/gkr_iop/src/gkr/layer.rs +++ b/gkr_iop/src/gkr/layer.rs @@ -21,7 +21,7 @@ use crate::{ error::BackendError, evaluation::EvalExpression, hal::{MultilinearPolynomial, ProverBackend, ProverDevice}, - selector::SelectorType, + selector::{SelectorContext, SelectorType}, }; pub mod cpu; @@ -184,7 +184,7 @@ impl Layer { pub_io_evals: &[E], challenges: &mut Vec, transcript: &mut T, - num_instances: usize, + selector_ctxs: &[SelectorContext], ) -> LayerProof { self.update_challenges(challenges, transcript); let mut eval_and_dedup_points = self.extract_claim_and_point(claims, challenges); @@ -204,7 +204,7 @@ impl Layer { pub_io_evals, challenges, transcript, - num_instances, + selector_ctxs, ) } LayerType::Linear => { @@ -232,7 +232,7 @@ impl Layer { pub_io_evals: &[E], challenges: &mut Vec, transcript: &mut Trans, - num_instances: usize, + selector_ctxs: &[SelectorContext], ) -> Result<(), BackendError> { self.update_challenges(challenges, transcript); let mut eval_and_dedup_points = self.extract_claim_and_point(claims, challenges); @@ -246,7 +246,7 @@ impl Layer { pub_io_evals, challenges, transcript, - num_instances, + selector_ctxs, )?, LayerType::Linear => { assert_eq!(eval_and_dedup_points.len(), 1); diff --git a/gkr_iop/src/gkr/layer/cpu/mod.rs b/gkr_iop/src/gkr/layer/cpu/mod.rs index 95d315f25..255daeed3 100644 --- a/gkr_iop/src/gkr/layer/cpu/mod.rs +++ b/gkr_iop/src/gkr/layer/cpu/mod.rs @@ -8,6 +8,7 @@ use crate::{ zerocheck_layer::RotationPoints, }, }, + selector::SelectorContext, utils::{rotation_next_base_mle, rotation_selector}, }; use either::Either; @@ -113,7 +114,7 @@ impl> ZerocheckLayerProver pub_io_evals: &[ as ProverBackend>::E], challenges: &[ as ProverBackend>::E], transcript: &mut impl Transcript< as ProverBackend>::E>, - num_instances: usize, + selector_ctxs: &[SelectorContext], ) -> ( LayerProof< as ProverBackend>::E>, Point< as ProverBackend>::E>, @@ -126,6 +127,12 @@ impl> ZerocheckLayerProver layer.out_sel_and_eval_exprs.len(), out_points.len(), ); + assert_eq!( + layer.out_sel_and_eval_exprs.len(), + selector_ctxs.len(), + "selector_ctxs length {}", + selector_ctxs.len() + ); let (_, raw_rotation_exprs) = &layer.rotation_exprs; let (rotation_proof, rotation_left, rotation_right, rotation_point) = @@ -174,7 +181,10 @@ impl> ZerocheckLayerProver .out_sel_and_eval_exprs .par_iter() .zip(out_points.par_iter()) - .filter_map(|((sel_type, _), point)| sel_type.compute(point, num_instances)) + .zip(selector_ctxs.par_iter()) + .filter_map(|(((sel_type, _), point), selector_ctx)| { + sel_type.compute(point, selector_ctx) + }) // for rotation left point .chain(rotation_left.par_iter().map(|rotation_left| { MultilinearExtension::from_evaluations_ext_vec( diff --git a/gkr_iop/src/gkr/layer/hal.rs b/gkr_iop/src/gkr/layer/hal.rs index 06508e298..c6cce26a0 100644 --- a/gkr_iop/src/gkr/layer/hal.rs +++ b/gkr_iop/src/gkr/layer/hal.rs @@ -4,6 +4,7 @@ use transcript::Transcript; use crate::{ gkr::layer::{Layer, LayerWitness, sumcheck_layer::LayerProof}, hal::ProverBackend, + selector::SelectorContext, }; pub trait LinearLayerProver { @@ -37,6 +38,6 @@ pub trait ZerocheckLayerProver { pub_io_evals: &[PB::E], challenges: &[PB::E], transcript: &mut impl Transcript, - num_instances: usize, + selector_ctxs: &[SelectorContext], ) -> (LayerProof, Point); } diff --git a/gkr_iop/src/gkr/layer/zerocheck_layer.rs b/gkr_iop/src/gkr/layer/zerocheck_layer.rs index 1d4e6c56a..d9f13a2a9 100644 --- a/gkr_iop/src/gkr/layer/zerocheck_layer.rs +++ b/gkr_iop/src/gkr/layer/zerocheck_layer.rs @@ -27,7 +27,7 @@ use crate::{ }, }, hal::{ProverBackend, ProverDevice}, - selector::SelectorType, + selector::{SelectorContext, SelectorType}, utils::rotation_selector_eval, }; @@ -58,7 +58,7 @@ pub trait ZerocheckLayer { pub_io_evals: &[PB::E], challenges: &[PB::E], transcript: &mut impl Transcript, - num_instances: usize, + selector_ctxs: &[SelectorContext], ) -> (LayerProof, Point); #[allow(clippy::too_many_arguments)] @@ -70,7 +70,7 @@ pub trait ZerocheckLayer { pub_io_evals: &[E], challenges: &[E], transcript: &mut impl Transcript, - num_instances: usize, + selector_ctxs: &[SelectorContext], ) -> Result, BackendError>; } @@ -177,7 +177,7 @@ impl ZerocheckLayer for Layer { pub_io_evals: &[PB::E], challenges: &[PB::E], transcript: &mut impl Transcript, - num_instances: usize, + selector_ctxs: &[SelectorContext], ) -> (LayerProof, Point) { >::prove( self, @@ -188,7 +188,7 @@ impl ZerocheckLayer for Layer { pub_io_evals, challenges, transcript, - num_instances, + selector_ctxs, ) } @@ -200,7 +200,7 @@ impl ZerocheckLayer for Layer { pub_io_evals: &[E], challenges: &[E], transcript: &mut impl Transcript, - num_instances: usize, + selector_ctxs: &[SelectorContext], ) -> Result, BackendError> { assert_eq!( self.out_sel_and_eval_exprs.len(), @@ -284,17 +284,20 @@ impl ZerocheckLayer for Layer { let in_point = in_point.into_iter().map(|c| c.elements).collect_vec(); // eval eq and set to respective witin - izip!(&self.out_sel_and_eval_exprs, &eval_and_dedup_points).for_each( - |((sel_type, _), (_, out_point))| { - sel_type.evaluate( - &mut main_evals, - out_point.as_ref().unwrap(), - &in_point, - num_instances, - self.n_witin, - ); - }, - ); + izip!( + &self.out_sel_and_eval_exprs, + &eval_and_dedup_points, + selector_ctxs.iter() + ) + .for_each(|((sel_type, _), (_, out_point), selector_ctx)| { + sel_type.evaluate( + &mut main_evals, + out_point.as_ref().unwrap(), + &in_point, + selector_ctx, + self.n_witin, + ); + }); let got_claim = eval_by_expr_with_instance( &[], @@ -450,10 +453,11 @@ pub fn extend_exprs_with_rotation( let expr = match sel_type { SelectorType::None => zero_check_expr, SelectorType::Whole(sel) - | SelectorType::Prefix(_, sel) + | SelectorType::Prefix(sel) | SelectorType::OrderedSparse32 { expression: sel, .. - } => match_expr(sel) * zero_check_expr, + } + | SelectorType::QuarkBinaryTreeLessThan(sel) => match_expr(sel) * zero_check_expr, }; zero_check_exprs.push(expr); } diff --git a/gkr_iop/src/selector.rs b/gkr_iop/src/selector.rs index bc57295f1..9f10d2249 100644 --- a/gkr_iop/src/selector.rs +++ b/gkr_iop/src/selector.rs @@ -1,16 +1,41 @@ +use std::iter::repeat_n; + use rayon::iter::IndexedParallelIterator; use ff_ext::ExtensionField; use multilinear_extensions::{ Expression, mle::{IntoMLE, MultilinearExtension, Point}, + util::ceil_log2, virtual_poly::{build_eq_x_r_vec, eq_eval}, }; -use rayon::{iter::ParallelIterator, slice::ParallelSliceMut}; +use p3::field::FieldAlgebra; +use rayon::{ + iter::{IntoParallelIterator, ParallelIterator}, + slice::ParallelSliceMut, +}; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use crate::{gkr::booleanhypercube::CYCLIC_POW2_5, utils::eq_eval_less_or_equal_than}; +/// Provide context for selector's instantiation at runtime +#[derive(Clone, Debug)] +pub struct SelectorContext { + pub offset: usize, + pub num_instances: usize, + pub num_vars: usize, +} + +impl SelectorContext { + pub fn new(offset: usize, num_instances: usize, num_vars: usize) -> Self { + Self { + offset, + num_instances, + num_vars, + } + } +} + /// Selector selects part of the witnesses in the sumcheck protocol. #[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)] #[serde(bound( @@ -20,42 +45,122 @@ use crate::{gkr::booleanhypercube::CYCLIC_POW2_5, utils::eq_eval_less_or_equal_t pub enum SelectorType { None, Whole(Expression), - /// Select a prefix as the instances, padded with a field element. - Prefix(E::BaseField, Expression), + /// Select part of the instances, other parts padded with a field element. + Prefix(Expression), /// selector activates on the specified `indices`, which are assumed to be in ascending order. /// each index corresponds to a position within a fixed-size chunk (e.g., size 32), OrderedSparse32 { indices: Vec, expression: Expression, }, + /// binary tree [`quark`] from paper + QuarkBinaryTreeLessThan(Expression), } impl SelectorType { + /// Returns an MultilinearExtension with `ctx.num_vars` variables whenever applicable + pub fn to_mle(&self, ctx: &SelectorContext) -> Option> { + match self { + SelectorType::None => None, + SelectorType::Whole(_) => { + assert_eq!(ceil_log2(ctx.num_instances), ctx.num_vars); + Some( + (0..(1 << ctx.num_vars)) + .into_par_iter() + .map(|_| E::BaseField::ONE) + .collect::>() + .into_mle(), + ) + } + SelectorType::Prefix(_) => { + assert!(ctx.offset + ctx.num_instances <= (1 << ctx.num_vars)); + let start = ctx.offset; + let end = start + ctx.num_instances; + Some( + (0..start) + .into_par_iter() + .map(|_| E::BaseField::ZERO) + .chain((start..end).into_par_iter().map(|_| E::BaseField::ONE)) + .chain( + (end..(1 << ctx.num_vars)) + .into_par_iter() + .map(|_| E::BaseField::ZERO), + ) + .collect::>() + .into_mle(), + ) + } + SelectorType::OrderedSparse32 { + indices, + expression: _, + } => { + assert_eq!(ceil_log2(ctx.num_instances) + 5, ctx.num_vars); + Some( + (0..(1 << (ctx.num_vars - 5))) + .into_par_iter() + .flat_map(|chunk_index| { + if chunk_index >= ctx.num_instances { + vec![E::ZERO; 32] + } else { + let mut chunk = vec![E::ZERO; 32]; + let mut indices_iter = indices.iter().copied(); + let mut next_keep = indices_iter.next(); + + for (i, e) in chunk.iter_mut().enumerate() { + if let Some(idx) = next_keep + && i == idx + { + *e = E::ONE; + next_keep = indices_iter.next(); // Keep this one + } + } + chunk + } + }) + .collect::>() + .into_mle(), + ) + } + SelectorType::QuarkBinaryTreeLessThan(..) => unimplemented!(), + } + } + /// Compute true and false mle eq(1; b[..5]) * sel(y; b[5..]), and eq(1; b[..5]) * (eq() - sel(y; b[5..])) pub fn compute( &self, out_point: &Point, - num_instances: usize, + ctx: &SelectorContext, ) -> Option> { + assert_eq!(out_point.len(), ctx.num_vars); + match self { SelectorType::None => None, - SelectorType::Whole(_expr) => Some(build_eq_x_r_vec(out_point).into_mle()), - SelectorType::Prefix(_, _expr) => { + SelectorType::Whole(_) => Some(build_eq_x_r_vec(out_point).into_mle()), + SelectorType::Prefix(_) => { + let start = ctx.offset; + let end = start + ctx.num_instances; + assert!( + end <= (1 << ctx.num_vars), + "start: {}, num_instances: {}, num_vars: {}", + start, + ctx.num_instances, + ctx.num_vars + ); + let mut sel = build_eq_x_r_vec(out_point); - if num_instances < sel.len() { - sel.splice( - num_instances..sel.len(), - std::iter::repeat_n(E::ZERO, sel.len() - num_instances), - ); - } + sel.splice(0..start, repeat_n(E::ZERO, start)); + sel.splice(end..sel.len(), repeat_n(E::ZERO, sel.len() - end)); Some(sel.into_mle()) } + // compute true and false mle eq(1; b[..5]) * sel(y; b[5..]), and eq(1; b[..5]) * (eq() - sel(y; b[5..])) SelectorType::OrderedSparse32 { indices, .. } => { + assert_eq!(out_point.len(), ceil_log2(ctx.num_instances) + 5); + let mut sel = build_eq_x_r_vec(out_point); sel.par_chunks_exact_mut(CYCLIC_POW2_5.len()) .enumerate() .for_each(|(chunk_index, chunk)| { - if chunk_index >= num_instances { + if chunk_index >= ctx.num_instances { // Zero out the entire chunk if out of instance range chunk.iter_mut().for_each(|e| *e = E::ZERO); return; @@ -75,31 +180,107 @@ impl SelectorType { }); Some(sel.into_mle()) } + // also see evaluate() function for more explanation + SelectorType::QuarkBinaryTreeLessThan(_) => { + assert_eq!(ctx.offset, 0); + // num_instances: number of prefix one in leaf layer + let mut sel: Vec = build_eq_x_r_vec(out_point); + let n = sel.len(); + + let num_instances_sequence = (0..out_point.len()) + // clean up sig bits + .scan(ctx.num_instances, |n_instance, _| { + // n points to sum means we have n/2 addition pairs + let cur = *n_instance / 2; + // the next layer has ceil(n/2) points to sum + *n_instance = (*n_instance).div_ceil(2); + Some(cur) + }) + .collect::>(); + + // split sel into different size of region, set tailing 0 of respective chunk size + // 1st round: take v = sel[0..sel.len()/2], zero out v[num_instances_sequence[0]..] + // 2nd round: take v = sel[sel.len()/2 .. sel.len()/4], zero out v[num_instances_sequence[1]..] + // ... + // each round: progressively smaller chunk + // example: round 0 uses first half, round 1 uses next quarter, etc. + // compute cumulative start indices: + // e.g. chunk = n/2, then start = 0, chunk, chunk + chunk/2, chunk + chunk/2 + chunk/4, ... + // compute disjoint start indices and lengths + let chunks: Vec<(usize, usize)> = { + let mut result = Vec::new(); + let mut start = 0; + let mut chunk_len = n / 2; + while chunk_len > 0 { + result.push((start, chunk_len)); + start += chunk_len; + chunk_len /= 2; + } + result + }; + + for (i, (start, len)) in chunks.into_iter().enumerate() { + let slice = &mut sel[start..start + len]; + + // determine from which index to zero + let zero_start = num_instances_sequence.get(i).copied().unwrap_or(0).min(len); + + for x in &mut slice[zero_start..] { + *x = E::ZERO; + } + } + + // zero out last bh evaluations + *sel.last_mut().unwrap() = E::ZERO; + Some(sel.into_mle()) + } } } - /// Evaluate true and false mle eq(CYCLIC_POW2_5[round]; b[..5]) * sel(y; b[5..]), and eq(1; b[..5]) * (1 - sel(y; b[5..])) pub fn evaluate( &self, evals: &mut Vec, out_point: &Point, in_point: &Point, - num_instances: usize, + ctx: &SelectorContext, offset_eq_id: usize, ) { + assert_eq!(in_point.len(), ctx.num_vars); + assert_eq!(out_point.len(), ctx.num_vars); + let (expr, eval) = match self { SelectorType::None => return, SelectorType::Whole(expr) => { debug_assert_eq!(out_point.len(), in_point.len()); (expr, eq_eval(out_point, in_point)) } - SelectorType::Prefix(_, expr) => { - debug_assert!(num_instances <= (1 << out_point.len())); - ( - expr, - eq_eval_less_or_equal_than(num_instances - 1, out_point, in_point), - ) + SelectorType::Prefix(expression) => { + let start = ctx.offset; + let end = start + ctx.num_instances; + + assert_eq!(in_point.len(), out_point.len()); + assert!( + end <= (1 << out_point.len()), + "start: {}, num_instances: {}, num_vars: {}", + start, + ctx.num_instances, + ctx.num_vars + ); + + if end == 0 { + (expression, E::ZERO) + } else { + let eq_end = eq_eval_less_or_equal_than(end - 1, out_point, in_point); + let sel = if start > 0 { + let eq_start = eq_eval_less_or_equal_than(start - 1, out_point, in_point); + eq_end - eq_start + } else { + eq_end + }; + (expression, sel) + } } + // evaluate true and false mle eq(CYCLIC_POW2_5[round]; b[..5]) * sel(y; b[5..]), and eq(1; b[..5]) * (1 - sel(y; b[5..])) SelectorType::OrderedSparse32 { indices, expression, @@ -110,10 +291,64 @@ impl SelectorType { for index in indices { eval += out_subgroup_eq[*index] * in_subgroup_eq[*index]; } - let sel = - eq_eval_less_or_equal_than(num_instances - 1, &out_point[5..], &in_point[5..]); + let sel = eq_eval_less_or_equal_than( + ctx.num_instances - 1, + &out_point[5..], + &in_point[5..], + ); (expression, eval * sel) } + SelectorType::QuarkBinaryTreeLessThan(expr) => { + // num_instances count on leaf layer + // where nodes size is 2^(N) / 2 + // out_point.len() is also log(2^(N)) - 1 + // so num_instances and 1 << out_point.len() are on same scaling + assert!(ctx.num_instances > 0); + assert!(ctx.num_instances <= (1 << out_point.len())); + assert!(!out_point.is_empty()); + assert_eq!(out_point.len(), in_point.len()); + + // we break down this special selector evaluation into recursive structure + // iterating through out_point and in_point, for each i + // next_eval = lhs * (1-out_point[i]) * (1 - in_point[i]) + prev_eval * out_point[i] * in_point[i] + // where the lhs is in consecutive prefix 1 follow by 0 + + // calculate prefix 1 length of each layer + let mut prefix_one_seq = (0..out_point.len()) + .scan(ctx.num_instances, |n_instance, _| { + // n points to sum means we have n/2 addition pairs + let cur = *n_instance / 2; + // next layer has ceil(n/2) points to sum + *n_instance = (*n_instance).div_ceil(2); + Some(cur) + }) + .collect::>(); + prefix_one_seq.reverse(); + + let mut res = if prefix_one_seq[0] == 0 { + E::ZERO + } else { + assert_eq!(prefix_one_seq[0], 1); + (E::ONE - out_point[0]) * (E::ONE - in_point[0]) + }; + for i in 1..out_point.len() { + let num_prefix_one_lhs = prefix_one_seq[i]; + let lhs_res = if num_prefix_one_lhs == 0 { + E::ZERO + } else { + (E::ONE - out_point[i]) + * (E::ONE - in_point[i]) + * eq_eval_less_or_equal_than( + num_prefix_one_lhs - 1, + &out_point[..i], + &in_point[..i], + ) + }; + let rhs_res = (out_point[i] * in_point[i]) * res; + res = lhs_res + rhs_res; + } + (expr, res) + } }; let Expression::StructuralWitIn(wit_id, _) = expr else { panic!("Wrong selector expression format"); @@ -137,8 +372,63 @@ impl SelectorType { match self { Self::OrderedSparse32 { expression, .. } | Self::Whole(expression) - | Self::Prefix(_, expression) => expression, + | Self::Prefix(expression) => expression, e => unimplemented!("no selector expression in {:?}", e), } } } + +#[cfg(test)] +mod tests { + use ff_ext::{BabyBearExt4, FromUniformBytes}; + use multilinear_extensions::{ + StructuralWitIn, ToExpr, util::ceil_log2, virtual_poly::build_eq_x_r_vec, + }; + use p3::field::FieldAlgebra; + use rand::thread_rng; + + use crate::selector::{SelectorContext, SelectorType}; + + type E = BabyBearExt4; + + #[test] + fn test_quark_lt_selector() { + let mut rng = thread_rng(); + let n_points = 5; + let n_vars = ceil_log2(n_points); + let witin = StructuralWitIn { + id: 0, + witin_type: multilinear_extensions::StructuralWitInType::EqualDistanceSequence { + max_len: 0, + offset: 0, + multi_factor: 0, + descending: false, + }, + }; + let selector = SelectorType::QuarkBinaryTreeLessThan(witin.expr()); + let ctx = SelectorContext::new(0, n_points, n_vars); + let out_rt = E::random_vec(n_vars, &mut rng); + let sel_mle = selector.compute(&out_rt, &ctx).unwrap(); + + // if we have 5 points to sum, then + // in 1st layer: two additions p12 = p1 + p2, p34 = p3 + p4, p5 kept + // in 2nd layer: one addition p14 = p12 + p34, p5 kept + // in 3rd layer: one addition p15 = p14 + p5 + let eq = build_eq_x_r_vec(&out_rt); + let vec = sel_mle.get_ext_field_vec(); + assert_eq!(vec[0], eq[0]); // p1+p2 + assert_eq!(vec[1], eq[1]); // p3+p4 + assert_eq!(vec[2], E::ZERO); // p5 + assert_eq!(vec[3], E::ZERO); + assert_eq!(vec[4], eq[4]); // p1+p2+p3+p4 + assert_eq!(vec[5], E::ZERO); // p5 + assert_eq!(vec[6], eq[6]); // p1+p2+p3+p4+p5 + assert_eq!(vec[7], E::ZERO); + + let in_rt = E::random_vec(n_vars, &mut rng); + let mut evals = vec![]; + // TODO: avoid the param evals when we evaluate a selector + selector.evaluate(&mut evals, &out_rt, &in_rt, &ctx, 0); + assert_eq!(sel_mle.evaluate(&in_rt), evals[0]); + } +}