Skip to content
Draft
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions ceno_zkvm/src/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use crate::{
circuit_builder::CircuitBuilder,
error::ZKVMError,
structs::{ChallengeId, RAMType, WitnessId},
uint::util::SimpleVecPool,
};

#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
Expand Down Expand Up @@ -141,6 +142,74 @@ impl<E: ExtensionField> Expression<E> {
}
}

#[allow(clippy::too_many_arguments)]
pub fn evaluate_with_instance_pool<T>(
&self,
fixed_in: &impl Fn(&Fixed) -> T,
wit_in: &impl Fn(WitnessId) -> T, // witin id
instance: &impl Fn(Instance) -> T,
constant: &impl Fn(E::BaseField) -> T,
challenge: &impl Fn(ChallengeId, usize, E, E) -> T,
sum: &impl Fn(T, T, &mut SimpleVecPool<Vec<E>>, &mut SimpleVecPool<Vec<E::BaseField>>) -> T,
product: &impl Fn(T, T, &mut SimpleVecPool<Vec<E>>, &mut SimpleVecPool<Vec<E::BaseField>>) -> T,
scaled: &impl Fn(
T,
T,
T,
&mut SimpleVecPool<Vec<E>>,
&mut SimpleVecPool<Vec<E::BaseField>>,
) -> T,
pool_e: &mut SimpleVecPool<Vec<E>>,
pool_b: &mut SimpleVecPool<Vec<E::BaseField>>,
) -> T {
match self {
Expression::Fixed(f) => fixed_in(f),
Expression::WitIn(witness_id) => wit_in(*witness_id),
Expression::Instance(i) => instance(*i),
Expression::Constant(scalar) => constant(*scalar),
Expression::Sum(a, b) => {
let a = a.evaluate_with_instance_pool(
fixed_in, wit_in, instance, constant, challenge, sum, product, scaled, pool_e,
pool_b,
);
let b = b.evaluate_with_instance_pool(
fixed_in, wit_in, instance, constant, challenge, sum, product, scaled, pool_e,
pool_b,
);
sum(a, b, pool_e, pool_b)
}
Expression::Product(a, b) => {
let a = a.evaluate_with_instance_pool(
fixed_in, wit_in, instance, constant, challenge, sum, product, scaled, pool_e,
pool_b,
);
let b = b.evaluate_with_instance_pool(
fixed_in, wit_in, instance, constant, challenge, sum, product, scaled, pool_e,
pool_b,
);
product(a, b, pool_e, pool_b)
}
Expression::ScaledSum(x, a, b) => {
let x = x.evaluate_with_instance_pool(
fixed_in, wit_in, instance, constant, challenge, sum, product, scaled, pool_e,
pool_b,
);
let a = a.evaluate_with_instance_pool(
fixed_in, wit_in, instance, constant, challenge, sum, product, scaled, pool_e,
pool_b,
);
let b = b.evaluate_with_instance_pool(
fixed_in, wit_in, instance, constant, challenge, sum, product, scaled, pool_e,
pool_b,
);
scaled(x, a, b, pool_e, pool_b)
}
Expression::Challenge(challenge_id, pow, scalar, offset) => {
challenge(*challenge_id, *pow, *scalar, *offset)
}
}
}

pub fn is_monomial_form(&self) -> bool {
Self::is_monomial_form_inner(MonomialState::SumTerm, self)
}
Expand Down
1 change: 1 addition & 0 deletions ceno_zkvm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#![feature(stmt_expr_attributes)]
#![feature(variant_count)]
#![feature(strict_overflow_ops)]
#![feature(sync_unsafe_cell)]

pub mod error;
pub mod instructions;
Expand Down
77 changes: 55 additions & 22 deletions ceno_zkvm/src/scheme/mock_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ use ff_ext::ExtensionField;
use generic_static::StaticTypeMap;
use goldilocks::SmallField;
use itertools::{Itertools, enumerate, izip};
use multilinear_extensions::{mle::IntoMLEs, virtual_poly_v2::ArcMultilinearExtension};
use multilinear_extensions::{
mle::IntoMLEs, util::max_usable_threads, virtual_poly_v2::ArcMultilinearExtension,
};
use rand::thread_rng;
use std::{
collections::{HashMap, HashSet},
Expand Down Expand Up @@ -426,6 +428,7 @@ impl<'a, E: ExtensionField + Hash> MockProver<E> {
challenge: Option<[E; 2]>,
lkm: Option<LkMultiplicity>,
) -> Result<(), Vec<MockProverError<E>>> {
let n_threads = max_usable_threads();
let program = Program::new(
CENO_PLATFORM.pc_base(),
CENO_PLATFORM.pc_base(),
Expand Down Expand Up @@ -473,10 +476,12 @@ impl<'a, E: ExtensionField + Hash> MockProver<E> {
let (left, right) = expr.unpack_sum().unwrap();
let right = right.neg();

let left_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, &left);
let left_evaluated =
wit_infer_by_expr(&[], wits_in, pi, &challenge, &left, n_threads);
let left_evaluated = left_evaluated.get_base_field_vec();

let right_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, &right);
let right_evaluated =
wit_infer_by_expr(&[], wits_in, pi, &challenge, &right, n_threads);
let right_evaluated = right_evaluated.get_base_field_vec();

// left_evaluated.len() ?= right_evaluated.len() due to padding instance
Expand All @@ -496,7 +501,8 @@ impl<'a, E: ExtensionField + Hash> MockProver<E> {
}
} else {
// contains require_zero
let expr_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, expr);
let expr_evaluated =
wit_infer_by_expr(&[], wits_in, pi, &challenge, expr, n_threads);
let expr_evaluated = expr_evaluated.get_base_field_vec();

for (inst_id, element) in enumerate(expr_evaluated) {
Expand All @@ -519,7 +525,7 @@ impl<'a, E: ExtensionField + Hash> MockProver<E> {
.iter()
.zip_eq(cb.cs.lk_expressions_namespace_map.iter())
{
let expr_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, expr);
let expr_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, expr, n_threads);
let expr_evaluated = expr_evaluated.get_ext_field_vec();

// Check each lookup expr exists in t vec
Expand Down Expand Up @@ -550,7 +556,7 @@ impl<'a, E: ExtensionField + Hash> MockProver<E> {
.map(|expr| {
// TODO generalized to all inst_id
let inst_id = 0;
wit_infer_by_expr(&[], wits_in, pi, &challenge, expr)
wit_infer_by_expr(&[], wits_in, pi, &challenge, expr, n_threads)
.get_base_field_vec()[inst_id]
.to_canonical_u64()
})
Expand Down Expand Up @@ -742,6 +748,7 @@ Hints:
witnesses: &ZKVMWitnesses<E>,
pi: &PublicValues<u32>,
) {
let n_threads = max_usable_threads();
let instance = pi
.to_vec::<E>()
.concat()
Expand Down Expand Up @@ -815,10 +822,16 @@ Hints:
.zip(cs.lk_expressions_namespace_map.clone().into_iter())
.zip(cs.lk_expressions_items_map.clone().into_iter())
{
let lk_input =
(wit_infer_by_expr(&fixed, &witness, &pi_mles, &challenges, expr)
.get_ext_field_vec())[..num_rows]
.to_vec();
let lk_input = (wit_infer_by_expr(
&fixed,
&witness,
&pi_mles,
&challenges,
expr,
n_threads,
)
.get_ext_field_vec())[..num_rows]
.to_vec();
rom_inputs.entry(rom_type).or_default().push((
lk_input,
circuit_name.clone(),
Expand All @@ -838,17 +851,24 @@ Hints:
.iter()
.zip(cs.lk_expressions_items_map.clone().into_iter())
{
let lk_table =
wit_infer_by_expr(&fixed, &witness, &pi_mles, &challenges, &expr.values)
.get_ext_field_vec()
.to_vec();
let lk_table = wit_infer_by_expr(
&fixed,
&witness,
&pi_mles,
&challenges,
&expr.values,
n_threads,
)
.get_ext_field_vec()
.to_vec();

let multiplicity = wit_infer_by_expr(
&fixed,
&witness,
&pi_mles,
&challenges,
&expr.multiplicity,
n_threads,
)
.get_base_field_vec()
.to_vec();
Expand Down Expand Up @@ -968,10 +988,16 @@ Hints:
.zip_eq(cs.w_ram_types.iter())
.filter(|((_, _), (ram_type, _))| *ram_type == $ram_type)
{
let write_rlc_records =
(wit_infer_by_expr(fixed, witness, &pi_mles, &challenges, w_rlc_expr)
.get_ext_field_vec())[..*num_rows]
.to_vec();
let write_rlc_records = (wit_infer_by_expr(
fixed,
witness,
&pi_mles,
&challenges,
w_rlc_expr,
n_threads,
)
.get_ext_field_vec())[..*num_rows]
.to_vec();

if $ram_type == RAMType::GlobalState {
// w_exprs = [GlobalState, pc, timestamp]
Expand All @@ -986,6 +1012,7 @@ Hints:
&pi_mles,
&challenges,
expr,
n_threads,
);
v.get_base_field_vec()[..*num_rows].to_vec()
})
Expand Down Expand Up @@ -1030,10 +1057,16 @@ Hints:
.zip_eq(cs.r_ram_types.iter())
.filter(|((_, _), (ram_type, _))| *ram_type == $ram_type)
{
let read_records =
wit_infer_by_expr(fixed, witness, &pi_mles, &challenges, r_expr)
.get_ext_field_vec()[..*num_rows]
.to_vec();
let read_records = wit_infer_by_expr(
fixed,
witness,
&pi_mles,
&challenges,
r_expr,
n_threads,
)
.get_ext_field_vec()[..*num_rows]
.to_vec();
let mut records = vec![];
for (row, record) in enumerate(read_records) {
// TODO: return error
Expand Down
16 changes: 9 additions & 7 deletions ceno_zkvm/src/scheme/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use itertools::{Itertools, enumerate, izip};
use mpcs::PolynomialCommitmentScheme;
use multilinear_extensions::{
mle::{IntoMLE, MultilinearExtension},
util::ceil_log2,
util::{ceil_log2, max_usable_threads},
virtual_poly::build_eq_x_r_vec,
virtual_poly_v2::ArcMultilinearExtension,
};
Expand Down Expand Up @@ -238,14 +238,15 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
let wit_inference_span = entered_span!("wit_inference", profiling_3 = true);
// main constraint: read/write record witness inference
let record_span = entered_span!("record");
let n_threads = max_usable_threads();
let records_wit: Vec<ArcMultilinearExtension<'_, E>> = cs
.r_expressions
.par_iter()
.chain(cs.w_expressions.par_iter())
.chain(cs.lk_expressions.par_iter())
.iter()
.chain(cs.w_expressions.iter())
.chain(cs.lk_expressions.iter())
.map(|expr| {
assert_eq!(expr.degree(), 1);
wit_infer_by_expr(&[], &witnesses, pi, challenges, expr)
wit_infer_by_expr(&[], &witnesses, pi, challenges, expr, n_threads)
})
.collect();
let (r_records_wit, w_lk_records_wit) = records_wit.split_at(cs.r_expressions.len());
Expand Down Expand Up @@ -525,7 +526,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
// sanity check in debug build and output != instance index for zero check sumcheck poly
if cfg!(debug_assertions) {
let expected_zero_poly =
wit_infer_by_expr(&[], &witnesses, pi, challenges, expr);
wit_infer_by_expr(&[], &witnesses, pi, challenges, expr, n_threads);
let top_100_errors = expected_zero_poly
.get_base_field_vec()
.iter()
Expand Down Expand Up @@ -701,6 +702,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
let wit_inference_span = entered_span!("wit_inference");
// main constraint: lookup denominator and numerator record witness inference
let record_span = entered_span!("record");
let n_threads = max_usable_threads();
let mut records_wit: Vec<ArcMultilinearExtension<'_, E>> = cs
.r_table_expressions
.par_iter()
Expand All @@ -714,7 +716,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
.chain(cs.lk_table_expressions.par_iter().map(|lk| &lk.values))
.map(|expr| {
assert_eq!(expr.degree(), 1);
wit_infer_by_expr(&fixed, &witnesses, pi, challenges, expr)
wit_infer_by_expr(&fixed, &witnesses, pi, challenges, expr, n_threads)
})
.collect();
let max_log2_num_instance = records_wit.iter().map(|mle| mle.num_vars()).max().unwrap();
Expand Down
Loading
Loading