Skip to content

Commit

Permalink
shared vector pool across whole witness inference
Browse files Browse the repository at this point in the history
  • Loading branch information
hero78119 committed Dec 19, 2024
1 parent c7263ab commit dd98c57
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 101 deletions.
24 changes: 17 additions & 7 deletions ceno_zkvm/src/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,24 +143,34 @@ impl<E: ExtensionField> Expression<E> {
}

#[allow(clippy::too_many_arguments)]
pub fn evaluate_with_instance_pool<T>(
pub fn evaluate_with_instance_pool<T, PF1: Fn() -> Vec<E>, PF2: Fn() -> Vec<E::BaseField>>(
&self,
fixed_in: &impl Fn(&Fixed) -> T,
wit_in: &impl Fn(WitnessId) -> T, // witin id
instance: &impl Fn(Instance) -> T,
constant: &impl Fn(E::BaseField) -> T,
challenge: &impl Fn(ChallengeId, usize, E, E) -> T,
sum: &impl Fn(T, T, &mut SimpleVecPool<Vec<E>>, &mut SimpleVecPool<Vec<E::BaseField>>) -> T,
product: &impl Fn(T, T, &mut SimpleVecPool<Vec<E>>, &mut SimpleVecPool<Vec<E::BaseField>>) -> T,
sum: &impl Fn(
T,
T,
&mut SimpleVecPool<Vec<E>, PF1>,
&mut SimpleVecPool<Vec<E::BaseField>, PF2>,
) -> T,
product: &impl Fn(
T,
T,
&mut SimpleVecPool<Vec<E>, PF1>,
&mut SimpleVecPool<Vec<E::BaseField>, PF2>,
) -> T,
scaled: &impl Fn(
T,
T,
T,
&mut SimpleVecPool<Vec<E>>,
&mut SimpleVecPool<Vec<E::BaseField>>,
&mut SimpleVecPool<Vec<E>, PF1>,
&mut SimpleVecPool<Vec<E::BaseField>, PF2>,
) -> T,
pool_e: &mut SimpleVecPool<Vec<E>>,
pool_b: &mut SimpleVecPool<Vec<E::BaseField>>,
pool_e: &mut SimpleVecPool<Vec<E>, PF1>,
pool_b: &mut SimpleVecPool<Vec<E::BaseField>, PF2>,
) -> T {
match self {
Expression::Fixed(f) => fixed_in(f),
Expand Down
77 changes: 22 additions & 55 deletions ceno_zkvm/src/scheme/mock_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ use ff_ext::ExtensionField;
use generic_static::StaticTypeMap;
use goldilocks::SmallField;
use itertools::{Itertools, enumerate, izip};
use multilinear_extensions::{
mle::IntoMLEs, util::max_usable_threads, virtual_poly_v2::ArcMultilinearExtension,
};
use multilinear_extensions::{mle::IntoMLEs, virtual_poly_v2::ArcMultilinearExtension};
use rand::thread_rng;
use std::{
collections::{HashMap, HashSet},
Expand Down Expand Up @@ -428,7 +426,6 @@ impl<'a, E: ExtensionField + Hash> MockProver<E> {
challenge: Option<[E; 2]>,
lkm: Option<LkMultiplicity>,
) -> Result<(), Vec<MockProverError<E>>> {
let n_threads = max_usable_threads();
let program = Program::new(
CENO_PLATFORM.pc_base(),
CENO_PLATFORM.pc_base(),
Expand Down Expand Up @@ -476,12 +473,10 @@ impl<'a, E: ExtensionField + Hash> MockProver<E> {
let (left, right) = expr.unpack_sum().unwrap();
let right = right.neg();

let left_evaluated =
wit_infer_by_expr(&[], wits_in, pi, &challenge, &left, n_threads);
let left_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, &left);
let left_evaluated = left_evaluated.get_base_field_vec();

let right_evaluated =
wit_infer_by_expr(&[], wits_in, pi, &challenge, &right, n_threads);
let right_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, &right);
let right_evaluated = right_evaluated.get_base_field_vec();

// left_evaluated.len() ?= right_evaluated.len() due to padding instance
Expand All @@ -501,8 +496,7 @@ impl<'a, E: ExtensionField + Hash> MockProver<E> {
}
} else {
// contains require_zero
let expr_evaluated =
wit_infer_by_expr(&[], wits_in, pi, &challenge, expr, n_threads);
let expr_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, expr);
let expr_evaluated = expr_evaluated.get_base_field_vec();

for (inst_id, element) in enumerate(expr_evaluated) {
Expand All @@ -525,7 +519,7 @@ impl<'a, E: ExtensionField + Hash> MockProver<E> {
.iter()
.zip_eq(cb.cs.lk_expressions_namespace_map.iter())
{
let expr_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, expr, n_threads);
let expr_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, expr);
let expr_evaluated = expr_evaluated.get_ext_field_vec();

// Check each lookup expr exists in t vec
Expand Down Expand Up @@ -556,7 +550,7 @@ impl<'a, E: ExtensionField + Hash> MockProver<E> {
.map(|expr| {
// TODO generalized to all inst_id
let inst_id = 0;
wit_infer_by_expr(&[], wits_in, pi, &challenge, expr, n_threads)
wit_infer_by_expr(&[], wits_in, pi, &challenge, expr)
.get_base_field_vec()[inst_id]
.to_canonical_u64()
})
Expand Down Expand Up @@ -748,7 +742,6 @@ Hints:
witnesses: &ZKVMWitnesses<E>,
pi: &PublicValues<u32>,
) {
let n_threads = max_usable_threads();
let instance = pi
.to_vec::<E>()
.concat()
Expand Down Expand Up @@ -822,16 +815,10 @@ Hints:
.zip(cs.lk_expressions_namespace_map.clone().into_iter())
.zip(cs.lk_expressions_items_map.clone().into_iter())
{
let lk_input = (wit_infer_by_expr(
&fixed,
&witness,
&pi_mles,
&challenges,
expr,
n_threads,
)
.get_ext_field_vec())[..num_rows]
.to_vec();
let lk_input =
(wit_infer_by_expr(&fixed, &witness, &pi_mles, &challenges, expr)
.get_ext_field_vec())[..num_rows]
.to_vec();
rom_inputs.entry(rom_type).or_default().push((
lk_input,
circuit_name.clone(),
Expand All @@ -851,24 +838,17 @@ Hints:
.iter()
.zip(cs.lk_expressions_items_map.clone().into_iter())
{
let lk_table = wit_infer_by_expr(
&fixed,
&witness,
&pi_mles,
&challenges,
&expr.values,
n_threads,
)
.get_ext_field_vec()
.to_vec();
let lk_table =
wit_infer_by_expr(&fixed, &witness, &pi_mles, &challenges, &expr.values)
.get_ext_field_vec()
.to_vec();

let multiplicity = wit_infer_by_expr(
&fixed,
&witness,
&pi_mles,
&challenges,
&expr.multiplicity,
n_threads,
)
.get_base_field_vec()
.to_vec();
Expand Down Expand Up @@ -988,16 +968,10 @@ Hints:
.zip_eq(cs.w_ram_types.iter())
.filter(|((_, _), (ram_type, _))| *ram_type == $ram_type)
{
let write_rlc_records = (wit_infer_by_expr(
fixed,
witness,
&pi_mles,
&challenges,
w_rlc_expr,
n_threads,
)
.get_ext_field_vec())[..*num_rows]
.to_vec();
let write_rlc_records =
(wit_infer_by_expr(fixed, witness, &pi_mles, &challenges, w_rlc_expr)
.get_ext_field_vec())[..*num_rows]
.to_vec();

if $ram_type == RAMType::GlobalState {
// w_exprs = [GlobalState, pc, timestamp]
Expand All @@ -1012,7 +986,6 @@ Hints:
&pi_mles,
&challenges,
expr,
n_threads,
);
v.get_base_field_vec()[..*num_rows].to_vec()
})
Expand Down Expand Up @@ -1057,16 +1030,10 @@ Hints:
.zip_eq(cs.r_ram_types.iter())
.filter(|((_, _), (ram_type, _))| *ram_type == $ram_type)
{
let read_records = wit_infer_by_expr(
fixed,
witness,
&pi_mles,
&challenges,
r_expr,
n_threads,
)
.get_ext_field_vec()[..*num_rows]
.to_vec();
let read_records =
wit_infer_by_expr(fixed, witness, &pi_mles, &challenges, r_expr)
.get_ext_field_vec()[..*num_rows]
.to_vec();
let mut records = vec![];
for (row, record) in enumerate(read_records) {
// TODO: return error
Expand Down
78 changes: 63 additions & 15 deletions ceno_zkvm/src/scheme/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ use multilinear_extensions::{
virtual_poly::build_eq_x_r_vec,
virtual_poly_v2::ArcMultilinearExtension,
};
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
use rayon::iter::{
IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator,
};
use sumcheck::{
macros::{entered_span, exit_span},
structs::{IOPProverMessage, IOPProverStateV2},
Expand All @@ -25,16 +27,18 @@ use crate::{
error::ZKVMError,
expression::Instance,
scheme::{
constants::{MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, NUM_FANIN, NUM_FANIN_LOGUP},
constants::{MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, MIN_PAR_SIZE, NUM_FANIN, NUM_FANIN_LOGUP},
utils::{
infer_tower_logup_witness, infer_tower_product_witness, interleaving_mles_to_mles,
wit_infer_by_expr,
wit_infer_by_expr, wit_infer_by_expr_pool,
},
},
structs::{
Point, ProvingKey, TowerProofs, TowerProver, TowerProverSpec, ZKVMProvingKey, ZKVMWitnesses,
},
utils::{get_challenge_pows, next_pow2_instance_padding, optimal_sumcheck_threads},
utils::{
SimpleVecPool, get_challenge_pows, next_pow2_instance_padding, optimal_sumcheck_threads,
},
virtual_polys::VirtualPolynomials,
};

Expand Down Expand Up @@ -238,6 +242,21 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
let wit_inference_span = entered_span!("wit_inference", profiling_3 = true);
// main constraint: read/write record witness inference
let record_span = entered_span!("record");
let len = witnesses[0].evaluations().len();
let mut pool_e: SimpleVecPool<Vec<_>, _> = SimpleVecPool::new(|| {
(0..len)
.into_par_iter()
.with_min_len(MIN_PAR_SIZE)
.map(|_| E::ZERO)
.collect::<Vec<E>>()
});
let mut pool_b: SimpleVecPool<Vec<_>, _> = SimpleVecPool::new(|| {
(0..len)
.into_par_iter()
.with_min_len(MIN_PAR_SIZE)
.map(|_| E::BaseField::ZERO)
.collect::<Vec<E::BaseField>>()
});
let n_threads = max_usable_threads();
let records_wit: Vec<ArcMultilinearExtension<'_, E>> = cs
.r_expressions
Expand All @@ -246,7 +265,16 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
.chain(cs.lk_expressions.iter())
.map(|expr| {
assert_eq!(expr.degree(), 1);
wit_infer_by_expr(&[], &witnesses, pi, challenges, expr, n_threads)
wit_infer_by_expr_pool(
&[],
&witnesses,
pi,
challenges,
expr,
n_threads,
&mut pool_e,
&mut pool_b,
)
})
.collect();
let (r_records_wit, w_lk_records_wit) = records_wit.split_at(cs.r_expressions.len());
Expand Down Expand Up @@ -526,7 +554,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
// sanity check in debug build and output != instance index for zero check sumcheck poly
if cfg!(debug_assertions) {
let expected_zero_poly =
wit_infer_by_expr(&[], &witnesses, pi, challenges, expr, n_threads);
wit_infer_by_expr(&[], &witnesses, pi, challenges, expr);
let top_100_errors = expected_zero_poly
.get_base_field_vec()
.iter()
Expand Down Expand Up @@ -702,21 +730,41 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
let wit_inference_span = entered_span!("wit_inference");
// main constraint: lookup denominator and numerator record witness inference
let record_span = entered_span!("record");
let len = witnesses[0].evaluations().len();
let mut pool_e: SimpleVecPool<Vec<_>, _> = SimpleVecPool::new(|| {
(0..len)
.into_par_iter()
.with_min_len(MIN_PAR_SIZE)
.map(|_| E::ZERO)
.collect::<Vec<E>>()
});
let mut pool_b: SimpleVecPool<Vec<_>, _> = SimpleVecPool::new(|| {
(0..len)
.into_par_iter()
.with_min_len(MIN_PAR_SIZE)
.map(|_| E::BaseField::ZERO)
.collect::<Vec<E::BaseField>>()
});
let n_threads = max_usable_threads();
let mut records_wit: Vec<ArcMultilinearExtension<'_, E>> = cs
.r_table_expressions
.par_iter()
.iter()
.map(|r| &r.expr)
.chain(cs.w_table_expressions.par_iter().map(|w| &w.expr))
.chain(
cs.lk_table_expressions
.par_iter()
.map(|lk| &lk.multiplicity),
)
.chain(cs.lk_table_expressions.par_iter().map(|lk| &lk.values))
.chain(cs.w_table_expressions.iter().map(|w| &w.expr))
.chain(cs.lk_table_expressions.iter().map(|lk| &lk.multiplicity))
.chain(cs.lk_table_expressions.iter().map(|lk| &lk.values))
.map(|expr| {
assert_eq!(expr.degree(), 1);
wit_infer_by_expr(&fixed, &witnesses, pi, challenges, expr, n_threads)
wit_infer_by_expr_pool(
&fixed,
&witnesses,
pi,
challenges,
expr,
n_threads,
&mut pool_e,
&mut pool_b,
)
})
.collect();
let max_log2_num_instance = records_wit.iter().map(|mle| mle.num_vars()).max().unwrap();
Expand Down
Loading

0 comments on commit dd98c57

Please sign in to comment.