diff --git a/crates/circuit/src/dag_circuit.rs b/crates/circuit/src/dag_circuit.rs index a19600fa00f..aceb80e8f4c 100644 --- a/crates/circuit/src/dag_circuit.rs +++ b/crates/circuit/src/dag_circuit.rs @@ -3058,43 +3058,80 @@ def _format(operand): /// Remove all of the ancestor operation nodes of node. fn remove_ancestors_of(&mut self, node: &DAGNode) -> PyResult<()> { - // anc = rx.ancestors(self._multi_graph, node) - // # TODO: probably better to do all at once using - // # multi_graph.remove_nodes_from; same for related functions ... - // - // for anc_node in anc: - // if isinstance(anc_node, DAGOpNode): - // self.remove_op_node(anc_node) - todo!() + let ancestors: Vec<_> = core_ancestors(&self.dag, node.node.unwrap()) + .filter(|next| { + next != &node.node.unwrap() + && match self.dag.node_weight(*next) { + Some(NodeType::Operation(_)) => true, + _ => false, + } + }) + .collect(); + for a in ancestors { + self.dag.remove_node(a); + } + Ok(()) } /// Remove all of the descendant operation nodes of node. fn remove_descendants_of(&mut self, node: &DAGNode) -> PyResult<()> { - // desc = rx.descendants(self._multi_graph, node) - // for desc_node in desc: - // if isinstance(desc_node, DAGOpNode): - // self.remove_op_node(desc_node) - todo!() + let descendants: Vec<_> = core_descendants(&self.dag, node.node.unwrap()) + .filter(|next| { + next != &node.node.unwrap() + && match self.dag.node_weight(*next) { + Some(NodeType::Operation(_)) => true, + _ => false, + } + }) + .collect(); + for d in descendants { + self.dag.remove_node(d); + } + Ok(()) } /// Remove all of the non-ancestors operation nodes of node. fn remove_nonancestors_of(&mut self, node: &DAGNode) -> PyResult<()> { - // anc = rx.ancestors(self._multi_graph, node) - // comp = list(set(self._multi_graph.nodes()) - set(anc)) - // for n in comp: - // if isinstance(n, DAGOpNode): - // self.remove_op_node(n) - todo!() + let ancestors: HashSet<_> = core_ancestors(&self.dag, node.node.unwrap()) + .filter(|next| { + next != &node.node.unwrap() + && match self.dag.node_weight(*next) { + Some(NodeType::Operation(_)) => true, + _ => false, + } + }) + .collect(); + let non_ancestors: Vec<_> = self + .dag + .node_indices() + .filter(|node_id| !ancestors.contains(node_id)) + .collect(); + for na in non_ancestors { + self.dag.remove_node(na); + } + Ok(()) } /// Remove all of the non-descendants operation nodes of node. fn remove_nondescendants_of(&mut self, node: &DAGNode) -> PyResult<()> { - // dec = rx.descendants(self._multi_graph, node) - // comp = list(set(self._multi_graph.nodes()) - set(dec)) - // for n in comp: - // if isinstance(n, DAGOpNode): - // self.remove_op_node(n) - todo!() + let descendants: HashSet<_> = core_descendants(&self.dag, node.node.unwrap()) + .filter(|next| { + next != &node.node.unwrap() + && match self.dag.node_weight(*next) { + Some(NodeType::Operation(_)) => true, + _ => false, + } + }) + .collect(); + let non_descendants: Vec<_> = self + .dag + .node_indices() + .filter(|node_id| !descendants.contains(node_id)) + .collect(); + for nd in non_descendants { + self.dag.remove_node(nd); + } + Ok(()) } /// Return a list of op nodes in the first layer of this dag.