Skip to content

Commit

Permalink
Add new BitData to CircuitData
Browse files Browse the repository at this point in the history
  • Loading branch information
raynelfss committed Jan 17, 2025
1 parent 3a0c5bb commit f8bdfa6
Show file tree
Hide file tree
Showing 5 changed files with 308 additions and 125 deletions.
198 changes: 171 additions & 27 deletions crates/circuit/src/bit_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,9 @@ use crate::imports::{CLASSICAL_REGISTER, QUANTUM_REGISTER, REGISTER};
use crate::register::{Register, RegisterAsKey};
use crate::{BitType, ToPyBit};
use hashbrown::HashMap;
use indexmap::{Equivalent, IndexSet};
use pyo3::exceptions::{PyKeyError, PyRuntimeError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList};
use std::borrow::Borrow;
use std::fmt::Debug;
use std::hash::{Hash, Hasher};
use std::sync::OnceLock;
Expand Down Expand Up @@ -245,25 +243,29 @@ pub struct NewBitData<T: From<BitType>, R: Register + Hash + Eq> {
bits: Vec<OnceLock<PyObject>>,
/// Maps Python bits to native type.
indices: HashMap<BitAsKey, T>,
/// Maps Register keys to indices
reg_keys: HashMap<RegisterAsKey, u32>,
/// Mapping between bit index and its register info
bit_info: Vec<Option<BitInfo>>,
/// Registers in the circuit
registry: IndexSet<R>,
registry: Vec<R>,
/// Registers in Python
registers: Vec<OnceLock<PyObject>>,
/// Cached Python bits
cached_py_bits: OnceLock<Py<PyList>>,
/// Cached Python registers
cached_py_regs: OnceLock<Py<PyList>>,
}

impl<T, R> NewBitData<T, R>
where
T: From<BitType> + Copy + Debug + ToPyBit,
R: Register<Bit = T>
+ Equivalent<RegisterAsKey>
+ for<'a> Borrow<&'a RegisterAsKey>
+ Hash
+ Eq
+ From<(usize, Option<String>)>
+ for<'a> From<&'a [T]>
+ for<'a> From<(&'a [T], Option<String>)>,
+ for<'a> From<(&'a [T], String)>,
BitType: From<T>,
{
pub fn new(description: String) -> Self {
Expand All @@ -272,19 +274,25 @@ where
bits: Vec::new(),
indices: HashMap::new(),
bit_info: Vec::new(),
registry: IndexSet::new(),
registry: Vec::new(),
registers: Vec::new(),
cached_py_bits: OnceLock::new(),
cached_py_regs: OnceLock::new(),
reg_keys: HashMap::new(),
}
}

pub fn with_capacity(description: String, capacity: usize) -> Self {
pub fn with_capacity(description: String, bit_capacity: usize, reg_capacity: usize) -> Self {
NewBitData {
description,
bits: Vec::with_capacity(capacity),
indices: HashMap::with_capacity(capacity),
bit_info: Vec::with_capacity(capacity),
registry: IndexSet::with_capacity(capacity),
registers: Vec::with_capacity(capacity),
bits: Vec::with_capacity(bit_capacity),
indices: HashMap::with_capacity(bit_capacity),
bit_info: Vec::with_capacity(bit_capacity),
registry: Vec::with_capacity(reg_capacity),
registers: Vec::with_capacity(reg_capacity),
cached_py_bits: OnceLock::new(),
cached_py_regs: OnceLock::new(),
reg_keys: HashMap::with_capacity(reg_capacity),
}
}

Expand All @@ -302,12 +310,6 @@ where
self.bits.is_empty()
}

/// Gets a reference to the underlying vector of Python bits.
#[inline]
pub fn bits(&self) -> &Vec<OnceLock<PyObject>> {
&self.bits
}

/// Adds a register onto the [BitData] of the circuit.
pub fn add_register(
&mut self,
Expand All @@ -318,7 +320,11 @@ where
match (size, bits) {
(None, None) => panic!("You should at least provide either a size or the bit indices."),
(None, Some(bits)) => {
let reg: R = (bits, name).into();
let reg: R = if let Some(name) = name {
(bits, name).into()
} else {
bits.into()
};
let idx = self.registry.len().try_into().unwrap_or_else(|_| {
panic!(
"The {} registry in this circuit has reached its maximum capacity.",
Expand All @@ -345,20 +351,26 @@ where
))
}
}
self.registry.insert(reg);
self.reg_keys.insert(reg.as_key().clone(), idx);
self.registry.push(reg);
self.registers.push(OnceLock::new());
idx
}
(Some(size), None) => {
let bits: Vec<T> = (0..size).map(|_| self.add_bit()).collect();
let reg = (bits.as_slice(), name).into();
let reg: R = if let Some(name) = name {
(bits.as_slice(), name).into()
} else {
bits.as_slice().into()
};
let idx = self.registry.len().try_into().unwrap_or_else(|_| {
panic!(
"The {} registry in this circuit has reached its maximum capacity.",
self.description
)
});
self.registry.insert(reg);
self.reg_keys.insert(reg.as_key().clone(), idx);
self.registry.push(reg);
self.registers.push(OnceLock::new());
idx
}
Expand All @@ -383,13 +395,33 @@ where
idx.into()
}

/// Retrieves the register info of a bit. Will panic if the index is out of range.
pub fn get_bit_info(&self, index: T) -> Option<&BitInfo> {
self.bit_info[BitType::from(index) as usize].as_ref()
}

/// Retrieves a register by its index within the circuit
#[inline]
pub fn get_register(&self, index: u32) -> Option<&R> {
self.registry.get_index(index as usize)
self.registry.get(index as usize)
}

#[inline]
pub fn get_register_by_key(&self, key: &RegisterAsKey) -> Option<&R> {
self.registry.get(&key)
self.reg_keys
.get(key)
.and_then(|idx| self.get_register(*idx))
}

/// Checks if a register is in the circuit
#[inline]
pub fn contains_register(&self, reg: &R) -> bool {
self.contains_register_by_key(reg.as_key())
}

#[inline]
pub fn contains_register_by_key(&self, reg: &RegisterAsKey) -> bool {
self.reg_keys.contains_key(reg)
}

// =======================
Expand All @@ -402,6 +434,34 @@ where
self.indices.get(&BitAsKey::new(bit)).copied()
}

/// Gets a reference to the cached Python list, maintained by
/// this instance.
#[inline]
pub fn py_cached_bits(&self, py: Python) -> &Py<PyList> {
self
.cached_py_bits
.get_or_init(|| PyList::empty_bound(py).into())
}

/// Gets a reference to the underlying vector of Python bits.
#[inline]
pub fn py_bits(&self, py: Python) -> PyResult<Vec<&PyObject>> {
(0..self.len())
.map(|idx| {
self.py_get_bit(py, (idx as u32).into())
.map(|bit| bit.unwrap())
})
.collect::<PyResult<_>>()
}

/// Gets a reference to the underlying vector of Python registers.
#[inline]
pub fn py_registers(&self, py: Python) -> PyResult<Vec<&PyObject>> {
(0..self.len_regs() as u32)
.map(|idx| self.py_get_register(py, idx).map(|reg| reg.unwrap()))
.collect::<PyResult<_>>()
}

/// Map the provided Python bits to their native indices.
/// An error is returned if any bit is not registered.
pub fn py_map_bits<'py>(
Expand All @@ -425,10 +485,25 @@ where
v.map(|x| x.into_iter())
}

/// Map the provided native indices to the corresponding Python
/// bit instances.
/// Panics if any of the indices are out of range.
pub fn py_map_indices(
&mut self,
py: Python,
bits: &[T],
) -> PyResult<impl ExactSizeIterator<Item = &Py<PyAny>>> {
let v: Vec<_> = bits
.iter()
.map(|i| -> PyResult<&PyObject> { Ok(self.py_get_bit(py, *i)?.unwrap()) })
.collect::<PyResult<_>>()?;
Ok(v.into_iter())
}

/// Gets the Python bit corresponding to the given native
/// bit index.
#[inline]
pub fn py_get_bit(&mut self, py: Python, index: T) -> PyResult<Option<&PyObject>> {
pub fn py_get_bit(&self, py: Python, index: T) -> PyResult<Option<&PyObject>> {
/*
For this method we want to make sure a couple of things are done first:
Expand Down Expand Up @@ -472,7 +547,7 @@ where
}

/// Retrieves a register instance from Python based on the rust description.
pub fn py_get_register(&mut self, py: Python, index: u32) -> PyResult<Option<&PyObject>> {
pub fn py_get_register(&self, py: Python, index: u32) -> PyResult<Option<&PyObject>> {
let index_as_usize = index as usize;
// First check if the cell is in range if not, return none
if self.registers.get(index_as_usize).is_none() {
Expand Down Expand Up @@ -533,6 +608,18 @@ where
pub fn py_add_bit(&mut self, bit: &Bound<PyAny>, strict: bool) -> PyResult<T> {
let py: Python<'_> = bit.py();

if self.bits.len()
!= self
.cached_py_bits
.get_or_init(|| PyList::empty_bound(py).into())
.bind(bit.py())
.len()
{
return Err(PyRuntimeError::new_err(
format!("This circuit's {} list has become out of sync with the circuit data. Did something modify it?", self.description)
));
}

let idx: BitType = self.bits.len().try_into().map_err(|_| {
PyRuntimeError::new_err(format!(
"The number of {} in the circuit has exceeded the maximum capacity",
Expand All @@ -546,6 +633,10 @@ where
{
self.bit_info.push(None);
self.bits.push(bit.into_py(py).into());
self.cached_py_bits
.get_or_init(|| PyList::empty_bound(py).into())
.bind(py)
.append(bit)?;
// self.cached.bind(py).append(bit)?;
} else if strict {
return Err(PyValueError::new_err(format!(
Expand All @@ -557,6 +648,19 @@ where
}

pub fn py_add_register(&mut self, register: &Bound<PyAny>) -> PyResult<u32> {
let py = register.py();
if self.registers.len()
!= self
.cached_py_regs
.get_or_init(|| PyList::empty_bound(py).into())
.bind(py)
.len()
{
return Err(PyRuntimeError::new_err(
format!("This circuit's {} list has become out of sync with the circuit data. Did something modify it?", self.description)
));
}

// let index: u32 = self.registers.len().try_into().map_err(|_| {
// PyRuntimeError::new_err(format!(
// "The number of {} registers in the circuit has exceeded the maximum capacity",
Expand All @@ -578,6 +682,10 @@ where

let name: String = register.getattr("name")?.extract()?;
self.registers.push(register.clone().unbind().into());
self.cached_py_regs
.get_or_init(|| PyList::empty_bound(py).into())
.bind(py)
.append(register)?;
Ok(self.add_register(Some(name), None, Some(&bits)))
}

Expand Down Expand Up @@ -613,6 +721,22 @@ where
Ok(())
}

pub(crate) fn py_bits_raw(&self) -> &[OnceLock<PyObject>] {
&self.bits
}

pub(crate) fn py_bits_cached_raw(&self) -> Option<&Py<PyList>> {
self.cached_py_bits.get()
}

pub(crate) fn py_regs_raw(&self) -> &[OnceLock<PyObject>] {
&self.bits
}

pub(crate) fn py_regs_cached_raw(&self) -> Option<&Py<PyList>> {
self.cached_py_bits.get()
}

/// Called during Python garbage collection, only!.
/// Note: INVALIDATES THIS INSTANCE.
pub fn dispose(&mut self) {
Expand All @@ -622,4 +746,24 @@ where
self.bit_info.clear();
self.registry.clear();
}

/// To convert [BitData] into [NewBitData]. If the structure the original comes from contains register
/// info. Make sure to add it manually after.
pub fn from_bit_data(py: Python, bit_data: &BitData<T>) -> Self {
Self {
description: bit_data.description.clone(),
bits: bit_data
.bits
.iter()
.map(|bit| bit.clone_ref(py).into())
.collect(),
indices: bit_data.indices.clone(),
reg_keys: HashMap::new(),
bit_info: (0..bit_data.len()).map(|_| None).collect(),
registry: Vec::new(),
registers: Vec::new(),
cached_py_bits: bit_data.cached().clone_ref(py).into(),
cached_py_regs: OnceLock::new(),
}
}
}
Loading

0 comments on commit f8bdfa6

Please sign in to comment.