Skip to content

Commit

Permalink
Merge pull request Qiskit#14 from raynelfss/oxidize-dag-family-tree
Browse files Browse the repository at this point in the history
[Oxidize DAGCircuit] Add: ancestors, descendants, bfs_successors to oxidized `DAGCircuit`
  • Loading branch information
kevinhartman authored Jul 2, 2024
2 parents 7523def + 3c402fe commit a7c697b
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 22 deletions.
66 changes: 56 additions & 10 deletions crates/circuit/src/dag_circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ use rustworkx_core::petgraph::prelude::StableDiGraph;
use rustworkx_core::petgraph::stable_graph::{DefaultIx, IndexType, Neighbors, NodeIndex};
use rustworkx_core::petgraph::visit::{IntoNodeReferences, NodeCount, NodeRef};
use rustworkx_core::petgraph::Incoming;
use rustworkx_core::traversal::{
ancestors as core_ancestors, bfs_successors as core_bfs_successors,
descendants as core_descendants,
};
use std::borrow::Borrow;
use std::convert::Infallible;
use std::f64::consts::PI;
use std::ffi::c_double;
Expand Down Expand Up @@ -1920,7 +1925,6 @@ def _format(operand):
// ))
}


/// Yield nodes in topological order.
///
/// Args:
Expand Down Expand Up @@ -2803,22 +2807,45 @@ def _format(operand):
}

/// Returns set of the ancestors of a node as DAGOpNodes and DAGInNodes.
fn ancestors(&self, py: Python, node: &DAGNode) -> PyResult<Py<PySet>> {
// return {self._multi_graph[x] for x in rx.ancestors(self._multi_graph, node._node_id)}
todo!()
#[pyo3(name = "ancestors")]
fn py_ancestors(&self, py: Python, node: &DAGNode) -> PyResult<Py<PySet>> {
let ancestors: PyResult<Vec<PyObject>> = self
.ancestors(node.node.unwrap())
.map(|node| self.get_node(py, node))
.collect();
Ok(PySet::new_bound(py, &ancestors?)?.unbind())
}

/// Returns set of the descendants of a node as DAGOpNodes and DAGOutNodes.
fn descendants(&self, py: Python, node: &DAGNode) -> PyResult<Py<PySet>> {
// return {self._multi_graph[x] for x in rx.descendants(self._multi_graph, node._node_id)}
todo!()
#[pyo3(name = "descendants")]
fn py_descendants(&self, py: Python, node: &DAGNode) -> PyResult<Py<PySet>> {
let descendants: PyResult<Vec<PyObject>> = self
.descendants(node.node.unwrap())
.map(|node| self.get_node(py, node))
.collect();
Ok(PySet::new_bound(py, &descendants?)?.unbind())
}

/// Returns an iterator of tuples of (DAGNode, [DAGNodes]) where the DAGNode is the current node
/// and [DAGNode] is its successors in BFS order.
fn bfs_successors(&self, py: Python, node: &DAGNode) -> PyResult<Py<PySet>> {
// return iter(rx.bfs_successors(self._multi_graph, node._node_id))
todo!()
#[pyo3(name = "bfs_successors")]
fn py_bfs_successors(&self, py: Python, node: &DAGNode) -> PyResult<Py<PyIterator>> {
let successor_index: PyResult<Vec<(PyObject, Vec<PyObject>)>> = self
.bfs_successors(node.node.unwrap())
.map(|(node, nodes)| -> PyResult<(PyObject, Vec<PyObject>)> {
Ok((
self.get_node(py, node)?,
nodes
.iter()
.map(|sub_node| self.get_node(py, *sub_node))
.collect::<PyResult<Vec<_>>>()?,
))
})
.collect();
Ok(PyList::new_bound(py, successor_index?)
.into_any()
.iter()?
.unbind())
}

/// Returns iterator of the successors of a node that are
Expand Down Expand Up @@ -3571,6 +3598,25 @@ impl DAGCircuit {
}
}

/// Returns an iterator of the ancestors indices of a node.
pub fn ancestors<'a>(&'a self, node: NodeIndex) -> impl Iterator<Item = NodeIndex> + 'a {
core_ancestors(&self.dag, node).filter(move |next| next != &node)
}

/// Returns an iterator of the descendants of a node as DAGOpNodes and DAGOutNodes.
pub fn descendants<'a>(&'a self, node: NodeIndex) -> impl Iterator<Item = NodeIndex> + 'a {
core_descendants(&self.dag, node).filter(move |next| next != &node)
}

/// Returns an iterator of tuples of (DAGNode, [DAGNodes]) where the DAGNode is the current node
/// and [DAGNode] is its successors in BFS order.
pub fn bfs_successors<'a>(
&'a self,
node: NodeIndex,
) -> impl Iterator<Item = (NodeIndex, Vec<NodeIndex>)> + 'a {
core_bfs_successors(&self.dag, node).filter(move |(_, others)| !others.is_empty())
}

fn unpack_into(&self, py: Python, id: NodeIndex, weight: &NodeType) -> PyResult<Py<PyAny>> {
let dag_node = match weight {
NodeType::QubitIn(qubit) => Py::new(
Expand Down
21 changes: 9 additions & 12 deletions crates/circuit/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ mod interner;

use pyo3::prelude::*;
use pyo3::types::{PySequence, PySlice, PyTuple};
use std::ops::Deref;
use pyo3::DowncastError;
use std::ops::Deref;

/// A private enumeration type used to extract arguments to pymethod
/// that may be either an index or a slice
Expand All @@ -51,18 +51,15 @@ pub struct TupleLikeArg<'py> {
impl<'py> FromPyObject<'py> for TupleLikeArg<'py> {
fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
let value = match ob.downcast::<PySequence>() {
Ok(seq) => {seq.to_tuple()?}
Err(_) => {
PyTuple::new_bound(
ob.py(),
ob.iter()?
.map(|o| Ok(o?.unbind()))
.collect::<PyResult<Vec<PyObject>>>()?)
}
Ok(seq) => seq.to_tuple()?,
Err(_) => PyTuple::new_bound(
ob.py(),
ob.iter()?
.map(|o| Ok(o?.unbind()))
.collect::<PyResult<Vec<PyObject>>>()?,
),
};
Ok(TupleLikeArg {
value,
})
Ok(TupleLikeArg { value })
}
}

Expand Down

0 comments on commit a7c697b

Please sign in to comment.