Skip to content

Commit

Permalink
Allow arbitrary iterables in assign_parameters (Qiskit#12887)
Browse files Browse the repository at this point in the history
In Qiskit 1.1, it was possible to give any object that was iterable and
had a `__len__` as the binding sequence for `assign_parameters`.  The
move to Rust space inadvertantly limited that to things that fulfilled
the sequence API.  This commit restores the ability to use general
iterables, and removes the need to have a `__len__`.
  • Loading branch information
jakelishman authored Aug 2, 2024
1 parent 6663db1 commit 9de8119
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 15 deletions.
32 changes: 21 additions & 11 deletions crates/circuit/src/circuit_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -834,19 +834,19 @@ impl CircuitData {
Ok(())
}

/// Assign all the circuit parameters, given a sequence-like input of `Param` instances.
fn assign_parameters_sequence(&mut self, sequence: Bound<PyAny>) -> PyResult<()> {
if sequence.len()? != self.param_table.num_parameters() {
return Err(PyValueError::new_err(concat!(
"Mismatching number of values and parameters. For partial binding ",
"please pass a dictionary of {parameter: value} pairs."
)));
}
let mut old_table = std::mem::take(&mut self.param_table);
/// Assign all the circuit parameters, given an iterable input of `Param` instances.
fn assign_parameters_iterable(&mut self, sequence: Bound<PyAny>) -> PyResult<()> {
if let Ok(readonly) = sequence.extract::<PyReadonlyArray1<f64>>() {
// Fast path for Numpy arrays; in this case we can easily handle them without copying
// the data across into a Rust-space `Vec` first.
let array = readonly.as_array();
if array.len() != self.param_table.num_parameters() {
return Err(PyValueError::new_err(concat!(
"Mismatching number of values and parameters. For partial binding ",
"please pass a dictionary of {parameter: value} pairs."
)));
}
let mut old_table = std::mem::take(&mut self.param_table);
self.assign_parameters_inner(
sequence.py(),
array
Expand All @@ -855,13 +855,23 @@ impl CircuitData {
.map(|(value, (param_ob, uses))| (param_ob, Param::Float(*value), uses)),
)
} else {
let values = sequence.extract::<Vec<AssignParam>>()?;
let values = sequence
.iter()?
.map(|ob| Param::extract_no_coerce(&ob?))
.collect::<PyResult<Vec<_>>>()?;
if values.len() != self.param_table.num_parameters() {
return Err(PyValueError::new_err(concat!(
"Mismatching number of values and parameters. For partial binding ",
"please pass a dictionary of {parameter: value} pairs."
)));
}
let mut old_table = std::mem::take(&mut self.param_table);
self.assign_parameters_inner(
sequence.py(),
values
.into_iter()
.zip(old_table.drain_ordered())
.map(|(value, (param_ob, uses))| (param_ob, value.0, uses)),
.map(|(value, (param_ob, uses))| (param_ob, value, uses)),
)
}
}
Expand Down
8 changes: 4 additions & 4 deletions qiskit/circuit/quantumcircuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4188,7 +4188,7 @@ def _unsorted_parameters(self) -> set[Parameter]:
@overload
def assign_parameters(
self,
parameters: Union[Mapping[Parameter, ParameterValueType], Sequence[ParameterValueType]],
parameters: Union[Mapping[Parameter, ParameterValueType], Iterable[ParameterValueType]],
inplace: Literal[False] = ...,
*,
flat_input: bool = ...,
Expand All @@ -4198,7 +4198,7 @@ def assign_parameters(
@overload
def assign_parameters(
self,
parameters: Union[Mapping[Parameter, ParameterValueType], Sequence[ParameterValueType]],
parameters: Union[Mapping[Parameter, ParameterValueType], Iterable[ParameterValueType]],
inplace: Literal[True] = ...,
*,
flat_input: bool = ...,
Expand All @@ -4207,7 +4207,7 @@ def assign_parameters(

def assign_parameters( # pylint: disable=missing-raises-doc
self,
parameters: Union[Mapping[Parameter, ParameterValueType], Sequence[ParameterValueType]],
parameters: Union[Mapping[Parameter, ParameterValueType], Iterable[ParameterValueType]],
inplace: bool = False,
*,
flat_input: bool = False,
Expand Down Expand Up @@ -4317,7 +4317,7 @@ def assign_parameters( # pylint: disable=missing-raises-doc
target._data.assign_parameters_mapping(parameter_binds)
else:
parameter_binds = _ParameterBindsSequence(target._data.parameters, parameters)
target._data.assign_parameters_sequence(parameters)
target._data.assign_parameters_iterable(parameters)

# Finally, assign the parameters inside any of the calibrations. We don't track these in
# the `ParameterTable`, so we manually reconstruct things.
Expand Down
14 changes: 14 additions & 0 deletions test/python/circuit/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,20 @@ def test_assign_parameters_by_name(self):
qc.assign_parameters({a: 1, b: 2, c: 3}), qc.assign_parameters({"a": 1, "b": 2, "c": 3})
)

def test_assign_parameters_by_iterable(self):
"""Assignment works with weird iterables."""
a, b, c = Parameter("a"), Parameter("b"), Parameter("c")
qc = QuantumCircuit(1)
qc.rz(a, 0)
qc.rz(b + c, 0)

binds = [1.25, 2.5, 0.125]
expected = qc.assign_parameters(dict(zip(qc.parameters, binds)))
self.assertEqual(qc.assign_parameters(iter(binds)), expected)
self.assertEqual(qc.assign_parameters(dict.fromkeys(binds).keys()), expected)
self.assertEqual(qc.assign_parameters(dict(zip(qc.parameters, binds)).values()), expected)
self.assertEqual(qc.assign_parameters(bind for bind in binds), expected)

def test_bind_parameters_custom_definition_global_phase(self):
"""Test that a custom gate with a parametrized `global_phase` is assigned correctly."""
x = Parameter("x")
Expand Down

0 comments on commit 9de8119

Please sign in to comment.