diff --git a/crates/circuit/src/equivalence.rs b/crates/circuit/src/equivalence.rs index 0c2f4d8ad5b1..cd44f1b6ed78 100644 --- a/crates/circuit/src/equivalence.rs +++ b/crates/circuit/src/equivalence.rs @@ -23,7 +23,7 @@ use std::{error::Error, fmt::Display}; use exceptions::CircuitError; use hashbrown::{HashMap, HashSet}; -use pyo3::types::{PyDict, PyString}; +use pyo3::types::PyDict; use pyo3::{prelude::*, types::IntoPyDict}; use rustworkx_core::petgraph::{ @@ -31,7 +31,7 @@ use rustworkx_core::petgraph::{ visit::EdgeRef, }; -use crate::circuit_instruction::convert_py_to_operation_type; +use crate::circuit_instruction::{convert_py_to_operation_type, CircuitInstruction}; use crate::imports::ImportOnceCell; use crate::operations::Param; use crate::operations::{Operation, OperationType}; @@ -278,40 +278,82 @@ impl<'py> FromPyObject<'py> for GateOper { #[derive(Debug, Clone)] pub struct CircuitRep { object: PyObject, - pub num_qubits: u32, - pub num_clbits: u32, - pub label: Option, - pub params: SmallVec<[Param; 3]>, + num_qubits: Option, + num_clbits: Option, + params: Option>, + data: Option>, // TODO: Have a valid implementation of CircuiData that's usable in Rust. } -impl FromPyObject<'_> for CircuitRep { - fn extract(ob: &'_ PyAny) -> PyResult { - let num_qubits = match ob.getattr("num_qubits") { - Ok(num_qubits) => num_qubits.extract::().ok(), - Err(_) => None, +impl CircuitRep { + #[inline] + pub fn num_qubits(&mut self) -> u32 { + match &self.num_qubits { + Some(num_qubits) => *num_qubits, + None => { + let num_qubits = Python::with_gil(|py| -> PyResult { + self.object.getattr(py, "num_qubits")?.extract(py) + }) + .unwrap_or_default(); + self.num_qubits = Some(num_qubits); + num_qubits + } + } + } + + #[inline] + pub fn num_clbits(&mut self) -> u32 { + match &self.num_clbits { + Some(num_clbits) => *num_clbits, + None => { + let num_clbits = Python::with_gil(|py| -> PyResult { + self.object.getattr(py, "num_clbits")?.extract(py) + }) + .unwrap_or_default(); + self.num_clbits = Some(num_clbits); + num_clbits + } } + } + + #[inline] + pub fn params(&mut self) -> &[Param] { + if self.params.is_some() { + return self.params.as_ref().unwrap(); + } + let params = Python::with_gil(|py| -> PyResult> { + self.object + .getattr(py, "params")? + .getattr(py, "data")? + .extract(py) + }) .unwrap_or_default(); - let num_clbits = match ob.getattr("num_clbits") { - Ok(num_clbits) => num_clbits.extract::().ok(), - Err(_) => None, + self.params = Some(params); + self.params.as_ref().unwrap() + } + + #[inline] + pub fn data(&mut self) -> &[CircuitInstruction] { + if self.data.is_some() { + return self.data.as_ref().unwrap(); } + let data = Python::with_gil(|py| -> PyResult> { + self.object.getattr(py, "data")?.extract(py) + }) .unwrap_or_default(); - let label = match ob.getattr("label") { - Ok(label) => label.extract::().ok(), - Err(_) => None, - }; - let params = ob - .getattr("parameters")? - .getattr("data")? - .extract::>() - .unwrap_or_default(); + self.data = Some(data); + self.data.as_ref().unwrap() + } +} + +impl FromPyObject<'_> for CircuitRep { + fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult { Ok(Self { - object: ob.into(), - num_qubits, - num_clbits, - label, - params, + object: ob.to_object(ob.py()), + num_qubits: None, + num_clbits: None, + params: None, + data: None, }) } } @@ -349,10 +391,10 @@ impl Default for CircuitRep { fn default() -> Self { Self { object: Python::with_gil(|py| py.None()), - num_qubits: 0, - num_clbits: 0, - label: None, - params: smallvec![], + num_qubits: None, + num_clbits: None, + params: None, + data: None, } } } @@ -449,7 +491,8 @@ impl EquivalenceLibrary { /// entry (List['QuantumCircuit']) : A list of QuantumCircuits, each /// equivalently implementing the given Gate. fn set_entry(&mut self, gate: GateOper, entry: Vec) -> PyResult<()> { - match self.set_entry_native(&gate, &entry) { + let mut entry = entry; + match self.set_entry_native(&gate, &mut entry) { Ok(_) => Ok(()), Err(e) => Err(CircuitError::new_err(e.message)), } @@ -519,10 +562,10 @@ impl EquivalenceLibrary { fn __getstate__(slf: PyRef) -> PyResult> { let ret = PyDict::new_bound(slf.py()); ret.set_item("rule_id", slf.rule_id)?; - let key_to_usize_node: HashMap<(String, u32), usize> = HashMap::from_iter( + let key_to_usize_node: HashMap = HashMap::from_iter( slf.key_to_node_index .iter() - .map(|(key, val)| ((key.name.to_string(), key.num_qubits), val.index())), + .map(|(key, val)| (key.clone(), val.index())), ); ret.set_item("key_to_node_index", key_to_usize_node.into_py(slf.py()))?; let graph_nodes: Vec = slf._graph.node_weights().cloned().collect(); @@ -544,13 +587,21 @@ impl EquivalenceLibrary { fn __setstate__(mut slf: PyRefMut, state: &Bound<'_, PyDict>) -> PyResult<()> { slf.rule_id = state.get_item("rule_id")?.unwrap().extract()?; - slf.key_to_node_index = state + state .get_item("key_to_node_index")? .unwrap() - .extract::>()? - .into_iter() - .map(|((name, num_qubits), val)| (Key::new(name, num_qubits), NodeIndex::new(val))) - .collect(); + .downcast::()? + .items() + .iter() + .filter_map( + |item| match (item.extract::().ok(), item.extract::().ok()) { + (Some(key), Some(value)) => Some((key, value)), + _ => None, + }, + ) + .for_each(|(key, value)| { + slf.key_to_node_index.insert(key, NodeIndex::new(value)); + }); let graph_nodes: Vec = state.get_item("graph_nodes")?.unwrap().extract()?; let graph_edges: Vec<(usize, usize, EdgeData)> = state.get_item("graph_edges")?.unwrap().extract()?; @@ -603,10 +654,10 @@ impl EquivalenceLibrary { pub fn add_equiv( &mut self, gate: GateOper, - equivalent_circuit: CircuitRep, + mut equivalent_circuit: CircuitRep, ) -> Result<(), EquivalenceError> { - raise_if_shape_mismatch(&gate, &equivalent_circuit)?; - raise_if_param_mismatch(&gate.params, &equivalent_circuit.params)?; + raise_if_shape_mismatch(&gate, &mut equivalent_circuit)?; + raise_if_param_mismatch(&gate.params, equivalent_circuit.params())?; let key: Key = Key { name: gate.operation.name().to_string(), @@ -621,7 +672,11 @@ impl EquivalenceLibrary { if let Some(node) = self._graph.node_weight_mut(target) { node.equivs.push(equiv.clone()); } - let sources: HashSet = get_sources_from_circuit_rep(&equivalent_circuit); + let sources: HashSet = + HashSet::from_iter(equivalent_circuit.data().iter().map(|inst| Key { + name: inst.operation.name().to_string(), + num_qubits: inst.operation.num_qubits(), + })); let edges = Vec::from_iter(sources.iter().map(|source| { ( self.set_default_node(source.clone()), @@ -657,11 +712,11 @@ impl EquivalenceLibrary { pub fn set_entry_native( &mut self, gate: &GateOper, - entry: &Vec, + entry: &mut Vec, ) -> Result<(), EquivalenceError> { - for equiv in entry { + for equiv in &mut *entry { raise_if_shape_mismatch(gate, equiv)?; - raise_if_param_mismatch(&gate.params, &equiv.params)?; + raise_if_param_mismatch(&gate.params, equiv.params())?; } let key = Key { @@ -711,9 +766,12 @@ fn raise_if_param_mismatch( Ok(()) } -fn raise_if_shape_mismatch(gate: &GateOper, circuit: &CircuitRep) -> Result<(), EquivalenceError> { - if gate.operation.num_qubits() != circuit.num_qubits - || gate.operation.num_clbits() != circuit.num_clbits +fn raise_if_shape_mismatch( + gate: &GateOper, + circuit: &mut CircuitRep, +) -> Result<(), EquivalenceError> { + if gate.operation.num_qubits() != circuit.num_qubits() + || gate.operation.num_clbits() != circuit.num_clbits() { return Err(EquivalenceError::new_err(format!( "Cannot add equivalence between circuit and gate \ @@ -721,8 +779,8 @@ fn raise_if_shape_mismatch(gate: &GateOper, circuit: &CircuitRep) -> Result<(), Circuit: {} qubits and {} clbits.", gate.operation.num_qubits(), gate.operation.num_clbits(), - circuit.num_qubits, - circuit.num_clbits + circuit.num_qubits(), + circuit.num_clbits() ))); } Ok(()) @@ -752,32 +810,6 @@ fn rebind_equiv(equiv: Equivalence, query_params: &[Param]) -> Option HashSet { - let raw_sources = Python::with_gil(|py| -> PyResult> { - Ok(circuit - .object - .bind(py) - .getattr("data")? - .iter()? - .flat_map(|inst| -> PyResult<(String, u32)> { - let operation = inst?.getattr("operation")?; - Ok(( - operation - .getattr("name")? - .downcast::()? - .to_string(), - operation.getattr("num_qubits")?.extract::()?, - )) - }) - .collect()) - }) - .unwrap_or(vec![]); - // println!("{:#?}", raw_sources); - HashSet::from_iter(raw_sources.iter().map(|(name, num_qubits)| Key { - name: name.to_string(), - num_qubits: *num_qubits, - })) -} // Errors #[derive(Debug, Clone)]