Skip to content

Commit

Permalink
Fix: Make CircuitRep attributes OneCell-like.
Browse files Browse the repository at this point in the history
- Attributes from CircuitRep are only written once, reducing the overhead.
- Modify `__setstate__` to avoid extra conversion.
- Remove `get_sources_from_circuit_rep`.
  • Loading branch information
raynelfss committed Jun 26, 2024
1 parent 785564f commit 3b954e4
Showing 1 changed file with 110 additions and 78 deletions.
188 changes: 110 additions & 78 deletions crates/circuit/src/equivalence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ use std::{error::Error, fmt::Display};

use exceptions::CircuitError;
use hashbrown::{HashMap, HashSet};
use pyo3::types::{PyDict, PyString};
use pyo3::types::PyDict;
use pyo3::{prelude::*, types::IntoPyDict};

use rustworkx_core::petgraph::{
graph::{EdgeIndex, NodeIndex},
visit::EdgeRef,
};

use crate::circuit_instruction::convert_py_to_operation_type;
use crate::circuit_instruction::{convert_py_to_operation_type, CircuitInstruction};
use crate::imports::ImportOnceCell;
use crate::operations::Param;
use crate::operations::{Operation, OperationType};
Expand Down Expand Up @@ -278,40 +278,82 @@ impl<'py> FromPyObject<'py> for GateOper {
#[derive(Debug, Clone)]
pub struct CircuitRep {
object: PyObject,
pub num_qubits: u32,
pub num_clbits: u32,
pub label: Option<String>,
pub params: SmallVec<[Param; 3]>,
num_qubits: Option<u32>,
num_clbits: Option<u32>,
params: Option<SmallVec<[Param; 3]>>,
data: Option<Vec<CircuitInstruction>>,
// TODO: Have a valid implementation of CircuiData that's usable in Rust.
}

impl FromPyObject<'_> for CircuitRep {
fn extract(ob: &'_ PyAny) -> PyResult<Self> {
let num_qubits = match ob.getattr("num_qubits") {
Ok(num_qubits) => num_qubits.extract::<u32>().ok(),
Err(_) => None,
impl CircuitRep {
#[inline]
pub fn num_qubits(&mut self) -> u32 {
match &self.num_qubits {
Some(num_qubits) => *num_qubits,
None => {
let num_qubits = Python::with_gil(|py| -> PyResult<u32> {
self.object.getattr(py, "num_qubits")?.extract(py)
})
.unwrap_or_default();
self.num_qubits = Some(num_qubits);
num_qubits
}
}
}

#[inline]
pub fn num_clbits(&mut self) -> u32 {
match &self.num_clbits {
Some(num_clbits) => *num_clbits,
None => {
let num_clbits = Python::with_gil(|py| -> PyResult<u32> {
self.object.getattr(py, "num_clbits")?.extract(py)
})
.unwrap_or_default();
self.num_clbits = Some(num_clbits);
num_clbits
}
}
}

#[inline]
pub fn params(&mut self) -> &[Param] {
if self.params.is_some() {
return self.params.as_ref().unwrap();
}
let params = Python::with_gil(|py| -> PyResult<SmallVec<[Param; 3]>> {
self.object
.getattr(py, "params")?
.getattr(py, "data")?
.extract(py)
})
.unwrap_or_default();
let num_clbits = match ob.getattr("num_clbits") {
Ok(num_clbits) => num_clbits.extract::<u32>().ok(),
Err(_) => None,
self.params = Some(params);
self.params.as_ref().unwrap()
}

#[inline]
pub fn data(&mut self) -> &[CircuitInstruction] {
if self.data.is_some() {
return self.data.as_ref().unwrap();
}
let data = Python::with_gil(|py| -> PyResult<Vec<CircuitInstruction>> {
self.object.getattr(py, "data")?.extract(py)
})
.unwrap_or_default();
let label = match ob.getattr("label") {
Ok(label) => label.extract::<String>().ok(),
Err(_) => None,
};
let params = ob
.getattr("parameters")?
.getattr("data")?
.extract::<SmallVec<[Param; 3]>>()
.unwrap_or_default();
self.data = Some(data);
self.data.as_ref().unwrap()
}
}

impl FromPyObject<'_> for CircuitRep {
fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
Ok(Self {
object: ob.into(),
num_qubits,
num_clbits,
label,
params,
object: ob.to_object(ob.py()),
num_qubits: None,
num_clbits: None,
params: None,
data: None,
})
}
}
Expand Down Expand Up @@ -349,10 +391,10 @@ impl Default for CircuitRep {
fn default() -> Self {
Self {
object: Python::with_gil(|py| py.None()),
num_qubits: 0,
num_clbits: 0,
label: None,
params: smallvec![],
num_qubits: None,
num_clbits: None,
params: None,
data: None,
}
}
}
Expand Down Expand Up @@ -449,7 +491,8 @@ impl EquivalenceLibrary {
/// entry (List['QuantumCircuit']) : A list of QuantumCircuits, each
/// equivalently implementing the given Gate.
fn set_entry(&mut self, gate: GateOper, entry: Vec<CircuitRep>) -> PyResult<()> {
match self.set_entry_native(&gate, &entry) {
let mut entry = entry;
match self.set_entry_native(&gate, &mut entry) {
Ok(_) => Ok(()),
Err(e) => Err(CircuitError::new_err(e.message)),
}
Expand Down Expand Up @@ -519,10 +562,10 @@ impl EquivalenceLibrary {
fn __getstate__(slf: PyRef<Self>) -> PyResult<Bound<'_, PyDict>> {
let ret = PyDict::new_bound(slf.py());
ret.set_item("rule_id", slf.rule_id)?;
let key_to_usize_node: HashMap<(String, u32), usize> = HashMap::from_iter(
let key_to_usize_node: HashMap<Key, usize> = HashMap::from_iter(
slf.key_to_node_index
.iter()
.map(|(key, val)| ((key.name.to_string(), key.num_qubits), val.index())),
.map(|(key, val)| (key.clone(), val.index())),
);
ret.set_item("key_to_node_index", key_to_usize_node.into_py(slf.py()))?;
let graph_nodes: Vec<NodeData> = slf._graph.node_weights().cloned().collect();
Expand All @@ -544,13 +587,21 @@ impl EquivalenceLibrary {

fn __setstate__(mut slf: PyRefMut<Self>, state: &Bound<'_, PyDict>) -> PyResult<()> {
slf.rule_id = state.get_item("rule_id")?.unwrap().extract()?;
slf.key_to_node_index = state
state
.get_item("key_to_node_index")?
.unwrap()
.extract::<HashMap<(String, u32), usize>>()?
.into_iter()
.map(|((name, num_qubits), val)| (Key::new(name, num_qubits), NodeIndex::new(val)))
.collect();
.downcast::<PyDict>()?
.items()
.iter()
.filter_map(
|item| match (item.extract::<Key>().ok(), item.extract::<usize>().ok()) {
(Some(key), Some(value)) => Some((key, value)),
_ => None,
},
)
.for_each(|(key, value)| {
slf.key_to_node_index.insert(key, NodeIndex::new(value));
});
let graph_nodes: Vec<NodeData> = state.get_item("graph_nodes")?.unwrap().extract()?;
let graph_edges: Vec<(usize, usize, EdgeData)> =
state.get_item("graph_edges")?.unwrap().extract()?;
Expand Down Expand Up @@ -603,10 +654,10 @@ impl EquivalenceLibrary {
pub fn add_equiv(
&mut self,
gate: GateOper,
equivalent_circuit: CircuitRep,
mut equivalent_circuit: CircuitRep,
) -> Result<(), EquivalenceError> {
raise_if_shape_mismatch(&gate, &equivalent_circuit)?;
raise_if_param_mismatch(&gate.params, &equivalent_circuit.params)?;
raise_if_shape_mismatch(&gate, &mut equivalent_circuit)?;
raise_if_param_mismatch(&gate.params, equivalent_circuit.params())?;

let key: Key = Key {
name: gate.operation.name().to_string(),
Expand All @@ -621,7 +672,11 @@ impl EquivalenceLibrary {
if let Some(node) = self._graph.node_weight_mut(target) {
node.equivs.push(equiv.clone());
}
let sources: HashSet<Key> = get_sources_from_circuit_rep(&equivalent_circuit);
let sources: HashSet<Key> =
HashSet::from_iter(equivalent_circuit.data().iter().map(|inst| Key {
name: inst.operation.name().to_string(),
num_qubits: inst.operation.num_qubits(),
}));
let edges = Vec::from_iter(sources.iter().map(|source| {
(
self.set_default_node(source.clone()),
Expand Down Expand Up @@ -657,11 +712,11 @@ impl EquivalenceLibrary {
pub fn set_entry_native(
&mut self,
gate: &GateOper,
entry: &Vec<CircuitRep>,
entry: &mut Vec<CircuitRep>,
) -> Result<(), EquivalenceError> {
for equiv in entry {
for equiv in &mut *entry {
raise_if_shape_mismatch(gate, equiv)?;
raise_if_param_mismatch(&gate.params, &equiv.params)?;
raise_if_param_mismatch(&gate.params, equiv.params())?;
}

let key = Key {
Expand Down Expand Up @@ -711,18 +766,21 @@ fn raise_if_param_mismatch(
Ok(())
}

fn raise_if_shape_mismatch(gate: &GateOper, circuit: &CircuitRep) -> Result<(), EquivalenceError> {
if gate.operation.num_qubits() != circuit.num_qubits
|| gate.operation.num_clbits() != circuit.num_clbits
fn raise_if_shape_mismatch(
gate: &GateOper,
circuit: &mut CircuitRep,
) -> Result<(), EquivalenceError> {
if gate.operation.num_qubits() != circuit.num_qubits()
|| gate.operation.num_clbits() != circuit.num_clbits()
{
return Err(EquivalenceError::new_err(format!(
"Cannot add equivalence between circuit and gate \
of different shapes. Gate: {} qubits and {} clbits. \
Circuit: {} qubits and {} clbits.",
gate.operation.num_qubits(),
gate.operation.num_clbits(),
circuit.num_qubits,
circuit.num_clbits
circuit.num_qubits(),
circuit.num_clbits()
)));
}
Ok(())
Expand Down Expand Up @@ -752,32 +810,6 @@ fn rebind_equiv(equiv: Equivalence, query_params: &[Param]) -> Option<CircuitRep
.ok()
}

fn get_sources_from_circuit_rep(circuit: &CircuitRep) -> HashSet<Key> {
let raw_sources = Python::with_gil(|py| -> PyResult<Vec<(String, u32)>> {
Ok(circuit
.object
.bind(py)
.getattr("data")?
.iter()?
.flat_map(|inst| -> PyResult<(String, u32)> {
let operation = inst?.getattr("operation")?;
Ok((
operation
.getattr("name")?
.downcast::<PyString>()?
.to_string(),
operation.getattr("num_qubits")?.extract::<u32>()?,
))
})
.collect())
})
.unwrap_or(vec![]);
// println!("{:#?}", raw_sources);
HashSet::from_iter(raw_sources.iter().map(|(name, num_qubits)| Key {
name: name.to_string(),
num_qubits: *num_qubits,
}))
}
// Errors

#[derive(Debug, Clone)]
Expand Down

0 comments on commit 3b954e4

Please sign in to comment.