Skip to content

Commit

Permalink
Remove: Undo changes to Param
Browse files Browse the repository at this point in the history
- Fix comparison methods for `Key`, `Equivalence`, `EdgeData` and `NodeData` to account for the removal of `PartialEq` for `Param`.
  • Loading branch information
raynelfss committed Jul 5, 2024
1 parent 5519659 commit bf9aedd
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 74 deletions.
91 changes: 54 additions & 37 deletions crates/circuit/src/equivalence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ impl Key {
Self { name, num_qubits }
}

fn __eq__(&self, other: Self) -> bool {
self.eq(&other)
fn __eq__(&self, other: &Self) -> bool {
self.eq(other)
}

fn __hash__(&self) -> u64 {
Expand Down Expand Up @@ -97,7 +97,7 @@ impl Display for Key {
}

#[pyclass(sequence, module = "qiskit._accelerate.circuit.equivalence")]
#[derive(Debug, Clone, PartialEq)]
#[derive(Debug, Clone)]
pub struct Equivalence {
#[pyo3(get)]
pub params: SmallVec<[Param; 3]>,
Expand All @@ -117,8 +117,20 @@ impl Equivalence {
self.to_string()
}

fn __eq__(&self, other: Self) -> bool {
self.eq(&other)
fn __eq__(&self, py: Python, other: &Self) -> PyResult<bool> {
let bound_circ = self.circuit.object.bind(py);
let other_len = other.params.len();
let mut other_iter = other.params.iter();

Ok(bound_circ.eq(&other.circuit.object)?
&& self.params.len() == other_len
&& self.params.iter().all(|param| {
let param = param.to_object(py);
let param_bound = param.bind(py);
param_bound
.eq(other_iter.next().unwrap())
.unwrap_or_default()
}))
}

fn __getnewargs__(&self, py: Python) -> (Py<PyList>, PyObject) {
Expand All @@ -142,7 +154,7 @@ impl Display for Equivalence {
}

#[pyclass(sequence, module = "qiskit._accelerate.circuit.equivalence")]
#[derive(Debug, Clone, PartialEq)]
#[derive(Debug, Clone)]
pub struct NodeData {
#[pyo3(get)]
key: Key,
Expand All @@ -162,8 +174,16 @@ impl NodeData {
self.to_string()
}

fn __eq__(&self, other: Self) -> bool {
self.eq(&other)
fn __eq__(&self, py: Python, other: &Self) -> PyResult<bool> {
let other_len = other.equivs.len();
let mut other_iter = other.equivs.iter();
Ok(self.key == other.key
&& self.equivs.len() == other_len
&& self.equivs.iter().all(|equiv| {
equiv
.__eq__(py, other_iter.next().unwrap())
.unwrap_or_default()
}))
}

fn __getnewargs__(&self, py: Python) -> (Key, Py<PyList>) {
Expand All @@ -190,7 +210,7 @@ impl Display for NodeData {
}

#[pyclass(sequence, module = "qiskit._accelerate.circuit.equivalence")]
#[derive(Debug, Clone, PartialEq)]
#[derive(Debug, Clone)]
pub struct EdgeData {
#[pyo3(get)]
pub index: usize,
Expand Down Expand Up @@ -219,8 +239,11 @@ impl EdgeData {
self.to_string()
}

fn __eq__(&self, other: Self) -> bool {
self.eq(&other)
fn __eq__(&self, py: Python, other: Self) -> PyResult<bool> {
Ok(self.index == other.index
&& self.num_gates == other.num_gates
&& self.source == other.source
&& self.rule.__eq__(py, &other.rule)?)
}

fn __getnewargs__(slf: PyRef<Self>) -> (usize, usize, Equivalence, Key) {
Expand All @@ -243,7 +266,7 @@ impl Display for EdgeData {
}
}

// Enum to extract circuit instructions more broadly
/// Enum that helps extract the Operation and Parameters on a Gate.
#[derive(Debug, Clone)]
pub struct GateOper {
operation: OperationType,
Expand All @@ -266,38 +289,32 @@ pub struct CircuitRep {
object: PyObject,
pub num_qubits: u32,
pub num_clbits: u32,
params: Option<SmallVec<[Param; 3]>>,
params: Option<PyObject>,
data: Option<Vec<CircuitInstruction>>,
// TODO: Have a valid implementation of CircuiData that's usable in Rust.
}

impl CircuitRep {
pub fn parameters(&mut self) -> &[Param] {
pub fn parameters(&mut self, py: Python) -> PyResult<PyObject> {
if self.params.is_none() {
let params = Python::with_gil(|py| -> PyResult<SmallVec<[Param; 3]>> {
self.object
.bind(py)
.getattr("parameters")?
.getattr("data")?
.extract()
})
.unwrap_or_default();
self.params = Some(params);
return self.params.as_ref().unwrap();
let params = self.object.getattr(py, "parameters")?;
self.params = Some(params.clone_ref(py));
return Ok(params);
}
return self.params.as_ref().unwrap();
return Ok(self
.params
.as_ref()
.map(|params| params.clone_ref(py))
.unwrap());
}

pub fn data(&mut self) -> &[CircuitInstruction] {
pub fn data(&mut self, py: Python) -> PyResult<&[CircuitInstruction]> {
if self.data.is_none() {
let data = Python::with_gil(|py| -> PyResult<Vec<CircuitInstruction>> {
self.object.bind(py).getattr("data")?.extract()
})
.unwrap_or_default();
let data = self.object.bind(py).getattr("data")?.extract()?;
self.data = Some(data);
return self.data.as_ref().unwrap();
return Ok(self.data.as_ref().unwrap());
}
return self.data.as_ref().unwrap();
return Ok(self.data.as_ref().unwrap());
}
}

Expand Down Expand Up @@ -413,7 +430,7 @@ impl EquivalenceLibrary {
mut equivalent_circuit: CircuitRep,
) -> PyResult<()> {
raise_if_shape_mismatch(&gate, &equivalent_circuit)?;
raise_if_param_mismatch(py, &gate.params, equivalent_circuit.parameters())?;
raise_if_param_mismatch(py, &gate.params, equivalent_circuit.parameters(py)?)?;

let key: Key = Key {
name: gate.operation.name().to_string(),
Expand All @@ -429,7 +446,7 @@ impl EquivalenceLibrary {
node.equivs.push(equiv.clone());
}
let sources: HashSet<Key> =
HashSet::from_iter(equivalent_circuit.data().iter().map(|inst| Key {
HashSet::from_iter(equivalent_circuit.data(py)?.iter().map(|inst| Key {
name: inst.operation.name().to_string(),
num_qubits: inst.operation.num_qubits(),
}));
Expand Down Expand Up @@ -485,7 +502,7 @@ impl EquivalenceLibrary {
) -> PyResult<()> {
for equiv in entry.iter_mut() {
raise_if_shape_mismatch(&gate, equiv)?;
raise_if_param_mismatch(py, &gate.params, equiv.parameters())?;
raise_if_param_mismatch(py, &gate.params, equiv.parameters(py)?)?;
}

let key = Key {
Expand Down Expand Up @@ -696,15 +713,15 @@ impl EquivalenceLibrary {
fn raise_if_param_mismatch(
py: Python,
gate_params: &[Param],
circuit_parameters: &[Param],
circuit_parameters: PyObject,
) -> PyResult<()> {
let gate_params_obj = PySet::new_bound(
py,
gate_params
.iter()
.filter(|param| matches!(param, Param::ParameterExpression(_))),
)?;
if !gate_params_obj.eq(PySet::new_bound(py, circuit_parameters)?)? {
if !gate_params_obj.eq(&circuit_parameters)? {
return Err(CircuitError::new_err(format!(
"Cannot add equivalence between circuit and gate \
of different parameters. Gate params: {:?}. \
Expand Down
37 changes: 0 additions & 37 deletions crates/circuit/src/operations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,43 +184,6 @@ impl ToPyObject for Param {
}
}

impl Param {
fn compare(one: &PyObject, other: &PyObject) -> bool {
Python::with_gil(|py| -> PyResult<bool> {
let other_bound = other.bind(py);
Ok(other_bound.eq(one)? || other_bound.is(one))
})
.unwrap_or_default()
}
}

impl PartialEq for Param {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Param::Float(s), Param::Float(other)) => s == other,
(Param::ParameterExpression(one), Param::ParameterExpression(other)) => {
Self::compare(one, other)
}
(Param::Obj(one), Param::Obj(other)) => Self::compare(one, other),
_ => false,
}
}
}

impl std::fmt::Display for Param {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let display_name: String = Python::with_gil(|py| -> PyResult<String> {
match self {
Param::ParameterExpression(obj) => obj.call_method0(py, "__repr__")?.extract(py),
Param::Float(float_param) => Ok(format!("Parameter({})", float_param)),
Param::Obj(obj) => obj.call_method0(py, "__repr__")?.extract(py),
}
})
.unwrap_or("None".to_owned());
write!(f, "{}", display_name)
}
}

#[derive(Clone, Debug, Copy, Eq, PartialEq, Hash)]
#[pyclass(module = "qiskit._accelerate.circuit")]
pub enum StandardGate {
Expand Down

0 comments on commit bf9aedd

Please sign in to comment.