diff --git a/Cargo.toml b/Cargo.toml index e2b3af6b..9decbb29 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ authors = ["Max Willsey "] categories = ["data-structures"] description = "An implementation of egraphs" -edition = "2018" +edition = "2021" keywords = ["e-graphs"] license = "MIT" name = "egg" diff --git a/src/dot.rs b/src/dot.rs index cefaf440..b68028ce 100644 --- a/src/dot.rs +++ b/src/dot.rs @@ -1,7 +1,7 @@ /*! EGraph visualization with [GraphViz] -Use the [`Dot`] struct to visualize an [`EGraph`] +Use the [`Dot`] struct to visualize an [`EGraph`](crate::EGraph) [GraphViz]: https://graphviz.gitlab.io/ !*/ @@ -11,13 +11,13 @@ use std::fmt::{self, Debug, Display, Formatter}; use std::io::{Error, ErrorKind, Result, Write}; use std::path::Path; -use crate::{egraph::EGraph, Analysis, Language}; +use crate::{raw::EGraphResidual, Language}; /** -A wrapper for an [`EGraph`] that can output [GraphViz] for +A wrapper for an [`EGraphResidual`] that can output [GraphViz] for visualization. -The [`EGraph::dot`](EGraph::dot()) method creates `Dot`s. +The [`EGraphResidual::dot`] method creates `Dot`s. # Example @@ -50,8 +50,8 @@ instead of to its own eclass. [GraphViz]: https://graphviz.gitlab.io/ **/ -pub struct Dot<'a, L: Language, N: Analysis> { - pub(crate) egraph: &'a EGraph, +pub struct Dot<'a, L: Language> { + pub(crate) egraph: &'a EGraphResidual, /// A list of strings to be output top part of the dot file. pub config: Vec, /// Whether or not to anchor the edges in the output. @@ -59,10 +59,9 @@ pub struct Dot<'a, L: Language, N: Analysis> { pub use_anchors: bool, } -impl<'a, L, N> Dot<'a, L, N> +impl<'a, L> Dot<'a, L> where L: Language + Display, - N: Analysis, { /// Writes the `Dot` to a .dot file with the given filename. /// Does _not_ require a `dot` binary. @@ -170,16 +169,15 @@ where } } -impl<'a, L: Language, N: Analysis> Debug for Dot<'a, L, N> { +impl<'a, L: Language> Debug for Dot<'a, L> { fn fmt(&self, f: &mut Formatter) -> fmt::Result { f.debug_tuple("Dot").field(self.egraph).finish() } } -impl<'a, L, N> Display for Dot<'a, L, N> +impl<'a, L> Display for Dot<'a, L> where L: Language + Display, - N: Analysis, { fn fmt(&self, f: &mut Formatter) -> fmt::Result { writeln!(f, "digraph egraph {{")?; @@ -192,17 +190,19 @@ where writeln!(f, " {}", line)?; } + let classes = self.egraph.generate_class_nodes(); + // define all the nodes, clustered by eclass - for class in self.egraph.classes() { - writeln!(f, " subgraph cluster_{} {{", class.id)?; + for (&id, class) in &classes { + writeln!(f, " subgraph cluster_{} {{", id)?; writeln!(f, " style=dotted")?; for (i, node) in class.iter().enumerate() { - writeln!(f, " {}.{}[label = \"{}\"]", class.id, i, node)?; + writeln!(f, " {}.{}[label = \"{}\"]", id, i, node)?; } writeln!(f, " }}")?; } - for class in self.egraph.classes() { + for (&id, class) in &classes { for (i_in_class, node) in class.iter().enumerate() { let mut arg_i = 0; node.try_for_each(|child| { @@ -210,19 +210,19 @@ where let (anchor, label) = self.edge(arg_i, node.len()); let child_leader = self.egraph.find(child); - if child_leader == class.id { + if child_leader == id { writeln!( f, // {}.0 to pick an arbitrary node in the cluster " {}.{}{} -> {}.{}:n [lhead = cluster_{}, {}]", - class.id, i_in_class, anchor, class.id, i_in_class, class.id, label + id, i_in_class, anchor, id, i_in_class, id, label )?; } else { writeln!( f, // {}.0 to pick an arbitrary node in the cluster " {}.{}{} -> {}.0 [lhead = cluster_{}, {}]", - class.id, i_in_class, anchor, child, child_leader, label + id, i_in_class, anchor, child, child_leader, label )?; } arg_i += 1; diff --git a/src/eclass.rs b/src/eclass.rs index 5f74b2c2..e235d58e 100644 --- a/src/eclass.rs +++ b/src/eclass.rs @@ -1,15 +1,13 @@ -use std::fmt::Debug; +use std::fmt::{Debug, Formatter}; use std::iter::ExactSizeIterator; use crate::*; -/// An equivalence class of enodes. +/// The additional data required to turn a [`raw::RawEClass`] into a [`EClass`] #[non_exhaustive] -#[derive(Debug, Clone)] +#[derive(Clone)] #[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] -pub struct EClass { - /// This eclass's id. - pub id: Id, +pub struct EClassData { /// The equivalent enodes in this equivalence class. pub nodes: Vec, /// The analysis data associated with this eclass. @@ -17,10 +15,19 @@ pub struct EClass { /// Modifying this field will _not_ cause changes to propagate through the e-graph. /// Prefer [`EGraph::set_analysis_data`] instead. pub data: D, - /// The parent enodes and their original Ids. - pub(crate) parents: Vec<(L, Id)>, } +impl Debug for EClassData { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let mut nodes = self.nodes.clone(); + nodes.sort(); + write!(f, "({:?}): {:?}", self.data, nodes) + } +} + +/// An equivalence class of enodes +pub type EClass = raw::RawEClass>; + impl EClass { /// Returns `true` if the `eclass` is empty. pub fn is_empty(&self) -> bool { @@ -36,11 +43,6 @@ impl EClass { pub fn iter(&self) -> impl ExactSizeIterator { self.nodes.iter() } - - /// Iterates over the parent enodes of this eclass. - pub fn parents(&self) -> impl ExactSizeIterator { - self.parents.iter().map(|(node, id)| (node, *id)) - } } impl EClass { diff --git a/src/egraph.rs b/src/egraph.rs index 6af452b2..3f35eddb 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -1,12 +1,12 @@ use crate::*; -use std::{ - borrow::BorrowMut, - fmt::{self, Debug, Display}, -}; +use std::fmt::{self, Debug, Display}; +use std::ops::Deref; #[cfg(feature = "serde-1")] use serde::{Deserialize, Serialize}; +use crate::eclass::EClassData; +use crate::raw::{EGraphResidual, RawEGraph}; use log::*; /** A data structure to keep track of equalities between expressions. @@ -56,16 +56,7 @@ pub struct EGraph> { pub analysis: N, /// The `Explain` used to explain equivalences in this `EGraph`. pub(crate) explain: Option>, - unionfind: UnionFind, - /// Stores each enode's `Id`, not the `Id` of the eclass. - /// Enodes in the memo are canonicalized at each rebuild, but after rebuilding new - /// unions can cause them to become out of date. - #[cfg_attr(feature = "serde-1", serde(with = "vectorize"))] - memo: HashMap, - /// Nodes which need to be processed for rebuilding. The `Id` is the `Id` of the enode, - /// not the canonical id of the eclass. - pending: Vec<(L, Id)>, - analysis_pending: UniqueQueue<(L, Id)>, + analysis_pending: UniqueQueue, #[cfg_attr( feature = "serde-1", serde(bound( @@ -73,7 +64,7 @@ pub struct EGraph> { deserialize = "N::Data: for<'a> Deserialize<'a>", )) )] - pub(crate) classes: HashMap>, + pub(crate) inner: RawEGraph>, #[cfg_attr(feature = "serde-1", serde(skip))] #[cfg_attr(feature = "serde-1", serde(default = "default_classes_by_op"))] pub(crate) classes_by_op: HashMap>, @@ -100,10 +91,16 @@ impl + Default> Default for EGraph { // manual debug impl to avoid L: Language bound on EGraph defn impl> Debug for EGraph { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("EGraph") - .field("memo", &self.memo) - .field("classes", &self.classes) - .finish() + self.inner.fmt(f) + } +} + +impl> Deref for EGraph { + type Target = EGraphResidual; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.inner } } @@ -112,12 +109,9 @@ impl> EGraph { pub fn new(analysis: N) -> Self { Self { analysis, - classes: Default::default(), - unionfind: Default::default(), clean: false, explain: None, - pending: Default::default(), - memo: Default::default(), + inner: Default::default(), analysis_pending: Default::default(), classes_by_op: Default::default(), } @@ -125,25 +119,12 @@ impl> EGraph { /// Returns an iterator over the eclasses in the egraph. pub fn classes(&self) -> impl ExactSizeIterator> { - self.classes.values() + self.inner.classes() } /// Returns an mutating iterator over the eclasses in the egraph. pub fn classes_mut(&mut self) -> impl ExactSizeIterator> { - self.classes.values_mut() - } - - /// Returns `true` if the egraph is empty - /// # Example - /// ``` - /// use egg::{*, SymbolLang as S}; - /// let mut egraph = EGraph::::default(); - /// assert!(egraph.is_empty()); - /// egraph.add(S::leaf("foo")); - /// assert!(!egraph.is_empty()); - /// ``` - pub fn is_empty(&self) -> bool { - self.memo.is_empty() + self.inner.classes_mut().0 } /// Returns the number of enodes in the `EGraph`. @@ -163,7 +144,7 @@ impl> EGraph { /// assert_eq!(egraph.number_of_classes(), 1); /// ``` pub fn total_size(&self) -> usize { - self.memo.len() + self.inner.total_size() } /// Iterates over the classes, returning the total number of nodes. @@ -173,7 +154,7 @@ impl> EGraph { /// Returns the number of eclasses in the egraph. pub fn number_of_classes(&self) -> usize { - self.classes.len() + self.classes().len() } /// Enable explanations for this `EGraph`. @@ -214,12 +195,14 @@ impl> EGraph { /// Make a copy of the egraph with the same nodes, but no unions between them. pub fn copy_without_unions(&self, analysis: N) -> Self { - if let Some(explain) = &self.explain { - let egraph = Self::new(analysis); - explain.populate_enodes(egraph) - } else { + if self.explain.is_none() { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get a copied egraph without unions"); } + let mut egraph = Self::new(analysis); + for (_, node) in self.uncanonical_nodes() { + egraph.add(node.clone()); + } + egraph } /// Performs the union between two egraphs. @@ -310,8 +293,8 @@ impl> EGraph { product_map: &mut HashMap<(Id, Id), Id>, ) { let res_id = Self::get_product_id(class1, class2, product_map); - for node1 in &self.classes[&class1].nodes { - for node2 in &other.classes[&class2].nodes { + for node1 in &self[class1].nodes { + for node2 in &other[class2].nodes { if node1.matches(node2) { let children1 = node1.children(); let children2 = node2.children(); @@ -333,38 +316,41 @@ impl> EGraph { } } - /// Pick a representative term for a given Id. - /// - /// Calling this function on an uncanonical `Id` returns a representative based on the how it - /// was obtained (see [`add_uncanoncial`](EGraph::add_uncanonical), - /// [`add_expr_uncanonical`](EGraph::add_expr_uncanonical)) - pub fn id_to_expr(&self, id: Id) -> RecExpr { - if let Some(explain) = &self.explain { - explain.node_to_recexpr(id) - } else { - panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get unique expressions per id"); - } - } - - /// Like [`id_to_expr`](EGraph::id_to_expr) but only goes one layer deep - pub fn id_to_node(&self, id: Id) -> &L { - if let Some(explain) = &self.explain { - explain.node(id) - } else { - panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get unique expressions per id"); - } - } - - /// Like [`id_to_expr`](EGraph::id_to_expr), but creates a pattern instead of a term. + /// Like [`id_to_expr`](EGraphResidual::id_to_expr), but creates a pattern instead of a term. /// When an eclass listed in the given substitutions is found, it creates a variable. /// It also adds this variable and the corresponding Id value to the resulting [`Subst`] - /// Otherwise it behaves like [`id_to_expr`](EGraph::id_to_expr). + /// Otherwise it behaves like [`id_to_expr`](EGraphResidual::id_to_expr). pub fn id_to_pattern(&self, id: Id, substitutions: &HashMap) -> (Pattern, Subst) { - if let Some(explain) = &self.explain { - explain.node_to_pattern(id, substitutions) - } else { - panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get unique patterns per id"); + let mut res = Default::default(); + let mut subst = Default::default(); + let mut cache = Default::default(); + self.id_to_pattern_internal(&mut res, id, substitutions, &mut subst, &mut cache); + (Pattern::new(res), subst) + } + + fn id_to_pattern_internal( + &self, + res: &mut PatternAst, + node_id: Id, + var_substitutions: &HashMap, + subst: &mut Subst, + cache: &mut HashMap, + ) -> Id { + if let Some(existing) = cache.get(&node_id) { + return *existing; } + let res_id = if let Some(existing) = var_substitutions.get(&node_id) { + let var = format!("?{}", node_id).parse().unwrap(); + subst.insert(var, *existing); + res.add(ENodeOrVar::Var(var)) + } else { + let new_node = self.id_to_node(node_id).clone().map_children(|child| { + self.id_to_pattern_internal(res, child, var_substitutions, subst, cache) + }); + res.add(ENodeOrVar::ENode(new_node)) + }; + cache.insert(node_id, res_id); + res_id } /// Get all the unions ever found in the egraph in terms of enode ids. @@ -390,8 +376,8 @@ impl> EGraph { /// Get the number of congruences between nodes in the egraph. /// Only available when explanations are enabled. pub fn get_num_congr(&mut self) -> usize { - if let Some(explain) = &self.explain { - explain.get_num_congr::(&self.classes, &self.unionfind) + if let Some(explain) = &mut self.explain { + explain.with_raw_egraph(&self.inner).get_num_congr() } else { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.") } @@ -399,11 +385,7 @@ impl> EGraph { /// Get the number of nodes in the egraph used for explanations. pub fn get_explanation_num_nodes(&mut self) -> usize { - if let Some(explain) = &self.explain { - explain.get_num_nodes() - } else { - panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.") - } + self.number_of_uncanonical_nodes() } /// When explanations are enabled, this function @@ -423,10 +405,10 @@ impl> EGraph { self.explain_id_equivalence(left, right) } - /// Equivalent to calling [`explain_equivalence`](EGraph::explain_equivalence)`(`[`id_to_expr`](EGraph::id_to_expr)`(left),` - /// [`id_to_expr`](EGraph::id_to_expr)`(right))` but more efficient + /// Equivalent to calling [`explain_equivalence`](EGraph::explain_equivalence)`(`[`id_to_expr`](EGraphResidual::id_to_expr)`(left),` + /// [`id_to_expr`](EGraphResidual::id_to_expr)`(right))` but more efficient /// - /// This function picks representatives using [`id_to_expr`](EGraph::id_to_expr) so choosing + /// This function picks representatives using [`id_to_expr`](EGraphResidual::id_to_expr) so choosing /// `Id`s returned by functions like [`add_uncanonical`](EGraph::add_uncanonical) is important /// to control explanations pub fn explain_id_equivalence(&mut self, left: Id, right: Id) -> Explanation { @@ -438,7 +420,9 @@ impl> EGraph { ); } if let Some(explain) = &mut self.explain { - explain.explain_equivalence::(left, right, &mut self.unionfind, &self.classes) + explain + .with_raw_egraph(&self.inner) + .explain_equivalence(left, right) } else { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.") } @@ -457,11 +441,11 @@ impl> EGraph { self.explain_existance_id(id) } - /// Equivalent to calling [`explain_existance`](EGraph::explain_existance)`(`[`id_to_expr`](EGraph::id_to_expr)`(id))` + /// Equivalent to calling [`explain_existance`](EGraph::explain_existance)`(`[`id_to_expr`](EGraphResidual::id_to_expr)`(id))` /// but more efficient fn explain_existance_id(&mut self, id: Id) -> Explanation { if let Some(explain) = &mut self.explain { - explain.explain_existance(id) + explain.with_raw_egraph(&self.inner).explain_existance(id) } else { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.") } @@ -475,7 +459,7 @@ impl> EGraph { ) -> Explanation { let id = self.add_instantiation_noncanonical(pattern, subst); if let Some(explain) = &mut self.explain { - explain.explain_existance(id) + explain.with_raw_egraph(&self.inner).explain_existance(id) } else { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.") } @@ -498,58 +482,20 @@ impl> EGraph { ); } if let Some(explain) = &mut self.explain { - explain.explain_equivalence::(left, right, &mut self.unionfind, &self.classes) + explain + .with_raw_egraph(&self.inner) + .explain_equivalence(left, right) } else { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations."); } } - - /// Canonicalizes an eclass id. - /// - /// This corresponds to the `find` operation on the egraph's - /// underlying unionfind data structure. - /// - /// # Example - /// ``` - /// use egg::{*, SymbolLang as S}; - /// let mut egraph = EGraph::::default(); - /// let x = egraph.add(S::leaf("x")); - /// let y = egraph.add(S::leaf("y")); - /// assert_ne!(egraph.find(x), egraph.find(y)); - /// - /// egraph.union(x, y); - /// egraph.rebuild(); - /// assert_eq!(egraph.find(x), egraph.find(y)); - /// ``` - pub fn find(&self, id: Id) -> Id { - self.unionfind.find(id) - } - - /// This is private, but internals should use this whenever - /// possible because it does path compression. - fn find_mut(&mut self, id: Id) -> Id { - self.unionfind.find_mut(id) - } - - /// Creates a [`Dot`] to visualize this egraph. See [`Dot`]. - /// - pub fn dot(&self) -> Dot { - Dot { - egraph: self, - config: vec![], - use_anchors: true, - } - } } /// Given an `Id` using the `egraph[id]` syntax, retrieve the e-class. impl> std::ops::Index for EGraph { type Output = EClass; fn index(&self, id: Id) -> &Self::Output { - let id = self.find(id); - self.classes - .get(&id) - .unwrap_or_else(|| panic!("Invalid id {}", id)) + self.inner.get_class(id) } } @@ -557,10 +503,7 @@ impl> std::ops::Index for EGraph { /// reference to the e-class. impl> std::ops::IndexMut for EGraph { fn index_mut(&mut self, id: Id) -> &mut Self::Output { - let id = self.find_mut(id); - self.classes - .get_mut(&id) - .unwrap_or_else(|| panic!("Invalid id {}", id)) + self.inner.get_class_mut(id).0 } } @@ -586,16 +529,16 @@ impl> EGraph { /// Similar to [`add_expr`](EGraph::add_expr) but the `Id` returned may not be canonical /// - /// Calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` return a copy of `expr` + /// Calling [`id_to_expr`](EGraphResidual::id_to_expr) on this `Id` return a copy of `expr` when explanations are enabled pub fn add_expr_uncanonical(&mut self, expr: &RecExpr) -> Id { let nodes = expr.as_ref(); let mut new_ids = Vec::with_capacity(nodes.len()); let mut new_node_q = Vec::with_capacity(nodes.len()); for node in nodes { let new_node = node.clone().map_children(|i| new_ids[usize::from(i)]); - let size_before = self.unionfind.size(); + let size_before = self.inner.number_of_uncanonical_nodes(); let next_id = self.add_uncanonical(new_node); - if self.unionfind.size() > size_before { + if self.inner.number_of_uncanonical_nodes() > size_before { new_node_q.push(true); } else { new_node_q.push(false); @@ -624,7 +567,7 @@ impl> EGraph { /// canonical /// /// Like [`add_uncanonical`](EGraph::add_uncanonical), when explanations are enabled calling - /// Calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` return an corrispond to the + /// Calling [`id_to_expr`](EGraphResidual::id_to_expr) on this `Id` return an correspond to the /// instantiation of the pattern fn add_instantiation_noncanonical(&mut self, pat: &PatternAst, subst: &Subst) -> Id { let nodes = pat.as_ref(); @@ -639,9 +582,9 @@ impl> EGraph { } ENodeOrVar::ENode(node) => { let new_node = node.clone().map_children(|i| new_ids[usize::from(i)]); - let size_before = self.unionfind.size(); + let size_before = self.inner.number_of_uncanonical_nodes(); let next_id = self.add_uncanonical(new_node); - if self.unionfind.size() > size_before { + if self.inner.number_of_uncanonical_nodes() > size_before { new_node_q.push(true); } else { new_node_q.push(false); @@ -661,67 +604,6 @@ impl> EGraph { *new_ids.last().unwrap() } - /// Lookup the eclass of the given enode. - /// - /// You can pass in either an owned enode or a `&mut` enode, - /// in which case the enode's children will be canonicalized. - /// - /// # Example - /// ``` - /// # use egg::*; - /// let mut egraph: EGraph = Default::default(); - /// let a = egraph.add(SymbolLang::leaf("a")); - /// let b = egraph.add(SymbolLang::leaf("b")); - /// - /// // lookup will find this node if its in the egraph - /// let mut node_f_ab = SymbolLang::new("f", vec![a, b]); - /// assert_eq!(egraph.lookup(node_f_ab.clone()), None); - /// let id = egraph.add(node_f_ab.clone()); - /// assert_eq!(egraph.lookup(node_f_ab.clone()), Some(id)); - /// - /// // if the query node isn't canonical, and its passed in by &mut instead of owned, - /// // its children will be canonicalized - /// egraph.union(a, b); - /// egraph.rebuild(); - /// assert_eq!(egraph.lookup(&mut node_f_ab), Some(id)); - /// assert_eq!(node_f_ab, SymbolLang::new("f", vec![a, a])); - /// ``` - pub fn lookup(&self, enode: B) -> Option - where - B: BorrowMut, - { - self.lookup_internal(enode).map(|id| self.find(id)) - } - - fn lookup_internal(&self, mut enode: B) -> Option - where - B: BorrowMut, - { - let enode = enode.borrow_mut(); - enode.update_children(|id| self.find(id)); - self.memo.get(enode).copied() - } - - /// Lookup the eclass of the given [`RecExpr`]. - /// - /// Equivalent to the last value in [`EGraph::lookup_expr_ids`]. - pub fn lookup_expr(&self, expr: &RecExpr) -> Option { - self.lookup_expr_ids(expr) - .and_then(|ids| ids.last().copied()) - } - - /// Lookup the eclasses of all the nodes in the given [`RecExpr`]. - pub fn lookup_expr_ids(&self, expr: &RecExpr) -> Option> { - let nodes = expr.as_ref(); - let mut new_ids = Vec::with_capacity(nodes.len()); - for node in nodes { - let node = node.clone().map_children(|i| new_ids[usize::from(i)]); - let id = self.lookup(node)?; - new_ids.push(id) - } - Some(new_ids) - } - /// Adds an enode to the [`EGraph`]. /// /// When adding an enode, to the egraph, [`add`] it performs @@ -741,10 +623,10 @@ impl> EGraph { /// Similar to [`add`](EGraph::add) but the `Id` returned may not be canonical /// - /// When explanations are enabled calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` will + /// When explanations are enabled calling [`id_to_expr`](EGraphResidual::id_to_expr) on this `Id` will /// correspond to the parameter `enode` /// - /// # Example + /// ## Example /// ``` /// # use egg::*; /// let mut egraph: EGraph = EGraph::default().with_explanations_enabled(); @@ -759,60 +641,67 @@ impl> EGraph { /// assert_eq!(egraph.id_to_expr(fa), "(f a)".parse().unwrap()); /// assert_eq!(egraph.id_to_expr(fb), "(f b)".parse().unwrap()); /// ``` - pub fn add_uncanonical(&mut self, mut enode: L) -> Id { - let original = enode.clone(); - if let Some(existing_id) = self.lookup_internal(&mut enode) { - let id = self.find(existing_id); - // when explanations are enabled, we need a new representative for this expr - if let Some(explain) = self.explain.as_mut() { - if let Some(existing_explain) = explain.uncanon_memo.get(&original) { - *existing_explain + /// + /// When explanations are not enabled calling [`id_to_expr`](EGraphResidual::id_to_expr) on this `Id` will + /// produce an expression with equivalent but not necessarily identical children + /// + /// # Example + /// ``` + /// # use egg::*; + /// let mut egraph: EGraph = EGraph::default().with_explanations_disabled(); + /// let a = egraph.add_uncanonical(SymbolLang::leaf("a")); + /// let b = egraph.add_uncanonical(SymbolLang::leaf("b")); + /// egraph.union(a, b); + /// egraph.rebuild(); + /// + /// let fa = egraph.add_uncanonical(SymbolLang::new("f", vec![a])); + /// let fb = egraph.add_uncanonical(SymbolLang::new("f", vec![b])); + /// + /// assert_eq!(egraph.id_to_expr(fa), "(f a)".parse().unwrap()); + /// assert_eq!(egraph.id_to_expr(fb), "(f a)".parse().unwrap()); + /// ``` + pub fn add_uncanonical(&mut self, enode: L) -> Id { + let mut added = false; + let id = RawEGraph::raw_add( + self, + |x| &mut x.inner, + enode, + |this, existing_id, enode| { + if let Some(explain) = this.explain.as_mut() { + if let Some(existing_id) = explain.uncanon_memo.get(enode) { + Some(*existing_id) + } else { + None + } } else { - let new_id = self.unionfind.make_set(); - explain.add(original, new_id, new_id); - self.unionfind.union(id, new_id); + Some(existing_id) + } + }, + |this, existing_id, new_id| { + if let Some(explain) = this.explain.as_mut() { + explain.add(this.inner.id_to_node(new_id).clone(), new_id, new_id); explain.union(existing_id, new_id, Justification::Congruence, true); - new_id } - } else { - existing_id - } - } else { - let id = self.make_new_eclass(enode); + }, + |this, id, _| { + added = true; + let node = this.id_to_node(id).clone(); + let data = N::make(this, &node); + EClassData { + nodes: vec![node], + data, + } + }, + ); + if added { if let Some(explain) = self.explain.as_mut() { - explain.add(original, id, id); + explain.add(self.inner.id_to_node(id).clone(), id, id); } // now that we updated explanations, run the analysis for the new eclass N::modify(self, id); self.clean = false; - id } - } - - /// This function makes a new eclass in the egraph (but doesn't touch explanations) - fn make_new_eclass(&mut self, enode: L) -> Id { - let id = self.unionfind.make_set(); - log::trace!(" ...adding to {}", id); - let class = EClass { - id, - nodes: vec![enode.clone()], - data: N::make(self, &enode), - parents: Default::default(), - }; - - // add this enode to the parent lists of its children - enode.for_each(|child| { - let tup = (enode.clone(), id); - self[child].parents.push(tup); - }); - - // TODO is this needed? - self.pending.push((enode.clone(), id)); - - self.classes.insert(id, class); - assert!(self.memo.insert(enode, id).is_none()); - id } @@ -858,9 +747,9 @@ impl> EGraph { rule_name: impl Into, ) -> (Id, bool) { let id1 = self.add_instantiation_noncanonical(from_pat, subst); - let size_before = self.unionfind.size(); + let size_before = self.number_of_uncanonical_nodes(); let id2 = self.add_instantiation_noncanonical(to_pat, subst); - let rhs_new = self.unionfind.size() > size_before; + let rhs_new = self.number_of_uncanonical_nodes() > size_before; let did_union = self.perform_union( id1, @@ -873,7 +762,7 @@ impl> EGraph { /// Unions two e-classes, using a given reason to justify it. /// - /// This function picks representatives using [`id_to_expr`](EGraph::id_to_expr) so choosing + /// This function picks representatives using [`id_to_expr`](EGraphResidual::id_to_expr) so choosing /// `Id`s returned by functions like [`add_uncanonical`](EGraph::add_uncanonical) is important /// to control explanations pub fn union_trusted(&mut self, from: Id, to: Id, reason: impl Into) -> bool { @@ -914,49 +803,37 @@ impl> EGraph { N::pre_union(self, enode_id1, enode_id2, &rule); self.clean = false; - let mut id1 = self.find_mut(enode_id1); - let mut id2 = self.find_mut(enode_id2); - if id1 == id2 { + let mut new_root = None; + self.inner.raw_union(enode_id1, enode_id2, |info| { + new_root = Some(info.id1); + + let did_merge = self.analysis.merge(&mut info.data1.data, info.data2.data); + if did_merge.0 { + self.analysis_pending + .extend(info.parents1.into_iter().copied()); + } + if did_merge.1 { + self.analysis_pending + .extend(info.parents2.into_iter().copied()); + } + + concat_vecs(&mut info.data1.nodes, info.data2.nodes); + }); + if let Some(id) = new_root { + if let Some(explain) = &mut self.explain { + explain.union(enode_id1, enode_id2, rule.unwrap(), any_new_rhs); + } + N::modify(self, id); + + true + } else { if let Some(Justification::Rule(_)) = rule { if let Some(explain) = &mut self.explain { - explain.alternate_rewrite(enode_id1, enode_id2, rule.unwrap()); + explain.alternate_rewrite(enode_id1, enode_id2, rule.unwrap()) } } - return false; - } - // make sure class2 has fewer parents - let class1_parents = self.classes[&id1].parents.len(); - let class2_parents = self.classes[&id2].parents.len(); - if class1_parents < class2_parents { - std::mem::swap(&mut id1, &mut id2); + false } - - if let Some(explain) = &mut self.explain { - explain.union(enode_id1, enode_id2, rule.unwrap(), any_new_rhs); - } - - // make id1 the new root - self.unionfind.union(id1, id2); - - assert_ne!(id1, id2); - let class2 = self.classes.remove(&id2).unwrap(); - let class1 = self.classes.get_mut(&id1).unwrap(); - assert_eq!(id1, class1.id); - - self.pending.extend(class2.parents.iter().cloned()); - let did_merge = self.analysis.merge(&mut class1.data, class2.data); - if did_merge.0 { - self.analysis_pending.extend(class1.parents.iter().cloned()); - } - if did_merge.1 { - self.analysis_pending.extend(class2.parents.iter().cloned()); - } - - concat_vecs(&mut class1.nodes, class2.nodes); - concat_vecs(&mut class1.parents, class2.parents); - - N::modify(self, id1); - true } /// Update the analysis data of an e-class. @@ -965,10 +842,9 @@ impl> EGraph { /// so [`Analysis::make`] and [`Analysis::merge`] will get /// called for other parts of the e-graph on rebuild. pub fn set_analysis_data(&mut self, id: Id, new_data: N::Data) { - let id = self.find_mut(id); - let class = self.classes.get_mut(&id).unwrap(); + let class = self.inner.get_class_mut(id).0; class.data = new_data; - self.analysis_pending.extend(class.parents.iter().cloned()); + self.analysis_pending.extend(class.parents()); N::modify(self, id) } @@ -981,7 +857,7 @@ impl> EGraph { /// /// [`Debug`]: std::fmt::Debug pub fn dump(&self) -> impl Debug + '_ { - EGraphDump(self) + self.inner.dump_classes() } } @@ -1020,9 +896,9 @@ impl> EGraph { classes_by_op.values_mut().for_each(|ids| ids.clear()); let mut trimmed = 0; - let uf = &mut self.unionfind; + let (classes, uf) = self.inner.classes_mut(); - for class in self.classes.values_mut() { + for class in classes { let old_len = class.len(); class .nodes @@ -1068,8 +944,8 @@ impl> EGraph { fn check_memo(&self) -> bool { let mut test_memo = HashMap::default(); - for (&id, class) in self.classes.iter() { - assert_eq!(class.id, id); + for class in self.classes() { + let id = class.id; for node in &class.nodes { if let Some(old) = test_memo.insert(node, id) { assert_eq!( @@ -1088,7 +964,7 @@ impl> EGraph { assert_eq!(e, self.find(e)); assert_eq!( Some(e), - self.memo.get(n).map(|id| self.find(*id)), + self.lookup(n.clone()), "Entry for {:?} at {} in test_memo was incorrect", n, e @@ -1102,34 +978,32 @@ impl> EGraph { fn process_unions(&mut self) -> usize { let mut n_unions = 0; - while !self.pending.is_empty() || !self.analysis_pending.is_empty() { - while let Some((mut node, class)) = self.pending.pop() { - node.update_children(|id| self.find_mut(id)); - if let Some(memo_class) = self.memo.insert(node, class) { - let did_something = self.perform_union( - memo_class, - class, - Some(Justification::Congruence), - false, - ); + while !self.inner.is_clean() || !self.analysis_pending.is_empty() { + RawEGraph::raw_rebuild( + self, + |this| &mut this.inner, + |this, id1, id2| { + let did_something = + this.perform_union(id1, id2, Some(Justification::Congruence), false); n_unions += did_something as usize; - } - } + }, + |_, _, _| {}, + ); - while let Some((node, class_id)) = self.analysis_pending.pop() { - let class_id = self.find_mut(class_id); + while let Some(mut class_id) = self.analysis_pending.pop() { + let node = self.id_to_node(class_id).clone(); let node_data = N::make(self, &node); - let class = self.classes.get_mut(&class_id).unwrap(); + let class = self.inner.get_class_mut(&mut class_id).0; let did_merge = self.analysis.merge(&mut class.data, node_data); if did_merge.0 { - self.analysis_pending.extend(class.parents.iter().cloned()); + self.analysis_pending.extend(class.parents()); N::modify(self, class_id) } } } - assert!(self.pending.is_empty()); + assert!(self.inner.is_clean()); assert!(self.analysis_pending.is_empty()); n_unions @@ -1173,7 +1047,7 @@ impl> EGraph { /// assert_eq!(egraph.find(ax), egraph.find(ay)); /// ``` pub fn rebuild(&mut self) -> usize { - let old_hc_size = self.memo.len(); + let old_hc_size = self.total_size(); let old_n_eclasses = self.number_of_classes(); let start = Instant::now(); @@ -1193,7 +1067,7 @@ impl> EGraph { elapsed.subsec_millis(), old_hc_size, old_n_eclasses, - self.memo.len(), + self.total_size(), self.number_of_classes(), n_unions, trimmed_nodes, @@ -1204,30 +1078,17 @@ impl> EGraph { n_unions } - pub(crate) fn check_each_explain(&self, rules: &[&Rewrite]) -> bool { - if let Some(explain) = &self.explain { - explain.check_each_explain(rules) + pub(crate) fn check_each_explain(&mut self, rules: &[&Rewrite]) -> bool { + if let Some(explain) = &mut self.explain { + explain + .with_raw_egraph(&self.inner) + .check_each_explain(rules) } else { panic!("Can't check explain when explanations are off"); } } } -struct EGraphDump<'a, L: Language, N: Analysis>(&'a EGraph); - -impl<'a, L: Language, N: Analysis> Debug for EGraphDump<'a, L, N> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut ids: Vec = self.0.classes().map(|c| c.id).collect(); - ids.sort(); - for id in ids { - let mut nodes = self.0[id].nodes.clone(); - nodes.sort(); - writeln!(f, "{} ({:?}): {:?}", id, self.0[id].data, nodes)? - } - Ok(()) - } -} - #[cfg(test)] mod tests { diff --git a/src/explain.rs b/src/explain.rs index 187aecfc..53a8ec61 100644 --- a/src/explain.rs +++ b/src/explain.rs @@ -1,14 +1,16 @@ use crate::Symbol; use crate::{ - util::pretty_print, Analysis, EClass, EGraph, ENodeOrVar, FromOp, HashMap, HashSet, Id, - Language, Pattern, PatternAst, RecExpr, Rewrite, Subst, UnionFind, Var, + util::pretty_print, Analysis, ENodeOrVar, FromOp, HashMap, HashSet, Id, Language, PatternAst, + RecExpr, Rewrite, UnionFind, Var, }; use saturating::Saturating; use std::cmp::Ordering; use std::collections::{BinaryHeap, VecDeque}; use std::fmt::{self, Debug, Display, Formatter}; +use std::ops::{Deref, DerefMut}; use std::rc::Rc; +use crate::raw::RawEGraph; use symbolic_expressions::Sexp; type ProofCost = Saturating; @@ -38,8 +40,7 @@ struct Connection { #[derive(Debug, Clone)] #[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] -struct ExplainNode { - node: L, +struct ExplainNode { // neighbors includes parent connections neighbors: Vec, parent_connection: Connection, @@ -54,8 +55,15 @@ struct ExplainNode { #[derive(Debug, Clone)] #[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] pub struct Explain { - explainfind: Vec>, + explainfind: Vec, #[cfg_attr(feature = "serde-1", serde(with = "vectorize"))] + #[cfg_attr( + feature = "serde-1", + serde(bound( + serialize = "L: serde::Serialize", + deserialize = "L: serde::Deserialize<'de>", + )) + )] pub uncanon_memo: HashMap, /// By default, egg uses a greedy algorithm to find shorter explanations when they are extracted. pub optimize_explanation_lengths: bool, @@ -69,6 +77,11 @@ pub struct Explain { shortest_explanation_memo: HashMap<(Id, Id), (ProofCost, Id)>, } +pub(crate) struct ExplainWith<'a, L: Language, X> { + explain: &'a mut Explain, + raw: X, +} + #[derive(Default)] struct DistanceMemo { parent_distance: Vec<(Id, ProofCost)>, @@ -883,97 +896,6 @@ impl PartialOrd for HeapState { } impl Explain { - pub(crate) fn node(&self, node_id: Id) -> &L { - &self.explainfind[usize::from(node_id)].node - } - fn node_to_explanation( - &self, - node_id: Id, - cache: &mut NodeExplanationCache, - ) -> Rc> { - if let Some(existing) = cache.get(&node_id) { - existing.clone() - } else { - let node = self.node(node_id).clone(); - let children = node.fold(vec![], |mut sofar, child| { - sofar.push(vec![self.node_to_explanation(child, cache)]); - sofar - }); - let res = Rc::new(TreeTerm::new(node, children)); - cache.insert(node_id, res.clone()); - res - } - } - - pub(crate) fn node_to_recexpr(&self, node_id: Id) -> RecExpr { - let mut res = Default::default(); - let mut cache = Default::default(); - self.node_to_recexpr_internal(&mut res, node_id, &mut cache); - res - } - fn node_to_recexpr_internal( - &self, - res: &mut RecExpr, - node_id: Id, - cache: &mut HashMap, - ) { - let new_node = self.node(node_id).clone().map_children(|child| { - if let Some(existing) = cache.get(&child) { - *existing - } else { - self.node_to_recexpr_internal(res, child, cache); - Id::from(res.as_ref().len() - 1) - } - }); - res.add(new_node); - } - - pub(crate) fn node_to_pattern( - &self, - node_id: Id, - substitutions: &HashMap, - ) -> (Pattern, Subst) { - let mut res = Default::default(); - let mut subst = Default::default(); - let mut cache = Default::default(); - self.node_to_pattern_internal(&mut res, node_id, substitutions, &mut subst, &mut cache); - (Pattern::new(res), subst) - } - - fn node_to_pattern_internal( - &self, - res: &mut PatternAst, - node_id: Id, - var_substitutions: &HashMap, - subst: &mut Subst, - cache: &mut HashMap, - ) { - if let Some(existing) = var_substitutions.get(&node_id) { - let var = format!("?{}", node_id).parse().unwrap(); - res.add(ENodeOrVar::Var(var)); - subst.insert(var, *existing); - } else { - let new_node = self.node(node_id).clone().map_children(|child| { - if let Some(existing) = cache.get(&child) { - *existing - } else { - self.node_to_pattern_internal(res, child, var_substitutions, subst, cache); - Id::from(res.as_ref().len() - 1) - } - }); - res.add(ENodeOrVar::ENode(new_node)); - } - } - - fn node_to_flat_explanation(&self, node_id: Id) -> FlatTerm { - let node = self.node(node_id).clone(); - let children = node.fold(vec![], |mut sofar, child| { - sofar.push(self.node_to_flat_explanation(child)); - sofar - }); - FlatTerm::new(node, children) - } - fn make_rule_table<'a, N: Analysis>( rules: &[&'a Rewrite], ) -> HashMap> { @@ -983,52 +905,6 @@ impl Explain { } table } - - pub fn check_each_explain>(&self, rules: &[&Rewrite]) -> bool { - let rule_table = Explain::make_rule_table(rules); - for i in 0..self.explainfind.len() { - let explain_node = &self.explainfind[i]; - - // check that explanation reasons never form a cycle - let mut existance = i; - let mut seen_existance: HashSet = Default::default(); - loop { - seen_existance.insert(existance); - let next = usize::from(self.explainfind[existance].existance_node); - if existance == next { - break; - } - existance = next; - if seen_existance.contains(&existance) { - panic!("Cycle in existance!"); - } - } - - if explain_node.parent_connection.next != Id::from(i) { - let mut current_explanation = self.node_to_flat_explanation(Id::from(i)); - let mut next_explanation = - self.node_to_flat_explanation(explain_node.parent_connection.next); - if let Justification::Rule(rule_name) = - &explain_node.parent_connection.justification - { - if let Some(rule) = rule_table.get(rule_name) { - if !explain_node.parent_connection.is_rewrite_forward { - std::mem::swap(&mut current_explanation, &mut next_explanation); - } - if !Explanation::check_rewrite( - ¤t_explanation, - &next_explanation, - rule, - ) { - return false; - } - } - } - } - } - true - } - pub fn new() -> Self { Explain { explainfind: vec![], @@ -1044,9 +920,8 @@ impl Explain { pub(crate) fn add(&mut self, node: L, set: Id, existance_node: Id) -> Id { assert_eq!(self.explainfind.len(), usize::from(set)); - self.uncanon_memo.insert(node.clone(), set); + self.uncanon_memo.insert(node, set); self.explainfind.push(ExplainNode { - node, neighbors: vec![], parent_connection: Connection { justification: Justification::Congruence, @@ -1119,7 +994,7 @@ impl Explain { new_rhs: bool, ) { if let Justification::Congruence = justification { - assert!(self.node(node1).matches(self.node(node2))); + // assert!(self.node(node1).matches(self.node(node2))); } if new_rhs { self.set_existance_reason(node2, node1) @@ -1155,7 +1030,6 @@ impl Explain { .push(other_pconnection); self.explainfind[usize::from(node1)].parent_connection = pconnection; } - pub(crate) fn get_union_equalities(&self) -> UnionEqualities { let mut equalities = vec![]; for node in &self.explainfind { @@ -1170,24 +1044,105 @@ impl Explain { equalities } - pub(crate) fn populate_enodes>(&self, mut egraph: EGraph) -> EGraph { - for i in 0..self.explainfind.len() { - let node = &self.explainfind[i]; - egraph.add(node.node.clone()); + pub(crate) fn with_raw_egraph<'a, X>(&'a mut self, raw: X) -> ExplainWith<'a, L, X> { + ExplainWith { explain: self, raw } + } +} + +impl<'a, L: Language, X> Deref for ExplainWith<'a, L, X> { + type Target = Explain; + + fn deref(&self) -> &Self::Target { + self.explain + } +} + +impl<'a, L: Language, X> DerefMut for ExplainWith<'a, L, X> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut *self.explain + } +} + +impl<'x, L: Language, D> ExplainWith<'x, L, &'x RawEGraph> { + pub(crate) fn node(&self, node_id: Id) -> &L { + self.raw.id_to_node(node_id) + } + fn node_to_explanation( + &self, + node_id: Id, + cache: &mut NodeExplanationCache, + ) -> Rc> { + if let Some(existing) = cache.get(&node_id) { + existing.clone() + } else { + let node = self.node(node_id).clone(); + let children = node.fold(vec![], |mut sofar, child| { + sofar.push(vec![self.node_to_explanation(child, cache)]); + sofar + }); + let res = Rc::new(TreeTerm::new(node, children)); + cache.insert(node_id, res.clone()); + res } + } - egraph + fn node_to_flat_explanation(&self, node_id: Id) -> FlatTerm { + let node = self.node(node_id).clone(); + let children = node.fold(vec![], |mut sofar, child| { + sofar.push(self.node_to_flat_explanation(child)); + sofar + }); + FlatTerm::new(node, children) } - pub(crate) fn explain_equivalence>( - &mut self, - left: Id, - right: Id, - unionfind: &mut UnionFind, - classes: &HashMap>, - ) -> Explanation { + pub fn check_each_explain>(&self, rules: &[&Rewrite]) -> bool { + let rule_table = Explain::make_rule_table(rules); + for i in 0..self.explainfind.len() { + let explain_node = &self.explainfind[i]; + + // check that explanation reasons never form a cycle + let mut existance = i; + let mut seen_existance: HashSet = Default::default(); + loop { + seen_existance.insert(existance); + let next = usize::from(self.explainfind[existance].existance_node); + if existance == next { + break; + } + existance = next; + if seen_existance.contains(&existance) { + panic!("Cycle in existance!"); + } + } + + if explain_node.parent_connection.next != Id::from(i) { + let mut current_explanation = self.node_to_flat_explanation(Id::from(i)); + let mut next_explanation = + self.node_to_flat_explanation(explain_node.parent_connection.next); + if let Justification::Rule(rule_name) = + &explain_node.parent_connection.justification + { + if let Some(rule) = rule_table.get(rule_name) { + if !explain_node.parent_connection.is_rewrite_forward { + std::mem::swap(&mut current_explanation, &mut next_explanation); + } + if !Explanation::check_rewrite( + ¤t_explanation, + &next_explanation, + rule, + ) { + return false; + } + } + } + } + } + true + } + + pub(crate) fn explain_equivalence(&mut self, left: Id, right: Id) -> Explanation { if self.optimize_explanation_lengths { - self.calculate_shortest_explanations::(left, right, classes, unionfind); + self.calculate_shortest_explanations(left, right); } let mut cache = Default::default(); @@ -1328,7 +1283,7 @@ impl Explain { let mut new_rest_of_proof = (*self.node_to_explanation(existance, enode_cache)).clone(); let mut index_of_child = 0; let mut found = false; - existance_node.node.for_each(|child| { + self.node(existance).for_each(|child| { if found { return; } @@ -1625,12 +1580,7 @@ impl Explain { distance_memo.parent_distance[usize::from(enode)].1 } - fn find_congruence_neighbors>( - &self, - classes: &HashMap>, - congruence_neighbors: &mut [Vec], - unionfind: &UnionFind, - ) { + fn find_congruence_neighbors(&self, congruence_neighbors: &mut [Vec]) { let mut counter = 0; // add the normal congruence edges first for node in &self.explainfind { @@ -1643,15 +1593,15 @@ impl Explain { } } - 'outer: for eclass in classes.keys() { - let enodes = self.find_all_enodes(*eclass); + 'outer: for eclass in self.raw.classes().map(|x| x.id) { + let enodes = self.find_all_enodes(eclass); // find all congruence nodes let mut cannon_enodes: HashMap> = Default::default(); for enode in &enodes { let cannon = self .node(*enode) .clone() - .map_children(|child| unionfind.find(child)); + .map_children(|child| self.raw.find(child)); if let Some(others) = cannon_enodes.get_mut(&cannon) { for other in others.iter() { congruence_neighbors[usize::from(*enode)].push(*other); @@ -1671,13 +1621,9 @@ impl Explain { } } - pub fn get_num_congr>( - &self, - classes: &HashMap>, - unionfind: &UnionFind, - ) -> usize { + pub fn get_num_congr(&self) -> usize { let mut congruence_neighbors = vec![vec![]; self.explainfind.len()]; - self.find_congruence_neighbors::(classes, &mut congruence_neighbors, unionfind); + self.find_congruence_neighbors(&mut congruence_neighbors); let mut count = 0; for v in congruence_neighbors { count += v.len(); @@ -1686,10 +1632,6 @@ impl Explain { count / 2 } - pub fn get_num_nodes(&self) -> usize { - self.explainfind.len() - } - fn shortest_path_modulo_congruence( &mut self, start: Id, @@ -1888,11 +1830,7 @@ impl Explain { self.explainfind[usize::from(enode)].parent_connection.next } - fn calculate_common_ancestor>( - &self, - classes: &HashMap>, - congruence_neighbors: &[Vec], - ) -> HashMap<(Id, Id), Id> { + fn calculate_common_ancestor(&self, congruence_neighbors: &[Vec]) -> HashMap<(Id, Id), Id> { let mut common_ancestor_queries = HashMap::default(); for (s_int, others) in congruence_neighbors.iter().enumerate() { let start = &Id::from(s_int); @@ -1924,8 +1862,8 @@ impl Explain { unionfind.make_set(); ancestor.push(Id::from(i)); } - for (eclass, _) in classes.iter() { - let enodes = self.find_all_enodes(*eclass); + for eclass in self.raw.classes().map(|x| x.id) { + let enodes = self.find_all_enodes(eclass); let mut children: HashMap> = HashMap::default(); for enode in &enodes { children.insert(*enode, vec![]); @@ -1956,15 +1894,9 @@ impl Explain { common_ancestor } - fn calculate_shortest_explanations>( - &mut self, - start: Id, - end: Id, - classes: &HashMap>, - unionfind: &UnionFind, - ) { + fn calculate_shortest_explanations(&mut self, start: Id, end: Id) { let mut congruence_neighbors = vec![vec![]; self.explainfind.len()]; - self.find_congruence_neighbors::(classes, &mut congruence_neighbors, unionfind); + self.find_congruence_neighbors(&mut congruence_neighbors); let mut parent_distance = vec![(Id::from(0), Saturating(0)); self.explainfind.len()]; for (i, entry) in parent_distance.iter_mut().enumerate() { entry.0 = Id::from(i); @@ -1972,7 +1904,7 @@ impl Explain { let mut distance_memo = DistanceMemo { parent_distance, - common_ancestor: self.calculate_common_ancestor::(classes, &congruence_neighbors), + common_ancestor: self.calculate_common_ancestor(&congruence_neighbors), tree_depth: self.calculate_tree_depths(), }; @@ -2092,7 +2024,7 @@ mod tests { #[test] fn simple_explain_union_trusted() { - use crate::SymbolLang; + use crate::{EGraph, SymbolLang}; crate::init_logger(); let mut egraph = EGraph::new(()).with_explanations_enabled(); diff --git a/src/lib.rs b/src/lib.rs index 5a293a58..298b3518 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -48,6 +48,9 @@ mod lp_extract; mod machine; mod multipattern; mod pattern; + +/// Lower level egraph API +pub mod raw; mod rewrite; mod run; mod subst; @@ -89,7 +92,7 @@ pub(crate) use {explain::Explain, unionfind::UnionFind}; pub use { dot::Dot, - eclass::EClass, + eclass::{EClass, EClassData}, egraph::EGraph, explain::{ Explanation, FlatExplanation, FlatTerm, Justification, TreeExplanation, TreeTerm, diff --git a/src/raw.rs b/src/raw.rs new file mode 100644 index 00000000..22395db6 --- /dev/null +++ b/src/raw.rs @@ -0,0 +1,5 @@ +mod eclass; +mod egraph; + +pub use eclass::RawEClass; +pub use egraph::{EGraphResidual, RawEGraph}; diff --git a/src/raw/eclass.rs b/src/raw/eclass.rs new file mode 100644 index 00000000..dd6e43be --- /dev/null +++ b/src/raw/eclass.rs @@ -0,0 +1,43 @@ +use crate::Id; +use std::fmt::Debug; +use std::iter::ExactSizeIterator; +use std::ops::{Deref, DerefMut}; + +/// An equivalence class of enodes. +#[non_exhaustive] +#[derive(Debug, Clone)] +#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] +pub struct RawEClass { + /// This eclass's id. + pub id: Id, + /// Arbitrary data associated with this eclass. + pub(super) raw_data: D, + /// The original Ids of parent enodes. + pub(super) parents: Vec, +} + +impl RawEClass { + /// Iterates over the non-canonical ids of parent enodes of this eclass. + pub fn parents(&self) -> impl ExactSizeIterator + '_ { + self.parents.iter().copied() + } + + /// Consumes `self` returning the stored data and an iterator similar to [`parents`](RawEClass::parents) + pub fn destruct(self) -> (D, impl ExactSizeIterator) { + (self.raw_data, self.parents.into_iter()) + } +} + +impl Deref for RawEClass { + type Target = D; + + fn deref(&self) -> &D { + &self.raw_data + } +} + +impl DerefMut for RawEClass { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.raw_data + } +} diff --git a/src/raw/egraph.rs b/src/raw/egraph.rs new file mode 100644 index 00000000..383b1361 --- /dev/null +++ b/src/raw/egraph.rs @@ -0,0 +1,728 @@ +use crate::{raw::RawEClass, Dot, HashMap, Id, Language, RecExpr, UnionFind}; +use std::convert::Infallible; +use std::ops::{Deref, DerefMut}; +use std::{ + borrow::BorrowMut, + fmt::{self, Debug}, + iter, slice, +}; + +#[cfg(feature = "serde-1")] +use serde::{Deserialize, Serialize}; + +pub struct Parents<'a>(&'a [Id]); + +impl<'a> IntoIterator for Parents<'a> { + type Item = Id; + type IntoIter = iter::Copied>; + + fn into_iter(self) -> Self::IntoIter { + self.0.iter().copied() + } +} + +/// A [`RawEGraph`] without its classes that can be obtained by dereferencing a [`RawEGraph`]. +/// +/// It exists as a separate type so that it can still be used while mutably borrowing a [`RawEClass`] +/// +/// See [`RawEGraph::classes_mut`], [`RawEGraph::get_class_mut`] +#[derive(Clone)] +#[cfg_attr(feature = "serde-1", derive(Serialize, Deserialize))] +pub struct EGraphResidual { + unionfind: UnionFind, + /// Stores the original node represented by each non-canonical id + nodes: Vec, + /// Stores each enode's `Id`, not the `Id` of the eclass. + /// Enodes in the memo are canonicalized at each rebuild, but after rebuilding new + /// unions can cause them to become out of date. + #[cfg_attr(feature = "serde-1", serde(with = "vectorize"))] + memo: HashMap, +} + +impl EGraphResidual { + /// Pick a representative term for a given Id. + /// + /// Calling this function on an uncanonical `Id` returns a representative based on how it + /// was obtained + pub fn id_to_expr(&self, id: Id) -> RecExpr { + let mut res = Default::default(); + let mut cache = Default::default(); + self.id_to_expr_internal(&mut res, id, &mut cache); + res + } + + fn id_to_expr_internal( + &self, + res: &mut RecExpr, + node_id: Id, + cache: &mut HashMap, + ) -> Id { + if let Some(existing) = cache.get(&node_id) { + return *existing; + } + let new_node = self + .id_to_node(node_id) + .clone() + .map_children(|child| self.id_to_expr_internal(res, child, cache)); + let res_id = res.add(new_node); + cache.insert(node_id, res_id); + res_id + } + + /// Like [`id_to_expr`](EGraphResidual::id_to_expr) but only goes one layer deep + pub fn id_to_node(&self, id: Id) -> &L { + &self.nodes[usize::from(id)] + } + + /// Canonicalizes an eclass id. + /// + /// This corresponds to the `find` operation on the egraph's + /// underlying unionfind data structure. + /// + /// # Example + /// ``` + /// use egg::{raw::*, SymbolLang as S}; + /// let mut egraph = RawEGraph::::default(); + /// let x = egraph.add_uncanonical(S::leaf("x")); + /// let y = egraph.add_uncanonical(S::leaf("y")); + /// assert_ne!(egraph.find(x), egraph.find(y)); + /// + /// egraph.union(x, y); + /// egraph.rebuild(); + /// assert_eq!(egraph.find(x), egraph.find(y)); + /// ``` + pub fn find(&self, id: Id) -> Id { + self.unionfind.find(id) + } + + /// Same as [`find`](EGraphResidual::find) but requires mutable access since it does path compression + pub fn find_mut(&mut self, id: Id) -> Id { + self.unionfind.find_mut(id) + } + + /// Returns `true` if the egraph is empty + /// # Example + /// ``` + /// use egg::{raw::*, SymbolLang as S}; + /// let mut egraph = RawEGraph::::default(); + /// assert!(egraph.is_empty()); + /// egraph.add_uncanonical(S::leaf("foo")); + /// assert!(!egraph.is_empty()); + /// ``` + pub fn is_empty(&self) -> bool { + self.nodes.is_empty() + } + + /// Returns the number of uncanonical enodes in the `EGraph`. + /// + /// # Example + /// ``` + /// use egg::{raw::*, SymbolLang as S}; + /// let mut egraph = RawEGraph::::default(); + /// let x = egraph.add_uncanonical(S::leaf("x")); + /// let y = egraph.add_uncanonical(S::leaf("y")); + /// let fx = egraph.add_uncanonical(S::new("f", vec![x])); + /// let fy = egraph.add_uncanonical(S::new("f", vec![y])); + /// // only one eclass + /// egraph.union(x, y); + /// egraph.rebuild(); + /// + /// assert_eq!(egraph.number_of_uncanonical_nodes(), 4); + /// assert_eq!(egraph.number_of_classes(), 2); + /// ``` + pub fn number_of_uncanonical_nodes(&self) -> usize { + self.nodes.len() + } + + /// Returns an iterator over the uncanonical ids in the egraph and the node + /// that would be obtained by calling [`id_to_node`](EGraphResidual::id_to_node) on each of them + pub fn uncanonical_nodes(&self) -> impl ExactSizeIterator { + self.nodes + .iter() + .enumerate() + .map(|(id, node)| (Id::from(id), node)) + } + + /// Returns the number of enodes in the `EGraph`. + /// + /// Actually returns the size of the hashcons index. + /// # Example + /// ``` + /// use egg::{*, SymbolLang as S}; + /// let mut egraph = EGraph::::default(); + /// let x = egraph.add(S::leaf("x")); + /// let y = egraph.add(S::leaf("y")); + /// // only one eclass + /// egraph.union(x, y); + /// egraph.rebuild(); + /// + /// assert_eq!(egraph.total_size(), 2); + /// assert_eq!(egraph.number_of_classes(), 1); + /// ``` + pub fn total_size(&self) -> usize { + self.memo.len() + } + + /// Lookup the eclass of the given enode. + /// + /// You can pass in either an owned enode or a `&mut` enode, + /// in which case the enode's children will be canonicalized. + /// + /// # Example + /// ``` + /// # use egg::*; + /// let mut egraph: EGraph = Default::default(); + /// let a = egraph.add(SymbolLang::leaf("a")); + /// let b = egraph.add(SymbolLang::leaf("b")); + /// + /// // lookup will find this node if its in the egraph + /// let mut node_f_ab = SymbolLang::new("f", vec![a, b]); + /// assert_eq!(egraph.lookup(node_f_ab.clone()), None); + /// let id = egraph.add(node_f_ab.clone()); + /// assert_eq!(egraph.lookup(node_f_ab.clone()), Some(id)); + /// + /// // if the query node isn't canonical, and its passed in by &mut instead of owned, + /// // its children will be canonicalized + /// egraph.union(a, b); + /// egraph.rebuild(); + /// assert_eq!(egraph.lookup(&mut node_f_ab), Some(id)); + /// assert_eq!(node_f_ab, SymbolLang::new("f", vec![a, a])); + /// ``` + pub fn lookup(&self, enode: B) -> Option + where + B: BorrowMut, + { + self.lookup_internal(enode).map(|id| self.find(id)) + } + + #[inline] + fn lookup_internal(&self, mut enode: B) -> Option + where + B: BorrowMut, + { + let enode = enode.borrow_mut(); + enode.update_children(|id| self.find(id)); + self.memo.get(enode).copied() + } + + /// Lookup the eclass of the given [`RecExpr`]. + /// + /// Equivalent to the last value in [`EGraphResidual::lookup_expr_ids`]. + pub fn lookup_expr(&self, expr: &RecExpr) -> Option { + self.lookup_expr_ids(expr) + .and_then(|ids| ids.last().copied()) + } + + /// Lookup the eclasses of all the nodes in the given [`RecExpr`]. + pub fn lookup_expr_ids(&self, expr: &RecExpr) -> Option> { + let nodes = expr.as_ref(); + let mut new_ids = Vec::with_capacity(nodes.len()); + for node in nodes { + let node = node.clone().map_children(|i| new_ids[usize::from(i)]); + let id = self.lookup(node)?; + new_ids.push(id) + } + Some(new_ids) + } + + /// Generate a mapping from canonical ids to the list of nodes they represent + pub fn generate_class_nodes(&self) -> HashMap> { + let mut classes = HashMap::default(); + let find = |id| self.find(id); + for (id, node) in self.uncanonical_nodes() { + let id = find(id); + let node = node.clone().map_children(find); + match classes.get_mut(&id) { + None => { + classes.insert(id, vec![node]); + } + Some(x) => x.push(node), + } + } + + // define all the nodes, clustered by eclass + for class in classes.values_mut() { + class.sort_unstable(); + class.dedup(); + } + classes + } + + /// Returns a more debug-able representation of the egraph focusing on its uncanonical ids and nodes. + /// + /// [`RawEGraph`]s implement [`Debug`], but it's not pretty. It + /// prints a lot of stuff you probably don't care about. + /// This method returns a wrapper that implements [`Debug`] in a + /// slightly nicer way, just dumping enodes in each eclass. + /// + /// [`Debug`]: std::fmt::Debug + pub fn dump_uncanonical(&self) -> impl Debug + '_ { + EGraphUncanonicalDump(self) + } + + /// Creates a [`Dot`] to visualize this egraph. See [`Dot`]. + pub fn dot(&self) -> Dot<'_, L> { + Dot { + egraph: self, + config: vec![], + use_anchors: true, + } + } +} + +// manual debug impl to avoid L: Language bound on EGraph defn +impl Debug for EGraphResidual { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("EGraphResidual") + .field("unionfind", &self.unionfind) + .field("nodes", &self.nodes) + .field("memo", &self.memo) + .finish() + } +} + +/** A data structure to keep track of equalities between expressions. + +Check out the [background tutorial](crate::tutorials::_01_background) +for more information on e-graphs in general. + +# E-graphs in `egg::raw` + +In `egg::raw`, the main types associated with e-graphs are +[`RawEGraph`], [`RawEClass`], [`Language`], and [`Id`]. + +[`RawEGraph`] and [`RawEClass`] are all generic over a +[`Language`], meaning that types actually floating around in the +egraph are all user-defined. +In particular, the e-nodes are elements of your [`Language`]. +[`RawEGraph`]s and [`RawEClass`]es are additionally parameterized by some +abritrary data associated with each e-class. + +Many methods of [`RawEGraph`] deal with [`Id`]s, which represent e-classes. +Because eclasses are frequently merged, many [`Id`]s will refer to the +same e-class. + +[`RawEGraph`] provides a low level API for dealing with egraphs, in particular with handling the data +stored in each [`RawEClass`] so user will likely want to implemented wrappers around +[`raw_add`](RawEGraph::raw_add), [`raw_union`](RawEGraph::raw_union), and [`raw_rebuild`](RawEGraph::raw_rebuild) +to properly handle this data + **/ +#[derive(Clone)] +#[cfg_attr(feature = "serde-1", derive(Serialize, Deserialize))] +pub struct RawEGraph { + #[cfg_attr(feature = "serde-1", serde(flatten))] + residual: EGraphResidual, + /// Nodes which need to be processed for rebuilding. The `Id` is the `Id` of the enode, + /// not the canonical id of the eclass. + pending: Vec, + classes: HashMap>, +} + +impl Default for RawEGraph { + fn default() -> Self { + let residual = EGraphResidual { + unionfind: Default::default(), + nodes: Default::default(), + memo: Default::default(), + }; + RawEGraph { + residual, + pending: Default::default(), + classes: Default::default(), + } + } +} + +impl Deref for RawEGraph { + type Target = EGraphResidual; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.residual + } +} + +impl DerefMut for RawEGraph { + #[inline] + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.residual + } +} + +// manual debug impl to avoid L: Language bound on EGraph defn +impl Debug for RawEGraph { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("EGraph") + .field("memo", &self.residual.memo) + .field("classes", &self.classes) + .finish() + } +} + +impl RawEGraph { + /// Returns an iterator over the eclasses in the egraph. + pub fn classes(&self) -> impl ExactSizeIterator> { + self.classes.iter().map(|(id, class)| { + debug_assert_eq!(*id, class.id); + class + }) + } + + /// Returns a mutating iterator over the eclasses in the egraph. + /// Also returns the [`EGraphResidual`] so it can still be used while `self` is borrowed + pub fn classes_mut( + &mut self, + ) -> ( + impl ExactSizeIterator>, + &mut EGraphResidual, + ) { + let iter = self.classes.iter_mut().map(|(id, class)| { + debug_assert_eq!(*id, class.id); + class + }); + (iter, &mut self.residual) + } + + /// Returns the number of eclasses in the egraph. + pub fn number_of_classes(&self) -> usize { + self.classes().len() + } + + /// Returns the eclass corresponding to `id` + pub fn get_class>(&self, mut id: I) -> &RawEClass { + let id = id.borrow_mut(); + *id = self.find(*id); + self.get_class_with_cannon(*id) + } + + /// Like [`get_class`](RawEGraph::get_class) but panics if `id` is not canonical + pub fn get_class_with_cannon(&self, id: Id) -> &RawEClass { + self.classes + .get(&id) + .unwrap_or_else(|| panic!("Invalid id {}", id)) + } + + /// Returns the eclass corresponding to `id` + /// Also returns the [`EGraphResidual`] so it can still be used while `self` is borrowed + pub fn get_class_mut>( + &mut self, + mut id: I, + ) -> (&mut RawEClass, &mut EGraphResidual) { + let id = id.borrow_mut(); + *id = self.find_mut(*id); + self.get_class_mut_with_cannon(*id) + } + + /// Like [`get_class_mut`](RawEGraph::get_class_mut) but panics if `id` is not canonical + pub fn get_class_mut_with_cannon( + &mut self, + id: Id, + ) -> (&mut RawEClass, &mut EGraphResidual) { + ( + self.classes + .get_mut(&id) + .unwrap_or_else(|| panic!("Invalid id {}", id)), + &mut self.residual, + ) + } +} + +/// Information for [`RawEGraph::raw_union`] callback +#[non_exhaustive] +pub struct MergeInfo<'a, D: 'a> { + /// id that will be the root for the newly merged eclass + pub id1: Id, + /// data associated with `id1` that can be modified to reflect `data2` being merged into it + pub data1: &'a mut D, + /// parents of `id1` before the merge + pub parents1: &'a [Id], + /// id that used to be a root but will now be in `id1` eclass + pub id2: Id, + /// data associated with `id2` + pub data2: D, + /// parents of `id2` before the merge + pub parents2: &'a [Id], + /// true if `id1` was the root of the second id passed to [`RawEGraph::raw_union`] + /// false if `id1` was the root of the first id passed to [`RawEGraph::raw_union`] + pub swapped_ids: bool, +} + +impl RawEGraph { + /// Adds `enode` to a [`RawEGraph`] contained within a wrapper type `T` + /// + /// ## Parameters + /// + /// ### `get_self` + /// Called to extract the [`RawEGraph`] from the wrapper type, and should not perform any mutation. + /// + /// This will likely be a simple field access or just the identity function if there is no wrapper type. + /// + /// ### `handle_equiv` + /// When there already exists a node that is congruently equivalent to `enode` in the egraph + /// this function is called with the uncanonical id of a equivalent node, and a reference to `enode` + /// + /// Returning `Some(id)` will cause `raw_add` to immediately return `id` + /// (in this case `id` should represent an enode that is equivalent to the one being inserted). + /// + /// Returning `None` will cause `raw_add` to create a new id for `enode`, union it to the equivalent node, + /// and then return it. + /// + /// ### `handle_union` + /// Called after `handle_equiv` returns `None` with the uncanonical id of the equivalent node + /// and the new `id` assigned to `enode` + /// + /// Calling [`id_to_node`](EGraphResidual::id_to_node) on the new `id` will return a reference to `enode` + /// + /// ### `mk_data` + /// When there does not already exist a node is congruently equivalent to `enode` in the egraph + /// this function is called with the new `id` assigned to `enode` and a reference to the canonicalized version of + /// `enode` to create to data that will be stored in the [`RawEClass`] associated with it + /// + /// Calling [`id_to_node`](EGraphResidual::id_to_node) on the new `id` will return a reference to `enode` + /// + /// Calling [`get_class`](RawEGraph::get_class) on the new `id` will cause a panic since the [`RawEClass`] is + /// still being built + #[inline] + pub fn raw_add( + outer: &mut T, + get_self: impl Fn(&mut T) -> &mut Self, + mut enode: L, + handle_equiv: impl FnOnce(&mut T, Id, &L) -> Option, + handle_union: impl FnOnce(&mut T, Id, Id), + mk_data: impl FnOnce(&mut T, Id, &L) -> D, + ) -> Id { + let this = get_self(outer); + let original = enode.clone(); + if let Some(existing_id) = this.lookup_internal(&mut enode) { + let canon_id = this.find(existing_id); + // when explanations are enabled, we need a new representative for this expr + if let Some(existing_id) = handle_equiv(outer, existing_id, &original) { + existing_id + } else { + let this = get_self(outer); + let new_id = this.residual.unionfind.make_set(); + debug_assert_eq!(Id::from(this.nodes.len()), new_id); + this.residual.nodes.push(original); + this.residual.unionfind.union(canon_id, new_id); + handle_union(outer, existing_id, new_id); + new_id + } + } else { + let id = this.residual.unionfind.make_set(); + debug_assert_eq!(Id::from(this.nodes.len()), id); + this.residual.nodes.push(original); + + log::trace!(" ...adding to {}", id); + let class = RawEClass { + id, + raw_data: mk_data(outer, id, &enode), + parents: Default::default(), + }; + let this = get_self(outer); + + // add this enode to the parent lists of its children + enode.for_each(|child| { + this.get_class_mut(child).0.parents.push(id); + }); + + // TODO is this needed? + this.pending.push(id); + + this.classes.insert(id, class); + assert!(this.residual.memo.insert(enode, id).is_none()); + + id + } + } + + /// Unions two eclasses given their ids. + /// + /// The given ids need not be canonical. + /// + /// If a union occurs, `merge` is called with the data, id, and parents of the two eclasses being merged + #[inline] + pub fn raw_union( + &mut self, + enode_id1: Id, + enode_id2: Id, + merge: impl FnOnce(MergeInfo<'_, D>), + ) { + let mut id1 = self.find_mut(enode_id1); + let mut id2 = self.find_mut(enode_id2); + if id1 == id2 { + return; + } + // make sure class2 has fewer parents + let class1_parents = self.classes[&id1].parents.len(); + let class2_parents = self.classes[&id2].parents.len(); + let mut swapped = false; + if class1_parents < class2_parents { + swapped = true; + std::mem::swap(&mut id1, &mut id2); + } + + // make id1 the new root + self.residual.unionfind.union(id1, id2); + + assert_ne!(id1, id2); + let class2 = self.classes.remove(&id2).unwrap(); + let class1 = self.classes.get_mut(&id1).unwrap(); + assert_eq!(id1, class1.id); + let info = MergeInfo { + id1: class1.id, + data1: &mut class1.raw_data, + parents1: &class1.parents, + id2: class2.id, + data2: class2.raw_data, + parents2: &class2.parents, + swapped_ids: swapped, + }; + merge(info); + + self.pending.extend(&class2.parents); + + class1.parents.extend(class2.parents); + } + + /// Rebuild to [`RawEGraph`] to restore congruence closure + /// + /// ## Parameters + /// + /// ### `get_self` + /// Called to extract the [`RawEGraph`] from the wrapper type, and should not perform any mutation. + /// + /// This will likely be a simple field access or just the identity function if there is no wrapper type. + /// + /// ### `perform_union` + /// Called on each pair of ids that needs to be unioned + /// + /// In order to be correct `perform_union` should call [`raw_union`](RawEGraph::raw_union) + /// + /// ### `handle_pending` + /// Called with the uncanonical id of each enode whose canonical children have changed, along with a canonical + /// version of it + #[inline] + pub fn raw_rebuild( + outer: &mut T, + get_self: impl Fn(&mut T) -> &mut Self, + mut perform_union: impl FnMut(&mut T, Id, Id), + handle_pending: impl FnMut(&mut T, Id, &L), + ) { + let _: Result<(), Infallible> = RawEGraph::try_raw_rebuild( + outer, + get_self, + |this, id1, id2| Ok(perform_union(this, id1, id2)), + handle_pending, + ); + } + + /// Similar to [`raw_rebuild`] but allows for the union operation to fail and abort the rebuild + #[inline] + pub fn try_raw_rebuild( + outer: &mut T, + get_self: impl Fn(&mut T) -> &mut Self, + mut perform_union: impl FnMut(&mut T, Id, Id) -> Result<(), E>, + mut handle_pending: impl FnMut(&mut T, Id, &L), + ) -> Result<(), E> { + loop { + let this = get_self(outer); + if let Some(class) = this.pending.pop() { + let mut node = this.id_to_node(class).clone(); + node.update_children(|id| this.find_mut(id)); + handle_pending(outer, class, &node); + if let Some(memo_class) = get_self(outer).residual.memo.insert(node, class) { + match perform_union(outer, memo_class, class) { + Ok(()) => {} + Err(e) => { + get_self(outer).pending.push(class); + return Err(e); + } + } + } + } else { + break Ok(()); + } + } + } + + /// Returns whether `self` is congruently closed + /// + /// This will always be true after calling [`raw_rebuild`](RawEGraph::raw_rebuild) + pub fn is_clean(&self) -> bool { + self.pending.is_empty() + } + + /// Returns a more debug-able representation of the egraph focusing on its classes. + /// + /// [`RawEGraph`]s implement [`Debug`], but it's not pretty. It + /// prints a lot of stuff you probably don't care about. + /// This method returns a wrapper that implements [`Debug`] in a + /// slightly nicer way, just dumping enodes in each eclass. + /// + /// [`Debug`]: std::fmt::Debug + pub fn dump_classes(&self) -> impl Debug + '_ + where + D: Debug, + { + EGraphDump(self) + } +} + +impl RawEGraph { + /// Simplified version of [`raw_add`](RawEGraph::raw_add) for egraphs without eclass data + pub fn add_uncanonical(&mut self, enode: L) -> Id { + Self::raw_add( + self, + |x| x, + enode, + |_, id, _| Some(id), + |_, _, _| {}, + |_, _, _| (), + ) + } + + /// Simplified version of [`raw_union`](RawEGraph::raw_union) for egraphs without eclass data + pub fn union(&mut self, id1: Id, id2: Id) -> bool { + let mut unioned = false; + self.raw_union(id1, id2, |_| { + unioned = true; + }); + unioned + } + + /// Simplified version of [`raw_rebuild`](RawEGraph::raw_rebuild) for egraphs without eclass data + pub fn rebuild(&mut self) { + Self::raw_rebuild( + self, + |x| x, + |this, id1, id2| { + this.union(id1, id2); + }, + |_, _, _| {}, + ); + } +} + +struct EGraphUncanonicalDump<'a, L: Language>(&'a EGraphResidual); + +impl<'a, L: Language> Debug for EGraphUncanonicalDump<'a, L> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + for (id, node) in self.0.uncanonical_nodes() { + writeln!(f, "{}: {:?} (root={})", id, node, self.0.find(id))? + } + Ok(()) + } +} + +struct EGraphDump<'a, L: Language, D>(&'a RawEGraph); + +impl<'a, L: Language, D: Debug> Debug for EGraphDump<'a, L, D> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut ids: Vec = self.0.classes().map(|c| c.id).collect(); + ids.sort(); + for id in ids { + writeln!(f, "{} {:?}", id, self.0.get_class(id).raw_data)? + } + Ok(()) + } +}