diff --git a/crates/circuit/src/equivalence.rs b/crates/circuit/src/equivalence.rs index 14fddca6a3b2..cad527f62cc4 100644 --- a/crates/circuit/src/equivalence.rs +++ b/crates/circuit/src/equivalence.rs @@ -11,6 +11,9 @@ // that they have been altered from the originals. use itertools::Itertools; +use rustworkx_core::petgraph::stable_graph::StableDiGraph; +use rustworkx_core::petgraph::visit::IntoEdgeReferences; + use smallvec::{smallvec, SmallVec}; use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; @@ -22,7 +25,7 @@ use pyo3::types::PyDict; use pyo3::{prelude::*, types::IntoPyDict}; use rustworkx_core::petgraph::{ - graph::{DiGraph, EdgeIndex, NodeIndex}, + graph::{EdgeIndex, NodeIndex}, visit::EdgeRef, }; @@ -383,7 +386,7 @@ impl Default for CircuitRep { } // Custom Types -type GraphType = DiGraph; +type GraphType = StableDiGraph; type KTIType = HashMap; #[pyclass( @@ -501,7 +504,7 @@ impl EquivalenceLibrary { /// the library, from earliest to latest, from top to base. The /// ordering of the StandardEquivalenceLibrary will not generally be /// consistent across Qiskit versions. - pub fn get_entry(&self, py: Python<'_>, gate: CircuitInstruction) -> Vec { + pub fn get_entry(&self, gate: CircuitInstruction) -> Vec { let key = Key { name: gate.operation().name().to_string(), num_qubits: gate.operation().num_qubits(), @@ -510,7 +513,7 @@ impl EquivalenceLibrary { self.get_equivalences(&key) .into_iter() - .filter_map(|equivalence| rebind_equiv(py, equivalence, query_params).ok()) + .filter_map(|equivalence| rebind_equiv(equivalence, query_params)) .collect() } @@ -764,29 +767,29 @@ fn raise_if_shape_mismatch( Ok(()) } -fn rebind_equiv( - py: Python<'_>, - equiv: Equivalence, - query_params: &[Param], -) -> PyResult { - let (equiv_params, equiv_circuit) = (equiv.params, equiv.circuit); - let param_map: Vec<(Param, Param)> = equiv_params - .into_iter() - .filter_map(|param| match param { - Param::ParameterExpression(_) => Some(param), - _ => None, - }) - .zip(query_params.iter().cloned()) - .collect(); - let dict = param_map.as_slice().into_py_dict_bound(py); - let kwargs = [("inplace", false), ("flat_input", true)].into_py_dict_bound(py); - let new_equiv = - equiv_circuit - .object - .call_method_bound(py, "assign_parameters", (dict,), Some(&kwargs))?; - new_equiv.extract::(py) +fn rebind_equiv(equiv: Equivalence, query_params: &[Param]) -> Option { + Python::with_gil(|py| -> PyResult { + let (equiv_params, equiv_circuit) = (equiv.params, equiv.circuit); + let param_map: Vec<(Param, Param)> = equiv_params + .into_iter() + .filter_map(|param| match param { + Param::ParameterExpression(_) => Some(param), + _ => None, + }) + .zip(query_params.iter().cloned()) + .collect(); + let dict = param_map.as_slice().into_py_dict_bound(py); + let kwargs = [("inplace", false), ("flat_input", true)].into_py_dict_bound(py); + let new_equiv = equiv_circuit.object.call_method_bound( + py, + "assign_parameters", + (dict,), + Some(&kwargs), + )?; + new_equiv.extract::(py) + }) + .ok() } - // Errors #[derive(Debug, Clone)] @@ -808,28 +811,25 @@ impl Display for EquivalenceError { } } -fn to_pygraph(py: Python<'_>, pet_graph: &DiGraph) -> PyResult +fn to_pygraph(py: Python<'_>, pet_graph: &StableDiGraph) -> PyResult where N: IntoPy + Clone, E: IntoPy + Clone, { let graph = PYDIGRAPH.get_bound(py).call0()?; - let node_weights = pet_graph.node_weights(); - for node in node_weights { - graph.call_method1("add_node", (node.to_owned(),))?; - } - let edge_weights = pet_graph.edge_references().map(|edge| { - ( - pet_graph.edge_endpoints(edge.id()).unwrap(), - pet_graph.edge_weight(edge.id()).unwrap(), - ) - }); - for ((source, target), weight) in edge_weights { - graph.call_method1( - "add_edge", - (source.index(), target.index(), weight.to_owned()), - )?; - } + let node_weights: Vec = pet_graph.node_weights().cloned().collect(); + graph.call_method1("add_nodes_from", (node_weights,))?; + let edge_weights: Vec<(usize, usize, E)> = pet_graph + .edge_references() + .map(|edge| { + ( + edge.source().index(), + edge.target().index(), + edge.weight().to_owned(), + ) + }) + .collect(); + graph.call_method1("add_edges_from", (edge_weights,))?; Ok(graph.unbind()) }