Skip to content

Commit

Permalink
Allow path compression to be disabled by undo log
Browse files Browse the repository at this point in the history
  • Loading branch information
dewert99 authored Dec 21, 2024
1 parent da3f0a2 commit b2b8397
Show file tree
Hide file tree
Showing 11 changed files with 128 additions and 72 deletions.
13 changes: 7 additions & 6 deletions src/dot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Use the [`Dot`] struct to visualize an [`EGraph`](crate::EGraph)
use no_std_compat::prelude::v1::*;
use std::fmt::{self, Debug, Display, Formatter};

use crate::raw::reflect_const::PathCompressT;
use crate::{raw::EGraphResidual, raw::Language};

/**
Expand Down Expand Up @@ -48,16 +49,16 @@ instead of to its own eclass.
[GraphViz]: https://graphviz.gitlab.io/
**/
pub struct Dot<'a, L: Language> {
pub(crate) egraph: &'a EGraphResidual<L>,
pub struct Dot<'a, L: Language, P: PathCompressT> {
pub(crate) egraph: &'a EGraphResidual<L, P>,
/// A list of strings to be output top part of the dot file.
pub config: Vec<String>,
/// Whether or not to anchor the edges in the output.
/// True by default.
pub use_anchors: bool,
}

impl<'a, L> Dot<'a, L>
impl<'a, L, P: PathCompressT> Dot<'a, L, P>
where
L: Language + Display,
{
Expand Down Expand Up @@ -100,7 +101,7 @@ mod std_only {
use std::io::{Error, ErrorKind, Result, Write};
use std::path::Path;

impl<'a, L: Language + Display> Dot<'a, L> {
impl<'a, L: Language + Display, P: PathCompressT> Dot<'a, L, P> {
/// Writes the `Dot` to a .dot file with the given filename.
/// Does _not_ require a `dot` binary.
pub fn to_dot(&self, filename: impl AsRef<Path>) -> Result<()> {
Expand Down Expand Up @@ -177,13 +178,13 @@ mod std_only {
}
}

impl<'a, L: Language> Debug for Dot<'a, L> {
impl<'a, L: Language, P: PathCompressT> Debug for Dot<'a, L, P> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.debug_tuple("Dot").field(self.egraph).finish()
}
}

impl<'a, L> Display for Dot<'a, L>
impl<'a, L, P: PathCompressT> Display for Dot<'a, L, P>
where
L: Language + Display,
{
Expand Down
3 changes: 2 additions & 1 deletion src/egraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use raw::semi_persistent1 as sp;
#[cfg(not(feature = "push-pop-alt"))]
use raw::semi_persistent2 as sp;

use crate::raw::UndoLogPC;
use sp::UndoLog;
type PushInfo = (sp::PushInfo, explain::PushInfo, usize);
/** A data structure to keep track of equalities between expressions.
Expand Down Expand Up @@ -108,7 +109,7 @@ impl<L: Language, N: Analysis<L>> Debug for EGraph<L, N> {
}

impl<L: Language, N: Analysis<L>> Deref for EGraph<L, N> {
type Target = EGraphResidual<L>;
type Target = EGraphResidual<L, <UndoLog as UndoLogPC>::AllowPathCompress>;

#[inline]
fn deref(&self) -> &Self::Target {
Expand Down
4 changes: 2 additions & 2 deletions src/explain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use std::mem;
use std::ops::{Deref, DerefMut};
use std::rc::Rc;

use crate::raw::RawEGraph;
use crate::raw::{RawEGraph, UndoLogT};
use symbolic_expressions::Sexp;

type ProofCost = Saturating<usize>;
Expand Down Expand Up @@ -1094,7 +1094,7 @@ impl<'a, L: Language, X> DerefMut for ExplainWith<'a, L, X> {
}
}

impl<'x, L: Language, D, U> ExplainWith<'x, L, &'x RawEGraph<L, D, U>> {
impl<'x, L: Language, D, U: UndoLogT<L, D>> ExplainWith<'x, L, &'x RawEGraph<L, D, U>> {
pub(crate) fn node(&self, node_id: Id) -> &L {
self.raw.id_to_node(node_id)
}
Expand Down
5 changes: 3 additions & 2 deletions src/explain/semi_persistent.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::explain::{Connection, Explain};
use crate::raw::reflect_const::PathCompressT;
use crate::raw::EGraphResidual;
use crate::{Id, Language};
use no_std_compat::prelude::v1::*;
Expand Down Expand Up @@ -28,11 +29,11 @@ impl<L: Language> Explain<L> {
PushInfo(self.undo_log.as_ref().unwrap().len())
}

pub(crate) fn pop(
pub(crate) fn pop<P: PathCompressT>(
&mut self,
info: PushInfo,
number_of_uncanon_nodes: usize,
egraph: &EGraphResidual<L>,
egraph: &EGraphResidual<L, P>,
) {
for id in self.undo_log.as_mut().unwrap().drain(info.0..).rev() {
let node1 = &mut self.explainfind[usize::from(id)];
Expand Down
5 changes: 4 additions & 1 deletion src/raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@ pub mod semi_persistent2;
mod unionfind;
pub(crate) mod util;

/// Types and traits for specify whether path compression is supported
pub mod reflect_const;

pub use eclass::RawEClass;
pub use egraph::{EGraphResidual, RawEGraph, UnionInfo};
pub use language::*;
use semi_persistent::Sealed;
pub use semi_persistent::{AsUnwrap, UndoLogT};
pub use semi_persistent::{AsUnwrap, UndoLogPC, UndoLogT};
pub use unionfind::UnionFind;
58 changes: 38 additions & 20 deletions src/raw/egraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use std::{
};

use crate::raw::dhashmap::*;
use crate::raw::reflect_const::{PathCompress, PathCompressT};
use crate::raw::UndoLogT;
use default_vec2::BitSet;
#[cfg(feature = "serde-1")]
Expand All @@ -35,8 +36,12 @@ impl<'a> IntoIterator for Parents<'a> {
/// See [`RawEGraph::classes_mut`], [`RawEGraph::get_class_mut`]
#[derive(Clone)]
#[cfg_attr(feature = "serde-1", derive(Serialize, Deserialize))]
pub struct EGraphResidual<L: Language> {
pub(super) unionfind: UnionFind,
#[cfg_attr(
feature = "serde-1",
serde(bound(serialize = "L: Serialize", deserialize = "L: Deserialize<'de>"))
)]
pub struct EGraphResidual<L: Language, P: PathCompressT = PathCompress<true>> {
pub(super) unionfind: UnionFind<P>,
/// Stores the original node represented by each non-canonical id
pub(super) nodes: Vec<L>,
/// Stores each enode's `Id`, not the `Id` of the eclass.
Expand All @@ -46,7 +51,7 @@ pub struct EGraphResidual<L: Language> {
pub(super) memo: DHashMap<L, Id>,
}

impl<L: Language> EGraphResidual<L> {
impl<L: Language, P: PathCompressT> EGraphResidual<L, P> {
/// Pick a representative term for a given Id.
///
/// Calling this function on an uncanonical `Id` returns a representative based on how it
Expand Down Expand Up @@ -308,7 +313,7 @@ impl<L: Language> EGraphResidual<L> {
}

/// Creates a [`Dot`] to visualize this egraph. See [`Dot`].
pub fn dot(&self) -> Dot<'_, L> {
pub fn dot(&self) -> Dot<'_, L, P> {
Dot {
egraph: self,
config: vec![],
Expand All @@ -317,8 +322,15 @@ impl<L: Language> EGraphResidual<L> {
}
}

impl<L: Language> EGraphResidual<L, PathCompress<false>> {
/// Return the direct parent from the union find without path compression
pub fn find_direct_parent(&self, id: Id) -> Id {
self.unionfind.parent_id(id)
}
}

// manual debug impl to avoid L: Language bound on EGraph defn
impl<L: Language> Debug for EGraphResidual<L> {
impl<L: Language, P: PathCompressT> Debug for EGraphResidual<L, P> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("EGraphResidual")
.field("unionfind", &self.unionfind)
Expand Down Expand Up @@ -356,9 +368,9 @@ to properly handle this data
**/
#[derive(Clone)]
#[cfg_attr(feature = "serde-1", derive(Serialize, Deserialize))]
pub struct RawEGraph<L: Language, D, U = ()> {
pub struct RawEGraph<L: Language, D, U: UndoLogT<L, D> = ()> {
#[cfg_attr(feature = "serde-1", serde(flatten))]
pub(super) residual: EGraphResidual<L>,
pub(super) residual: EGraphResidual<L, U::AllowPathCompress>,
/// Nodes which need to be processed for rebuilding. The `Id` is the `Id` of the enode,
/// not the canonical id of the eclass.
pub(super) pending: Vec<Id>,
Expand All @@ -368,7 +380,7 @@ pub struct RawEGraph<L: Language, D, U = ()> {
pub(super) undo_log: U,
}

impl<L: Language, D, U: Default> Default for RawEGraph<L, D, U> {
impl<L: Language, D, U: Default + UndoLogT<L, D>> Default for RawEGraph<L, D, U> {
fn default() -> Self {
let residual = EGraphResidual {
unionfind: Default::default(),
Expand All @@ -385,24 +397,24 @@ impl<L: Language, D, U: Default> Default for RawEGraph<L, D, U> {
}
}

impl<L: Language, D, U> Deref for RawEGraph<L, D, U> {
type Target = EGraphResidual<L>;
impl<L: Language, D, U: UndoLogT<L, D>> Deref for RawEGraph<L, D, U> {
type Target = EGraphResidual<L, U::AllowPathCompress>;

#[inline]
fn deref(&self) -> &Self::Target {
&self.residual
}
}

impl<L: Language, D, U> DerefMut for RawEGraph<L, D, U> {
impl<L: Language, D, U: UndoLogT<L, D>> DerefMut for RawEGraph<L, D, U> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.residual
}
}

// manual debug impl to avoid L: Language bound on EGraph defn
impl<L: Language, D: Debug, U> Debug for RawEGraph<L, D, U> {
impl<L: Language, D: Debug, U: UndoLogT<L, D>> Debug for RawEGraph<L, D, U> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let classes: BTreeMap<_, _> = self
.classes
Expand All @@ -428,7 +440,7 @@ impl<L: Language, D: Debug, U> Debug for RawEGraph<L, D, U> {
}
}

impl<L: Language, D, U> RawEGraph<L, D, U> {
impl<L: Language, D, U: UndoLogT<L, D>> RawEGraph<L, D, U> {
/// Returns an iterator over the eclasses in the egraph.
pub fn classes(&self) -> impl ExactSizeIterator<Item = &RawEClass<D>> {
self.classes.iter()
Expand All @@ -440,7 +452,7 @@ impl<L: Language, D, U> RawEGraph<L, D, U> {
&mut self,
) -> (
impl ExactSizeIterator<Item = &mut RawEClass<D>>,
&mut EGraphResidual<L>,
&mut EGraphResidual<L, U::AllowPathCompress>,
) {
let iter = self.classes.iter_mut();
(iter, &mut self.residual)
Expand Down Expand Up @@ -470,7 +482,10 @@ impl<L: Language, D, U> RawEGraph<L, D, U> {
pub fn get_class_mut<I: BorrowMut<Id>>(
&mut self,
mut id: I,
) -> (&mut RawEClass<D>, &mut EGraphResidual<L>) {
) -> (
&mut RawEClass<D>,
&mut EGraphResidual<L, U::AllowPathCompress>,
) {
let id = id.borrow_mut();
let (nid, cid) = self.unionfind.find_mut_full(*id);
*id = nid;
Expand All @@ -481,7 +496,10 @@ impl<L: Language, D, U> RawEGraph<L, D, U> {
pub fn get_class_mut_with_cannon(
&mut self,
id: Id,
) -> (&mut RawEClass<D>, &mut EGraphResidual<L>) {
) -> (
&mut RawEClass<D>,
&mut EGraphResidual<L, U::AllowPathCompress>,
) {
let cid = self.unionfind.find_canon(id);
(&mut self.classes[cid.idx()], &mut self.residual)
}
Expand Down Expand Up @@ -900,9 +918,9 @@ impl<L: Language, U: UndoLogT<L, ()>> RawEGraph<L, (), U> {
}
}

struct EGraphUncanonicalDump<'a, L: Language>(&'a EGraphResidual<L>);
struct EGraphUncanonicalDump<'a, L: Language, P: PathCompressT>(&'a EGraphResidual<L, P>);

impl<'a, L: Language> Debug for EGraphUncanonicalDump<'a, L> {
impl<'a, L: Language, P: PathCompressT> Debug for EGraphUncanonicalDump<'a, L, P> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
for (id, node) in self.0.uncanonical_nodes() {
writeln!(f, "{}: {:?} (root={})", id, node, self.0.find(id))?
Expand All @@ -911,9 +929,9 @@ impl<'a, L: Language> Debug for EGraphUncanonicalDump<'a, L> {
}
}

struct EGraphDump<'a, L: Language, D, U>(&'a RawEGraph<L, D, U>);
struct EGraphDump<'a, L: Language, D, U: UndoLogT<L, D>>(&'a RawEGraph<L, D, U>);

impl<'a, L: Language, D: Debug, U> Debug for EGraphDump<'a, L, D, U> {
impl<'a, L: Language, D: Debug, U: UndoLogT<L, D>> Debug for EGraphDump<'a, L, D, U> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut ids: Vec<Id> = self.0.classes().map(|c| c.id).collect();
ids.sort();
Expand Down
14 changes: 14 additions & 0 deletions src/raw/reflect_const.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#![allow(missing_docs)]
use core::fmt::Debug;

#[derive(Copy, Clone, Eq, PartialEq, Default, Debug)]
#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))]
pub struct PathCompress<const B: bool>;

impl<const B: bool> PathCompressT for PathCompress<B> {
const PATH_COMPRESS: bool = B;
}

pub trait PathCompressT: Copy + Clone + Eq + PartialEq + Default + Debug {
const PATH_COMPRESS: bool;
}
18 changes: 17 additions & 1 deletion src/raw/semi_persistent.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::raw::reflect_const::{PathCompress, PathCompressT};
use crate::raw::{Language, RawEGraph};
use crate::{ClassId, Id};
use no_std_compat::prelude::v1::*;
Expand All @@ -9,7 +10,14 @@ impl<U: Sealed> Sealed for Option<U> {}

/// A sealed trait for types that can be used for `push`/`pop` APIs
/// It is trivially implemented for `()`
pub trait UndoLogT<L, D>: Default + Debug + Sealed {
pub trait UndoLogPC {
/// When this type of undo log allows for path compression
type AllowPathCompress: PathCompressT;
}

/// A sealed trait for types that can be used for `push`/`pop` APIs
/// It is trivially implemented for `()`
pub trait UndoLogT<L, D>: Default + Debug + Sealed + UndoLogPC {
#[doc(hidden)]
fn add_node(&mut self, node: &L, canon_children: &[Id], node_id: Id, cid: ClassId);

Expand All @@ -32,6 +40,10 @@ pub trait UndoLogT<L, D>: Default + Debug + Sealed {
fn is_enabled(&self) -> bool;
}

impl UndoLogPC for () {
type AllowPathCompress = PathCompress<true>;
}

impl<L, D> UndoLogT<L, D> for () {
#[inline]
fn add_node(&mut self, _: &L, _: &[Id], _: Id, _: ClassId) {}
Expand All @@ -54,6 +66,10 @@ impl<L, D> UndoLogT<L, D> for () {
}
}

impl<U: UndoLogPC> UndoLogPC for Option<U> {
type AllowPathCompress = U::AllowPathCompress;
}

impl<L, D, U: UndoLogT<L, D>> UndoLogT<L, D> for Option<U> {
#[inline]
fn add_node(&mut self, node: &L, canon_children: &[Id], node_id: Id, cid: ClassId) {
Expand Down
Loading

0 comments on commit b2b8397

Please sign in to comment.