Skip to content

Commit

Permalink
Start using Rust DAGCircuit. Implement most of __eq__.
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinhartman committed Jun 25, 2024
1 parent 507f88c commit 5834c1f
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 85 deletions.
237 changes: 153 additions & 84 deletions crates/circuit/src/dag_circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use indexmap::set::Slice;
use indexmap::{IndexMap, IndexSet};
use petgraph::prelude::*;
use pyo3::callback::IntoPyCallbackOutput;
use pyo3::exceptions::{PyIndexError, PyKeyError, PyRuntimeError, PyValueError};
use pyo3::exceptions::{PyIndexError, PyKeyError, PyRuntimeError, PyTypeError, PyValueError};
use pyo3::ffi::PyCFunction;
use pyo3::prelude::*;
use pyo3::types::iter::BoundTupleIterator;
Expand Down Expand Up @@ -125,9 +125,9 @@ enum Wire {
#[pyclass(module = "qiskit._accelerate.circuit")]
#[derive(Clone, Debug)]
pub struct DAGCircuit {
#[pyo3(get)]
#[pyo3(get, set)]
name: Option<Py<PyString>>,
#[pyo3(get)]
#[pyo3(get, set)]
metadata: Option<Py<PyDict>>,
calibrations: HashMap<String, Py<PyDict>>,

Expand All @@ -147,8 +147,10 @@ pub struct DAGCircuit {
/// Global phase.
global_phase: PyObject,
/// Duration.
#[pyo3(get, set)]
duration: Option<PyObject>,
/// Unit of duration.
#[pyo3(get, set)]
unit: String,

// Note: these are tracked separately from `qubits` and `clbits`
Expand Down Expand Up @@ -226,61 +228,42 @@ impl PyControlFlowModule {

#[derive(Clone, Debug)]
struct PyCircuitModule {
clbit: Py<PyType>,
qubit: Py<PyType>,
classical_register: Py<PyType>,
quantum_register: Py<PyType>,
control_flow_op: Py<PyType>,
for_loop_op: Py<PyType>,
if_else_op: Py<PyType>,
while_loop_op: Py<PyType>,
switch_case_op: Py<PyType>,
operation: Py<PyType>,
store: Py<PyType>,
gate: Py<PyType>,
parameter_expression: Py<PyType>,
variable_mapper: Py<PyType>,
clbit: Py<PyAny>,
qubit: Py<PyAny>,
classical_register: Py<PyAny>,
quantum_register: Py<PyAny>,
control_flow_op: Py<PyAny>,
for_loop_op: Py<PyAny>,
if_else_op: Py<PyAny>,
while_loop_op: Py<PyAny>,
switch_case_op: Py<PyAny>,
operation: Py<PyAny>,
store: Py<PyAny>,
gate: Py<PyAny>,
parameter_expression: Py<PyAny>,
variable_mapper: Py<PyAny>,
}

impl PyCircuitModule {
fn new(py: Python) -> PyResult<Self> {
let module = PyModule::import_bound(py, "qiskit.circuit")?;
Ok(PyCircuitModule {
clbit: module.getattr("Clbit")?.downcast_into_exact()?.unbind(),
qubit: module.getattr("Qubit")?.downcast_into_exact()?.unbind(),
classical_register: module
.getattr("ClassicalRegister")?
.downcast_into_exact()?
.unbind(),
quantum_register: module
.getattr("QuantumRegsiter")?
.downcast_into_exact()?
.unbind(),
control_flow_op: module
.getattr("ControlFlowOp")?
.downcast_into_exact()?
.unbind(),
for_loop_op: module.getattr("ForLoopOp")?.downcast_into_exact()?.unbind(),
if_else_op: module.getattr("IfElseOp")?.downcast_into_exact()?.unbind(),
while_loop_op: module
.getattr("WhileLoopOp")?
.downcast_into_exact()?
.unbind(),
switch_case_op: module
.getattr("SwitchCaseOp")?
.downcast_into_exact()?
.unbind(),
operation: module.getattr("Operation")?.downcast_into_exact()?.unbind(),
store: module.getattr("Store")?.downcast_into_exact()?.unbind(),
gate: module.getattr("Gate")?.downcast_into_exact()?.unbind(),
parameter_expression: module
.getattr("ParameterExpression")?
.downcast_into_exact()?
.unbind(),
clbit: module.getattr("Clbit")?.unbind(),
qubit: module.getattr("Qubit")?.unbind(),
classical_register: module.getattr("ClassicalRegister")?.unbind(),
quantum_register: module.getattr("QuantumRegister")?.unbind(),
control_flow_op: module.getattr("ControlFlowOp")?.unbind(),
for_loop_op: module.getattr("ForLoopOp")?.unbind(),
if_else_op: module.getattr("IfElseOp")?.unbind(),
while_loop_op: module.getattr("WhileLoopOp")?.unbind(),
switch_case_op: module.getattr("SwitchCaseOp")?.unbind(),
operation: module.getattr("Operation")?.unbind(),
store: module.getattr("Store")?.unbind(),
gate: module.getattr("Gate")?.unbind(),
parameter_expression: module.getattr("ParameterExpression")?.unbind(),
variable_mapper: module
.getattr("_classical_resource_map")?
.getattr("VariableMapper")?
.downcast_into_exact()?
.unbind(),
})
}
Expand Down Expand Up @@ -1769,41 +1752,127 @@ def _format(operand):
todo!()
}

fn __eq__(&self, other: &DAGCircuit) -> PyResult<bool> {
// # Try to convert to float, but in case of unbound ParameterExpressions
// # a TypeError will be raise, fallback to normal equality in those
// # cases
// try:
// self_phase = float(self.global_phase)
// other_phase = float(other.global_phase)
// if (
// abs((self_phase - other_phase + np.pi) % (2 * np.pi) - np.pi) > 1.0e-10
// ): # TODO: atol?
// return False
// except TypeError:
// if self.global_phase != other.global_phase:
// return False
// if self.calibrations != other.calibrations:
// return False
//
// self_bit_indices = {bit: idx for idx, bit in enumerate(self.qubits + self.clbits)}
// other_bit_indices = {bit: idx for idx, bit in enumerate(other.qubits + other.clbits)}
//
// self_qreg_indices = {
// regname: [self_bit_indices[bit] for bit in reg] for regname, reg in self.qregs.items()
// }
// self_creg_indices = {
// regname: [self_bit_indices[bit] for bit in reg] for regname, reg in self.cregs.items()
// }
//
// other_qreg_indices = {
// regname: [other_bit_indices[bit] for bit in reg] for regname, reg in other.qregs.items()
// }
// other_creg_indices = {
// regname: [other_bit_indices[bit] for bit in reg] for regname, reg in other.cregs.items()
// }
// if self_qreg_indices != other_qreg_indices or self_creg_indices != other_creg_indices:
// return False
fn __eq__(&self, py: Python, other: &DAGCircuit) -> PyResult<bool> {
// Try to convert to float, but in case of unbound ParameterExpressions
// a TypeError will be raise, fallback to normal equality in those
// cases.
let self_phase = match self
.global_phase
.bind(py)
.call_method0(intern!(py, "__float__"))
{
Err(e) if !e.is_instance_of::<PyTypeError>(py) => {
return Err(e);
}
res => res.ok(),
};
let other_phase = match other
.global_phase
.bind(py)
.call_method0(intern!(py, "__float__"))
{
Err(e) if !e.is_instance_of::<PyTypeError>(py) => {
return Err(e);
}
res => res.ok(),
};
match (self_phase, other_phase) {
(Some(self_phase), Some(other_phase)) => {
let self_phase: f64 = self_phase.extract()?;
let other_phase: f64 = other_phase.extract()?;
if (((self_phase - other_phase + PI) % (2.0 * PI)) - PI).abs() > 1.0e-10 {
return Ok(false);
}
}
_ => {
if !self.global_phase.bind(py).eq(other.global_phase.bind(py))? {
return Ok(false);
}
}
}

if self.calibrations.len() != other.calibrations.len() {
return Ok(false);
}

for (k, v1) in &self.calibrations {
match other.calibrations.get(k) {
Some(v2) => {
if !v1.bind(py).eq(v2.bind(py))? {
return Ok(false);
}
}
None => {
return Ok(false);
}
}
}

let self_bit_indices = {
let indices = self
.qubits
.bits()
.iter()
.chain(self.clbits.bits())
.enumerate()
.map(|(idx, bit)| (bit, idx));
indices.into_py_dict_bound(py)
};

let other_bit_indices = {
let indices = other
.qubits
.bits()
.iter()
.chain(other.clbits.bits())
.enumerate()
.map(|(idx, bit)| (bit, idx));
indices.into_py_dict_bound(py)
};

// Check if qregs are the same.
let self_qregs = self.qregs.bind(py);
let other_qregs = other.qregs.bind(py);
if self_qregs.len() != other_qregs.len() {
return Ok(false);
}
for (regname, self_bits) in self_qregs {
let self_bits = self_bits.downcast_into_exact::<PyList>()?;
let other_bits = match other_qregs.get_item(regname)? {
Some(bits) => bits.downcast_into_exact::<PyList>()?,
None => return Ok(false),
};
if !self
.qubits
.map_bits(self_bits)?
.eq(other.qubits.map_bits(other_bits)?)
{
return Ok(false);
}
}

// Check if cregs are the same.
let self_cregs = self.cregs.bind(py);
let other_cregs = other.cregs.bind(py);
if self_cregs.len() != other_cregs.len() {
return Ok(false);
}

for (regname, self_bits) in self_cregs {
let self_bits = self_bits.downcast_into_exact::<PyList>()?;
let other_bits = match other_cregs.get_item(regname)? {
Some(bits) => bits.downcast_into_exact::<PyList>()?,
None => return Ok(false),
};
if !self
.clbits
.map_bits(self_bits)?
.eq(other.clbits.map_bits(other_bits)?)
{
return Ok(false);
}
}

//
// def node_eq(node_self, node_other):
// return DAGNode.semantic_eq(node_self, node_other, self_bit_indices, other_bit_indices)
Expand Down
7 changes: 6 additions & 1 deletion qiskit/dagcircuit/dagcircuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,13 @@
# The allowable arguments to :meth:`DAGCircuit.copy_empty_like`'s ``vars_mode``.
_VarsMode = Literal["alike", "captures", "drop"]

import qiskit._accelerate.circuit

class DAGCircuit:

DAGCircuit = qiskit._accelerate.circuit.DAGCircuit


class _OldDAGCircuit:
"""
Quantum circuit as a directed acyclic graph.
Expand Down

0 comments on commit 5834c1f

Please sign in to comment.