diff --git a/.github/workflows/lints.yml b/.github/workflows/lints.yml index 29ff53880..8c08a3662 100644 --- a/.github/workflows/lints.yml +++ b/.github/workflows/lints.yml @@ -66,3 +66,5 @@ jobs: run: taplo --version || cargo install taplo-cli - name: Run taplo run: taplo fmt --check --diff + - name: Ensure Cargo.lock not modified by build + run: git diff --exit-code Cargo.lock diff --git a/Cargo.lock b/Cargo.lock index d3cb607be..22db24a3e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -939,7 +939,7 @@ dependencies = [ [[package]] name = "ceno_crypto_primitives" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno#4a9dff21fd408e93c21edb6e874a09b0171b0c8b" +source = "git+https://github.com/scroll-tech/ceno#050108047aad24101fcb010da4e7d29e9d72678a" dependencies = [ "ceno_syscall 0.1.0 (git+https://github.com/scroll-tech/ceno)", "elliptic-curve", @@ -958,9 +958,12 @@ dependencies = [ "multilinear_extensions", "num-derive", "num-traits", + "rayon", "rrs-succinct", + "rustc-hash", "secp", "serde", + "smallvec", "strum", "strum_macros", "substrate-bn 0.6.0 (registry+https://github.com/rust-lang/crates.io-index)", @@ -1013,7 +1016,7 @@ version = "0.1.0" [[package]] name = "ceno_syscall" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno#4a9dff21fd408e93c21edb6e874a09b0171b0c8b" +source = "git+https://github.com/scroll-tech/ceno#050108047aad24101fcb010da4e7d29e9d72678a" [[package]] name = "ceno_zkvm" @@ -1049,6 +1052,7 @@ dependencies = [ "proptest", "rand 0.8.5", "rayon", + "rustc-hash", "serde", "serde_json", "sp1-curves", @@ -1851,7 +1855,7 @@ dependencies = [ "ceno_syscall 0.1.0", "getrandom 0.3.2", "rand 0.8.5", - "revm-precompile 28.1.0", + "revm-precompile 28.1.1", "rkyv", "substrate-bn 0.6.0 (registry+https://github.com/rust-lang/crates.io-index)", "substrate-bn 0.6.0 (git+https://github.com/scroll-tech/bn?branch=ceno)", @@ -3874,9 +3878,9 @@ dependencies = [ [[package]] name = "revm-precompile" -version = "28.1.0" +version = "28.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "176169b39beb1f57b11f2ea3900c404b8498a56dfd8394e66f4d24f66cea368e" +checksum = "e57aadd7a2087705f653b5aaacc8ad4f8e851f5d330661e3f4c43b5475bbceae" dependencies = [ "ark-bls12-381", "ark-bn254", @@ -3888,7 +3892,7 @@ dependencies = [ "cfg-if", "k256 0.13.4 (registry+https://github.com/rust-lang/crates.io-index)", "p256", - "revm-primitives 21.0.0", + "revm-primitives 21.0.1", "ripemd", "sha2 0.10.9", ] @@ -3907,9 +3911,9 @@ dependencies = [ [[package]] name = "revm-primitives" -version = "21.0.0" +version = "21.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38271b8b85f00154bdcf9f2ab0a3ec7a8100377d2c7a0d8eb23e19389b42c795" +checksum = "536f30e24c3c2bf0d3d7d20fa9cf99b93040ed0f021fd9301c78cddb0dacda13" dependencies = [ "alloy-primitives", "num_enum 0.7.4", @@ -4432,6 +4436,9 @@ name = "smallvec" version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8917285742e9f3e1683f0a9c4e6b57960b7314d0b08d30d1ecd426713ee2eee9" +dependencies = [ + "serde", +] [[package]] name = "snowbridge-amcl" diff --git a/Cargo.toml b/Cargo.toml index 7acdd2d7a..733390e07 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,9 +57,17 @@ rand_chacha = { version = "0.3", features = ["serde1"] } rand_core = "0.6" rayon = "1.10" rkyv = { version = "0.8", features = ["pointer_width_32"] } +rustc-hash = "2.0.0" secp = "0.4.1" serde = { version = "1.0", features = ["derive", "rc"] } serde_json = "1.0" +smallvec = { version = "1.13.2", features = [ + "const_generics", + "const_new", + "serde", + "union", + "write", +] } strum = "0.26" strum_macros = "0.26" substrate-bn = { version = "0.6.0" } diff --git a/ceno_cli/src/commands/common_args/ceno.rs b/ceno_cli/src/commands/common_args/ceno.rs index 9632986a8..d73841080 100644 --- a/ceno_cli/src/commands/common_args/ceno.rs +++ b/ceno_cli/src/commands/common_args/ceno.rs @@ -78,6 +78,14 @@ pub struct CenoOptions { #[arg(long)] pub out_vk: Option, + /// shard id + #[arg(long, default_value = "0")] + shard_id: u32, + + /// number of total shards. + #[arg(long, default_value = "1")] + max_num_shards: u32, + /// Profiling granularity. /// Setting any value restricts logs to profiling information #[arg(long)] @@ -337,6 +345,7 @@ fn run_elf_inner< std::fs::read(elf_path).context(format!("failed to read {}", elf_path.display()))?; let program = Program::load_elf(&elf_bytes, u32::MAX).context("failed to load elf")?; print_cargo_message("Loaded", format_args!("{}", elf_path.display())); + let shards = Shards::new(options.shard_id as usize, options.max_num_shards as usize); let public_io = options .read_public_io() @@ -385,6 +394,7 @@ fn run_elf_inner< create_prover(backend.clone()), program, platform, + shards, &hints, &public_io, options.max_steps, diff --git a/ceno_emul/Cargo.toml b/ceno_emul/Cargo.toml index b0af43fe3..6cc12cd17 100644 --- a/ceno_emul/Cargo.toml +++ b/ceno_emul/Cargo.toml @@ -19,9 +19,12 @@ itertools.workspace = true multilinear_extensions.workspace = true num-derive.workspace = true num-traits.workspace = true +rayon.workspace = true rrs_lib = { package = "rrs-succinct", version = "0.1.0" } +rustc-hash.workspace = true secp.workspace = true serde.workspace = true +smallvec.workspace = true strum.workspace = true strum_macros.workspace = true substrate-bn.workspace = true diff --git a/ceno_emul/src/chunked_vec.rs b/ceno_emul/src/chunked_vec.rs new file mode 100644 index 000000000..e53d51a73 --- /dev/null +++ b/ceno_emul/src/chunked_vec.rs @@ -0,0 +1,89 @@ +use rayon::iter::{IntoParallelIterator, ParallelIterator}; +use std::ops::{Index, IndexMut}; + +/// a chunked vector that grows in fixed-size chunks. +#[derive(Default, Debug, Clone)] +pub struct ChunkedVec { + chunks: Vec>, + chunk_size: usize, + len: usize, +} + +impl ChunkedVec { + /// create a new ChunkedVec with a given chunk size. + pub fn new(chunk_size: usize) -> Self { + assert!(chunk_size > 0, "chunk_size must be > 0"); + Self { + chunks: Vec::new(), + chunk_size, + len: 0, + } + } + + /// get the current number of elements. + pub fn len(&self) -> usize { + self.len + } + + /// returns true if the vector is empty. + pub fn is_empty(&self) -> bool { + self.len == 0 + } + + /// access element by index (immutable). + pub fn get(&self, index: usize) -> Option<&T> { + if index >= self.len { + return None; + } + let chunk_idx = index / self.chunk_size; + let within_idx = index % self.chunk_size; + self.chunks.get(chunk_idx)?.get(within_idx) + } + + /// access element by index (mutable). + /// get mutable reference to element at index, auto-creating chunks as needed + pub fn get_or_create(&mut self, index: usize) -> &mut T { + let chunk_idx = index / self.chunk_size; + let within_idx = index % self.chunk_size; + + // Ensure enough chunks exist + if chunk_idx >= self.chunks.len() { + let to_create = chunk_idx + 1 - self.chunks.len(); + + // Use rayon to create all missing chunks in parallel + let mut new_chunks: Vec> = (0..to_create) + .map(|_| { + (0..self.chunk_size) + .into_par_iter() + .map(|_| Default::default()) + .collect::>() + }) + .collect(); + + self.chunks.append(&mut new_chunks); + } + + let chunk = &mut self.chunks[chunk_idx]; + + // Update the overall length + if index >= self.len { + self.len = index + 1; + } + + &mut chunk[within_idx] + } +} + +impl Index for ChunkedVec { + type Output = T; + + fn index(&self, index: usize) -> &Self::Output { + self.get(index).expect("index out of bounds") + } +} + +impl IndexMut for ChunkedVec { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + self.get_or_create(index) + } +} diff --git a/ceno_emul/src/lib.rs b/ceno_emul/src/lib.rs index 8f439d036..3d88484fa 100644 --- a/ceno_emul/src/lib.rs +++ b/ceno_emul/src/lib.rs @@ -7,7 +7,9 @@ mod platform; pub use platform::{CENO_PLATFORM, Platform}; mod tracer; -pub use tracer::{Change, MemOp, ReadOp, StepRecord, Tracer, WriteOp}; +pub use tracer::{ + Change, MemOp, NextAccessPair, NextCycleAccess, ReadOp, StepRecord, Tracer, WriteOp, +}; mod vm_state; pub use vm_state::VMState; @@ -44,4 +46,5 @@ pub mod utils; pub mod test_utils; +mod chunked_vec; pub mod host_utils; diff --git a/ceno_emul/src/tracer.rs b/ceno_emul/src/tracer.rs index 8280e8351..c36bd5bef 100644 --- a/ceno_emul/src/tracer.rs +++ b/ceno_emul/src/tracer.rs @@ -1,13 +1,13 @@ -use std::{ - collections::{BTreeMap, HashMap}, - fmt, mem, -}; +use rustc_hash::FxHashMap; +use smallvec::SmallVec; +use std::{collections::BTreeMap, fmt, mem}; use ceno_rt::WORD_SIZE; use crate::{ CENO_PLATFORM, InsnKind, Instruction, PC_STEP_SIZE, Platform, addr::{ByteAddr, Cycle, RegIdx, Word, WordAddr}, + chunked_vec::ChunkedVec, encode_rv32, syscalls::{SyscallEffects, SyscallWitness}, }; @@ -39,6 +39,10 @@ pub struct StepRecord { syscall: Option, } +pub type NextAccessPair = SmallVec<[(WordAddr, Cycle); 1]>; +pub type NextCycleAccess = ChunkedVec; +const ACCESSED_CHUNK_SIZE: usize = 1 << 20; + #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct MemOp { /// Virtual Memory Address. @@ -305,7 +309,8 @@ pub struct Tracer { // record each section max access address // (start_addr -> (start_addr, end_addr, min_access_addr, max_access_addr)) mmio_min_max_access: Option>, - latest_accesses: HashMap, + latest_accesses: FxHashMap, + next_accesses: NextCycleAccess, } impl Default for Tracer { @@ -362,7 +367,8 @@ impl Tracer { cycle: Self::SUBCYCLES_PER_INSN, ..StepRecord::default() }, - latest_accesses: HashMap::new(), + latest_accesses: FxHashMap::default(), + next_accesses: NextCycleAccess::new(ACCESSED_CHUNK_SIZE), } } @@ -471,16 +477,24 @@ impl Tracer { /// - Record the current instruction as the origin of the latest access. /// - Accesses within the same instruction are distinguished by `subcycle ∈ [0, 3]`. pub fn track_access(&mut self, addr: WordAddr, subcycle: Cycle) -> Cycle { - self.latest_accesses - .insert(addr, self.record.cycle + subcycle) - .unwrap_or(0) + let cur_cycle = self.record.cycle + subcycle; + let prev_cycle = self.latest_accesses.insert(addr, cur_cycle).unwrap_or(0); + self.next_accesses + .get_or_create(prev_cycle as usize) + .push((addr, cur_cycle)); + prev_cycle } /// Return all the addresses that were accessed and the cycle when they were last accessed. - pub fn final_accesses(&self) -> &HashMap { + pub fn final_accesses(&self) -> &FxHashMap { &self.latest_accesses } + /// Return all the addresses that were accessed and the cycle when they were last accessed. + pub fn next_accesses(self) -> NextCycleAccess { + self.next_accesses + } + /// Return the cycle of the pending instruction (after the last completed step). pub fn cycle(&self) -> Cycle { self.record.cycle diff --git a/ceno_emul/src/vm_state.rs b/ceno_emul/src/vm_state.rs index 51057c2b0..eaac9d639 100644 --- a/ceno_emul/src/vm_state.rs +++ b/ceno_emul/src/vm_state.rs @@ -68,6 +68,10 @@ impl VMState { &self.tracer } + pub fn take_tracer(self) -> Tracer { + self.tracer + } + pub fn platform(&self) -> &Platform { &self.platform } diff --git a/ceno_emul/tests/test_vm_trace.rs b/ceno_emul/tests/test_vm_trace.rs index 74cc83d4e..14bf7a1fe 100644 --- a/ceno_emul/tests/test_vm_trace.rs +++ b/ceno_emul/tests/test_vm_trace.rs @@ -1,9 +1,7 @@ #![allow(clippy::unusual_byte_groupings)] use anyhow::Result; -use std::{ - collections::{BTreeMap, HashMap}, - sync::Arc, -}; +use rustc_hash::FxHashMap; +use std::{collections::BTreeMap, sync::Arc}; use ceno_emul::{ CENO_PLATFORM, Cycle, EmuContext, InsnKind, Instruction, Platform, Program, StepRecord, Tracer, @@ -111,8 +109,8 @@ fn expected_ops_fibonacci_20() -> Vec { } /// Reconstruct the last access of each register. -fn expected_final_accesses_fibonacci_20() -> HashMap { - let mut accesses = HashMap::new(); +fn expected_final_accesses_fibonacci_20() -> FxHashMap { + let mut accesses = FxHashMap::default(); let x = |i| WordAddr::from(Platform::register_vma(i)); const C: Cycle = Tracer::SUBCYCLES_PER_INSN; diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index 3347b38cc..3c1c99ed4 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -34,6 +34,7 @@ witness.workspace = true itertools.workspace = true ndarray.workspace = true prettytable-rs.workspace = true +rustc-hash.workspace = true strum.workspace = true strum_macros.workspace = true tracing.workspace = true diff --git a/ceno_zkvm/benches/fibonacci.rs b/ceno_zkvm/benches/fibonacci.rs index 878502f8e..325c59f46 100644 --- a/ceno_zkvm/benches/fibonacci.rs +++ b/ceno_zkvm/benches/fibonacci.rs @@ -13,7 +13,7 @@ use criterion::*; use ff_ext::BabyBearExt4; use gkr_iop::cpu::default_backend_config; -use ceno_zkvm::scheme::verifier::ZKVMVerifier; +use ceno_zkvm::{e2e::Shards, scheme::verifier::ZKVMVerifier}; use mpcs::BasefoldDefault; use transcript::BasicTranscript; @@ -54,6 +54,7 @@ fn fibonacci_prove(c: &mut Criterion) { create_prover(backend.clone()), program.clone(), platform.clone(), + Shards::default(), &Vec::from(&hints), &[], max_steps, @@ -91,6 +92,7 @@ fn fibonacci_prove(c: &mut Criterion) { create_prover(backend.clone()), program.clone(), platform.clone(), + Shards::default(), &Vec::from(&hints), &[], max_steps, diff --git a/ceno_zkvm/benches/fibonacci_witness.rs b/ceno_zkvm/benches/fibonacci_witness.rs index 483b690d5..d942743db 100644 --- a/ceno_zkvm/benches/fibonacci_witness.rs +++ b/ceno_zkvm/benches/fibonacci_witness.rs @@ -9,6 +9,7 @@ use std::{fs, path::PathBuf, time::Duration}; mod alloc; use criterion::*; +use ceno_zkvm::e2e::Shards; use ff_ext::BabyBearExt4; use gkr_iop::cpu::default_backend_config; use mpcs::BasefoldDefault; @@ -65,6 +66,7 @@ fn fibonacci_witness(c: &mut Criterion) { create_prover(backend.clone()), program.clone(), platform.clone(), + Shards::default(), &Vec::from(&hints), &[], max_steps, diff --git a/ceno_zkvm/benches/is_prime.rs b/ceno_zkvm/benches/is_prime.rs index b55805fb7..6d66ff859 100644 --- a/ceno_zkvm/benches/is_prime.rs +++ b/ceno_zkvm/benches/is_prime.rs @@ -8,6 +8,7 @@ use ceno_zkvm::{ scheme::{create_backend, create_prover}, }; mod alloc; +use ceno_zkvm::e2e::Shards; use criterion::*; use ff_ext::BabyBearExt4; use gkr_iop::cpu::default_backend_config; @@ -62,6 +63,7 @@ fn is_prime_1(c: &mut Criterion) { create_prover(backend.clone()), program.clone(), platform.clone(), + Shards::default(), &hints, &[], max_steps, diff --git a/ceno_zkvm/benches/keccak.rs b/ceno_zkvm/benches/keccak.rs index c1a889594..19011d460 100644 --- a/ceno_zkvm/benches/keccak.rs +++ b/ceno_zkvm/benches/keccak.rs @@ -8,7 +8,7 @@ use ceno_zkvm::{ scheme::{create_backend, create_prover}, }; mod alloc; -use ceno_zkvm::scheme::verifier::ZKVMVerifier; +use ceno_zkvm::{e2e::Shards, scheme::verifier::ZKVMVerifier}; use criterion::*; use ff_ext::BabyBearExt4; use gkr_iop::cpu::default_backend_config; @@ -51,6 +51,7 @@ fn keccak_prove(c: &mut Criterion) { create_prover(backend.clone()), program.clone(), platform.clone(), + Shards::default(), &Vec::from(&hints), &[], max_steps, @@ -85,6 +86,7 @@ fn keccak_prove(c: &mut Criterion) { create_prover(backend.clone()), program.clone(), platform.clone(), + Shards::default(), &Vec::from(&hints), &[], max_steps, diff --git a/ceno_zkvm/benches/quadratic_sorting.rs b/ceno_zkvm/benches/quadratic_sorting.rs index dc234a03a..93389c388 100644 --- a/ceno_zkvm/benches/quadratic_sorting.rs +++ b/ceno_zkvm/benches/quadratic_sorting.rs @@ -8,6 +8,7 @@ use ceno_zkvm::{ scheme::{create_backend, create_prover}, }; mod alloc; +use ceno_zkvm::e2e::Shards; use criterion::*; use ff_ext::BabyBearExt4; use gkr_iop::cpu::default_backend_config; @@ -63,6 +64,7 @@ fn quadratic_sorting_1(c: &mut Criterion) { create_prover(backend.clone()), program.clone(), platform.clone(), + Shards::default(), &hints, &[], max_steps, diff --git a/ceno_zkvm/src/bin/e2e.rs b/ceno_zkvm/src/bin/e2e.rs index c7ec2b310..52df7e6da 100644 --- a/ceno_zkvm/src/bin/e2e.rs +++ b/ceno_zkvm/src/bin/e2e.rs @@ -4,7 +4,7 @@ use ceno_host::{CenoStdin, memory_from_file}; use ceno_zkvm::print_allocated_bytes; use ceno_zkvm::{ e2e::{ - Checkpoint, FieldType, PcsKind, Preset, run_e2e_with_checkpoint, setup_platform, + Checkpoint, FieldType, PcsKind, Preset, Shards, run_e2e_with_checkpoint, setup_platform, setup_platform_debug, verify, }, scheme::{ @@ -108,6 +108,14 @@ struct Args { /// The security level to use. #[arg(short, long, value_enum, default_value_t = SecurityLevel::default())] security_level: SecurityLevel, + + // shard id + #[arg(long, default_value = "0")] + shard_id: u32, + + // number of total shards + #[arg(long, default_value = "1")] + max_num_shards: u32, } fn main() { @@ -240,6 +248,7 @@ fn main() { .unwrap_or_default(); let max_steps = args.max_steps.unwrap_or(usize::MAX); + let shards = Shards::new(args.shard_id as usize, args.max_num_shards as usize); match (args.pcs, args.field) { (PcsKind::Basefold, FieldType::Goldilocks) => { @@ -249,6 +258,7 @@ fn main() { prover, program, platform, + shards, &hints, &public_io, max_steps, @@ -264,6 +274,7 @@ fn main() { prover, program, platform, + shards, &hints, &public_io, max_steps, @@ -279,6 +290,7 @@ fn main() { prover, program, platform, + shards, &hints, &public_io, max_steps, @@ -294,6 +306,7 @@ fn main() { prover, program, platform, + shards, &hints, &public_io, max_steps, @@ -320,6 +333,7 @@ fn run_inner< pd: PD, program: Program, platform: Platform, + shards: Shards, hints: &[u32], public_io: &[u32], max_steps: usize, @@ -328,7 +342,7 @@ fn run_inner< checkpoint: Checkpoint, ) { let result = run_e2e_with_checkpoint::( - pd, program, platform, hints, public_io, max_steps, checkpoint, + pd, program, platform, shards, hints, public_io, max_steps, checkpoint, ); let zkvm_proof = result diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index e1ace19d0..dbd9961a9 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -4,8 +4,8 @@ use gkr_iop::{error::CircuitBuilderError, tables::LookupTable}; use crate::{ circuit_builder::CircuitBuilder, instructions::riscv::constants::{ - END_CYCLE_IDX, END_PC_IDX, EXIT_CODE_IDX, INIT_CYCLE_IDX, INIT_PC_IDX, PUBLIC_IO_IDX, - UINT_LIMBS, + END_CYCLE_IDX, END_PC_IDX, END_SHARD_ID_IDX, EXIT_CODE_IDX, INIT_CYCLE_IDX, INIT_PC_IDX, + PUBLIC_IO_IDX, UINT_LIMBS, }, tables::InsnRecord, }; @@ -22,6 +22,8 @@ pub trait PublicIOQuery { fn query_end_pc(&mut self) -> Result; fn query_end_cycle(&mut self) -> Result; fn query_public_io(&mut self) -> Result<[Instance; UINT_LIMBS], CircuitBuilderError>; + #[allow(dead_code)] + fn query_shard_id(&mut self) -> Result; } impl<'a, E: ExtensionField> InstFetch for CircuitBuilder<'a, E> { @@ -60,6 +62,10 @@ impl<'a, E: ExtensionField> PublicIOQuery for CircuitBuilder<'a, E> { self.cs.query_instance(|| "end_cycle", END_CYCLE_IDX) } + fn query_shard_id(&mut self) -> Result { + self.cs.query_instance(|| "shard_id", END_SHARD_ID_IDX) + } + fn query_public_io(&mut self) -> Result<[Instance; UINT_LIMBS], CircuitBuilderError> { Ok([ self.cs.query_instance(|| "public_io_low", PUBLIC_IO_IDX)?, diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 226231c2b..712f3b7a1 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -16,22 +16,28 @@ use crate::{ tables::{MemFinalRecord, MemInitRecord, ProgramTableCircuit, ProgramTableConfig}, }; use ceno_emul::{ - Addr, ByteAddr, CENO_PLATFORM, EmuContext, InsnKind, IterAddresses, Platform, Program, - StepRecord, Tracer, VMState, WORD_SIZE, WordAddr, host_utils::read_all_messages, + Addr, ByteAddr, CENO_PLATFORM, Cycle, EmuContext, InsnKind, IterAddresses, NextCycleAccess, + Platform, Program, StepRecord, Tracer, VMState, WORD_SIZE, Word, WordAddr, + host_utils::read_all_messages, }; use clap::ValueEnum; +use either::Either; use ff_ext::ExtensionField; #[cfg(debug_assertions)] use ff_ext::{Instrumented, PoseidonField}; -use gkr_iop::hal::ProverBackend; +use gkr_iop::{RAMType, hal::ProverBackend}; use itertools::{Itertools, MinMaxResult, chain}; use mpcs::{PolynomialCommitmentScheme, SecurityLevel}; +use multilinear_extensions::util::max_usable_threads; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; use serde::Serialize; use std::{ - collections::{BTreeSet, HashMap, HashSet}, + borrow::Cow, + collections::{BTreeMap, BTreeSet, HashMap, HashSet}, sync::Arc, }; use transcript::BasicTranscript as Transcript; +use witness::next_pow2_instance_padding; /// The polynomial commitment scheme kind #[derive( @@ -87,19 +93,288 @@ pub struct FullMemState { type InitMemState = FullMemState; type FinalMemState = FullMemState; -pub struct EmulationResult { +pub struct EmulationResult<'a> { pub exit_code: Option, pub all_records: Vec, pub final_mem_state: FinalMemState, pub pi: PublicValues, + pub shard_ctx: ShardContext<'a>, } -pub fn emulate_program( +pub struct RAMRecord { + pub ram_type: RAMType, + pub id: u64, + pub addr: WordAddr, + pub prev_cycle: Cycle, + pub cycle: Cycle, + pub prev_value: Option, + pub value: Word, +} + +#[derive(Clone, Debug)] +pub struct Shards { + pub shard_id: usize, + pub max_num_shards: usize, +} + +impl Shards { + pub fn new(shard_id: usize, max_num_shards: usize) -> Self { + assert!(shard_id < max_num_shards); + Self { + shard_id, + max_num_shards, + } + } + + pub fn is_first_shard(&self) -> bool { + self.shard_id == 0 + } + + pub fn is_last_shard(&self) -> bool { + self.shard_id == self.max_num_shards - 1 + } +} + +impl Default for Shards { + fn default() -> Self { + Self { + shard_id: 0, + max_num_shards: 1, + } + } +} + +pub struct ShardContext<'a> { + shards: Shards, + max_cycle: Cycle, + // TODO optimize this map as it's super huge + addr_future_accesses: Cow<'a, NextCycleAccess>, + read_thread_based_record_storage: + Either>, &'a mut BTreeMap>, + write_thread_based_record_storage: + Either>, &'a mut BTreeMap>, + pub cur_shard_cycle_range: std::ops::Range, +} + +impl<'a> Default for ShardContext<'a> { + fn default() -> Self { + let max_threads = max_usable_threads(); + Self { + shards: Shards::default(), + max_cycle: Cycle::default(), + addr_future_accesses: Cow::Owned(Default::default()), + read_thread_based_record_storage: Either::Left( + (0..max_threads) + .into_par_iter() + .map(|_| BTreeMap::new()) + .collect::>(), + ), + write_thread_based_record_storage: Either::Left( + (0..max_threads) + .into_par_iter() + .map(|_| BTreeMap::new()) + .collect::>(), + ), + cur_shard_cycle_range: Tracer::SUBCYCLES_PER_INSN as usize..usize::MAX, + } + } +} + +impl<'a> ShardContext<'a> { + pub fn new( + shards: Shards, + executed_instructions: usize, + addr_future_accesses: NextCycleAccess, + ) -> Self { + // current strategy: at least each shard deal with one instruction + let max_num_shards = shards.max_num_shards.min(executed_instructions); + assert!( + shards.shard_id < max_num_shards, + "implement mechanism to skip current shard proof" + ); + + let subcycle_per_insn = Tracer::SUBCYCLES_PER_INSN as usize; + let max_threads = max_usable_threads(); + let expected_inst_per_shard = executed_instructions.div_ceil(max_num_shards); + let max_cycle = (executed_instructions + 1) * subcycle_per_insn; // cycle start from subcycle_per_insn + let cur_shard_cycle_range = (shards.shard_id * expected_inst_per_shard * subcycle_per_insn + + subcycle_per_insn) + ..((shards.shard_id + 1) * expected_inst_per_shard * subcycle_per_insn + + subcycle_per_insn) + .min(max_cycle); + + ShardContext { + shards, + max_cycle: max_cycle as Cycle, + addr_future_accesses: Cow::Owned(addr_future_accesses), + // TODO with_capacity optimisation + read_thread_based_record_storage: Either::Left( + (0..max_threads) + .into_par_iter() + .map(|_| BTreeMap::new()) + .collect::>(), + ), + // TODO with_capacity optimisation + write_thread_based_record_storage: Either::Left( + (0..max_threads) + .into_par_iter() + .map(|_| BTreeMap::new()) + .collect::>(), + ), + cur_shard_cycle_range, + } + } + + pub fn get_forked(&mut self) -> Vec> { + match ( + &mut self.read_thread_based_record_storage, + &mut self.write_thread_based_record_storage, + ) { + ( + Either::Left(read_thread_based_record_storage), + Either::Left(write_thread_based_record_storage), + ) => read_thread_based_record_storage + .iter_mut() + .zip(write_thread_based_record_storage.iter_mut()) + .map(|(read, write)| ShardContext { + shards: self.shards.clone(), + max_cycle: self.max_cycle, + addr_future_accesses: Cow::Borrowed(self.addr_future_accesses.as_ref()), + read_thread_based_record_storage: Either::Right(read), + write_thread_based_record_storage: Either::Right(write), + cur_shard_cycle_range: self.cur_shard_cycle_range.clone(), + }) + .collect_vec(), + _ => panic!("invalid type"), + } + } + + pub fn read_records(&self) -> &[BTreeMap] { + match &self.read_thread_based_record_storage { + Either::Left(m) => m, + Either::Right(_) => panic!("undefined behaviour"), + } + } + + pub fn write_records(&self) -> &[BTreeMap] { + match &self.write_thread_based_record_storage { + Either::Left(m) => m, + Either::Right(_) => panic!("undefined behaviour"), + } + } + + #[inline(always)] + pub fn is_first_shard(&self) -> bool { + self.shards.shard_id == 0 + } + + #[inline(always)] + pub fn is_last_shard(&self) -> bool { + self.shards.shard_id == self.shards.max_num_shards - 1 + } + + #[inline(always)] + pub fn is_current_shard_cycle(&self, cycle: Cycle) -> bool { + self.cur_shard_cycle_range.contains(&(cycle as usize)) + } + + #[inline(always)] + pub fn aligned_prev_ts(&self, prev_cycle: Cycle) -> Cycle { + let mut ts = prev_cycle - self.current_shard_offset_cycle(); + if ts < Tracer::SUBCYCLES_PER_INSN { + ts = 0 + } + ts + } + + pub fn current_shard_offset_cycle(&self) -> Cycle { + // cycle of each local shard start from Tracer::SUBCYCLES_PER_INSN + (self.cur_shard_cycle_range.start as Cycle) - Tracer::SUBCYCLES_PER_INSN + } + + #[inline(always)] + #[allow(clippy::too_many_arguments)] + pub fn send( + &mut self, + ram_type: crate::structs::RAMType, + addr: WordAddr, + id: u64, + cycle: Cycle, + prev_cycle: Cycle, + value: Word, + prev_value: Option, + ) { + // check read from external mem bus + // exclude first shard + if prev_cycle < self.cur_shard_cycle_range.start as Cycle + && self.is_current_shard_cycle(cycle) + && !self.is_first_shard() + { + let ram_record = self + .read_thread_based_record_storage + .as_mut() + .right() + .expect("illegal type"); + ram_record.insert( + addr, + RAMRecord { + ram_type, + id, + addr, + prev_cycle, + cycle, + prev_value, + value, + }, + ); + } + + // check write to external mem bus + if let Some(future_touch_cycle) = + self.addr_future_accesses + .get(cycle as usize) + .and_then(|res| { + if res.len() == 1 { + Some(res[0].1) + } else if res.len() > 1 { + res.iter() + .find(|(m_addr, _)| *m_addr == addr) + .map(|(_, cycle)| *cycle) + } else { + None + } + }) + && future_touch_cycle >= self.cur_shard_cycle_range.end as Cycle + && self.is_current_shard_cycle(cycle) + { + let ram_record = self + .write_thread_based_record_storage + .as_mut() + .right() + .expect("illegal type"); + ram_record.insert( + addr, + RAMRecord { + ram_type, + id, + addr, + prev_cycle, + cycle, + prev_value, + value, + }, + ); + } + } +} + +pub fn emulate_program<'a>( program: Arc, max_steps: usize, init_mem_state: &InitMemState, platform: &Platform, -) -> EmulationResult { + shards: &Shards, +) -> EmulationResult<'a> { let InitMemState { mem: mem_init, io: io_init, @@ -156,6 +431,7 @@ pub fn emulate_program( Tracer::SUBCYCLES_PER_INSN, vm.get_pc().into(), end_cycle, + shards.shard_id as u32, io_init.iter().map(|rec| rec.value).collect_vec(), ); @@ -167,6 +443,7 @@ pub fn emulate_program( if index < VMState::REG_COUNT { let vma: WordAddr = Platform::register_vma(index).into(); MemFinalRecord { + ram_type: RAMType::Register, addr: rec.addr, value: vm.peek_register(index), cycle: *final_access.get(&vma).unwrap_or(&0), @@ -174,6 +451,7 @@ pub fn emulate_program( } else { // The table is padded beyond the number of registers. MemFinalRecord { + ram_type: RAMType::Register, addr: rec.addr, value: 0, cycle: 0, @@ -188,6 +466,7 @@ pub fn emulate_program( .map(|rec| { let vma: WordAddr = rec.addr.into(); MemFinalRecord { + ram_type: RAMType::Memory, addr: rec.addr, value: vm.peek_memory(vma), cycle: *final_access.get(&vma).unwrap_or(&0), @@ -199,6 +478,7 @@ pub fn emulate_program( let io_final = io_init .iter() .map(|rec| MemFinalRecord { + ram_type: RAMType::Memory, addr: rec.addr, value: rec.value, cycle: *final_access.get(&rec.addr.into()).unwrap_or(&0), @@ -209,6 +489,7 @@ pub fn emulate_program( let hints_final = hints_init .iter() .map(|rec| MemFinalRecord { + ram_type: RAMType::Memory, addr: rec.addr, value: rec.value, cycle: *final_access.get(&rec.addr.into()).unwrap_or(&0), @@ -226,6 +507,7 @@ pub fn emulate_program( .map(|vma| { let byte_addr = vma.baddr(); MemFinalRecord { + ram_type: RAMType::Memory, addr: byte_addr.0, value: vm.peek_memory(vma), cycle: *final_access.get(&vma).unwrap_or(&0), @@ -249,6 +531,7 @@ pub fn emulate_program( .map(|vma| { let byte_addr = vma.baddr(); MemFinalRecord { + ram_type: RAMType::Memory, addr: byte_addr.0, value: vm.peek_memory(vma), cycle: *final_access.get(&vma).unwrap_or(&0), @@ -270,10 +553,13 @@ pub fn emulate_program( ), ); + let shard_ctx = ShardContext::new(shards.clone(), insts, vm.take_tracer().next_accesses()); + EmulationResult { pi, exit_code, all_records, + shard_ctx, final_mem_state: FinalMemState { reg: reg_final, io: io_final, @@ -389,17 +675,17 @@ pub fn init_static_addrs(program: &Program) -> Vec { program_addrs } -pub struct ConstraintSystemConfig { +pub struct ConstraintSystemConfig<'a, E: ExtensionField> { pub zkvm_cs: ZKVMConstraintSystem, pub config: Rv32imConfig, - pub mmu_config: MmuConfig, + pub mmu_config: MmuConfig<'a, E>, pub dummy_config: DummyExtraConfig, pub prog_config: ProgramTableConfig, } -pub fn construct_configs( +pub fn construct_configs<'a, E: ExtensionField>( program_params: ProgramParams, -) -> ConstraintSystemConfig { +) -> ConstraintSystemConfig<'a, E> { let mut zkvm_cs = ZKVMConstraintSystem::new_with_platform(program_params); let config = Rv32imConfig::::construct_circuits(&mut zkvm_cs); @@ -450,7 +736,7 @@ pub fn generate_fixed_traces( pub fn generate_witness( system_config: &ConstraintSystemConfig, - emul_result: EmulationResult, + mut emul_result: EmulationResult, program: &Program, ) -> ZKVMWitnesses { let mut zkvm_witness = ZKVMWitnesses::default(); @@ -459,13 +745,19 @@ pub fn generate_witness( .config .assign_opcode_circuit( &system_config.zkvm_cs, + &mut emul_result.shard_ctx, &mut zkvm_witness, emul_result.all_records, ) .unwrap(); system_config .dummy_config - .assign_opcode_circuit(&system_config.zkvm_cs, &mut zkvm_witness, dummy_records) + .assign_opcode_circuit( + &system_config.zkvm_cs, + &mut emul_result.shard_ctx, + &mut zkvm_witness, + dummy_records, + ) .unwrap(); zkvm_witness.finalize_lk_multiplicities(); @@ -478,6 +770,7 @@ pub fn generate_witness( .mmu_config .assign_table_circuit( &system_config.zkvm_cs, + &emul_result.shard_ctx, &mut zkvm_witness, &emul_result.final_mem_state.reg, &emul_result.final_mem_state.mem, @@ -519,12 +812,13 @@ pub enum Checkpoint { pub type IntermediateState = (Option>, Option>); /// Context construct from a program and given platform -pub struct E2EProgramCtx { +pub struct E2EProgramCtx<'a, E: ExtensionField> { pub program: Arc, pub platform: Platform, + pub shards: Shards, pub static_addrs: Vec, pub pubio_len: usize, - pub system_config: ConstraintSystemConfig, + pub system_config: ConstraintSystemConfig<'a, E>, pub reg_init: Vec, pub io_init: Vec, pub zkvm_fixed_traces: ZKVMFixedTraces, @@ -549,12 +843,16 @@ impl> E2ECheckpointResult< } /// Set up a program with the given platform -pub fn setup_program(program: Program, platform: Platform) -> E2EProgramCtx { +pub fn setup_program<'a, E: ExtensionField>( + program: Program, + platform: Platform, + shards: Shards, +) -> E2EProgramCtx<'a, E> { let static_addrs = init_static_addrs(&program); let pubio_len = platform.public_io.iter_addresses().len(); let program_params = ProgramParams { platform: platform.clone(), - program_size: program.instructions.len(), + program_size: next_pow2_instance_padding(program.instructions.len()), static_memory_len: static_addrs.len(), pubio_len, }; @@ -574,6 +872,7 @@ pub fn setup_program(program: Program, platform: Platform) -> E2EProgramCtx { program: Arc::new(program), platform, + shards, static_addrs, pubio_len, system_config, @@ -583,7 +882,7 @@ pub fn setup_program(program: Program, platform: Platform) -> } } -impl E2EProgramCtx { +impl E2EProgramCtx<'_, E> { pub fn keygen + 'static>( &self, max_num_variables: usize, @@ -666,13 +965,14 @@ pub fn run_e2e_with_checkpoint< device: PD, program: Program, platform: Platform, + shards: Shards, hints: &[u32], public_io: &[u32], max_steps: usize, checkpoint: Checkpoint, ) -> E2ECheckpointResult { let start = std::time::Instant::now(); - let ctx = setup_program::(program, platform); + let ctx = setup_program::(program, platform, shards); tracing::debug!("setup_program done in {:?}", start.elapsed()); // Keygen @@ -710,6 +1010,7 @@ pub fn run_e2e_with_checkpoint< max_steps, &init_full_mem, &ctx.platform, + &ctx.shards, ); tracing::debug!("emulate done in {:?}", start.elapsed()); @@ -793,7 +1094,13 @@ pub fn run_e2e_proof< is_mock_proving: bool, ) -> ZKVMProof { // Emulate program - let emul_result = emulate_program(ctx.program.clone(), max_steps, init_full_mem, &ctx.platform); + let emul_result = emulate_program( + ctx.program.clone(), + max_steps, + init_full_mem, + &ctx.platform, + &ctx.shards, + ); // clone pi before consuming let pi = emul_result.pi.clone(); diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index 4591c47e3..13a3ed22b 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -1,5 +1,5 @@ use crate::{ - circuit_builder::CircuitBuilder, error::ZKVMError, structs::ProgramParams, + circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, structs::ProgramParams, tables::RMMCollections, witness::LkMultiplicity, }; use ceno_emul::StepRecord; @@ -93,8 +93,9 @@ pub trait Instruction { } // assign single instance giving step from trace - fn assign_instance( + fn assign_instance<'a>( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext<'a>, instance: &mut [E::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, @@ -102,6 +103,7 @@ pub trait Instruction { fn assign_instances( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, num_witin: usize, num_structural_witin: usize, steps: Vec, @@ -131,22 +133,32 @@ pub trait Instruction { let raw_witin_iter = raw_witin.par_batch_iter_mut(num_instance_per_batch); let raw_structual_witin_iter = raw_structual_witin.par_batch_iter_mut(num_instance_per_batch); + let shard_ctx_vec = shard_ctx.get_forked(); raw_witin_iter .zip_eq(raw_structual_witin_iter) .zip_eq(steps.par_chunks(num_instance_per_batch)) - .flat_map(|((instances, structural_instance), steps)| { - let mut lk_multiplicity = lk_multiplicity.clone(); - instances - .chunks_mut(num_witin) - .zip_eq(structural_instance.chunks_mut(num_structural_witin)) - .zip_eq(steps) - .map(|((instance, structural_instance), step)| { - set_val!(structural_instance, selector_witin, E::BaseField::ONE); - Self::assign_instance(config, instance, &mut lk_multiplicity, step) - }) - .collect::>() - }) + .zip(shard_ctx_vec) + .flat_map( + |(((instances, structural_instance), steps), mut shard_ctx)| { + let mut lk_multiplicity = lk_multiplicity.clone(); + instances + .chunks_mut(num_witin) + .zip_eq(structural_instance.chunks_mut(num_structural_witin)) + .zip_eq(steps) + .map(|((instance, structural_instance), step)| { + set_val!(structural_instance, selector_witin, E::BaseField::ONE); + Self::assign_instance( + config, + &mut shard_ctx, + instance, + &mut lk_multiplicity, + step, + ) + }) + .collect::>() + }, + ) .collect::>()?; raw_witin.padding_by_strategy(); diff --git a/ceno_zkvm/src/instructions/riscv/arith.rs b/ceno_zkvm/src/instructions/riscv/arith.rs index b73abcda4..a94024b4a 100644 --- a/ceno_zkvm/src/instructions/riscv/arith.rs +++ b/ceno_zkvm/src/instructions/riscv/arith.rs @@ -2,8 +2,8 @@ use std::marker::PhantomData; use super::{RIVInstruction, constants::UInt, r_insn::RInstructionConfig}; use crate::{ - circuit_builder::CircuitBuilder, error::ZKVMError, instructions::Instruction, - structs::ProgramParams, uint::Value, witness::LkMultiplicity, + circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, + instructions::Instruction, structs::ProgramParams, uint::Value, witness::LkMultiplicity, }; use ceno_emul::{InsnKind, StepRecord}; use ff_ext::ExtensionField; @@ -87,13 +87,14 @@ impl Instruction for ArithInstruction::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { config .r_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; let rs2_read = Value::new_unchecked(step.rs2().unwrap().value); config @@ -186,6 +187,7 @@ mod test { let insn_code = encode_rv32(I::INST_KIND, 2, 3, 4, 0); let (raw_witin, lkm) = ArithInstruction::::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_r_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm.rs b/ceno_zkvm/src/instructions/riscv/arith_imm.rs index a040681bc..4de4069d0 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm.rs @@ -21,6 +21,7 @@ mod test { use crate::{ Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, instructions::{Instruction, riscv::constants::UInt}, scheme::mock_prover::{MOCK_PC_START, MockProver}, structs::ProgramParams, @@ -63,6 +64,7 @@ mod test { let insn_code = encode_rv32(InsnKind::ADDI, 2, 0, 4, imm); let (raw_witin, lkm) = AddiInstruction::::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_i_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit.rs b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit.rs index 8a4722a08..11d93242c 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit.rs @@ -1,6 +1,7 @@ use crate::{ Value, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -58,6 +59,7 @@ impl Instruction for AddiInstruction { fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, @@ -77,7 +79,7 @@ impl Instruction for AddiInstruction { config .i_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; Ok(()) } diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs index f969a68b0..8ed175d58 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs @@ -1,6 +1,7 @@ use crate::{ Value, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -70,6 +71,7 @@ impl Instruction for AddiInstruction { fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, @@ -93,7 +95,7 @@ impl Instruction for AddiInstruction { config .i_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; Ok(()) } diff --git a/ceno_zkvm/src/instructions/riscv/auipc.rs b/ceno_zkvm/src/instructions/riscv/auipc.rs index 7957f7003..3244c5d60 100644 --- a/ceno_zkvm/src/instructions/riscv/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/auipc.rs @@ -4,6 +4,7 @@ use std::marker::PhantomData; use crate::{ circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -142,13 +143,14 @@ impl Instruction for AuipcInstruction { fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [E::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &ceno_emul::StepRecord, ) -> Result<(), ZKVMError> { config .i_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; let rd_written = split_to_u8(step.rd().unwrap().value.after); config.rd_written.assign_limbs(instance, &rd_written); @@ -189,6 +191,7 @@ mod tests { use crate::{ Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, instructions::{ Instruction, riscv::{auipc::AuipcInstruction, constants::UInt}, @@ -239,6 +242,7 @@ mod tests { let insn_code = encode_rv32(InsnKind::AUIPC, 0, 0, 4, imm); let (raw_witin, lkm) = AuipcInstruction::::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_i_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/b_insn.rs b/ceno_zkvm/src/instructions/riscv/b_insn.rs index 798902754..cdc1db56d 100644 --- a/ceno_zkvm/src/instructions/riscv/b_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/b_insn.rs @@ -5,6 +5,7 @@ use super::constants::PC_STEP_SIZE; use crate::{ chip_handler::{RegisterExpr, general::InstFetch}, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::riscv::insn_base::{ReadRS1, ReadRS2, StateInOut}, tables::InsnRecord, @@ -12,7 +13,6 @@ use crate::{ }; use ff_ext::FieldInto; use multilinear_extensions::{Expression, ToExpr, WitIn}; - // Opcode: 1100011 // Funct3: // 000 BEQ @@ -89,12 +89,15 @@ impl BInstructionConfig { pub fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - self.vm_state.assign_instance(instance, step)?; - self.rs1.assign_instance(instance, lk_multiplicity, step)?; - self.rs2.assign_instance(instance, lk_multiplicity, step)?; + self.vm_state.assign_instance(instance, shard_ctx, step)?; + self.rs1 + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; + self.rs2 + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; // Immediate set_val!( diff --git a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit.rs b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit.rs index 8aecd50f8..2c97a12ee 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit.rs @@ -6,6 +6,7 @@ use ff_ext::ExtensionField; use crate::{ Value, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, gadgets::{IsEqualConfig, IsLtConfig, SignedLtConfig}, instructions::{ @@ -137,13 +138,14 @@ impl Instruction for BranchCircuit Result<(), ZKVMError> { config .b_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; let rs1 = Value::new_unchecked(step.rs1().unwrap().value); let rs2 = Value::new_unchecked(step.rs2().unwrap().value); diff --git a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs index 94abb56d1..386d2c286 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs @@ -1,6 +1,7 @@ use crate::{ Value, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, gadgets::{UIntLimbsLT, UIntLimbsLTConfig}, instructions::{ @@ -68,13 +69,14 @@ impl Instruction for BranchCircuit Result<(), ZKVMError> { config .b_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; let rs1 = Value::new_unchecked(step.rs1().unwrap().value); let rs1_limbs = rs1.as_u16_limbs(); diff --git a/ceno_zkvm/src/instructions/riscv/branch/test.rs b/ceno_zkvm/src/instructions/riscv/branch/test.rs index aaf468127..82dbcffac 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/test.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/test.rs @@ -6,6 +6,7 @@ use ff_ext::{ExtensionField, GoldilocksExt2}; use super::*; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, error::ZKVMError, instructions::Instruction, scheme::mock_prover::{MOCK_PC_START, MockProver}, @@ -39,6 +40,7 @@ fn impl_opcode_beq(equal: bool) { let pc_offset = if equal { 8 } else { PC_STEP_SIZE }; let (raw_witin, lkm) = BeqInstruction::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_b_instruction( @@ -79,6 +81,7 @@ fn impl_opcode_bne(equal: bool) { let pc_offset = if equal { PC_STEP_SIZE } else { 8 }; let (raw_witin, lkm) = BneInstruction::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_b_instruction( @@ -122,6 +125,7 @@ fn impl_bltu_circuit(taken: bool, a: u32, b: u32) -> Result<(), ZKVMError> { let insn_code = encode_rv32(InsnKind::BLTU, 2, 3, 0, -8); let (raw_witin, lkm) = BltuInstruction::assign_instances( &config, + &mut ShardContext::default(), circuit_builder.cs.num_witin as usize, circuit_builder.cs.num_structural_witin as usize, vec![StepRecord::new_b_instruction( @@ -166,6 +170,7 @@ fn impl_bgeu_circuit(taken: bool, a: u32, b: u32) -> Result<(), ZKVMError> { let insn_code = encode_rv32(InsnKind::BGEU, 2, 3, 0, -8); let (raw_witin, lkm) = BgeuInstruction::assign_instances( &config, + &mut ShardContext::default(), circuit_builder.cs.num_witin as usize, circuit_builder.cs.num_structural_witin as usize, vec![StepRecord::new_b_instruction( @@ -217,6 +222,7 @@ fn impl_blt_circuit(taken: bool, a: i32, b: i32) -> Result<() let insn_code = encode_rv32(InsnKind::BLT, 2, 3, 0, -8); let (raw_witin, lkm) = BltInstruction::assign_instances( &config, + &mut ShardContext::default(), circuit_builder.cs.num_witin as usize, circuit_builder.cs.num_structural_witin as usize, vec![StepRecord::new_b_instruction( @@ -268,6 +274,7 @@ fn impl_bge_circuit(taken: bool, a: i32, b: i32) -> Result<() let insn_code = encode_rv32(InsnKind::BGE, 2, 3, 0, -8); let (raw_witin, lkm) = BgeInstruction::assign_instances( &config, + &mut ShardContext::default(), circuit_builder.cs.num_witin as usize, circuit_builder.cs.num_structural_witin as usize, vec![StepRecord::new_b_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/constants.rs b/ceno_zkvm/src/instructions/riscv/constants.rs index 1992f4fa3..4e3786235 100644 --- a/ceno_zkvm/src/instructions/riscv/constants.rs +++ b/ceno_zkvm/src/instructions/riscv/constants.rs @@ -9,7 +9,8 @@ pub const INIT_PC_IDX: usize = 2; pub const INIT_CYCLE_IDX: usize = 3; pub const END_PC_IDX: usize = 4; pub const END_CYCLE_IDX: usize = 5; -pub const PUBLIC_IO_IDX: usize = 6; +pub const END_SHARD_ID_IDX: usize = 6; +pub const PUBLIC_IO_IDX: usize = 7; pub const LIMB_BITS: usize = 16; pub const LIMB_MASK: u32 = 0xFFFF; diff --git a/ceno_zkvm/src/instructions/riscv/div.rs b/ceno_zkvm/src/instructions/riscv/div.rs index 7ca30d2b8..966320407 100644 --- a/ceno_zkvm/src/instructions/riscv/div.rs +++ b/ceno_zkvm/src/instructions/riscv/div.rs @@ -53,6 +53,7 @@ mod test { use crate::{ Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, instructions::{ Instruction, riscv::{ @@ -179,6 +180,7 @@ mod test { // values assignment let ([raw_witin, _], lkm) = Insn::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_r_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/div/div_circuit.rs b/ceno_zkvm/src/instructions/riscv/div/div_circuit.rs index ef5b9d936..99a73a8a4 100644 --- a/ceno_zkvm/src/instructions/riscv/div/div_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/div/div_circuit.rs @@ -75,6 +75,7 @@ use super::{ }; use crate::{ circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, gadgets::{AssertLtConfig, IsEqualConfig, IsLtConfig, IsZeroConfig, Signed}, instructions::{Instruction, riscv::constants::LIMB_BITS}, @@ -310,6 +311,7 @@ impl Instruction for ArithInstruction Instruction for ArithInstruction Instruction for ArithInstruction Instruction for ArithInstruction (true, true), diff --git a/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs b/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs index 7c98e2159..1df279dd9 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs @@ -9,9 +9,9 @@ use super::super::{ insn_base::{ReadMEM, ReadRS1, ReadRS2, StateInOut, WriteMEM, WriteRD}, }; use crate::{ - chip_handler::general::InstFetch, circuit_builder::CircuitBuilder, error::ZKVMError, - instructions::Instruction, structs::ProgramParams, tables::InsnRecord, uint::Value, - witness::LkMultiplicity, + chip_handler::general::InstFetch, circuit_builder::CircuitBuilder, e2e::ShardContext, + error::ZKVMError, instructions::Instruction, structs::ProgramParams, tables::InsnRecord, + uint::Value, witness::LkMultiplicity, }; use ff_ext::FieldInto; use multilinear_extensions::{ToExpr, WitIn}; @@ -70,11 +70,12 @@ impl Instruction for DummyInstruction::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - config.assign_instance(instance, lk_multiplicity, step) + config.assign_instance(instance, shard_ctx, lk_multiplicity, step) } } @@ -242,30 +243,31 @@ impl DummyConfig { pub(super) fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { // State in and out - self.vm_state.assign_instance(instance, step)?; + self.vm_state.assign_instance(instance, shard_ctx, step)?; // Fetch instruction lk_multiplicity.fetch(step.pc().before.0); // Registers if let Some((rs1_op, rs1_read)) = &self.rs1 { - rs1_op.assign_instance(instance, lk_multiplicity, step)?; + rs1_op.assign_instance(instance, shard_ctx, lk_multiplicity, step)?; let rs1_val = Value::new_unchecked(step.rs1().expect("rs1 value").value); rs1_read.assign_value(instance, rs1_val); } if let Some((rs2_op, rs2_read)) = &self.rs2 { - rs2_op.assign_instance(instance, lk_multiplicity, step)?; + rs2_op.assign_instance(instance, shard_ctx, lk_multiplicity, step)?; let rs2_val = Value::new_unchecked(step.rs2().expect("rs2 value").value); rs2_read.assign_value(instance, rs2_val); } if let Some((rd_op, rd_written)) = &self.rd { - rd_op.assign_instance(instance, lk_multiplicity, step)?; + rd_op.assign_instance(instance, shard_ctx, lk_multiplicity, step)?; let rd_val = Value::new_unchecked(step.rd().expect("rd value").value.after); rd_written.assign_value(instance, rd_val); @@ -284,10 +286,10 @@ impl DummyConfig { mem_after.assign_value(instance, Value::new(mem_op.value.after, lk_multiplicity)); } if let Some(mem_read) = &self.mem_read { - mem_read.assign_instance(instance, lk_multiplicity, step)?; + mem_read.assign_instance(instance, shard_ctx, lk_multiplicity, step)?; } if let Some(mem_write) = &self.mem_write { - mem_write.assign_instance::(instance, lk_multiplicity, step)?; + mem_write.assign_instance::(instance, shard_ctx, lk_multiplicity, step)?; } let imm = InsnRecord::::imm_internal(&step.insn()).1; diff --git a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs index 69bdd1648..9cd5cb0f3 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs @@ -8,6 +8,7 @@ use super::{super::insn_base::WriteMEM, dummy_circuit::DummyConfig}; use crate::{ Value, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -84,6 +85,7 @@ impl Instruction for LargeEcallDummy fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [E::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, @@ -93,14 +95,14 @@ impl Instruction for LargeEcallDummy // Assign instruction. config .dummy_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; set_val!(instance, config.start_addr, u64::from(ops.mem_ops[0].addr)); // Assign registers. for ((value, writer), op) in config.reg_writes.iter().zip_eq(&ops.reg_ops) { value.assign_value(instance, Value::new_unchecked(op.value.after)); - writer.assign_op(instance, lk_multiplicity, step.cycle(), op)?; + writer.assign_op(instance, shard_ctx, lk_multiplicity, step.cycle(), op)?; } // Assign memory. @@ -112,7 +114,7 @@ impl Instruction for LargeEcallDummy .after .assign_value(instance, Value::new(op.value.after, lk_multiplicity)); set_val!(instance, addr, u64::from(op.addr)); - writer.assign_op(instance, lk_multiplicity, step.cycle(), op)?; + writer.assign_op(instance, shard_ctx, lk_multiplicity, step.cycle(), op)?; } Ok(()) diff --git a/ceno_zkvm/src/instructions/riscv/dummy/test.rs b/ceno_zkvm/src/instructions/riscv/dummy/test.rs index 6f7a89f73..c6f51d142 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/test.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/test.rs @@ -4,6 +4,7 @@ use ff_ext::GoldilocksExt2; use super::*; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, instructions::{ Instruction, riscv::{arith::AddOp, branch::BeqOp, ecall::EcallDummy}, @@ -34,6 +35,7 @@ fn test_dummy_ecall() { let insn_code = step.insn(); let (raw_witin, lkm) = EcallDummy::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![step], @@ -63,6 +65,7 @@ fn test_dummy_keccak() { let (step, program) = ceno_emul::test_utils::keccak_step(); let (raw_witin, lkm) = KeccakDummy::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![step], @@ -90,6 +93,7 @@ fn test_dummy_r() { let insn_code = encode_rv32(InsnKind::ADD, 2, 3, 4, 0); let (raw_witin, lkm) = AddDummy::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_r_instruction( @@ -125,6 +129,7 @@ fn test_dummy_b() { let insn_code = encode_rv32(InsnKind::BEQ, 2, 3, 0, 8); let (raw_witin, lkm) = BeqDummy::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_b_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/ecall/halt.rs b/ceno_zkvm/src/instructions/riscv/ecall/halt.rs index e14585727..bf38a67c4 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/halt.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/halt.rs @@ -1,6 +1,7 @@ use crate::{ chip_handler::{RegisterChipOperations, general::PublicIOQuery}, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, gadgets::AssertLtConfig, instructions::{ @@ -70,6 +71,7 @@ impl Instruction for HaltInstruction { fn assign_instance( config: &Self::InstructionConfig, + _shard_ctx: &mut ShardContext, instance: &mut [E::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, diff --git a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs index b0ac2a505..dccdf34a2 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs @@ -21,6 +21,7 @@ use witness::{InstancePaddingStrategy, RowMajorMatrix}; use crate::{ chip_handler::general::InstFetch, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -156,6 +157,7 @@ impl Instruction for KeccakInstruction { fn assign_instance( _config: &Self::InstructionConfig, + _shard_ctx: &mut ShardContext, _instance: &mut [::BaseField], _lk_multiplicity: &mut LkMultiplicity, _step: &StepRecord, @@ -165,6 +167,7 @@ impl Instruction for KeccakInstruction { fn assign_instances( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, num_witin: usize, num_structural_witin: usize, steps: Vec, @@ -196,11 +199,13 @@ impl Instruction for KeccakInstruction { // each instance are composed of KECCAK_ROUNDS.next_power_of_two() let raw_witin_iter = raw_witin .par_batch_iter_mut(num_instance_per_batch * KECCAK_ROUNDS.next_power_of_two()); + let shard_ctx_vec = shard_ctx.get_forked(); // 1st pass: assign witness outside of gkr-iop scope raw_witin_iter .zip_eq(steps.par_chunks(num_instance_per_batch)) - .flat_map(|(instances, steps)| { + .zip(shard_ctx_vec) + .flat_map(|((instances, steps), mut shard_ctx)| { let mut lk_multiplicity = lk_multiplicity.clone(); instances @@ -218,10 +223,13 @@ impl Instruction for KeccakInstruction { [round_index as usize * num_witin..][..num_witin]; // vm_state - config.vm_state.assign_instance(instance, step)?; + config + .vm_state + .assign_instance(instance, &shard_ctx, step)?; config.ecall_id.assign_op( instance, + &mut shard_ctx, &mut lk_multiplicity, step.cycle(), &WriteOp::new_register_op( @@ -238,6 +246,7 @@ impl Instruction for KeccakInstruction { )?; config.state_ptr.0.assign_op( instance, + &mut shard_ctx, &mut lk_multiplicity, step.cycle(), &ops.reg_ops[0], @@ -246,6 +255,7 @@ impl Instruction for KeccakInstruction { for (writer, op) in config.mem_rw.iter().zip_eq(&ops.mem_ops) { writer.assign_op( instance, + &mut shard_ctx, &mut lk_multiplicity, step.cycle(), op, diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs index 6365cfcd2..adf52683f 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs @@ -24,6 +24,7 @@ use witness::{InstancePaddingStrategy, RowMajorMatrix}; use crate::{ chip_handler::general::InstFetch, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -207,6 +208,7 @@ impl Instruction fn assign_instance( _config: &Self::InstructionConfig, + _shard_ctx: &mut ShardContext, _instance: &mut [::BaseField], _lk_multiplicity: &mut LkMultiplicity, _step: &StepRecord, @@ -216,6 +218,7 @@ impl Instruction fn assign_instances( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, num_witin: usize, num_structural_witin: usize, steps: Vec, @@ -255,11 +258,13 @@ impl Instruction ); let raw_witin_iter = raw_witin.par_batch_iter_mut(num_instance_per_batch); + let shard_ctx_vec = shard_ctx.get_forked(); // 1st pass: assign witness outside of gkr-iop scope raw_witin_iter .zip_eq(steps.par_chunks(num_instance_per_batch)) - .flat_map(|(instances, steps)| { + .zip(shard_ctx_vec) + .flat_map(|((instances, steps), mut shard_ctx)| { let mut lk_multiplicity = lk_multiplicity.clone(); instances @@ -269,10 +274,13 @@ impl Instruction let ops = &step.syscall().expect("syscall step"); // vm_state - config.vm_state.assign_instance(instance, step)?; + config + .vm_state + .assign_instance(instance, &shard_ctx, step)?; config.ecall_id.assign_op( instance, + &mut shard_ctx, &mut lk_multiplicity, step.cycle(), &WriteOp::new_register_op( @@ -289,6 +297,7 @@ impl Instruction )?; config.point_ptr_0.0.assign_op( instance, + &mut shard_ctx, &mut lk_multiplicity, step.cycle(), &ops.reg_ops[0], @@ -301,12 +310,19 @@ impl Instruction )?; config.point_ptr_1.0.assign_op( instance, + &mut shard_ctx, &mut lk_multiplicity, step.cycle(), &ops.reg_ops[1], )?; for (writer, op) in config.mem_rw.iter().zip_eq(&ops.mem_ops) { - writer.assign_op(instance, &mut lk_multiplicity, step.cycle(), op)?; + writer.assign_op( + instance, + &mut shard_ctx, + &mut lk_multiplicity, + step.cycle(), + op, + )?; } // fetch lk_multiplicity.fetch(step.pc().before.0); diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs index 6003f9794..250141669 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs @@ -31,6 +31,7 @@ use witness::{InstancePaddingStrategy, RowMajorMatrix}; use crate::{ chip_handler::general::InstFetch, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -208,6 +209,7 @@ impl Instruction::BaseField], _lk_multiplicity: &mut LkMultiplicity, _step: &StepRecord, @@ -217,6 +219,7 @@ impl Instruction, @@ -254,12 +257,14 @@ impl Instruction::WordsFieldElement::USIZE; // 1st pass: assign witness outside of gkr-iop scope let sign_bit_and_y_words = raw_witin_iter .zip_eq(steps.par_chunks(num_instance_per_batch)) - .flat_map(|(instances, steps)| { + .zip(shard_ctx_vec) + .flat_map(|((instances, steps), mut shard_ctx)| { let mut lk_multiplicity = lk_multiplicity.clone(); instances @@ -269,9 +274,12 @@ impl Instruction Instruction Instruction Instruction::BaseField], _lk_multiplicity: &mut LkMultiplicity, _step: &StepRecord, @@ -188,6 +190,7 @@ impl Instruction, @@ -227,11 +230,13 @@ impl Instruction Instruction Instruction OpFixedRS Result<(), ZKVMError> { - set_val!(instance, self.prev_ts, op.previous_cycle); + let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); + let current_shard_offset_cycle = shard_ctx.current_shard_offset_cycle(); + let shard_cycle = cycle - current_shard_offset_cycle; + set_val!(instance, self.prev_ts, shard_prev_cycle); // Register state if let Some(prev_value) = self.prev_value.as_ref() { @@ -76,17 +82,30 @@ impl OpFixedRS IInstructionConfig { pub fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - self.vm_state.assign_instance(instance, step)?; - self.rs1.assign_instance(instance, lk_multiplicity, step)?; - self.rd.assign_instance(instance, lk_multiplicity, step)?; + self.vm_state.assign_instance(instance, shard_ctx, step)?; + self.rs1 + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; + self.rd + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; // Fetch instruction lk_multiplicity.fetch(step.pc().before.0); diff --git a/ceno_zkvm/src/instructions/riscv/im_insn.rs b/ceno_zkvm/src/instructions/riscv/im_insn.rs index 5fa6cd501..c7f6cace0 100644 --- a/ceno_zkvm/src/instructions/riscv/im_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/im_insn.rs @@ -7,6 +7,7 @@ use crate::{ witness::LkMultiplicity, }; +use crate::e2e::ShardContext; use ceno_emul::{InsnKind, StepRecord}; use ff_ext::ExtensionField; use multilinear_extensions::{Expression, ToExpr}; @@ -67,14 +68,17 @@ impl IMInstructionConfig { pub fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - self.vm_state.assign_instance(instance, step)?; - self.rs1.assign_instance(instance, lk_multiplicity, step)?; - self.rd.assign_instance(instance, lk_multiplicity, step)?; + self.vm_state.assign_instance(instance, shard_ctx, step)?; + self.rs1 + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; + self.rd + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; self.mem_read - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; // Fetch instruction lk_multiplicity.fetch(step.pc().before.0); diff --git a/ceno_zkvm/src/instructions/riscv/insn_base.rs b/ceno_zkvm/src/instructions/riscv/insn_base.rs index 43a72f739..4877df9d1 100644 --- a/ceno_zkvm/src/instructions/riscv/insn_base.rs +++ b/ceno_zkvm/src/instructions/riscv/insn_base.rs @@ -10,8 +10,10 @@ use crate::{ RegisterChipOperations, RegisterExpr, }, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, gadgets::AssertLtConfig, + structs::RAMType, uint::Value, witness::{LkMultiplicity, set_val}, }; @@ -58,14 +60,17 @@ impl StateInOut { pub fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &ShardContext, // lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { + let current_shard_offset_cycle = shard_ctx.current_shard_offset_cycle(); + set_val!(instance, self.pc, step.pc().before.0 as u64); if let Some(n_pc) = self.next_pc { set_val!(instance, n_pc, step.pc().after.0 as u64); } - set_val!(instance, self.ts, step.cycle()); + set_val!(instance, self.ts, step.cycle() - current_shard_offset_cycle); Ok(()) } @@ -106,20 +111,33 @@ impl ReadRS1 { pub fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { let op = step.rs1().expect("rs1 op"); + let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); + let current_shard_offset_cycle = shard_ctx.current_shard_offset_cycle(); + let shard_cycle = step.cycle() - current_shard_offset_cycle; set_val!(instance, self.id, op.register_index() as u64); - set_val!(instance, self.prev_ts, op.previous_cycle); + set_val!(instance, self.prev_ts, shard_prev_cycle); // Register read self.lt_cfg.assign_instance( instance, lk_multiplicity, - op.previous_cycle, - step.cycle() + Tracer::SUBCYCLE_RS1, + shard_prev_cycle, + shard_cycle + Tracer::SUBCYCLE_RS1, )?; + shard_ctx.send( + RAMType::Register, + op.addr, + op.register_index() as u64, + step.cycle() + Tracer::SUBCYCLE_RS1, + op.previous_cycle, + op.value, + None, + ); Ok(()) } @@ -160,21 +178,35 @@ impl ReadRS2 { pub fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { let op = step.rs2().expect("rs2 op"); + let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); + let current_shard_offset_cycle = shard_ctx.current_shard_offset_cycle(); + let shard_cycle = step.cycle() - current_shard_offset_cycle; set_val!(instance, self.id, op.register_index() as u64); - set_val!(instance, self.prev_ts, op.previous_cycle); + set_val!(instance, self.prev_ts, shard_prev_cycle); // Register read self.lt_cfg.assign_instance( instance, lk_multiplicity, - op.previous_cycle, - step.cycle() + Tracer::SUBCYCLE_RS2, + shard_prev_cycle, + shard_cycle + Tracer::SUBCYCLE_RS2, )?; + shard_ctx.send( + RAMType::Register, + op.addr, + op.register_index() as u64, + step.cycle() + Tracer::SUBCYCLE_RS2, + op.previous_cycle, + op.value, + None, + ); + Ok(()) } } @@ -216,22 +248,27 @@ impl WriteRD { pub fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { let op = step.rd().expect("rd op"); - self.assign_op(instance, lk_multiplicity, step.cycle(), &op) + self.assign_op(instance, shard_ctx, lk_multiplicity, step.cycle(), &op) } pub fn assign_op( &self, instance: &mut [E::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, cycle: Cycle, op: &WriteOp, ) -> Result<(), ZKVMError> { + let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); + let current_shard_offset_cycle = shard_ctx.current_shard_offset_cycle(); + let shard_cycle = cycle - current_shard_offset_cycle; set_val!(instance, self.id, op.register_index() as u64); - set_val!(instance, self.prev_ts, op.previous_cycle); + set_val!(instance, self.prev_ts, shard_prev_cycle); // Register state self.prev_value.assign_limbs( @@ -243,9 +280,18 @@ impl WriteRD { self.lt_cfg.assign_instance( instance, lk_multiplicity, - op.previous_cycle, - cycle + Tracer::SUBCYCLE_RD, + shard_prev_cycle, + shard_cycle + Tracer::SUBCYCLE_RD, )?; + shard_ctx.send( + RAMType::Register, + op.addr, + op.register_index() as u64, + cycle + Tracer::SUBCYCLE_RD, + op.previous_cycle, + op.value.after, + Some(op.value.before), + ); Ok(()) } @@ -284,24 +330,35 @@ impl ReadMEM { pub fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { + let op = step.memory_op().unwrap(); + let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); + let current_shard_offset_cycle = shard_ctx.current_shard_offset_cycle(); + let shard_cycle = step.cycle() - current_shard_offset_cycle; // Memory state - set_val!( - instance, - self.prev_ts, - step.memory_op().unwrap().previous_cycle - ); + set_val!(instance, self.prev_ts, shard_prev_cycle); // Memory read self.lt_cfg.assign_instance( instance, lk_multiplicity, - step.memory_op().unwrap().previous_cycle, - step.cycle() + Tracer::SUBCYCLE_MEM, + shard_prev_cycle, + shard_cycle + Tracer::SUBCYCLE_MEM, )?; + shard_ctx.send( + RAMType::Memory, + op.addr, + op.addr.baddr().0 as u64, + step.cycle() + Tracer::SUBCYCLE_MEM, + op.previous_cycle, + op.value.after, + None, + ); + Ok(()) } } @@ -337,29 +394,44 @@ impl WriteMEM { pub fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { let op = step.memory_op().unwrap(); - self.assign_op(instance, lk_multiplicity, step.cycle(), &op) + self.assign_op(instance, shard_ctx, lk_multiplicity, step.cycle(), &op) } pub fn assign_op( &self, instance: &mut [F], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, cycle: Cycle, op: &WriteOp, ) -> Result<(), ZKVMError> { - set_val!(instance, self.prev_ts, op.previous_cycle); + let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); + let current_shard_offset_cycle = shard_ctx.current_shard_offset_cycle(); + let shard_cycle = cycle - current_shard_offset_cycle; + set_val!(instance, self.prev_ts, shard_prev_cycle); self.lt_cfg.assign_instance( instance, lk_multiplicity, - op.previous_cycle, - cycle + Tracer::SUBCYCLE_MEM, + shard_prev_cycle, + shard_cycle + Tracer::SUBCYCLE_MEM, )?; + shard_ctx.send( + RAMType::Memory, + op.addr, + op.addr.baddr().0 as u64, + cycle + Tracer::SUBCYCLE_MEM, + op.previous_cycle, + op.value.after, + Some(op.value.before), + ); + Ok(()) } } diff --git a/ceno_zkvm/src/instructions/riscv/j_insn.rs b/ceno_zkvm/src/instructions/riscv/j_insn.rs index 156aa1cd1..84cb84679 100644 --- a/ceno_zkvm/src/instructions/riscv/j_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/j_insn.rs @@ -4,13 +4,13 @@ use ff_ext::ExtensionField; use crate::{ chip_handler::{RegisterExpr, general::InstFetch}, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::riscv::insn_base::{StateInOut, WriteRD}, tables::InsnRecord, witness::LkMultiplicity, }; use multilinear_extensions::ToExpr; - // Opcode: 1101111 /// This config handles the common part of the J-type instruction (JAL): @@ -55,11 +55,13 @@ impl JInstructionConfig { pub fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - self.vm_state.assign_instance(instance, step)?; - self.rd.assign_instance(instance, lk_multiplicity, step)?; + self.vm_state.assign_instance(instance, shard_ctx, step)?; + self.rd + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; // Fetch the instruction. lk_multiplicity.fetch(step.pc().before.0); diff --git a/ceno_zkvm/src/instructions/riscv/jump/jal.rs b/ceno_zkvm/src/instructions/riscv/jump/jal.rs index a4c0a96f4..c8abc77ac 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jal.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jal.rs @@ -5,6 +5,7 @@ use ff_ext::ExtensionField; use crate::{ Value, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -64,13 +65,14 @@ impl Instruction for JalInstruction { fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [E::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &ceno_emul::StepRecord, ) -> Result<(), ZKVMError> { config .j_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; let rd_written = Value::new(step.rd().unwrap().value.after, lk_multiplicity); config.rd_written.assign_value(instance, rd_written); diff --git a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs index 0f67be424..545adf275 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs @@ -4,6 +4,7 @@ use ff_ext::ExtensionField; use crate::{ circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -88,13 +89,14 @@ impl Instruction for JalInstruction { fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [E::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &ceno_emul::StepRecord, ) -> Result<(), ZKVMError> { config .j_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; let rd_written = split_to_u8(step.rd().unwrap().value.after); config.rd_written.assign_limbs(instance, &rd_written); diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs index f1ba94aa7..77f6ad1f8 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs @@ -5,6 +5,7 @@ use ff_ext::ExtensionField; use crate::{ Value, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -111,6 +112,7 @@ impl Instruction for JalrInstruction { fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [E::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &ceno_emul::StepRecord, @@ -150,7 +152,7 @@ impl Instruction for JalrInstruction { config .i_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; Ok(()) } diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs index bfec3a099..7f23ac9b6 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs @@ -5,6 +5,7 @@ use crate::{ Value, chip_handler::general::InstFetch, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -135,6 +136,7 @@ impl Instruction for JalrInstruction { fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [E::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &ceno_emul::StepRecord, @@ -177,7 +179,7 @@ impl Instruction for JalrInstruction { config .i_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; Ok(()) } diff --git a/ceno_zkvm/src/instructions/riscv/jump/test.rs b/ceno_zkvm/src/instructions/riscv/jump/test.rs index 0b379f250..899e5a035 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/test.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/test.rs @@ -2,6 +2,7 @@ use super::{JalInstruction, JalrInstruction}; use crate::{ Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, instructions::{Instruction, riscv::constants::UInt}, scheme::mock_prover::{MOCK_PC_START, MockProver}, structs::ProgramParams, @@ -42,6 +43,7 @@ fn verify_test_opcode_jal(pc_offset: i32) { let insn_code = encode_rv32(InsnKind::JAL, 0, 0, 4, pc_offset); let (raw_witin, lkm) = JalInstruction::::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_j_instruction( @@ -117,6 +119,7 @@ fn verify_test_opcode_jalr(rs1_read: Word, imm: i32) { let (raw_witin, lkm) = JalrInstruction::::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_i_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs index f761f6102..5a2d8e404 100644 --- a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs @@ -6,6 +6,7 @@ use std::marker::PhantomData; use crate::{ circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -53,6 +54,7 @@ impl Instruction for LogicInstruction { fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, @@ -63,7 +65,7 @@ impl Instruction for LogicInstruction { step.rs2().unwrap().value as u64, ); - config.assign_instance(instance, lk_multiplicity, step) + config.assign_instance(instance, shard_ctx, lk_multiplicity, step) } } @@ -106,11 +108,12 @@ impl LogicConfig { fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { self.r_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; let rs1_read = split_to_u8(step.rs1().unwrap().value); self.rs1_read.assign_limbs(instance, &rs1_read); diff --git a/ceno_zkvm/src/instructions/riscv/logic/test.rs b/ceno_zkvm/src/instructions/riscv/logic/test.rs index dc01487d9..f68135c72 100644 --- a/ceno_zkvm/src/instructions/riscv/logic/test.rs +++ b/ceno_zkvm/src/instructions/riscv/logic/test.rs @@ -1,16 +1,16 @@ use ceno_emul::{Change, StepRecord, Word, encode_rv32}; use ff_ext::GoldilocksExt2; +use super::*; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, instructions::{Instruction, riscv::constants::UInt8}, scheme::mock_prover::{MOCK_PC_START, MockProver}, structs::ProgramParams, utils::split_to_u8, }; -use super::*; - const A: Word = 0xbead1010; const B: Word = 0xef552020; @@ -32,6 +32,7 @@ fn test_opcode_and() { let insn_code = encode_rv32(InsnKind::AND, 2, 3, 4, 0); let (raw_witin, lkm) = AndInstruction::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_r_instruction( @@ -74,6 +75,7 @@ fn test_opcode_or() { let insn_code = encode_rv32(InsnKind::OR, 2, 3, 4, 0); let (raw_witin, lkm) = OrInstruction::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_r_instruction( @@ -116,6 +118,7 @@ fn test_opcode_xor() { let insn_code = encode_rv32(InsnKind::XOR, 2, 3, 4, 0); let (raw_witin, lkm) = XorInstruction::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_r_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs index aad60b43b..596792ad8 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs @@ -6,6 +6,7 @@ use std::marker::PhantomData; use crate::{ circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -48,6 +49,7 @@ impl Instruction for LogicInstruction { fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [::BaseField], lkm: &mut LkMultiplicity, step: &StepRecord, @@ -58,7 +60,7 @@ impl Instruction for LogicInstruction { InsnRecord::::imm_internal(&step.insn()).0 as u64, ); - config.assign_instance(instance, lkm, step) + config.assign_instance(instance, shard_ctx, lkm, step) } } @@ -102,10 +104,12 @@ impl LogicConfig { fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lkm: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - self.i_insn.assign_instance(instance, lkm, step)?; + self.i_insn + .assign_instance(instance, shard_ctx, lkm, step)?; let rs1_read = split_to_u8(step.rs1().unwrap().value); self.rs1_read.assign_limbs(instance, &rs1_read); diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs index c72f31efe..b48af7f5f 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs @@ -7,6 +7,7 @@ use std::marker::PhantomData; use crate::{ circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -94,6 +95,7 @@ impl Instruction for LogicInstruction { fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [::BaseField], lkm: &mut LkMultiplicity, step: &StepRecord, @@ -115,7 +117,7 @@ impl Instruction for LogicInstruction { imm_hi.into(), ); - config.assign_instance(instance, lkm, step) + config.assign_instance(instance, shard_ctx, lkm, step) } } @@ -163,11 +165,13 @@ impl LogicConfig { fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lkm: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { let num_limbs = LIMB_BITS / 8; - self.i_insn.assign_instance(instance, lkm, step)?; + self.i_insn + .assign_instance(instance, shard_ctx, lkm, step)?; let rs1_read = split_to_u8(step.rs1().unwrap().value); self.rs1_read.assign_limbs(instance, &rs1_read); diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/test.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/test.rs index 23aa2d77c..68032fd41 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/test.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/test.rs @@ -4,6 +4,7 @@ use gkr_iop::circuit_builder::DebugIndex; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, instructions::{ Instruction, riscv::{ @@ -70,6 +71,7 @@ fn verify(name: &'static str, rs1_read: u32, imm: u32, expected_rd_w let insn_code = encode_rv32u(I::INST_KIND, 2, 0, 4, imm); let (raw_witin, lkm) = LogicInstruction::::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_i_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/lui.rs b/ceno_zkvm/src/instructions/riscv/lui.rs index 2cc280f04..198bafbc5 100644 --- a/ceno_zkvm/src/instructions/riscv/lui.rs +++ b/ceno_zkvm/src/instructions/riscv/lui.rs @@ -4,6 +4,7 @@ use std::marker::PhantomData; use crate::{ circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -88,13 +89,14 @@ impl Instruction for LuiInstruction { fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [E::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &ceno_emul::StepRecord, ) -> Result<(), ZKVMError> { config .i_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; let rd_written = split_to_u8(step.rd().unwrap().value.after); for (val, witin) in izip!(rd_written.iter().skip(1), config.rd_written) { @@ -117,6 +119,7 @@ mod tests { use crate::{ Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, instructions::{ Instruction, riscv::{constants::UInt, lui::LuiInstruction}, @@ -153,6 +156,7 @@ mod tests { let insn_code = encode_rv32(InsnKind::LUI, 0, 0, 4, imm); let (raw_witin, lkm) = LuiInstruction::::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_i_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/memory/load.rs b/ceno_zkvm/src/instructions/riscv/memory/load.rs index 5945f26bd..41fbf0059 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load.rs @@ -1,6 +1,7 @@ use crate::{ Value, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, gadgets::SignedExtendConfig, instructions::{ @@ -165,6 +166,7 @@ impl Instruction for LoadInstruction Instruction for LoadInstruction Instruction for LoadInstruction Instruction for LoadInstruction Instruction fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [E::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, @@ -124,7 +126,7 @@ impl Instruction let addr = ByteAddr::from(step.rs1().unwrap().value.wrapping_add_signed(imm.0 as i32)); config .s_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; config.rs1_read.assign_value(instance, rs1); config.rs2_read.assign_value(instance, rs2); set_val!(instance, config.imm, imm.1); diff --git a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs index f07968d19..cb512975b 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs @@ -1,6 +1,7 @@ use crate::{ Value, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -127,6 +128,7 @@ impl Instruction fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [E::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, @@ -147,7 +149,7 @@ impl Instruction let addr = ByteAddr::from(step.rs1().unwrap().value.wrapping_add_signed(imm.0 as i32)); config .s_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; config.rs1_read.assign_value(instance, rs1); config.rs2_read.assign_value(instance, rs2); set_val!(instance, config.imm, imm.1); diff --git a/ceno_zkvm/src/instructions/riscv/memory/test.rs b/ceno_zkvm/src/instructions/riscv/memory/test.rs index 90c5a0273..b2a04326b 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/test.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/test.rs @@ -1,6 +1,7 @@ use crate::{ Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, instructions::{ Instruction, riscv::{ @@ -102,6 +103,7 @@ fn impl_opcode_store::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_r_instruction( @@ -217,6 +219,7 @@ mod test { let insn_code = encode_rv32(InsnKind::MULH, 2, 3, 4, 0); let (raw_witin, lkm) = MulhInstruction::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_r_instruction( @@ -300,6 +303,7 @@ mod test { let insn_code = encode_rv32(InsnKind::MULHSU, 2, 3, 4, 0); let (raw_witin, lkm) = MulhsuInstruction::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_r_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit.rs b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit.rs index bc5bc9ed4..dd919dd3e 100644 --- a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit.rs @@ -86,6 +86,7 @@ use p3::{field::FieldAlgebra, goldilocks::Goldilocks}; use crate::{ circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, gadgets::{IsEqualConfig, Signed}, instructions::{ @@ -286,6 +287,7 @@ impl Instruction for MulhInstructionBas fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, @@ -312,7 +314,7 @@ impl Instruction for MulhInstructionBas // R-type instruction config .r_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; // Assign signed values, if any, and compute low 32-bit limb of product let prod_lo_hi = match &config.sign_deps { diff --git a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs index c1853d7a8..a94f63e74 100644 --- a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs @@ -19,6 +19,7 @@ use multilinear_extensions::{Expression, ToExpr as _, WitIn}; use p3::field::{Field, FieldAlgebra}; use witness::set_val; +use crate::e2e::ShardContext; use itertools::Itertools; use std::{array, marker::PhantomData}; @@ -223,6 +224,7 @@ impl Instruction for MulhInstructionBas fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, @@ -241,7 +243,7 @@ impl Instruction for MulhInstructionBas // R-type instruction config .r_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; let (rd_high, rd_low, carry, rs1_ext, rs2_ext) = run_mulh::( I::INST_KIND, diff --git a/ceno_zkvm/src/instructions/riscv/r_insn.rs b/ceno_zkvm/src/instructions/riscv/r_insn.rs index 540ccaffe..a4b9bb128 100644 --- a/ceno_zkvm/src/instructions/riscv/r_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/r_insn.rs @@ -4,6 +4,7 @@ use ff_ext::ExtensionField; use crate::{ chip_handler::{RegisterExpr, general::InstFetch}, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::riscv::insn_base::{ReadRS1, ReadRS2, StateInOut, WriteRD}, tables::InsnRecord, @@ -63,13 +64,17 @@ impl RInstructionConfig { pub fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - self.vm_state.assign_instance(instance, step)?; - self.rs1.assign_instance(instance, lk_multiplicity, step)?; - self.rs2.assign_instance(instance, lk_multiplicity, step)?; - self.rd.assign_instance(instance, lk_multiplicity, step)?; + self.vm_state.assign_instance(instance, shard_ctx, step)?; + self.rs1 + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; + self.rs2 + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; + self.rd + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; // Fetch instruction lk_multiplicity.fetch(step.pc().before.0); diff --git a/ceno_zkvm/src/instructions/riscv/rv32im.rs b/ceno_zkvm/src/instructions/riscv/rv32im.rs index b953fc2af..9957f2122 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im.rs @@ -9,6 +9,7 @@ use crate::instructions::riscv::lui::LuiInstruction; #[cfg(not(feature = "u16limb_circuit"))] use crate::tables::PowTableCircuit; use crate::{ + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -409,6 +410,7 @@ impl Rv32imConfig { pub fn assign_opcode_circuit( &self, cs: &ZKVMConstraintSystem, + shard_ctx: &mut ShardContext, witness: &mut ZKVMWitnesses, steps: Vec, ) -> Result { @@ -422,38 +424,49 @@ impl Rv32imConfig { let mut secp256k1_add_records = Vec::new(); let mut secp256k1_double_records = Vec::new(); let mut secp256k1_decompress_records = Vec::new(); - steps.into_iter().for_each(|record| { - let insn_kind = record.insn.kind; - match insn_kind { - // ecall / halt - InsnKind::ECALL if record.rs1().unwrap().value == Platform::ecall_halt() => { - halt_records.push(record); + steps + .into_iter() + .filter_map(|step| { + if shard_ctx.is_current_shard_cycle(step.cycle()) { + Some(step) + } else { + None } - InsnKind::ECALL if record.rs1().unwrap().value == KeccakSpec::CODE => { - keccak_records.push(record); + }) + .for_each(|record| { + let insn_kind = record.insn.kind; + match insn_kind { + // ecall / halt + InsnKind::ECALL if record.rs1().unwrap().value == Platform::ecall_halt() => { + halt_records.push(record); + } + InsnKind::ECALL if record.rs1().unwrap().value == KeccakSpec::CODE => { + keccak_records.push(record); + } + InsnKind::ECALL if record.rs1().unwrap().value == Bn254AddSpec::CODE => { + bn254_add_records.push(record); + } + InsnKind::ECALL if record.rs1().unwrap().value == Bn254DoubleSpec::CODE => { + bn254_double_records.push(record); + } + InsnKind::ECALL if record.rs1().unwrap().value == Secp256k1AddSpec::CODE => { + secp256k1_add_records.push(record); + } + InsnKind::ECALL if record.rs1().unwrap().value == Secp256k1DoubleSpec::CODE => { + secp256k1_double_records.push(record); + } + InsnKind::ECALL + if record.rs1().unwrap().value == Secp256k1DecompressSpec::CODE => + { + secp256k1_decompress_records.push(record); + } + // other type of ecalls are handled by dummy ecall instruction + _ => { + // it's safe to unwrap as all_records are initialized with Vec::new() + all_records.get_mut(&insn_kind).unwrap().push(record); + } } - InsnKind::ECALL if record.rs1().unwrap().value == Bn254AddSpec::CODE => { - bn254_add_records.push(record); - } - InsnKind::ECALL if record.rs1().unwrap().value == Bn254DoubleSpec::CODE => { - bn254_double_records.push(record); - } - InsnKind::ECALL if record.rs1().unwrap().value == Secp256k1AddSpec::CODE => { - secp256k1_add_records.push(record); - } - InsnKind::ECALL if record.rs1().unwrap().value == Secp256k1DoubleSpec::CODE => { - secp256k1_double_records.push(record); - } - InsnKind::ECALL if record.rs1().unwrap().value == Secp256k1DecompressSpec::CODE => { - secp256k1_decompress_records.push(record); - } - // other type of ecalls are handled by dummy ecall instruction - _ => { - // it's safe to unwrap as all_records are initialized with Vec::new() - all_records.get_mut(&insn_kind).unwrap().push(record); - } - } - }); + }); for (insn_kind, (_, records)) in izip!(InsnKind::iter(), &all_records).sorted_by_key(|(_, (_, a))| Reverse(a.len())) @@ -465,6 +478,7 @@ impl Rv32imConfig { ($insn_kind:ident,$instruction:ty,$config:ident) => { witness.assign_opcode_circuit::<$instruction>( cs, + shard_ctx, &self.$config, all_records.remove(&($insn_kind)).unwrap(), )?; @@ -524,35 +538,46 @@ impl Rv32imConfig { assign_opcode!(SB, SbInstruction, sb_config); // ecall / halt - witness.assign_opcode_circuit::>(cs, &self.halt_config, halt_records)?; + witness.assign_opcode_circuit::>( + cs, + shard_ctx, + &self.halt_config, + halt_records, + )?; witness.assign_opcode_circuit::>( cs, + shard_ctx, &self.keccak_config, keccak_records, )?; witness.assign_opcode_circuit::>>( cs, + shard_ctx, &self.bn254_add_config, bn254_add_records, )?; witness.assign_opcode_circuit::>>( cs, + shard_ctx, &self.bn254_double_config, bn254_double_records, )?; witness.assign_opcode_circuit::>>( cs, + shard_ctx, &self.secp256k1_add_config, secp256k1_add_records, )?; witness .assign_opcode_circuit::>>( cs, + shard_ctx, &self.secp256k1_double_config, secp256k1_double_records, )?; witness.assign_opcode_circuit::>>( cs, + shard_ctx, &self.secp256k1_decompress_config, secp256k1_decompress_records, )?; @@ -671,6 +696,7 @@ impl DummyExtraConfig { pub fn assign_opcode_circuit( &self, cs: &ZKVMConstraintSystem, + shard_ctx: &mut ShardContext, witness: &mut ZKVMWitnesses, steps: GroupedSteps, ) -> Result<(), ZKVMError> { @@ -700,35 +726,46 @@ impl DummyExtraConfig { witness.assign_opcode_circuit::>( cs, + shard_ctx, &self.secp256k1_decompress_config, secp256k1_decompress_steps, )?; witness.assign_opcode_circuit::>( cs, + shard_ctx, &self.sha256_extend_config, sha256_extend_steps, )?; witness.assign_opcode_circuit::>( cs, + shard_ctx, &self.bn254_fp_add_config, bn254_fp_add_steps, )?; witness.assign_opcode_circuit::>( cs, + shard_ctx, &self.bn254_fp_mul_config, bn254_fp_mul_steps, )?; witness.assign_opcode_circuit::>( cs, + shard_ctx, &self.bn254_fp2_add_config, bn254_fp2_add_steps, )?; witness.assign_opcode_circuit::>( cs, + shard_ctx, &self.bn254_fp2_mul_config, bn254_fp2_mul_steps, )?; - witness.assign_opcode_circuit::>(cs, &self.ecall_config, other_steps)?; + witness.assign_opcode_circuit::>( + cs, + shard_ctx, + &self.ecall_config, + other_steps, + )?; let _ = steps.remove(&INVALID); let keys: Vec<&InsnKind> = steps.keys().collect::>(); diff --git a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs index d8c032c7b..82a8d0c91 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs @@ -1,54 +1,63 @@ -use std::{collections::HashSet, iter::zip, ops::Range}; - -use ceno_emul::{Addr, Cycle, IterAddresses, WORD_SIZE, Word}; -use ff_ext::ExtensionField; -use itertools::{Itertools, chain}; - use crate::{ + e2e::ShardContext, error::ZKVMError, structs::{ProgramParams, ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses}, tables::{ - HeapCircuit, HintsCircuit, MemFinalRecord, MemInitRecord, NonVolatileTable, PubIOCircuit, - PubIOTable, RegTable, RegTableCircuit, StackCircuit, StaticMemCircuit, StaticMemTable, - TableCircuit, + DynVolatileRamTable, HeapInitCircuit, HeapTable, HintsCircuit, LocalFinalCircuit, + MemFinalRecord, MemInitRecord, NonVolatileTable, PubIOCircuit, PubIOTable, RBCircuit, + RegTable, RegTableInitCircuit, StackInitCircuit, StackTable, StaticMemInitCircuit, + StaticMemTable, TableCircuit, }, }; +use ceno_emul::{Addr, Cycle, IterAddresses, WORD_SIZE, Word}; +use ff_ext::ExtensionField; +use itertools::{Itertools, chain}; +use std::{collections::HashSet, iter::zip, ops::Range, sync::Arc}; +use witness::InstancePaddingStrategy; -pub struct MmuConfig { +pub struct MmuConfig<'a, E: ExtensionField> { /// Initialization of registers. - pub reg_config: as TableCircuit>::TableConfig, + pub reg_init_config: as TableCircuit>::TableConfig, /// Initialization of memory with static addresses. - pub static_mem_config: as TableCircuit>::TableConfig, + pub static_mem_init_config: as TableCircuit>::TableConfig, /// Initialization of public IO. pub public_io_config: as TableCircuit>::TableConfig, /// Initialization of hints. pub hints_config: as TableCircuit>::TableConfig, /// Initialization of heap. - pub heap_config: as TableCircuit>::TableConfig, + pub heap_init_config: as TableCircuit>::TableConfig, /// Initialization of stack. - pub stack_config: as TableCircuit>::TableConfig, + pub stack_init_config: as TableCircuit>::TableConfig, + /// finalized circuit for all MMIO + pub local_final_circuit: as TableCircuit>::TableConfig, + /// ram bus to deal with cross shard read/write + pub ram_bus_circuit: as TableCircuit>::TableConfig, pub params: ProgramParams, } -impl MmuConfig { +impl MmuConfig<'_, E> { pub fn construct_circuits(cs: &mut ZKVMConstraintSystem) -> Self { - let reg_config = cs.register_table_circuit::>(); + let reg_init_config = cs.register_table_circuit::>(); - let static_mem_config = cs.register_table_circuit::>(); + let static_mem_init_config = cs.register_table_circuit::>(); let public_io_config = cs.register_table_circuit::>(); let hints_config = cs.register_table_circuit::>(); - let stack_config = cs.register_table_circuit::>(); - let heap_config = cs.register_table_circuit::>(); + let stack_init_config = cs.register_table_circuit::>(); + let heap_init_config = cs.register_table_circuit::>(); + let local_final_circuit = cs.register_table_circuit::>(); + let ram_bus_circuit = cs.register_table_circuit::>(); Self { - reg_config, - static_mem_config, + reg_init_config, + static_mem_init_config, public_io_config, hints_config, - stack_config, - heap_config, + stack_init_config, + heap_init_config, + local_final_circuit, + ram_bus_circuit, params: cs.params.clone(), } } @@ -72,24 +81,27 @@ impl MmuConfig { "memory addresses must be unique" ); - fixed.register_table_circuit::>(cs, &self.reg_config, reg_init); + fixed.register_table_circuit::>(cs, &self.reg_init_config, reg_init); - fixed.register_table_circuit::>( + fixed.register_table_circuit::>( cs, - &self.static_mem_config, + &self.static_mem_init_config, static_mem_init, ); fixed.register_table_circuit::>(cs, &self.public_io_config, io_addrs); fixed.register_table_circuit::>(cs, &self.hints_config, &()); - fixed.register_table_circuit::>(cs, &self.stack_config, &()); - fixed.register_table_circuit::>(cs, &self.heap_config, &()); + fixed.register_table_circuit::>(cs, &self.stack_init_config, &()); + fixed.register_table_circuit::>(cs, &self.heap_init_config, &()); + fixed.register_table_circuit::>(cs, &self.local_final_circuit, &()); + fixed.register_table_circuit::>(cs, &self.ram_bus_circuit, &()); } #[allow(clippy::too_many_arguments)] pub fn assign_table_circuit( &self, cs: &ZKVMConstraintSystem, + shard_ctx: &ShardContext, witness: &mut ZKVMWitnesses, reg_final: &[MemFinalRecord], static_mem_final: &[MemFinalRecord], @@ -98,18 +110,60 @@ impl MmuConfig { stack_final: &[MemFinalRecord], heap_final: &[MemFinalRecord], ) -> Result<(), ZKVMError> { - witness.assign_table_circuit::>(cs, &self.reg_config, reg_final)?; + witness.assign_table_circuit::>( + cs, + &self.reg_init_config, + reg_final, + )?; - witness.assign_table_circuit::>( + witness.assign_table_circuit::>( cs, - &self.static_mem_config, + &self.static_mem_init_config, static_mem_final, )?; witness.assign_table_circuit::>(cs, &self.public_io_config, io_cycles)?; witness.assign_table_circuit::>(cs, &self.hints_config, hints_final)?; - witness.assign_table_circuit::>(cs, &self.stack_config, stack_final)?; - witness.assign_table_circuit::>(cs, &self.heap_config, heap_final)?; + witness.assign_table_circuit::>( + cs, + &self.stack_init_config, + stack_final, + )?; + witness.assign_table_circuit::>( + cs, + &self.heap_init_config, + heap_final, + )?; + + let all_records = vec![ + (InstancePaddingStrategy::Default, reg_final), + (InstancePaddingStrategy::Default, static_mem_final), + ( + InstancePaddingStrategy::Custom({ + let params = cs.params.clone(); + Arc::new(move |row: u64, _: u64| StackTable::addr(¶ms, row as usize) as u64) + }), + stack_final, + ), + ( + InstancePaddingStrategy::Custom({ + let params = cs.params.clone(); + Arc::new(move |row: u64, _: u64| HeapTable::addr(¶ms, row as usize) as u64) + }), + heap_final, + ), + ] + .into_iter() + .filter(|(_, record)| !record.is_empty()) + .collect_vec(); + // take all mem result and + witness.assign_table_circuit::>( + cs, + &self.local_final_circuit, + &(shard_ctx, all_records.as_slice()), + )?; + + witness.assign_table_circuit::>(cs, &self.ram_bus_circuit, shard_ctx)?; Ok(()) } diff --git a/ceno_zkvm/src/instructions/riscv/s_insn.rs b/ceno_zkvm/src/instructions/riscv/s_insn.rs index f46cf4c5d..f252a7c60 100644 --- a/ceno_zkvm/src/instructions/riscv/s_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/s_insn.rs @@ -1,6 +1,7 @@ use crate::{ chip_handler::{AddressExpr, MemoryExpr, RegisterExpr, general::InstFetch}, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::riscv::insn_base::{ReadRS1, ReadRS2, StateInOut, WriteMEM}, tables::InsnRecord, @@ -73,14 +74,17 @@ impl SInstructionConfig { pub fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - self.vm_state.assign_instance(instance, step)?; - self.rs1.assign_instance(instance, lk_multiplicity, step)?; - self.rs2.assign_instance(instance, lk_multiplicity, step)?; + self.vm_state.assign_instance(instance, shard_ctx, step)?; + self.rs1 + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; + self.rs2 + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; self.mem_write - .assign_instance::(instance, lk_multiplicity, step)?; + .assign_instance::(instance, shard_ctx, lk_multiplicity, step)?; // Fetch instruction lk_multiplicity.fetch(step.pc().before.0); diff --git a/ceno_zkvm/src/instructions/riscv/shift.rs b/ceno_zkvm/src/instructions/riscv/shift.rs index 0c53f1a4c..d09b98c89 100644 --- a/ceno_zkvm/src/instructions/riscv/shift.rs +++ b/ceno_zkvm/src/instructions/riscv/shift.rs @@ -45,6 +45,7 @@ mod tests { use crate::utils::split_to_u8; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, instructions::{Instruction, riscv::RIVInstruction}, scheme::mock_prover::{MOCK_PC_START, MockProver}, structs::ProgramParams, @@ -173,6 +174,7 @@ mod tests { let (raw_witin, lkm) = ShiftLogicalInstruction::::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_r_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit.rs b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit.rs index 87374b20e..c1d83ce87 100644 --- a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit.rs @@ -1,5 +1,6 @@ use crate::{ Value, + e2e::ShardContext, error::ZKVMError, gadgets::SignedExtendConfig, instructions::{ @@ -151,6 +152,7 @@ impl Instruction for ShiftLogicalInstru fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [::BaseField], lk_multiplicity: &mut crate::witness::LkMultiplicity, step: &ceno_emul::StepRecord, @@ -211,7 +213,7 @@ impl Instruction for ShiftLogicalInstru config .r_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; Ok(()) } diff --git a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs index 4e929670c..fac05279e 100644 --- a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs @@ -1,3 +1,4 @@ +use crate::e2e::ShardContext; /// constrain implementation follow from https://github.com/openvm-org/openvm/blob/main/extensions/rv32im/circuit/src/shift/core.rs use crate::{ instructions::{ @@ -321,6 +322,7 @@ impl Instruction for ShiftLogicalInstru fn assign_instance( config: &ShiftRTypeConfig, + shard_ctx: &mut ShardContext, instance: &mut [::BaseField], lk_multiplicity: &mut crate::witness::LkMultiplicity, step: &ceno_emul::StepRecord, @@ -352,7 +354,7 @@ impl Instruction for ShiftLogicalInstru ); config .r_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; Ok(()) } @@ -419,6 +421,7 @@ impl Instruction for ShiftImmInstructio fn assign_instance( config: &ShiftImmConfig, + shard_ctx: &mut ShardContext, instance: &mut [::BaseField], lk_multiplicity: &mut crate::witness::LkMultiplicity, step: &ceno_emul::StepRecord, @@ -449,7 +452,7 @@ impl Instruction for ShiftImmInstructio ); config .i_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; Ok(()) } diff --git a/ceno_zkvm/src/instructions/riscv/shift_imm.rs b/ceno_zkvm/src/instructions/riscv/shift_imm.rs index 4cf7ac155..1757a0fc7 100644 --- a/ceno_zkvm/src/instructions/riscv/shift_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/shift_imm.rs @@ -43,6 +43,7 @@ mod test { use crate::utils::split_to_u8; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, instructions::{Instruction, riscv::RIVInstruction}, scheme::mock_prover::{MOCK_PC_START, MockProver}, structs::ProgramParams, @@ -170,6 +171,7 @@ mod test { let (raw_witin, lkm) = ShiftImmInstruction::::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_i_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/shift_imm/shift_imm_circuit.rs b/ceno_zkvm/src/instructions/riscv/shift_imm/shift_imm_circuit.rs index 0bba35411..a2fa8d032 100644 --- a/ceno_zkvm/src/instructions/riscv/shift_imm/shift_imm_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/shift_imm/shift_imm_circuit.rs @@ -1,6 +1,7 @@ use crate::{ Value, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, gadgets::SignedExtendConfig, instructions::{ @@ -132,6 +133,7 @@ impl Instruction for ShiftImmInstructio fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, @@ -168,7 +170,7 @@ impl Instruction for ShiftImmInstructio config .i_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; Ok(()) } diff --git a/ceno_zkvm/src/instructions/riscv/slt.rs b/ceno_zkvm/src/instructions/riscv/slt.rs index 7b27617ad..3ba12bb39 100644 --- a/ceno_zkvm/src/instructions/riscv/slt.rs +++ b/ceno_zkvm/src/instructions/riscv/slt.rs @@ -38,6 +38,7 @@ mod test { use crate::{ Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, instructions::{Instruction, riscv::constants::UInt}, scheme::mock_prover::{MOCK_PC_START, MockProver}, structs::ProgramParams, @@ -72,6 +73,7 @@ mod test { let insn_code = encode_rv32(I::INST_KIND, 2, 3, 4, 0); let (raw_witin, lkm) = SetLessThanInstruction::<_, I>::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_r_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit.rs b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit.rs index 3ffd9de69..b9b63acaf 100644 --- a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit.rs @@ -1,5 +1,6 @@ use crate::{ Value, + e2e::ShardContext, error::ZKVMError, gadgets::SignedLtConfig, instructions::{ @@ -92,11 +93,14 @@ impl Instruction for SetLessThanInstruc fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [::BaseField], lkm: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - config.r_insn.assign_instance(instance, lkm, step)?; + config + .r_insn + .assign_instance(instance, shard_ctx, lkm, step)?; let rs1 = step.rs1().unwrap().value; let rs2 = step.rs2().unwrap().value; diff --git a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs index 391dffb89..cd0b97ce4 100644 --- a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs @@ -1,6 +1,7 @@ use crate::{ Value, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, gadgets::{UIntLimbsLT, UIntLimbsLTConfig}, instructions::{ @@ -75,11 +76,14 @@ impl Instruction for SetLessThanInstruc fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [::BaseField], lkm: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - config.r_insn.assign_instance(instance, lkm, step)?; + config + .r_insn + .assign_instance(instance, shard_ctx, lkm, step)?; let rs1 = step.rs1().unwrap().value; let rs2 = step.rs2().unwrap().value; diff --git a/ceno_zkvm/src/instructions/riscv/slti.rs b/ceno_zkvm/src/instructions/riscv/slti.rs index 5802c4229..ff3a78043 100644 --- a/ceno_zkvm/src/instructions/riscv/slti.rs +++ b/ceno_zkvm/src/instructions/riscv/slti.rs @@ -35,6 +35,7 @@ mod test { use crate::{ Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, instructions::{ Instruction, riscv::{ @@ -185,6 +186,7 @@ mod test { let (raw_witin, lkm) = SetLessThanImmInstruction::::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_i_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs index 266faeed3..8b93f593c 100644 --- a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs @@ -1,6 +1,7 @@ use crate::{ Value, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, gadgets::SignedExtendConfig, instructions::{ @@ -94,11 +95,14 @@ impl Instruction for SetLessThanImmInst fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [E::BaseField], lkm: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - config.i_insn.assign_instance(instance, lkm, step)?; + config + .i_insn + .assign_instance(instance, shard_ctx, lkm, step)?; let rs1 = step.rs1().unwrap().value; let rs1_value = Value::new_unchecked(rs1 as Word); diff --git a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs index 1085561fb..914424247 100644 --- a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs @@ -1,6 +1,7 @@ use crate::{ Value, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, gadgets::{UIntLimbsLT, UIntLimbsLTConfig}, instructions::{ @@ -92,11 +93,14 @@ impl Instruction for SetLessThanImmInst fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [E::BaseField], lkm: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - config.i_insn.assign_instance(instance, lkm, step)?; + config + .i_insn + .assign_instance(instance, shard_ctx, lkm, step)?; let rs1 = step.rs1().unwrap().value; let rs1_value = Value::new_unchecked(rs1 as Word); diff --git a/ceno_zkvm/src/keygen.rs b/ceno_zkvm/src/keygen.rs index 17ab9e72c..0ced182b8 100644 --- a/ceno_zkvm/src/keygen.rs +++ b/ceno_zkvm/src/keygen.rs @@ -26,8 +26,11 @@ impl ZKVMConstraintSystem { .remove(&c_name) .flatten() .ok_or(ZKVMError::FixedTraceNotFound(c_name.clone().into()))?; + vm_pk + .circuit_index_fixed_num_instances + .insert(circuit_index, fixed_trace_rmm.num_instances()); fixed_traces.insert(circuit_index, fixed_trace_rmm); - }; + } let circuit_pk = cs.key_gen(); assert!(vm_pk.circuit_pks.insert(c_name, circuit_pk).is_none()); diff --git a/ceno_zkvm/src/lib.rs b/ceno_zkvm/src/lib.rs index 16e7ee821..a72c0ffe6 100644 --- a/ceno_zkvm/src/lib.rs +++ b/ceno_zkvm/src/lib.rs @@ -1,6 +1,7 @@ #![deny(clippy::cargo)] #![feature(box_patterns)] #![feature(stmt_expr_attributes)] +#![feature(variant_count)] pub mod error; pub mod instructions; diff --git a/ceno_zkvm/src/precompiles/lookup_keccakf.rs b/ceno_zkvm/src/precompiles/lookup_keccakf.rs index 2fcd8de79..5b2c1867f 100644 --- a/ceno_zkvm/src/precompiles/lookup_keccakf.rs +++ b/ceno_zkvm/src/precompiles/lookup_keccakf.rs @@ -40,6 +40,7 @@ use witness::{InstancePaddingStrategy, RowMajorMatrix}; use crate::{ chip_handler::MemoryExpr, + e2e::ShardContext, error::ZKVMError, instructions::riscv::insn_base::{StateInOut, WriteMEM}, precompiles::{ @@ -1025,6 +1026,7 @@ pub fn run_faster_keccakf verify: bool, test_outputs: bool, ) -> Result, BackendError> { + let mut shard_ctx = ShardContext::default(); let num_instances = states.len(); let num_instances_rounds = num_instances * ROUNDS.next_power_of_two(); let log2_num_instance_rounds = ceil_log2(num_instances_rounds); @@ -1073,9 +1075,11 @@ pub fn run_faster_keccakf ); let raw_witin_iter = phase1_witness.par_batch_iter_mut(num_instance_per_batch * ROUNDS.next_power_of_two()); + let shard_ctx_vec = shard_ctx.get_forked(); raw_witin_iter .zip_eq(instances.par_chunks(num_instance_per_batch)) - .for_each(|(instances, steps)| { + .zip(shard_ctx_vec) + .for_each(|((instances, steps), mut shard_ctx)| { let mut lk_multiplicity = lk_multiplicity.clone(); instances .chunks_mut(num_witin as usize * ROUNDS.next_power_of_two()) @@ -1087,6 +1091,7 @@ pub fn run_faster_keccakf .vm_state .assign_instance( instance, + &shard_ctx, &StepRecord::new_ecall_any(10, ByteAddr::from(0)), ) .expect("assign vm_state error"); @@ -1095,6 +1100,7 @@ pub fn run_faster_keccakf mem_config .assign_op( instance, + &mut shard_ctx, &mut lk_multiplicity, 10, &MemOp { diff --git a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs index 3be11dbd0..18e1a205b 100644 --- a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs +++ b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs @@ -63,6 +63,7 @@ use witness::{InstancePaddingStrategy, RowMajorMatrix}; use crate::{ chip_handler::MemoryExpr, + e2e::ShardContext, error::ZKVMError, gadgets::{FieldOperation, field_op::FieldOpCols}, instructions::riscv::insn_base::{StateInOut, WriteMEM}, @@ -559,6 +560,7 @@ pub fn run_weierstrass_add< verify: bool, test_outputs: bool, ) -> Result, BackendError> { + let mut shard_ctx = ShardContext::default(); let num_instances = points.len(); let log2_num_instance = ceil_log2(num_instances); let num_threads = optimal_sumcheck_threads(log2_num_instance); @@ -591,9 +593,11 @@ pub fn run_weierstrass_add< InstancePaddingStrategy::Default, ); let raw_witin_iter = phase1_witness.par_batch_iter_mut(num_instance_per_batch); + let shard_ctx_vec = shard_ctx.get_forked(); raw_witin_iter .zip_eq(instances.par_chunks(num_instance_per_batch)) - .for_each(|(instances, steps)| { + .zip(shard_ctx_vec) + .for_each(|((instances, steps), mut shard_ctx)| { let mut lk_multiplicity = lk_multiplicity.clone(); instances .chunks_mut(num_witin as usize) @@ -603,6 +607,7 @@ pub fn run_weierstrass_add< .vm_state .assign_instance( instance, + &shard_ctx, &StepRecord::new_ecall_any(10, ByteAddr::from(0)), ) .expect("assign vm_state error"); @@ -610,6 +615,7 @@ pub fn run_weierstrass_add< mem_config .assign_op( instance, + &mut shard_ctx, &mut lk_multiplicity, 10, &MemOp { diff --git a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs index 52496e869..de03a829e 100644 --- a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs +++ b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs @@ -67,6 +67,7 @@ use witness::{InstancePaddingStrategy, RowMajorMatrix}; use crate::{ chip_handler::MemoryExpr, + e2e::ShardContext, error::ZKVMError, gadgets::{ FieldOperation, field_inner_product::FieldInnerProductCols, field_op::FieldOpCols, @@ -557,6 +558,7 @@ pub fn run_weierstrass_decompress< test_outputs: bool, verify: bool, ) -> Result, BackendError> { + let mut shard_ctx = ShardContext::default(); let num_instances = instances.len(); let log2_num_instance = ceil_log2(num_instances); let num_threads = optimal_sumcheck_threads(log2_num_instance); @@ -577,9 +579,11 @@ pub fn run_weierstrass_decompress< InstancePaddingStrategy::Default, ); let raw_witin_iter = phase1_witness.par_batch_iter_mut(num_instance_per_batch); + let shard_ctx_vec = shard_ctx.get_forked(); raw_witin_iter .zip_eq(instances.par_chunks(num_instance_per_batch)) - .for_each(|(instances, steps)| { + .zip(shard_ctx_vec) + .for_each(|((instances, steps), mut shard_ctx)| { let mut lk_multiplicity = lk_multiplicity.clone(); instances .chunks_mut(num_witin as usize) @@ -589,6 +593,7 @@ pub fn run_weierstrass_decompress< .vm_state .assign_instance( instance, + &shard_ctx, &StepRecord::new_ecall_any(10, ByteAddr::from(0)), ) .expect("assign vm_state error"); @@ -596,6 +601,7 @@ pub fn run_weierstrass_decompress< mem_config .assign_op( instance, + &mut shard_ctx, &mut lk_multiplicity, 10, &MemOp { diff --git a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs index e5f16ba2f..1260fae33 100644 --- a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs +++ b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs @@ -64,6 +64,7 @@ use witness::{InstancePaddingStrategy, RowMajorMatrix}; use crate::{ chip_handler::MemoryExpr, + e2e::ShardContext, error::ZKVMError, gadgets::{FieldOperation, field_op::FieldOpCols}, instructions::riscv::insn_base::{StateInOut, WriteMEM}, @@ -564,6 +565,7 @@ pub fn run_weierstrass_double< verify: bool, test_outputs: bool, ) -> Result, BackendError> { + let mut shard_ctx = ShardContext::default(); let num_instances = points.len(); let log2_num_instance = ceil_log2(num_instances); let num_threads = optimal_sumcheck_threads(log2_num_instance); @@ -593,9 +595,11 @@ pub fn run_weierstrass_double< InstancePaddingStrategy::Default, ); let raw_witin_iter = phase1_witness.par_batch_iter_mut(num_instance_per_batch); + let shard_ctx_vec = shard_ctx.get_forked(); raw_witin_iter .zip_eq(instances.par_chunks(num_instance_per_batch)) - .for_each(|(instances, steps)| { + .zip(shard_ctx_vec) + .for_each(|((instances, steps), mut shard_ctx)| { let mut lk_multiplicity = lk_multiplicity.clone(); instances .chunks_mut(num_witin as usize) @@ -605,6 +609,7 @@ pub fn run_weierstrass_double< .vm_state .assign_instance( instance, + &shard_ctx, &StepRecord::new_ecall_any(10, ByteAddr::from(0)), ) .expect("assign vm_state error"); @@ -612,6 +617,7 @@ pub fn run_weierstrass_double< mem_config .assign_op( instance, + &mut shard_ctx, &mut lk_multiplicity, 10, &MemOp { diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index 58a9aae89..b36759d10 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -72,6 +72,7 @@ pub struct PublicValues { init_cycle: u64, end_pc: u32, end_cycle: u64, + shard_id: u32, public_io: Vec, } @@ -82,6 +83,7 @@ impl PublicValues { init_cycle: u64, end_pc: u32, end_cycle: u64, + shard_id: u32, public_io: Vec, ) -> Self { Self { @@ -90,6 +92,7 @@ impl PublicValues { init_cycle, end_pc, end_cycle, + shard_id, public_io, } } @@ -103,6 +106,7 @@ impl PublicValues { vec![E::BaseField::from_canonical_u64(self.init_cycle)], vec![E::BaseField::from_canonical_u32(self.end_pc)], vec![E::BaseField::from_canonical_u64(self.end_cycle)], + vec![E::BaseField::from_canonical_u32(self.shard_id)], ] .into_iter() .chain( diff --git a/ceno_zkvm/src/scheme/constants.rs b/ceno_zkvm/src/scheme/constants.rs index 3cc212e9f..191fdf103 100644 --- a/ceno_zkvm/src/scheme/constants.rs +++ b/ceno_zkvm/src/scheme/constants.rs @@ -1,5 +1,4 @@ pub(crate) const MIN_PAR_SIZE: usize = 64; -pub(crate) const SEL_DEGREE: usize = 2; pub const NUM_FANIN: usize = 2; pub const NUM_FANIN_LOGUP: usize = 2; diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index 9b0020116..b4972e3e2 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -567,7 +567,6 @@ impl> MainSumcheckProver> MainSumcheckProver>(); - let fixed_in_evals = evals.split_off(input.witness.len()); - let wits_in_evals = evals; - exit_span!(span); + let (wits_in_evals, fixed_in_evals, main_sumcheck_proof, rt) = { + let span = entered_span!("fixed::evals + witin::evals"); + let mut evals = input + .witness + .par_iter() + .chain(input.fixed.par_iter()) + .map(|poly| poly.evaluate(&rt_tower[..poly.num_vars()])) + .collect::>(); + let fixed_in_evals = evals.split_off(input.witness.len()); + let wits_in_evals = evals; + exit_span!(span); + (wits_in_evals, fixed_in_evals, None, rt_tower) + }; Ok(( - rt_tower, + rt, MainSumcheckEvals { wits_in_evals, fixed_in_evals, }, - None, + main_sumcheck_proof, None, )) } @@ -713,38 +713,38 @@ impl> OpeningProver>, points: Vec>, - mut evals: Vec>, // where each inner Vec = wit_evals + fixed_evals - circuit_num_polys: &[(usize, usize)], - num_instances: &[(usize, usize)], + mut evals: Vec>>, // where each inner vec![wit_evals, fixed_evals] transcript: &mut impl Transcript, ) -> PCS::Proof { let mut rounds = vec![]; - rounds.push(( - &witness_data, - points - .iter() - .zip_eq(evals.iter_mut()) - .zip_eq(num_instances.iter()) - .map(|((point, evals), (chip_idx, _))| { - let (num_witin, _) = circuit_num_polys[*chip_idx]; - (point.clone(), evals.drain(..num_witin).collect_vec()) + rounds.push((&witness_data, { + evals + .iter_mut() + .zip(&points) + .filter_map(|(evals, point)| { + let witin_evals = evals.remove(0); + if !witin_evals.is_empty() { + Some((point.clone(), witin_evals)) + } else { + None + } }) - .collect_vec(), - )); + .collect_vec() + })); if let Some(fixed_data) = fixed_data.as_ref().map(|f| f.as_ref()) { - rounds.push(( - fixed_data, - points - .iter() - .zip_eq(evals.iter_mut()) - .zip_eq(num_instances.iter()) - .filter(|(_, (chip_idx, _))| { - let (_, num_fixed) = circuit_num_polys[*chip_idx]; - num_fixed > 0 + rounds.push((fixed_data, { + evals + .iter_mut() + .zip(points) + .filter_map(|(evals, point)| { + if !evals.is_empty() && !evals[0].is_empty() { + Some((point.clone(), evals.remove(0))) + } else { + None + } }) - .map(|((point, evals), _)| (point.clone(), evals.to_vec())) - .collect_vec(), - )); + .collect_vec() + })); } PCS::batch_open(&self.backend.pp, rounds, transcript).unwrap() } diff --git a/ceno_zkvm/src/scheme/gpu/mod.rs b/ceno_zkvm/src/scheme/gpu/mod.rs index 023686e36..455b6786d 100644 --- a/ceno_zkvm/src/scheme/gpu/mod.rs +++ b/ceno_zkvm/src/scheme/gpu/mod.rs @@ -755,8 +755,6 @@ impl> OpeningProver as ProverBackend>::PcsData>>, points: Vec>, mut evals: Vec>, // where each inner Vec = wit_evals + fixed_evals - circuit_num_polys: &[(usize, usize)], - num_instances: &[(usize, usize)], transcript: &mut (impl Transcript + 'static), ) -> PCS::Proof { if std::any::TypeId::of::() != std::any::TypeId::of::() { @@ -764,32 +762,34 @@ impl> OpeningProver 0 + rounds.push((fixed_data, { + evals + .iter_mut() + .zip(points) + .filter_map(|(evals, point)| { + if !evals.is_empty() && !evals[0].is_empty() { + Some((point.clone(), evals.remove(0))) + } else { + None + } }) - .map(|((point, evals), _)| (point.clone(), evals.to_vec())) - .collect_vec(), - )); + .collect_vec() + })); } // use ceno_gpu::{ diff --git a/ceno_zkvm/src/scheme/hal.rs b/ceno_zkvm/src/scheme/hal.rs index 17ad6b92a..840ef1788 100644 --- a/ceno_zkvm/src/scheme/hal.rs +++ b/ceno_zkvm/src/scheme/hal.rs @@ -147,9 +147,7 @@ pub trait OpeningProver { witness_data: PB::PcsData, fixed_data: Option>, points: Vec>, - evals: Vec>, - circuit_num_polys: &[(usize, usize)], - num_instances: &[(usize, usize)], + evals: Vec>>, transcript: &mut (impl Transcript + 'static), ) -> >::Proof; } diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index 028f844a6..edf7a63f1 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -26,13 +26,14 @@ use itertools::{Itertools, chain, enumerate, izip}; use multilinear_extensions::{ Expression, WitnessId, fmt, mle::{ArcMultilinearExtension, IntoMLEs, MultilinearExtension}, + util::ceil_log2, utils::{eval_by_expr, eval_by_expr_with_fixed, eval_by_expr_with_instance}, }; use p3::field::{Field, FieldAlgebra}; use rand::thread_rng; use std::{ cmp::max, - collections::{BTreeSet, HashMap, HashSet}, + collections::{BTreeMap, BTreeSet, HashMap, HashSet}, fmt::Debug, fs::File, hash::Hash, @@ -42,6 +43,7 @@ use std::{ }; use strum::IntoEnumIterator; use tiny_keccak::{Hasher, Keccak}; +use witness::next_pow2_instance_padding; const MAX_CONSTRAINT_DEGREE: usize = 3; const MOCK_PROGRAM_SIZE: usize = 32; @@ -828,7 +830,10 @@ impl<'a, E: ExtensionField + Hash> MockProver { let mut cs = ConstraintSystem::::new(|| "mock_program"); let params = ProgramParams { platform: CENO_PLATFORM, - program_size: max(program.instructions.len(), MOCK_PROGRAM_SIZE), + program_size: max( + next_pow2_instance_padding(program.instructions.len()), + MOCK_PROGRAM_SIZE, + ), ..ProgramParams::default() }; let mut cb = CircuitBuilder::new(&mut cs); @@ -974,30 +979,48 @@ Hints: let mut fixed_mles = HashMap::new(); let mut num_instances = HashMap::new(); + let circuit_index_fixed_num_instances: BTreeMap = fixed_trace + .circuit_fixed_traces + .iter() + .map(|(circuit_name, rmm)| { + ( + circuit_name.clone(), + rmm.as_ref().map(|rmm| rmm.num_instances()).unwrap_or(0), + ) + }) + .collect(); let mut lkm_tables = LkMultiplicityRaw::::default(); let mut lkm_opcodes = LkMultiplicityRaw::::default(); // Process all circuits. - for ( - circuit_name, - ComposedConstrainSystem { + for (circuit_name, composed_cs) in &cs.circuit_css { + let ComposedConstrainSystem { zkvm_v1_css: cs, gkr_circuit, - }, - ) in &cs.circuit_css - { + } = &composed_cs; let is_opcode = gkr_circuit.is_some(); let [witness, structural_witness] = witnesses .get_opcode_witness(circuit_name) .or_else(|| witnesses.get_table_witness(circuit_name)) .unwrap_or_else(|| panic!("witness for {} should not be None", circuit_name)); - let num_rows = witness.num_instances(); + let num_rows = if witness.num_instances() > 0 { + witness.num_instances() + } else if structural_witness.num_instances() > 0 { + structural_witness.num_instances() + } else if composed_cs.is_static_circuit() { + circuit_index_fixed_num_instances + .get(circuit_name) + .copied() + .unwrap_or(0) + } else { + 0 + }; - if witness.num_instances() == 0 { + if num_rows == 0 { wit_mles.insert(circuit_name.clone(), vec![]); structural_wit_mles.insert(circuit_name.clone(), vec![]); fixed_mles.insert(circuit_name.clone(), vec![]); - num_instances.insert(circuit_name.clone(), num_rows); + num_instances.insert(circuit_name.clone(), 0); continue; } let mut witness = witness @@ -1133,21 +1156,20 @@ Hints: if *num_rows == 0 { continue; } - let w_selector: ArcMultilinearExtension<_> = if let Some(w_selector) = &cs.w_selector { structural_witness[w_selector.selector_expr().id()].clone() } else { let mut selector = vec![E::BaseField::ONE; *num_rows]; - selector.resize(witness[0].evaluations().len(), E::BaseField::ZERO); + selector.resize(next_pow2_instance_padding(*num_rows), E::BaseField::ZERO); MultilinearExtension::from_evaluation_vec_smart( - witness[0].num_vars(), + ceil_log2(next_pow2_instance_padding(*num_rows)), selector, ) .into() }; - for ((w_rlc_expr, annotation), _) in (cs + for ((w_rlc_expr, annotation), (ram_type_expr, _)) in (cs .w_expressions .iter() .chain(cs.w_table_expressions.iter().map(|expr| &expr.expr))) @@ -1157,8 +1179,19 @@ Hints: .chain(cs.w_table_expressions_namespace_map.iter()), ) .zip_eq(cs.w_ram_types.iter()) - .filter(|((_, _), (ram_type, _))| *ram_type == $ram_type) { + let ram_type_mle = wit_infer_by_expr( + ram_type_expr, + cs.num_witin, + cs.num_structural_witin, + cs.num_fixed as WitnessId, + fixed, + witness, + structural_witness, + &pi_mles, + &challenges, + ); + let ram_type_vec = ram_type_mle.get_ext_field_vec(); let write_rlc_records = wit_infer_by_expr( w_rlc_expr, cs.num_witin, @@ -1170,13 +1203,34 @@ Hints: &pi_mles, &challenges, ); + let w_selector_vec = w_selector.get_base_field_vec(); let write_rlc_records = - filter_mle_by_selector_mle(write_rlc_records, w_selector.clone()); + filter_mle_by_predicate(write_rlc_records, |i, _v| { + ram_type_vec[i] == E::from_canonical_u32($ram_type as u32) + && w_selector_vec[i] == E::BaseField::ONE + }); + if write_rlc_records.is_empty() { + continue; + } let mut records = vec![]; + let mut writes_within_expr_dedup = HashSet::new(); for (row, record_rlc) in enumerate(write_rlc_records) { // TODO: report error - assert_eq!(writes.insert(record_rlc), true); + assert_eq!( + writes_within_expr_dedup.insert(record_rlc), + true, + "within expression write duplicated on RAMType {:?} annotation {:?}", + $ram_type, + annotation + ); + assert_eq!( + writes.insert(record_rlc), + true, + "crossing-chip write duplicated on RAMType {:?} annotation {:?}", + $ram_type, + annotation + ); records.push((record_rlc, row)); } writes_grp_by_annotations @@ -1205,14 +1259,14 @@ Hints: structural_witness[r_selector.selector_expr().id()].clone() } else { let mut selector = vec![E::BaseField::ONE; *num_rows]; - selector.resize(witness[0].evaluations().len(), E::BaseField::ZERO); + selector.resize(next_pow2_instance_padding(*num_rows), E::BaseField::ZERO); MultilinearExtension::from_evaluation_vec_smart( - witness[0].num_vars(), + ceil_log2(next_pow2_instance_padding(*num_rows)), selector, ) .into() }; - for ((r_rlc_expr, annotation), (_, r_exprs)) in (cs + for ((r_rlc_expr, annotation), (ram_type_expr, r_exprs)) in (cs .r_expressions .iter() .chain(cs.r_table_expressions.iter().map(|expr| &expr.expr))) @@ -1222,8 +1276,19 @@ Hints: .chain(cs.r_table_expressions_namespace_map.iter()), ) .zip_eq(cs.r_ram_types.iter()) - .filter(|((_, _), (ram_type, _))| *ram_type == $ram_type) { + let ram_type_mle = wit_infer_by_expr( + ram_type_expr, + cs.num_witin, + cs.num_structural_witin, + cs.num_fixed as WitnessId, + fixed, + witness, + structural_witness, + &pi_mles, + &challenges, + ); + let ram_type_vec = ram_type_mle.get_ext_field_vec(); let read_records = wit_infer_by_expr( r_rlc_expr, cs.num_witin, @@ -1235,8 +1300,14 @@ Hints: &pi_mles, &challenges, ); - let read_records = - filter_mle_by_selector_mle(read_records, r_selector.clone()); + let r_selector_vec = r_selector.get_base_field_vec(); + let read_records = filter_mle_by_predicate(read_records, |i, _v| { + ram_type_vec[i] == E::from_canonical_u32($ram_type as u32) + && r_selector_vec[i] == E::BaseField::ONE + }); + if read_records.is_empty() { + continue; + } if $ram_type == RAMType::GlobalState { // r_exprs = [GlobalState, pc, timestamp] @@ -1269,9 +1340,23 @@ Hints: }; let mut records = vec![]; + let mut reads_within_expr_dedup = HashSet::new(); for (row, record) in enumerate(read_records) { // TODO: return error - assert_eq!(reads.insert(record), true); + assert_eq!( + reads_within_expr_dedup.insert(record), + true, + "within expression read duplicated on RAMType {:?} annotation {:?}", + $ram_type, + annotation, + ); + assert_eq!( + reads.insert(record), + true, + "crossing-chip read duplicated on RAMType {:?} annotation {:?}", + $ram_type, + annotation, + ); records.push((record, row)); } reads_grp_by_annotations @@ -1467,6 +1552,19 @@ fn print_errors( } } +fn filter_mle_by_predicate(target_mle: ArcMultilinearExtension, mut predicate: F) -> Vec +where + E: ExtensionField, + F: FnMut(usize, &E) -> bool, +{ + target_mle + .get_ext_field_vec() + .iter() + .enumerate() + .filter_map(|(i, v)| if predicate(i, v) { Some(*v) } else { None }) + .collect_vec() +} + fn filter_mle_by_selector_mle( target_mle: ArcMultilinearExtension, selector: ArcMultilinearExtension, @@ -1487,7 +1585,6 @@ fn filter_mle_by_selector_mle( #[cfg(test)] mod tests { - use super::*; use crate::{ ROMType, diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index e1094d77f..ace2215a1 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -113,14 +113,20 @@ impl< // only keep track of circuits that have non-zero instances let mut num_instances = Vec::with_capacity(self.pk.circuit_pks.len()); let mut num_instances_with_rotation = Vec::with_capacity(self.pk.circuit_pks.len()); + let mut circuit_name_num_instances_mapping = BTreeMap::new(); for (index, (circuit_name, ProvingKey { vk, .. })) in self.pk.circuit_pks.iter().enumerate() { // num_instance from witness might include rotation if let Some(num_instance) = witnesses .get_opcode_witness(circuit_name) .or_else(|| witnesses.get_table_witness(circuit_name)) - .map(|rmms| &rmms[0]) - .map(|rmm| rmm.num_instances()) + .map(|rmms| { + if rmms[0].num_instances() == 0 { + rmms[1].num_instances() + } else { + rmms[0].num_instances() + } + }) .and_then(|num_instance| { if num_instance > 0 { Some(num_instance) @@ -128,12 +134,25 @@ impl< None } }) + .or_else(|| { + vk.get_cs().is_static_circuit().then(|| { + self.pk + .circuit_index_fixed_num_instances + .get(&index) + .copied() + .unwrap_or(0) + }) + }) { num_instances.push(( index, num_instance >> vk.get_cs().rotation_vars().unwrap_or(0), )); - num_instances_with_rotation.push((index, num_instance)) + num_instances_with_rotation.push((index, num_instance)); + circuit_name_num_instances_mapping.insert( + circuit_name, + num_instance >> vk.get_cs().rotation_vars().unwrap_or(0), + ); } } @@ -144,7 +163,6 @@ impl< } let commit_to_traces_span = entered_span!("batch commit to traces", profiling_1 = true); - let mut wits_instances = BTreeMap::new(); let mut wits_rmms = BTreeMap::new(); let mut structural_wits = BTreeMap::new(); @@ -157,31 +175,19 @@ impl< } else { RowMajorMatrix::empty() }; - let rotation_vars = self - .pk - .circuit_pks - .get(&circuit_name) - .unwrap() - .vk - .get_cs() - .rotation_vars(); - let num_instances = witness_rmm.num_instances() >> (rotation_vars.unwrap_or(0)); - assert!( - wits_instances - .insert(circuit_name.clone(), num_instances) - .is_none() - ); - if num_instances == 0 { - continue; - } - let structural_witness = structural_witness_rmm.to_mles(); - wits_rmms.insert(circuit_name_index_mapping[&circuit_name], witness_rmm); - structural_wits.insert(circuit_name, (structural_witness, num_instances)); + if witness_rmm.num_instances() > 0 { + wits_rmms.insert(circuit_name_index_mapping[&circuit_name], witness_rmm); + } + if structural_witness_rmm.num_instances() > 0 { + let num_instances = circuit_name_num_instances_mapping + .get(&circuit_name) + .unwrap(); + let structural_witness = structural_witness_rmm.to_mles(); + structural_wits.insert(circuit_name, (structural_witness, num_instances)); + } } - debug_assert_eq!(num_instances.len(), wits_rmms.len()); - // commit to witness traces in batch let (mut witness_mles, witness_data, witin_commit) = self.device.commit_traces(wits_rmms); PCS::write_commitment(&witin_commit, &mut transcript).map_err(ZKVMError::PCSError)?; @@ -208,9 +214,10 @@ impl< let (points, evaluations) = self.pk.circuit_pks.iter().enumerate().try_fold( (vec![], vec![]), |(mut points, mut evaluations), (index, (circuit_name, pk))| { - let num_instances = *wits_instances - .get(circuit_name) - .ok_or(ZKVMError::WitnessNotFound(circuit_name.to_string().into()))?; + let num_instances = circuit_name_num_instances_mapping + .get(&circuit_name) + .copied() + .unwrap_or(0); let cs = pk.get_cs(); if num_instances == 0 { // we need to drain respective fixed when num_instances is 0 @@ -237,8 +244,7 @@ impl< exit_span!(structural_witness_span); let fixed = fixed_mles.drain(..cs.num_fixed()).collect_vec(); - - let mut input = ProofInput { + let input = ProofInput { witness: witness_mle, fixed, structural_witness, @@ -260,23 +266,30 @@ impl< num_instances ); points.push(input_opening_point); - evaluations.push(opcode_proof.wits_in_evals.clone()); + evaluations.push(vec![opcode_proof.wits_in_evals.clone()]); chip_proofs.insert(index, opcode_proof); } else { // FIXME: PROGRAM table circuit is not guaranteed to have 2^n instances - input.num_instances = 1 << input.log2_num_instances(); - let (mut table_proof, pi_in_evals, input_opening_point) = self - .create_chip_proof(circuit_name, pk, input, &mut transcript, &challenges)?; - points.push(input_opening_point); - evaluations.push( - [ + // input.num_instances = 1 << input.log2_num_instances(); + let (table_proof, pi_in_evals, input_opening_point) = self.create_chip_proof( + circuit_name, + pk, + input, + &mut transcript, + &challenges, + )?; + if cs.num_witin() > 0 || cs.num_fixed() > 0 { + points.push(input_opening_point); + evaluations.push(vec![ table_proof.wits_in_evals.clone(), table_proof.fixed_in_evals.clone(), - ] - .concat(), - ); + ]); + } else { + assert!(table_proof.wits_in_evals.is_empty()); + assert!(table_proof.fixed_in_evals.is_empty()); + } // FIXME: PROGRAM table circuit is not guaranteed to have 2^n instances - table_proof.num_instances = num_instances; + // table_proof.num_instances = num_instances; chip_proofs.insert(index, table_proof); for (idx, eval) in pi_in_evals { pi_evals[idx] = eval; @@ -289,20 +302,12 @@ impl< // batch opening pcs // generate static info from prover key for expected num variable - let circuit_num_polys = self - .pk - .circuit_pks - .values() - .map(|pk| (pk.get_cs().num_witin(), pk.get_cs().num_fixed())) - .collect_vec(); let pcs_opening = entered_span!("pcs_opening", profiling_1 = true); let mpcs_opening_proof = self.device.open( witness_data, Some(device_pk.pcs_data), points, evaluations, - &circuit_num_polys, - &num_instances_with_rotation, &mut transcript, ); exit_span!(pcs_opening); diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 60dff6a99..5cce8f4db 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -42,7 +42,7 @@ use super::{ utils::infer_tower_product_witness, verifier::{TowerVerify, ZKVMVerifier}, }; -use crate::tables::DynamicRangeTableCircuit; +use crate::{e2e::ShardContext, tables::DynamicRangeTableCircuit}; use itertools::Itertools; use mpcs::{ PolynomialCommitmentScheme, SecurityLevel, SecurityLevel::Conjecture100bits, WhirDefault, @@ -90,6 +90,7 @@ impl Instruction for Test fn assign_instance( config: &Self::InstructionConfig, + _shard_ctx: &mut ShardContext, instance: &mut [E::BaseField], _lk_multiplicity: &mut LkMultiplicity, _step: &StepRecord, @@ -118,6 +119,7 @@ fn test_rw_lk_expression_combination() { let name = TestCircuit::::name(); let mut zkvm_cs = ZKVMConstraintSystem::default(); let config = zkvm_cs.register_opcode_circuit::>(); + let mut shard_ctx = ShardContext::default(); // generate fixed traces let mut zkvm_fixed_traces = ZKVMFixedTraces::default(); @@ -140,6 +142,7 @@ fn test_rw_lk_expression_combination() { zkvm_witness .assign_opcode_circuit::>( &zkvm_cs, + &mut shard_ctx, &config, vec![StepRecord::default(); num_instances], ) @@ -274,6 +277,7 @@ fn test_single_add_instance_e2e() { Pcs::setup(1 << MAX_NUM_VARIABLES, SecurityLevel::default()).expect("Basefold PCS setup"); let (pp, vp) = Pcs::trim((), 1 << MAX_NUM_VARIABLES).expect("Basefold trim"); let mut zkvm_cs = ZKVMConstraintSystem::default(); + let mut shard_ctx = ShardContext::default(); // opcode circuits let add_config = zkvm_cs.register_opcode_circuit::>(); let halt_config = zkvm_cs.register_opcode_circuit::>(); @@ -339,10 +343,20 @@ fn test_single_add_instance_e2e() { let mut zkvm_witness = ZKVMWitnesses::default(); // assign opcode circuits zkvm_witness - .assign_opcode_circuit::>(&zkvm_cs, &add_config, add_records) + .assign_opcode_circuit::>( + &zkvm_cs, + &mut shard_ctx, + &add_config, + add_records, + ) .unwrap(); zkvm_witness - .assign_opcode_circuit::>(&zkvm_cs, &halt_config, halt_records) + .assign_opcode_circuit::>( + &zkvm_cs, + &mut shard_ctx, + &halt_config, + halt_records, + ) .unwrap(); zkvm_witness.finalize_lk_multiplicities(); zkvm_witness @@ -356,7 +370,7 @@ fn test_single_add_instance_e2e() { .assign_table_circuit::>(&zkvm_cs, &prog_config, &program) .unwrap(); - let pi = PublicValues::new(0, 0, 0, 0, 0, vec![0]); + let pi = PublicValues::new(0, 0, 0, 0, 0, 0, vec![0]); let transcript = BasicTranscript::new(b"riscv"); let zkvm_proof = prover .create_proof(zkvm_witness, pi, transcript) diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index c8b67929e..637fa09b1 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -350,12 +350,16 @@ pub fn build_main_witness< } if let Some(gkr_circuit) = gkr_circuit { - // opcode must have at least one read/write/lookup + // circuit must have at least one read/write/lookup assert!( - cs.lk_expressions.is_empty() - || !cs.r_expressions.is_empty() - || !cs.w_expressions.is_empty(), - "assert opcode circuit" + cs.r_expressions.len() + + cs.w_expressions.len() + + cs.lk_expressions.len() + + cs.r_table_expressions.len() + + cs.w_table_expressions.len() + + cs.lk_table_expressions.len() + > 0, + "assert circuit" ); let (_, gkr_circuit_out) = gkr_witness::( @@ -370,7 +374,7 @@ pub fn build_main_witness< } else { ( >::table_witness(device, input, cs, challenges), - false, + input.num_instances > 1 && input.num_instances.is_power_of_two(), ) } }; diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index f8c1c8a2a..b38c6e589 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -1,10 +1,20 @@ -use std::marker::PhantomData; - +use either::Either; use ff_ext::ExtensionField; +use std::marker::PhantomData; #[cfg(debug_assertions)] use ff_ext::{Instrumented, PoseidonField}; +use super::{ZKVMChipProof, ZKVMProof}; +use crate::{ + error::ZKVMError, + scheme::constants::{NUM_FANIN, NUM_FANIN_LOGUP}, + structs::{ComposedConstrainSystem, PointAndEval, TowerProofs, VerifyingKey, ZKVMVerifyingKey}, + utils::{ + eval_inner_repeated_incremental_vec, eval_outer_repeated_incremental_vec, + eval_stacked_constant_vec, eval_stacked_wellform_address_vec, eval_wellform_address_vec, + }, +}; use gkr_iop::gkr::GKRClaims; use itertools::{Itertools, chain, interleave, izip}; use mpcs::{Point, PolynomialCommitmentScheme}; @@ -23,18 +33,6 @@ use sumcheck::{ use transcript::{ForkableTranscript, Transcript}; use witness::next_pow2_instance_padding; -use crate::{ - error::ZKVMError, - scheme::constants::{NUM_FANIN, NUM_FANIN_LOGUP, SEL_DEGREE}, - structs::{ComposedConstrainSystem, PointAndEval, TowerProofs, VerifyingKey, ZKVMVerifyingKey}, - utils::{ - eval_inner_repeated_incremental_vec, eval_outer_repeated_incremental_vec, - eval_stacked_constant_vec, eval_stacked_wellform_address_vec, eval_wellform_address_vec, - }, -}; - -use super::{ZKVMChipProof, ZKVMProof}; - pub struct ZKVMVerifier> { pub vk: ZKVMVerifyingKey, } @@ -157,11 +155,10 @@ impl> ZKVMVerifier let dummy_table_item = challenges[0]; let mut dummy_table_item_multiplicity = 0; let point_eval = PointAndEval::default(); - let mut rt_points = Vec::with_capacity(vm_proof.chip_proofs.len()); - let mut evaluations = Vec::with_capacity(vm_proof.chip_proofs.len()); let mut witin_openings = Vec::with_capacity(vm_proof.chip_proofs.len()); let mut fixed_openings = Vec::with_capacity(vm_proof.chip_proofs.len()); for (index, proof) in &vm_proof.chip_proofs { + assert!(proof.num_instances > 0); let circuit_name = &self.vk.circuit_index_to_name[index]; let circuit_vk = &self.vk.circuit_vks[circuit_name]; @@ -254,22 +251,18 @@ impl> ZKVMVerifier &challenges, )? }; - rt_points.push((*index, input_opening_point.clone())); - evaluations.push(( - *index, - [proof.wits_in_evals.clone(), proof.fixed_in_evals.clone()].concat(), - )); - witin_openings.push(( - input_opening_point.len(), - (input_opening_point.clone(), proof.wits_in_evals.clone()), - )); - if !proof.fixed_in_evals.is_empty() { + if circuit_vk.get_cs().num_witin() > 0 { + witin_openings.push(( + input_opening_point.len(), + (input_opening_point.clone(), proof.wits_in_evals.clone()), + )); + } + if circuit_vk.get_cs().num_fixed() > 0 { fixed_openings.push(( input_opening_point.len(), (input_opening_point.clone(), proof.fixed_in_evals.clone()), )); } - prod_w *= proof.w_out_evals.iter().flatten().copied().product::(); prod_r *= proof.r_out_evals.iter().flatten().copied().product::(); tracing::debug!("verified proof for circuit {}", circuit_name); @@ -355,9 +348,9 @@ impl> ZKVMVerifier } = &composed_cs; let num_instances = proof.num_instances; let (r_counts_per_instance, w_counts_per_instance, lk_counts_per_instance) = ( - cs.r_expressions.len(), - cs.w_expressions.len(), - cs.lk_expressions.len(), + cs.r_expressions.len() + cs.r_table_expressions.len(), + cs.w_expressions.len() + cs.w_table_expressions.len(), + cs.lk_expressions.len() + cs.lk_table_expressions.len() * 2, ); let num_batched = r_counts_per_instance + w_counts_per_instance + lk_counts_per_instance; @@ -437,43 +430,43 @@ impl> ZKVMVerifier let ComposedConstrainSystem { zkvm_v1_css: cs, .. } = circuit_vk.get_cs(); - debug_assert!( - cs.r_table_expressions - .iter() - .zip_eq(cs.w_table_expressions.iter()) - .all(|(r, w)| r.table_spec.len == w.table_spec.len) - ); - + let with_rw = !cs.r_table_expressions.is_empty() && !cs.w_table_expressions.is_empty(); + if with_rw { + debug_assert!( + cs.r_table_expressions + .iter() + .zip_eq(cs.w_table_expressions.iter()) + .all(|(r, w)| r.table_spec.len == w.table_spec.len) + ); + } let log2_num_instances = next_pow2_instance_padding(proof.num_instances).ilog2() as usize; - // in table proof, we always skip same point sumcheck for now - // as tower sumcheck batch product argument/logup in same length - let is_skip_same_point_sumcheck = true; - // verify and reduce product tower sumcheck let tower_proofs = &proof.tower_proof; // NOTE: for all structural witness within same constrain system should got same hints num variable via `log2_num_instances` - let expected_rounds = cs - // only iterate r set, as read/write set round should match - .r_table_expressions - .iter() - .flat_map(|r| { + let expected_rounds = interleave(&cs.r_table_expressions, &cs.w_table_expressions) + .map(|set_table_expr| { // iterate through structural witins and collect max round. - let num_vars = r.table_spec.len.map(ceil_log2).unwrap_or_else(|| { - r.table_spec - .structural_witins - .iter() - .map(|StructuralWitIn { witin_type, .. }| { - let hint_num_vars = log2_num_instances; - assert!((1 << hint_num_vars) <= witin_type.max_len()); - hint_num_vars - }) - .max() - .unwrap() - }); + let num_vars = set_table_expr + .table_spec + .len + .map(ceil_log2) + .unwrap_or_else(|| { + set_table_expr + .table_spec + .structural_witins + .iter() + .map(|StructuralWitIn { witin_type, .. }| { + let hint_num_vars = log2_num_instances; + assert!((1 << hint_num_vars) <= witin_type.max_len()); + hint_num_vars + }) + .max() + .unwrap() + }); assert_eq!(num_vars, log2_num_instances); - [num_vars, num_vars] // format: [read_round, write_round] + num_vars }) .chain(cs.lk_table_expressions.iter().map(|l| { // iterate through structural witins and collect max round. @@ -494,14 +487,10 @@ impl> ZKVMVerifier })) .collect_vec(); - let expected_max_rounds = expected_rounds.iter().cloned().max().unwrap(); let (rt_tower, prod_point_and_eval, logup_p_point_and_eval, logup_q_point_and_eval) = TowerVerify::verify( - proof - .r_out_evals - .iter() - .zip(proof.w_out_evals.iter()) - .flat_map(|(r_evals, w_evals)| [r_evals.to_vec(), w_evals.to_vec()]) + interleave(&proof.r_out_evals, &proof.w_out_evals) + .map(|eval| eval.to_vec()) .collect_vec(), proof .lk_out_evals @@ -530,13 +519,19 @@ impl> ZKVMVerifier cs.r_table_expressions.len() + cs.w_table_expressions.len(), "[prod_record] mismatch length" ); - let num_rw_records = cs.r_table_expressions.len() + cs.w_table_expressions.len(); - // evaluate the evaluation of structural mles at input_opening_point by verifier - let structural_evals = cs - .r_table_expressions - .iter() - .map(|r| &r.table_spec) + // TODO differentiate `ram_bus` via cs + let is_shard_ram_bus_circuit = false; + + let input_opening_point = if !is_shard_ram_bus_circuit { + // evaluate the evaluation of structural mles at input_opening_point by verifier + let structural_evals = if with_rw { + // only iterate r set, as read/write set round should match + Either::Left(cs.r_table_expressions.iter()) + } else { + Either::Right(cs.r_table_expressions.iter().chain(&cs.w_table_expressions)) + } + .map(|set_table_expr| &set_table_expr.table_spec) .chain(cs.lk_table_expressions.iter().map(|r| &r.table_spec)) .flat_map(|table_spec| { table_spec @@ -571,32 +566,30 @@ impl> ZKVMVerifier }) .collect_vec(); - // verify records (degree = 1) statement, thus no sumcheck - let expected_evals = interleave( - &cs.r_table_expressions, // r - &cs.w_table_expressions, // w - ) - .map(|rw| &rw.expr) - .chain( - cs.lk_table_expressions - .iter() - .flat_map(|lk| vec![&lk.multiplicity, &lk.values]), // p, q - ) - .map(|expr| { - eval_by_expr_with_instance( - &proof.fixed_in_evals, - &proof.wits_in_evals, - &structural_evals, - pi, - challenges, - expr, + // verify records (degree = 1) statement, thus no sumcheck + let expected_evals = interleave( + &cs.r_table_expressions, // r + &cs.w_table_expressions, // w ) - .right() - .unwrap() - }) - .collect_vec(); - - let input_opening_point = if is_skip_same_point_sumcheck { + .map(|rw| &rw.expr) + .chain( + cs.lk_table_expressions + .iter() + .flat_map(|lk| vec![&lk.multiplicity, &lk.values]), // p, q + ) + .map(|expr| { + eval_by_expr_with_instance( + &proof.fixed_in_evals, + &proof.wits_in_evals, + &structural_evals, + pi, + challenges, + expr, + ) + .right() + .unwrap() + }) + .collect_vec(); for (expected_eval, eval) in expected_evals.iter().zip( prod_point_and_eval .into_iter() @@ -619,84 +612,7 @@ impl> ZKVMVerifier } rt_tower } else { - assert!(proof.main_sumcheck_proofs.is_some()); - - // verify opening same point layer sumcheck - let alpha_pow = get_challenge_pows( - cs.r_table_expressions.len() - + cs.w_table_expressions.len() - + cs.lk_table_expressions.len() * 2, // 2 for lk numerator and denominator - transcript, - ); - - // \sum_i alpha_{i} * (out_r_eval{i}) - // + \sum_i alpha_{i} * (out_w_eval{i}) - // + \sum_i alpha_{i} * (out_lk_n{i}) - // + \sum_i alpha_{i} * (out_lk_d{i}) - let claim_sum = prod_point_and_eval - .iter() - .zip(alpha_pow.iter()) - .map(|(point_and_eval, alpha)| *alpha * point_and_eval.eval) - .sum::() - + interleave(&logup_p_point_and_eval, &logup_q_point_and_eval) - .zip_eq(alpha_pow.iter().skip(num_rw_records)) - .map(|(point_n_eval, alpha)| *alpha * point_n_eval.eval) - .sum::(); - let sel_subclaim = IOPVerifierState::verify( - claim_sum, - &IOPProof { - proofs: proof.main_sumcheck_proofs.clone().unwrap(), - }, - &VPAuxInfo { - max_degree: SEL_DEGREE, - max_num_variables: expected_max_rounds, - phantom: PhantomData, - }, - transcript, - ); - let (input_opening_point, expected_evaluation) = ( - sel_subclaim.point.iter().map(|c| c.elements).collect_vec(), - sel_subclaim.expected_evaluation, - ); - - let computed_evals = [ - // r, w - prod_point_and_eval - .into_iter() - .zip_eq(&expected_evals[0..num_rw_records]) - .zip(alpha_pow.iter()) - .map(|((point_and_eval, in_eval), alpha)| { - let eq = eq_eval( - &point_and_eval.point, - &input_opening_point[0..point_and_eval.point.len()], - ); - // TODO times multiplication factor - *alpha * eq * *in_eval - }) - .sum::(), - interleave(logup_p_point_and_eval, logup_q_point_and_eval) - .zip_eq(&expected_evals[num_rw_records..]) - .zip_eq(alpha_pow.iter().skip(num_rw_records)) - .map(|((point_and_eval, in_eval), alpha)| { - let eq = eq_eval( - &point_and_eval.point, - &input_opening_point[0..point_and_eval.point.len()], - ); - // TODO times multiplication factor - *alpha * eq * *in_eval - }) - .sum::(), - ] - .iter() - .copied() - .sum::(); - - if computed_evals != expected_evaluation { - return Err(ZKVMError::VerifyError( - "sel evaluation verify failed".into(), - )); - } - input_opening_point + unimplemented!("shard ram bus circuit go here"); }; // assume public io is tiny vector, so we evaluate it directly without PCS @@ -749,9 +665,9 @@ impl TowerVerify { let log2_num_fanin = ceil_log2(num_fanin); // sanity check - assert!(num_prod_spec == tower_proofs.prod_spec_size()); + assert_eq!(num_prod_spec, tower_proofs.prod_spec_size()); assert!(prod_out_evals.iter().all(|evals| evals.len() == num_fanin)); - assert!(num_logup_spec == tower_proofs.logup_spec_size()); + assert_eq!(num_logup_spec, tower_proofs.logup_spec_size()); assert!(logup_out_evals.iter().all(|evals| { evals.len() == 4 // [p1, p2, q1, q2] })); diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index cd76d6fcd..8c92036ae 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -1,5 +1,6 @@ use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, error::ZKVMError, instructions::Instruction, state::StateCircuit, @@ -108,10 +109,19 @@ impl ComposedConstrainSystem { self.zkvm_v1_css.num_witin.into() } + pub fn num_structural_witin(&self) -> usize { + self.zkvm_v1_css.num_structural_witin.into() + } + pub fn num_fixed(&self) -> usize { self.zkvm_v1_css.num_fixed } + /// static circuit means there is only fixed column + pub fn is_static_circuit(&self) -> bool { + (self.num_witin() + self.num_structural_witin()) == 0 && self.num_fixed() > 0 + } + pub fn num_reads(&self) -> usize { self.zkvm_v1_css.r_expressions.len() + self.zkvm_v1_css.r_table_expressions.len() } @@ -125,9 +135,7 @@ impl ComposedConstrainSystem { } pub fn is_opcode_circuit(&self) -> bool { - self.zkvm_v1_css.lk_table_expressions.is_empty() - && self.zkvm_v1_css.r_table_expressions.is_empty() - && self.zkvm_v1_css.w_table_expressions.is_empty() + self.gkr_circuit.is_some() } /// return number of lookup operation @@ -209,18 +217,13 @@ impl ZKVMConstraintSystem { pub fn register_table_circuit>(&mut self) -> TC::TableConfig { let mut cs = ConstraintSystem::new(|| format!("riscv_table/{}", TC::name())); let mut circuit_builder = CircuitBuilder::::new(&mut cs); - let config = TC::construct_circuit(&mut circuit_builder, &self.params).unwrap(); - assert!( - self.circuit_css - .insert( - TC::name(), - ComposedConstrainSystem { - zkvm_v1_css: cs, - gkr_circuit: None - } - ) - .is_none() - ); + let (config, gkr_iop_circuit) = + TC::build_gkr_iop_circuit(&mut circuit_builder, &self.params).unwrap(); + let cs = ComposedConstrainSystem { + zkvm_v1_css: cs, + gkr_circuit: gkr_iop_circuit, + }; + assert!(self.circuit_css.insert(TC::name(), cs).is_none()); config } @@ -310,6 +313,7 @@ impl ZKVMWitnesses { pub fn assign_opcode_circuit>( &mut self, cs: &ZKVMConstraintSystem, + shard_ctx: &mut ShardContext, config: &OC::InstructionConfig, records: Vec, ) -> Result<(), ZKVMError> { @@ -318,6 +322,7 @@ impl ZKVMWitnesses { let cs = cs.get_cs(&OC::name()).unwrap(); let (witness, logup_multiplicity) = OC::assign_instances( config, + shard_ctx, cs.zkvm_v1_css.num_witin as usize, cs.zkvm_v1_css.num_structural_witin as usize, records, @@ -404,6 +409,7 @@ pub struct ZKVMProvingKey> pub circuit_pks: BTreeMap>, pub fixed_commit_wd: Option>::CommitmentWithWitness>>, pub fixed_commit: Option<>::Commitment>, + pub circuit_index_fixed_num_instances: BTreeMap, // expression for global state in/out pub initial_global_state_expr: Expression, @@ -418,6 +424,7 @@ impl> ZKVMProvingKey { params: &ProgramParams, ) -> Result; + fn build_gkr_iop_circuit( + cb: &mut CircuitBuilder, + param: &ProgramParams, + ) -> Result<(Self::TableConfig, Option>), ZKVMError> { + let config = Self::construct_circuit(cb, param)?; + Ok((config, None)) + } + fn generate_fixed_traces( config: &Self::TableConfig, num_fixed: usize, diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index 41890200e..833663e74 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -182,6 +182,7 @@ impl TableCircuit for ProgramTableCircuit { cb: &mut CircuitBuilder, params: &ProgramParams, ) -> Result { + assert!(params.program_size.is_power_of_two()); #[cfg(not(feature = "u16limb_circuit"))] let record = InsnRecord([ cb.create_fixed(|| "pc"), @@ -214,7 +215,7 @@ impl TableCircuit for ProgramTableCircuit { cb.lk_table_record( || "prog table", SetTableSpec { - len: Some(params.program_size.next_power_of_two()), + len: Some(params.program_size), structural_witins: vec![], }, ROMType::Instruction, diff --git a/ceno_zkvm/src/tables/ram.rs b/ceno_zkvm/src/tables/ram.rs index e34ce1dcc..6075b0440 100644 --- a/ceno_zkvm/src/tables/ram.rs +++ b/ceno_zkvm/src/tables/ram.rs @@ -8,6 +8,12 @@ use crate::{ mod ram_circuit; mod ram_impl; +use crate::tables::ram::{ + ram_circuit::{LocalFinalRamCircuit, RamBusCircuit}, + ram_impl::{ + DynVolatileRamTableConfig, DynVolatileRamTableInitConfig, NonVolatileInitTableConfig, + }, +}; pub use ram_circuit::{DynVolatileRamTable, MemFinalRecord, MemInitRecord, NonVolatileTable}; #[derive(Clone)] @@ -32,7 +38,8 @@ impl DynVolatileRamTable for HeapTable { } } -pub type HeapCircuit = DynVolatileRamCircuit; +pub type HeapInitCircuit = + DynVolatileRamCircuit>; #[derive(Clone)] pub struct StackTable; @@ -66,7 +73,8 @@ impl DynVolatileRamTable for StackTable { } } -pub type StackCircuit = DynVolatileRamCircuit; +pub type StackInitCircuit = + DynVolatileRamCircuit>; #[derive(Clone)] pub struct HintsTable; @@ -88,7 +96,8 @@ impl DynVolatileRamTable for HintsTable { "HintsTable" } } -pub type HintsCircuit = DynVolatileRamCircuit; +pub type HintsCircuit = + DynVolatileRamCircuit>; /// RegTable, fix size without offset #[derive(Clone)] @@ -108,7 +117,8 @@ impl NonVolatileTable for RegTable { } } -pub type RegTableCircuit = NonVolatileRamCircuit; +pub type RegTableInitCircuit = + NonVolatileRamCircuit>; #[derive(Clone)] pub struct StaticMemTable; @@ -127,7 +137,8 @@ impl NonVolatileTable for StaticMemTable { } } -pub type StaticMemCircuit = NonVolatileRamCircuit; +pub type StaticMemInitCircuit = + NonVolatileRamCircuit>; #[derive(Clone)] pub struct PubIOTable; @@ -147,3 +158,5 @@ impl NonVolatileTable for PubIOTable { } pub type PubIOCircuit = PubIORamCircuit; +pub type LocalFinalCircuit<'a, E> = LocalFinalRamCircuit<'a, UINT_LIMBS, E>; +pub type RBCircuit<'a, E> = RamBusCircuit<'a, UINT_LIMBS, E>; diff --git a/ceno_zkvm/src/tables/ram/ram_circuit.rs b/ceno_zkvm/src/tables/ram/ram_circuit.rs index 0a8b6bf97..8fc43e348 100644 --- a/ceno_zkvm/src/tables/ram/ram_circuit.rs +++ b/ceno_zkvm/src/tables/ram/ram_circuit.rs @@ -1,17 +1,27 @@ use std::{collections::HashMap, marker::PhantomData}; -use ceno_emul::{Addr, Cycle, GetAddr, WORD_SIZE, Word}; -use ff_ext::ExtensionField; -use witness::{InstancePaddingStrategy, RowMajorMatrix}; - +use super::ram_impl::{ + LocalFinalRAMTableConfig, NonVolatileTableConfigTrait, PubIOTableConfig, RAMBusConfig, +}; use crate::{ circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, structs::{ProgramParams, RAMType}, tables::{RMMCollections, TableCircuit}, }; - -use super::ram_impl::{DynVolatileRamTableConfig, NonVolatileTableConfig, PubIOTableConfig}; +use ceno_emul::{Addr, Cycle, GetAddr, WORD_SIZE, Word}; +use ff_ext::{ExtensionField, SmallField}; +use gkr_iop::{ + chip::Chip, + error::CircuitBuilderError, + gkr::{GKRCircuit, layer::Layer}, + selector::SelectorType, +}; +use itertools::Itertools; +use multilinear_extensions::{StructuralWitInType, ToExpr}; +use p3::field::FieldAlgebra; +use witness::{InstancePaddingStrategy, RowMajorMatrix}; #[derive(Clone, Debug)] pub struct MemInitRecord { @@ -21,6 +31,7 @@ pub struct MemInitRecord { #[derive(Clone, Debug)] pub struct MemFinalRecord { + pub ram_type: RAMType, pub addr: Addr, pub cycle: Cycle, pub value: Word, @@ -60,12 +71,15 @@ pub trait NonVolatileTable { /// - with fixed initial content, /// - with witnessed final content that the program wrote, if WRITABLE, /// - or final content equal to initial content, if not WRITABLE. -pub struct NonVolatileRamCircuit(PhantomData<(E, R)>); +pub struct NonVolatileRamCircuit(PhantomData<(E, R, C)>); -impl TableCircuit - for NonVolatileRamCircuit +impl< + E: ExtensionField, + NVRAM: NonVolatileTable + Send + Sync + Clone, + C: NonVolatileTableConfigTrait, +> TableCircuit for NonVolatileRamCircuit { - type TableConfig = NonVolatileTableConfig; + type TableConfig = C::Config; type FixedInput = [MemInitRecord]; type WitnessInput = [MemFinalRecord]; @@ -77,10 +91,7 @@ impl TableCirc cb: &mut CircuitBuilder, params: &ProgramParams, ) -> Result { - Ok(cb.namespace( - || Self::name(), - |cb| Self::TableConfig::construct_circuit(cb, params), - )?) + Ok(cb.namespace(|| Self::name(), |cb| C::construct_circuit(cb, params))?) } fn generate_fixed_traces( @@ -89,7 +100,7 @@ impl TableCirc init_v: &Self::FixedInput, ) -> RowMajorMatrix { // assume returned table is well-formed include padding - config.gen_init_state(num_fixed, init_v) + C::gen_init_state(config, num_fixed, init_v) } fn assign_instances( @@ -100,7 +111,12 @@ impl TableCirc final_v: &Self::WitnessInput, ) -> Result, ZKVMError> { // assume returned table is well-formed include padding - Ok(config.assign_instances(num_witin, num_structural_witin, final_v)?) + Ok(C::assign_instances( + config, + num_witin, + num_structural_witin, + final_v, + )?) } } @@ -189,6 +205,20 @@ pub trait DynVolatileRamTable { } } +pub trait DynVolatileRamTableConfigTrait: Sized + Send + Sync { + type Config: Sized + Send + Sync; + fn construct_circuit( + cb: &mut CircuitBuilder, + params: &ProgramParams, + ) -> Result; + fn assign_instances( + config: &Self::Config, + num_witin: usize, + num_structural_witin: usize, + final_mem: &[MemFinalRecord], + ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError>; +} + /// DynVolatileRamCircuit initializes and finalizes memory /// - at witnessed addresses, in a contiguous range chosen by the prover, /// - with zeros as initial content if ZERO_INIT, @@ -197,12 +227,15 @@ pub trait DynVolatileRamTable { /// If not ZERO_INIT: /// - The initial content is an unconstrained prover hint. /// - The final content is equal to this initial content. -pub struct DynVolatileRamCircuit(PhantomData<(E, R)>); +pub struct DynVolatileRamCircuit(PhantomData<(E, R, C)>); -impl TableCircuit - for DynVolatileRamCircuit +impl< + E: ExtensionField, + DVRAM: DynVolatileRamTable + Send + Sync + Clone, + C: DynVolatileRamTableConfigTrait, +> TableCircuit for DynVolatileRamCircuit { - type TableConfig = DynVolatileRamTableConfig; + type TableConfig = C::Config; type FixedInput = (); type WitnessInput = [MemFinalRecord]; @@ -210,6 +243,57 @@ impl TableC format!("RAM_{:?}_{}", DVRAM::RAM_TYPE, DVRAM::name()) } + fn construct_circuit( + cb: &mut CircuitBuilder, + params: &ProgramParams, + ) -> Result { + Ok(cb.namespace(|| Self::name(), |cb| C::construct_circuit(cb, params))?) + } + + fn generate_fixed_traces( + _config: &Self::TableConfig, + _num_fixed: usize, + _init_v: &Self::FixedInput, + ) -> RowMajorMatrix { + RowMajorMatrix::::new(0, 0, InstancePaddingStrategy::Default) + } + + fn assign_instances( + config: &Self::TableConfig, + num_witin: usize, + num_structural_witin: usize, + _multiplicity: &[HashMap], + final_v: &Self::WitnessInput, + ) -> Result, ZKVMError> { + // assume returned table is well-formed include padding + Ok( + >::assign_instances( + config, + num_witin, + num_structural_witin, + final_v, + )?, + ) + } +} + +/// This circuit is generalized version to handle all mmio records +pub struct LocalFinalRamCircuit<'a, const V_LIMBS: usize, E>(PhantomData<(&'a (), E)>); + +impl<'a, E: ExtensionField, const V_LIMBS: usize> TableCircuit + for LocalFinalRamCircuit<'a, V_LIMBS, E> +{ + type TableConfig = LocalFinalRAMTableConfig; + type FixedInput = (); + type WitnessInput = ( + &'a ShardContext<'a>, + &'a [(InstancePaddingStrategy, &'a [MemFinalRecord])], + ); + + fn name() -> String { + "LocalRAMTableFinal".to_string() + } + fn construct_circuit( cb: &mut CircuitBuilder, params: &ProgramParams, @@ -220,6 +304,49 @@ impl TableC )?) } + fn build_gkr_iop_circuit( + cb: &mut CircuitBuilder, + param: &ProgramParams, + ) -> Result<(Self::TableConfig, Option>), ZKVMError> { + let config = Self::construct_circuit(cb, param)?; + let r_table_len = cb.cs.r_table_expressions.len(); + + let selector = cb.create_structural_witin( + || "selector", + StructuralWitInType::EqualDistanceSequence { + // TODO determin proper size of max length + max_len: u32::MAX as usize, + offset: 0, + multi_factor: 0, + descending: false, + }, + ); + let selector_type = SelectorType::Prefix(E::BaseField::ZERO, selector.expr()); + + // all shared the same selector + let (out_evals, mut chip) = ( + [ + // r_record + (0..r_table_len).collect_vec(), + // w_record + vec![], + // lk_record + vec![], + // zero_record + vec![], + ], + Chip::new_from_cb(cb, 0), + ); + + // register selector to legacy constrain system + cb.cs.r_selector = Some(selector_type.clone()); + + let layer = Layer::from_circuit_builder(cb, "Rounds".to_string(), 0, out_evals); + chip.add_layer(layer); + + Ok((config, Some(chip.gkr_circuit()))) + } + fn generate_fixed_traces( _config: &Self::TableConfig, _num_fixed: usize, @@ -233,9 +360,64 @@ impl TableC num_witin: usize, num_structural_witin: usize, _multiplicity: &[HashMap], - final_v: &Self::WitnessInput, + (shard_ctx, final_mem): &Self::WitnessInput, + ) -> Result, ZKVMError> { + // assume returned table is well-formed include padding + Ok(Self::TableConfig::assign_instances( + config, + shard_ctx, + num_witin, + num_structural_witin, + final_mem, + )?) + } +} + +/// This circuit is generalized version to handle all mmio records +pub struct RamBusCircuit<'a, const V_LIMBS: usize, E>(PhantomData<(&'a (), E)>); + +impl<'a, E: ExtensionField, const V_LIMBS: usize> TableCircuit + for RamBusCircuit<'a, V_LIMBS, E> +{ + type TableConfig = RAMBusConfig; + type FixedInput = (); + type WitnessInput = ShardContext<'a>; + + fn name() -> String { + "RamBusCircuit".to_string() + } + + fn construct_circuit( + cb: &mut CircuitBuilder, + params: &ProgramParams, + ) -> Result { + Ok(cb.namespace( + || Self::name(), + |cb| Self::TableConfig::construct_circuit(cb, params), + )?) + } + + fn generate_fixed_traces( + _config: &Self::TableConfig, + _num_fixed: usize, + _init_v: &Self::FixedInput, + ) -> RowMajorMatrix { + RowMajorMatrix::::new(0, 0, InstancePaddingStrategy::Default) + } + + fn assign_instances( + config: &Self::TableConfig, + num_witin: usize, + num_structural_witin: usize, + _multiplicity: &[HashMap], + shard_ctx: &Self::WitnessInput, ) -> Result, ZKVMError> { // assume returned table is well-formed include padding - Ok(config.assign_instances(num_witin, num_structural_witin, final_v)?) + Ok(Self::TableConfig::assign_instances( + config, + shard_ctx, + num_witin, + num_structural_witin, + )?) } } diff --git a/ceno_zkvm/src/tables/ram/ram_impl.rs b/ceno_zkvm/src/tables/ram/ram_impl.rs index f92dc37cc..554c71235 100644 --- a/ceno_zkvm/src/tables/ram/ram_impl.rs +++ b/ceno_zkvm/src/tables/ram/ram_impl.rs @@ -1,62 +1,82 @@ -use std::{marker::PhantomData, sync::Arc}; - use ceno_emul::{Addr, Cycle, WORD_SIZE}; +use either::Either; use ff_ext::{ExtensionField, SmallField}; use gkr_iop::error::CircuitBuilderError; use itertools::Itertools; -use rayon::iter::{IndexedParallelIterator, ParallelIterator}; -use witness::{InstancePaddingStrategy, RowMajorMatrix, set_fixed_val, set_val}; +use rayon::iter::{ + IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator, +}; +use std::marker::PhantomData; +use witness::{ + InstancePaddingStrategy, RowMajorMatrix, next_pow2_instance_padding, set_fixed_val, set_val, +}; +use super::{ + MemInitRecord, + ram_circuit::{DynVolatileRamTable, MemFinalRecord, NonVolatileTable}, +}; use crate::{ chip_handler::general::PublicIOQuery, circuit_builder::{CircuitBuilder, SetTableSpec}, + e2e::ShardContext, instructions::riscv::constants::{LIMB_BITS, LIMB_MASK}, structs::ProgramParams, + tables::ram::ram_circuit::DynVolatileRamTableConfigTrait, }; use ff_ext::FieldInto; +use gkr_iop::RAMType; use multilinear_extensions::{ Expression, Fixed, StructuralWitIn, StructuralWitInType, ToExpr, WitIn, }; +use p3::field::FieldAlgebra; -use super::{ - MemInitRecord, - ram_circuit::{DynVolatileRamTable, MemFinalRecord, NonVolatileTable}, -}; +pub trait NonVolatileTableConfigTrait: Sized + Send + Sync { + type Config: Sized + Send + Sync; + + fn construct_circuit( + cb: &mut CircuitBuilder, + params: &ProgramParams, + ) -> Result; + + fn gen_init_state( + config: &Self::Config, + num_fixed: usize, + init_mem: &[MemInitRecord], + ) -> RowMajorMatrix; + + fn assign_instances( + config: &Self::Config, + num_witin: usize, + num_structural_witin: usize, + final_mem: &[MemFinalRecord], + ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError>; +} /// define a non-volatile memory with init value #[derive(Clone, Debug)] -pub struct NonVolatileTableConfig { +pub struct NonVolatileInitTableConfig { init_v: Vec, addr: Fixed, - final_v: Option>, - final_cycle: WitIn, - phantom: PhantomData, params: ProgramParams, } -impl NonVolatileTableConfig { - pub fn construct_circuit( +impl NonVolatileTableConfigTrait + for NonVolatileInitTableConfig +{ + type Config = NonVolatileInitTableConfig; + + fn construct_circuit( cb: &mut CircuitBuilder, params: &ProgramParams, ) -> Result { + assert!(NVRAM::WRITABLE); let init_v = (0..NVRAM::V_LIMBS) .map(|i| cb.create_fixed(|| format!("init_v_limb_{i}"))) .collect_vec(); let addr = cb.create_fixed(|| "addr"); - let final_cycle = cb.create_witin(|| "final_cycle"); - let final_v = if NVRAM::WRITABLE { - Some( - (0..NVRAM::V_LIMBS) - .map(|i| cb.create_witin(|| format!("final_v_limb_{i}"))) - .collect::>(), - ) - } else { - None - }; - let init_table = [ vec![(NVRAM::RAM_TYPE as usize).into()], vec![Expression::Fixed(addr)], @@ -65,18 +85,6 @@ impl NonVolatileTableConfig NonVolatileTableConfig( - &self, + fn gen_init_state( + config: &Self::Config, num_fixed: usize, init_mem: &[MemInitRecord], ) -> RowMajorMatrix { assert!( - NVRAM::len(&self.params).is_power_of_two(), + NVRAM::len(&config.params).is_power_of_two(), "{} len {} must be a power of 2", NVRAM::name(), - NVRAM::len(&self.params) + NVRAM::len(&config.params) ); let mut init_table = RowMajorMatrix::::new( - NVRAM::len(&self.params), + NVRAM::len(&config.params), num_fixed, InstancePaddingStrategy::Default, ); @@ -129,56 +126,31 @@ impl NonVolatileTableConfig> (l * LIMB_BITS)) & LIMB_MASK; set_fixed_val!(row, limb, (val as u64).into_f()); }); } - set_fixed_val!(row, self.addr, (rec.addr as u64).into_f()); + set_fixed_val!(row, config.addr, (rec.addr as u64).into_f()); }); init_table } /// TODO consider taking RowMajorMatrix as argument to save allocations. - pub fn assign_instances( - &self, - num_witin: usize, + fn assign_instances( + _config: &Self::Config, + _num_witin: usize, num_structural_witin: usize, - final_mem: &[MemFinalRecord], + _final_mem: &[MemFinalRecord], ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { assert_eq!(num_structural_witin, 0); - let mut final_table = RowMajorMatrix::::new( - NVRAM::len(&self.params), - num_witin, - InstancePaddingStrategy::Default, - ); - - final_table - .par_rows_mut() - .zip_eq(final_mem) - .for_each(|(row, rec)| { - if let Some(final_v) = &self.final_v { - if final_v.len() == 1 { - // Assign value directly. - set_val!(row, final_v[0], rec.value as u64); - } else { - // Assign value limbs. - final_v.iter().enumerate().for_each(|(l, limb)| { - let val = (rec.value >> (l * LIMB_BITS)) & LIMB_MASK; - set_val!(row, limb, val as u64); - }); - } - } - set_val!(row, self.final_cycle, rec.cycle); - }); - - Ok([final_table, RowMajorMatrix::empty()]) + Ok([RowMajorMatrix::empty(), RowMajorMatrix::empty()]) } } @@ -311,8 +283,11 @@ pub struct DynVolatileRamTableConfig DynVolatileRamTableConfig { - pub fn construct_circuit( +impl DynVolatileRamTableConfigTrait + for DynVolatileRamTableConfig +{ + type Config = DynVolatileRamTableConfig; + fn construct_circuit( cb: &mut CircuitBuilder, params: &ProgramParams, ) -> Result { @@ -385,59 +360,664 @@ impl DynVolatileRamTableConfig } /// TODO consider taking RowMajorMatrix as argument to save allocations. - pub fn assign_instances( - &self, + fn assign_instances( + config: &Self::Config, num_witin: usize, num_structural_witin: usize, final_mem: &[MemFinalRecord], ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { - assert!(final_mem.len() <= DVRAM::max_len(&self.params)); - assert!(DVRAM::max_len(&self.params).is_power_of_two()); - - let params = self.params.clone(); - let addr_id = self.addr.id as u64; - let addr_padding_fn = move |row: u64, col: u64| { - assert_eq!(col, addr_id); - DVRAM::addr(¶ms, row as usize) as u64 - }; + if final_mem.is_empty() { + return Ok([RowMajorMatrix::empty(), RowMajorMatrix::empty()]); + } - let mut witness = - RowMajorMatrix::::new(final_mem.len(), num_witin, InstancePaddingStrategy::Default); + let num_instances_padded = next_pow2_instance_padding(final_mem.len()); + assert!(num_instances_padded <= DVRAM::max_len(&config.params)); + assert!(DVRAM::max_len(&config.params).is_power_of_two()); + + let mut witness = RowMajorMatrix::::new( + num_instances_padded, + num_witin, + InstancePaddingStrategy::Default, + ); let mut structural_witness = RowMajorMatrix::::new( - final_mem.len(), + num_instances_padded, num_structural_witin, - InstancePaddingStrategy::Custom(Arc::new(addr_padding_fn)), + InstancePaddingStrategy::Default, ); witness .par_rows_mut() - .zip(structural_witness.par_rows_mut()) - .zip(final_mem) + .zip_eq(structural_witness.par_rows_mut()) .enumerate() - .for_each(|(i, ((row, structural_row), rec))| { - assert_eq!( - rec.addr, - DVRAM::addr(&self.params, i), - "rec.addr {:x} != expected {:x}", - rec.addr, - DVRAM::addr(&self.params, i), - ); + .for_each(|(i, (row, structural_row))| { + if cfg!(debug_assertions) + && let Some(addr) = final_mem.get(i).map(|rec| rec.addr) + { + debug_assert_eq!( + addr, + DVRAM::addr(&config.params, i), + "rec.addr {:x} != expected {:x}", + addr, + DVRAM::addr(&config.params, i), + ); + } - if self.final_v.len() == 1 { - // Assign value directly. - set_val!(row, self.final_v[0], rec.value as u64); - } else { - // Assign value limbs. - self.final_v.iter().enumerate().for_each(|(l, limb)| { - let val = (rec.value >> (l * LIMB_BITS)) & LIMB_MASK; - set_val!(row, limb, val as u64); - }); + if let Some(rec) = final_mem.get(i) { + if config.final_v.len() == 1 { + // Assign value directly. + set_val!(row, config.final_v[0], rec.value as u64); + } else { + // Assign value limbs. + config.final_v.iter().enumerate().for_each(|(l, limb)| { + let val = (rec.value >> (l * LIMB_BITS)) & LIMB_MASK; + set_val!(row, limb, val as u64); + }); + } + set_val!(row, config.final_cycle, rec.cycle); } - set_val!(row, self.final_cycle, rec.cycle); + set_val!( + structural_row, + config.addr, + DVRAM::addr(&config.params, i) as u64 + ); + }); + + Ok([witness, structural_witness]) + } +} + +/// volatile with all init value as 0 +/// dynamic address as witin, relied on augment of knowledge to prove address form +#[derive(Clone, Debug)] +pub struct DynVolatileRamTableInitConfig { + addr: StructuralWitIn, + + phantom: PhantomData, + params: ProgramParams, +} + +impl DynVolatileRamTableConfigTrait + for DynVolatileRamTableInitConfig +{ + type Config = DynVolatileRamTableInitConfig; + + fn construct_circuit( + cb: &mut CircuitBuilder, + params: &ProgramParams, + ) -> Result { + let max_len = DVRAM::max_len(params); + let addr = cb.create_structural_witin( + || "addr", + StructuralWitInType::EqualDistanceSequence { + max_len, + offset: DVRAM::offset_addr(params), + multi_factor: WORD_SIZE, + descending: DVRAM::DESCENDING, + }, + ); + + assert!(DVRAM::ZERO_INIT); + + let init_expr = vec![Expression::ZERO; DVRAM::V_LIMBS]; + + let init_table = [ + vec![(DVRAM::RAM_TYPE as usize).into()], + vec![addr.expr()], + init_expr, + vec![Expression::ZERO], // Initial cycle. + ] + .concat(); + + cb.w_table_record( + || "init_table", + DVRAM::RAM_TYPE, + SetTableSpec { + len: None, + structural_witins: vec![addr], + }, + init_table, + )?; + + Ok(Self { + addr, + phantom: PhantomData, + params: params.clone(), + }) + } + + /// TODO consider taking RowMajorMatrix as argument to save allocations. + fn assign_instances( + config: &Self::Config, + _num_witin: usize, + num_structural_witin: usize, + final_mem: &[MemFinalRecord], + ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { + if final_mem.is_empty() { + return Ok([RowMajorMatrix::empty(), RowMajorMatrix::empty()]); + } + + let num_instances_padded = next_pow2_instance_padding(final_mem.len()); + assert!(num_instances_padded <= DVRAM::max_len(&config.params)); + assert!(DVRAM::max_len(&config.params).is_power_of_two()); - set_val!(structural_row, self.addr, rec.addr as u64); + let mut structural_witness = RowMajorMatrix::::new( + num_instances_padded, + num_structural_witin, + InstancePaddingStrategy::Default, + ); + + structural_witness + .par_rows_mut() + .enumerate() + .for_each(|(i, structural_row)| { + if cfg!(debug_assertions) + && let Some(addr) = final_mem.get(i).map(|rec| rec.addr) + { + debug_assert_eq!( + addr, + DVRAM::addr(&config.params, i), + "rec.addr {:x} != expected {:x}", + addr, + DVRAM::addr(&config.params, i), + ); + } + set_val!( + structural_row, + config.addr, + DVRAM::addr(&config.params, i) as u64 + ); }); + Ok([RowMajorMatrix::empty(), structural_witness]) + } +} + +/// This table is generalized version to handle all mmio records +#[derive(Clone, Debug)] +pub struct LocalFinalRAMTableConfig { + addr_subset: WitIn, + ram_type: WitIn, + + final_v: Vec, + final_cycle: WitIn, +} + +impl LocalFinalRAMTableConfig { + pub fn construct_circuit( + cb: &mut CircuitBuilder, + _params: &ProgramParams, + ) -> Result { + let addr_subset = cb.create_witin(|| "addr_subset"); + let ram_type = cb.create_witin(|| "ram_type"); + + let final_v = (0..V_LIMBS) + .map(|i| cb.create_witin(|| format!("final_v_limb_{i}"))) + .collect::>(); + let final_cycle = cb.create_witin(|| "final_cycle"); + + let final_expr = final_v.iter().map(|v| v.expr()).collect_vec(); + let raw_final_table = [ + // a v t + vec![ram_type.expr()], + vec![addr_subset.expr()], + final_expr, + vec![final_cycle.expr()], + ] + .concat(); + let rlc_record = cb.rlc_chip_record(raw_final_table.clone()); + cb.r_table_rlc_record( + || "final_table", + // XXX we mixed all ram type here to save column allocation + ram_type.expr(), + SetTableSpec { + len: None, + structural_witins: vec![], + }, + raw_final_table, + rlc_record, + )?; + + Ok(Self { + addr_subset, + ram_type, + final_v, + final_cycle, + }) + } + + /// TODO consider taking RowMajorMatrix as argument to save allocations. + pub fn assign_instances( + &self, + shard_ctx: &ShardContext, + num_witin: usize, + num_structural_witin: usize, + final_mem: &[(InstancePaddingStrategy, &[MemFinalRecord])], + ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { + assert!(num_structural_witin == 0 || num_structural_witin == 1); + let num_structural_witin = num_structural_witin.max(1); + let selector_witin = WitIn { id: 0 }; + + let is_current_shard_mem_record = |record: &&MemFinalRecord| -> bool { + (shard_ctx.is_first_shard() && record.cycle == 0) + || shard_ctx.is_current_shard_cycle(record.cycle) + }; + + // collect each raw mem belong to this shard, BEFORE padding length + let current_shard_mems_len: Vec = final_mem + .par_iter() + .map(|(_, mem)| mem.par_iter().filter(is_current_shard_mem_record).count()) + .collect(); + + // deal with non-pow2 padding for first shard + // format Vec<(pad_len, pad_start_index)> + let padding_info = if shard_ctx.is_first_shard() { + final_mem + .iter() + .map(|(_, mem)| { + assert!(!mem.is_empty()); + ( + next_pow2_instance_padding(mem.len()) - mem.len(), + mem.len(), + mem[0].ram_type, + ) + }) + .collect_vec() + } else { + vec![(0, 0, RAMType::Undefined); final_mem.len()] + }; + + // calculate mem length + let mem_lens = current_shard_mems_len + .iter() + .zip_eq(&padding_info) + .map(|(raw_len, (pad_len, _, _))| raw_len + pad_len) + .collect_vec(); + let total_records = mem_lens.iter().sum(); + + let mut witness = + RowMajorMatrix::::new(total_records, num_witin, InstancePaddingStrategy::Default); + let mut structural_witness = RowMajorMatrix::::new( + total_records, + num_structural_witin, + InstancePaddingStrategy::Default, + ); + + let mut witness_mut_slices = Vec::with_capacity(final_mem.len()); + let mut structural_witness_mut_slices = Vec::with_capacity(final_mem.len()); + let mut witness_value_rest = witness.values.as_mut_slice(); + let mut structural_witness_value_rest = structural_witness.values.as_mut_slice(); + + for mem_len in mem_lens { + let witness_length = mem_len * num_witin; + let structural_witness_length = mem_len * num_structural_witin; + assert!( + witness_length <= witness_value_rest.len(), + "chunk size exceeds remaining data" + ); + assert!( + structural_witness_length <= structural_witness_value_rest.len(), + "chunk size exceeds remaining data" + ); + let (witness_left, witness_r) = witness_value_rest.split_at_mut(witness_length); + let (structural_witness_left, structural_witness_r) = + structural_witness_value_rest.split_at_mut(structural_witness_length); + witness_mut_slices.push(witness_left); + structural_witness_mut_slices.push(structural_witness_left); + witness_value_rest = witness_r; + structural_witness_value_rest = structural_witness_r; + } + + witness_mut_slices + .par_iter_mut() + .zip_eq(structural_witness_mut_slices.par_iter_mut()) + .zip_eq(final_mem.par_iter()) + .zip_eq(padding_info.par_iter()) + .for_each( + |( + ((witness, structural_witness), (padding_strategy, final_mem)), + (pad_size, pad_start_index, ram_type), + )| { + let mem_record_count = witness + .chunks_mut(num_witin) + .zip_eq(structural_witness.chunks_mut(num_structural_witin)) + .zip(final_mem.iter().filter(is_current_shard_mem_record)) + .map(|((row, structural_row), rec)| { + if self.final_v.len() == 1 { + // Assign value directly. + set_val!(row, self.final_v[0], rec.value as u64); + } else { + // Assign value limbs. + self.final_v.iter().enumerate().for_each(|(l, limb)| { + let val = (rec.value >> (l * LIMB_BITS)) & LIMB_MASK; + set_val!(row, limb, val as u64); + }); + } + set_val!(row, self.final_cycle, rec.cycle); + + set_val!(row, self.ram_type, rec.ram_type as u64); + set_val!(row, self.addr_subset, rec.addr as u64); + set_val!(structural_row, selector_witin, 1u64); + }) + .count(); + + if *pad_size > 0 && shard_ctx.is_first_shard() { + match padding_strategy { + InstancePaddingStrategy::Custom(pad_func) => { + witness[mem_record_count * num_witin..] + .chunks_mut(num_witin) + .zip_eq( + structural_witness + [mem_record_count * num_structural_witin..] + .chunks_mut(num_structural_witin), + ) + .zip_eq( + std::iter::successors(Some(*pad_start_index), |n| { + Some(*n + 1) + }) + .take(*pad_size), + ) + .for_each(|((row, structural_row), pad_index)| { + set_val!( + row, + self.addr_subset, + pad_func(pad_index as u64, self.addr_subset.id as u64) + ); + set_val!(row, self.ram_type, *ram_type as u64); + set_val!(structural_row, selector_witin, 1u64); + }); + } + _ => unimplemented!(), + } + } + }, + ); + + Ok([witness, structural_witness]) + } +} + +/// The general config to handle ram bus across all records +#[derive(Clone, Debug)] +pub struct RAMBusConfig { + addr_subset: WitIn, + + sel_read: StructuralWitIn, + sel_write: StructuralWitIn, + local_write_v: Vec, + local_read_v: Vec, + local_read_cycle: WitIn, +} + +impl RAMBusConfig { + pub fn construct_circuit( + cb: &mut CircuitBuilder, + _params: &ProgramParams, + ) -> Result { + let ram_type = cb.create_witin(|| "ram_type"); + let one = Expression::Constant(Either::Left(E::BaseField::ONE)); + let addr_subset = cb.create_witin(|| "addr_subset"); + // TODO add new selector to support sel_rw + let sel_read = cb.create_structural_witin( + || "sel_read", + StructuralWitInType::EqualDistanceSequence { + max_len: 0, + offset: 0, + multi_factor: WORD_SIZE, + descending: false, + }, + ); + let sel_write = cb.create_structural_witin( + || "sel_write", + StructuralWitInType::EqualDistanceSequence { + max_len: 0, + offset: 0, + multi_factor: WORD_SIZE, + descending: false, + }, + ); + + // local write + let local_write_v = (0..V_LIMBS) + .map(|i| cb.create_witin(|| format!("local_write_v_limb_{i}"))) + .collect::>(); + let local_write_v_expr = local_write_v.iter().map(|v| v.expr()).collect_vec(); + + // local read + let local_read_v = (0..V_LIMBS) + .map(|i| cb.create_witin(|| format!("local_read_v_limb_{i}"))) + .collect::>(); + let local_read_v_expr: Vec> = + local_read_v.iter().map(|v| v.expr()).collect_vec(); + let local_read_cycle = cb.create_witin(|| "local_read_cycle"); + + // TODO global write + // TODO global read + + // constraints + // read from global, write to local + // W_{local} = sel_read * local_write_record + (1 - sel_read) * ONE + let local_raw_write_record = [ + vec![ram_type.expr()], + vec![addr_subset.expr()], + local_write_v_expr.clone(), + vec![Expression::ZERO], // mem bus local init cycle always 0. + ] + .concat(); + let local_write_record = cb.rlc_chip_record(local_raw_write_record.clone()); + let local_write = + sel_read.expr() * local_write_record + (one.clone() - sel_read.expr()).expr(); + // local write, global read + cb.w_table_rlc_record( + || "local_write_record", + ram_type.expr(), + SetTableSpec { + len: None, + structural_witins: vec![sel_read], + }, + local_raw_write_record, + local_write, + )?; + // TODO R_{global} = mem_bus_with_read * (sel_read * global_read + (1-sel_read) * EC_INFINITY) + (1 - mem_bus_with_read) * EC_INFINITY + + // write to global, read from local + // R_{local} = sel_write * local_read_record + (1 - sel_write) * ONE + let local_raw_read_record = [ + vec![ram_type.expr()], + vec![addr_subset.expr()], + local_read_v_expr.clone(), + vec![local_read_cycle.expr()], + ] + .concat(); + let local_read_record = cb.rlc_chip_record(local_raw_read_record.clone()); + let local_read: Expression = + sel_write.expr() * local_read_record + (one.clone() - sel_write.expr()); + + // local read, global write + cb.r_table_rlc_record( + || "local_read_record", + ram_type.expr(), + SetTableSpec { + len: None, + structural_witins: vec![sel_write], + }, + local_raw_read_record, + local_read, + )?; + // TODO W_{local} = mem_bus_with_write * (sel_write * global_write + (1 - sel_write) * EC_INFINITY) + (1 - mem_bus_with_write) * EC_INFINITY + + Ok(Self { + addr_subset, + sel_write, + sel_read, + local_write_v, + local_read_v, + local_read_cycle, + }) + } + + /// TODO consider taking RowMajorMatrix as argument to save allocations. + pub fn assign_instances( + &self, + shard_ctx: &ShardContext, + num_witin: usize, + num_structural_witin: usize, + ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { + let (global_read_records, global_write_records) = + (shard_ctx.read_records(), shard_ctx.write_records()); + assert_eq!(global_read_records.len(), global_write_records.len()); + let raw_write_len: usize = global_write_records.iter().map(|m| m.len()).sum(); + let raw_read_len: usize = global_read_records.iter().map(|m| m.len()).sum(); + if raw_read_len + raw_write_len == 0 { + return Ok([RowMajorMatrix::empty(), RowMajorMatrix::empty()]); + } + // TODO refactor to deal with only read/write + + let witness_length = { + let max_len = raw_read_len.max(raw_write_len); + // first half write, second half read + next_pow2_instance_padding(max_len) * 2 + }; + let mut witness = + RowMajorMatrix::::new(witness_length, num_witin, InstancePaddingStrategy::Default); + let mut structural_witness = RowMajorMatrix::::new( + witness_length, + num_structural_witin, + InstancePaddingStrategy::Default, + ); + let witness_mid = witness.values.len() / 2; + let (witness_write, witness_read) = witness.values.split_at_mut(witness_mid); + let structural_witness_mid = structural_witness.values.len() / 2; + let (structural_witness_write, structural_witness_read) = structural_witness + .values + .split_at_mut(structural_witness_mid); + + let mut witness_write_mut_slices = Vec::with_capacity(global_write_records.len()); + let mut witness_read_mut_slices = Vec::with_capacity(global_read_records.len()); + let mut structural_witness_write_mut_slices = + Vec::with_capacity(global_write_records.len()); + let mut structural_witness_read_mut_slices = Vec::with_capacity(global_read_records.len()); + let mut witness_write_value_rest = witness_write; + let mut witness_read_value_rest = witness_read; + let mut structural_witness_write_value_rest = structural_witness_write; + let mut structural_witness_read_value_rest = structural_witness_read; + + for (global_read_record, global_write_record) in + global_read_records.iter().zip_eq(global_write_records) + { + let witness_write_length = global_write_record.len() * num_witin; + let witness_read_length = global_read_record.len() * num_witin; + let structural_witness_write_length = global_write_record.len() * num_structural_witin; + let structural_witness_read_length = global_read_record.len() * num_structural_witin; + assert!( + witness_write_length <= witness_write_value_rest.len(), + "chunk size exceeds remaining data" + ); + assert!( + witness_read_length <= witness_read_value_rest.len(), + "chunk size exceeds remaining data" + ); + assert!( + structural_witness_write_length <= structural_witness_write_value_rest.len(), + "chunk size exceeds remaining data" + ); + assert!( + structural_witness_read_length <= structural_witness_read_value_rest.len(), + "chunk size exceeds remaining data" + ); + let (witness_write, witness_write_r) = + witness_write_value_rest.split_at_mut(witness_write_length); + witness_write_mut_slices.push(witness_write); + witness_write_value_rest = witness_write_r; + + let (witness_read, witness_read_r) = + witness_read_value_rest.split_at_mut(witness_read_length); + witness_read_mut_slices.push(witness_read); + witness_read_value_rest = witness_read_r; + + let (structural_witness_write, structural_witness_write_r) = + structural_witness_write_value_rest.split_at_mut(structural_witness_write_length); + structural_witness_write_mut_slices.push(structural_witness_write); + structural_witness_write_value_rest = structural_witness_write_r; + + let (structural_witness_read, structural_witness_read_r) = + structural_witness_read_value_rest.split_at_mut(structural_witness_read_length); + structural_witness_read_mut_slices.push(structural_witness_read); + structural_witness_read_value_rest = structural_witness_read_r; + } + + rayon::join( + // global write, local read + || { + witness_write_mut_slices + .par_iter_mut() + .zip_eq(structural_witness_write_mut_slices.par_iter_mut()) + .zip_eq(global_write_records.par_iter()) + .for_each( + |((witness_write, structural_witness_write), global_write_mem)| { + witness_write + .chunks_mut(num_witin) + .zip_eq(structural_witness_write.chunks_mut(num_structural_witin)) + .zip_eq(global_write_mem.values()) + .for_each(|((row, structural_row), rec)| { + if self.local_read_v.len() == 1 { + // Assign value directly. + set_val!(row, self.local_read_v[0], rec.value as u64); + } else { + // Assign value limbs. + self.local_read_v.iter().enumerate().for_each( + |(l, limb)| { + let val = + (rec.value >> (l * LIMB_BITS)) & LIMB_MASK; + set_val!(row, limb, val as u64); + }, + ); + } + set_val!(row, self.local_read_cycle, rec.cycle); + + set_val!(row, self.addr_subset, rec.addr.baddr().0 as u64); + set_val!(structural_row, self.sel_write, 1u64); + + // TODO assign W_{global} + }); + }, + ); + }, + // global read, local write + || { + witness_read_mut_slices + .par_iter_mut() + .zip_eq(structural_witness_read_mut_slices.par_iter_mut()) + .zip_eq(global_read_records.par_iter()) + .for_each( + |((witness_read, structural_witness_read), global_read_mem)| { + witness_read + .chunks_mut(num_witin) + .zip_eq(structural_witness_read.chunks_mut(num_structural_witin)) + .zip_eq(global_read_mem.values()) + .for_each(|((row, structural_row), rec)| { + if self.local_write_v.len() == 1 { + // Assign value directly. + set_val!(row, self.local_write_v[0], rec.value as u64); + } else { + // Assign value limbs. + self.local_write_v.iter().enumerate().for_each( + |(l, limb)| { + let val = + (rec.value >> (l * LIMB_BITS)) & LIMB_MASK; + set_val!(row, limb, val as u64); + }, + ); + } + set_val!(row, self.addr_subset, rec.addr.baddr().0 as u64); + set_val!(structural_row, self.sel_read, 1u64); + + // TODO assign R_{global} + }); + }, + ); + }, + ); + structural_witness.padding_by_strategy(); Ok([witness, structural_witness]) } @@ -456,6 +1036,7 @@ mod tests { use ceno_emul::WORD_SIZE; use ff_ext::GoldilocksExt2 as E; + use gkr_iop::RAMType; use itertools::Itertools; use multilinear_extensions::mle::MultilinearExtension; use p3::{field::FieldAlgebra, goldilocks::Goldilocks as F}; @@ -474,6 +1055,7 @@ mod tests { let some_non_2_pow = 26; let input = (0..some_non_2_pow) .map(|i| MemFinalRecord { + ram_type: RAMType::Memory, addr: HintsTable::addr(&def_params, i), cycle: 0, value: 0, diff --git a/gkr_iop/src/chip.rs b/gkr_iop/src/chip.rs index 1b33bb1de..10048418e 100644 --- a/gkr_iop/src/chip.rs +++ b/gkr_iop/src/chip.rs @@ -40,11 +40,17 @@ impl Chip { n_evaluations: cb.cs.w_expressions.len() + cb.cs.r_expressions.len() + cb.cs.lk_expressions.len() + + cb.cs.w_table_expressions.len() + + cb.cs.r_table_expressions.len() + + cb.cs.lk_table_expressions.len() * 2 + cb.cs.num_fixed + cb.cs.num_witin as usize, final_out_evals: (0..cb.cs.w_expressions.len() + cb.cs.r_expressions.len() - + cb.cs.lk_expressions.len()) + + cb.cs.lk_expressions.len() + + cb.cs.w_table_expressions.len() + + cb.cs.r_table_expressions.len() + + cb.cs.lk_table_expressions.len() * 2) .collect_vec(), layers: vec![], } diff --git a/gkr_iop/src/circuit_builder.rs b/gkr_iop/src/circuit_builder.rs index e4129bfe8..395b9e6c9 100644 --- a/gkr_iop/src/circuit_builder.rs +++ b/gkr_iop/src/circuit_builder.rs @@ -108,14 +108,14 @@ pub struct ConstraintSystem { pub r_expressions_namespace_map: Vec, // for each read expression we store its ram type and original value before doing RLC // the original value will be used for debugging - pub r_ram_types: Vec<(RAMType, Vec>)>, + pub r_ram_types: Vec<(Expression, Vec>)>, pub w_selector: Option>, pub w_expressions: Vec>, pub w_expressions_namespace_map: Vec, // for each write expression we store its ram type and original value before doing RLC // the original value will be used for debugging - pub w_ram_types: Vec<(RAMType, Vec>)>, + pub w_ram_types: Vec<(Expression, Vec>)>, /// init/final ram expression pub r_table_expressions: Vec>, @@ -329,12 +329,27 @@ impl ConstraintSystem { N: FnOnce() -> NR, { let rlc_record = self.rlc_chip_record(record.clone()); - assert_eq!( - rlc_record.degree(), - 1, - "rlc record degree {} != 1", - rlc_record.degree() - ); + self.r_table_rlc_record( + name_fn, + (ram_type as u64).into(), + table_spec, + record, + rlc_record, + ) + } + + pub fn r_table_rlc_record( + &mut self, + name_fn: N, + ram_type: Expression, + table_spec: SetTableSpec, + record: Vec>, + rlc_record: Expression, + ) -> Result<(), CircuitBuilderError> + where + NR: Into, + N: FnOnce() -> NR, + { self.r_table_expressions.push(SetTableExpression { expr: rlc_record, table_spec, @@ -358,12 +373,27 @@ impl ConstraintSystem { N: FnOnce() -> NR, { let rlc_record = self.rlc_chip_record(record.clone()); - assert_eq!( - rlc_record.degree(), - 1, - "rlc record degree {} != 1", - rlc_record.degree() - ); + self.w_table_rlc_record( + name_fn, + (ram_type as u64).into(), + table_spec, + record, + rlc_record, + ) + } + + pub fn w_table_rlc_record( + &mut self, + name_fn: N, + ram_type: Expression, + table_spec: SetTableSpec, + record: Vec>, + rlc_record: Expression, + ) -> Result<(), CircuitBuilderError> + where + NR: Into, + N: FnOnce() -> NR, + { self.w_table_expressions.push(SetTableExpression { expr: rlc_record, table_spec, @@ -387,7 +417,7 @@ impl ConstraintSystem { self.r_expressions_namespace_map.push(path); // Since r_expression is RLC(record) and when we're debugging // it's helpful to recover the value of record itself. - self.r_ram_types.push((ram_type, record)); + self.r_ram_types.push(((ram_type as u64).into(), record)); Ok(()) } @@ -401,7 +431,7 @@ impl ConstraintSystem { self.w_expressions.push(rlc_record); let path = self.ns.compute_path(name_fn().into()); self.w_expressions_namespace_map.push(path); - self.w_ram_types.push((ram_type, record)); + self.w_ram_types.push(((ram_type as u64).into(), record)); Ok(()) } @@ -579,6 +609,22 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { .r_table_record(name_fn, ram_type, table_spec, record) } + pub fn r_table_rlc_record( + &mut self, + name_fn: N, + ram_type: Expression, + table_spec: SetTableSpec, + record: Vec>, + rlc_record: Expression, + ) -> Result<(), CircuitBuilderError> + where + NR: Into, + N: FnOnce() -> NR, + { + self.cs + .r_table_rlc_record(name_fn, ram_type, table_spec, record, rlc_record) + } + pub fn w_table_record( &mut self, name_fn: N, @@ -594,6 +640,22 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { .w_table_record(name_fn, ram_type, table_spec, record) } + pub fn w_table_rlc_record( + &mut self, + name_fn: N, + ram_type: Expression, + table_spec: SetTableSpec, + record: Vec>, + rlc_record: Expression, + ) -> Result<(), CircuitBuilderError> + where + NR: Into, + N: FnOnce() -> NR, + { + self.cs + .w_table_rlc_record(name_fn, ram_type, table_spec, record, rlc_record) + } + pub fn read_record( &mut self, name_fn: N, diff --git a/gkr_iop/src/gkr/layer.rs b/gkr_iop/src/gkr/layer.rs index a337dde30..6bd76af68 100644 --- a/gkr_iop/src/gkr/layer.rs +++ b/gkr_iop/src/gkr/layer.rs @@ -1,3 +1,4 @@ +use either::Either; use ff_ext::ExtensionField; use itertools::{Itertools, chain, izip}; use linear_layer::{LayerClaims, LinearLayer}; @@ -319,9 +320,9 @@ impl Layer { n_challenges: usize, out_evals: OutEvalGroups, ) -> Layer { - let w_len = cb.cs.w_expressions.len(); - let r_len = cb.cs.r_expressions.len(); - let lk_len = cb.cs.lk_expressions.len(); + let w_len = cb.cs.w_expressions.len() + cb.cs.w_table_expressions.len(); + let r_len = cb.cs.r_expressions.len() + cb.cs.r_table_expressions.len(); + let lk_len = cb.cs.lk_expressions.len() + cb.cs.lk_table_expressions.len() * 2; // logup lk table include p, q let zero_len = cb.cs.assert_zero_expressions.len() + cb.cs.assert_zero_sumcheck_expressions.len(); @@ -331,9 +332,12 @@ impl Layer { assert_eq!(lookup_evals.len(), lk_len); assert_eq!(zero_evals.len(), zero_len); - let non_zero_expr_len = cb.cs.w_expressions_namespace_map.len() - + cb.cs.r_expressions_namespace_map.len() - + cb.cs.lk_expressions.len(); + let non_zero_expr_len = cb.cs.w_expressions.len() + + cb.cs.w_table_expressions.len() + + cb.cs.r_expressions.len() + + cb.cs.r_table_expressions.len() + + cb.cs.lk_expressions.len() + + cb.cs.lk_table_expressions.len() * 2; let zero_expr_len = cb.cs.assert_zero_expressions.len() + cb.cs.assert_zero_sumcheck_expressions.len(); @@ -341,88 +345,116 @@ impl Layer { let mut expr_names = Vec::with_capacity(non_zero_expr_len + zero_expr_len); let mut expressions = Vec::with_capacity(non_zero_expr_len + zero_expr_len); - // process r_record - let evals = - Self::dedup_last_selector_evals(cb.cs.r_selector.as_ref().unwrap(), &mut expr_evals); - for (idx, ((ram_expr, name), ram_eval)) in cb - .cs - .r_expressions - .iter() - .zip_eq(&cb.cs.r_expressions_namespace_map) + if let Some(r_selector) = cb.cs.r_selector.as_ref() { + // process r_record + let evals = Self::dedup_last_selector_evals(r_selector, &mut expr_evals); + for (idx, ((ram_expr, name), ram_eval)) in (cb + .cs + .r_expressions + .iter() + .chain(cb.cs.r_table_expressions.iter().map(|t| &t.expr))) + .zip_eq( + cb.cs + .r_expressions_namespace_map + .iter() + .chain(&cb.cs.r_table_expressions_namespace_map), + ) .zip_eq(&r_record_evals) .enumerate() - { - expressions.push(ram_expr - E::BaseField::ONE.expr()); - evals.push(EvalExpression::::Linear( - // evaluation = claim * one - one (padding) - *ram_eval, - E::BaseField::ONE.expr().into(), - E::BaseField::ONE.neg().expr().into(), - )); - expr_names.push(format!("{}/{idx}", name)); + { + expressions.push(ram_expr - E::BaseField::ONE.expr()); + evals.push(EvalExpression::::Linear( + // evaluation = claim * one - one (padding) + *ram_eval, + E::BaseField::ONE.expr().into(), + E::BaseField::ONE.neg().expr().into(), + )); + expr_names.push(format!("{}/{idx}", name)); + } } - // process w_record - let evals = - Self::dedup_last_selector_evals(cb.cs.w_selector.as_ref().unwrap(), &mut expr_evals); - for (idx, ((ram_expr, name), ram_eval)) in cb - .cs - .w_expressions - .iter() - .zip_eq(&cb.cs.w_expressions_namespace_map) + if let Some(w_selector) = cb.cs.w_selector.as_ref() { + // process w_record + let evals = Self::dedup_last_selector_evals(w_selector, &mut expr_evals); + for (idx, ((ram_expr, name), ram_eval)) in (cb + .cs + .w_expressions + .iter() + .chain(cb.cs.w_table_expressions.iter().map(|t| &t.expr))) + .zip_eq( + cb.cs + .w_expressions_namespace_map + .iter() + .chain(&cb.cs.w_table_expressions_namespace_map), + ) .zip_eq(&w_record_evals) .enumerate() - { - expressions.push(ram_expr - E::BaseField::ONE.expr()); - evals.push(EvalExpression::::Linear( - // evaluation = claim * one - one (padding) - *ram_eval, - E::BaseField::ONE.expr().into(), - E::BaseField::ONE.neg().expr().into(), - )); - expr_names.push(format!("{}/{idx}", name)); + { + expressions.push(ram_expr - E::BaseField::ONE.expr()); + evals.push(EvalExpression::::Linear( + // evaluation = claim * one - one (padding) + *ram_eval, + E::BaseField::ONE.expr().into(), + E::BaseField::ONE.neg().expr().into(), + )); + expr_names.push(format!("{}/{idx}", name)); + } } - // process lookup records - let evals = - Self::dedup_last_selector_evals(cb.cs.lk_selector.as_ref().unwrap(), &mut expr_evals); - for (idx, ((lookup, name), lookup_eval)) in cb - .cs - .lk_expressions - .iter() - .zip_eq(&cb.cs.lk_expressions_namespace_map) + if let Some(lk_selector) = cb.cs.lk_selector.as_ref() { + // process lookup records + let evals = Self::dedup_last_selector_evals(lk_selector, &mut expr_evals); + for (idx, ((lookup, name), lookup_eval)) in (cb + .cs + .lk_expressions + .iter() + .chain(cb.cs.lk_table_expressions.iter().map(|t| &t.multiplicity)) + .chain(cb.cs.lk_table_expressions.iter().map(|t| &t.values))) + .zip_eq(if cb.cs.lk_table_expressions.is_empty() { + Either::Left(cb.cs.lk_expressions_namespace_map.iter()) + } else { + // repeat expressions_namespace_map twice to deal with lk p, q + Either::Right( + cb.cs + .lk_expressions_namespace_map + .iter() + .chain(&cb.cs.lk_expressions_namespace_map), + ) + }) .zip_eq(&lookup_evals) .enumerate() - { - expressions.push(lookup - cb.cs.chip_record_alpha.clone()); - evals.push(EvalExpression::::Linear( - // evaluation = claim * one - alpha (padding) - *lookup_eval, - E::BaseField::ONE.expr().into(), - cb.cs.chip_record_alpha.clone().neg().into(), - )); - expr_names.push(format!("{}/{idx}", name)); + { + expressions.push(lookup - cb.cs.chip_record_alpha.clone()); + evals.push(EvalExpression::::Linear( + // evaluation = claim * one - alpha (padding) + *lookup_eval, + E::BaseField::ONE.expr().into(), + cb.cs.chip_record_alpha.clone().neg().into(), + )); + expr_names.push(format!("{}/{idx}", name)); + } } - // process zero_record - let evals = - Self::dedup_last_selector_evals(cb.cs.zero_selector.as_ref().unwrap(), &mut expr_evals); - for (idx, (zero_expr, name)) in izip!( - 0.., - chain!( - cb.cs - .assert_zero_expressions - .iter() - .zip_eq(&cb.cs.assert_zero_expressions_namespace_map), - cb.cs - .assert_zero_sumcheck_expressions - .iter() - .zip_eq(&cb.cs.assert_zero_sumcheck_expressions_namespace_map) - ) - ) { - expressions.push(zero_expr.clone()); - evals.push(EvalExpression::Zero); - expr_names.push(format!("{}/{idx}", name)); + if let Some(zero_selector) = cb.cs.zero_selector.as_ref() { + // process zero_record + let evals = Self::dedup_last_selector_evals(zero_selector, &mut expr_evals); + for (idx, (zero_expr, name)) in izip!( + 0.., + chain!( + cb.cs + .assert_zero_expressions + .iter() + .zip_eq(&cb.cs.assert_zero_expressions_namespace_map), + cb.cs + .assert_zero_sumcheck_expressions + .iter() + .zip_eq(&cb.cs.assert_zero_sumcheck_expressions_namespace_map) + ) + ) { + expressions.push(zero_expr.clone()); + evals.push(EvalExpression::Zero); + expr_names.push(format!("{}/{idx}", name)); + } } // Sort expressions, expr_names, and evals according to eval.0 and classify evals. diff --git a/gkr_iop/src/gkr/layer/cpu/mod.rs b/gkr_iop/src/gkr/layer/cpu/mod.rs index fa4c33c5e..95d315f25 100644 --- a/gkr_iop/src/gkr/layer/cpu/mod.rs +++ b/gkr_iop/src/gkr/layer/cpu/mod.rs @@ -168,6 +168,7 @@ impl> ZerocheckLayerProver ) ) .collect_vec(); + // zero check eq || rotation eq let mut eqs = layer .out_sel_and_eval_exprs @@ -221,15 +222,16 @@ impl> ZerocheckLayerProver layer.n_structural_witin, layer.n_fixed, ); + let builder = VirtualPolynomialsBuilder::new_with_mles(num_threads, max_num_variables, all_witins); let span = entered_span!("IOPProverState::prove", profiling_4 = true); let (proof, prover_state) = IOPProverState::prove( builder.to_virtual_polys_with_monomial_terms( - &layer + layer .main_sumcheck_expression_monomial_terms - .clone() + .as_ref() .unwrap(), pub_io_evals, &main_sumcheck_challenges, diff --git a/gkr_iop/src/lib.rs b/gkr_iop/src/lib.rs index fc69037ff..a5e20f704 100644 --- a/gkr_iop/src/lib.rs +++ b/gkr_iop/src/lib.rs @@ -7,6 +7,7 @@ use either::Either; use ff_ext::ExtensionField; use multilinear_extensions::{Expression, impl_expr_from_unsigned, mle::ArcMultilinearExtension}; use std::marker::PhantomData; +use strum_macros::EnumIter; use transcript::Transcript; use witness::RowMajorMatrix; @@ -77,12 +78,13 @@ pub struct ProtocolVerifier, PCS>( PhantomData<(E, Trans, PCS)>, ); -#[derive(Clone, Debug, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[derive(Clone, Debug, Copy, EnumIter, PartialEq, Eq, serde::Serialize, serde::Deserialize)] #[repr(usize)] pub enum RAMType { - GlobalState, + GlobalState = 0, Register, Memory, + Undefined, } impl_expr_from_unsigned!(RAMType);