Skip to content

Commit

Permalink
Fix: Accept Param references in assign_parameters_inner
Browse files Browse the repository at this point in the history
  • Loading branch information
raynelfss committed Aug 12, 2024
1 parent 9e84225 commit a4c653c
Showing 1 changed file with 23 additions and 18 deletions.
41 changes: 23 additions & 18 deletions crates/circuit/src/circuit_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -898,12 +898,13 @@ impl CircuitData {
)));
}
let mut old_table = std::mem::take(&mut self.param_table);
let owned_iter: Vec<Param> = array.iter().map(|value| Param::Float(*value)).collect();
self.assign_parameters_inner(
sequence.py(),
array
owned_iter
.iter()
.zip(old_table.drain_ordered())
.map(|(value, (param_ob, uses))| (param_ob, Param::Float(*value), uses)),
.map(|(value, (obj, uses))| (obj, value, uses)),
)
} else {
let values = sequence
Expand All @@ -918,12 +919,21 @@ impl CircuitData {
fn assign_parameters_mapping(&mut self, mapping: Bound<PyAny>) -> PyResult<()> {
let py = mapping.py();
let mut items = Vec::new();
let mut objs = Vec::new();
for item in mapping.call_method0("items")?.iter()? {
let (param_ob, value) = item?.extract::<(Py<PyAny>, AssignParam)>()?;
let uuid = ParameterUuid::from_parameter(param_ob.bind(py))?;
items.push((param_ob, value.0, self.param_table.pop(uuid)?));
items.push(value);
objs.push(param_ob); // We need to separate the objects to avoid cloning.
}
self.assign_parameters_inner(py, items)
let borrowed_iterator: PyResult<Vec<_>> = items
.iter()
.zip(objs.into_iter())
.map(|(value, param_obj)| -> PyResult<_> {
let uuid = ParameterUuid::from_parameter(param_obj.bind(py))?;
Ok((param_obj, &value.0, self.param_table.pop(uuid)?))
})
.collect();
self.assign_parameters_inner(py, borrowed_iterator?)
}

pub fn clear(&mut self) {
Expand Down Expand Up @@ -1086,7 +1096,6 @@ impl CircuitData {
py,
slice
.iter()
.cloned()
.zip(old_table.drain_ordered())
.map(|(value, (param_ob, uses))| (param_ob, value, uses)),
)
Expand All @@ -1095,9 +1104,9 @@ impl CircuitData {
/// Assigns parameters to circuit data based on a mapping of `ParameterUuid` : `Param`.
/// This mapping assumes that the provided `ParameterUuid` keys are instances
/// of `ParameterExpression`.
pub fn assign_parameters_from_mapping<I>(&mut self, py: Python, iter: I) -> PyResult<()>
pub fn assign_parameters_from_mapping<'a, I>(&mut self, py: Python, iter: I) -> PyResult<()>
where
I: IntoIterator<Item = (ParameterUuid, Param)>,
I: IntoIterator<Item = (ParameterUuid, &'a Param)>,
{
let mut items = Vec::new();
for (param_uuid, value) in iter {
Expand All @@ -1112,9 +1121,9 @@ impl CircuitData {
self.assign_parameters_inner(py, items)
}

fn assign_parameters_inner<I>(&mut self, py: Python, iter: I) -> PyResult<()>
fn assign_parameters_inner<'a, I>(&mut self, py: Python, iter: I) -> PyResult<()>
where
I: IntoIterator<Item = (Py<PyAny>, Param, HashSet<ParameterUse>)>,
I: IntoIterator<Item = (Py<PyAny>, &'a Param, HashSet<ParameterUse>)>,
{
let inconsistent =
|| PyRuntimeError::new_err("internal error: circuit parameter table is inconsistent");
Expand Down Expand Up @@ -1162,7 +1171,7 @@ impl CircuitData {
};
self.set_global_phase(
py,
bind_expr(expr.bind_borrowed(py), &param_ob, &value, true)?,
bind_expr(expr.bind_borrowed(py), &param_ob, value, true)?,
)?;
}
ParameterUse::Index {
Expand All @@ -1177,7 +1186,7 @@ impl CircuitData {
return Err(inconsistent());
};
params[parameter] =
match bind_expr(expr.bind_borrowed(py), &param_ob, &value, true)? {
match bind_expr(expr.bind_borrowed(py), &param_ob, value, true)? {
Param::Obj(obj) => {
return Err(CircuitError::new_err(format!(
"bad type after binding for gate '{}': '{}'",
Expand Down Expand Up @@ -1215,12 +1224,8 @@ impl CircuitData {
Param::ParameterExpression(expr) => {
// For user gates, we don't coerce floats to integers in `Param`
// so that users can use them if they choose.
let new_param = bind_expr(
expr.bind_borrowed(py),
&param_ob,
&value,
false,
)?;
let new_param =
bind_expr(expr.bind_borrowed(py), &param_ob, value, false)?;
// Historically, `assign_parameters` called `validate_parameter`
// only when a `ParameterExpression` became fully bound. Some
// "generalised" (or user) gates fail without this, though
Expand Down

0 comments on commit a4c653c

Please sign in to comment.