From 859862883504304d8de727b922fd8cc24a952276 Mon Sep 17 00:00:00 2001 From: Raynel Sanchez <87539502+raynelfss@users.noreply.github.com> Date: Mon, 1 Jul 2024 11:10:34 -0400 Subject: [PATCH] Initial: Extend functionality of `Param` - Implement `PartialEq` using `Python::with_gil()` to compare parameters through Python. - Add display method for debugging. --- crates/circuit/src/operations.rs | 36 ++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/crates/circuit/src/operations.rs b/crates/circuit/src/operations.rs index af7dabc86216..ef7609fac65c 100644 --- a/crates/circuit/src/operations.rs +++ b/crates/circuit/src/operations.rs @@ -176,6 +176,42 @@ impl ToPyObject for Param { } } +impl PartialEq for Param { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Param::Float(s), Param::Float(other)) => s == other, + (Param::ParameterExpression(one), Param::ParameterExpression(other)) => { + compare(one, other) + } + (Param::Obj(one), Param::Obj(other)) => compare(one, other), + _ => false, + } + } +} + +impl std::fmt::Display for Param { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let display_name: String = Python::with_gil(|py| -> PyResult { + match self { + Param::ParameterExpression(obj) => obj.call_method0(py, "__repr__")?.extract(py), + Param::Float(float_param) => Ok(format!("Parameter({})", float_param)), + Param::Obj(obj) => obj.call_method0(py, "__repr__")?.extract(py), + } + }) + .unwrap_or("None".to_owned()); + write!(f, "{}", display_name) + } +} + +/// Perform comparison between two Python objects +pub fn compare(one: &PyObject, other: &PyObject) -> bool { + Python::with_gil(|py| -> PyResult { + let other_bound = other.bind(py); + Ok(other_bound.eq(one)? || other_bound.is(one)) + }) + .unwrap_or_default() +} + #[derive(Clone, Debug, Copy, Eq, PartialEq, Hash)] #[pyclass(module = "qiskit._accelerate.circuit")] pub enum StandardGate {