From 6dc6ca4413f7a3028d2e39d46900fca3bb6998da Mon Sep 17 00:00:00 2001 From: dewert99 Date: Fri, 20 Dec 2024 17:49:27 -0800 Subject: [PATCH] Allow path compression to be disabled but undo log --- src/dot.rs | 13 ++++----- src/explain.rs | 4 +-- src/raw.rs | 3 +++ src/raw/egraph.rs | 54 +++++++++++++++++++++++-------------- src/raw/reflect_const.rs | 14 ++++++++++ src/raw/semi_persistent.rs | 7 +++++ src/raw/semi_persistent1.rs | 53 +++++++++++++----------------------- src/raw/semi_persistent2.rs | 7 +++-- src/raw/unionfind.rs | 10 +++++-- 9 files changed, 99 insertions(+), 66 deletions(-) create mode 100644 src/raw/reflect_const.rs diff --git a/src/dot.rs b/src/dot.rs index c2dd9de6..8da52f06 100644 --- a/src/dot.rs +++ b/src/dot.rs @@ -9,6 +9,7 @@ Use the [`Dot`] struct to visualize an [`EGraph`](crate::EGraph) use no_std_compat::prelude::v1::*; use std::fmt::{self, Debug, Display, Formatter}; +use crate::raw::reflect_const::PathCompressT; use crate::{raw::EGraphResidual, raw::Language}; /** @@ -48,8 +49,8 @@ instead of to its own eclass. [GraphViz]: https://graphviz.gitlab.io/ **/ -pub struct Dot<'a, L: Language> { - pub(crate) egraph: &'a EGraphResidual, +pub struct Dot<'a, L: Language, P: PathCompressT> { + pub(crate) egraph: &'a EGraphResidual, /// A list of strings to be output top part of the dot file. pub config: Vec, /// Whether or not to anchor the edges in the output. @@ -57,7 +58,7 @@ pub struct Dot<'a, L: Language> { pub use_anchors: bool, } -impl<'a, L> Dot<'a, L> +impl<'a, L, P: PathCompressT> Dot<'a, L, P> where L: Language + Display, { @@ -100,7 +101,7 @@ mod std_only { use std::io::{Error, ErrorKind, Result, Write}; use std::path::Path; - impl<'a, L: Language + Display> Dot<'a, L> { + impl<'a, L: Language + Display, P: PathCompressT> Dot<'a, L, P> { /// Writes the `Dot` to a .dot file with the given filename. /// Does _not_ require a `dot` binary. pub fn to_dot(&self, filename: impl AsRef) -> Result<()> { @@ -177,13 +178,13 @@ mod std_only { } } -impl<'a, L: Language> Debug for Dot<'a, L> { +impl<'a, L: Language, P: PathCompressT> Debug for Dot<'a, L, P> { fn fmt(&self, f: &mut Formatter) -> fmt::Result { f.debug_tuple("Dot").field(self.egraph).finish() } } -impl<'a, L> Display for Dot<'a, L> +impl<'a, L, P: PathCompressT> Display for Dot<'a, L, P> where L: Language + Display, { diff --git a/src/explain.rs b/src/explain.rs index 8a5efe17..53006649 100644 --- a/src/explain.rs +++ b/src/explain.rs @@ -15,7 +15,7 @@ use std::mem; use std::ops::{Deref, DerefMut}; use std::rc::Rc; -use crate::raw::RawEGraph; +use crate::raw::{RawEGraph, UndoLogT}; use symbolic_expressions::Sexp; type ProofCost = Saturating; @@ -1094,7 +1094,7 @@ impl<'a, L: Language, X> DerefMut for ExplainWith<'a, L, X> { } } -impl<'x, L: Language, D, U> ExplainWith<'x, L, &'x RawEGraph> { +impl<'x, L: Language, D, U: UndoLogT> ExplainWith<'x, L, &'x RawEGraph> { pub(crate) fn node(&self, node_id: Id) -> &L { self.raw.id_to_node(node_id) } diff --git a/src/raw.rs b/src/raw.rs index 11e9a80f..86f02fd5 100644 --- a/src/raw.rs +++ b/src/raw.rs @@ -12,6 +12,9 @@ pub mod semi_persistent2; mod unionfind; pub(crate) mod util; +/// Types and traits for specify whether path compression is supported +pub mod reflect_const; + pub use eclass::RawEClass; pub use egraph::{EGraphResidual, RawEGraph, UnionInfo}; pub use language::*; diff --git a/src/raw/egraph.rs b/src/raw/egraph.rs index f3e47905..e3449d1a 100644 --- a/src/raw/egraph.rs +++ b/src/raw/egraph.rs @@ -11,6 +11,7 @@ use std::{ }; use crate::raw::dhashmap::*; +use crate::raw::reflect_const::{PathCompress, PathCompressT}; use crate::raw::UndoLogT; use default_vec2::BitSet; #[cfg(feature = "serde-1")] @@ -35,8 +36,8 @@ impl<'a> IntoIterator for Parents<'a> { /// See [`RawEGraph::classes_mut`], [`RawEGraph::get_class_mut`] #[derive(Clone)] #[cfg_attr(feature = "serde-1", derive(Serialize, Deserialize))] -pub struct EGraphResidual { - pub(super) unionfind: UnionFind, +pub struct EGraphResidual> { + pub(super) unionfind: UnionFind

, /// Stores the original node represented by each non-canonical id pub(super) nodes: Vec, /// Stores each enode's `Id`, not the `Id` of the eclass. @@ -46,7 +47,7 @@ pub struct EGraphResidual { pub(super) memo: DHashMap, } -impl EGraphResidual { +impl EGraphResidual { /// Pick a representative term for a given Id. /// /// Calling this function on an uncanonical `Id` returns a representative based on how it @@ -308,7 +309,7 @@ impl EGraphResidual { } /// Creates a [`Dot`] to visualize this egraph. See [`Dot`]. - pub fn dot(&self) -> Dot<'_, L> { + pub fn dot(&self) -> Dot<'_, L, P> { Dot { egraph: self, config: vec![], @@ -317,8 +318,15 @@ impl EGraphResidual { } } +impl EGraphResidual> { + /// Return the direct parent from the union find without path compression + pub fn find_direct_parent(&self, id: Id) -> Id { + self.unionfind.parent_id(id) + } +} + // manual debug impl to avoid L: Language bound on EGraph defn -impl Debug for EGraphResidual { +impl Debug for EGraphResidual { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("EGraphResidual") .field("unionfind", &self.unionfind) @@ -356,9 +364,9 @@ to properly handle this data **/ #[derive(Clone)] #[cfg_attr(feature = "serde-1", derive(Serialize, Deserialize))] -pub struct RawEGraph { +pub struct RawEGraph = ()> { #[cfg_attr(feature = "serde-1", serde(flatten))] - pub(super) residual: EGraphResidual, + pub(super) residual: EGraphResidual, /// Nodes which need to be processed for rebuilding. The `Id` is the `Id` of the enode, /// not the canonical id of the eclass. pub(super) pending: Vec, @@ -368,7 +376,7 @@ pub struct RawEGraph { pub(super) undo_log: U, } -impl Default for RawEGraph { +impl> Default for RawEGraph { fn default() -> Self { let residual = EGraphResidual { unionfind: Default::default(), @@ -385,8 +393,8 @@ impl Default for RawEGraph { } } -impl Deref for RawEGraph { - type Target = EGraphResidual; +impl> Deref for RawEGraph { + type Target = EGraphResidual; #[inline] fn deref(&self) -> &Self::Target { @@ -394,7 +402,7 @@ impl Deref for RawEGraph { } } -impl DerefMut for RawEGraph { +impl> DerefMut for RawEGraph { #[inline] fn deref_mut(&mut self) -> &mut Self::Target { &mut self.residual @@ -402,7 +410,7 @@ impl DerefMut for RawEGraph { } // manual debug impl to avoid L: Language bound on EGraph defn -impl Debug for RawEGraph { +impl> Debug for RawEGraph { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let classes: BTreeMap<_, _> = self .classes @@ -428,7 +436,7 @@ impl Debug for RawEGraph { } } -impl RawEGraph { +impl> RawEGraph { /// Returns an iterator over the eclasses in the egraph. pub fn classes(&self) -> impl ExactSizeIterator> { self.classes.iter() @@ -440,7 +448,7 @@ impl RawEGraph { &mut self, ) -> ( impl ExactSizeIterator>, - &mut EGraphResidual, + &mut EGraphResidual, ) { let iter = self.classes.iter_mut(); (iter, &mut self.residual) @@ -470,7 +478,10 @@ impl RawEGraph { pub fn get_class_mut>( &mut self, mut id: I, - ) -> (&mut RawEClass, &mut EGraphResidual) { + ) -> ( + &mut RawEClass, + &mut EGraphResidual, + ) { let id = id.borrow_mut(); let (nid, cid) = self.unionfind.find_mut_full(*id); *id = nid; @@ -481,7 +492,10 @@ impl RawEGraph { pub fn get_class_mut_with_cannon( &mut self, id: Id, - ) -> (&mut RawEClass, &mut EGraphResidual) { + ) -> ( + &mut RawEClass, + &mut EGraphResidual, + ) { let cid = self.unionfind.find_canon(id); (&mut self.classes[cid.idx()], &mut self.residual) } @@ -900,9 +914,9 @@ impl> RawEGraph { } } -struct EGraphUncanonicalDump<'a, L: Language>(&'a EGraphResidual); +struct EGraphUncanonicalDump<'a, L: Language, P: PathCompressT>(&'a EGraphResidual); -impl<'a, L: Language> Debug for EGraphUncanonicalDump<'a, L> { +impl<'a, L: Language, P: PathCompressT> Debug for EGraphUncanonicalDump<'a, L, P> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { for (id, node) in self.0.uncanonical_nodes() { writeln!(f, "{}: {:?} (root={})", id, node, self.0.find(id))? @@ -911,9 +925,9 @@ impl<'a, L: Language> Debug for EGraphUncanonicalDump<'a, L> { } } -struct EGraphDump<'a, L: Language, D, U>(&'a RawEGraph); +struct EGraphDump<'a, L: Language, D, U: UndoLogT>(&'a RawEGraph); -impl<'a, L: Language, D: Debug, U> Debug for EGraphDump<'a, L, D, U> { +impl<'a, L: Language, D: Debug, U: UndoLogT> Debug for EGraphDump<'a, L, D, U> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let mut ids: Vec = self.0.classes().map(|c| c.id).collect(); ids.sort(); diff --git a/src/raw/reflect_const.rs b/src/raw/reflect_const.rs new file mode 100644 index 00000000..a350b638 --- /dev/null +++ b/src/raw/reflect_const.rs @@ -0,0 +1,14 @@ +#![allow(missing_docs)] +use core::fmt::Debug; + +#[derive(Copy, Clone, Eq, PartialEq, Default, Debug)] +#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] +pub struct PathCompress; + +impl PathCompressT for PathCompress { + const PATH_COMPRESS: bool = B; +} + +pub trait PathCompressT: Copy + Clone + Eq + PartialEq + Default + Debug { + const PATH_COMPRESS: bool; +} diff --git a/src/raw/semi_persistent.rs b/src/raw/semi_persistent.rs index 5b88b5c5..cf69a5c6 100644 --- a/src/raw/semi_persistent.rs +++ b/src/raw/semi_persistent.rs @@ -1,3 +1,4 @@ +use crate::raw::reflect_const::{PathCompress, PathCompressT}; use crate::raw::{Language, RawEGraph}; use crate::{ClassId, Id}; use no_std_compat::prelude::v1::*; @@ -10,6 +11,8 @@ impl Sealed for Option {} /// A sealed trait for types that can be used for `push`/`pop` APIs /// It is trivially implemented for `()` pub trait UndoLogT: Default + Debug + Sealed { + /// When this type of undo log allows for path compression + type AllowPathCompress: PathCompressT; #[doc(hidden)] fn add_node(&mut self, node: &L, canon_children: &[Id], node_id: Id, cid: ClassId); @@ -33,6 +36,8 @@ pub trait UndoLogT: Default + Debug + Sealed { } impl UndoLogT for () { + type AllowPathCompress = PathCompress; + #[inline] fn add_node(&mut self, _: &L, _: &[Id], _: Id, _: ClassId) {} @@ -55,6 +60,8 @@ impl UndoLogT for () { } impl> UndoLogT for Option { + type AllowPathCompress = U::AllowPathCompress; + #[inline] fn add_node(&mut self, node: &L, canon_children: &[Id], node_id: Id, cid: ClassId) { if let Some(undo) = self { diff --git a/src/raw/semi_persistent1.rs b/src/raw/semi_persistent1.rs index 572c4434..2b853a80 100644 --- a/src/raw/semi_persistent1.rs +++ b/src/raw/semi_persistent1.rs @@ -1,4 +1,5 @@ -use crate::raw::{AsUnwrap, Language, RawEClass, RawEGraph, Sealed, UndoLogT, UnionFind}; +use crate::raw::reflect_const::PathCompress; +use crate::raw::{AsUnwrap, Language, RawEClass, RawEGraph, Sealed, UndoLogT}; use crate::{ClassId, Id}; use core::mem; use no_std_compat::prelude::v1::*; @@ -44,8 +45,6 @@ struct UnionInfo { #[derive(Clone, Debug)] #[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] pub struct UndoLog { - // Mirror of the union find without path compression - undo_find: UnionFind, pop_parents: Vec, union_log: Vec, memo_log: Vec, @@ -55,7 +54,6 @@ pub struct UndoLog { impl Default for UndoLog { fn default() -> Self { UndoLog { - undo_find: Default::default(), pop_parents: Default::default(), union_log: vec![UnionInfo { old_id: Id::from(0), @@ -73,17 +71,16 @@ impl Default for UndoLog { impl Sealed for UndoLog {} impl UndoLogT for UndoLog { - fn add_node(&mut self, _: &L, canon_children: &[Id], node_id: Id, cid: ClassId) { - let new = self.undo_find.make_set_with_id(cid); - debug_assert_eq!(new, node_id); + type AllowPathCompress = PathCompress; + + fn add_node(&mut self, _: &L, canon_children: &[Id], _: Id, _: ClassId) { self.pop_parents.extend(canon_children); let log = self.union_log.last_mut().unwrap(); log.parents_added_after += canon_children.len() as u32; log.nodes_added_after += 1; } - fn union(&mut self, id1: Id, id2: Id, old_parents: Vec, old_cid: ClassId) { - self.undo_find.union(id1, id2); + fn union(&mut self, _: Id, id2: Id, old_parents: Vec, old_cid: ClassId) { self.union_log.push(UnionInfo { old_id: id2, parents_added_after: 0, @@ -93,9 +90,7 @@ impl UndoLogT for UndoLog { }) } - fn fix_id(&mut self, fixup_id: Id, cid: ClassId) { - self.undo_find.reset_root(fixup_id, cid) - } + fn fix_id(&mut self, _: Id, _: ClassId) {} fn insert_memo(&mut self, hash: u64) { self.memo_log.push(hash); @@ -109,7 +104,6 @@ impl UndoLogT for UndoLog { self.union_log.truncate(1); self.union_log[0].parents_added_after = 0; self.memo_log.clear(); - self.undo_find.clear(); self.congr_dup_log.clear(); } @@ -119,7 +113,7 @@ impl UndoLogT for UndoLog { } } -impl> RawEGraph { +impl + UndoLogT> RawEGraph { /// Create a [`PushInfo`] representing the current state of the egraph /// which can later be passed into [`raw_pop1`](RawEGraph::raw_pop1) /// @@ -193,11 +187,6 @@ impl> RawEGraph { self.pop_nodes1(node_count as usize); } - /// Return the direct parent from the union find without path compression - pub fn find_direct_parent(&self, id: Id) -> Id { - self.undo_log.as_unwrap().undo_find.parent_id(id) - } - fn pop_memo1(&mut self, old_count: usize) { assert!(self.memo.len() >= old_count); let memo_log = &mut self.undo_log.as_mut_unwrap().memo_log; @@ -215,12 +204,12 @@ impl> RawEGraph { mut split: impl FnMut(&mut D, Id, Id) -> D, ) { let undo = self.undo_log.as_mut_unwrap(); - let mut node_count_rem = (undo.undo_find.size() - node_count) as u32; + let mut node_count_rem = (self.residual.unionfind.size() - node_count) as u32; assert!(self.residual.number_of_uncanonical_nodes() >= old_count); for info in undo.union_log.drain(old_count..).rev() { for _ in 0..info.parents_added_after { let id = undo.pop_parents.pop().unwrap(); - self.classes[undo.undo_find.find_canon(id).idx()] + self.classes[self.residual.unionfind.find_canon(id).idx()] .parents .pop(); } @@ -228,11 +217,11 @@ impl> RawEGraph { .truncate(self.classes.len() - info.nodes_added_after as usize); node_count_rem -= info.nodes_added_after; let old_id = info.old_id; - let new_id = undo.undo_find.parent_id(old_id); + let new_id = self.residual.unionfind.parent_id(old_id); debug_assert_ne!(new_id, old_id); - debug_assert_eq!(undo.undo_find.find(new_id), new_id); - let new_cid = undo.undo_find.find_canon(new_id); - undo.undo_find.reset_root(old_id, info.old_cid); + debug_assert_eq!(self.residual.unionfind.find(new_id), new_id); + let new_cid = self.residual.unionfind.find_canon(new_id); + self.residual.unionfind.reset_root(old_id, info.old_cid); let new_class = &mut self.classes[new_cid.idx()]; debug_assert_eq!(new_class.id, new_id); let cut = new_class.parents.len() - info.old_parents.len(); @@ -246,7 +235,8 @@ impl> RawEGraph { }; if info.old_cid.idx() != self.classes.len() { mem::swap(&mut self.classes[info.old_cid.idx()], &mut class); - undo.undo_find + self.residual + .unionfind .reset_root(class.id, ClassId::from(self.classes.len())) } self.classes.push(class) @@ -254,7 +244,7 @@ impl> RawEGraph { let parent_rem = undo.pop_parents.len() - pop_parents_count; for _ in 0..parent_rem { let id = undo.pop_parents.pop().unwrap(); - self.classes[undo.undo_find.find_canon(id).idx()] + self.classes[self.residual.unionfind.find_canon(id).idx()] .parents .pop(); } @@ -267,17 +257,12 @@ impl> RawEGraph { fn pop_nodes1(&mut self, old_count: usize) { assert!(self.number_of_uncanonical_nodes() >= old_count); - let undo = self.undo_log.as_mut_unwrap(); - undo.undo_find.parents.truncate(old_count); - self.residual - .unionfind - .parents - .clone_from(&undo.undo_find.parents); + self.residual.unionfind.parents.truncate(old_count); self.residual.nodes.truncate(old_count); } } -impl> RawEGraph { +impl + UndoLogT> RawEGraph { /// Simplified version of [`raw_pop1`](RawEGraph::raw_pop1) for egraphs without eclass data pub fn pop1(&mut self, info: PushInfo) { self.raw_pop1(info, |_, _, _| ()) diff --git a/src/raw/semi_persistent2.rs b/src/raw/semi_persistent2.rs index 437e0ee2..34f82aef 100644 --- a/src/raw/semi_persistent2.rs +++ b/src/raw/semi_persistent2.rs @@ -1,3 +1,4 @@ +use crate::raw::reflect_const::PathCompress; use crate::raw::unionfind::UnionFindElt; use crate::raw::util::HashSet; use crate::raw::{AsUnwrap, Language, RawEClass, RawEGraph, Sealed, UndoLogT}; @@ -65,6 +66,8 @@ pub struct UndoLog { impl Sealed for UndoLog {} impl UndoLogT for UndoLog { + type AllowPathCompress = PathCompress; + fn add_node(&mut self, _: &L, canon: &[Id], node_id: Id, _: ClassId) { debug_assert_eq!(self.undo_find.len(), usize::from(node_id)); self.undo_find.push(UndoNode::default()); @@ -101,7 +104,7 @@ impl UndoLogT for UndoLog { } } -impl> RawEGraph { +impl + UndoLogT> RawEGraph { /// Create a [`PushInfo`] representing the current state of the egraph /// which can later be passed into [`raw_pop2`](RawEGraph::raw_pop2) /// @@ -357,7 +360,7 @@ impl<'a, L> UndoCtx<'a, L> { } } -impl> RawEGraph { +impl + UndoLogT> RawEGraph { /// Simplified version of [`raw_pop2`](RawEGraph::raw_pop2) for egraphs without eclass data pub fn pop2(&mut self, info: PushInfo) { self.raw_pop2( diff --git a/src/raw/unionfind.rs b/src/raw/unionfind.rs index fb4c199e..7f49b3d3 100644 --- a/src/raw/unionfind.rs +++ b/src/raw/unionfind.rs @@ -1,4 +1,6 @@ +use crate::raw::reflect_const::{PathCompress, PathCompressT}; use crate::{ClassId, Id, U31_MAX}; +use core::marker::PhantomData; use no_std_compat::prelude::v1::*; use std::fmt::{Debug, Formatter}; @@ -50,11 +52,12 @@ impl Debug for RawUnionFindElt { #[derive(Debug, Clone, Default)] #[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] /// Data structure that stores disjoint sets of `Id`s each with a representative -pub struct UnionFind { +pub struct UnionFind> { pub(super) parents: Vec, + phantom: PhantomData

, } -impl UnionFind { +impl UnionFind

{ /// Creates a singleton set and returns its representative pub fn make_set(&mut self) -> Id { self.make_set_with_id(0.into()) @@ -123,6 +126,9 @@ impl UnionFind { } pub(super) fn find_mut_full(&mut self, mut current: Id) -> (Id, ClassId) { + if !P::PATH_COMPRESS { + return self.find_full(current); + } let canon = self.find(current); loop { match self.parent(current) {