From a4c653c2763719636af977751874be1fd8e59d06 Mon Sep 17 00:00:00 2001 From: Raynel Sanchez Date: Mon, 12 Aug 2024 14:40:09 -0400 Subject: [PATCH] Fix: Accept Param references in `assign_parameters_inner` --- crates/circuit/src/circuit_data.rs | 41 +++++++++++++++++------------- 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/crates/circuit/src/circuit_data.rs b/crates/circuit/src/circuit_data.rs index ec17b4f8a9e..02e2021518b 100644 --- a/crates/circuit/src/circuit_data.rs +++ b/crates/circuit/src/circuit_data.rs @@ -898,12 +898,13 @@ impl CircuitData { ))); } let mut old_table = std::mem::take(&mut self.param_table); + let owned_iter: Vec = array.iter().map(|value| Param::Float(*value)).collect(); self.assign_parameters_inner( sequence.py(), - array + owned_iter .iter() .zip(old_table.drain_ordered()) - .map(|(value, (param_ob, uses))| (param_ob, Param::Float(*value), uses)), + .map(|(value, (obj, uses))| (obj, value, uses)), ) } else { let values = sequence @@ -918,12 +919,21 @@ impl CircuitData { fn assign_parameters_mapping(&mut self, mapping: Bound) -> PyResult<()> { let py = mapping.py(); let mut items = Vec::new(); + let mut objs = Vec::new(); for item in mapping.call_method0("items")?.iter()? { let (param_ob, value) = item?.extract::<(Py, AssignParam)>()?; - let uuid = ParameterUuid::from_parameter(param_ob.bind(py))?; - items.push((param_ob, value.0, self.param_table.pop(uuid)?)); + items.push(value); + objs.push(param_ob); // We need to separate the objects to avoid cloning. } - self.assign_parameters_inner(py, items) + let borrowed_iterator: PyResult> = items + .iter() + .zip(objs.into_iter()) + .map(|(value, param_obj)| -> PyResult<_> { + let uuid = ParameterUuid::from_parameter(param_obj.bind(py))?; + Ok((param_obj, &value.0, self.param_table.pop(uuid)?)) + }) + .collect(); + self.assign_parameters_inner(py, borrowed_iterator?) } pub fn clear(&mut self) { @@ -1086,7 +1096,6 @@ impl CircuitData { py, slice .iter() - .cloned() .zip(old_table.drain_ordered()) .map(|(value, (param_ob, uses))| (param_ob, value, uses)), ) @@ -1095,9 +1104,9 @@ impl CircuitData { /// Assigns parameters to circuit data based on a mapping of `ParameterUuid` : `Param`. /// This mapping assumes that the provided `ParameterUuid` keys are instances /// of `ParameterExpression`. - pub fn assign_parameters_from_mapping(&mut self, py: Python, iter: I) -> PyResult<()> + pub fn assign_parameters_from_mapping<'a, I>(&mut self, py: Python, iter: I) -> PyResult<()> where - I: IntoIterator, + I: IntoIterator, { let mut items = Vec::new(); for (param_uuid, value) in iter { @@ -1112,9 +1121,9 @@ impl CircuitData { self.assign_parameters_inner(py, items) } - fn assign_parameters_inner(&mut self, py: Python, iter: I) -> PyResult<()> + fn assign_parameters_inner<'a, I>(&mut self, py: Python, iter: I) -> PyResult<()> where - I: IntoIterator, Param, HashSet)>, + I: IntoIterator, &'a Param, HashSet)>, { let inconsistent = || PyRuntimeError::new_err("internal error: circuit parameter table is inconsistent"); @@ -1162,7 +1171,7 @@ impl CircuitData { }; self.set_global_phase( py, - bind_expr(expr.bind_borrowed(py), ¶m_ob, &value, true)?, + bind_expr(expr.bind_borrowed(py), ¶m_ob, value, true)?, )?; } ParameterUse::Index { @@ -1177,7 +1186,7 @@ impl CircuitData { return Err(inconsistent()); }; params[parameter] = - match bind_expr(expr.bind_borrowed(py), ¶m_ob, &value, true)? { + match bind_expr(expr.bind_borrowed(py), ¶m_ob, value, true)? { Param::Obj(obj) => { return Err(CircuitError::new_err(format!( "bad type after binding for gate '{}': '{}'", @@ -1215,12 +1224,8 @@ impl CircuitData { Param::ParameterExpression(expr) => { // For user gates, we don't coerce floats to integers in `Param` // so that users can use them if they choose. - let new_param = bind_expr( - expr.bind_borrowed(py), - ¶m_ob, - &value, - false, - )?; + let new_param = + bind_expr(expr.bind_borrowed(py), ¶m_ob, value, false)?; // Historically, `assign_parameters` called `validate_parameter` // only when a `ParameterExpression` became fully bound. Some // "generalised" (or user) gates fail without this, though