Skip to content

Commit

Permalink
Allow path compression to be disabled but undo log
Browse files Browse the repository at this point in the history
  • Loading branch information
dewert99 committed Dec 21, 2024
1 parent da3f0a2 commit 6dc6ca4
Show file tree
Hide file tree
Showing 9 changed files with 99 additions and 66 deletions.
13 changes: 7 additions & 6 deletions src/dot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Use the [`Dot`] struct to visualize an [`EGraph`](crate::EGraph)
use no_std_compat::prelude::v1::*;
use std::fmt::{self, Debug, Display, Formatter};

use crate::raw::reflect_const::PathCompressT;
use crate::{raw::EGraphResidual, raw::Language};

/**
Expand Down Expand Up @@ -48,16 +49,16 @@ instead of to its own eclass.
[GraphViz]: https://graphviz.gitlab.io/
**/
pub struct Dot<'a, L: Language> {
pub(crate) egraph: &'a EGraphResidual<L>,
pub struct Dot<'a, L: Language, P: PathCompressT> {
pub(crate) egraph: &'a EGraphResidual<L, P>,
/// A list of strings to be output top part of the dot file.
pub config: Vec<String>,
/// Whether or not to anchor the edges in the output.
/// True by default.
pub use_anchors: bool,
}

impl<'a, L> Dot<'a, L>
impl<'a, L, P: PathCompressT> Dot<'a, L, P>
where
L: Language + Display,
{
Expand Down Expand Up @@ -100,7 +101,7 @@ mod std_only {
use std::io::{Error, ErrorKind, Result, Write};
use std::path::Path;

impl<'a, L: Language + Display> Dot<'a, L> {
impl<'a, L: Language + Display, P: PathCompressT> Dot<'a, L, P> {
/// Writes the `Dot` to a .dot file with the given filename.
/// Does _not_ require a `dot` binary.
pub fn to_dot(&self, filename: impl AsRef<Path>) -> Result<()> {
Expand Down Expand Up @@ -177,13 +178,13 @@ mod std_only {
}
}

impl<'a, L: Language> Debug for Dot<'a, L> {
impl<'a, L: Language, P: PathCompressT> Debug for Dot<'a, L, P> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.debug_tuple("Dot").field(self.egraph).finish()
}
}

impl<'a, L> Display for Dot<'a, L>
impl<'a, L, P: PathCompressT> Display for Dot<'a, L, P>
where
L: Language + Display,
{
Expand Down
4 changes: 2 additions & 2 deletions src/explain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use std::mem;
use std::ops::{Deref, DerefMut};
use std::rc::Rc;

use crate::raw::RawEGraph;
use crate::raw::{RawEGraph, UndoLogT};
use symbolic_expressions::Sexp;

type ProofCost = Saturating<usize>;
Expand Down Expand Up @@ -1094,7 +1094,7 @@ impl<'a, L: Language, X> DerefMut for ExplainWith<'a, L, X> {
}
}

impl<'x, L: Language, D, U> ExplainWith<'x, L, &'x RawEGraph<L, D, U>> {
impl<'x, L: Language, D, U: UndoLogT<L, D>> ExplainWith<'x, L, &'x RawEGraph<L, D, U>> {
pub(crate) fn node(&self, node_id: Id) -> &L {
self.raw.id_to_node(node_id)
}
Expand Down
3 changes: 3 additions & 0 deletions src/raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ pub mod semi_persistent2;
mod unionfind;
pub(crate) mod util;

/// Types and traits for specify whether path compression is supported
pub mod reflect_const;

pub use eclass::RawEClass;
pub use egraph::{EGraphResidual, RawEGraph, UnionInfo};
pub use language::*;
Expand Down
54 changes: 34 additions & 20 deletions src/raw/egraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use std::{
};

use crate::raw::dhashmap::*;
use crate::raw::reflect_const::{PathCompress, PathCompressT};
use crate::raw::UndoLogT;
use default_vec2::BitSet;
#[cfg(feature = "serde-1")]
Expand All @@ -35,8 +36,8 @@ impl<'a> IntoIterator for Parents<'a> {
/// See [`RawEGraph::classes_mut`], [`RawEGraph::get_class_mut`]
#[derive(Clone)]
#[cfg_attr(feature = "serde-1", derive(Serialize, Deserialize))]
pub struct EGraphResidual<L: Language> {
pub(super) unionfind: UnionFind,
pub struct EGraphResidual<L: Language, P: PathCompressT = PathCompress<true>> {
pub(super) unionfind: UnionFind<P>,
/// Stores the original node represented by each non-canonical id
pub(super) nodes: Vec<L>,
/// Stores each enode's `Id`, not the `Id` of the eclass.
Expand All @@ -46,7 +47,7 @@ pub struct EGraphResidual<L: Language> {
pub(super) memo: DHashMap<L, Id>,
}

impl<L: Language> EGraphResidual<L> {
impl<L: Language, P: PathCompressT> EGraphResidual<L, P> {
/// Pick a representative term for a given Id.
///
/// Calling this function on an uncanonical `Id` returns a representative based on how it
Expand Down Expand Up @@ -308,7 +309,7 @@ impl<L: Language> EGraphResidual<L> {
}

/// Creates a [`Dot`] to visualize this egraph. See [`Dot`].
pub fn dot(&self) -> Dot<'_, L> {
pub fn dot(&self) -> Dot<'_, L, P> {
Dot {
egraph: self,
config: vec![],
Expand All @@ -317,8 +318,15 @@ impl<L: Language> EGraphResidual<L> {
}
}

impl<L: Language> EGraphResidual<L, PathCompress<false>> {
/// Return the direct parent from the union find without path compression
pub fn find_direct_parent(&self, id: Id) -> Id {
self.unionfind.parent_id(id)
}
}

// manual debug impl to avoid L: Language bound on EGraph defn
impl<L: Language> Debug for EGraphResidual<L> {
impl<L: Language, P: PathCompressT> Debug for EGraphResidual<L, P> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("EGraphResidual")
.field("unionfind", &self.unionfind)
Expand Down Expand Up @@ -356,9 +364,9 @@ to properly handle this data
**/
#[derive(Clone)]
#[cfg_attr(feature = "serde-1", derive(Serialize, Deserialize))]
pub struct RawEGraph<L: Language, D, U = ()> {
pub struct RawEGraph<L: Language, D, U: UndoLogT<L, D> = ()> {
#[cfg_attr(feature = "serde-1", serde(flatten))]
pub(super) residual: EGraphResidual<L>,
pub(super) residual: EGraphResidual<L, U::AllowPathCompress>,
/// Nodes which need to be processed for rebuilding. The `Id` is the `Id` of the enode,
/// not the canonical id of the eclass.
pub(super) pending: Vec<Id>,
Expand All @@ -368,7 +376,7 @@ pub struct RawEGraph<L: Language, D, U = ()> {
pub(super) undo_log: U,
}

impl<L: Language, D, U: Default> Default for RawEGraph<L, D, U> {
impl<L: Language, D, U: Default + UndoLogT<L, D>> Default for RawEGraph<L, D, U> {
fn default() -> Self {
let residual = EGraphResidual {
unionfind: Default::default(),
Expand All @@ -385,24 +393,24 @@ impl<L: Language, D, U: Default> Default for RawEGraph<L, D, U> {
}
}

impl<L: Language, D, U> Deref for RawEGraph<L, D, U> {
type Target = EGraphResidual<L>;
impl<L: Language, D, U: UndoLogT<L, D>> Deref for RawEGraph<L, D, U> {
type Target = EGraphResidual<L, U::AllowPathCompress>;

#[inline]
fn deref(&self) -> &Self::Target {
&self.residual
}
}

impl<L: Language, D, U> DerefMut for RawEGraph<L, D, U> {
impl<L: Language, D, U: UndoLogT<L, D>> DerefMut for RawEGraph<L, D, U> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.residual
}
}

// manual debug impl to avoid L: Language bound on EGraph defn
impl<L: Language, D: Debug, U> Debug for RawEGraph<L, D, U> {
impl<L: Language, D: Debug, U: UndoLogT<L, D>> Debug for RawEGraph<L, D, U> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let classes: BTreeMap<_, _> = self
.classes
Expand All @@ -428,7 +436,7 @@ impl<L: Language, D: Debug, U> Debug for RawEGraph<L, D, U> {
}
}

impl<L: Language, D, U> RawEGraph<L, D, U> {
impl<L: Language, D, U: UndoLogT<L, D>> RawEGraph<L, D, U> {
/// Returns an iterator over the eclasses in the egraph.
pub fn classes(&self) -> impl ExactSizeIterator<Item = &RawEClass<D>> {
self.classes.iter()
Expand All @@ -440,7 +448,7 @@ impl<L: Language, D, U> RawEGraph<L, D, U> {
&mut self,
) -> (
impl ExactSizeIterator<Item = &mut RawEClass<D>>,
&mut EGraphResidual<L>,
&mut EGraphResidual<L, U::AllowPathCompress>,
) {
let iter = self.classes.iter_mut();
(iter, &mut self.residual)
Expand Down Expand Up @@ -470,7 +478,10 @@ impl<L: Language, D, U> RawEGraph<L, D, U> {
pub fn get_class_mut<I: BorrowMut<Id>>(
&mut self,
mut id: I,
) -> (&mut RawEClass<D>, &mut EGraphResidual<L>) {
) -> (
&mut RawEClass<D>,
&mut EGraphResidual<L, U::AllowPathCompress>,
) {
let id = id.borrow_mut();
let (nid, cid) = self.unionfind.find_mut_full(*id);
*id = nid;
Expand All @@ -481,7 +492,10 @@ impl<L: Language, D, U> RawEGraph<L, D, U> {
pub fn get_class_mut_with_cannon(
&mut self,
id: Id,
) -> (&mut RawEClass<D>, &mut EGraphResidual<L>) {
) -> (
&mut RawEClass<D>,
&mut EGraphResidual<L, U::AllowPathCompress>,
) {
let cid = self.unionfind.find_canon(id);
(&mut self.classes[cid.idx()], &mut self.residual)
}
Expand Down Expand Up @@ -900,9 +914,9 @@ impl<L: Language, U: UndoLogT<L, ()>> RawEGraph<L, (), U> {
}
}

struct EGraphUncanonicalDump<'a, L: Language>(&'a EGraphResidual<L>);
struct EGraphUncanonicalDump<'a, L: Language, P: PathCompressT>(&'a EGraphResidual<L, P>);

impl<'a, L: Language> Debug for EGraphUncanonicalDump<'a, L> {
impl<'a, L: Language, P: PathCompressT> Debug for EGraphUncanonicalDump<'a, L, P> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
for (id, node) in self.0.uncanonical_nodes() {
writeln!(f, "{}: {:?} (root={})", id, node, self.0.find(id))?
Expand All @@ -911,9 +925,9 @@ impl<'a, L: Language> Debug for EGraphUncanonicalDump<'a, L> {
}
}

struct EGraphDump<'a, L: Language, D, U>(&'a RawEGraph<L, D, U>);
struct EGraphDump<'a, L: Language, D, U: UndoLogT<L, D>>(&'a RawEGraph<L, D, U>);

impl<'a, L: Language, D: Debug, U> Debug for EGraphDump<'a, L, D, U> {
impl<'a, L: Language, D: Debug, U: UndoLogT<L, D>> Debug for EGraphDump<'a, L, D, U> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut ids: Vec<Id> = self.0.classes().map(|c| c.id).collect();
ids.sort();
Expand Down
14 changes: 14 additions & 0 deletions src/raw/reflect_const.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#![allow(missing_docs)]
use core::fmt::Debug;

#[derive(Copy, Clone, Eq, PartialEq, Default, Debug)]
#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))]
pub struct PathCompress<const B: bool>;

impl<const B: bool> PathCompressT for PathCompress<B> {
const PATH_COMPRESS: bool = B;
}

pub trait PathCompressT: Copy + Clone + Eq + PartialEq + Default + Debug {
const PATH_COMPRESS: bool;
}
7 changes: 7 additions & 0 deletions src/raw/semi_persistent.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::raw::reflect_const::{PathCompress, PathCompressT};
use crate::raw::{Language, RawEGraph};
use crate::{ClassId, Id};
use no_std_compat::prelude::v1::*;
Expand All @@ -10,6 +11,8 @@ impl<U: Sealed> Sealed for Option<U> {}
/// A sealed trait for types that can be used for `push`/`pop` APIs
/// It is trivially implemented for `()`
pub trait UndoLogT<L, D>: Default + Debug + Sealed {
/// When this type of undo log allows for path compression
type AllowPathCompress: PathCompressT;
#[doc(hidden)]
fn add_node(&mut self, node: &L, canon_children: &[Id], node_id: Id, cid: ClassId);

Expand All @@ -33,6 +36,8 @@ pub trait UndoLogT<L, D>: Default + Debug + Sealed {
}

impl<L, D> UndoLogT<L, D> for () {
type AllowPathCompress = PathCompress<true>;

#[inline]
fn add_node(&mut self, _: &L, _: &[Id], _: Id, _: ClassId) {}

Expand All @@ -55,6 +60,8 @@ impl<L, D> UndoLogT<L, D> for () {
}

impl<L, D, U: UndoLogT<L, D>> UndoLogT<L, D> for Option<U> {
type AllowPathCompress = U::AllowPathCompress;

#[inline]
fn add_node(&mut self, node: &L, canon_children: &[Id], node_id: Id, cid: ClassId) {
if let Some(undo) = self {
Expand Down
Loading

0 comments on commit 6dc6ca4

Please sign in to comment.