From 7872ff110629aa6b7dadaba961fb3039178df284 Mon Sep 17 00:00:00 2001 From: Andrew Cheung Date: Thu, 18 Jul 2024 18:06:49 -0700 Subject: [PATCH 1/4] Initial commit --- src/main.rs | 10 ++++++++ src/to_egraph_serialized.rs | 49 +++++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+) create mode 100644 src/to_egraph_serialized.rs diff --git a/src/main.rs b/src/main.rs index e3dd1c0..6d6704d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,6 +4,8 @@ pub use extract::*; use egraph_serialize::*; +mod to_egraph_serialized; + use indexmap::IndexMap; use ordered_float::NotNan; @@ -128,6 +130,8 @@ fn main() { .unwrap() .unwrap_or_else(|| "out.json".into()); + let pruned_filename: Option = args.opt_value_from_str("--pruned").unwrap(); + let filename: String = args.free_from_str().unwrap(); let rest = args.finish(); @@ -152,6 +156,12 @@ fn main() { result.check(&egraph); + if let Some(pruned_filename) = pruned_filename { + let egraph = to_egraph_serialized::get_term(&egraph, &result); + egraph.to_json_file(pruned_filename.clone()).unwrap(); + println!("Wrote pruned egraph to {}", pruned_filename.display()); + } + let tree = result.tree_cost(&egraph, &egraph.root_eclasses); let dag = result.dag_cost(&egraph, &egraph.root_eclasses); diff --git a/src/to_egraph_serialized.rs b/src/to_egraph_serialized.rs new file mode 100644 index 0000000..4b8e727 --- /dev/null +++ b/src/to_egraph_serialized.rs @@ -0,0 +1,49 @@ +use egraph_serialize::{ClassId, NodeId}; +use indexmap::IndexMap; + +use crate::ExtractionResult; + +pub fn get_term( + egraph: &egraph_serialize::EGraph, + result: &ExtractionResult, +) -> egraph_serialize::EGraph { + let choices = &result.choices; + assert!( + egraph.root_eclasses.len() == 1, + "expected exactly one root eclass", + ); + let root_cid = egraph.root_eclasses[0].clone(); + let mut result_egraph = egraph_serialize::EGraph::default(); + // populate_egraph(egraph, &mut result_egraph, choices, root_cid); + for cid in choices.keys() { + let node = &choices[cid]; + // add the node to the result egraph + if !result_egraph.nodes.contains_key(node) { + result_egraph.add_node(node.clone(), egraph.nodes[node].clone()); + } + } + // find number of eclasses in the original egraph + let mut eclasses = std::collections::HashSet::new(); + for enode in egraph.nodes.values() { + eclasses.insert(enode.eclass.clone()); + } + println!("eclasses in original: {}", eclasses.len()); + println!("eclasses in result: {}", result.choices.len()); + println!("original egraph size: {}", egraph.nodes.len()); + println!("result egraph size: {}", result_egraph.nodes.len()); + result_egraph +} + +fn populate_egraph( + egraph: &egraph_serialize::EGraph, + result_egraph: &mut egraph_serialize::EGraph, + choices: &IndexMap, + cid: ClassId, +) { + // get the node for the eclass + let node = &choices[&cid]; + // add the node to the result egraph + if !result_egraph.nodes.contains_key(node) { + result_egraph.add_node(node.clone(), egraph.nodes[node].clone()); + } +} From 5cb38be722e8a053f68b3a694fe4ea4aefd141d8 Mon Sep 17 00:00:00 2001 From: Andrew Cheung Date: Thu, 18 Jul 2024 19:13:27 -0700 Subject: [PATCH 2/4] Add root eclass back into result egraph --- src/to_egraph_serialized.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/to_egraph_serialized.rs b/src/to_egraph_serialized.rs index 4b8e727..decc458 100644 --- a/src/to_egraph_serialized.rs +++ b/src/to_egraph_serialized.rs @@ -22,11 +22,13 @@ pub fn get_term( result_egraph.add_node(node.clone(), egraph.nodes[node].clone()); } } + // find number of eclasses in the original egraph let mut eclasses = std::collections::HashSet::new(); for enode in egraph.nodes.values() { eclasses.insert(enode.eclass.clone()); } + result_egraph.root_eclasses = egraph.root_eclasses.clone(); println!("eclasses in original: {}", eclasses.len()); println!("eclasses in result: {}", result.choices.len()); println!("original egraph size: {}", egraph.nodes.len()); From 1eb4002ac038a82eca7c653c6ebd67d96a1d2d83 Mon Sep 17 00:00:00 2001 From: Andrew Cheung Date: Thu, 18 Jul 2024 19:52:03 -0700 Subject: [PATCH 3/4] Remove all traces of the old egraph --- src/to_egraph_serialized.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/to_egraph_serialized.rs b/src/to_egraph_serialized.rs index decc458..ad290b8 100644 --- a/src/to_egraph_serialized.rs +++ b/src/to_egraph_serialized.rs @@ -19,7 +19,14 @@ pub fn get_term( let node = &choices[cid]; // add the node to the result egraph if !result_egraph.nodes.contains_key(node) { - result_egraph.add_node(node.clone(), egraph.nodes[node].clone()); + let mut new_node = egraph.nodes[node].clone(); + new_node.children = egraph.nodes[node] + .children + .iter() + .map(|child| choices[egraph.nid_to_cid(&child)].clone()) + .collect(); + + result_egraph.add_node(node.clone(), new_node); } } From af820548dc4afa460ebf009a38d5c2c84de00b82 Mon Sep 17 00:00:00 2001 From: Andrew Cheung Date: Wed, 28 Aug 2024 14:05:45 -0700 Subject: [PATCH 4/4] Remove printlns --- src/to_egraph_serialized.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/to_egraph_serialized.rs b/src/to_egraph_serialized.rs index ad290b8..2d30973 100644 --- a/src/to_egraph_serialized.rs +++ b/src/to_egraph_serialized.rs @@ -36,10 +36,6 @@ pub fn get_term( eclasses.insert(enode.eclass.clone()); } result_egraph.root_eclasses = egraph.root_eclasses.clone(); - println!("eclasses in original: {}", eclasses.len()); - println!("eclasses in result: {}", result.choices.len()); - println!("original egraph size: {}", egraph.nodes.len()); - println!("result egraph size: {}", result_egraph.nodes.len()); result_egraph }