From 934b82596380d4e0ef1aa5f030c54f1ee3d0c285 Mon Sep 17 00:00:00 2001 From: Raynel Sanchez <87539502+raynelfss@users.noreply.github.com> Date: Sat, 15 Jun 2024 00:30:05 +0200 Subject: [PATCH] Add: Comparison methods for `Param` --- crates/circuit/src/operations.rs | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/crates/circuit/src/operations.rs b/crates/circuit/src/operations.rs index ceca4e942aa9..eaed5402f24d 100644 --- a/crates/circuit/src/operations.rs +++ b/crates/circuit/src/operations.rs @@ -177,7 +177,7 @@ impl Param { fn compare(one: &PyObject, other: &PyObject) -> bool { Python::with_gil(|py| -> PyResult { let other_bound = other.bind(py); - other_bound.eq(one) + Ok(other_bound.eq(one)? || other_bound.is(one)) }) .unwrap_or_default() } @@ -187,16 +187,11 @@ impl PartialEq for Param { fn eq(&self, other: &Self) -> bool { match (self, other) { (Param::Float(s), Param::Float(other)) => s == other, - (Param::Float(_), Param::ParameterExpression(_)) => false, - (Param::ParameterExpression(_), Param::Float(_)) => false, - (Param::ParameterExpression(s), Param::ParameterExpression(other)) => { - Self::compare(s, other) + (Param::ParameterExpression(one), Param::ParameterExpression(other)) => { + Self::compare(one, other) } - (Param::ParameterExpression(_), Param::Obj(_)) => false, - (Param::Float(_), Param::Obj(_)) => false, - (Param::Obj(_), Param::ParameterExpression(_)) => false, - (Param::Obj(_), Param::Float(_)) => false, (Param::Obj(one), Param::Obj(other)) => Self::compare(one, other), + _ => false, } } }