diff --git a/ceno_zkvm/src/expression.rs b/ceno_zkvm/src/expression.rs index 29b800303..507507397 100644 --- a/ceno_zkvm/src/expression.rs +++ b/ceno_zkvm/src/expression.rs @@ -143,24 +143,34 @@ impl Expression { } #[allow(clippy::too_many_arguments)] - pub fn evaluate_with_instance_pool( + pub fn evaluate_with_instance_pool Vec, PF2: Fn() -> Vec>( &self, fixed_in: &impl Fn(&Fixed) -> T, wit_in: &impl Fn(WitnessId) -> T, // witin id instance: &impl Fn(Instance) -> T, constant: &impl Fn(E::BaseField) -> T, challenge: &impl Fn(ChallengeId, usize, E, E) -> T, - sum: &impl Fn(T, T, &mut SimpleVecPool>, &mut SimpleVecPool>) -> T, - product: &impl Fn(T, T, &mut SimpleVecPool>, &mut SimpleVecPool>) -> T, + sum: &impl Fn( + T, + T, + &mut SimpleVecPool, PF1>, + &mut SimpleVecPool, PF2>, + ) -> T, + product: &impl Fn( + T, + T, + &mut SimpleVecPool, PF1>, + &mut SimpleVecPool, PF2>, + ) -> T, scaled: &impl Fn( T, T, T, - &mut SimpleVecPool>, - &mut SimpleVecPool>, + &mut SimpleVecPool, PF1>, + &mut SimpleVecPool, PF2>, ) -> T, - pool_e: &mut SimpleVecPool>, - pool_b: &mut SimpleVecPool>, + pool_e: &mut SimpleVecPool, PF1>, + pool_b: &mut SimpleVecPool, PF2>, ) -> T { match self { Expression::Fixed(f) => fixed_in(f), diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index e11ebdee5..037f8a130 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -23,9 +23,7 @@ use ff_ext::ExtensionField; use generic_static::StaticTypeMap; use goldilocks::SmallField; use itertools::{Itertools, enumerate, izip}; -use multilinear_extensions::{ - mle::IntoMLEs, util::max_usable_threads, virtual_poly_v2::ArcMultilinearExtension, -}; +use multilinear_extensions::{mle::IntoMLEs, virtual_poly_v2::ArcMultilinearExtension}; use rand::thread_rng; use std::{ collections::{HashMap, HashSet}, @@ -428,7 +426,6 @@ impl<'a, E: ExtensionField + Hash> MockProver { challenge: Option<[E; 2]>, lkm: Option, ) -> Result<(), Vec>> { - let n_threads = max_usable_threads(); let program = Program::new( CENO_PLATFORM.pc_base(), CENO_PLATFORM.pc_base(), @@ -476,12 +473,10 @@ impl<'a, E: ExtensionField + Hash> MockProver { let (left, right) = expr.unpack_sum().unwrap(); let right = right.neg(); - let left_evaluated = - wit_infer_by_expr(&[], wits_in, pi, &challenge, &left, n_threads); + let left_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, &left); let left_evaluated = left_evaluated.get_base_field_vec(); - let right_evaluated = - wit_infer_by_expr(&[], wits_in, pi, &challenge, &right, n_threads); + let right_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, &right); let right_evaluated = right_evaluated.get_base_field_vec(); // left_evaluated.len() ?= right_evaluated.len() due to padding instance @@ -501,8 +496,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { } } else { // contains require_zero - let expr_evaluated = - wit_infer_by_expr(&[], wits_in, pi, &challenge, expr, n_threads); + let expr_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, expr); let expr_evaluated = expr_evaluated.get_base_field_vec(); for (inst_id, element) in enumerate(expr_evaluated) { @@ -525,7 +519,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { .iter() .zip_eq(cb.cs.lk_expressions_namespace_map.iter()) { - let expr_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, expr, n_threads); + let expr_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, expr); let expr_evaluated = expr_evaluated.get_ext_field_vec(); // Check each lookup expr exists in t vec @@ -556,7 +550,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { .map(|expr| { // TODO generalized to all inst_id let inst_id = 0; - wit_infer_by_expr(&[], wits_in, pi, &challenge, expr, n_threads) + wit_infer_by_expr(&[], wits_in, pi, &challenge, expr) .get_base_field_vec()[inst_id] .to_canonical_u64() }) @@ -748,7 +742,6 @@ Hints: witnesses: &ZKVMWitnesses, pi: &PublicValues, ) { - let n_threads = max_usable_threads(); let instance = pi .to_vec::() .concat() @@ -822,16 +815,10 @@ Hints: .zip(cs.lk_expressions_namespace_map.clone().into_iter()) .zip(cs.lk_expressions_items_map.clone().into_iter()) { - let lk_input = (wit_infer_by_expr( - &fixed, - &witness, - &pi_mles, - &challenges, - expr, - n_threads, - ) - .get_ext_field_vec())[..num_rows] - .to_vec(); + let lk_input = + (wit_infer_by_expr(&fixed, &witness, &pi_mles, &challenges, expr) + .get_ext_field_vec())[..num_rows] + .to_vec(); rom_inputs.entry(rom_type).or_default().push(( lk_input, circuit_name.clone(), @@ -851,16 +838,10 @@ Hints: .iter() .zip(cs.lk_expressions_items_map.clone().into_iter()) { - let lk_table = wit_infer_by_expr( - &fixed, - &witness, - &pi_mles, - &challenges, - &expr.values, - n_threads, - ) - .get_ext_field_vec() - .to_vec(); + let lk_table = + wit_infer_by_expr(&fixed, &witness, &pi_mles, &challenges, &expr.values) + .get_ext_field_vec() + .to_vec(); let multiplicity = wit_infer_by_expr( &fixed, @@ -868,7 +849,6 @@ Hints: &pi_mles, &challenges, &expr.multiplicity, - n_threads, ) .get_base_field_vec() .to_vec(); @@ -988,16 +968,10 @@ Hints: .zip_eq(cs.w_ram_types.iter()) .filter(|((_, _), (ram_type, _))| *ram_type == $ram_type) { - let write_rlc_records = (wit_infer_by_expr( - fixed, - witness, - &pi_mles, - &challenges, - w_rlc_expr, - n_threads, - ) - .get_ext_field_vec())[..*num_rows] - .to_vec(); + let write_rlc_records = + (wit_infer_by_expr(fixed, witness, &pi_mles, &challenges, w_rlc_expr) + .get_ext_field_vec())[..*num_rows] + .to_vec(); if $ram_type == RAMType::GlobalState { // w_exprs = [GlobalState, pc, timestamp] @@ -1012,7 +986,6 @@ Hints: &pi_mles, &challenges, expr, - n_threads, ); v.get_base_field_vec()[..*num_rows].to_vec() }) @@ -1057,16 +1030,10 @@ Hints: .zip_eq(cs.r_ram_types.iter()) .filter(|((_, _), (ram_type, _))| *ram_type == $ram_type) { - let read_records = wit_infer_by_expr( - fixed, - witness, - &pi_mles, - &challenges, - r_expr, - n_threads, - ) - .get_ext_field_vec()[..*num_rows] - .to_vec(); + let read_records = + wit_infer_by_expr(fixed, witness, &pi_mles, &challenges, r_expr) + .get_ext_field_vec()[..*num_rows] + .to_vec(); let mut records = vec![]; for (row, record) in enumerate(read_records) { // TODO: return error diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 20c106c3d..f728cd5ad 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -13,7 +13,9 @@ use multilinear_extensions::{ virtual_poly::build_eq_x_r_vec, virtual_poly_v2::ArcMultilinearExtension, }; -use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; +use rayon::iter::{ + IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator, +}; use sumcheck::{ macros::{entered_span, exit_span}, structs::{IOPProverMessage, IOPProverStateV2}, @@ -25,16 +27,18 @@ use crate::{ error::ZKVMError, expression::Instance, scheme::{ - constants::{MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, NUM_FANIN, NUM_FANIN_LOGUP}, + constants::{MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, MIN_PAR_SIZE, NUM_FANIN, NUM_FANIN_LOGUP}, utils::{ infer_tower_logup_witness, infer_tower_product_witness, interleaving_mles_to_mles, - wit_infer_by_expr, + wit_infer_by_expr, wit_infer_by_expr_pool, }, }, structs::{ Point, ProvingKey, TowerProofs, TowerProver, TowerProverSpec, ZKVMProvingKey, ZKVMWitnesses, }, - utils::{get_challenge_pows, next_pow2_instance_padding, optimal_sumcheck_threads}, + utils::{ + SimpleVecPool, get_challenge_pows, next_pow2_instance_padding, optimal_sumcheck_threads, + }, virtual_polys::VirtualPolynomials, }; @@ -238,6 +242,21 @@ impl> ZKVMProver { let wit_inference_span = entered_span!("wit_inference", profiling_3 = true); // main constraint: read/write record witness inference let record_span = entered_span!("record"); + let len = witnesses[0].evaluations().len(); + let mut pool_e: SimpleVecPool, _> = SimpleVecPool::new(|| { + (0..len) + .into_par_iter() + .with_min_len(MIN_PAR_SIZE) + .map(|_| E::ZERO) + .collect::>() + }); + let mut pool_b: SimpleVecPool, _> = SimpleVecPool::new(|| { + (0..len) + .into_par_iter() + .with_min_len(MIN_PAR_SIZE) + .map(|_| E::BaseField::ZERO) + .collect::>() + }); let n_threads = max_usable_threads(); let records_wit: Vec> = cs .r_expressions @@ -246,7 +265,16 @@ impl> ZKVMProver { .chain(cs.lk_expressions.iter()) .map(|expr| { assert_eq!(expr.degree(), 1); - wit_infer_by_expr(&[], &witnesses, pi, challenges, expr, n_threads) + wit_infer_by_expr_pool( + &[], + &witnesses, + pi, + challenges, + expr, + n_threads, + &mut pool_e, + &mut pool_b, + ) }) .collect(); let (r_records_wit, w_lk_records_wit) = records_wit.split_at(cs.r_expressions.len()); @@ -526,7 +554,7 @@ impl> ZKVMProver { // sanity check in debug build and output != instance index for zero check sumcheck poly if cfg!(debug_assertions) { let expected_zero_poly = - wit_infer_by_expr(&[], &witnesses, pi, challenges, expr, n_threads); + wit_infer_by_expr(&[], &witnesses, pi, challenges, expr); let top_100_errors = expected_zero_poly .get_base_field_vec() .iter() @@ -702,21 +730,41 @@ impl> ZKVMProver { let wit_inference_span = entered_span!("wit_inference"); // main constraint: lookup denominator and numerator record witness inference let record_span = entered_span!("record"); + let len = witnesses[0].evaluations().len(); + let mut pool_e: SimpleVecPool, _> = SimpleVecPool::new(|| { + (0..len) + .into_par_iter() + .with_min_len(MIN_PAR_SIZE) + .map(|_| E::ZERO) + .collect::>() + }); + let mut pool_b: SimpleVecPool, _> = SimpleVecPool::new(|| { + (0..len) + .into_par_iter() + .with_min_len(MIN_PAR_SIZE) + .map(|_| E::BaseField::ZERO) + .collect::>() + }); let n_threads = max_usable_threads(); let mut records_wit: Vec> = cs .r_table_expressions - .par_iter() + .iter() .map(|r| &r.expr) - .chain(cs.w_table_expressions.par_iter().map(|w| &w.expr)) - .chain( - cs.lk_table_expressions - .par_iter() - .map(|lk| &lk.multiplicity), - ) - .chain(cs.lk_table_expressions.par_iter().map(|lk| &lk.values)) + .chain(cs.w_table_expressions.iter().map(|w| &w.expr)) + .chain(cs.lk_table_expressions.iter().map(|lk| &lk.multiplicity)) + .chain(cs.lk_table_expressions.iter().map(|lk| &lk.values)) .map(|expr| { assert_eq!(expr.degree(), 1); - wit_infer_by_expr(&fixed, &witnesses, pi, challenges, expr, n_threads) + wit_infer_by_expr_pool( + &fixed, + &witnesses, + pi, + challenges, + expr, + n_threads, + &mut pool_e, + &mut pool_b, + ) }) .collect(); let max_log2_num_instance = records_wit.iter().map(|mle| mle.num_vars()).max().unwrap(); diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index 4d9f81d52..e0fb49840 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -7,14 +7,12 @@ use multilinear_extensions::{ commutative_op_mle_pair_pool, mle::{DenseMultilinearExtension, FieldType, IntoMLE}, op_mle_xa_b_pool, op_mle3_range_pool, - util::ceil_log2, + util::{ceil_log2, max_usable_threads}, virtual_poly_v2::ArcMultilinearExtension, }; use ff::Field; -const POOL_CAP: usize = 3; - use rayon::{ iter::{ IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, @@ -240,10 +238,10 @@ pub(crate) fn infer_tower_product_witness( wit_layers } -fn try_recycle_arcpoly( +fn try_recycle_arcpoly Vec, PF2: Fn() -> Vec>( poly: Cow>, - pool_e: &mut SimpleVecPool>, - pool_b: &mut SimpleVecPool>, + pool_e: &mut SimpleVecPool, PF1>, + pool_b: &mut SimpleVecPool, PF2>, pool_expected_size_vec: usize, ) { // fn downcast_arc( @@ -284,25 +282,49 @@ pub(crate) fn wit_infer_by_expr<'a, E: ExtensionField, const N: usize>( instance: &[ArcMultilinearExtension<'a, E>], challenges: &[E; N], expr: &Expression, - n_threads: usize, ) -> ArcMultilinearExtension<'a, E> { + let n_threads = max_usable_threads(); let len = witnesses[0].evaluations().len(); - let mut pool_e: SimpleVecPool> = SimpleVecPool::new(POOL_CAP, || { + let mut pool_e: SimpleVecPool, _> = SimpleVecPool::new(|| { (0..len) .into_par_iter() .with_min_len(MIN_PAR_SIZE) .map(|_| E::ZERO) .collect::>() }); - let mut pool_b: SimpleVecPool> = SimpleVecPool::new(POOL_CAP, || { + let mut pool_b: SimpleVecPool, _> = SimpleVecPool::new(|| { (0..len) .into_par_iter() .with_min_len(MIN_PAR_SIZE) .map(|_| E::BaseField::ZERO) .collect::>() }); + wit_infer_by_expr_pool( + fixed, + witnesses, + instance, + challenges, + expr, + n_threads, + &mut pool_e, + &mut pool_b, + ) +} + +#[allow(clippy::too_many_arguments)] +pub(crate) fn wit_infer_by_expr_pool<'a, E: ExtensionField, const N: usize>( + fixed: &[ArcMultilinearExtension<'a, E>], + witnesses: &[ArcMultilinearExtension<'a, E>], + instance: &[ArcMultilinearExtension<'a, E>], + challenges: &[E; N], + expr: &Expression, + n_threads: usize, + pool_e: &mut SimpleVecPool, impl Fn() -> Vec>, + pool_b: &mut SimpleVecPool, impl Fn() -> Vec>, +) -> ArcMultilinearExtension<'a, E> { + let len = witnesses[0].evaluations().len(); let poly = - expr.evaluate_with_instance_pool::>>( + expr.evaluate_with_instance_pool::>, _, _>( &|f| Cow::Borrowed(&fixed[f.0]), &|witness_id| Cow::Borrowed(&witnesses[witness_id as usize]), &|i| Cow::Borrowed(&instance[i.0]), @@ -473,8 +495,8 @@ pub(crate) fn wit_infer_by_expr<'a, E: ExtensionField, const N: usize>( try_recycle_arcpoly(cow_x, pool_e, pool_b, len); poly }, - &mut pool_e, - &mut pool_b, + pool_e, + pool_b, ); match poly { Cow::Borrowed(poly) => poly.clone(), @@ -816,7 +838,6 @@ mod tests { &[], &[], &expr, - 1, ); res.get_base_field_vec(); } @@ -847,7 +868,6 @@ mod tests { &[], &[E::ONE], &expr, - 1, ); res.get_ext_field_vec(); } diff --git a/ceno_zkvm/src/utils.rs b/ceno_zkvm/src/utils.rs index 8d89431e3..898e8364c 100644 --- a/ceno_zkvm/src/utils.rs +++ b/ceno_zkvm/src/utils.rs @@ -230,19 +230,19 @@ where result } -pub struct SimpleVecPool { +pub struct SimpleVecPool T> { pool: VecDeque, + factory_fn: F, } -impl SimpleVecPool { +impl T> SimpleVecPool { // Create a new pool with a factory closure - pub fn new T>(cap: usize, init: F) -> Self { - let mut pool = SimpleVecPool { + pub fn new(init: F) -> Self { + let pool = SimpleVecPool { pool: VecDeque::new(), + factory_fn: init, }; - (0..cap).for_each(|_| { - pool.add(init()); - }); + pool } @@ -253,9 +253,7 @@ impl SimpleVecPool { // Borrow an item from the pool, or create a new one if empty pub fn borrow(&mut self) -> T { - self.pool - .pop_front() - .expect("pool is empty, consider increase cap size") + self.pool.pop_front().unwrap_or_else(|| (self.factory_fn)()) } // Return an item to the pool