Skip to content

Commit

Permalink
Remove: Default initialization methods from custom datatypes.
Browse files Browse the repository at this point in the history
- Use `__getnewargs__ instead.
  • Loading branch information
raynelfss committed Jun 28, 2024
1 parent e2c8dcb commit dc3041e
Showing 1 changed file with 35 additions and 34 deletions.
69 changes: 35 additions & 34 deletions crates/circuit/src/equivalence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use rustworkx_core::petgraph::csr::IndexType;
use rustworkx_core::petgraph::stable_graph::StableDiGraph;
use rustworkx_core::petgraph::visit::IntoEdgeReferences;

use smallvec::{smallvec, SmallVec};
use smallvec::SmallVec;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::{error::Error, fmt::Display};
Expand All @@ -42,7 +42,8 @@ mod exceptions {
import_exception_bound! {qiskit.circuit.exceptions, CircuitError}
}
pub static PYDIGRAPH: ImportOnceCell = ImportOnceCell::new("rustworkx", "PyDiGraph");
pub static QUANTUMCIRCUIT: ImportOnceCell = ImportOnceCell::new("qiskit.circuit", "QuantumCircuit");
pub static QUANTUMCIRCUIT: ImportOnceCell =
ImportOnceCell::new("qiskit.circuit.quantumcircuit", "QuantumCircuit");

// Custom Structs

Expand All @@ -58,7 +59,7 @@ pub struct Key {
#[pymethods]
impl Key {
#[new]
#[pyo3(signature = (name="".to_string(), num_qubits=0))]
#[pyo3(signature = (name, num_qubits))]
fn new(name: String, num_qubits: u32) -> Self {
Self { name, num_qubits }
}
Expand All @@ -81,6 +82,10 @@ impl Key {
(slf.name.clone(), slf.num_qubits)
}

fn __getnewargs__(slf: PyRef<Self>) -> (String, u32) {
Key::__getstate__(slf)
}

fn __setstate__(mut slf: PyRefMut<Self>, state: (String, u32)) {
slf.name = state.0;
slf.num_qubits = state.1;
Expand All @@ -97,17 +102,8 @@ impl Display for Key {
}
}

impl Default for Key {
fn default() -> Self {
Self {
name: "".to_string(),
num_qubits: 0,
}
}
}

#[pyclass(sequence, module = "qiskit._accelerate.circuit.equivalence")]
#[derive(Debug, Clone, PartialEq, Default)]
#[derive(Debug, Clone, PartialEq)]
pub struct Equivalence {
#[pyo3(get)]
pub params: SmallVec<[Param; 3]>,
Expand All @@ -118,7 +114,7 @@ pub struct Equivalence {
#[pymethods]
impl Equivalence {
#[new]
#[pyo3(signature = (params=smallvec![], circuit=CircuitRep::default()))]
#[pyo3(signature = (params, circuit))]
fn new(params: SmallVec<[Param; 3]>, circuit: CircuitRep) -> Self {
Self { circuit, params }
}
Expand All @@ -131,8 +127,12 @@ impl Equivalence {
self.eq(&other)
}

fn __getstate__(slf: PyRef<Self>) -> (SmallVec<[Param; 3]>, CircuitRep) {
(slf.params.clone(), slf.circuit.clone())
fn __getstate__(&self, py: Python) -> (PyObject, PyObject) {
(self.params.to_object(py), self.circuit.object.clone_ref(py))
}

fn __getnewargs__(&self, py: Python) -> (PyObject, PyObject) {
self.__getstate__(py)
}

fn __setstate__(mut slf: PyRefMut<Self>, state: (SmallVec<[Param; 3]>, CircuitRep)) {
Expand All @@ -153,7 +153,7 @@ impl Display for Equivalence {
}

#[pyclass(sequence, module = "qiskit._accelerate.circuit.equivalence")]
#[derive(Debug, Clone, PartialEq, Default)]
#[derive(Debug, Clone, PartialEq)]
pub struct NodeData {
#[pyo3(get)]
key: Key,
Expand All @@ -164,7 +164,7 @@ pub struct NodeData {
#[pymethods]
impl NodeData {
#[new]
#[pyo3(signature = (key=Key::default(), equivs=vec![]))]
#[pyo3(signature = (key, equivs))]
fn new(key: Key, equivs: Vec<Equivalence>) -> Self {
Self { key, equivs }
}
Expand All @@ -177,8 +177,12 @@ impl NodeData {
self.eq(&other)
}

fn __getstate__(slf: PyRef<Self>) -> (Key, Vec<Equivalence>) {
(slf.key.clone(), slf.equivs.clone())
fn __getstate__(&self) -> (Key, Vec<Equivalence>) {
(self.key.clone(), self.equivs.clone())
}

fn __getnewargs__(&self) -> (Key, Vec<Equivalence>) {
self.__getstate__()
}

fn __setstate__(mut slf: PyRefMut<Self>, state: (Key, Vec<Equivalence>)) {
Expand All @@ -199,7 +203,7 @@ impl Display for NodeData {
}

#[pyclass(sequence, module = "qiskit._accelerate.circuit.equivalence")]
#[derive(Debug, Clone, PartialEq, Default)]
#[derive(Debug, Clone, PartialEq)]
pub struct EdgeData {
#[pyo3(get)]
pub index: usize,
Expand All @@ -214,7 +218,7 @@ pub struct EdgeData {
#[pymethods]
impl EdgeData {
#[new]
#[pyo3(signature = (index=0, num_gates=0, rule=Equivalence::default(), source=Key::default()))]
#[pyo3(signature = (index, num_gates, rule, source))]
fn new(index: usize, num_gates: usize, rule: Equivalence, source: Key) -> Self {
Self {
index,
Expand All @@ -241,6 +245,15 @@ impl EdgeData {
)
}

fn __getnewargs__(slf: PyRef<Self>) -> (usize, usize, Equivalence, Key) {
(
slf.index,
slf.num_gates,
slf.rule.clone(),
slf.source.clone(),
)
}

fn __setstate__(mut slf: PyRefMut<Self>, state: (usize, usize, Equivalence, Key)) {
slf.index = state.0;
slf.num_gates = state.1;
Expand Down Expand Up @@ -366,18 +379,6 @@ impl IntoPy<PyObject> for CircuitRep {
}
}

impl Default for CircuitRep {
fn default() -> Self {
Self {
object: Python::with_gil(|py| py.None()),
num_qubits: 0,
num_clbits: 0,
params: None,
data: None,
}
}
}

// Custom Types
type GraphType = StableDiGraph<NodeData, EdgeData>;
type KTIType = HashMap<Key, NodeIndex>;
Expand Down

0 comments on commit dc3041e

Please sign in to comment.