From bf9aedd72b019600b223e71f46e0107a295457b3 Mon Sep 17 00:00:00 2001 From: Raynel Sanchez <87539502+raynelfss@users.noreply.github.com> Date: Fri, 5 Jul 2024 12:15:24 -0400 Subject: [PATCH] Remove: Undo changes to Param - Fix comparison methods for `Key`, `Equivalence`, `EdgeData` and `NodeData` to account for the removal of `PartialEq` for `Param`. --- crates/circuit/src/equivalence.rs | 91 ++++++++++++++++++------------- crates/circuit/src/operations.rs | 37 ------------- 2 files changed, 54 insertions(+), 74 deletions(-) diff --git a/crates/circuit/src/equivalence.rs b/crates/circuit/src/equivalence.rs index b7456f470999..fa720ef107a4 100644 --- a/crates/circuit/src/equivalence.rs +++ b/crates/circuit/src/equivalence.rs @@ -64,8 +64,8 @@ impl Key { Self { name, num_qubits } } - fn __eq__(&self, other: Self) -> bool { - self.eq(&other) + fn __eq__(&self, other: &Self) -> bool { + self.eq(other) } fn __hash__(&self) -> u64 { @@ -97,7 +97,7 @@ impl Display for Key { } #[pyclass(sequence, module = "qiskit._accelerate.circuit.equivalence")] -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone)] pub struct Equivalence { #[pyo3(get)] pub params: SmallVec<[Param; 3]>, @@ -117,8 +117,20 @@ impl Equivalence { self.to_string() } - fn __eq__(&self, other: Self) -> bool { - self.eq(&other) + fn __eq__(&self, py: Python, other: &Self) -> PyResult { + let bound_circ = self.circuit.object.bind(py); + let other_len = other.params.len(); + let mut other_iter = other.params.iter(); + + Ok(bound_circ.eq(&other.circuit.object)? + && self.params.len() == other_len + && self.params.iter().all(|param| { + let param = param.to_object(py); + let param_bound = param.bind(py); + param_bound + .eq(other_iter.next().unwrap()) + .unwrap_or_default() + })) } fn __getnewargs__(&self, py: Python) -> (Py, PyObject) { @@ -142,7 +154,7 @@ impl Display for Equivalence { } #[pyclass(sequence, module = "qiskit._accelerate.circuit.equivalence")] -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone)] pub struct NodeData { #[pyo3(get)] key: Key, @@ -162,8 +174,16 @@ impl NodeData { self.to_string() } - fn __eq__(&self, other: Self) -> bool { - self.eq(&other) + fn __eq__(&self, py: Python, other: &Self) -> PyResult { + let other_len = other.equivs.len(); + let mut other_iter = other.equivs.iter(); + Ok(self.key == other.key + && self.equivs.len() == other_len + && self.equivs.iter().all(|equiv| { + equiv + .__eq__(py, other_iter.next().unwrap()) + .unwrap_or_default() + })) } fn __getnewargs__(&self, py: Python) -> (Key, Py) { @@ -190,7 +210,7 @@ impl Display for NodeData { } #[pyclass(sequence, module = "qiskit._accelerate.circuit.equivalence")] -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone)] pub struct EdgeData { #[pyo3(get)] pub index: usize, @@ -219,8 +239,11 @@ impl EdgeData { self.to_string() } - fn __eq__(&self, other: Self) -> bool { - self.eq(&other) + fn __eq__(&self, py: Python, other: Self) -> PyResult { + Ok(self.index == other.index + && self.num_gates == other.num_gates + && self.source == other.source + && self.rule.__eq__(py, &other.rule)?) } fn __getnewargs__(slf: PyRef) -> (usize, usize, Equivalence, Key) { @@ -243,7 +266,7 @@ impl Display for EdgeData { } } -// Enum to extract circuit instructions more broadly +/// Enum that helps extract the Operation and Parameters on a Gate. #[derive(Debug, Clone)] pub struct GateOper { operation: OperationType, @@ -266,38 +289,32 @@ pub struct CircuitRep { object: PyObject, pub num_qubits: u32, pub num_clbits: u32, - params: Option>, + params: Option, data: Option>, // TODO: Have a valid implementation of CircuiData that's usable in Rust. } impl CircuitRep { - pub fn parameters(&mut self) -> &[Param] { + pub fn parameters(&mut self, py: Python) -> PyResult { if self.params.is_none() { - let params = Python::with_gil(|py| -> PyResult> { - self.object - .bind(py) - .getattr("parameters")? - .getattr("data")? - .extract() - }) - .unwrap_or_default(); - self.params = Some(params); - return self.params.as_ref().unwrap(); + let params = self.object.getattr(py, "parameters")?; + self.params = Some(params.clone_ref(py)); + return Ok(params); } - return self.params.as_ref().unwrap(); + return Ok(self + .params + .as_ref() + .map(|params| params.clone_ref(py)) + .unwrap()); } - pub fn data(&mut self) -> &[CircuitInstruction] { + pub fn data(&mut self, py: Python) -> PyResult<&[CircuitInstruction]> { if self.data.is_none() { - let data = Python::with_gil(|py| -> PyResult> { - self.object.bind(py).getattr("data")?.extract() - }) - .unwrap_or_default(); + let data = self.object.bind(py).getattr("data")?.extract()?; self.data = Some(data); - return self.data.as_ref().unwrap(); + return Ok(self.data.as_ref().unwrap()); } - return self.data.as_ref().unwrap(); + return Ok(self.data.as_ref().unwrap()); } } @@ -413,7 +430,7 @@ impl EquivalenceLibrary { mut equivalent_circuit: CircuitRep, ) -> PyResult<()> { raise_if_shape_mismatch(&gate, &equivalent_circuit)?; - raise_if_param_mismatch(py, &gate.params, equivalent_circuit.parameters())?; + raise_if_param_mismatch(py, &gate.params, equivalent_circuit.parameters(py)?)?; let key: Key = Key { name: gate.operation.name().to_string(), @@ -429,7 +446,7 @@ impl EquivalenceLibrary { node.equivs.push(equiv.clone()); } let sources: HashSet = - HashSet::from_iter(equivalent_circuit.data().iter().map(|inst| Key { + HashSet::from_iter(equivalent_circuit.data(py)?.iter().map(|inst| Key { name: inst.operation.name().to_string(), num_qubits: inst.operation.num_qubits(), })); @@ -485,7 +502,7 @@ impl EquivalenceLibrary { ) -> PyResult<()> { for equiv in entry.iter_mut() { raise_if_shape_mismatch(&gate, equiv)?; - raise_if_param_mismatch(py, &gate.params, equiv.parameters())?; + raise_if_param_mismatch(py, &gate.params, equiv.parameters(py)?)?; } let key = Key { @@ -696,7 +713,7 @@ impl EquivalenceLibrary { fn raise_if_param_mismatch( py: Python, gate_params: &[Param], - circuit_parameters: &[Param], + circuit_parameters: PyObject, ) -> PyResult<()> { let gate_params_obj = PySet::new_bound( py, @@ -704,7 +721,7 @@ fn raise_if_param_mismatch( .iter() .filter(|param| matches!(param, Param::ParameterExpression(_))), )?; - if !gate_params_obj.eq(PySet::new_bound(py, circuit_parameters)?)? { + if !gate_params_obj.eq(&circuit_parameters)? { return Err(CircuitError::new_err(format!( "Cannot add equivalence between circuit and gate \ of different parameters. Gate params: {:?}. \ diff --git a/crates/circuit/src/operations.rs b/crates/circuit/src/operations.rs index 6cb7ed7893bf..3bfef81d29ce 100644 --- a/crates/circuit/src/operations.rs +++ b/crates/circuit/src/operations.rs @@ -184,43 +184,6 @@ impl ToPyObject for Param { } } -impl Param { - fn compare(one: &PyObject, other: &PyObject) -> bool { - Python::with_gil(|py| -> PyResult { - let other_bound = other.bind(py); - Ok(other_bound.eq(one)? || other_bound.is(one)) - }) - .unwrap_or_default() - } -} - -impl PartialEq for Param { - fn eq(&self, other: &Self) -> bool { - match (self, other) { - (Param::Float(s), Param::Float(other)) => s == other, - (Param::ParameterExpression(one), Param::ParameterExpression(other)) => { - Self::compare(one, other) - } - (Param::Obj(one), Param::Obj(other)) => Self::compare(one, other), - _ => false, - } - } -} - -impl std::fmt::Display for Param { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let display_name: String = Python::with_gil(|py| -> PyResult { - match self { - Param::ParameterExpression(obj) => obj.call_method0(py, "__repr__")?.extract(py), - Param::Float(float_param) => Ok(format!("Parameter({})", float_param)), - Param::Obj(obj) => obj.call_method0(py, "__repr__")?.extract(py), - } - }) - .unwrap_or("None".to_owned()); - write!(f, "{}", display_name) - } -} - #[derive(Clone, Debug, Copy, Eq, PartialEq, Hash)] #[pyclass(module = "qiskit._accelerate.circuit")] pub enum StandardGate {