diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index e0fb49840..af3ed8f8d 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -4,9 +4,9 @@ use ark_std::iterable::Iterable; use ff_ext::ExtensionField; use itertools::Itertools; use multilinear_extensions::{ - commutative_op_mle_pair_pool, + commutative_op_mle_pair, mle::{DenseMultilinearExtension, FieldType, IntoMLE}, - op_mle_xa_b_pool, op_mle3_range_pool, + op_mle_xa_b, op_mle3_range, util::{ceil_log2, max_usable_threads}, virtual_poly_v2::ArcMultilinearExtension, }; @@ -238,28 +238,12 @@ pub(crate) fn infer_tower_product_witness( wit_layers } -fn try_recycle_arcpoly Vec, PF2: Fn() -> Vec>( +fn optional_arcpoly_unwrap_pushback( poly: Cow>, - pool_e: &mut SimpleVecPool, PF1>, - pool_b: &mut SimpleVecPool, PF2>, + pool_e: &mut SimpleVecPool, impl Fn() -> Vec>, + pool_b: &mut SimpleVecPool, impl Fn() -> Vec>, pool_expected_size_vec: usize, ) { - // fn downcast_arc( - // arc: ArcMultilinearExtension<'_, E>, - // ) -> DenseMultilinearExtension { - // unsafe { - // // get the raw pointer from the Arc - // assert_eq!(Arc::strong_count(&arc), 1); - // let raw = Arc::into_raw(arc); - // // cast the raw pointer to the desired concrete type - // let typed_ptr = raw as *const DenseMultilinearExtension; - // // manually drop the Arc without dropping the value - // Arc::decrement_strong_count(raw); - // // reconstruct the Arc with the concrete type - // // Move the value out - // ptr::read(typed_ptr) - // } - // } let len = poly.evaluations().len(); if len == pool_expected_size_vec { match poly { @@ -348,7 +332,7 @@ pub(crate) fn wit_infer_by_expr_pool<'a, E: ExtensionField, const N: usize>( &|cow_a, cow_b, pool_e, pool_b| { let (a, b) = (cow_a.as_ref(), cow_b.as_ref()); let poly = - commutative_op_mle_pair_pool!( + commutative_op_mle_pair!( |a, b, res| { match (a.len(), b.len()) { (1, 1) => { @@ -401,14 +385,14 @@ pub(crate) fn wit_infer_by_expr_pool<'a, E: ExtensionField, const N: usize>( pool_e, pool_b ); - try_recycle_arcpoly(cow_a, pool_e, pool_b, len); - try_recycle_arcpoly(cow_b, pool_e, pool_b, len); + optional_arcpoly_unwrap_pushback(cow_a, pool_e, pool_b, len); + optional_arcpoly_unwrap_pushback(cow_b, pool_e, pool_b, len); poly }, &|cow_a, cow_b, pool_e, pool_b| { let (a, b) = (cow_a.as_ref(), cow_b.as_ref()); let poly = - commutative_op_mle_pair_pool!( + commutative_op_mle_pair!( |a, b, res| { match (a.len(), b.len()) { (1, 1) => { @@ -464,13 +448,13 @@ pub(crate) fn wit_infer_by_expr_pool<'a, E: ExtensionField, const N: usize>( pool_e, pool_b ); - try_recycle_arcpoly(cow_a, pool_e, pool_b, len); - try_recycle_arcpoly(cow_b, pool_e, pool_b, len); + optional_arcpoly_unwrap_pushback(cow_a, pool_e, pool_b, len); + optional_arcpoly_unwrap_pushback(cow_b, pool_e, pool_b, len); poly }, &|cow_x, cow_a, cow_b, pool_e, pool_b| { let (x, a, b) = (cow_x.as_ref(), cow_a.as_ref(), cow_b.as_ref()); - let poly = op_mle_xa_b_pool!( + let poly = op_mle_xa_b!( |x, a, b, res| { let res = SyncUnsafeCell::new(res); assert_eq!(a.len(), 1); @@ -490,9 +474,9 @@ pub(crate) fn wit_infer_by_expr_pool<'a, E: ExtensionField, const N: usize>( pool_e, pool_b ); - try_recycle_arcpoly(cow_a, pool_e, pool_b, len); - try_recycle_arcpoly(cow_b, pool_e, pool_b, len); - try_recycle_arcpoly(cow_x, pool_e, pool_b, len); + optional_arcpoly_unwrap_pushback(cow_a, pool_e, pool_b, len); + optional_arcpoly_unwrap_pushback(cow_b, pool_e, pool_b, len); + optional_arcpoly_unwrap_pushback(cow_x, pool_e, pool_b, len); poly }, pool_e, diff --git a/ceno_zkvm/src/utils.rs b/ceno_zkvm/src/utils.rs index 898e8364c..79fa3c490 100644 --- a/ceno_zkvm/src/utils.rs +++ b/ceno_zkvm/src/utils.rs @@ -230,33 +230,28 @@ where result } +/// a simple vector pool +/// not support multi-thread access pub struct SimpleVecPool T> { pool: VecDeque, factory_fn: F, } impl T> SimpleVecPool { - // Create a new pool with a factory closure + // new pool with a factory closure pub fn new(init: F) -> Self { - let pool = SimpleVecPool { + SimpleVecPool { pool: VecDeque::new(), factory_fn: init, - }; - - pool - } - - // Add a new item to the pool - pub fn add(&mut self, item: T) { - self.pool.push_back(item); + } } - // Borrow an item from the pool, or create a new one if empty + // borrow an item from the pool, or create a new one if empty pub fn borrow(&mut self) -> T { self.pool.pop_front().unwrap_or_else(|| (self.factory_fn)()) } - // Return an item to the pool + // push an item to the pool pub fn return_to_pool(&mut self, item: T) { self.pool.push_back(item); } diff --git a/multilinear_extensions/src/mle.rs b/multilinear_extensions/src/mle.rs index 0c8d51393..85947721d 100644 --- a/multilinear_extensions/src/mle.rs +++ b/multilinear_extensions/src/mle.rs @@ -1062,10 +1062,7 @@ macro_rules! op_mle3_range { let $bb_out = $op; $op_bb_out }}; -} -#[macro_export] -macro_rules! op_mle3_range_pool { ($x:ident, $a:ident, $b:ident, $res:ident, $x_vec:ident, $a_vec:ident, $b_vec:ident, $res_vec:ident, $op:expr, |$bb_out:ident| $op_bb_out:expr) => {{ let $x = if let Some((start, offset)) = $x.evaluations_range() { &$x_vec[start..][..offset] @@ -1091,7 +1088,7 @@ macro_rules! op_mle3_range_pool { /// deal with x * a + b #[macro_export] -macro_rules! op_mle_xa_b_pool { +macro_rules! op_mle_xa_b { (|$x:ident, $a:ident, $b:ident, $res:ident| $op:expr, $pool_e:ident, $pool_b:ident, |$bb_out:ident| $op_bb_out:expr) => { match (&$x.evaluations(), &$a.evaluations(), &$b.evaluations()) { ( @@ -1100,7 +1097,7 @@ macro_rules! op_mle_xa_b_pool { $crate::mle::FieldType::Base(b_vec), ) => { let res_vec = $pool_b.borrow(); - op_mle3_range_pool!( + op_mle3_range!( $x, $a, $b, @@ -1119,7 +1116,7 @@ macro_rules! op_mle_xa_b_pool { $crate::mle::FieldType::Base(b_vec), ) => { let res_vec = $pool_e.borrow(); - op_mle3_range_pool!( + op_mle3_range!( $x, $a, $b, @@ -1138,7 +1135,7 @@ macro_rules! op_mle_xa_b_pool { $crate::mle::FieldType::Ext(b_vec), ) => { let res_vec = $pool_e.borrow(); - op_mle3_range_pool!( + op_mle3_range!( $x, $a, $b, @@ -1160,7 +1157,7 @@ macro_rules! op_mle_xa_b_pool { } }; (|$x:ident, $a:ident, $b:ident, $res:ident| $op:expr, $pool_e:ident, $pool_b:ident) => { - op_mle_xa_b_pool!(|$x, $a, $b, $res| $op, $pool_e, $pool_b, |out| out) + op_mle_xa_b!(|$x, $a, $b, $res| $op, $pool_e, $pool_b, |out| out) }; } @@ -1297,15 +1294,6 @@ macro_rules! commutative_op_mle_pair { _ => unreachable!(), } }; - (|$a:ident, $b:ident| $op:expr) => { - commutative_op_mle_pair!(|$a, $b| $op, |out| out) - }; -} - -/// macro support op(a, b) and tackles type matching internally. -/// Please noted that op must satisfy commutative rule w.r.t op(b, a) operand swap. -#[macro_export] -macro_rules! commutative_op_mle_pair_pool { (|$first:ident, $second:ident, $res:ident| $op:expr, $pool_e:ident, $pool_b:ident, |$bb_out:ident| $op_bb_out:expr) => { match (&$first.evaluations(), &$second.evaluations()) { ($crate::mle::FieldType::Base(base1), $crate::mle::FieldType::Base(base2)) => { @@ -1374,6 +1362,9 @@ macro_rules! commutative_op_mle_pair_pool { } }; (|$a:ident, $b:ident, $res:ident| $op:expr, $pool_e:ident, $pool_b:ident) => { - commutative_op_mle_pair_pool!(|$a, $b, $res| $op, $pool_e, $pool_b, |out| out) + commutative_op_mle_pair!(|$a, $b, $res| $op, $pool_e, $pool_b, |out| out) + }; + (|$a:ident, $b:ident| $op:expr) => { + commutative_op_mle_pair!(|$a, $b| $op, |out| out) }; }