Skip to content

Commit

Permalink
Add: data() method to avoid extracting CircuitData
Browse files Browse the repository at this point in the history
- Add `py_clone` to perform shallow clones of a `CircuitRef` object by cloning the references to the `QuantumCircuit` object.
- Extract `num_qubits` and `num_clbits` for CircuitRep.
- Add wrapper over `add_equivalence` to be able to accept references and avoid unnecessary cloning of `GateRep` objects in `set_entry`.
- Remove stray mutability of `entry` in `set_entry`.
  • Loading branch information
raynelfss committed Jul 15, 2024
1 parent 08915c0 commit d24c134
Showing 1 changed file with 111 additions and 82 deletions.
193 changes: 111 additions & 82 deletions crates/circuit/src/equivalence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use rustworkx_core::petgraph::{
};

use crate::circuit_data::CircuitData;
use crate::circuit_instruction::{convert_py_to_operation_type, PackedInstruction};
use crate::circuit_instruction::convert_py_to_operation_type;
use crate::imports::ImportOnceCell;
use crate::operations::Param;
use crate::operations::{Operation, OperationType};
Expand Down Expand Up @@ -125,9 +125,10 @@ impl Equivalence {
&& other_circuit.eq(&slf.getattr("circuit")?)?)
}

fn __getnewargs__(&self, py: Python) -> (Py<PyList>, CircuitRep) {
let params = PyList::new_bound(py, self.params.iter().map(|param| param.to_object(py)));
(params.unbind(), self.circuit.clone())
fn __getnewargs__<'py>(
slf: &'py Bound<'py, Self>,
) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)> {
Ok((slf.getattr("params")?, slf.getattr("circuit")?))
}
}

Expand Down Expand Up @@ -171,15 +172,10 @@ impl NodeData {
&& slf.getattr("equivs")?.eq(other.getattr("equivs")?)?)
}

fn __getnewargs__(&self, py: Python) -> (Key, Py<PyList>) {
(
self.key.clone(),
PyList::new_bound(
py,
self.equivs.iter().map(|equiv| equiv.clone().into_py(py)),
)
.unbind(),
)
fn __getnewargs__<'py>(
slf: &'py Bound<'py, Self>,
) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)> {
Ok((slf.getattr("key")?, slf.getattr("equivs")?))
}
}

Expand Down Expand Up @@ -224,11 +220,13 @@ impl EdgeData {
self.to_string()
}

fn __eq__(slf: &Bound<Self>, other: &Bound<PyAny>) -> PyResult<bool> {
Ok(slf.getattr("index")?.eq(other.getattr("index")?)?
&& slf.getattr("num_gates")?.eq(other.getattr("num_gates")?)?
&& slf.getattr("rule")?.eq(other.getattr("rule")?)?
&& slf.getattr("source")?.eq(other.getattr("source")?)?)
fn __eq__(slf: &Bound<Self>, other: &Bound<Self>) -> PyResult<bool> {
let other_borrowed = other.borrow();
let slf_borrowed = slf.borrow();
Ok(slf_borrowed.index == other_borrowed.index
&& slf_borrowed.num_gates == other_borrowed.num_gates
&& slf_borrowed.source == other_borrowed.source
&& other.getattr("rule")?.eq(slf.getattr("rule")?)?)
}

fn __getnewargs__(slf: PyRef<Self>) -> (usize, usize, Equivalence, Key) {
Expand Down Expand Up @@ -273,22 +271,41 @@ impl<'py> FromPyObject<'py> for GateOper {
#[derive(Debug, Clone)]
pub struct CircuitRep {
object: PyObject,
data: CircuitData,
pub num_qubits: usize,
pub num_clbits: usize,
data: Py<CircuitData>,
}

impl CircuitRep {
pub fn iter(&self) -> impl Iterator<Item = &PackedInstruction> {
self.data.iter()
/// Allows access to the circuit's data through a Python reference.
pub fn data<'py>(&'py self, py: Python<'py>) -> PyRef<'py, CircuitData> {
self.data.borrow(py)
}

/// Performs a shallow cloning of the structure by using `clone_ref()`.
pub fn py_clone(&self, py: Python) -> Self {
Self {
object: self.object.clone_ref(py),
num_qubits: self.num_qubits,
num_clbits: self.num_clbits,
data: self.data.clone_ref(py),
}
}
}

impl FromPyObject<'_> for CircuitRep {
fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
if ob.is_instance(QUANTUMCIRCUIT.get_bound(ob.py()))? {
let data = ob.getattr("_data")?.extract()?;
let data: Bound<PyAny> = ob.getattr("_data")?;
let data_downcast: Bound<CircuitData> = data.downcast_into()?;
let data_borrow: PyRef<CircuitData> = data_downcast.borrow();
let num_qubits: usize = data_borrow.num_qubits();
let num_clbits: usize = data_borrow.num_clbits();
Ok(Self {
object: ob.into_py(ob.py()),
data,
num_qubits,
num_clbits,
data: data_downcast.unbind(),
})
} else {
Err(PyTypeError::new_err(
Expand All @@ -304,6 +321,12 @@ impl IntoPy<PyObject> for CircuitRep {
}
}

impl ToPyObject for CircuitRep {
fn to_object(&self, py: Python<'_>) -> PyObject {
self.object.clone_ref(py)
}
}

// Custom Types
type GraphType = StableDiGraph<NodeData, EdgeData>;
type KTIType = HashMap<Key, NodeIndex>;
Expand Down Expand Up @@ -360,54 +383,14 @@ impl EquivalenceLibrary {
/// gate (Gate): A Gate instance.
/// equivalent_circuit (QuantumCircuit): A circuit equivalently
/// implementing the given Gate.
fn add_equivalence(
#[pyo3(name = "add_equivalence")]
fn py_add_equivalence(
&mut self,
py: Python,
gate: GateOper,
equivalent_circuit: CircuitRep,
) -> PyResult<()> {
raise_if_shape_mismatch(&gate, &equivalent_circuit)?;
raise_if_param_mismatch(
py,
&gate.params,
equivalent_circuit.data.get_params_unsorted(py)?,
)?;

let key: Key = Key {
name: gate.operation.name().to_string(),
num_qubits: gate.operation.num_qubits(),
};
let equiv = Equivalence {
params: gate.params,
circuit: equivalent_circuit.clone(),
};

let target = self.set_default_node(key);
if let Some(node) = self._graph.node_weight_mut(target) {
node.equivs.push(equiv.clone());
}
let sources: HashSet<Key> = HashSet::from_iter(equivalent_circuit.iter().map(|inst| Key {
name: inst.op.name().to_string(),
num_qubits: inst.op.num_qubits(),
}));
let edges = Vec::from_iter(sources.iter().map(|source| {
(
self.set_default_node(source.clone()),
target,
EdgeData {
index: self.rule_id,
num_gates: sources.len(),
rule: equiv.clone(),
source: source.clone(),
},
)
}));
for edge in edges {
self._graph.add_edge(edge.0, edge.1, edge.2);
}
self.rule_id += 1;
self.graph = None;
Ok(())
self.add_equivalence(py, &gate, equivalent_circuit)
}

/// Check if a library contains any decompositions for gate.
Expand All @@ -434,15 +417,10 @@ impl EquivalenceLibrary {
/// gate (Gate): A Gate instance.
/// entry (List['QuantumCircuit']) : A list of QuantumCircuits, each
/// equivalently implementing the given Gate.
fn set_entry(
&mut self,
py: Python,
gate: GateOper,
mut entry: Vec<CircuitRep>,
) -> PyResult<()> {
for equiv in entry.iter_mut() {
fn set_entry(&mut self, py: Python, gate: GateOper, entry: Vec<CircuitRep>) -> PyResult<()> {
for equiv in entry.iter() {
raise_if_shape_mismatch(&gate, equiv)?;
raise_if_param_mismatch(py, &gate.params, equiv.data.get_params_unsorted(py)?)?;
raise_if_param_mismatch(py, &gate.params, equiv.data(py).get_params_unsorted(py)?)?;
}

let key = Key {
Expand All @@ -464,7 +442,7 @@ impl EquivalenceLibrary {
self._graph.remove_edge(edge);
}
for equiv in entry {
self.add_equivalence(py, gate.clone(), equiv.clone())?
self.add_equivalence(py, &gate, equiv)?
}
self.graph = None;
Ok(())
Expand Down Expand Up @@ -547,7 +525,7 @@ impl EquivalenceLibrary {
ret.set_item("rule_id", slf.rule_id)?;
let key_to_usize_node: Bound<PyDict> = PyDict::new_bound(slf.py());
for (key, val) in slf.key_to_node_index.iter() {
key_to_usize_node.set_item((&key.name, key.num_qubits), val.index())?;
key_to_usize_node.set_item(key.clone().into_py(slf.py()), val.index())?;
}
ret.set_item("key_to_node_index", key_to_usize_node)?;
let graph_nodes: Bound<PyList> = PyList::empty_bound(slf.py());
Expand Down Expand Up @@ -591,9 +569,9 @@ impl EquivalenceLibrary {
slf.key_to_node_index = state
.get_item("key_to_node_index")?
.unwrap()
.extract::<HashMap<(String, u32), usize>>()?
.extract::<HashMap<Key, usize>>()?
.into_iter()
.map(|((name, num_qubits), val)| (Key::new(name, num_qubits), NodeIndex::new(val)))
.map(|(key, val)| (key, NodeIndex::new(val)))
.collect();
slf.graph = None;
Ok(())
Expand All @@ -602,6 +580,57 @@ impl EquivalenceLibrary {

// Rust native methods
impl EquivalenceLibrary {
fn add_equivalence(
&mut self,
py: Python,
gate: &GateOper,
equivalent_circuit: CircuitRep,
) -> PyResult<()> {
raise_if_shape_mismatch(gate, &equivalent_circuit)?;
raise_if_param_mismatch(
py,
&gate.params,
equivalent_circuit.data(py).get_params_unsorted(py)?,
)?;

let key: Key = Key {
name: gate.operation.name().to_string(),
num_qubits: gate.operation.num_qubits(),
};
let equiv = Equivalence {
params: gate.params.clone(),
circuit: equivalent_circuit.py_clone(py),
};

let target = self.set_default_node(key);
if let Some(node) = self._graph.node_weight_mut(target) {
node.equivs.push(equiv.clone());
}
let sources: HashSet<Key> =
HashSet::from_iter(equivalent_circuit.data(py).iter().map(|inst| Key {
name: inst.op.name().to_string(),
num_qubits: inst.op.num_qubits(),
}));
let edges = Vec::from_iter(sources.iter().map(|source| {
(
self.set_default_node(source.clone()),
target,
EdgeData {
index: self.rule_id,
num_gates: sources.len(),
rule: equiv.clone(),
source: source.clone(),
},
)
}));
for edge in edges {
self._graph.add_edge(edge.0, edge.1, edge.2);
}
self.rule_id += 1;
self.graph = None;
Ok(())
}

/// Rust native equivalent to `EquivalenceLibrary.has_entry()`
///
/// Check if a library contains any decompositions for gate.
Expand Down Expand Up @@ -673,17 +702,17 @@ fn raise_if_param_mismatch(
}

fn raise_if_shape_mismatch(gate: &GateOper, circuit: &CircuitRep) -> PyResult<()> {
if gate.operation.num_qubits() != circuit.data.num_qubits() as u32
|| gate.operation.num_clbits() != circuit.data.num_clbits() as u32
if gate.operation.num_qubits() != circuit.num_qubits as u32
|| gate.operation.num_clbits() != circuit.num_clbits as u32
{
return Err(CircuitError::new_err(format!(
"Cannot add equivalence between circuit and gate \
of different shapes. Gate: {} qubits and {} clbits. \
Circuit: {} qubits and {} clbits.",
gate.operation.num_qubits(),
gate.operation.num_clbits(),
circuit.data.num_qubits(),
circuit.data.num_clbits()
circuit.num_qubits,
circuit.num_clbits
)));
}
Ok(())
Expand Down

0 comments on commit d24c134

Please sign in to comment.