Skip to content

Commit

Permalink
Add: Stable infrastructure for EquivalenceLibrary
Browse files Browse the repository at this point in the history
- TODO: Make elements pickleable.
  • Loading branch information
raynelfss committed Jun 13, 2024
1 parent 9eddbbf commit 11611e2
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 27 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/circuit/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@ doctest = false

[dependencies]
hashbrown.workspace = true
itertools = "0.13.0"
pyo3.workspace = true
rustworkx-core = "0.14.2"
159 changes: 142 additions & 17 deletions crates/circuit/src/equivalence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,17 @@
// copyright notice, and modified files need to carry a notice indicating
// that they have been altered from the originals.

use itertools::Itertools;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::{error::Error, fmt::Display};

use exceptions::CircuitError;
use hashbrown::{HashMap, HashSet};
use pyo3::sync::GILOnceCell;
use pyo3::types::PyDict;
use pyo3::{prelude::*, types::IntoPyDict};

use rustworkx_core::petgraph::{
graph::{DiGraph, EdgeIndex, NodeIndex},
visit::EdgeRef,
Expand Down Expand Up @@ -73,7 +76,7 @@ pub static PYDIGRAPH: ImportOnceCell = ImportOnceCell::new("rustworkx", "PyDiGra

// Custom Structs

#[pyclass(sequence)]
#[pyclass(sequence, module = "qiskit._accelerate.circuit.equivalence")]
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Key {
#[pyo3(get)]
Expand All @@ -85,6 +88,7 @@ pub struct Key {
#[pymethods]
impl Key {
#[new]
#[pyo3(signature = (name, num_qubits))]
fn new(name: String, num_qubits: usize) -> Self {
Self { name, num_qubits }
}
Expand All @@ -102,6 +106,15 @@ impl Key {
fn __repr__(slf: PyRef<'_, Self>) -> String {
slf.to_string()
}

fn __getstate__(slf: PyRef<Self>) -> (String, usize) {
(slf.name.to_owned(), slf.num_qubits)
}

fn __setstate__(mut slf: PyRefMut<Self>, state: (String, usize)) {
slf.name = state.0;
slf.num_qubits = state.1;
}
}

impl Display for Key {
Expand All @@ -114,8 +127,8 @@ impl Display for Key {
}
}

#[pyclass(sequence)]
#[derive(Debug, Clone)]
#[pyclass(sequence, module = "qiskit._accelerate.circuit.equivalence")]
#[derive(Debug, Clone, PartialEq)]
pub struct Equivalence {
#[pyo3(get)]
pub params: Vec<Param>,
Expand All @@ -126,13 +139,27 @@ pub struct Equivalence {
#[pymethods]
impl Equivalence {
#[new]
#[pyo3(signature = (params, circuit))]
fn new(params: Vec<Param>, circuit: CircuitRep) -> Self {
Self { circuit, params }
}

fn __repr__(&self) -> String {
self.to_string()
}

fn __eq__(&self, other: Self) -> bool {
self.eq(&other)
}

fn __getstate__(slf: PyRef<Self>) -> (Vec<Param>, CircuitRep) {
(slf.params.to_owned(), slf.circuit.to_owned())
}

fn __setstate__(mut slf: PyRefMut<Self>, state: (Vec<Param>, CircuitRep)) {
slf.params = state.0;
slf.circuit = state.1;
}
}

impl Display for Equivalence {
Expand All @@ -145,8 +172,8 @@ impl Display for Equivalence {
}
}

#[pyclass(sequence)]
#[derive(Debug, Clone)]
#[pyclass(sequence, module = "qiskit._accelerate.circuit.equivalence")]
#[derive(Debug, Clone, PartialEq)]
pub struct NodeData {
#[pyo3(get)]
key: Key,
Expand All @@ -157,26 +184,45 @@ pub struct NodeData {
#[pymethods]
impl NodeData {
#[new]
#[pyo3(signature = (key, equivs))]
fn new(key: Key, equivs: Vec<Equivalence>) -> Self {
Self { key, equivs }
}

fn __repr__(&self) -> String {
self.to_string()
}

fn __eq__(&self, other: Self) -> bool {
self.eq(&other)
}

fn __getstate__(slf: PyRef<Self>) -> (Key, Vec<Equivalence>) {
(slf.key.to_owned(), slf.equivs.to_owned())
}

fn __setstate__(mut slf: PyRefMut<Self>, state: (Key, Vec<Equivalence>)) {
slf.key = state.0;
slf.equivs = state.1;
}
}

impl Display for NodeData {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "NodeData(key={}, equivs={:#?})", self.key, self.equivs)
write!(
f,
"NodeData(key={}, equivs=[{}])",
self.key,
self.equivs.iter().format(", ")
)
}
}

#[pyclass(sequence)]
#[derive(Debug, Clone)]
#[pyclass(sequence, module = "qiskit._accelerate.circuit.equivalence")]
#[derive(Debug, Clone, PartialEq)]
pub struct EdgeData {
#[pyo3(get)]
pub index: u32,
pub index: usize,
#[pyo3(get)]
pub num_gates: usize,
#[pyo3(get)]
Expand All @@ -188,7 +234,8 @@ pub struct EdgeData {
#[pymethods]
impl EdgeData {
#[new]
fn new(index: u32, num_gates: usize, rule: Equivalence, source: Key) -> Self {
#[pyo3(signature = (index, num_gates, rule, source))]
fn new(index: usize, num_gates: usize, rule: Equivalence, source: Key) -> Self {
Self {
index,
num_gates,
Expand All @@ -200,6 +247,26 @@ impl EdgeData {
fn __repr__(&self) -> String {
self.to_string()
}

fn __eq__(&self, other: Self) -> bool {
self.eq(&other)
}

fn __getstate__(slf: PyRef<Self>) -> (usize, usize, Equivalence, Key) {
(
slf.index,
slf.num_gates,
slf.rule.to_owned(),
slf.source.to_owned(),
)
}

fn __setstate__(mut slf: PyRefMut<Self>, state: (usize, usize, Equivalence, Key)) {
slf.index = state.0;
slf.num_gates = state.1;
slf.rule = state.2;
slf.source = state.3;
}
}

impl Display for EdgeData {
Expand Down Expand Up @@ -390,6 +457,12 @@ impl FromPyObject<'_> for CircuitRep {
}
}

impl PartialEq for CircuitRep {
fn eq(&self, other: &Self) -> bool {
self.object.is(&other.object)
}
}

impl IntoPy<PyObject> for CircuitRep {
fn into_py(self, _py: Python<'_>) -> PyObject {
self.object
Expand All @@ -415,7 +488,11 @@ impl FromPyObject<'_> for CircuitInstructionRep {
type GraphType = DiGraph<NodeData, EdgeData>;
type KTIType = HashMap<Key, NodeIndex>;

#[pyclass(subclass, name = "BaseEquivalenceLibrary")]
#[pyclass(
subclass,
name = "BaseEquivalenceLibrary",
module = "qiskit._accelerate.circuit.equivalence"
)]
#[derive(Debug, Clone)]
pub struct EquivalenceLibrary {
_graph: GraphType,
Expand Down Expand Up @@ -552,6 +629,59 @@ impl EquivalenceLibrary {
fn node_index(&self, key: Key) -> usize {
self.key_to_node_index[&key].index()
}

fn __getstate__(slf: PyRef<Self>) -> PyResult<Bound<'_, PyDict>> {
let ret = PyDict::new_bound(slf.py());
ret.set_item("rule_id", slf.rule_id)?;
let key_to_usize_node: HashMap<(String, usize), usize> = HashMap::from_iter(
slf.key_to_node_index
.iter()
.map(|(key, val)| ((key.name.to_string(), key.num_qubits), val.index())),
);
ret.set_item("key_to_node_index", key_to_usize_node.into_py(slf.py()))?;
let graph_nodes: Vec<NodeData> = slf._graph.node_weights().cloned().collect();
ret.set_item("graph_nodes", graph_nodes.into_py(slf.py()))?;
let graph_edges: Vec<(usize, usize, EdgeData)> = slf
._graph
.edge_indices()
.map(|edge_id| {
(
slf._graph.edge_endpoints(edge_id).unwrap(),
slf._graph.edge_weight(edge_id).unwrap(),
)
})
.map(|((source, target), weight)| (source.index(), target.index(), weight.to_owned()))
.collect_vec();
ret.set_item("graph_edges", graph_edges.into_py(slf.py()))?;
Ok(ret)
}

fn __setstate__(mut slf: PyRefMut<Self>, state: &Bound<'_, PyDict>) -> PyResult<()> {
slf.rule_id = state.get_item("rule_id")?.unwrap().extract()?;
slf.key_to_node_index = state
.get_item("key_to_node_index")?
.unwrap()
.extract::<HashMap<Key, usize>>()?
.into_iter()
.map(|(key, val)| (key, NodeIndex::new(val)))
.collect();
let graph_nodes: Vec<NodeData> = state.get_item("graph_nodes")?.unwrap().extract()?;
let graph_edges: Vec<(usize, usize, EdgeData)> =
state.get_item("graph_edges")?.unwrap().extract()?;
slf._graph = GraphType::new();
for node_weight in graph_nodes {
slf._graph.add_node(node_weight);
}
for (source_node, target_node, edge_weight) in graph_edges {
slf._graph.add_edge(
NodeIndex::new(source_node),
NodeIndex::new(target_node),
edge_weight,
);
}
slf.graph = None;
Ok(())
}
}

// Rust native methods
Expand Down Expand Up @@ -622,7 +752,7 @@ impl EquivalenceLibrary {
self.set_default_node(source.to_owned()),
target,
EdgeData {
index: self.rule_id as u32,
index: self.rule_id,
num_gates: sources.len(),
rule: equiv.to_owned(),
source: source.to_owned(),
Expand Down Expand Up @@ -749,11 +879,6 @@ fn rebind_equiv(
let param_map: Vec<(Param, Param)> = equiv_params
.into_iter()
.filter_map(|param| {
println!(
"{:#?}: is expr: {}",
param,
matches!(param, Param::ParameterExpression(_))
);
if matches!(param, Param::ParameterExpression(_)) {
Some(param)
} else {
Expand Down
8 changes: 2 additions & 6 deletions crates/circuit/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ mod bit_data;
mod interner;
mod packed_instruction;

use pyo3::prelude::*;
use pyo3::types::PySlice;
use pyo3::{prelude::*, wrap_pymodule};

/// A private enumeration type used to extract arguments to pymethod
/// that may be either an index or a slice
Expand Down Expand Up @@ -64,12 +64,8 @@ impl From<Clbit> for BitType {

#[pymodule]
pub fn circuit(m: Bound<PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pymodule!(equivalence::equivalence))?;
m.add_class::<circuit_data::CircuitData>()?;
m.add_class::<equivalence::EquivalenceLibrary>()?;
m.add_class::<equivalence::EdgeData>()?;
m.add_class::<equivalence::NodeData>()?;
m.add_class::<equivalence::Key>()?;
m.add_class::<equivalence::Equivalence>()?;
m.add_class::<dag_node::DAGNode>()?;
m.add_class::<dag_node::DAGInNode>()?;
m.add_class::<dag_node::DAGOutNode>()?;
Expand Down
1 change: 1 addition & 0 deletions qiskit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
# We manually define them on import so people can directly import qiskit._accelerate.* submodules
# and not have to rely on attribute access. No action needed for top-level extension packages.
sys.modules["qiskit._accelerate.circuit"] = _accelerate.circuit
sys.modules["qiskit._accelerate.circuit.equivalence"] = _accelerate.circuit.equivalence
sys.modules["qiskit._accelerate.convert_2q_block_matrix"] = _accelerate.convert_2q_block_matrix
sys.modules["qiskit._accelerate.dense_layout"] = _accelerate.dense_layout
sys.modules["qiskit._accelerate.error_map"] = _accelerate.error_map
Expand Down
8 changes: 7 additions & 1 deletion qiskit/circuit/equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@
from .exceptions import CircuitError
from .parameter import Parameter
from .parameterexpression import ParameterExpression
from qiskit._accelerate.circuit import BaseEquivalenceLibrary, Key, Equivalence, NodeData, EdgeData
from qiskit._accelerate.circuit.equivalence import (
BaseEquivalenceLibrary,
Key,
Equivalence,
NodeData,
EdgeData,
)


class EquivalenceLibrary(BaseEquivalenceLibrary):
Expand Down
11 changes: 8 additions & 3 deletions qiskit/transpiler/passes/basis/basis_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
)
from qiskit.dagcircuit import DAGCircuit
from qiskit.converters import circuit_to_dag, dag_to_circuit
from qiskit.circuit.equivalence import Key, NodeData
from qiskit.circuit.equivalence import Key, NodeData, Equivalence
from qiskit.transpiler.basepasses import TransformationPass
from qiskit.transpiler.exceptions import TranspilerError

Expand Down Expand Up @@ -541,7 +541,7 @@ def _basis_search(equiv_lib, source_basis, target_basis):
logger.debug("Begining basis search from %s to %s.", source_basis, target_basis)

source_basis = {
(gate_name, gate_num_qubits)
Key(gate_name, gate_num_qubits)
for gate_name, gate_num_qubits in source_basis
if gate_name not in target_basis
}
Expand All @@ -559,7 +559,12 @@ def _basis_search(equiv_lib, source_basis, target_basis):

# we add a dummy node and connect it with gates in the target basis.
# we'll start the search from this dummy node.
dummy = graph.add_node(NodeData(key="key", equivs=[("dummy starting node", 0)]))
dummy = graph.add_node(
NodeData(
key=Key("key", 0),
equivs=[Equivalence([], QuantumCircuit(0, name="dummy starting node"))],
)
)

try:
graph.add_edges_from_no_data(
Expand Down

0 comments on commit 11611e2

Please sign in to comment.