Skip to content

Commit

Permalink
Fix: Use StableDiGraph for more stable indexing.
Browse files Browse the repository at this point in the history
- Remove required py argument for get_entry.
- Reformat `to_pygraph` to use `add_nodes_from` and `add_edges_from`.
- Other small fixes.
  • Loading branch information
raynelfss committed Jun 21, 2024
1 parent 5a3bf2b commit 38436cd
Showing 1 changed file with 43 additions and 43 deletions.
86 changes: 43 additions & 43 deletions crates/circuit/src/equivalence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
// that they have been altered from the originals.

use itertools::Itertools;
use rustworkx_core::petgraph::stable_graph::StableDiGraph;
use rustworkx_core::petgraph::visit::IntoEdgeReferences;

use smallvec::{smallvec, SmallVec};
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
Expand All @@ -22,7 +25,7 @@ use pyo3::types::PyDict;
use pyo3::{prelude::*, types::IntoPyDict};

use rustworkx_core::petgraph::{
graph::{DiGraph, EdgeIndex, NodeIndex},
graph::{EdgeIndex, NodeIndex},
visit::EdgeRef,
};

Expand Down Expand Up @@ -383,7 +386,7 @@ impl Default for CircuitRep {
}

// Custom Types
type GraphType = DiGraph<NodeData, EdgeData>;
type GraphType = StableDiGraph<NodeData, EdgeData>;
type KTIType = HashMap<Key, NodeIndex>;

#[pyclass(
Expand Down Expand Up @@ -501,7 +504,7 @@ impl EquivalenceLibrary {
/// the library, from earliest to latest, from top to base. The
/// ordering of the StandardEquivalenceLibrary will not generally be
/// consistent across Qiskit versions.
pub fn get_entry(&self, py: Python<'_>, gate: CircuitInstruction) -> Vec<CircuitRep> {
pub fn get_entry(&self, gate: CircuitInstruction) -> Vec<CircuitRep> {
let key = Key {
name: gate.operation().name().to_string(),
num_qubits: gate.operation().num_qubits(),
Expand All @@ -510,7 +513,7 @@ impl EquivalenceLibrary {

self.get_equivalences(&key)
.into_iter()
.filter_map(|equivalence| rebind_equiv(py, equivalence, query_params).ok())
.filter_map(|equivalence| rebind_equiv(equivalence, query_params))
.collect()
}

Expand Down Expand Up @@ -764,29 +767,29 @@ fn raise_if_shape_mismatch(
Ok(())
}

fn rebind_equiv(
py: Python<'_>,
equiv: Equivalence,
query_params: &[Param],
) -> PyResult<CircuitRep> {
let (equiv_params, equiv_circuit) = (equiv.params, equiv.circuit);
let param_map: Vec<(Param, Param)> = equiv_params
.into_iter()
.filter_map(|param| match param {
Param::ParameterExpression(_) => Some(param),
_ => None,
})
.zip(query_params.iter().cloned())
.collect();
let dict = param_map.as_slice().into_py_dict_bound(py);
let kwargs = [("inplace", false), ("flat_input", true)].into_py_dict_bound(py);
let new_equiv =
equiv_circuit
.object
.call_method_bound(py, "assign_parameters", (dict,), Some(&kwargs))?;
new_equiv.extract::<CircuitRep>(py)
fn rebind_equiv(equiv: Equivalence, query_params: &[Param]) -> Option<CircuitRep> {
Python::with_gil(|py| -> PyResult<CircuitRep> {
let (equiv_params, equiv_circuit) = (equiv.params, equiv.circuit);
let param_map: Vec<(Param, Param)> = equiv_params
.into_iter()
.filter_map(|param| match param {
Param::ParameterExpression(_) => Some(param),
_ => None,
})
.zip(query_params.iter().cloned())
.collect();
let dict = param_map.as_slice().into_py_dict_bound(py);
let kwargs = [("inplace", false), ("flat_input", true)].into_py_dict_bound(py);
let new_equiv = equiv_circuit.object.call_method_bound(
py,
"assign_parameters",
(dict,),
Some(&kwargs),
)?;
new_equiv.extract::<CircuitRep>(py)
})
.ok()
}

// Errors

#[derive(Debug, Clone)]
Expand All @@ -808,28 +811,25 @@ impl Display for EquivalenceError {
}
}

fn to_pygraph<N, E>(py: Python<'_>, pet_graph: &DiGraph<N, E>) -> PyResult<PyObject>
fn to_pygraph<N, E>(py: Python<'_>, pet_graph: &StableDiGraph<N, E>) -> PyResult<PyObject>
where
N: IntoPy<PyObject> + Clone,
E: IntoPy<PyObject> + Clone,
{
let graph = PYDIGRAPH.get_bound(py).call0()?;
let node_weights = pet_graph.node_weights();
for node in node_weights {
graph.call_method1("add_node", (node.to_owned(),))?;
}
let edge_weights = pet_graph.edge_references().map(|edge| {
(
pet_graph.edge_endpoints(edge.id()).unwrap(),
pet_graph.edge_weight(edge.id()).unwrap(),
)
});
for ((source, target), weight) in edge_weights {
graph.call_method1(
"add_edge",
(source.index(), target.index(), weight.to_owned()),
)?;
}
let node_weights: Vec<N> = pet_graph.node_weights().cloned().collect();
graph.call_method1("add_nodes_from", (node_weights,))?;
let edge_weights: Vec<(usize, usize, E)> = pet_graph
.edge_references()
.map(|edge| {
(
edge.source().index(),
edge.target().index(),
edge.weight().to_owned(),
)
})
.collect();
graph.call_method1("add_edges_from", (edge_weights,))?;
Ok(graph.unbind())
}

Expand Down

0 comments on commit 38436cd

Please sign in to comment.