diff --git a/crates/circuit/src/operations.rs b/crates/circuit/src/operations.rs index af7dabc86216..ef7609fac65c 100644 --- a/crates/circuit/src/operations.rs +++ b/crates/circuit/src/operations.rs @@ -176,6 +176,42 @@ impl ToPyObject for Param { } } +impl PartialEq for Param { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Param::Float(s), Param::Float(other)) => s == other, + (Param::ParameterExpression(one), Param::ParameterExpression(other)) => { + compare(one, other) + } + (Param::Obj(one), Param::Obj(other)) => compare(one, other), + _ => false, + } + } +} + +impl std::fmt::Display for Param { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let display_name: String = Python::with_gil(|py| -> PyResult { + match self { + Param::ParameterExpression(obj) => obj.call_method0(py, "__repr__")?.extract(py), + Param::Float(float_param) => Ok(format!("Parameter({})", float_param)), + Param::Obj(obj) => obj.call_method0(py, "__repr__")?.extract(py), + } + }) + .unwrap_or("None".to_owned()); + write!(f, "{}", display_name) + } +} + +/// Perform comparison between two Python objects +pub fn compare(one: &PyObject, other: &PyObject) -> bool { + Python::with_gil(|py| -> PyResult { + let other_bound = other.bind(py); + Ok(other_bound.eq(one)? || other_bound.is(one)) + }) + .unwrap_or_default() +} + #[derive(Clone, Debug, Copy, Eq, PartialEq, Hash)] #[pyclass(module = "qiskit._accelerate.circuit")] pub enum StandardGate {