Skip to content

Commit

Permalink
Optimize rebuild
Browse files Browse the repository at this point in the history
  • Loading branch information
dewert99 committed Apr 9, 2024
1 parent 9ba0fef commit 1c7ecbb
Show file tree
Hide file tree
Showing 7 changed files with 175 additions and 3 deletions.
1 change: 1 addition & 0 deletions src/raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ mod semi_persistent;
/// One variant of semi_persistence
pub mod semi_persistent1;

mod bitset;
/// Another variant of semi_persistence
pub mod semi_persistent2;
mod unionfind;
Expand Down
96 changes: 96 additions & 0 deletions src/raw/bitset.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#[cfg(feature = "serde-1")]
use serde::{Deserialize, Serialize};
use std::fmt::{Debug, Formatter};
use std::mem;

#[derive(Debug, Default, Clone)]
#[cfg_attr(feature = "serde-1", derive(Serialize, Deserialize))]
pub struct DefaultVec<T>(Box<[T]>);

impl<T: Default> DefaultVec<T> {
#[cold]
#[inline(never)]
fn reserve(&mut self, i: usize) {
let mut v = mem::take(&mut self.0).into_vec();
v.reserve(i + 1 - v.len());
v.resize_with(v.capacity(), T::default);
self.0 = v.into_boxed_slice();
assert!(i < self.0.len())
}

pub fn get_mut(&mut self, i: usize) -> &mut T {
if i < self.0.len() {
&mut self.0[i]
} else {
self.reserve(i);
&mut self.0[i]
}
}

pub fn get(&self, i: usize) -> T
where
T: Copy,
{
self.0.get(i).copied().unwrap_or_default()
}

pub fn clear(&mut self) {
self.0.fill_with(Default::default)
}
}

type Elt = u32;

#[derive(Default, Clone)]
#[cfg_attr(feature = "serde-1", derive(Serialize, Deserialize))]
pub struct BitSet(DefaultVec<Elt>);

#[inline]
fn split(x: usize) -> (usize, Elt) {
let offset = (x % Elt::BITS as usize) as u32;
(x / Elt::BITS as usize, 1 << offset)
}

impl BitSet {
pub fn insert(&mut self, x: usize) -> bool {
let (chunk_idx, mask) = split(x);
let chunk = self.0.get_mut(chunk_idx);
let res = (*chunk & mask) != 0;
*chunk |= mask;
res
}
pub fn remove(&mut self, x: usize) -> bool {
let (chunk_idx, mask) = split(x);
let chunk = self.0.get_mut(chunk_idx);
let res = (*chunk & mask) != 0;
*chunk &= !mask;
res
}
pub fn contains(&self, x: usize) -> bool {
let (chunk_idx, mask) = split(x);
let chunk = self.0.get(chunk_idx);
(chunk & mask) != 0
}

/// Same as contains but already reserves space for `x`
pub fn contains_mut(&mut self, x: usize) -> bool {
let (chunk_idx, mask) = split(x);
let chunk = *self.0.get_mut(chunk_idx);
(chunk & mask) != 0
}

pub fn clear(&mut self) {
self.0.clear()
}

pub fn iter(&self) -> impl Iterator<Item = usize> + '_ {
let max = self.0 .0.len() * (Elt::BITS as usize);
(0..max).filter(|x| self.contains(*x))
}
}

impl Debug for BitSet {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_set().entries(self.iter()).finish()
}
}
29 changes: 28 additions & 1 deletion src/raw/egraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use std::{
iter, slice,
};

use crate::raw::bitset::BitSet;
use crate::raw::dhashmap::*;
use crate::raw::UndoLogT;
#[cfg(feature = "serde-1")]
Expand Down Expand Up @@ -359,6 +360,8 @@ pub struct RawEGraph<L: Language, D, U = ()> {
/// Nodes which need to be processed for rebuilding. The `Id` is the `Id` of the enode,
/// not the canonical id of the eclass.
pub(super) pending: Vec<Id>,
/// `Id`s that are congruently equivalent to another `Id` that is not in this set
pub(super) congruence_duplicates: BitSet,
pub(super) classes: HashMap<Id, RawEClass<D>>,
pub(super) undo_log: U,
}
Expand All @@ -373,6 +376,7 @@ impl<L: Language, D, U: Default> Default for RawEGraph<L, D, U> {
RawEGraph {
residual,
pending: Default::default(),
congruence_duplicates: Default::default(),
classes: Default::default(),
undo_log: Default::default(),
}
Expand Down Expand Up @@ -417,6 +421,7 @@ impl<L: Language, D: Debug, U> Debug for RawEGraph<L, D, U> {
f.debug_struct("EGraph")
.field("memo", &self.residual.memo)
.field("classes", &classes)
.field("congruence_duplicates", &self.congruence_duplicates)
.finish()
}
}
Expand Down Expand Up @@ -767,6 +772,13 @@ impl<L: Language, D, U: UndoLogT<L, D>> RawEGraph<L, D, U> {
loop {
let this = get_self(outer);
if let Some(class) = this.pending.pop() {
if this.congruence_duplicates.contains_mut(class.into()) {
// `class` is congruently equivalent to another node `croot`,
// so each node that has `class` as a parent also has `croot` as a parent,
// and they will always be added to pending together, so we only need to handle
// `croot` when it comes up, but not `class`
continue;
}
let mut node = this.id_to_node(class).clone();
node.update_children(|id| this.find_mut(id));
handle_pending(outer, class, &node);
Expand All @@ -777,7 +789,20 @@ impl<L: Language, D, U: UndoLogT<L, D>> RawEGraph<L, D, U> {
let memo_class = *id;
let pre = pre_union(orig, &new);
match perform_union(outer, pre, memo_class, class) {
Ok(()) => {}
Ok(()) => {
let this = get_self(outer);
debug_assert_eq!(
this.find(memo_class),
this.find(class),
"`perform_union` didn't perform_union"
);
// class is congruently equivalent to memo_class which isn't in
// congruence_duplicates
if class != memo_class {
this.congruence_duplicates.insert(class.into());
this.undo_log.add_congruence_duplicate(class);
}
}
Err(e) => {
get_self(outer).pending.push(class);
return Err(e);
Expand Down Expand Up @@ -816,6 +841,8 @@ impl<L: Language, D, U: UndoLogT<L, D>> RawEGraph<L, D, U> {
self.residual.memo.clear();
self.residual.unionfind.clear();
self.pending.clear();
self.congruence_duplicates.clear();
self.classes.clear();
self.undo_log.clear();
}
}
Expand Down
11 changes: 11 additions & 0 deletions src/raw/semi_persistent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ pub trait UndoLogT<L, D>: Default + Debug + Sealed {
#[doc(hidden)]
fn insert_memo(&mut self, hash: u64);

#[doc(hidden)]
fn add_congruence_duplicate(&mut self, id: Id);

#[doc(hidden)]
fn clear(&mut self);

Expand All @@ -35,6 +38,8 @@ impl<L, D> UndoLogT<L, D> for () {
#[inline]
fn insert_memo(&mut self, _: u64) {}

fn add_congruence_duplicate(&mut self, _: Id) {}

#[inline]
fn clear(&mut self) {}

Expand Down Expand Up @@ -65,6 +70,12 @@ impl<L, D, U: UndoLogT<L, D>> UndoLogT<L, D> for Option<U> {
}
}

fn add_congruence_duplicate(&mut self, id: Id) {
if let Some(undo) = self {
undo.add_congruence_duplicate(id)
}
}

#[inline]
fn clear(&mut self) {
if let Some(undo) = self {
Expand Down
18 changes: 18 additions & 0 deletions src/raw/semi_persistent1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub struct PushInfo {
union_count: u32,
memo_log_count: u32,
pop_parents_count: u32,
congr_dup_count: u32,
}

impl PushInfo {
Expand Down Expand Up @@ -44,6 +45,7 @@ pub struct UndoLog {
pop_parents: Vec<Id>,
union_log: Vec<UnionInfo>,
memo_log: Vec<u64>,
congr_dup_log: Vec<Id>,
}

impl Default for UndoLog {
Expand All @@ -57,6 +59,7 @@ impl Default for UndoLog {
added_after: 0,
}],
memo_log: Default::default(),
congr_dup_log: vec![],
}
}
}
Expand Down Expand Up @@ -84,11 +87,16 @@ impl<L: Language, D> UndoLogT<L, D> for UndoLog {
self.memo_log.push(hash);
}

fn add_congruence_duplicate(&mut self, id: Id) {
self.congr_dup_log.push(id);
}

fn clear(&mut self) {
self.union_log.truncate(1);
self.union_log[0].added_after = 0;
self.memo_log.clear();
self.undo_find.clear();
self.congr_dup_log.clear();
}

#[inline]
Expand Down Expand Up @@ -131,6 +139,7 @@ impl<L: Language, D, U: AsUnwrap<UndoLog>> RawEGraph<L, D, U> {
union_count: undo.union_log.len() as u32,
memo_log_count: undo.memo_log.len() as u32,
pop_parents_count: undo.pop_parents.len() as u32,
congr_dup_count: undo.congr_dup_log.len() as u32,
}
}

Expand All @@ -149,8 +158,17 @@ impl<L: Language, D, U: AsUnwrap<UndoLog>> RawEGraph<L, D, U> {
union_count,
memo_log_count,
pop_parents_count,
congr_dup_count,
} = info;
self.pending.clear();
for id in self
.undo_log
.as_mut_unwrap()
.congr_dup_log
.drain(congr_dup_count as usize..)
{
self.congruence_duplicates.remove(id.into());
}
self.pop_memo1(memo_log_count as usize);
self.pop_unions1(union_count as usize, pop_parents_count as usize, split);
self.pop_nodes1(node_count as usize);
Expand Down
22 changes: 20 additions & 2 deletions src/raw/semi_persistent2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub struct PushInfo {
union_count: u32,
memo_log_count: u32,
pop_parents_count: u32,
congr_dup_count: u32,
}

impl PushInfo {
Expand All @@ -53,6 +54,7 @@ pub struct UndoLog {
union_log: Vec<Id>,
memo_log: Vec<u64>,
pop_parents: Vec<Id>,
congr_dup_log: Vec<Id>,
// Scratch space, should be empty other that when inside `pop`
#[cfg_attr(feature = "serde-1", serde(skip))]
dirty: HashSet<Id>,
Expand All @@ -79,10 +81,15 @@ impl<L: Language, D> UndoLogT<L, D> for UndoLog {
self.memo_log.push(hash);
}

fn add_congruence_duplicate(&mut self, id: Id) {
self.congr_dup_log.push(id);
}

fn clear(&mut self) {
self.union_log.clear();
self.memo_log.clear();
self.undo_find.clear();
self.congr_dup_log.clear();
}

fn is_enabled(&self) -> bool {
Expand Down Expand Up @@ -135,6 +142,7 @@ impl<L: Language, D, U: AsUnwrap<UndoLog>> RawEGraph<L, D, U> {
union_count: undo.union_log.len() as u32,
memo_log_count: undo.memo_log.len() as u32,
pop_parents_count: undo.pop_parents.len() as u32,
congr_dup_count: undo.congr_dup_log.len() as u32,
}
}

Expand Down Expand Up @@ -164,10 +172,20 @@ impl<L: Language, D, U: AsUnwrap<UndoLog>> RawEGraph<L, D, U> {
union_count,
memo_log_count,
pop_parents_count,
congr_dup_count,
} = info;
self.pending.clear();
for id in self
.undo_log
.as_mut_unwrap()
.congr_dup_log
.drain(congr_dup_count as usize..)
{
self.congruence_duplicates.remove(id.into());
}
self.pending.clear();
self.pop_memo2(memo_log_count as usize);
self.pop_parents2(pop_parents_count as usize, node_count as usize);
self.pop_parents2(pop_parents_count as usize);
self.pop_unions2(
union_count as usize,
node_count as usize,
Expand All @@ -188,7 +206,7 @@ impl<L: Language, D, U: AsUnwrap<UndoLog>> RawEGraph<L, D, U> {
}
}

fn pop_parents2(&mut self, old_count: usize, node_count: usize) {
fn pop_parents2(&mut self, old_count: usize) {
let undo = self.undo_log.as_mut_unwrap();

for id in undo.pop_parents.drain(old_count..) {
Expand Down
1 change: 1 addition & 0 deletions src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ pub fn test_runner<L, A>(
// Test push if feature is on
if cfg!(feature = "test-push-pop") {
runner.egraph = runner.egraph.with_push_pop_enabled();
history2.borrow_mut().push(EGraph::clone(&runner.egraph));
runner = runner.with_hook(move |runner| {
runner.egraph.push();
history2.borrow_mut().push(EGraph::clone(&runner.egraph));
Expand Down

0 comments on commit 1c7ecbb

Please sign in to comment.