forked from Qiskit/qiskit
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into move-target
- Loading branch information
Showing
40 changed files
with
1,171 additions
and
548 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
// This code is part of Qiskit. | ||
// | ||
// (C) Copyright IBM 2024 | ||
// | ||
// This code is licensed under the Apache License, Version 2.0. You may | ||
// obtain a copy of this license in the LICENSE.txt file in the root directory | ||
// of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. | ||
// | ||
// Any modifications or derivative works of this code must retain this | ||
// copyright notice, and modified files need to carry a notice indicating | ||
// that they have been altered from the originals. | ||
|
||
use ndarray::{Array1, ArrayView1}; | ||
use numpy::PyArrayLike1; | ||
use pyo3::exceptions::PyValueError; | ||
use pyo3::prelude::*; | ||
use std::vec::Vec; | ||
|
||
fn validate_permutation(pattern: &ArrayView1<i64>) -> PyResult<()> { | ||
let n = pattern.len(); | ||
let mut seen: Vec<bool> = vec![false; n]; | ||
|
||
for &x in pattern { | ||
if x < 0 { | ||
return Err(PyValueError::new_err( | ||
"Invalid permutation: input contains a negative number.", | ||
)); | ||
} | ||
|
||
if x as usize >= n { | ||
return Err(PyValueError::new_err(format!( | ||
"Invalid permutation: input has length {} and contains {}.", | ||
n, x | ||
))); | ||
} | ||
|
||
if seen[x as usize] { | ||
return Err(PyValueError::new_err(format!( | ||
"Invalid permutation: input contains {} more than once.", | ||
x | ||
))); | ||
} | ||
|
||
seen[x as usize] = true; | ||
} | ||
|
||
Ok(()) | ||
} | ||
|
||
fn invert(pattern: &ArrayView1<i64>) -> Array1<usize> { | ||
let mut inverse: Array1<usize> = Array1::zeros(pattern.len()); | ||
pattern.iter().enumerate().for_each(|(ii, &jj)| { | ||
inverse[jj as usize] = ii; | ||
}); | ||
inverse | ||
} | ||
|
||
fn get_ordered_swap(pattern: &ArrayView1<i64>) -> Vec<(i64, i64)> { | ||
let mut permutation: Vec<usize> = pattern.iter().map(|&x| x as usize).collect(); | ||
let mut index_map = invert(pattern); | ||
|
||
let n = permutation.len(); | ||
let mut swaps: Vec<(i64, i64)> = Vec::with_capacity(n); | ||
for ii in 0..n { | ||
let val = permutation[ii]; | ||
if val == ii { | ||
continue; | ||
} | ||
let jj = index_map[ii]; | ||
swaps.push((ii as i64, jj as i64)); | ||
(permutation[ii], permutation[jj]) = (permutation[jj], permutation[ii]); | ||
index_map[val] = jj; | ||
index_map[ii] = ii; | ||
} | ||
|
||
swaps[..].reverse(); | ||
swaps | ||
} | ||
|
||
/// Checks whether an array of size N is a permutation of 0, 1, ..., N - 1. | ||
#[pyfunction] | ||
#[pyo3(signature = (pattern))] | ||
fn _validate_permutation(py: Python, pattern: PyArrayLike1<i64>) -> PyResult<PyObject> { | ||
let view = pattern.as_array(); | ||
validate_permutation(&view)?; | ||
Ok(py.None()) | ||
} | ||
|
||
/// Finds inverse of a permutation pattern. | ||
#[pyfunction] | ||
#[pyo3(signature = (pattern))] | ||
fn _inverse_pattern(py: Python, pattern: PyArrayLike1<i64>) -> PyResult<PyObject> { | ||
let view = pattern.as_array(); | ||
let inverse_i64: Vec<i64> = invert(&view).iter().map(|&x| x as i64).collect(); | ||
Ok(inverse_i64.to_object(py)) | ||
} | ||
|
||
/// Sorts the input permutation by iterating through the permutation list | ||
/// and putting each element to its correct position via a SWAP (if it's not | ||
/// at the correct position already). If ``n`` is the length of the input | ||
/// permutation, this requires at most ``n`` SWAPs. | ||
/// | ||
/// More precisely, if the input permutation is a cycle of length ``m``, | ||
/// then this creates a quantum circuit with ``m-1`` SWAPs (and of depth ``m-1``); | ||
/// if the input permutation consists of several disjoint cycles, then each cycle | ||
/// is essentially treated independently. | ||
#[pyfunction] | ||
#[pyo3(signature = (permutation_in))] | ||
fn _get_ordered_swap(py: Python, permutation_in: PyArrayLike1<i64>) -> PyResult<PyObject> { | ||
let view = permutation_in.as_array(); | ||
Ok(get_ordered_swap(&view).to_object(py)) | ||
} | ||
|
||
#[pymodule] | ||
pub fn permutation(m: &Bound<PyModule>) -> PyResult<()> { | ||
m.add_function(wrap_pyfunction!(_validate_permutation, m)?)?; | ||
m.add_function(wrap_pyfunction!(_inverse_pattern, m)?)?; | ||
m.add_function(wrap_pyfunction!(_get_ordered_swap, m)?)?; | ||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
// This code is part of Qiskit. | ||
// | ||
// (C) Copyright IBM 2024 | ||
// | ||
// This code is licensed under the Apache License, Version 2.0. You may | ||
// obtain a copy of this license in the LICENSE.txt file in the root directory | ||
// of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. | ||
// | ||
// Any modifications or derivative works of this code must retain this | ||
// copyright notice, and modified files need to carry a notice indicating | ||
// that they have been altered from the originals. | ||
|
||
use crate::BitType; | ||
use hashbrown::HashMap; | ||
use pyo3::exceptions::{PyRuntimeError, PyValueError}; | ||
use pyo3::prelude::*; | ||
use pyo3::types::PyList; | ||
use std::fmt::Debug; | ||
use std::hash::{Hash, Hasher}; | ||
|
||
/// Private wrapper for Python-side Bit instances that implements | ||
/// [Hash] and [Eq], allowing them to be used in Rust hash-based | ||
/// sets and maps. | ||
/// | ||
/// Python's `hash()` is called on the wrapped Bit instance during | ||
/// construction and returned from Rust's [Hash] trait impl. | ||
/// The impl of [PartialEq] first compares the native Py pointers | ||
/// to determine equality. If these are not equal, only then does | ||
/// it call `repr()` on both sides, which has a significant | ||
/// performance advantage. | ||
#[derive(Clone, Debug)] | ||
struct BitAsKey { | ||
/// Python's `hash()` of the wrapped instance. | ||
hash: isize, | ||
/// The wrapped instance. | ||
bit: PyObject, | ||
} | ||
|
||
impl BitAsKey { | ||
pub fn new(bit: &Bound<PyAny>) -> Self { | ||
BitAsKey { | ||
// This really shouldn't fail, but if it does, | ||
// we'll just use 0. | ||
hash: bit.hash().unwrap_or(0), | ||
bit: bit.clone().unbind(), | ||
} | ||
} | ||
} | ||
|
||
impl Hash for BitAsKey { | ||
fn hash<H: Hasher>(&self, state: &mut H) { | ||
state.write_isize(self.hash); | ||
} | ||
} | ||
|
||
impl PartialEq for BitAsKey { | ||
fn eq(&self, other: &Self) -> bool { | ||
self.bit.is(&other.bit) | ||
|| Python::with_gil(|py| { | ||
self.bit | ||
.bind(py) | ||
.repr() | ||
.unwrap() | ||
.eq(other.bit.bind(py).repr().unwrap()) | ||
.unwrap() | ||
}) | ||
} | ||
} | ||
|
||
impl Eq for BitAsKey {} | ||
|
||
#[derive(Clone, Debug)] | ||
pub(crate) struct BitData<T> { | ||
/// The public field name (i.e. `qubits` or `clbits`). | ||
description: String, | ||
/// Registered Python bits. | ||
bits: Vec<PyObject>, | ||
/// Maps Python bits to native type. | ||
indices: HashMap<BitAsKey, T>, | ||
/// The bits registered, cached as a PyList. | ||
cached: Py<PyList>, | ||
} | ||
|
||
pub(crate) struct BitNotFoundError<'py>(pub(crate) Bound<'py, PyAny>); | ||
|
||
impl<T> BitData<T> | ||
where | ||
T: From<BitType> + Copy, | ||
BitType: From<T>, | ||
{ | ||
pub fn new(py: Python<'_>, description: String) -> Self { | ||
BitData { | ||
description, | ||
bits: Vec::new(), | ||
indices: HashMap::new(), | ||
cached: PyList::empty_bound(py).unbind(), | ||
} | ||
} | ||
|
||
/// Gets the number of bits. | ||
pub fn len(&self) -> usize { | ||
self.bits.len() | ||
} | ||
|
||
/// Gets a reference to the underlying vector of Python bits. | ||
#[inline] | ||
pub fn bits(&self) -> &Vec<PyObject> { | ||
&self.bits | ||
} | ||
|
||
/// Gets a reference to the cached Python list, maintained by | ||
/// this instance. | ||
#[inline] | ||
pub fn cached(&self) -> &Py<PyList> { | ||
&self.cached | ||
} | ||
|
||
/// Finds the native bit index of the given Python bit. | ||
#[inline] | ||
pub fn find(&self, bit: &Bound<PyAny>) -> Option<T> { | ||
self.indices.get(&BitAsKey::new(bit)).copied() | ||
} | ||
|
||
/// Map the provided Python bits to their native indices. | ||
/// An error is returned if any bit is not registered. | ||
pub fn map_bits<'py>( | ||
&self, | ||
bits: impl IntoIterator<Item = Bound<'py, PyAny>>, | ||
) -> Result<impl Iterator<Item = T>, BitNotFoundError<'py>> { | ||
let v: Result<Vec<_>, _> = bits | ||
.into_iter() | ||
.map(|b| { | ||
self.indices | ||
.get(&BitAsKey::new(&b)) | ||
.copied() | ||
.ok_or_else(|| BitNotFoundError(b)) | ||
}) | ||
.collect(); | ||
v.map(|x| x.into_iter()) | ||
} | ||
|
||
/// Map the provided native indices to the corresponding Python | ||
/// bit instances. | ||
/// Panics if any of the indices are out of range. | ||
pub fn map_indices(&self, bits: &[T]) -> impl Iterator<Item = &Py<PyAny>> + ExactSizeIterator { | ||
let v: Vec<_> = bits.iter().map(|i| self.get(*i).unwrap()).collect(); | ||
v.into_iter() | ||
} | ||
|
||
/// Gets the Python bit corresponding to the given native | ||
/// bit index. | ||
#[inline] | ||
pub fn get(&self, index: T) -> Option<&PyObject> { | ||
self.bits.get(<BitType as From<T>>::from(index) as usize) | ||
} | ||
|
||
/// Adds a new Python bit. | ||
pub fn add(&mut self, py: Python, bit: &Bound<PyAny>, strict: bool) -> PyResult<()> { | ||
if self.bits.len() != self.cached.bind(bit.py()).len() { | ||
return Err(PyRuntimeError::new_err( | ||
format!("This circuit's {} list has become out of sync with the circuit data. Did something modify it?", self.description) | ||
)); | ||
} | ||
let idx: BitType = self.bits.len().try_into().map_err(|_| { | ||
PyRuntimeError::new_err(format!( | ||
"The number of {} in the circuit has exceeded the maximum capacity", | ||
self.description | ||
)) | ||
})?; | ||
if self | ||
.indices | ||
.try_insert(BitAsKey::new(bit), idx.into()) | ||
.is_ok() | ||
{ | ||
self.bits.push(bit.into_py(py)); | ||
self.cached.bind(py).append(bit)?; | ||
} else if strict { | ||
return Err(PyValueError::new_err(format!( | ||
"Existing bit {:?} cannot be re-added in strict mode.", | ||
bit | ||
))); | ||
} | ||
Ok(()) | ||
} | ||
|
||
/// Called during Python garbage collection, only!. | ||
/// Note: INVALIDATES THIS INSTANCE. | ||
pub fn dispose(&mut self) { | ||
self.indices.clear(); | ||
self.bits.clear(); | ||
} | ||
} |
Oops, something went wrong.