From cb653baf14a74ce314d7ea5e93a18d0c10145f42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elena=20Pe=C3=B1a=20Tapia?= Date: Fri, 12 Jul 2024 11:25:37 +0200 Subject: [PATCH 1/4] Implement removal of ascestors, descendants, non-ancestors and non-descendants. --- crates/circuit/src/dag_circuit.rs | 76 +++++++++++++++++++++---------- 1 file changed, 51 insertions(+), 25 deletions(-) diff --git a/crates/circuit/src/dag_circuit.rs b/crates/circuit/src/dag_circuit.rs index a9d6637f774..971ac0126ad 100644 --- a/crates/circuit/src/dag_circuit.rs +++ b/crates/circuit/src/dag_circuit.rs @@ -2919,43 +2919,69 @@ def _format(operand): /// Remove all of the ancestor operation nodes of node. fn remove_ancestors_of(&mut self, node: &DAGNode) -> PyResult<()> { - // anc = rx.ancestors(self._multi_graph, node) - // # TODO: probably better to do all at once using - // # multi_graph.remove_nodes_from; same for related functions ... - // - // for anc_node in anc: - // if isinstance(anc_node, DAGOpNode): - // self.remove_op_node(anc_node) - todo!() + let dag_binding = self.dag.clone(); + for ancestor in core_ancestors(&dag_binding, node.node.unwrap()) + .filter(|next| next != &node.node.unwrap()) + .filter(|next| match dag_binding.node_weight(*next) { + Some(NodeType::Operation(_)) => true, + _ => false, + }) + { + self.dag.remove_node(ancestor); + } + Ok(()) } /// Remove all of the descendant operation nodes of node. fn remove_descendants_of(&mut self, node: &DAGNode) -> PyResult<()> { - // desc = rx.descendants(self._multi_graph, node) - // for desc_node in desc: - // if isinstance(desc_node, DAGOpNode): - // self.remove_op_node(desc_node) - todo!() + let dag_binding = self.dag.clone(); + for descendant in core_descendants(&dag_binding, node.node.unwrap()) + .filter(|next| next != &node.node.unwrap()) + .filter(|next| match dag_binding.node_weight(*next) { + Some(NodeType::Operation(_)) => true, + _ => false, + }) + { + self.dag.remove_node(descendant); + } + Ok(()) } /// Remove all of the non-ancestors operation nodes of node. fn remove_nonancestors_of(&mut self, node: &DAGNode) -> PyResult<()> { - // anc = rx.ancestors(self._multi_graph, node) - // comp = list(set(self._multi_graph.nodes()) - set(anc)) - // for n in comp: - // if isinstance(n, DAGOpNode): - // self.remove_op_node(n) - todo!() + let dag_binding = self.dag.clone(); + let mut ancestors = core_ancestors(&dag_binding, node.node.unwrap()) + .filter(|next| next != &node.node.unwrap()) + .filter(|next| match dag_binding.node_weight(*next) { + Some(NodeType::Operation(_)) => true, + _ => false, + }); + for node_id in dag_binding.node_indices() { + if ancestors.find(|anc| *anc == node_id).is_some() { + continue; + } + self.dag.remove_node(node_id); + } + Ok(()) } /// Remove all of the non-descendants operation nodes of node. fn remove_nondescendants_of(&mut self, node: &DAGNode) -> PyResult<()> { - // dec = rx.descendants(self._multi_graph, node) - // comp = list(set(self._multi_graph.nodes()) - set(dec)) - // for n in comp: - // if isinstance(n, DAGOpNode): - // self.remove_op_node(n) - todo!() + let dag_binding = self.dag.clone(); + let mut descendants = core_descendants(&dag_binding, node.node.unwrap()) + .filter(|next| next != &node.node.unwrap()) + .filter(|next| match dag_binding.node_weight(*next) { + Some(NodeType::Operation(_)) => true, + _ => false, + }); + + for node_id in dag_binding.node_indices() { + if descendants.find(|desc| *desc == node_id).is_some() { + continue; + } + self.dag.remove_node(node_id); + } + Ok(()) } /// Return a list of op nodes in the first layer of this dag. From 938141b4b61035741cbb33efe99af5292544154b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elena=20Pe=C3=B1a=20Tapia?= <57907331+ElePT@users.noreply.github.com> Date: Mon, 15 Jul 2024 09:45:55 +0200 Subject: [PATCH 2/4] Apply suggestions from John's code review Co-authored-by: John Lapeyre --- crates/circuit/src/dag_circuit.rs | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/crates/circuit/src/dag_circuit.rs b/crates/circuit/src/dag_circuit.rs index 971ac0126ad..ab1b88ba5de 100644 --- a/crates/circuit/src/dag_circuit.rs +++ b/crates/circuit/src/dag_circuit.rs @@ -2919,16 +2919,16 @@ def _format(operand): /// Remove all of the ancestor operation nodes of node. fn remove_ancestors_of(&mut self, node: &DAGNode) -> PyResult<()> { - let dag_binding = self.dag.clone(); - for ancestor in core_ancestors(&dag_binding, node.node.unwrap()) - .filter(|next| next != &node.node.unwrap()) - .filter(|next| match dag_binding.node_weight(*next) { - Some(NodeType::Operation(_)) => true, - _ => false, + let ancestors: Vec<_> = core_ancestors(&self.dag, node.node.unwrap()) + .filter(|next| { + next != &node.node.unwrap() + && match self.dag.node_weight(*next) { + Some(NodeType::Operation(_)) => true, + _ => false, + } }) - { - self.dag.remove_node(ancestor); - } + .collect(); + ancestors.iter().map(|a| self.dag.remove_node(*a)); Ok(()) } From d6ffa285affc4f45392a0863326f34976562b1a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elena=20Pe=C3=B1a=20Tapia?= Date: Mon, 15 Jul 2024 10:13:22 +0200 Subject: [PATCH 3/4] Apply suggestions from code review --- crates/circuit/src/dag_circuit.rs | 79 ++++++++++++++++++------------- 1 file changed, 45 insertions(+), 34 deletions(-) diff --git a/crates/circuit/src/dag_circuit.rs b/crates/circuit/src/dag_circuit.rs index ab1b88ba5de..3da8613bae6 100644 --- a/crates/circuit/src/dag_circuit.rs +++ b/crates/circuit/src/dag_circuit.rs @@ -2928,58 +2928,69 @@ def _format(operand): } }) .collect(); - ancestors.iter().map(|a| self.dag.remove_node(*a)); + for a in ancestors { + self.dag.remove_node(a); + } Ok(()) } /// Remove all of the descendant operation nodes of node. fn remove_descendants_of(&mut self, node: &DAGNode) -> PyResult<()> { - let dag_binding = self.dag.clone(); - for descendant in core_descendants(&dag_binding, node.node.unwrap()) - .filter(|next| next != &node.node.unwrap()) - .filter(|next| match dag_binding.node_weight(*next) { - Some(NodeType::Operation(_)) => true, - _ => false, + let descendants: Vec<_> = core_descendants(&self.dag, node.node.unwrap()) + .filter(|next| { + next != &node.node.unwrap() + && match self.dag.node_weight(*next) { + Some(NodeType::Operation(_)) => true, + _ => false, + } }) - { - self.dag.remove_node(descendant); + .collect(); + for d in descendants { + self.dag.remove_node(d); } Ok(()) } /// Remove all of the non-ancestors operation nodes of node. fn remove_nonancestors_of(&mut self, node: &DAGNode) -> PyResult<()> { - let dag_binding = self.dag.clone(); - let mut ancestors = core_ancestors(&dag_binding, node.node.unwrap()) - .filter(|next| next != &node.node.unwrap()) - .filter(|next| match dag_binding.node_weight(*next) { - Some(NodeType::Operation(_)) => true, - _ => false, - }); - for node_id in dag_binding.node_indices() { - if ancestors.find(|anc| *anc == node_id).is_some() { - continue; - } - self.dag.remove_node(node_id); + let ancestors: Vec<_> = core_ancestors(&self.dag, node.node.unwrap()) + .filter(|next| { + next != &node.node.unwrap() + && match self.dag.node_weight(*next) { + Some(NodeType::Operation(_)) => true, + _ => false, + } + }) + .collect(); + let non_ancestors: Vec<_> = self + .dag + .node_indices() + .filter(|node_id| ancestors.iter().find(|anc| *anc == node_id).is_none()) + .collect(); + for na in non_ancestors { + self.dag.remove_node(na); } Ok(()) } /// Remove all of the non-descendants operation nodes of node. fn remove_nondescendants_of(&mut self, node: &DAGNode) -> PyResult<()> { - let dag_binding = self.dag.clone(); - let mut descendants = core_descendants(&dag_binding, node.node.unwrap()) - .filter(|next| next != &node.node.unwrap()) - .filter(|next| match dag_binding.node_weight(*next) { - Some(NodeType::Operation(_)) => true, - _ => false, - }); - - for node_id in dag_binding.node_indices() { - if descendants.find(|desc| *desc == node_id).is_some() { - continue; - } - self.dag.remove_node(node_id); + let descendants: Vec<_> = core_descendants(&self.dag, node.node.unwrap()) + .filter(|next| { + next != &node.node.unwrap() + && match self.dag.node_weight(*next) { + Some(NodeType::Operation(_)) => true, + _ => false, + } + }) + .collect(); + let non_descendants: Vec<_> = self + .dag + .node_indices() + .filter(|node_id| descendants.iter().find(|desc| *desc == node_id).is_none()) + .collect(); + for nd in non_descendants { + self.dag.remove_node(nd); } Ok(()) } From 8bb71897b529a77b41bbb696246549d985a651e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elena=20Pe=C3=B1a=20Tapia?= Date: Tue, 16 Jul 2024 15:47:00 +0200 Subject: [PATCH 4/4] Use HashSet in nonancestors and nondescendants --- crates/circuit/src/dag_circuit.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/crates/circuit/src/dag_circuit.rs b/crates/circuit/src/dag_circuit.rs index 3da8613bae6..75190c44ba2 100644 --- a/crates/circuit/src/dag_circuit.rs +++ b/crates/circuit/src/dag_circuit.rs @@ -2953,7 +2953,7 @@ def _format(operand): /// Remove all of the non-ancestors operation nodes of node. fn remove_nonancestors_of(&mut self, node: &DAGNode) -> PyResult<()> { - let ancestors: Vec<_> = core_ancestors(&self.dag, node.node.unwrap()) + let ancestors: HashSet<_> = core_ancestors(&self.dag, node.node.unwrap()) .filter(|next| { next != &node.node.unwrap() && match self.dag.node_weight(*next) { @@ -2965,7 +2965,7 @@ def _format(operand): let non_ancestors: Vec<_> = self .dag .node_indices() - .filter(|node_id| ancestors.iter().find(|anc| *anc == node_id).is_none()) + .filter(|node_id| !ancestors.contains(node_id)) .collect(); for na in non_ancestors { self.dag.remove_node(na); @@ -2975,7 +2975,7 @@ def _format(operand): /// Remove all of the non-descendants operation nodes of node. fn remove_nondescendants_of(&mut self, node: &DAGNode) -> PyResult<()> { - let descendants: Vec<_> = core_descendants(&self.dag, node.node.unwrap()) + let descendants: HashSet<_> = core_descendants(&self.dag, node.node.unwrap()) .filter(|next| { next != &node.node.unwrap() && match self.dag.node_weight(*next) { @@ -2987,7 +2987,7 @@ def _format(operand): let non_descendants: Vec<_> = self .dag .node_indices() - .filter(|node_id| descendants.iter().find(|desc| *desc == node_id).is_none()) + .filter(|node_id| !descendants.contains(node_id)) .collect(); for nd in non_descendants { self.dag.remove_node(nd);