Skip to content

Commit

Permalink
code cosmetics
Browse files Browse the repository at this point in the history
  • Loading branch information
hero78119 committed Dec 19, 2024
1 parent dd98c57 commit d3a2f4b
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 61 deletions.
46 changes: 15 additions & 31 deletions ceno_zkvm/src/scheme/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ use ark_std::iterable::Iterable;
use ff_ext::ExtensionField;
use itertools::Itertools;
use multilinear_extensions::{
commutative_op_mle_pair_pool,
commutative_op_mle_pair,
mle::{DenseMultilinearExtension, FieldType, IntoMLE},
op_mle_xa_b_pool, op_mle3_range_pool,
op_mle_xa_b, op_mle3_range,
util::{ceil_log2, max_usable_threads},
virtual_poly_v2::ArcMultilinearExtension,
};
Expand Down Expand Up @@ -238,28 +238,12 @@ pub(crate) fn infer_tower_product_witness<E: ExtensionField>(
wit_layers
}

fn try_recycle_arcpoly<E: ExtensionField, PF1: Fn() -> Vec<E>, PF2: Fn() -> Vec<E::BaseField>>(
fn optional_arcpoly_unwrap_pushback<E: ExtensionField>(
poly: Cow<ArcMultilinearExtension<'_, E>>,
pool_e: &mut SimpleVecPool<Vec<E>, PF1>,
pool_b: &mut SimpleVecPool<Vec<E::BaseField>, PF2>,
pool_e: &mut SimpleVecPool<Vec<E>, impl Fn() -> Vec<E>>,
pool_b: &mut SimpleVecPool<Vec<E::BaseField>, impl Fn() -> Vec<E::BaseField>>,
pool_expected_size_vec: usize,
) {
// fn downcast_arc<E: ExtensionField>(
// arc: ArcMultilinearExtension<'_, E>,
// ) -> DenseMultilinearExtension<E> {
// unsafe {
// // get the raw pointer from the Arc
// assert_eq!(Arc::strong_count(&arc), 1);
// let raw = Arc::into_raw(arc);
// // cast the raw pointer to the desired concrete type
// let typed_ptr = raw as *const DenseMultilinearExtension<E>;
// // manually drop the Arc without dropping the value
// Arc::decrement_strong_count(raw);
// // reconstruct the Arc with the concrete type
// // Move the value out
// ptr::read(typed_ptr)
// }
// }
let len = poly.evaluations().len();
if len == pool_expected_size_vec {
match poly {
Expand Down Expand Up @@ -348,7 +332,7 @@ pub(crate) fn wit_infer_by_expr_pool<'a, E: ExtensionField, const N: usize>(
&|cow_a, cow_b, pool_e, pool_b| {
let (a, b) = (cow_a.as_ref(), cow_b.as_ref());
let poly =
commutative_op_mle_pair_pool!(
commutative_op_mle_pair!(
|a, b, res| {
match (a.len(), b.len()) {
(1, 1) => {
Expand Down Expand Up @@ -401,14 +385,14 @@ pub(crate) fn wit_infer_by_expr_pool<'a, E: ExtensionField, const N: usize>(
pool_e,
pool_b
);
try_recycle_arcpoly(cow_a, pool_e, pool_b, len);
try_recycle_arcpoly(cow_b, pool_e, pool_b, len);
optional_arcpoly_unwrap_pushback(cow_a, pool_e, pool_b, len);
optional_arcpoly_unwrap_pushback(cow_b, pool_e, pool_b, len);
poly
},
&|cow_a, cow_b, pool_e, pool_b| {
let (a, b) = (cow_a.as_ref(), cow_b.as_ref());
let poly =
commutative_op_mle_pair_pool!(
commutative_op_mle_pair!(
|a, b, res| {
match (a.len(), b.len()) {
(1, 1) => {
Expand Down Expand Up @@ -464,13 +448,13 @@ pub(crate) fn wit_infer_by_expr_pool<'a, E: ExtensionField, const N: usize>(
pool_e,
pool_b
);
try_recycle_arcpoly(cow_a, pool_e, pool_b, len);
try_recycle_arcpoly(cow_b, pool_e, pool_b, len);
optional_arcpoly_unwrap_pushback(cow_a, pool_e, pool_b, len);
optional_arcpoly_unwrap_pushback(cow_b, pool_e, pool_b, len);
poly
},
&|cow_x, cow_a, cow_b, pool_e, pool_b| {
let (x, a, b) = (cow_x.as_ref(), cow_a.as_ref(), cow_b.as_ref());
let poly = op_mle_xa_b_pool!(
let poly = op_mle_xa_b!(
|x, a, b, res| {
let res = SyncUnsafeCell::new(res);
assert_eq!(a.len(), 1);
Expand All @@ -490,9 +474,9 @@ pub(crate) fn wit_infer_by_expr_pool<'a, E: ExtensionField, const N: usize>(
pool_e,
pool_b
);
try_recycle_arcpoly(cow_a, pool_e, pool_b, len);
try_recycle_arcpoly(cow_b, pool_e, pool_b, len);
try_recycle_arcpoly(cow_x, pool_e, pool_b, len);
optional_arcpoly_unwrap_pushback(cow_a, pool_e, pool_b, len);
optional_arcpoly_unwrap_pushback(cow_b, pool_e, pool_b, len);
optional_arcpoly_unwrap_pushback(cow_x, pool_e, pool_b, len);
poly
},
pool_e,
Expand Down
19 changes: 7 additions & 12 deletions ceno_zkvm/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,33 +230,28 @@ where
result
}

/// a simple vector pool
/// not support multi-thread access
pub struct SimpleVecPool<T, F: Fn() -> T> {
pool: VecDeque<T>,
factory_fn: F,
}

impl<T, F: Fn() -> T> SimpleVecPool<T, F> {
// Create a new pool with a factory closure
// new pool with a factory closure
pub fn new(init: F) -> Self {
let pool = SimpleVecPool {
SimpleVecPool {
pool: VecDeque::new(),
factory_fn: init,
};

pool
}

// Add a new item to the pool
pub fn add(&mut self, item: T) {
self.pool.push_back(item);
}
}

// Borrow an item from the pool, or create a new one if empty
// borrow an item from the pool, or create a new one if empty
pub fn borrow(&mut self) -> T {
self.pool.pop_front().unwrap_or_else(|| (self.factory_fn)())
}

// Return an item to the pool
// push an item to the pool
pub fn return_to_pool(&mut self, item: T) {
self.pool.push_back(item);
}
Expand Down
27 changes: 9 additions & 18 deletions multilinear_extensions/src/mle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1062,10 +1062,7 @@ macro_rules! op_mle3_range {
let $bb_out = $op;
$op_bb_out
}};
}

#[macro_export]
macro_rules! op_mle3_range_pool {
($x:ident, $a:ident, $b:ident, $res:ident, $x_vec:ident, $a_vec:ident, $b_vec:ident, $res_vec:ident, $op:expr, |$bb_out:ident| $op_bb_out:expr) => {{
let $x = if let Some((start, offset)) = $x.evaluations_range() {
&$x_vec[start..][..offset]
Expand All @@ -1091,7 +1088,7 @@ macro_rules! op_mle3_range_pool {

/// deal with x * a + b
#[macro_export]
macro_rules! op_mle_xa_b_pool {
macro_rules! op_mle_xa_b {
(|$x:ident, $a:ident, $b:ident, $res:ident| $op:expr, $pool_e:ident, $pool_b:ident, |$bb_out:ident| $op_bb_out:expr) => {
match (&$x.evaluations(), &$a.evaluations(), &$b.evaluations()) {
(
Expand All @@ -1100,7 +1097,7 @@ macro_rules! op_mle_xa_b_pool {
$crate::mle::FieldType::Base(b_vec),
) => {
let res_vec = $pool_b.borrow();
op_mle3_range_pool!(
op_mle3_range!(
$x,
$a,
$b,
Expand All @@ -1119,7 +1116,7 @@ macro_rules! op_mle_xa_b_pool {
$crate::mle::FieldType::Base(b_vec),
) => {
let res_vec = $pool_e.borrow();
op_mle3_range_pool!(
op_mle3_range!(
$x,
$a,
$b,
Expand All @@ -1138,7 +1135,7 @@ macro_rules! op_mle_xa_b_pool {
$crate::mle::FieldType::Ext(b_vec),
) => {
let res_vec = $pool_e.borrow();
op_mle3_range_pool!(
op_mle3_range!(
$x,
$a,
$b,
Expand All @@ -1160,7 +1157,7 @@ macro_rules! op_mle_xa_b_pool {
}
};
(|$x:ident, $a:ident, $b:ident, $res:ident| $op:expr, $pool_e:ident, $pool_b:ident) => {
op_mle_xa_b_pool!(|$x, $a, $b, $res| $op, $pool_e, $pool_b, |out| out)
op_mle_xa_b!(|$x, $a, $b, $res| $op, $pool_e, $pool_b, |out| out)
};
}

Expand Down Expand Up @@ -1297,15 +1294,6 @@ macro_rules! commutative_op_mle_pair {
_ => unreachable!(),
}
};
(|$a:ident, $b:ident| $op:expr) => {
commutative_op_mle_pair!(|$a, $b| $op, |out| out)
};
}

/// macro support op(a, b) and tackles type matching internally.
/// Please noted that op must satisfy commutative rule w.r.t op(b, a) operand swap.
#[macro_export]
macro_rules! commutative_op_mle_pair_pool {
(|$first:ident, $second:ident, $res:ident| $op:expr, $pool_e:ident, $pool_b:ident, |$bb_out:ident| $op_bb_out:expr) => {
match (&$first.evaluations(), &$second.evaluations()) {
($crate::mle::FieldType::Base(base1), $crate::mle::FieldType::Base(base2)) => {
Expand Down Expand Up @@ -1374,6 +1362,9 @@ macro_rules! commutative_op_mle_pair_pool {
}
};
(|$a:ident, $b:ident, $res:ident| $op:expr, $pool_e:ident, $pool_b:ident) => {
commutative_op_mle_pair_pool!(|$a, $b, $res| $op, $pool_e, $pool_b, |out| out)
commutative_op_mle_pair!(|$a, $b, $res| $op, $pool_e, $pool_b, |out| out)
};
(|$a:ident, $b:ident| $op:expr) => {
commutative_op_mle_pair!(|$a, $b| $op, |out| out)
};
}

0 comments on commit d3a2f4b

Please sign in to comment.