Skip to content

Commit

Permalink
Fix: Further improvements to pickling
Browse files Browse the repository at this point in the history
- Use python structures to avoid extra conversions.
- Add rust native `EquivalenceLibrary.keys()` and have the python method use it.
  • Loading branch information
raynelfss committed Jul 1, 2024
1 parent a072635 commit 77dbec8
Showing 1 changed file with 33 additions and 20 deletions.
53 changes: 33 additions & 20 deletions crates/circuit/src/equivalence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use std::{error::Error, fmt::Display};

use exceptions::CircuitError;
use hashbrown::{HashMap, HashSet};
use pyo3::types::{PyDict, PyString};
use pyo3::types::{PyDict, PyList, PySet, PyString};
use pyo3::{prelude::*, types::IntoPyDict};

use rustworkx_core::petgraph::{
Expand Down Expand Up @@ -121,8 +121,9 @@ impl Equivalence {
self.eq(&other)
}

fn __getnewargs__(&self, py: Python) -> (PyObject, PyObject) {
(self.params.to_object(py), self.circuit.object.clone_ref(py))
fn __getnewargs__(&self, py: Python) -> (Py<PyList>, PyObject) {
let params = PyList::new_bound(py, self.params.iter().map(|param| param.to_object(py)));
(params.unbind(), self.circuit.object.clone_ref(py))
}
}

Expand Down Expand Up @@ -313,10 +314,8 @@ impl FromPyObject<'_> for CircuitRep {
impl Display for CircuitRep {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let py_rep_str = Python::with_gil(|py| -> PyResult<String> {
match self.object.call_method0(py, "__repr__") {
Ok(str_obj) => str_obj.extract::<String>(py),
Err(_) => Ok("None".to_string()),
}
let bound = self.object.bind(py);
bound.repr().map(|pystring| pystring.to_string())
})
.unwrap();
write!(f, "{}", py_rep_str)
Expand Down Expand Up @@ -396,6 +395,7 @@ impl EquivalenceLibrary {
/// gate (Gate): A Gate instance.
/// equivalent_circuit (QuantumCircuit): A circuit equivalently
/// implementing the given Gate.
#[pyo3(text_signature = "(gate, equivalent_circuit, /,")]
fn add_equivalence(&mut self, gate: GateOper, equivalent_circuit: CircuitRep) -> PyResult<()> {
match self.add_equiv(gate, equivalent_circuit) {
Ok(_) => Ok(()),
Expand All @@ -411,6 +411,7 @@ impl EquivalenceLibrary {
/// Returns:
/// Bool: True if gate has a known decomposition in the library.
/// False otherwise.
#[pyo3(text_signature = "(gate, /,)")]
pub fn has_entry(&self, gate: GateOper) -> bool {
let key = Key {
name: gate.operation.name().to_string(),
Expand All @@ -430,6 +431,7 @@ impl EquivalenceLibrary {
/// gate (Gate): A Gate instance.
/// entry (List['QuantumCircuit']) : A list of QuantumCircuits, each
/// equivalently implementing the given Gate.
#[pyo3(text_signature = "(gate, entry, /,)")]
fn set_entry(&mut self, gate: GateOper, entry: Vec<CircuitRep>) -> PyResult<()> {
match self.set_entry_native(gate, entry) {
Ok(_) => Ok(()),
Expand All @@ -454,6 +456,7 @@ impl EquivalenceLibrary {
/// the library, from earliest to latest, from top to base. The
/// ordering of the StandardEquivalenceLibrary will not generally be
/// consistent across Qiskit versions.
#[pyo3(text_signature = "(gate, /,)")]
pub fn get_entry(&self, gate: GateOper) -> Vec<CircuitRep> {
let key = Key {
name: gate.operation.name().to_string(),
Expand Down Expand Up @@ -490,8 +493,13 @@ impl EquivalenceLibrary {
}
}

fn keys(&self) -> HashSet<Key> {
self.key_to_node_index.keys().cloned().collect()
#[pyo3(name = "keys", text_signature = "()")]
pub fn py_keys(slf: PyRef<Self>) -> PyResult<Bound<PySet>> {
let py_set = PySet::empty_bound(slf.py())?;
for key in slf.keys() {
py_set.add(key.clone().into_py(slf.py()))?;
}
Ok(py_set)
}

fn node_index(&self, key: Key) -> usize {
Expand All @@ -501,12 +509,11 @@ impl EquivalenceLibrary {
fn __getstate__(slf: PyRef<Self>) -> PyResult<Bound<'_, PyDict>> {
let ret = PyDict::new_bound(slf.py());
ret.set_item("rule_id", slf.rule_id)?;
let key_to_usize_node: HashMap<(String, u32), usize> = HashMap::from_iter(
slf.key_to_node_index
.iter()
.map(|(key, val)| ((key.name.to_string(), key.num_qubits), val.index())),
);
ret.set_item("key_to_node_index", key_to_usize_node.into_py(slf.py()))?;
let key_to_usize_node: Bound<PyDict> = PyDict::new_bound(slf.py());
for (key, val) in slf.key_to_node_index.iter() {
key_to_usize_node.set_item((&key.name, key.num_qubits), val.index())?;
}
ret.set_item("key_to_node_index", key_to_usize_node)?;
let graph_nodes: Vec<NodeData> = slf._graph.node_weights().cloned().collect();
ret.set_item("graph_nodes", graph_nodes.into_py(slf.py()))?;
let graph_edges: Vec<(usize, usize, EdgeData)> = slf
Expand All @@ -526,14 +533,16 @@ impl EquivalenceLibrary {

fn __setstate__(mut slf: PyRefMut<Self>, state: &Bound<'_, PyDict>) -> PyResult<()> {
slf.rule_id = state.get_item("rule_id")?.unwrap().extract()?;
let graph_nodes: Vec<NodeData> = state.get_item("graph_nodes")?.unwrap().extract()?;
let graph_edges: Vec<(usize, usize, EdgeData)> =
state.get_item("graph_edges")?.unwrap().extract()?;
let graph_nodes_ref: Bound<PyAny> = state.get_item("graph_nodes")?.unwrap();
let graph_nodes: &Bound<PyList> = graph_nodes_ref.downcast()?;
let graph_edge_ref: Bound<PyAny> = state.get_item("graph_edges")?.unwrap();
let graph_edges: &Bound<PyList> = graph_edge_ref.downcast()?;
slf._graph = GraphType::new();
for node_weight in graph_nodes {
slf._graph.add_node(node_weight);
slf._graph.add_node(node_weight.extract()?);
}
for (source_node, target_node, edge_weight) in graph_edges {
for edge in graph_edges {
let (source_node, target_node, edge_weight) = edge.extract()?;
slf._graph.add_edge(
NodeIndex::new(source_node),
NodeIndex::new(target_node),
Expand All @@ -554,6 +563,10 @@ impl EquivalenceLibrary {

// Rust native methods
impl EquivalenceLibrary {
pub fn keys(&self) -> impl Iterator<Item = &Key> {
self.key_to_node_index.keys()
}

/// Create a new node if key not found
fn set_default_node(&mut self, key: Key) -> NodeIndex {
if let Some(value) = self.key_to_node_index.get(&key) {
Expand Down

0 comments on commit 77dbec8

Please sign in to comment.